diff --git a/tractor/to_asyncio.py b/tractor/to_asyncio.py index 0e85908..6ad8bf5 100644 --- a/tractor/to_asyncio.py +++ b/tractor/to_asyncio.py @@ -27,14 +27,57 @@ log = get_logger(__name__) __all__ = ['run_task', 'run_as_asyncio_guest'] +@dataclass +class LinkedTaskChannel(trio.abc.Channel): + ''' + A "linked task channel" which allows for two-way synchronized msg + passing between a ``trio``-in-guest-mode task and an ``asyncio`` + task scheduled in the host loop. + + ''' + _to_aio: asyncio.Queue + _from_aio: trio.MemoryReceiveChannel + _to_trio: trio.MemorySendChannel + + _trio_cs: trio.CancelScope + _aio_task_complete: trio.Event + + # set after ``asyncio.create_task()`` + _aio_task: Optional[asyncio.Task] = None + _aio_err: Optional[BaseException] = None + + async def aclose(self) -> None: + await self._from_aio.aclose() + + async def receive(self) -> Any: + async with translate_aio_errors(self): + return await self._from_aio.receive() + + async def wait_ayncio_complete(self) -> None: + await self._aio_task_complete.wait() + + # def cancel_asyncio_task(self) -> None: + # self._aio_task.cancel() + + async def send(self, item: Any) -> None: + ''' + Send a value through to the asyncio task presuming + it defines a ``from_trio`` argument, if it does not + this method will raise an error. + + ''' + self._to_aio.put_nowait(item) + + def _run_asyncio_task( + func: Callable, *, qsize: int = 1, provide_channels: bool = False, **kwargs, -) -> Any: +) -> LinkedTaskChannel: ''' Run an ``asyncio`` async function or generator in a task, return or stream the result back to ``trio``. @@ -45,11 +88,9 @@ def _run_asyncio_task( # ITC (inter task comms), these channel/queue names are mostly from # ``asyncio``'s perspective. - from_trio = asyncio.Queue(qsize) # type: ignore + aio_q = from_trio = asyncio.Queue(qsize) # type: ignore to_trio, from_aio = trio.open_memory_channel(qsize) # type: ignore - from_aio._err = None - args = tuple(inspect.getfullargspec(func).args) if getattr(func, '_tractor_steam_function', None): @@ -74,6 +115,15 @@ def _run_asyncio_task( aio_task_complete = trio.Event() aio_err: Optional[BaseException] = None + chan = LinkedTaskChannel( + aio_q, # asyncio.Queue + from_aio, # recv chan + to_trio, # send chan + + cancel_scope, + aio_task_complete, + ) + async def wait_on_coro_final_result( to_trio: trio.MemorySendChannel, @@ -86,12 +136,13 @@ def _run_asyncio_task( ''' nonlocal aio_err + nonlocal chan + orig = result = id(coro) try: result = await coro - except BaseException as err: - aio_err = err - from_aio._err = aio_err + except BaseException as aio_err: + chan._aio_err = aio_err raise else: @@ -116,25 +167,25 @@ def _run_asyncio_task( aio_task_complete.set() # start the asyncio task we submitted from trio - if inspect.isawaitable(coro): - task = asyncio.create_task( - wait_on_coro_final_result( - to_trio, - coro, - aio_task_complete - ) - ) - - else: + if not inspect.isawaitable(coro): raise TypeError(f"No support for invoking {coro}") + task = asyncio.create_task( + wait_on_coro_final_result( + to_trio, + coro, + aio_task_complete + ) + ) + chan._aio_task = task + def cancel_trio(task: asyncio.Task) -> None: ''' Cancel the calling ``trio`` task on error. ''' - nonlocal aio_err - aio_err = from_aio._err + nonlocal chan + aio_err = chan._aio_err # only to avoid ``asyncio`` complaining about uncaptured # task exceptions @@ -159,27 +210,26 @@ def _run_asyncio_task( task.add_done_callback(cancel_trio) - return task, from_aio, to_trio, from_trio, cancel_scope, aio_task_complete + return chan @acm async def translate_aio_errors( - from_aio: trio.MemoryReceiveChannel, - task: asyncio.Task, + chan: LinkedTaskChannel, -) -> None: +) -> AsyncIterator[None]: ''' Error handling context around ``asyncio`` task spawns which appropriately translates errors and cancels into ``trio`` land. ''' - aio_err: Optional[Exception] = None + aio_err: Optional[BaseException] = None def maybe_raise_aio_err( err: Optional[Exception] = None ) -> None: - aio_err = from_aio._err + aio_err = chan._aio_err if ( aio_err is not None and type(aio_err) != CancelledError @@ -189,6 +239,9 @@ async def translate_aio_errors( raise aio_err from err else: raise aio_err + + task = chan._aio_task + assert task try: yield except ( @@ -196,7 +249,7 @@ async def translate_aio_errors( # termination callback trio.ClosedResourceError, ): - aio_err = from_aio._err + aio_err = chan._aio_err if ( task.cancelled() and type(aio_err) is CancelledError @@ -234,65 +287,26 @@ async def run_task( ''' # simple async func - task, from_aio, to_trio, aio_q, cs, _ = _run_asyncio_task( + chan = _run_asyncio_task( func, qsize=1, **kwargs, ) - with from_aio: + with chan._from_aio: # try: - async with translate_aio_errors(from_aio, task): + async with translate_aio_errors(chan): # return single value that is the output from the # ``asyncio`` function-as-task. Expect the mem chan api to # do the job of handling cross-framework cancellations # / errors via closure and translation in the # ``translate_aio_errors()`` in the above ctx mngr. - return await from_aio.receive() - - -@dataclass -class LinkedTaskChannel(trio.abc.Channel): - ''' - A "linked task channel" which allows for two-way synchronized msg - passing between a ``trio``-in-guest-mode task and an ``asyncio`` - task. - - ''' - _aio_task: asyncio.Task - _to_aio: asyncio.Queue - _from_aio: trio.MemoryReceiveChannel - _aio_task_complete: trio.Event - - async def aclose(self) -> None: - self._from_aio.close() - - async def receive(self) -> Any: - async with translate_aio_errors( - self._from_aio, - self._aio_task, - ): - return await self._from_aio.receive() - - async def wait_ayncio_complete(self) -> None: - await self._aio_task_complete.wait() - - # def cancel_asyncio_task(self) -> None: - # self._aio_task.cancel() - - async def send(self, item: Any) -> None: - ''' - Send a value through to the asyncio task presuming - it defines a ``from_trio`` argument, if it does not - this method will raise an error. - - ''' - self._to_aio.put_nowait(item) + return await chan.receive() @acm async def open_channel_from( - target: Callable[[Any, ...], Any], + target: Callable[..., Any], **kwargs, ) -> AsyncIterator[Any]: @@ -301,21 +315,17 @@ async def open_channel_from( spawned ``asyncio`` task and ``trio``. ''' - task, from_aio, to_trio, aio_q, cs, aio_task_complete = _run_asyncio_task( + chan = _run_asyncio_task( target, qsize=2**8, provide_channels=True, **kwargs, ) - chan = LinkedTaskChannel( - task, aio_q, from_aio, - aio_task_complete - ) - async with from_aio: - async with translate_aio_errors(from_aio, task): + async with chan._from_aio: + async with translate_aio_errors(chan): # sync to a "started()"-like first delivered value from the # ``asyncio`` task. - first = await from_aio.receive() + first = await chan.receive() # stream values upward yield first, chan