forked from goodboy/tractor
				
			Factor error translation into a ctx mngr
Pull the common `asyncio` -> `trio` error translation logic into a common context manager and don't expect a final result to be captured when using `open_channel_from()` since it's a manager interface and it would be clunky to try and deliver some "final result" after exit.infect_asyncio
							parent
							
								
									e6687bcdc4
								
							
						
					
					
						commit
						9bc94b5ccc
					
				|  | @ -92,14 +92,28 @@ def _run_asyncio_task( | |||
|             raise | ||||
| 
 | ||||
|         finally: | ||||
|             aio_task_complete.set() | ||||
|             if result != orig and aio_err is None: | ||||
|             if ( | ||||
|                 result != orig and | ||||
|                 aio_err is None and | ||||
| 
 | ||||
|                 # in the ``open_channel_from()`` case we don't | ||||
|                 # relay through the "return value". | ||||
|                 not provide_channels | ||||
|             ): | ||||
|                 to_trio.send_nowait(result) | ||||
| 
 | ||||
|             to_trio.close() | ||||
|             from_aio.close() | ||||
|             aio_task_complete.set() | ||||
| 
 | ||||
|     # start the asyncio task we submitted from trio | ||||
|     if inspect.isawaitable(coro): | ||||
|         task = asyncio.create_task( | ||||
|             wait_on_coro_final_result(to_trio, coro, aio_task_complete) | ||||
|             wait_on_coro_final_result( | ||||
|                 to_trio, | ||||
|                 coro, | ||||
|                 aio_task_complete | ||||
|             ) | ||||
|         ) | ||||
| 
 | ||||
|     else: | ||||
|  | @ -120,7 +134,7 @@ def _run_asyncio_task( | |||
|             cancel_scope.cancel() | ||||
|         else: | ||||
|             if aio_err is not None: | ||||
|                 log.exception(f"infected task errorred:") | ||||
|                 log.exception("infected task errorred:") | ||||
|                 from_aio._err = aio_err | ||||
|                 # order is opposite here | ||||
|                 cancel_scope.cancel() | ||||
|  | @ -131,41 +145,20 @@ def _run_asyncio_task( | |||
|     return task, from_aio, to_trio, cancel_scope, aio_task_complete | ||||
| 
 | ||||
| 
 | ||||
| async def run_task( | ||||
|     func: Callable, | ||||
|     *, | ||||
| @acm | ||||
| async def translate_aio_errors( | ||||
| 
 | ||||
|     qsize: int = 2**10, | ||||
|     **kwargs, | ||||
|     from_aio: trio.MemoryReceiveChannel, | ||||
|     task: asyncio.Task, | ||||
| 
 | ||||
| ) -> Any: | ||||
| ) -> None: | ||||
|     ''' | ||||
|     Run an ``asyncio`` async function or generator in a task, return | ||||
|     or stream the result back to ``trio``. | ||||
|     Error handling context around ``asyncio`` task spawns which | ||||
|     appropriately translates errors and cancels into ``trio`` land. | ||||
| 
 | ||||
|     ''' | ||||
|     # simple async func | ||||
|     try: | ||||
|         task, from_aio, to_trio, cs, _ = _run_asyncio_task( | ||||
|             func, | ||||
|             qsize=1, | ||||
|             **kwargs, | ||||
|         ) | ||||
| 
 | ||||
|         # return single value | ||||
|         with cs: | ||||
|             # naively expect the mem chan api to do the job | ||||
|             # of handling cross-framework cancellations / errors | ||||
|             return await from_aio.receive() | ||||
| 
 | ||||
|         if cs.cancelled_caught: | ||||
|             aio_err = from_aio._err | ||||
| 
 | ||||
|             # always raise from any captured asyncio error | ||||
|             if aio_err: | ||||
|                 raise aio_err | ||||
| 
 | ||||
|     # Do we need this? | ||||
|         yield | ||||
|     except ( | ||||
|         Exception, | ||||
|         CancelledError, | ||||
|  | @ -190,6 +183,41 @@ async def run_task( | |||
|         #     ... do what .. | ||||
| 
 | ||||
| 
 | ||||
| async def run_task( | ||||
|     func: Callable, | ||||
|     *, | ||||
| 
 | ||||
|     qsize: int = 2**10, | ||||
|     **kwargs, | ||||
| 
 | ||||
| ) -> Any: | ||||
|     ''' | ||||
|     Run an ``asyncio`` async function or generator in a task, return | ||||
|     or stream the result back to ``trio``. | ||||
| 
 | ||||
|     ''' | ||||
|     # simple async func | ||||
|     task, from_aio, to_trio, cs, _ = _run_asyncio_task( | ||||
|         func, | ||||
|         qsize=1, | ||||
|         **kwargs, | ||||
|     ) | ||||
|     async with translate_aio_errors(from_aio, task): | ||||
| 
 | ||||
|         # return single value | ||||
|         with cs: | ||||
|             # naively expect the mem chan api to do the job | ||||
|             # of handling cross-framework cancellations / errors | ||||
|             return await from_aio.receive() | ||||
| 
 | ||||
|         if cs.cancelled_caught: | ||||
|             aio_err = from_aio._err | ||||
| 
 | ||||
|             # always raise from any captured asyncio error | ||||
|             if aio_err: | ||||
|                 raise aio_err | ||||
| 
 | ||||
| 
 | ||||
| # TODO: explicitly api for the streaming case where | ||||
| # we pull from the mem chan in an async generator? | ||||
| # This ends up looking more like our ``Portal.open_stream_from()`` | ||||
|  | @ -203,40 +231,21 @@ async def open_channel_from( | |||
| 
 | ||||
| ) -> AsyncIterator[Any]: | ||||
| 
 | ||||
|     try: | ||||
|     task, from_aio, to_trio, cs, aio_task_complete = _run_asyncio_task( | ||||
|         target, | ||||
|         qsize=2**8, | ||||
|         provide_channels=True, | ||||
|         **kwargs, | ||||
|     ) | ||||
| 
 | ||||
|     async with translate_aio_errors(from_aio, task): | ||||
|         with cs: | ||||
|             # sync to "started()" call. | ||||
|             first = await from_aio.receive() | ||||
| 
 | ||||
|             # stream values upward | ||||
|             async with from_aio: | ||||
|                 yield first, from_aio | ||||
|                 # await aio_task_complete.wait() | ||||
| 
 | ||||
|     except BaseException as err: | ||||
| 
 | ||||
|         aio_err = from_aio._err | ||||
| 
 | ||||
|         if aio_err is not None: | ||||
|             # always raise from any captured asyncio error | ||||
|             raise err from aio_err | ||||
|         else: | ||||
|             raise | ||||
| 
 | ||||
|     finally: | ||||
|         if cs.cancelled_caught: | ||||
|             # always raise from any captured asyncio error | ||||
|             if from_aio._err: | ||||
|                 raise from_aio._err | ||||
| 
 | ||||
|         if not task.done(): | ||||
|             task.cancel() | ||||
|                 await aio_task_complete.wait() | ||||
| 
 | ||||
| 
 | ||||
| def run_as_asyncio_guest( | ||||
|  | @ -284,7 +293,7 @@ def run_as_asyncio_guest( | |||
|             run_sync_soon_threadsafe=loop.call_soon_threadsafe, | ||||
|             done_callback=trio_done_callback, | ||||
|         ) | ||||
|         (await trio_done_fut).unwrap() | ||||
|         return (await trio_done_fut).unwrap() | ||||
| 
 | ||||
|     # might as well if it's installed. | ||||
|     try: | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue