diff --git a/tractor/to_asyncio.py b/tractor/to_asyncio.py index 6013fb1..5d168f7 100644 --- a/tractor/to_asyncio.py +++ b/tractor/to_asyncio.py @@ -92,14 +92,28 @@ def _run_asyncio_task( raise finally: - aio_task_complete.set() - if result != orig and aio_err is None: + if ( + result != orig and + aio_err is None and + + # in the ``open_channel_from()`` case we don't + # relay through the "return value". + not provide_channels + ): to_trio.send_nowait(result) + to_trio.close() + from_aio.close() + aio_task_complete.set() + # start the asyncio task we submitted from trio if inspect.isawaitable(coro): task = asyncio.create_task( - wait_on_coro_final_result(to_trio, coro, aio_task_complete) + wait_on_coro_final_result( + to_trio, + coro, + aio_task_complete + ) ) else: @@ -120,7 +134,7 @@ def _run_asyncio_task( cancel_scope.cancel() else: if aio_err is not None: - log.exception(f"infected task errorred:") + log.exception("infected task errorred:") from_aio._err = aio_err # order is opposite here cancel_scope.cancel() @@ -131,41 +145,20 @@ def _run_asyncio_task( return task, from_aio, to_trio, cancel_scope, aio_task_complete -async def run_task( - func: Callable, - *, +@acm +async def translate_aio_errors( - qsize: int = 2**10, - **kwargs, + from_aio: trio.MemoryReceiveChannel, + task: asyncio.Task, -) -> Any: +) -> None: ''' - Run an ``asyncio`` async function or generator in a task, return - or stream the result back to ``trio``. + Error handling context around ``asyncio`` task spawns which + appropriately translates errors and cancels into ``trio`` land. ''' - # simple async func try: - task, from_aio, to_trio, cs, _ = _run_asyncio_task( - func, - qsize=1, - **kwargs, - ) - - # return single value - with cs: - # naively expect the mem chan api to do the job - # of handling cross-framework cancellations / errors - return await from_aio.receive() - - if cs.cancelled_caught: - aio_err = from_aio._err - - # always raise from any captured asyncio error - if aio_err: - raise aio_err - - # Do we need this? + yield except ( Exception, CancelledError, @@ -190,6 +183,41 @@ async def run_task( # ... do what .. +async def run_task( + func: Callable, + *, + + qsize: int = 2**10, + **kwargs, + +) -> Any: + ''' + Run an ``asyncio`` async function or generator in a task, return + or stream the result back to ``trio``. + + ''' + # simple async func + task, from_aio, to_trio, cs, _ = _run_asyncio_task( + func, + qsize=1, + **kwargs, + ) + async with translate_aio_errors(from_aio, task): + + # return single value + with cs: + # naively expect the mem chan api to do the job + # of handling cross-framework cancellations / errors + return await from_aio.receive() + + if cs.cancelled_caught: + aio_err = from_aio._err + + # always raise from any captured asyncio error + if aio_err: + raise aio_err + + # TODO: explicitly api for the streaming case where # we pull from the mem chan in an async generator? # This ends up looking more like our ``Portal.open_stream_from()`` @@ -203,40 +231,21 @@ async def open_channel_from( ) -> AsyncIterator[Any]: - try: - task, from_aio, to_trio, cs, aio_task_complete = _run_asyncio_task( - target, - qsize=2**8, - provide_channels=True, - **kwargs, - ) - + task, from_aio, to_trio, cs, aio_task_complete = _run_asyncio_task( + target, + qsize=2**8, + provide_channels=True, + **kwargs, + ) + async with translate_aio_errors(from_aio, task): with cs: # sync to "started()" call. first = await from_aio.receive() + # stream values upward async with from_aio: yield first, from_aio - # await aio_task_complete.wait() - - except BaseException as err: - - aio_err = from_aio._err - - if aio_err is not None: - # always raise from any captured asyncio error - raise err from aio_err - else: - raise - - finally: - if cs.cancelled_caught: - # always raise from any captured asyncio error - if from_aio._err: - raise from_aio._err - - if not task.done(): - task.cancel() + await aio_task_complete.wait() def run_as_asyncio_guest( @@ -284,7 +293,7 @@ def run_as_asyncio_guest( run_sync_soon_threadsafe=loop.call_soon_threadsafe, done_callback=trio_done_callback, ) - (await trio_done_fut).unwrap() + return (await trio_done_fut).unwrap() # might as well if it's installed. try: