forked from goodboy/tractor
Re-wrap and raise `asyncio.CancelledError`
For whatever reason `trio` seems to be swallowing this exception when raised in the `trio` task so instead wrap it in our own non-base exception type: `AsyncioCancelled` and raise that when the `asyncio` task cancels itself internally using `raise <err> from <src_err>` style. Further don't bother cancelling the `trio` task (via cancel scope) since we we can just use the recv mem chan closure error as a signal and explicitly lookup any set asyncio error.infect_asyncio
parent
c48c68c0bc
commit
5f4094691d
|
@ -82,6 +82,15 @@ class StreamOverrun(trio.TooSlowError):
|
||||||
"This stream was overrun by sender"
|
"This stream was overrun by sender"
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncioCancelled(Exception):
|
||||||
|
'''
|
||||||
|
Asyncio cancelled translation (non-base) error
|
||||||
|
for use with the ``to_asyncio`` module
|
||||||
|
to be raised in the ``trio`` side task
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
def pack_error(
|
def pack_error(
|
||||||
exc: BaseException,
|
exc: BaseException,
|
||||||
tb=None,
|
tb=None,
|
||||||
|
|
|
@ -19,6 +19,7 @@ import trio
|
||||||
|
|
||||||
from .log import get_logger
|
from .log import get_logger
|
||||||
from ._state import current_actor
|
from ._state import current_actor
|
||||||
|
from ._exceptions import AsyncioCancelled
|
||||||
|
|
||||||
log = get_logger(__name__)
|
log = get_logger(__name__)
|
||||||
|
|
||||||
|
@ -91,11 +92,9 @@ def _run_asyncio_task(
|
||||||
except BaseException as err:
|
except BaseException as err:
|
||||||
aio_err = err
|
aio_err = err
|
||||||
from_aio._err = aio_err
|
from_aio._err = aio_err
|
||||||
to_trio.close()
|
|
||||||
from_aio.close()
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
finally:
|
else:
|
||||||
if (
|
if (
|
||||||
result != orig and
|
result != orig and
|
||||||
aio_err is None and
|
aio_err is None and
|
||||||
|
@ -106,11 +105,13 @@ def _run_asyncio_task(
|
||||||
):
|
):
|
||||||
to_trio.send_nowait(result)
|
to_trio.send_nowait(result)
|
||||||
|
|
||||||
|
finally:
|
||||||
# if the task was spawned using ``open_channel_from()``
|
# if the task was spawned using ``open_channel_from()``
|
||||||
# then we close the channels on exit.
|
# then we close the channels on exit.
|
||||||
if provide_channels:
|
if provide_channels:
|
||||||
|
# only close the sender side which will relay
|
||||||
|
# a ``trio.EndOfChannel`` to the trio (consumer) side.
|
||||||
to_trio.close()
|
to_trio.close()
|
||||||
from_aio.close()
|
|
||||||
|
|
||||||
aio_task_complete.set()
|
aio_task_complete.set()
|
||||||
|
|
||||||
|
@ -127,28 +128,27 @@ def _run_asyncio_task(
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"No support for invoking {coro}")
|
raise TypeError(f"No support for invoking {coro}")
|
||||||
|
|
||||||
def cancel_trio(task) -> None:
|
def cancel_trio(task: asyncio.Task) -> None:
|
||||||
'''
|
'''
|
||||||
Cancel the calling ``trio`` task on error.
|
Cancel the calling ``trio`` task on error.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
nonlocal aio_err
|
nonlocal aio_err
|
||||||
try:
|
aio_err = from_aio._err
|
||||||
aio_err = task.exception()
|
|
||||||
except CancelledError as cerr:
|
if aio_err is not None:
|
||||||
log.cancel("infected task was cancelled")
|
if type(aio_err) is CancelledError:
|
||||||
from_aio._err = cerr
|
log.cancel("infected task was cancelled")
|
||||||
from_aio.close()
|
else:
|
||||||
cancel_scope.cancel()
|
|
||||||
else:
|
|
||||||
if aio_err is not None:
|
|
||||||
aio_err.with_traceback(aio_err.__traceback__)
|
aio_err.with_traceback(aio_err.__traceback__)
|
||||||
log.exception("infected task errorred:")
|
log.exception("infected task errorred:")
|
||||||
from_aio._err = aio_err
|
|
||||||
|
|
||||||
# NOTE: order is opposite here
|
# NOTE: currently mem chan closure may act as a form
|
||||||
cancel_scope.cancel()
|
# of error relay (at least in the ``asyncio.CancelledError``
|
||||||
from_aio.close()
|
# case) since we have no way to directly trigger a ``trio``
|
||||||
|
# task error without creating a nursery to throw one.
|
||||||
|
# We might want to change this in the future though.
|
||||||
|
from_aio.close()
|
||||||
|
|
||||||
task.add_done_callback(cancel_trio)
|
task.add_done_callback(cancel_trio)
|
||||||
|
|
||||||
|
@ -160,6 +160,7 @@ async def translate_aio_errors(
|
||||||
|
|
||||||
from_aio: trio.MemoryReceiveChannel,
|
from_aio: trio.MemoryReceiveChannel,
|
||||||
task: asyncio.Task,
|
task: asyncio.Task,
|
||||||
|
trio_cs: trio.CancelScope,
|
||||||
|
|
||||||
) -> None:
|
) -> None:
|
||||||
'''
|
'''
|
||||||
|
@ -167,34 +168,50 @@ async def translate_aio_errors(
|
||||||
appropriately translates errors and cancels into ``trio`` land.
|
appropriately translates errors and cancels into ``trio`` land.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
err: Optional[Exception] = None
|
|
||||||
aio_err: Optional[Exception] = None
|
aio_err: Optional[Exception] = None
|
||||||
|
|
||||||
def maybe_raise_aio_err(err: Exception):
|
def maybe_raise_aio_err(
|
||||||
|
err: Optional[Exception] = None
|
||||||
|
) -> None:
|
||||||
aio_err = from_aio._err
|
aio_err = from_aio._err
|
||||||
if (
|
if (
|
||||||
aio_err is not None and
|
aio_err is not None and
|
||||||
type(aio_err) != CancelledError
|
type(aio_err) != CancelledError
|
||||||
):
|
):
|
||||||
# always raise from any captured asyncio error
|
# always raise from any captured asyncio error
|
||||||
raise aio_err from err
|
if err:
|
||||||
|
raise aio_err from err
|
||||||
|
else:
|
||||||
|
raise aio_err
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
except (
|
except (
|
||||||
Exception,
|
# NOTE: see the note in the ``cancel_trio()`` asyncio task
|
||||||
CancelledError,
|
# termination callback
|
||||||
) as err:
|
trio.ClosedResourceError,
|
||||||
maybe_raise_aio_err(err)
|
):
|
||||||
raise
|
aio_err = from_aio._err
|
||||||
|
if (
|
||||||
|
task.cancelled() and
|
||||||
|
type(aio_err) is CancelledError
|
||||||
|
):
|
||||||
|
# if an underlying ``asyncio.CancelledError`` triggered this
|
||||||
|
# channel close, raise our (non-``BaseException``) wrapper
|
||||||
|
# error: ``AsyncioCancelled`` from that source error.
|
||||||
|
raise AsyncioCancelled from aio_err
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise
|
||||||
finally:
|
finally:
|
||||||
|
# always cancel the ``asyncio`` task if we've made it this far
|
||||||
|
# and it's not done.
|
||||||
if not task.done() and aio_err:
|
if not task.done() and aio_err:
|
||||||
|
# assert not aio_err, 'WTF how did asyncio do this?!'
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
maybe_raise_aio_err(err)
|
# if any ``asyncio`` error was caught, raise it here inline
|
||||||
# if task.cancelled():
|
# here in the ``trio`` task
|
||||||
# ... do what ..
|
maybe_raise_aio_err()
|
||||||
|
|
||||||
|
|
||||||
async def run_task(
|
async def run_task(
|
||||||
|
@ -216,27 +233,16 @@ async def run_task(
|
||||||
qsize=1,
|
qsize=1,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
async with translate_aio_errors(from_aio, task):
|
with from_aio:
|
||||||
|
# try:
|
||||||
# return single value
|
async with translate_aio_errors(from_aio, task, cs):
|
||||||
with cs:
|
# return single value that is the output from the
|
||||||
# naively expect the mem chan api to do the job
|
# ``asyncio`` function-as-task. Expect the mem chan api to
|
||||||
# of handling cross-framework cancellations / errors
|
# do the job of handling cross-framework cancellations
|
||||||
|
# / errors via closure and translation in the
|
||||||
|
# ``translate_aio_errors()`` in the above ctx mngr.
|
||||||
return await from_aio.receive()
|
return await from_aio.receive()
|
||||||
|
|
||||||
if cs.cancelled_caught:
|
|
||||||
aio_err = from_aio._err
|
|
||||||
|
|
||||||
# always raise from any captured asyncio error
|
|
||||||
if aio_err:
|
|
||||||
raise aio_err
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: explicitly api for the streaming case where
|
|
||||||
# we pull from the mem chan in an async generator?
|
|
||||||
# This ends up looking more like our ``Portal.open_stream_from()``
|
|
||||||
# NB: code below is untested.
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LinkedTaskChannel(trio.abc.Channel):
|
class LinkedTaskChannel(trio.abc.Channel):
|
||||||
|
@ -250,19 +256,24 @@ class LinkedTaskChannel(trio.abc.Channel):
|
||||||
_to_aio: asyncio.Queue
|
_to_aio: asyncio.Queue
|
||||||
_from_aio: trio.MemoryReceiveChannel
|
_from_aio: trio.MemoryReceiveChannel
|
||||||
_aio_task_complete: trio.Event
|
_aio_task_complete: trio.Event
|
||||||
|
_trio_cs: trio.CancelScope
|
||||||
|
|
||||||
async def aclose(self) -> None:
|
async def aclose(self) -> None:
|
||||||
self._from_aio.close()
|
self._from_aio.close()
|
||||||
|
|
||||||
async def receive(self) -> Any:
|
async def receive(self) -> Any:
|
||||||
async with translate_aio_errors(self._from_aio, self._aio_task):
|
async with translate_aio_errors(
|
||||||
|
self._from_aio,
|
||||||
|
self._aio_task,
|
||||||
|
self._trio_cs,
|
||||||
|
):
|
||||||
return await self._from_aio.receive()
|
return await self._from_aio.receive()
|
||||||
|
|
||||||
async def wait_ayncio_complete(self) -> None:
|
async def wait_ayncio_complete(self) -> None:
|
||||||
await self._aio_task_complete.wait()
|
await self._aio_task_complete.wait()
|
||||||
|
|
||||||
def cancel_asyncio_task(self) -> None:
|
# def cancel_asyncio_task(self) -> None:
|
||||||
self._aio_task.cancel()
|
# self._aio_task.cancel()
|
||||||
|
|
||||||
async def send(self, item: Any) -> None:
|
async def send(self, item: Any) -> None:
|
||||||
'''
|
'''
|
||||||
|
@ -292,16 +303,18 @@ async def open_channel_from(
|
||||||
provide_channels=True,
|
provide_channels=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
chan = LinkedTaskChannel(task, aio_q, from_aio, aio_task_complete)
|
chan = LinkedTaskChannel(
|
||||||
with cs:
|
task, aio_q, from_aio,
|
||||||
async with translate_aio_errors(from_aio, task):
|
aio_task_complete, cs
|
||||||
|
)
|
||||||
|
async with from_aio:
|
||||||
|
async with translate_aio_errors(from_aio, task, cs):
|
||||||
# sync to a "started()"-like first delivered value from the
|
# sync to a "started()"-like first delivered value from the
|
||||||
# ``asyncio`` task.
|
# ``asyncio`` task.
|
||||||
first = await from_aio.receive()
|
first = await from_aio.receive()
|
||||||
|
|
||||||
# stream values upward
|
# stream values upward
|
||||||
async with from_aio:
|
yield first, chan
|
||||||
yield first, chan
|
|
||||||
|
|
||||||
|
|
||||||
def run_as_asyncio_guest(
|
def run_as_asyncio_guest(
|
||||||
|
|
Loading…
Reference in New Issue