Re-wrap and raise `asyncio.CancelledError`

For whatever reason `trio` seems to be swallowing this exception when
raised in the `trio` task so instead wrap it in our own non-base
exception type: `AsyncioCancelled` and raise that when the `asyncio`
task cancels itself internally using `raise <err> from <src_err>` style.

Further don't bother cancelling the `trio` task (via cancel scope)
since we we can just use the recv mem chan closure error as a signal
and explicitly lookup any set asyncio error.
infect_asyncio
Tyler Goodlet 2021-11-23 16:19:19 -05:00
parent c48c68c0bc
commit 5f4094691d
2 changed files with 79 additions and 57 deletions

View File

@ -82,6 +82,15 @@ class StreamOverrun(trio.TooSlowError):
"This stream was overrun by sender" "This stream was overrun by sender"
class AsyncioCancelled(Exception):
'''
Asyncio cancelled translation (non-base) error
for use with the ``to_asyncio`` module
to be raised in the ``trio`` side task
'''
def pack_error( def pack_error(
exc: BaseException, exc: BaseException,
tb=None, tb=None,

View File

@ -19,6 +19,7 @@ import trio
from .log import get_logger from .log import get_logger
from ._state import current_actor from ._state import current_actor
from ._exceptions import AsyncioCancelled
log = get_logger(__name__) log = get_logger(__name__)
@ -91,11 +92,9 @@ def _run_asyncio_task(
except BaseException as err: except BaseException as err:
aio_err = err aio_err = err
from_aio._err = aio_err from_aio._err = aio_err
to_trio.close()
from_aio.close()
raise raise
finally: else:
if ( if (
result != orig and result != orig and
aio_err is None and aio_err is None and
@ -106,11 +105,13 @@ def _run_asyncio_task(
): ):
to_trio.send_nowait(result) to_trio.send_nowait(result)
finally:
# if the task was spawned using ``open_channel_from()`` # if the task was spawned using ``open_channel_from()``
# then we close the channels on exit. # then we close the channels on exit.
if provide_channels: if provide_channels:
# only close the sender side which will relay
# a ``trio.EndOfChannel`` to the trio (consumer) side.
to_trio.close() to_trio.close()
from_aio.close()
aio_task_complete.set() aio_task_complete.set()
@ -127,28 +128,27 @@ def _run_asyncio_task(
else: else:
raise TypeError(f"No support for invoking {coro}") raise TypeError(f"No support for invoking {coro}")
def cancel_trio(task) -> None: def cancel_trio(task: asyncio.Task) -> None:
''' '''
Cancel the calling ``trio`` task on error. Cancel the calling ``trio`` task on error.
''' '''
nonlocal aio_err nonlocal aio_err
try: aio_err = from_aio._err
aio_err = task.exception()
except CancelledError as cerr: if aio_err is not None:
log.cancel("infected task was cancelled") if type(aio_err) is CancelledError:
from_aio._err = cerr log.cancel("infected task was cancelled")
from_aio.close() else:
cancel_scope.cancel()
else:
if aio_err is not None:
aio_err.with_traceback(aio_err.__traceback__) aio_err.with_traceback(aio_err.__traceback__)
log.exception("infected task errorred:") log.exception("infected task errorred:")
from_aio._err = aio_err
# NOTE: order is opposite here # NOTE: currently mem chan closure may act as a form
cancel_scope.cancel() # of error relay (at least in the ``asyncio.CancelledError``
from_aio.close() # case) since we have no way to directly trigger a ``trio``
# task error without creating a nursery to throw one.
# We might want to change this in the future though.
from_aio.close()
task.add_done_callback(cancel_trio) task.add_done_callback(cancel_trio)
@ -160,6 +160,7 @@ async def translate_aio_errors(
from_aio: trio.MemoryReceiveChannel, from_aio: trio.MemoryReceiveChannel,
task: asyncio.Task, task: asyncio.Task,
trio_cs: trio.CancelScope,
) -> None: ) -> None:
''' '''
@ -167,34 +168,50 @@ async def translate_aio_errors(
appropriately translates errors and cancels into ``trio`` land. appropriately translates errors and cancels into ``trio`` land.
''' '''
err: Optional[Exception] = None
aio_err: Optional[Exception] = None aio_err: Optional[Exception] = None
def maybe_raise_aio_err(err: Exception): def maybe_raise_aio_err(
err: Optional[Exception] = None
) -> None:
aio_err = from_aio._err aio_err = from_aio._err
if ( if (
aio_err is not None and aio_err is not None and
type(aio_err) != CancelledError type(aio_err) != CancelledError
): ):
# always raise from any captured asyncio error # always raise from any captured asyncio error
raise aio_err from err if err:
raise aio_err from err
else:
raise aio_err
try: try:
yield yield
except ( except (
Exception, # NOTE: see the note in the ``cancel_trio()`` asyncio task
CancelledError, # termination callback
) as err: trio.ClosedResourceError,
maybe_raise_aio_err(err) ):
raise aio_err = from_aio._err
if (
task.cancelled() and
type(aio_err) is CancelledError
):
# if an underlying ``asyncio.CancelledError`` triggered this
# channel close, raise our (non-``BaseException``) wrapper
# error: ``AsyncioCancelled`` from that source error.
raise AsyncioCancelled from aio_err
else:
raise
finally: finally:
# always cancel the ``asyncio`` task if we've made it this far
# and it's not done.
if not task.done() and aio_err: if not task.done() and aio_err:
# assert not aio_err, 'WTF how did asyncio do this?!'
task.cancel() task.cancel()
maybe_raise_aio_err(err) # if any ``asyncio`` error was caught, raise it here inline
# if task.cancelled(): # here in the ``trio`` task
# ... do what .. maybe_raise_aio_err()
async def run_task( async def run_task(
@ -216,27 +233,16 @@ async def run_task(
qsize=1, qsize=1,
**kwargs, **kwargs,
) )
async with translate_aio_errors(from_aio, task): with from_aio:
# try:
# return single value async with translate_aio_errors(from_aio, task, cs):
with cs: # return single value that is the output from the
# naively expect the mem chan api to do the job # ``asyncio`` function-as-task. Expect the mem chan api to
# of handling cross-framework cancellations / errors # do the job of handling cross-framework cancellations
# / errors via closure and translation in the
# ``translate_aio_errors()`` in the above ctx mngr.
return await from_aio.receive() return await from_aio.receive()
if cs.cancelled_caught:
aio_err = from_aio._err
# always raise from any captured asyncio error
if aio_err:
raise aio_err
# TODO: explicitly api for the streaming case where
# we pull from the mem chan in an async generator?
# This ends up looking more like our ``Portal.open_stream_from()``
# NB: code below is untested.
@dataclass @dataclass
class LinkedTaskChannel(trio.abc.Channel): class LinkedTaskChannel(trio.abc.Channel):
@ -250,19 +256,24 @@ class LinkedTaskChannel(trio.abc.Channel):
_to_aio: asyncio.Queue _to_aio: asyncio.Queue
_from_aio: trio.MemoryReceiveChannel _from_aio: trio.MemoryReceiveChannel
_aio_task_complete: trio.Event _aio_task_complete: trio.Event
_trio_cs: trio.CancelScope
async def aclose(self) -> None: async def aclose(self) -> None:
self._from_aio.close() self._from_aio.close()
async def receive(self) -> Any: async def receive(self) -> Any:
async with translate_aio_errors(self._from_aio, self._aio_task): async with translate_aio_errors(
self._from_aio,
self._aio_task,
self._trio_cs,
):
return await self._from_aio.receive() return await self._from_aio.receive()
async def wait_ayncio_complete(self) -> None: async def wait_ayncio_complete(self) -> None:
await self._aio_task_complete.wait() await self._aio_task_complete.wait()
def cancel_asyncio_task(self) -> None: # def cancel_asyncio_task(self) -> None:
self._aio_task.cancel() # self._aio_task.cancel()
async def send(self, item: Any) -> None: async def send(self, item: Any) -> None:
''' '''
@ -292,16 +303,18 @@ async def open_channel_from(
provide_channels=True, provide_channels=True,
**kwargs, **kwargs,
) )
chan = LinkedTaskChannel(task, aio_q, from_aio, aio_task_complete) chan = LinkedTaskChannel(
with cs: task, aio_q, from_aio,
async with translate_aio_errors(from_aio, task): aio_task_complete, cs
)
async with from_aio:
async with translate_aio_errors(from_aio, task, cs):
# sync to a "started()"-like first delivered value from the # sync to a "started()"-like first delivered value from the
# ``asyncio`` task. # ``asyncio`` task.
first = await from_aio.receive() first = await from_aio.receive()
# stream values upward # stream values upward
async with from_aio: yield first, chan
yield first, chan
def run_as_asyncio_guest( def run_as_asyncio_guest(