diff --git a/tractor/_entry.py b/tractor/_entry.py index ff3cce7..65ca5bb 100644 --- a/tractor/_entry.py +++ b/tractor/_entry.py @@ -12,9 +12,6 @@ from .log import get_console_log, get_logger from . import _state -__all__ = ('run',) - - log = get_logger(__name__) diff --git a/tractor/to_asyncio.py b/tractor/to_asyncio.py index 5d90f22..40c8bf1 100644 --- a/tractor/to_asyncio.py +++ b/tractor/to_asyncio.py @@ -13,22 +13,26 @@ from typing import ( import trio +from ._state import current_actor + + +__all__ = ['run'] + async def _invoke( - from_trio, - to_trio, - coro + from_trio: trio.abc.ReceiveChannel, + to_trio: asyncio.Queue, + coro: Awaitable, ) -> Union[AsyncGenerator, Awaitable]: """Await or stream awaiable object based on type into ``trio`` memory channel. """ async def stream_from_gen(c): async for item in c: - to_trio.put_nowait(item) - to_trio.put_nowait + to_trio.send_nowait(item) async def just_return(c): - to_trio.put_nowait(await c) + to_trio.send_nowait(await c) if inspect.isasyncgen(coro): return await stream_from_gen(coro) @@ -36,7 +40,6 @@ async def _invoke( return await coro -# TODO: make this some kind of tractor.to_asyncio.run() async def run( func: Callable, qsize: int = 2**10, @@ -45,6 +48,8 @@ async def run( """Run an ``asyncio`` async function or generator in a task, return or stream the result back to ``trio``. """ + assert current_actor()._infected_aio + # ITC (inter task comms) from_trio = asyncio.Queue(qsize) to_trio, from_aio = trio.open_memory_channel(qsize) @@ -55,16 +60,40 @@ async def run( coro = func(**kwargs) + cancel_scope = trio.CancelScope() + # start the asyncio task we submitted from trio # TODO: try out ``anyio`` asyncio based tg here - asyncio.create_task(_invoke(from_trio, to_trio, coro)) + task = asyncio.create_task(_invoke(from_trio, to_trio, coro)) + err = None + + # XXX: I'm not sure this actually does anything... + def cancel_trio(task): + """Cancel the calling ``trio`` task on error. + """ + nonlocal err + err = task.exception() + cancel_scope.cancel() + + task.add_done_callback(cancel_trio) # determine return type async func vs. gen if inspect.isasyncgen(coro): - await from_aio.get() - elif inspect.iscoroutine(coro): - async def gen(): - async for tick in from_aio: - yield tuple(tick) + # simple async func + async def result(): + with cancel_scope: + return await from_aio.get() + if cancel_scope.cancelled_caught and err: + raise err - return gen() + elif inspect.iscoroutine(coro): + # asycn gen + async def result(): + with cancel_scope: + async with from_aio: + async for item in from_aio: + yield item + if cancel_scope.cancelled_caught and err: + raise err + + return result()