forked from goodboy/tractor
				
			Re-wrap and raise `asyncio.CancelledError`
For whatever reason `trio` seems to be swallowing this exception when raised in the `trio` task so instead wrap it in our own non-base exception type: `AsyncioCancelled` and raise that when the `asyncio` task cancels itself internally using `raise <err> from <src_err>` style. Further don't bother cancelling the `trio` task (via cancel scope) since we we can just use the recv mem chan closure error as a signal and explicitly lookup any set asyncio error.infect_asyncio
							parent
							
								
									c48c68c0bc
								
							
						
					
					
						commit
						5f4094691d
					
				| 
						 | 
				
			
			@ -82,6 +82,15 @@ class StreamOverrun(trio.TooSlowError):
 | 
			
		|||
    "This stream was overrun by sender"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AsyncioCancelled(Exception):
 | 
			
		||||
    '''
 | 
			
		||||
    Asyncio cancelled translation (non-base) error
 | 
			
		||||
    for use with the ``to_asyncio`` module
 | 
			
		||||
    to be raised in the ``trio`` side task
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def pack_error(
 | 
			
		||||
    exc: BaseException,
 | 
			
		||||
    tb=None,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,6 +19,7 @@ import trio
 | 
			
		|||
 | 
			
		||||
from .log import get_logger
 | 
			
		||||
from ._state import current_actor
 | 
			
		||||
from ._exceptions import AsyncioCancelled
 | 
			
		||||
 | 
			
		||||
log = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -91,11 +92,9 @@ def _run_asyncio_task(
 | 
			
		|||
        except BaseException as err:
 | 
			
		||||
            aio_err = err
 | 
			
		||||
            from_aio._err = aio_err
 | 
			
		||||
            to_trio.close()
 | 
			
		||||
            from_aio.close()
 | 
			
		||||
            raise
 | 
			
		||||
 | 
			
		||||
        finally:
 | 
			
		||||
        else:
 | 
			
		||||
            if (
 | 
			
		||||
                result != orig and
 | 
			
		||||
                aio_err is None and
 | 
			
		||||
| 
						 | 
				
			
			@ -106,11 +105,13 @@ def _run_asyncio_task(
 | 
			
		|||
            ):
 | 
			
		||||
                to_trio.send_nowait(result)
 | 
			
		||||
 | 
			
		||||
        finally:
 | 
			
		||||
            # if the task was spawned using ``open_channel_from()``
 | 
			
		||||
            # then we close the channels on exit.
 | 
			
		||||
            if provide_channels:
 | 
			
		||||
                # only close the sender side which will relay
 | 
			
		||||
                # a ``trio.EndOfChannel`` to the trio (consumer) side.
 | 
			
		||||
                to_trio.close()
 | 
			
		||||
                from_aio.close()
 | 
			
		||||
 | 
			
		||||
            aio_task_complete.set()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -127,28 +128,27 @@ def _run_asyncio_task(
 | 
			
		|||
    else:
 | 
			
		||||
        raise TypeError(f"No support for invoking {coro}")
 | 
			
		||||
 | 
			
		||||
    def cancel_trio(task) -> None:
 | 
			
		||||
    def cancel_trio(task: asyncio.Task) -> None:
 | 
			
		||||
        '''
 | 
			
		||||
        Cancel the calling ``trio`` task on error.
 | 
			
		||||
 | 
			
		||||
        '''
 | 
			
		||||
        nonlocal aio_err
 | 
			
		||||
        try:
 | 
			
		||||
            aio_err = task.exception()
 | 
			
		||||
        except CancelledError as cerr:
 | 
			
		||||
            log.cancel("infected task was cancelled")
 | 
			
		||||
            from_aio._err = cerr
 | 
			
		||||
            from_aio.close()
 | 
			
		||||
            cancel_scope.cancel()
 | 
			
		||||
        else:
 | 
			
		||||
            if aio_err is not None:
 | 
			
		||||
        aio_err = from_aio._err
 | 
			
		||||
 | 
			
		||||
        if aio_err is not None:
 | 
			
		||||
            if type(aio_err) is CancelledError:
 | 
			
		||||
                log.cancel("infected task was cancelled")
 | 
			
		||||
            else:
 | 
			
		||||
                aio_err.with_traceback(aio_err.__traceback__)
 | 
			
		||||
                log.exception("infected task errorred:")
 | 
			
		||||
                from_aio._err = aio_err
 | 
			
		||||
 | 
			
		||||
                # NOTE: order is opposite here
 | 
			
		||||
                cancel_scope.cancel()
 | 
			
		||||
                from_aio.close()
 | 
			
		||||
            # NOTE: currently mem chan closure may act as a form
 | 
			
		||||
            # of error relay (at least in the ``asyncio.CancelledError``
 | 
			
		||||
            # case) since we have no way to directly trigger a ``trio``
 | 
			
		||||
            # task error without creating a nursery to throw one.
 | 
			
		||||
            # We might want to change this in the future though.
 | 
			
		||||
            from_aio.close()
 | 
			
		||||
 | 
			
		||||
    task.add_done_callback(cancel_trio)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -160,6 +160,7 @@ async def translate_aio_errors(
 | 
			
		|||
 | 
			
		||||
    from_aio: trio.MemoryReceiveChannel,
 | 
			
		||||
    task: asyncio.Task,
 | 
			
		||||
    trio_cs: trio.CancelScope,
 | 
			
		||||
 | 
			
		||||
) -> None:
 | 
			
		||||
    '''
 | 
			
		||||
| 
						 | 
				
			
			@ -167,34 +168,50 @@ async def translate_aio_errors(
 | 
			
		|||
    appropriately translates errors and cancels into ``trio`` land.
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    err: Optional[Exception] = None
 | 
			
		||||
    aio_err: Optional[Exception] = None
 | 
			
		||||
 | 
			
		||||
    def maybe_raise_aio_err(err: Exception):
 | 
			
		||||
    def maybe_raise_aio_err(
 | 
			
		||||
        err: Optional[Exception] = None
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        aio_err = from_aio._err
 | 
			
		||||
        if (
 | 
			
		||||
            aio_err is not None and
 | 
			
		||||
            type(aio_err) != CancelledError
 | 
			
		||||
        ):
 | 
			
		||||
            # always raise from any captured asyncio error
 | 
			
		||||
            raise aio_err from err
 | 
			
		||||
 | 
			
		||||
            if err:
 | 
			
		||||
                raise aio_err from err
 | 
			
		||||
            else:
 | 
			
		||||
                raise aio_err
 | 
			
		||||
    try:
 | 
			
		||||
        yield
 | 
			
		||||
    except (
 | 
			
		||||
        Exception,
 | 
			
		||||
        CancelledError,
 | 
			
		||||
    ) as err:
 | 
			
		||||
        maybe_raise_aio_err(err)
 | 
			
		||||
        raise
 | 
			
		||||
        # NOTE: see the note in the ``cancel_trio()`` asyncio task
 | 
			
		||||
        # termination callback
 | 
			
		||||
        trio.ClosedResourceError,
 | 
			
		||||
    ):
 | 
			
		||||
        aio_err = from_aio._err
 | 
			
		||||
        if (
 | 
			
		||||
            task.cancelled() and
 | 
			
		||||
            type(aio_err) is CancelledError
 | 
			
		||||
        ):
 | 
			
		||||
            # if an underlying ``asyncio.CancelledError`` triggered this
 | 
			
		||||
            # channel close, raise our (non-``BaseException``) wrapper
 | 
			
		||||
            # error: ``AsyncioCancelled`` from that source error.
 | 
			
		||||
            raise AsyncioCancelled from aio_err
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            raise
 | 
			
		||||
    finally:
 | 
			
		||||
        # always cancel the ``asyncio`` task if we've made it this far
 | 
			
		||||
        # and it's not done.
 | 
			
		||||
        if not task.done() and aio_err:
 | 
			
		||||
            # assert not aio_err, 'WTF how did asyncio do this?!'
 | 
			
		||||
            task.cancel()
 | 
			
		||||
 | 
			
		||||
        maybe_raise_aio_err(err)
 | 
			
		||||
        # if task.cancelled():
 | 
			
		||||
        #     ... do what ..
 | 
			
		||||
        # if any ``asyncio`` error was caught, raise it here inline
 | 
			
		||||
        # here in the ``trio`` task
 | 
			
		||||
        maybe_raise_aio_err()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def run_task(
 | 
			
		||||
| 
						 | 
				
			
			@ -216,27 +233,16 @@ async def run_task(
 | 
			
		|||
        qsize=1,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    )
 | 
			
		||||
    async with translate_aio_errors(from_aio, task):
 | 
			
		||||
 | 
			
		||||
        # return single value
 | 
			
		||||
        with cs:
 | 
			
		||||
            # naively expect the mem chan api to do the job
 | 
			
		||||
            # of handling cross-framework cancellations / errors
 | 
			
		||||
    with from_aio:
 | 
			
		||||
        # try:
 | 
			
		||||
        async with translate_aio_errors(from_aio, task, cs):
 | 
			
		||||
            # return single value that is the output from the
 | 
			
		||||
            # ``asyncio`` function-as-task. Expect the mem chan api to
 | 
			
		||||
            # do the job of handling cross-framework cancellations
 | 
			
		||||
            # / errors via closure and translation in the
 | 
			
		||||
            # ``translate_aio_errors()`` in the above ctx mngr.
 | 
			
		||||
            return await from_aio.receive()
 | 
			
		||||
 | 
			
		||||
        if cs.cancelled_caught:
 | 
			
		||||
            aio_err = from_aio._err
 | 
			
		||||
 | 
			
		||||
            # always raise from any captured asyncio error
 | 
			
		||||
            if aio_err:
 | 
			
		||||
                raise aio_err
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO: explicitly api for the streaming case where
 | 
			
		||||
# we pull from the mem chan in an async generator?
 | 
			
		||||
# This ends up looking more like our ``Portal.open_stream_from()``
 | 
			
		||||
# NB: code below is untested.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class LinkedTaskChannel(trio.abc.Channel):
 | 
			
		||||
| 
						 | 
				
			
			@ -250,19 +256,24 @@ class LinkedTaskChannel(trio.abc.Channel):
 | 
			
		|||
    _to_aio: asyncio.Queue
 | 
			
		||||
    _from_aio: trio.MemoryReceiveChannel
 | 
			
		||||
    _aio_task_complete: trio.Event
 | 
			
		||||
    _trio_cs: trio.CancelScope
 | 
			
		||||
 | 
			
		||||
    async def aclose(self) -> None:
 | 
			
		||||
        self._from_aio.close()
 | 
			
		||||
 | 
			
		||||
    async def receive(self) -> Any:
 | 
			
		||||
        async with translate_aio_errors(self._from_aio, self._aio_task):
 | 
			
		||||
        async with translate_aio_errors(
 | 
			
		||||
            self._from_aio,
 | 
			
		||||
            self._aio_task,
 | 
			
		||||
            self._trio_cs,
 | 
			
		||||
        ):
 | 
			
		||||
            return await self._from_aio.receive()
 | 
			
		||||
 | 
			
		||||
    async def wait_ayncio_complete(self) -> None:
 | 
			
		||||
        await self._aio_task_complete.wait()
 | 
			
		||||
 | 
			
		||||
    def cancel_asyncio_task(self) -> None:
 | 
			
		||||
        self._aio_task.cancel()
 | 
			
		||||
    # def cancel_asyncio_task(self) -> None:
 | 
			
		||||
    #     self._aio_task.cancel()
 | 
			
		||||
 | 
			
		||||
    async def send(self, item: Any) -> None:
 | 
			
		||||
        '''
 | 
			
		||||
| 
						 | 
				
			
			@ -292,16 +303,18 @@ async def open_channel_from(
 | 
			
		|||
        provide_channels=True,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    )
 | 
			
		||||
    chan = LinkedTaskChannel(task, aio_q, from_aio, aio_task_complete)
 | 
			
		||||
    with cs:
 | 
			
		||||
        async with translate_aio_errors(from_aio, task):
 | 
			
		||||
    chan = LinkedTaskChannel(
 | 
			
		||||
        task, aio_q, from_aio,
 | 
			
		||||
        aio_task_complete, cs
 | 
			
		||||
    )
 | 
			
		||||
    async with from_aio:
 | 
			
		||||
        async with translate_aio_errors(from_aio, task, cs):
 | 
			
		||||
            # sync to a "started()"-like first delivered value from the
 | 
			
		||||
            # ``asyncio`` task.
 | 
			
		||||
            first = await from_aio.receive()
 | 
			
		||||
 | 
			
		||||
            # stream values upward
 | 
			
		||||
            async with from_aio:
 | 
			
		||||
                yield first, chan
 | 
			
		||||
            yield first, chan
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_as_asyncio_guest(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue