diff --git a/examples/infected_asyncio_echo_server.py b/examples/infected_asyncio_echo_server.py index ee7c45b..0250835 100644 --- a/examples/infected_asyncio_echo_server.py +++ b/examples/infected_asyncio_echo_server.py @@ -13,6 +13,7 @@ import tractor async def aio_echo_server( to_trio: trio.MemorySendChannel, from_trio: asyncio.Queue, + ) -> None: # a first message must be sent **from** this ``asyncio`` diff --git a/nooz/318.bug.rst b/nooz/318.bug.rst new file mode 100644 index 0000000..5bbf4f0 --- /dev/null +++ b/nooz/318.bug.rst @@ -0,0 +1,13 @@ +Fix a previously undetected ``trio``-``asyncio`` task lifetime linking +issue with the ``to_asyncio.open_channel_from()`` api where both sides +where not properly waiting/signalling termination and it was possible +for ``asyncio``-side errors to not propagate due to a race condition. + +The implementation fix summary is: +- add state to signal the end of the ``trio`` side task to be + read by the ``asyncio`` side and always cancel any ongoing + task in such cases. +- always wait on the ``asyncio`` task termination from the ``trio`` + side on error before maybe raising said error. +- always close the ``trio`` mem chan on exit to ensure the other + side can detect it and follow. diff --git a/tests/test_infected_asyncio.py b/tests/test_infected_asyncio.py index 37c85fd..976741d 100644 --- a/tests/test_infected_asyncio.py +++ b/tests/test_infected_asyncio.py @@ -11,12 +11,25 @@ import importlib import pytest import trio import tractor -from tractor import to_asyncio -from tractor import RemoteActorError +from tractor import ( + to_asyncio, + RemoteActorError, +) from tractor.trionics import BroadcastReceiver -async def sleep_and_err(sleep_for: float = 0.1): +async def sleep_and_err( + sleep_for: float = 0.1, + + # just signature placeholders for compat with + # ``to_asyncio.open_channel_from()`` + to_trio: Optional[trio.MemorySendChannel] = None, + from_trio: Optional[asyncio.Queue] = None, + +): + if to_trio: + to_trio.send_nowait('start') + await asyncio.sleep(sleep_for) assert 0 @@ -146,6 +159,80 @@ def test_trio_cancels_aio(arb_addr): trio.run(main) +@tractor.context +async def trio_ctx( + ctx: tractor.Context, +): + + await ctx.started('start') + + # this will block until the ``asyncio`` task sends a "first" + # message. + with trio.fail_after(2): + async with ( + tractor.to_asyncio.open_channel_from( + sleep_and_err, + ) as (first, chan), + + trio.open_nursery() as n, + ): + + assert first == 'start' + + # spawn another asyncio task for the cuck of it. + n.start_soon( + tractor.to_asyncio.run_task, + sleep_forever, + ) + await trio.sleep_forever() + + +@pytest.mark.parametrize( + 'parent_cancels', [False, True], + ids='parent_actor_cancels_child={}'.format +) +def test_context_spawns_aio_task_that_errors( + arb_addr, + parent_cancels: bool, +): + ''' + Verify that spawning a task via an intertask channel ctx mngr that + errors correctly propagates the error back from the `asyncio`-side + task. + + ''' + async def main(): + + async with tractor.open_nursery() as n: + p = await n.start_actor( + 'aio_daemon', + enable_modules=[__name__], + infect_asyncio=True, + # debug_mode=True, + loglevel='cancel', + ) + async with p.open_context( + trio_ctx, + ) as (ctx, first): + + assert first == 'start' + + if parent_cancels: + await p.cancel_actor() + + await trio.sleep_forever() + + with pytest.raises(RemoteActorError) as excinfo: + trio.run(main) + + err = excinfo.value + assert isinstance(err, RemoteActorError) + if parent_cancels: + assert err.type == trio.Cancelled + else: + assert err.type == AssertionError + + async def aio_cancel(): '''' Cancel urself boi. @@ -385,6 +472,8 @@ async def trio_to_aio_echo_server( print('breaking aio echo loop') break + print('exiting asyncio task') + async with to_asyncio.open_channel_from( aio_echo_server, ) as (first, chan): diff --git a/tractor/to_asyncio.py b/tractor/to_asyncio.py index 6ca07ca..a19afe1 100644 --- a/tractor/to_asyncio.py +++ b/tractor/to_asyncio.py @@ -23,7 +23,6 @@ from asyncio.exceptions import CancelledError from contextlib import asynccontextmanager as acm from dataclasses import dataclass import inspect -import traceback from typing import ( Any, Callable, @@ -63,6 +62,7 @@ class LinkedTaskChannel(trio.abc.Channel): _trio_cs: trio.CancelScope _aio_task_complete: trio.Event + _trio_exited: bool = False # set after ``asyncio.create_task()`` _aio_task: Optional[asyncio.Task] = None @@ -73,7 +73,13 @@ class LinkedTaskChannel(trio.abc.Channel): await self._from_aio.aclose() async def receive(self) -> Any: - async with translate_aio_errors(self): + async with translate_aio_errors( + self, + + # XXX: obviously this will deadlock if an on-going stream is + # being procesed. + # wait_on_aio_task=False, + ): # TODO: do we need this to guarantee asyncio code get's # cancelled in the case where the trio side somehow creates @@ -210,10 +216,8 @@ def _run_asyncio_task( orig = result = id(coro) try: result = await coro - except GeneratorExit: - # no need to relay error - raise except BaseException as aio_err: + log.exception('asyncio task errored') chan._aio_err = aio_err raise @@ -237,6 +241,7 @@ def _run_asyncio_task( to_trio.close() aio_task_complete.set() + log.runtime(f'`asyncio` task: {task.get_name()} is complete') # start the asyncio task we submitted from trio if not inspect.isawaitable(coro): @@ -291,10 +296,12 @@ def _run_asyncio_task( elif task_err is None: assert aio_err aio_err.with_traceback(aio_err.__traceback__) - msg = ''.join(traceback.format_exception(type(aio_err))) - log.error( - f'infected task errorred:\n{msg}' - ) + log.error('infected task errorred') + + # XXX: alway cancel the scope on error + # in case the trio task is blocking + # on a checkpoint. + cancel_scope.cancel() # raise any ``asyncio`` side error. raise aio_err @@ -307,6 +314,7 @@ def _run_asyncio_task( async def translate_aio_errors( chan: LinkedTaskChannel, + wait_on_aio_task: bool = False, ) -> AsyncIterator[None]: ''' @@ -318,6 +326,7 @@ async def translate_aio_errors( aio_err: Optional[BaseException] = None + # TODO: make thisi a channel method? def maybe_raise_aio_err( err: Optional[Exception] = None ) -> None: @@ -367,13 +376,30 @@ async def translate_aio_errors( raise finally: - # always cancel the ``asyncio`` task if we've made it this far - # and it's not done. - if not task.done() and aio_err: + if ( + # NOTE: always cancel the ``asyncio`` task if we've made it + # this far and it's not done. + not task.done() and aio_err + + # or the trio side has exited it's surrounding cancel scope + # indicating the lifetime of the ``asyncio``-side task + # should also be terminated. + or chan._trio_exited + ): + log.runtime( + f'Cancelling `asyncio`-task: {task.get_name()}' + ) # assert not aio_err, 'WTF how did asyncio do this?!' task.cancel() - # if any ``asyncio`` error was caught, raise it here inline + # Required to sync with the far end ``asyncio``-task to ensure + # any error is captured (via monkeypatching the + # ``channel._aio_err``) before calling ``maybe_raise_aio_err()`` + # below! + if wait_on_aio_task: + await chan._aio_task_complete.wait() + + # NOTE: if any ``asyncio`` error was caught, raise it here inline # here in the ``trio`` task maybe_raise_aio_err() @@ -398,7 +424,10 @@ async def run_task( **kwargs, ) with chan._from_aio: - async with translate_aio_errors(chan): + async with translate_aio_errors( + chan, + wait_on_aio_task=True, + ): # return single value that is the output from the # ``asyncio`` function-as-task. Expect the mem chan api to # do the job of handling cross-framework cancellations @@ -426,13 +455,21 @@ async def open_channel_from( **kwargs, ) async with chan._from_aio: - async with translate_aio_errors(chan): + async with translate_aio_errors( + chan, + wait_on_aio_task=True, + ): # sync to a "started()"-like first delivered value from the # ``asyncio`` task. first = await chan.receive() # deliver stream handle upward - yield first, chan + try: + with chan._trio_cs: + yield first, chan + finally: + chan._trio_exited = True + chan._to_trio.close() def run_as_asyncio_guest( @@ -482,7 +519,7 @@ def run_as_asyncio_guest( main_outcome.unwrap() else: trio_done_fut.set_result(main_outcome) - print(f"trio_main finished: {main_outcome!r}") + log.runtime(f"trio_main finished: {main_outcome!r}") # start the infection: run trio on the asyncio loop in "guest mode" log.info(f"Infecting asyncio process with {trio_main}")