forked from goodboy/tractor
1
0
Fork 0

Re-wrap and raise `asyncio.CancelledError`

For whatever reason `trio` seems to be swallowing this exception when
raised in the `trio` task so instead wrap it in our own non-base
exception type: `AsyncioCancelled` and raise that when the `asyncio`
task cancels itself internally using `raise <err> from <src_err>` style.

Further don't bother cancelling the `trio` task (via cancel scope)
since we we can just use the recv mem chan closure error as a signal
and explicitly lookup any set asyncio error.
infect_asyncio
Tyler Goodlet 2021-11-23 16:19:19 -05:00
parent c48c68c0bc
commit 5f4094691d
2 changed files with 79 additions and 57 deletions

View File

@ -82,6 +82,15 @@ class StreamOverrun(trio.TooSlowError):
"This stream was overrun by sender"
class AsyncioCancelled(Exception):
'''
Asyncio cancelled translation (non-base) error
for use with the ``to_asyncio`` module
to be raised in the ``trio`` side task
'''
def pack_error(
exc: BaseException,
tb=None,

View File

@ -19,6 +19,7 @@ import trio
from .log import get_logger
from ._state import current_actor
from ._exceptions import AsyncioCancelled
log = get_logger(__name__)
@ -91,11 +92,9 @@ def _run_asyncio_task(
except BaseException as err:
aio_err = err
from_aio._err = aio_err
to_trio.close()
from_aio.close()
raise
finally:
else:
if (
result != orig and
aio_err is None and
@ -106,11 +105,13 @@ def _run_asyncio_task(
):
to_trio.send_nowait(result)
finally:
# if the task was spawned using ``open_channel_from()``
# then we close the channels on exit.
if provide_channels:
# only close the sender side which will relay
# a ``trio.EndOfChannel`` to the trio (consumer) side.
to_trio.close()
from_aio.close()
aio_task_complete.set()
@ -127,27 +128,26 @@ def _run_asyncio_task(
else:
raise TypeError(f"No support for invoking {coro}")
def cancel_trio(task) -> None:
def cancel_trio(task: asyncio.Task) -> None:
'''
Cancel the calling ``trio`` task on error.
'''
nonlocal aio_err
try:
aio_err = task.exception()
except CancelledError as cerr:
log.cancel("infected task was cancelled")
from_aio._err = cerr
from_aio.close()
cancel_scope.cancel()
else:
aio_err = from_aio._err
if aio_err is not None:
if type(aio_err) is CancelledError:
log.cancel("infected task was cancelled")
else:
aio_err.with_traceback(aio_err.__traceback__)
log.exception("infected task errorred:")
from_aio._err = aio_err
# NOTE: order is opposite here
cancel_scope.cancel()
# NOTE: currently mem chan closure may act as a form
# of error relay (at least in the ``asyncio.CancelledError``
# case) since we have no way to directly trigger a ``trio``
# task error without creating a nursery to throw one.
# We might want to change this in the future though.
from_aio.close()
task.add_done_callback(cancel_trio)
@ -160,6 +160,7 @@ async def translate_aio_errors(
from_aio: trio.MemoryReceiveChannel,
task: asyncio.Task,
trio_cs: trio.CancelScope,
) -> None:
'''
@ -167,34 +168,50 @@ async def translate_aio_errors(
appropriately translates errors and cancels into ``trio`` land.
'''
err: Optional[Exception] = None
aio_err: Optional[Exception] = None
def maybe_raise_aio_err(err: Exception):
def maybe_raise_aio_err(
err: Optional[Exception] = None
) -> None:
aio_err = from_aio._err
if (
aio_err is not None and
type(aio_err) != CancelledError
):
# always raise from any captured asyncio error
if err:
raise aio_err from err
else:
raise aio_err
try:
yield
except (
Exception,
CancelledError,
) as err:
maybe_raise_aio_err(err)
raise
# NOTE: see the note in the ``cancel_trio()`` asyncio task
# termination callback
trio.ClosedResourceError,
):
aio_err = from_aio._err
if (
task.cancelled() and
type(aio_err) is CancelledError
):
# if an underlying ``asyncio.CancelledError`` triggered this
# channel close, raise our (non-``BaseException``) wrapper
# error: ``AsyncioCancelled`` from that source error.
raise AsyncioCancelled from aio_err
else:
raise
finally:
# always cancel the ``asyncio`` task if we've made it this far
# and it's not done.
if not task.done() and aio_err:
# assert not aio_err, 'WTF how did asyncio do this?!'
task.cancel()
maybe_raise_aio_err(err)
# if task.cancelled():
# ... do what ..
# if any ``asyncio`` error was caught, raise it here inline
# here in the ``trio`` task
maybe_raise_aio_err()
async def run_task(
@ -216,27 +233,16 @@ async def run_task(
qsize=1,
**kwargs,
)
async with translate_aio_errors(from_aio, task):
# return single value
with cs:
# naively expect the mem chan api to do the job
# of handling cross-framework cancellations / errors
with from_aio:
# try:
async with translate_aio_errors(from_aio, task, cs):
# return single value that is the output from the
# ``asyncio`` function-as-task. Expect the mem chan api to
# do the job of handling cross-framework cancellations
# / errors via closure and translation in the
# ``translate_aio_errors()`` in the above ctx mngr.
return await from_aio.receive()
if cs.cancelled_caught:
aio_err = from_aio._err
# always raise from any captured asyncio error
if aio_err:
raise aio_err
# TODO: explicitly api for the streaming case where
# we pull from the mem chan in an async generator?
# This ends up looking more like our ``Portal.open_stream_from()``
# NB: code below is untested.
@dataclass
class LinkedTaskChannel(trio.abc.Channel):
@ -250,19 +256,24 @@ class LinkedTaskChannel(trio.abc.Channel):
_to_aio: asyncio.Queue
_from_aio: trio.MemoryReceiveChannel
_aio_task_complete: trio.Event
_trio_cs: trio.CancelScope
async def aclose(self) -> None:
self._from_aio.close()
async def receive(self) -> Any:
async with translate_aio_errors(self._from_aio, self._aio_task):
async with translate_aio_errors(
self._from_aio,
self._aio_task,
self._trio_cs,
):
return await self._from_aio.receive()
async def wait_ayncio_complete(self) -> None:
await self._aio_task_complete.wait()
def cancel_asyncio_task(self) -> None:
self._aio_task.cancel()
# def cancel_asyncio_task(self) -> None:
# self._aio_task.cancel()
async def send(self, item: Any) -> None:
'''
@ -292,15 +303,17 @@ async def open_channel_from(
provide_channels=True,
**kwargs,
)
chan = LinkedTaskChannel(task, aio_q, from_aio, aio_task_complete)
with cs:
async with translate_aio_errors(from_aio, task):
chan = LinkedTaskChannel(
task, aio_q, from_aio,
aio_task_complete, cs
)
async with from_aio:
async with translate_aio_errors(from_aio, task, cs):
# sync to a "started()"-like first delivered value from the
# ``asyncio`` task.
first = await from_aio.receive()
# stream values upward
async with from_aio:
yield first, chan