diff --git a/tractor/_exceptions.py b/tractor/_exceptions.py index 95d7533..f3beb5a 100644 --- a/tractor/_exceptions.py +++ b/tractor/_exceptions.py @@ -82,6 +82,15 @@ class StreamOverrun(trio.TooSlowError): "This stream was overrun by sender" +class AsyncioCancelled(Exception): + ''' + Asyncio cancelled translation (non-base) error + for use with the ``to_asyncio`` module + to be raised in the ``trio`` side task + + ''' + + def pack_error( exc: BaseException, tb=None, diff --git a/tractor/to_asyncio.py b/tractor/to_asyncio.py index 6132303..4e33e68 100644 --- a/tractor/to_asyncio.py +++ b/tractor/to_asyncio.py @@ -19,6 +19,7 @@ import trio from .log import get_logger from ._state import current_actor +from ._exceptions import AsyncioCancelled log = get_logger(__name__) @@ -91,11 +92,9 @@ def _run_asyncio_task( except BaseException as err: aio_err = err from_aio._err = aio_err - to_trio.close() - from_aio.close() raise - finally: + else: if ( result != orig and aio_err is None and @@ -106,11 +105,13 @@ def _run_asyncio_task( ): to_trio.send_nowait(result) + finally: # if the task was spawned using ``open_channel_from()`` # then we close the channels on exit. if provide_channels: + # only close the sender side which will relay + # a ``trio.EndOfChannel`` to the trio (consumer) side. to_trio.close() - from_aio.close() aio_task_complete.set() @@ -127,28 +128,27 @@ def _run_asyncio_task( else: raise TypeError(f"No support for invoking {coro}") - def cancel_trio(task) -> None: + def cancel_trio(task: asyncio.Task) -> None: ''' Cancel the calling ``trio`` task on error. ''' nonlocal aio_err - try: - aio_err = task.exception() - except CancelledError as cerr: - log.cancel("infected task was cancelled") - from_aio._err = cerr - from_aio.close() - cancel_scope.cancel() - else: - if aio_err is not None: + aio_err = from_aio._err + + if aio_err is not None: + if type(aio_err) is CancelledError: + log.cancel("infected task was cancelled") + else: aio_err.with_traceback(aio_err.__traceback__) log.exception("infected task errorred:") - from_aio._err = aio_err - # NOTE: order is opposite here - cancel_scope.cancel() - from_aio.close() + # NOTE: currently mem chan closure may act as a form + # of error relay (at least in the ``asyncio.CancelledError`` + # case) since we have no way to directly trigger a ``trio`` + # task error without creating a nursery to throw one. + # We might want to change this in the future though. + from_aio.close() task.add_done_callback(cancel_trio) @@ -160,6 +160,7 @@ async def translate_aio_errors( from_aio: trio.MemoryReceiveChannel, task: asyncio.Task, + trio_cs: trio.CancelScope, ) -> None: ''' @@ -167,34 +168,50 @@ async def translate_aio_errors( appropriately translates errors and cancels into ``trio`` land. ''' - err: Optional[Exception] = None aio_err: Optional[Exception] = None - def maybe_raise_aio_err(err: Exception): + def maybe_raise_aio_err( + err: Optional[Exception] = None + ) -> None: aio_err = from_aio._err if ( aio_err is not None and type(aio_err) != CancelledError ): # always raise from any captured asyncio error - raise aio_err from err - + if err: + raise aio_err from err + else: + raise aio_err try: yield except ( - Exception, - CancelledError, - ) as err: - maybe_raise_aio_err(err) - raise + # NOTE: see the note in the ``cancel_trio()`` asyncio task + # termination callback + trio.ClosedResourceError, + ): + aio_err = from_aio._err + if ( + task.cancelled() and + type(aio_err) is CancelledError + ): + # if an underlying ``asyncio.CancelledError`` triggered this + # channel close, raise our (non-``BaseException``) wrapper + # error: ``AsyncioCancelled`` from that source error. + raise AsyncioCancelled from aio_err + else: + raise finally: + # always cancel the ``asyncio`` task if we've made it this far + # and it's not done. if not task.done() and aio_err: + # assert not aio_err, 'WTF how did asyncio do this?!' task.cancel() - maybe_raise_aio_err(err) - # if task.cancelled(): - # ... do what .. + # if any ``asyncio`` error was caught, raise it here inline + # here in the ``trio`` task + maybe_raise_aio_err() async def run_task( @@ -216,27 +233,16 @@ async def run_task( qsize=1, **kwargs, ) - async with translate_aio_errors(from_aio, task): - - # return single value - with cs: - # naively expect the mem chan api to do the job - # of handling cross-framework cancellations / errors + with from_aio: + # try: + async with translate_aio_errors(from_aio, task, cs): + # return single value that is the output from the + # ``asyncio`` function-as-task. Expect the mem chan api to + # do the job of handling cross-framework cancellations + # / errors via closure and translation in the + # ``translate_aio_errors()`` in the above ctx mngr. return await from_aio.receive() - if cs.cancelled_caught: - aio_err = from_aio._err - - # always raise from any captured asyncio error - if aio_err: - raise aio_err - - -# TODO: explicitly api for the streaming case where -# we pull from the mem chan in an async generator? -# This ends up looking more like our ``Portal.open_stream_from()`` -# NB: code below is untested. - @dataclass class LinkedTaskChannel(trio.abc.Channel): @@ -250,19 +256,24 @@ class LinkedTaskChannel(trio.abc.Channel): _to_aio: asyncio.Queue _from_aio: trio.MemoryReceiveChannel _aio_task_complete: trio.Event + _trio_cs: trio.CancelScope async def aclose(self) -> None: self._from_aio.close() async def receive(self) -> Any: - async with translate_aio_errors(self._from_aio, self._aio_task): + async with translate_aio_errors( + self._from_aio, + self._aio_task, + self._trio_cs, + ): return await self._from_aio.receive() async def wait_ayncio_complete(self) -> None: await self._aio_task_complete.wait() - def cancel_asyncio_task(self) -> None: - self._aio_task.cancel() + # def cancel_asyncio_task(self) -> None: + # self._aio_task.cancel() async def send(self, item: Any) -> None: ''' @@ -292,16 +303,18 @@ async def open_channel_from( provide_channels=True, **kwargs, ) - chan = LinkedTaskChannel(task, aio_q, from_aio, aio_task_complete) - with cs: - async with translate_aio_errors(from_aio, task): + chan = LinkedTaskChannel( + task, aio_q, from_aio, + aio_task_complete, cs + ) + async with from_aio: + async with translate_aio_errors(from_aio, task, cs): # sync to a "started()"-like first delivered value from the # ``asyncio`` task. first = await from_aio.receive() # stream values upward - async with from_aio: - yield first, chan + yield first, chan def run_as_asyncio_guest(