forked from goodboy/tractor
				
			Propagate any spawned `asyncio` task error upwards
This should mostly maintain top level SC principles for any task spawned using `tractor.to_asyncio.run()`. When the `asyncio` task completes make sure to cancel the pertaining `trio` cancel scope and raise any error that may have resulted. Resolves #120pre_bad_close
							parent
							
								
									230c5a87f8
								
							
						
					
					
						commit
						1ba8a82dc6
					
				|  | @ -0,0 +1,99 @@ | ||||||
|  | """ | ||||||
|  | Infection apis for ``asyncio`` loops running ``trio`` using guest mode. | ||||||
|  | """ | ||||||
|  | import asyncio | ||||||
|  | import inspect | ||||||
|  | from typing import ( | ||||||
|  |     Any, | ||||||
|  |     Callable, | ||||||
|  |     AsyncGenerator, | ||||||
|  |     Awaitable, | ||||||
|  |     Union, | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | import trio | ||||||
|  | 
 | ||||||
|  | from ._state import current_actor | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | __all__ = ['run'] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | async def _invoke( | ||||||
|  |     from_trio: trio.abc.ReceiveChannel, | ||||||
|  |     to_trio: asyncio.Queue, | ||||||
|  |     coro: Awaitable, | ||||||
|  | ) -> Union[AsyncGenerator, Awaitable]: | ||||||
|  |     """Await or stream awaiable object based on type into | ||||||
|  |     ``trio`` memory channel. | ||||||
|  |     """ | ||||||
|  |     async def stream_from_gen(c): | ||||||
|  |         async for item in c: | ||||||
|  |             to_trio.send_nowait(item) | ||||||
|  | 
 | ||||||
|  |     async def just_return(c): | ||||||
|  |         to_trio.send_nowait(await c) | ||||||
|  | 
 | ||||||
|  |     if inspect.isasyncgen(coro): | ||||||
|  |         return await stream_from_gen(coro) | ||||||
|  |     elif inspect.iscoroutine(coro): | ||||||
|  |         return await coro | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | async def run( | ||||||
|  |     func: Callable, | ||||||
|  |     qsize: int = 2**10, | ||||||
|  |     **kwargs, | ||||||
|  | ) -> Any: | ||||||
|  |     """Run an ``asyncio`` async function or generator in a task, return | ||||||
|  |     or stream the result back to ``trio``. | ||||||
|  |     """ | ||||||
|  |     assert current_actor()._infected_aio | ||||||
|  | 
 | ||||||
|  |     # ITC (inter task comms) | ||||||
|  |     from_trio = asyncio.Queue(qsize) | ||||||
|  |     to_trio, from_aio = trio.open_memory_channel(qsize) | ||||||
|  | 
 | ||||||
|  |     # allow target func to accept/stream results manually | ||||||
|  |     kwargs['to_trio'] = to_trio | ||||||
|  |     kwargs['from_trio'] = to_trio | ||||||
|  | 
 | ||||||
|  |     coro = func(**kwargs) | ||||||
|  | 
 | ||||||
|  |     cancel_scope = trio.CancelScope() | ||||||
|  | 
 | ||||||
|  |     # start the asyncio task we submitted from trio | ||||||
|  |     # TODO: try out ``anyio`` asyncio based tg here | ||||||
|  |     task = asyncio.create_task(_invoke(from_trio, to_trio, coro)) | ||||||
|  |     err = None | ||||||
|  | 
 | ||||||
|  |     # XXX: I'm not sure this actually does anything... | ||||||
|  |     def cancel_trio(task): | ||||||
|  |         """Cancel the calling ``trio`` task on error. | ||||||
|  |         """ | ||||||
|  |         nonlocal err | ||||||
|  |         err = task.exception() | ||||||
|  |         cancel_scope.cancel() | ||||||
|  | 
 | ||||||
|  |     task.add_done_callback(cancel_trio) | ||||||
|  | 
 | ||||||
|  |     # determine return type async func vs. gen | ||||||
|  |     if inspect.isasyncgen(coro): | ||||||
|  |         # simple async func | ||||||
|  |         async def result(): | ||||||
|  |             with cancel_scope: | ||||||
|  |                 return await from_aio.get() | ||||||
|  |             if cancel_scope.cancelled_caught and err: | ||||||
|  |                 raise err | ||||||
|  | 
 | ||||||
|  |     elif inspect.iscoroutine(coro): | ||||||
|  |         # asycn gen | ||||||
|  |         async def result(): | ||||||
|  |             with cancel_scope: | ||||||
|  |                 async with from_aio: | ||||||
|  |                     async for item in from_aio: | ||||||
|  |                         yield item | ||||||
|  |             if cancel_scope.cancelled_caught and err: | ||||||
|  |                 raise err | ||||||
|  | 
 | ||||||
|  |     return result() | ||||||
		Loading…
	
		Reference in New Issue