Fix `asyncio`-task-sync and error propagation
This fixes an previously undetected bug where if an `.open_channel_from()` spawned task errored the error would not be propagated to the `trio` side and instead would fail silently with a console log error. What was most odd is that it only seems easy to trigger when you put a slight task sleep before the error is raised (:eyeroll:). This patch adds a few things to address this and just in general improve iter-task lifetime syncing: - add `LinkedTaskChannel._trio_exited: bool` a flag set from the `trio` side when the channel block exits. - add a `wait_on_aio_task: bool` flag to `translate_aio_errors` which toggles whether to wait the `asyncio` task termination event on exit. - cancel the `asyncio` task if the trio side has ended, when `._trio_exited == True`. - always close the `trio` mem channel when the task exits such that the `asyncio` side can error on any next `.send()` call.aio_error_propagation^2
							parent
							
								
									98de2fab31
								
							
						
					
					
						commit
						38d03858d7
					
				|  | @ -185,7 +185,7 @@ async def trio_ctx( | |||
|                 tractor.to_asyncio.run_task, | ||||
|                 sleep_forever, | ||||
|             ) | ||||
|             # await trio.sleep_forever() | ||||
|             await trio.sleep_forever() | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|  |  | |||
|  | @ -63,6 +63,7 @@ class LinkedTaskChannel(trio.abc.Channel): | |||
| 
 | ||||
|     _trio_cs: trio.CancelScope | ||||
|     _aio_task_complete: trio.Event | ||||
|     _trio_exited: bool = False | ||||
| 
 | ||||
|     # set after ``asyncio.create_task()`` | ||||
|     _aio_task: Optional[asyncio.Task] = None | ||||
|  | @ -73,7 +74,13 @@ class LinkedTaskChannel(trio.abc.Channel): | |||
|         await self._from_aio.aclose() | ||||
| 
 | ||||
|     async def receive(self) -> Any: | ||||
|         async with translate_aio_errors(self): | ||||
|         async with translate_aio_errors( | ||||
|             self, | ||||
| 
 | ||||
|             # XXX: obviously this will deadlock if an on-going stream is | ||||
|             # being procesed. | ||||
|             # wait_on_aio_task=False, | ||||
|         ): | ||||
| 
 | ||||
|             # TODO: do we need this to guarantee asyncio code get's | ||||
|             # cancelled in the case where the trio side somehow creates | ||||
|  | @ -210,10 +217,8 @@ def _run_asyncio_task( | |||
|         orig = result = id(coro) | ||||
|         try: | ||||
|             result = await coro | ||||
|         except GeneratorExit: | ||||
|             # no need to relay error | ||||
|             raise | ||||
|         except BaseException as aio_err: | ||||
|             log.exception('asyncio task errored') | ||||
|             chan._aio_err = aio_err | ||||
|             raise | ||||
| 
 | ||||
|  | @ -237,6 +242,7 @@ def _run_asyncio_task( | |||
|                 to_trio.close() | ||||
| 
 | ||||
|             aio_task_complete.set() | ||||
|             log.runtime(f'`asyncio` task: {task.get_name()} is complete') | ||||
| 
 | ||||
|     # start the asyncio task we submitted from trio | ||||
|     if not inspect.isawaitable(coro): | ||||
|  | @ -296,6 +302,11 @@ def _run_asyncio_task( | |||
|                     f'infected task errorred:\n{msg}' | ||||
|                 ) | ||||
| 
 | ||||
|             # XXX: alway cancel the scope on error | ||||
|             # in case the trio task is blocking | ||||
|             # on a checkpoint. | ||||
|             cancel_scope.cancel() | ||||
| 
 | ||||
|             # raise any ``asyncio`` side error. | ||||
|             raise aio_err | ||||
| 
 | ||||
|  | @ -307,6 +318,7 @@ def _run_asyncio_task( | |||
| async def translate_aio_errors( | ||||
| 
 | ||||
|     chan: LinkedTaskChannel, | ||||
|     wait_on_aio_task: bool = False, | ||||
| 
 | ||||
| ) -> AsyncIterator[None]: | ||||
|     ''' | ||||
|  | @ -318,6 +330,7 @@ async def translate_aio_errors( | |||
| 
 | ||||
|     aio_err: Optional[BaseException] = None | ||||
| 
 | ||||
|     # TODO: make thisi a channel method? | ||||
|     def maybe_raise_aio_err( | ||||
|         err: Optional[Exception] = None | ||||
|     ) -> None: | ||||
|  | @ -367,13 +380,30 @@ async def translate_aio_errors( | |||
|             raise | ||||
| 
 | ||||
|     finally: | ||||
|         # always cancel the ``asyncio`` task if we've made it this far | ||||
|         # and it's not done. | ||||
|         if not task.done() and aio_err: | ||||
|         if ( | ||||
|             # NOTE: always cancel the ``asyncio`` task if we've made it | ||||
|             # this far and it's not done. | ||||
|             not task.done() and aio_err | ||||
| 
 | ||||
|             # or the trio side has exited it's surrounding cancel scope | ||||
|             # indicating the lifetime of the ``asyncio``-side task | ||||
|             # should also be terminated. | ||||
|             or chan._trio_exited | ||||
|         ): | ||||
|             log.runtime( | ||||
|                 f'Cancelling `asyncio`-task: {chan._aio_taskget_name()}' | ||||
|             ) | ||||
|             # assert not aio_err, 'WTF how did asyncio do this?!' | ||||
|             task.cancel() | ||||
| 
 | ||||
|         # if any ``asyncio`` error was caught, raise it here inline | ||||
|         # Required to sync with the far end ``asyncio``-task to ensure | ||||
|         # any error is captured (via monkeypatching the | ||||
|         # ``channel._aio_err``) before calling ``maybe_raise_aio_err()`` | ||||
|         # below! | ||||
|         if wait_on_aio_task: | ||||
|             await chan._aio_task_complete.wait() | ||||
| 
 | ||||
|         # NOTE: if any ``asyncio`` error was caught, raise it here inline | ||||
|         # here in the ``trio`` task | ||||
|         maybe_raise_aio_err() | ||||
| 
 | ||||
|  | @ -398,7 +428,10 @@ async def run_task( | |||
|         **kwargs, | ||||
|     ) | ||||
|     with chan._from_aio: | ||||
|         async with translate_aio_errors(chan): | ||||
|         async with translate_aio_errors( | ||||
|             chan, | ||||
|             wait_on_aio_task=True, | ||||
|         ): | ||||
|             # return single value that is the output from the | ||||
|             # ``asyncio`` function-as-task. Expect the mem chan api to | ||||
|             # do the job of handling cross-framework cancellations | ||||
|  | @ -426,13 +459,21 @@ async def open_channel_from( | |||
|         **kwargs, | ||||
|     ) | ||||
|     async with chan._from_aio: | ||||
|         async with translate_aio_errors(chan): | ||||
|         async with translate_aio_errors( | ||||
|             chan, | ||||
|             wait_on_aio_task=True, | ||||
|         ): | ||||
|             # sync to a "started()"-like first delivered value from the | ||||
|             # ``asyncio`` task. | ||||
|             first = await chan.receive() | ||||
| 
 | ||||
|             # deliver stream handle upward | ||||
|             yield first, chan | ||||
|             try: | ||||
|                 with chan._trio_cs: | ||||
|                     yield first, chan | ||||
|             finally: | ||||
|                 chan._trio_exited = True | ||||
|                 chan._to_trio.close() | ||||
| 
 | ||||
| 
 | ||||
| def run_as_asyncio_guest( | ||||
|  | @ -482,7 +523,7 @@ def run_as_asyncio_guest( | |||
|                 main_outcome.unwrap() | ||||
|             else: | ||||
|                 trio_done_fut.set_result(main_outcome) | ||||
|                 print(f"trio_main finished: {main_outcome!r}") | ||||
|                 log.runtime(f"trio_main finished: {main_outcome!r}") | ||||
| 
 | ||||
|         # start the infection: run trio on the asyncio loop in "guest mode" | ||||
|         log.info(f"Infecting asyncio process with {trio_main}") | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue