Re-wrap and raise `asyncio.CancelledError`
For whatever reason `trio` seems to be swallowing this exception when raised in the `trio` task so instead wrap it in our own non-base exception type: `AsyncioCancelled` and raise that when the `asyncio` task cancels itself internally using `raise <err> from <src_err>` style. Further don't bother cancelling the `trio` task (via cancel scope) since we we can just use the recv mem chan closure error as a signal and explicitly lookup any set asyncio error.infect_asyncio
							parent
							
								
									c48c68c0bc
								
							
						
					
					
						commit
						5f4094691d
					
				|  | @ -82,6 +82,15 @@ class StreamOverrun(trio.TooSlowError): | |||
|     "This stream was overrun by sender" | ||||
| 
 | ||||
| 
 | ||||
| class AsyncioCancelled(Exception): | ||||
|     ''' | ||||
|     Asyncio cancelled translation (non-base) error | ||||
|     for use with the ``to_asyncio`` module | ||||
|     to be raised in the ``trio`` side task | ||||
| 
 | ||||
|     ''' | ||||
| 
 | ||||
| 
 | ||||
| def pack_error( | ||||
|     exc: BaseException, | ||||
|     tb=None, | ||||
|  |  | |||
|  | @ -19,6 +19,7 @@ import trio | |||
| 
 | ||||
| from .log import get_logger | ||||
| from ._state import current_actor | ||||
| from ._exceptions import AsyncioCancelled | ||||
| 
 | ||||
| log = get_logger(__name__) | ||||
| 
 | ||||
|  | @ -91,11 +92,9 @@ def _run_asyncio_task( | |||
|         except BaseException as err: | ||||
|             aio_err = err | ||||
|             from_aio._err = aio_err | ||||
|             to_trio.close() | ||||
|             from_aio.close() | ||||
|             raise | ||||
| 
 | ||||
|         finally: | ||||
|         else: | ||||
|             if ( | ||||
|                 result != orig and | ||||
|                 aio_err is None and | ||||
|  | @ -106,11 +105,13 @@ def _run_asyncio_task( | |||
|             ): | ||||
|                 to_trio.send_nowait(result) | ||||
| 
 | ||||
|         finally: | ||||
|             # if the task was spawned using ``open_channel_from()`` | ||||
|             # then we close the channels on exit. | ||||
|             if provide_channels: | ||||
|                 # only close the sender side which will relay | ||||
|                 # a ``trio.EndOfChannel`` to the trio (consumer) side. | ||||
|                 to_trio.close() | ||||
|                 from_aio.close() | ||||
| 
 | ||||
|             aio_task_complete.set() | ||||
| 
 | ||||
|  | @ -127,27 +128,26 @@ def _run_asyncio_task( | |||
|     else: | ||||
|         raise TypeError(f"No support for invoking {coro}") | ||||
| 
 | ||||
|     def cancel_trio(task) -> None: | ||||
|     def cancel_trio(task: asyncio.Task) -> None: | ||||
|         ''' | ||||
|         Cancel the calling ``trio`` task on error. | ||||
| 
 | ||||
|         ''' | ||||
|         nonlocal aio_err | ||||
|         try: | ||||
|             aio_err = task.exception() | ||||
|         except CancelledError as cerr: | ||||
|             log.cancel("infected task was cancelled") | ||||
|             from_aio._err = cerr | ||||
|             from_aio.close() | ||||
|             cancel_scope.cancel() | ||||
|         else: | ||||
|         aio_err = from_aio._err | ||||
| 
 | ||||
|         if aio_err is not None: | ||||
|             if type(aio_err) is CancelledError: | ||||
|                 log.cancel("infected task was cancelled") | ||||
|             else: | ||||
|                 aio_err.with_traceback(aio_err.__traceback__) | ||||
|                 log.exception("infected task errorred:") | ||||
|                 from_aio._err = aio_err | ||||
| 
 | ||||
|                 # NOTE: order is opposite here | ||||
|                 cancel_scope.cancel() | ||||
|             # NOTE: currently mem chan closure may act as a form | ||||
|             # of error relay (at least in the ``asyncio.CancelledError`` | ||||
|             # case) since we have no way to directly trigger a ``trio`` | ||||
|             # task error without creating a nursery to throw one. | ||||
|             # We might want to change this in the future though. | ||||
|             from_aio.close() | ||||
| 
 | ||||
|     task.add_done_callback(cancel_trio) | ||||
|  | @ -160,6 +160,7 @@ async def translate_aio_errors( | |||
| 
 | ||||
|     from_aio: trio.MemoryReceiveChannel, | ||||
|     task: asyncio.Task, | ||||
|     trio_cs: trio.CancelScope, | ||||
| 
 | ||||
| ) -> None: | ||||
|     ''' | ||||
|  | @ -167,34 +168,50 @@ async def translate_aio_errors( | |||
|     appropriately translates errors and cancels into ``trio`` land. | ||||
| 
 | ||||
|     ''' | ||||
|     err: Optional[Exception] = None | ||||
|     aio_err: Optional[Exception] = None | ||||
| 
 | ||||
|     def maybe_raise_aio_err(err: Exception): | ||||
|     def maybe_raise_aio_err( | ||||
|         err: Optional[Exception] = None | ||||
|     ) -> None: | ||||
|         aio_err = from_aio._err | ||||
|         if ( | ||||
|             aio_err is not None and | ||||
|             type(aio_err) != CancelledError | ||||
|         ): | ||||
|             # always raise from any captured asyncio error | ||||
|             if err: | ||||
|                 raise aio_err from err | ||||
| 
 | ||||
|             else: | ||||
|                 raise aio_err | ||||
|     try: | ||||
|         yield | ||||
|     except ( | ||||
|         Exception, | ||||
|         CancelledError, | ||||
|     ) as err: | ||||
|         maybe_raise_aio_err(err) | ||||
|         raise | ||||
|         # NOTE: see the note in the ``cancel_trio()`` asyncio task | ||||
|         # termination callback | ||||
|         trio.ClosedResourceError, | ||||
|     ): | ||||
|         aio_err = from_aio._err | ||||
|         if ( | ||||
|             task.cancelled() and | ||||
|             type(aio_err) is CancelledError | ||||
|         ): | ||||
|             # if an underlying ``asyncio.CancelledError`` triggered this | ||||
|             # channel close, raise our (non-``BaseException``) wrapper | ||||
|             # error: ``AsyncioCancelled`` from that source error. | ||||
|             raise AsyncioCancelled from aio_err | ||||
| 
 | ||||
|         else: | ||||
|             raise | ||||
|     finally: | ||||
|         # always cancel the ``asyncio`` task if we've made it this far | ||||
|         # and it's not done. | ||||
|         if not task.done() and aio_err: | ||||
|             # assert not aio_err, 'WTF how did asyncio do this?!' | ||||
|             task.cancel() | ||||
| 
 | ||||
|         maybe_raise_aio_err(err) | ||||
|         # if task.cancelled(): | ||||
|         #     ... do what .. | ||||
|         # if any ``asyncio`` error was caught, raise it here inline | ||||
|         # here in the ``trio`` task | ||||
|         maybe_raise_aio_err() | ||||
| 
 | ||||
| 
 | ||||
| async def run_task( | ||||
|  | @ -216,27 +233,16 @@ async def run_task( | |||
|         qsize=1, | ||||
|         **kwargs, | ||||
|     ) | ||||
|     async with translate_aio_errors(from_aio, task): | ||||
| 
 | ||||
|         # return single value | ||||
|         with cs: | ||||
|             # naively expect the mem chan api to do the job | ||||
|             # of handling cross-framework cancellations / errors | ||||
|     with from_aio: | ||||
|         # try: | ||||
|         async with translate_aio_errors(from_aio, task, cs): | ||||
|             # return single value that is the output from the | ||||
|             # ``asyncio`` function-as-task. Expect the mem chan api to | ||||
|             # do the job of handling cross-framework cancellations | ||||
|             # / errors via closure and translation in the | ||||
|             # ``translate_aio_errors()`` in the above ctx mngr. | ||||
|             return await from_aio.receive() | ||||
| 
 | ||||
|         if cs.cancelled_caught: | ||||
|             aio_err = from_aio._err | ||||
| 
 | ||||
|             # always raise from any captured asyncio error | ||||
|             if aio_err: | ||||
|                 raise aio_err | ||||
| 
 | ||||
| 
 | ||||
| # TODO: explicitly api for the streaming case where | ||||
| # we pull from the mem chan in an async generator? | ||||
| # This ends up looking more like our ``Portal.open_stream_from()`` | ||||
| # NB: code below is untested. | ||||
| 
 | ||||
| 
 | ||||
| @dataclass | ||||
| class LinkedTaskChannel(trio.abc.Channel): | ||||
|  | @ -250,19 +256,24 @@ class LinkedTaskChannel(trio.abc.Channel): | |||
|     _to_aio: asyncio.Queue | ||||
|     _from_aio: trio.MemoryReceiveChannel | ||||
|     _aio_task_complete: trio.Event | ||||
|     _trio_cs: trio.CancelScope | ||||
| 
 | ||||
|     async def aclose(self) -> None: | ||||
|         self._from_aio.close() | ||||
| 
 | ||||
|     async def receive(self) -> Any: | ||||
|         async with translate_aio_errors(self._from_aio, self._aio_task): | ||||
|         async with translate_aio_errors( | ||||
|             self._from_aio, | ||||
|             self._aio_task, | ||||
|             self._trio_cs, | ||||
|         ): | ||||
|             return await self._from_aio.receive() | ||||
| 
 | ||||
|     async def wait_ayncio_complete(self) -> None: | ||||
|         await self._aio_task_complete.wait() | ||||
| 
 | ||||
|     def cancel_asyncio_task(self) -> None: | ||||
|         self._aio_task.cancel() | ||||
|     # def cancel_asyncio_task(self) -> None: | ||||
|     #     self._aio_task.cancel() | ||||
| 
 | ||||
|     async def send(self, item: Any) -> None: | ||||
|         ''' | ||||
|  | @ -292,15 +303,17 @@ async def open_channel_from( | |||
|         provide_channels=True, | ||||
|         **kwargs, | ||||
|     ) | ||||
|     chan = LinkedTaskChannel(task, aio_q, from_aio, aio_task_complete) | ||||
|     with cs: | ||||
|         async with translate_aio_errors(from_aio, task): | ||||
|     chan = LinkedTaskChannel( | ||||
|         task, aio_q, from_aio, | ||||
|         aio_task_complete, cs | ||||
|     ) | ||||
|     async with from_aio: | ||||
|         async with translate_aio_errors(from_aio, task, cs): | ||||
|             # sync to a "started()"-like first delivered value from the | ||||
|             # ``asyncio`` task. | ||||
|             first = await from_aio.receive() | ||||
| 
 | ||||
|             # stream values upward | ||||
|             async with from_aio: | ||||
|             yield first, chan | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue