diff --git a/tractor/to_asyncio.py b/tractor/to_asyncio.py new file mode 100644 index 0000000..40c8bf1 --- /dev/null +++ b/tractor/to_asyncio.py @@ -0,0 +1,99 @@ +""" +Infection apis for ``asyncio`` loops running ``trio`` using guest mode. +""" +import asyncio +import inspect +from typing import ( + Any, + Callable, + AsyncGenerator, + Awaitable, + Union, +) + +import trio + +from ._state import current_actor + + +__all__ = ['run'] + + +async def _invoke( + from_trio: trio.abc.ReceiveChannel, + to_trio: asyncio.Queue, + coro: Awaitable, +) -> Union[AsyncGenerator, Awaitable]: + """Await or stream awaiable object based on type into + ``trio`` memory channel. + """ + async def stream_from_gen(c): + async for item in c: + to_trio.send_nowait(item) + + async def just_return(c): + to_trio.send_nowait(await c) + + if inspect.isasyncgen(coro): + return await stream_from_gen(coro) + elif inspect.iscoroutine(coro): + return await coro + + +async def run( + func: Callable, + qsize: int = 2**10, + **kwargs, +) -> Any: + """Run an ``asyncio`` async function or generator in a task, return + or stream the result back to ``trio``. + """ + assert current_actor()._infected_aio + + # ITC (inter task comms) + from_trio = asyncio.Queue(qsize) + to_trio, from_aio = trio.open_memory_channel(qsize) + + # allow target func to accept/stream results manually + kwargs['to_trio'] = to_trio + kwargs['from_trio'] = to_trio + + coro = func(**kwargs) + + cancel_scope = trio.CancelScope() + + # start the asyncio task we submitted from trio + # TODO: try out ``anyio`` asyncio based tg here + task = asyncio.create_task(_invoke(from_trio, to_trio, coro)) + err = None + + # XXX: I'm not sure this actually does anything... + def cancel_trio(task): + """Cancel the calling ``trio`` task on error. + """ + nonlocal err + err = task.exception() + cancel_scope.cancel() + + task.add_done_callback(cancel_trio) + + # determine return type async func vs. gen + if inspect.isasyncgen(coro): + # simple async func + async def result(): + with cancel_scope: + return await from_aio.get() + if cancel_scope.cancelled_caught and err: + raise err + + elif inspect.iscoroutine(coro): + # asycn gen + async def result(): + with cancel_scope: + async with from_aio: + async for item in from_aio: + yield item + if cancel_scope.cancelled_caught and err: + raise err + + return result()