Adjust linked-loop-task tear down sequence

Close the mem chan before cancelling the `trio` task in order to ensure
we retrieve whatever error is shuttled from `asyncio` before the channel
read is potentially cancelled (previously a race?).

Handle `asyncio.CancelledError` specially such that we raise it directly
(instead of `raise aio_cancelled from other_err`) since it *is* the
source error in the case where the cancellation is `asyncio` internal.
infect_asyncio
Tyler Goodlet 2021-11-17 13:20:04 -05:00
parent 56357242e9
commit 1114b6980e
1 changed files with 26 additions and 10 deletions

View File

@ -3,6 +3,7 @@ Infection apis for ``asyncio`` loops running ``trio`` using guest mode.
''' '''
import asyncio import asyncio
from asyncio.exceptions import CancelledError
from contextlib import asynccontextmanager as acm from contextlib import asynccontextmanager as acm
import inspect import inspect
from typing import ( from typing import (
@ -15,7 +16,7 @@ from typing import (
import trio import trio
from .log import get_logger, get_console_log from .log import get_logger
from ._state import current_actor from ._state import current_actor
log = get_logger(__name__) log = get_logger(__name__)
@ -110,14 +111,16 @@ def _run_asyncio_task(
nonlocal aio_err nonlocal aio_err
try: try:
aio_err = task.exception() aio_err = task.exception()
except asyncio.CancelledError as cerr: except CancelledError as cerr:
log.exception("infected task was cancelled")
# raise
aio_err = cerr aio_err = cerr
if aio_err: if aio_err:
log.exception(f"asyncio task errorred:\n{aio_err}") log.exception(f"infected task errorred with {type(aio_err)}")
from_aio._err = aio_err from_aio._err = aio_err
cancel_scope.cancel()
from_aio.close() from_aio.close()
cancel_scope.cancel()
task.add_done_callback(cancel_trio) task.add_done_callback(cancel_trio)
@ -132,10 +135,11 @@ async def run_task(
**kwargs, **kwargs,
) -> Any: ) -> Any:
"""Run an ``asyncio`` async function or generator in a task, return '''
Run an ``asyncio`` async function or generator in a task, return
or stream the result back to ``trio``. or stream the result back to ``trio``.
""" '''
# simple async func # simple async func
try: try:
task, from_aio, to_trio, cs, _ = _run_asyncio_task( task, from_aio, to_trio, cs, _ = _run_asyncio_task(
@ -151,24 +155,36 @@ async def run_task(
return await from_aio.receive() return await from_aio.receive()
if cs.cancelled_caught: if cs.cancelled_caught:
aio_err = from_aio._err
# always raise from any captured asyncio error # always raise from any captured asyncio error
if from_aio._err: if aio_err:
raise from_aio._err raise aio_err
# Do we need this? # Do we need this?
except BaseException as err: except (
Exception,
CancelledError,
) as err:
aio_err = from_aio._err aio_err = from_aio._err
if aio_err is not None: if (
aio_err is not None and
type(aio_err) != CancelledError
):
# always raise from any captured asyncio error # always raise from any captured asyncio error
raise err from aio_err raise err from aio_err
else: else:
raise raise
finally: finally:
if not task.done(): if not task.done():
task.cancel() task.cancel()
# if task.cancelled():
# ... do what ..
# TODO: explicitly api for the streaming case where # TODO: explicitly api for the streaming case where
# we pull from the mem chan in an async generator? # we pull from the mem chan in an async generator?