From 38d03858d748e1227386b2e69078a1023ef7c600 Mon Sep 17 00:00:00 2001 From: Tyler Goodlet Date: Thu, 14 Jul 2022 16:35:41 -0400 Subject: [PATCH] Fix `asyncio`-task-sync and error propagation This fixes an previously undetected bug where if an `.open_channel_from()` spawned task errored the error would not be propagated to the `trio` side and instead would fail silently with a console log error. What was most odd is that it only seems easy to trigger when you put a slight task sleep before the error is raised (:eyeroll:). This patch adds a few things to address this and just in general improve iter-task lifetime syncing: - add `LinkedTaskChannel._trio_exited: bool` a flag set from the `trio` side when the channel block exits. - add a `wait_on_aio_task: bool` flag to `translate_aio_errors` which toggles whether to wait the `asyncio` task termination event on exit. - cancel the `asyncio` task if the trio side has ended, when `._trio_exited == True`. - always close the `trio` mem channel when the task exits such that the `asyncio` side can error on any next `.send()` call. --- tests/test_infected_asyncio.py | 2 +- tractor/to_asyncio.py | 65 +++++++++++++++++++++++++++------- 2 files changed, 54 insertions(+), 13 deletions(-) diff --git a/tests/test_infected_asyncio.py b/tests/test_infected_asyncio.py index e1228c0..81a4f3e 100644 --- a/tests/test_infected_asyncio.py +++ b/tests/test_infected_asyncio.py @@ -185,7 +185,7 @@ async def trio_ctx( tractor.to_asyncio.run_task, sleep_forever, ) - # await trio.sleep_forever() + await trio.sleep_forever() @pytest.mark.parametrize( diff --git a/tractor/to_asyncio.py b/tractor/to_asyncio.py index 6ca07ca..5168234 100644 --- a/tractor/to_asyncio.py +++ b/tractor/to_asyncio.py @@ -63,6 +63,7 @@ class LinkedTaskChannel(trio.abc.Channel): _trio_cs: trio.CancelScope _aio_task_complete: trio.Event + _trio_exited: bool = False # set after ``asyncio.create_task()`` _aio_task: Optional[asyncio.Task] = None @@ -73,7 +74,13 @@ class LinkedTaskChannel(trio.abc.Channel): await self._from_aio.aclose() async def receive(self) -> Any: - async with translate_aio_errors(self): + async with translate_aio_errors( + self, + + # XXX: obviously this will deadlock if an on-going stream is + # being procesed. + # wait_on_aio_task=False, + ): # TODO: do we need this to guarantee asyncio code get's # cancelled in the case where the trio side somehow creates @@ -210,10 +217,8 @@ def _run_asyncio_task( orig = result = id(coro) try: result = await coro - except GeneratorExit: - # no need to relay error - raise except BaseException as aio_err: + log.exception('asyncio task errored') chan._aio_err = aio_err raise @@ -237,6 +242,7 @@ def _run_asyncio_task( to_trio.close() aio_task_complete.set() + log.runtime(f'`asyncio` task: {task.get_name()} is complete') # start the asyncio task we submitted from trio if not inspect.isawaitable(coro): @@ -296,6 +302,11 @@ def _run_asyncio_task( f'infected task errorred:\n{msg}' ) + # XXX: alway cancel the scope on error + # in case the trio task is blocking + # on a checkpoint. + cancel_scope.cancel() + # raise any ``asyncio`` side error. raise aio_err @@ -307,6 +318,7 @@ def _run_asyncio_task( async def translate_aio_errors( chan: LinkedTaskChannel, + wait_on_aio_task: bool = False, ) -> AsyncIterator[None]: ''' @@ -318,6 +330,7 @@ async def translate_aio_errors( aio_err: Optional[BaseException] = None + # TODO: make thisi a channel method? def maybe_raise_aio_err( err: Optional[Exception] = None ) -> None: @@ -367,13 +380,30 @@ async def translate_aio_errors( raise finally: - # always cancel the ``asyncio`` task if we've made it this far - # and it's not done. - if not task.done() and aio_err: + if ( + # NOTE: always cancel the ``asyncio`` task if we've made it + # this far and it's not done. + not task.done() and aio_err + + # or the trio side has exited it's surrounding cancel scope + # indicating the lifetime of the ``asyncio``-side task + # should also be terminated. + or chan._trio_exited + ): + log.runtime( + f'Cancelling `asyncio`-task: {chan._aio_taskget_name()}' + ) # assert not aio_err, 'WTF how did asyncio do this?!' task.cancel() - # if any ``asyncio`` error was caught, raise it here inline + # Required to sync with the far end ``asyncio``-task to ensure + # any error is captured (via monkeypatching the + # ``channel._aio_err``) before calling ``maybe_raise_aio_err()`` + # below! + if wait_on_aio_task: + await chan._aio_task_complete.wait() + + # NOTE: if any ``asyncio`` error was caught, raise it here inline # here in the ``trio`` task maybe_raise_aio_err() @@ -398,7 +428,10 @@ async def run_task( **kwargs, ) with chan._from_aio: - async with translate_aio_errors(chan): + async with translate_aio_errors( + chan, + wait_on_aio_task=True, + ): # return single value that is the output from the # ``asyncio`` function-as-task. Expect the mem chan api to # do the job of handling cross-framework cancellations @@ -426,13 +459,21 @@ async def open_channel_from( **kwargs, ) async with chan._from_aio: - async with translate_aio_errors(chan): + async with translate_aio_errors( + chan, + wait_on_aio_task=True, + ): # sync to a "started()"-like first delivered value from the # ``asyncio`` task. first = await chan.receive() # deliver stream handle upward - yield first, chan + try: + with chan._trio_cs: + yield first, chan + finally: + chan._trio_exited = True + chan._to_trio.close() def run_as_asyncio_guest( @@ -482,7 +523,7 @@ def run_as_asyncio_guest( main_outcome.unwrap() else: trio_done_fut.set_result(main_outcome) - print(f"trio_main finished: {main_outcome!r}") + log.runtime(f"trio_main finished: {main_outcome!r}") # start the infection: run trio on the asyncio loop in "guest mode" log.info(f"Infecting asyncio process with {trio_main}")