diff --git a/tractor/to_asyncio.py b/tractor/to_asyncio.py index fb8f4cd..0f22300 100644 --- a/tractor/to_asyncio.py +++ b/tractor/to_asyncio.py @@ -4,6 +4,7 @@ Infection apis for ``asyncio`` loops running ``trio`` using guest mode. import asyncio import inspect from typing import ( + Any, Callable, AsyncGenerator, Awaitable, @@ -21,21 +22,26 @@ log = get_logger(__name__) __all__ = ['run_task', 'run_as_asyncio_guest'] -async def _invoke( - from_trio: trio.abc.ReceiveChannel, - to_trio: asyncio.Queue, +async def run_coro( + to_trio: trio.MemorySendChannel, coro: Awaitable, ) -> None: - """Await or stream awaiable object based on ``coro`` type into - ``trio`` memory channel. - - ``from_trio`` might eventually be used here for bidirectional streaming. + """Await ``coro`` and relay result back to ``trio``. """ - if inspect.isasyncgen(coro): - async for item in coro: - to_trio.send_nowait(item) - elif inspect.iscoroutine(coro): - to_trio.send_nowait(await coro) + to_trio.send_nowait(await coro) + + +async def consume_asyncgen( + to_trio: trio.MemorySendChannel, + coro: AsyncGenerator, +) -> None: + """Stream async generator results back to ``trio``. + + ``from_trio`` might eventually be used here for + bidirectional streaming. + """ + async for item in coro: + to_trio.send_nowait(item) async def run_task( @@ -44,15 +50,15 @@ async def run_task( qsize: int = 2**10, _treat_as_stream: bool = False, **kwargs, -) -> Union[AsyncGenerator, Awaitable]: +) -> Union[AsyncGenerator, Any]: """Run an ``asyncio`` async function or generator in a task, return or stream the result back to ``trio``. """ assert current_actor().is_infected_aio() # ITC (inter task comms) - from_trio = asyncio.Queue(qsize) - to_trio, from_aio = trio.open_memory_channel(qsize) + from_trio = asyncio.Queue(qsize) # type: ignore + to_trio, from_aio = trio.open_memory_channel(qsize) # type: ignore args = tuple(inspect.getfullargspec(func).args) @@ -66,7 +72,7 @@ async def run_task( if 'to_trio' in args: kwargs['to_trio'] = to_trio if 'from_trio' in args: - kwargs['from_trio'] = to_trio + kwargs['from_trio'] = from_trio coro = func(**kwargs) @@ -74,7 +80,13 @@ async def run_task( # start the asyncio task we submitted from trio # TODO: try out ``anyio`` asyncio based tg here - task = asyncio.create_task(_invoke(from_trio, to_trio, coro)) + if inspect.isawaitable(coro): + task = asyncio.create_task(run_coro(to_trio, coro)) + elif inspect.isasyncgen(coro): + task = asyncio.create_task(consume_asyncgen(to_trio, coro)) + else: + raise TypeError(f"No support for {coro}") + err = None def cancel_trio(task): @@ -88,27 +100,29 @@ async def run_task( # asycn gen if inspect.isasyncgen(coro) or _treat_as_stream: - async def result(): + + async def stream_results(): with cancel_scope: + # stream values upward async with from_aio: async for item in from_aio: yield item if cancel_scope.cancelled_caught and err: raise err - return result() + return stream_results() # simple async func elif inspect.iscoroutine(coro): with cancel_scope: - result = await from_aio.receive() - return result + # return single value + return await from_aio.receive() if cancel_scope.cancelled_caught and err: raise err def run_as_asyncio_guest( - trio_main: Awaitable, + trio_main: Callable, ) -> None: """Entry for an "infected ``asyncio`` actor".