From 44d0e9fc3273669b65995b7b19d3eb9c231fb908 Mon Sep 17 00:00:00 2001 From: Tyler Goodlet Date: Mon, 22 Nov 2021 13:08:00 -0500 Subject: [PATCH] Add a `LinkedTaskChannel` for synced inter-loop-streaming Wraps the pairs of underlying `trio` mem chans and the `asyncio.Queue` with this new composite which will be delivered from `open_channel_from()`. This allows for both sending and receiving values from the `asyncio` task (2 way msg passing) as well controls for cancelling or waiting on the task. Factor `asyncio` translation and re-raising logic into a new closure which is run on both `trio` side error handling as well as on normal termination to avoid missing `asyncio` errors even when `trio` task cancellation is handled first. Only close the `trio` mem chans on `trio` task termination *iff* the task was spawned using `open_channel_from()`: - on `open_channel_from()` exit, mem chan closure is the desired semantic - on `run_task()` we normally only return a single value or error and if the channel is closed before the error is raised we may propagate a `trio.EndOfChannel` instead of the desired underlying `asyncio` task's error --- tractor/to_asyncio.py | 104 ++++++++++++++++++++++++++++++++---------- 1 file changed, 80 insertions(+), 24 deletions(-) diff --git a/tractor/to_asyncio.py b/tractor/to_asyncio.py index 5d168f7..6132303 100644 --- a/tractor/to_asyncio.py +++ b/tractor/to_asyncio.py @@ -5,6 +5,7 @@ Infection apis for ``asyncio`` loops running ``trio`` using guest mode. import asyncio from asyncio.exceptions import CancelledError from contextlib import asynccontextmanager as acm +from dataclasses import dataclass import inspect from typing import ( Any, @@ -41,7 +42,8 @@ def _run_asyncio_task( if not current_actor().is_infected_aio(): raise RuntimeError("`infect_asyncio` mode is not enabled!?") - # ITC (inter task comms) + # ITC (inter task comms), these channel/queue names are mostly from + # ``asyncio``'s perspective. from_trio = asyncio.Queue(qsize) # type: ignore to_trio, from_aio = trio.open_memory_channel(qsize) # type: ignore @@ -89,6 +91,8 @@ def _run_asyncio_task( except BaseException as err: aio_err = err from_aio._err = aio_err + to_trio.close() + from_aio.close() raise finally: @@ -102,8 +106,12 @@ def _run_asyncio_task( ): to_trio.send_nowait(result) - to_trio.close() - from_aio.close() + # if the task was spawned using ``open_channel_from()`` + # then we close the channels on exit. + if provide_channels: + to_trio.close() + from_aio.close() + aio_task_complete.set() # start the asyncio task we submitted from trio @@ -134,15 +142,17 @@ def _run_asyncio_task( cancel_scope.cancel() else: if aio_err is not None: + aio_err.with_traceback(aio_err.__traceback__) log.exception("infected task errorred:") from_aio._err = aio_err - # order is opposite here + + # NOTE: order is opposite here cancel_scope.cancel() from_aio.close() task.add_done_callback(cancel_trio) - return task, from_aio, to_trio, cancel_scope, aio_task_complete + return task, from_aio, to_trio, from_trio, cancel_scope, aio_task_complete @acm @@ -157,28 +167,32 @@ async def translate_aio_errors( appropriately translates errors and cancels into ``trio`` land. ''' + err: Optional[Exception] = None + aio_err: Optional[Exception] = None + + def maybe_raise_aio_err(err: Exception): + aio_err = from_aio._err + if ( + aio_err is not None and + type(aio_err) != CancelledError + ): + # always raise from any captured asyncio error + raise aio_err from err + try: yield except ( Exception, CancelledError, ) as err: - - aio_err = from_aio._err - - if ( - aio_err is not None and - type(aio_err) != CancelledError - ): - # always raise from any captured asyncio error - raise err from aio_err - else: - raise + maybe_raise_aio_err(err) + raise finally: - if not task.done(): + if not task.done() and aio_err: task.cancel() + maybe_raise_aio_err(err) # if task.cancelled(): # ... do what .. @@ -197,7 +211,7 @@ async def run_task( ''' # simple async func - task, from_aio, to_trio, cs, _ = _run_asyncio_task( + task, from_aio, to_trio, aio_q, cs, _ = _run_asyncio_task( func, qsize=1, **kwargs, @@ -224,28 +238,70 @@ async def run_task( # NB: code below is untested. +@dataclass +class LinkedTaskChannel(trio.abc.Channel): + ''' + A "linked task channel" which allows for two-way synchronized msg + passing between a ``trio``-in-guest-mode task and an ``asyncio`` + task. + + ''' + _aio_task: asyncio.Task + _to_aio: asyncio.Queue + _from_aio: trio.MemoryReceiveChannel + _aio_task_complete: trio.Event + + async def aclose(self) -> None: + self._from_aio.close() + + async def receive(self) -> Any: + async with translate_aio_errors(self._from_aio, self._aio_task): + return await self._from_aio.receive() + + async def wait_ayncio_complete(self) -> None: + await self._aio_task_complete.wait() + + def cancel_asyncio_task(self) -> None: + self._aio_task.cancel() + + async def send(self, item: Any) -> None: + ''' + Send a value through to the asyncio task presuming + it defines a ``from_trio`` argument, if it does not + this method will raise an error. + + ''' + self._to_aio.put_nowait(item) + + @acm async def open_channel_from( + target: Callable[[Any, ...], Any], **kwargs, ) -> AsyncIterator[Any]: + ''' + Open an inter-loop linked task channel for streaming between a target + spawned ``asyncio`` task and ``trio``. - task, from_aio, to_trio, cs, aio_task_complete = _run_asyncio_task( + ''' + task, from_aio, to_trio, aio_q, cs, aio_task_complete = _run_asyncio_task( target, qsize=2**8, provide_channels=True, **kwargs, ) - async with translate_aio_errors(from_aio, task): - with cs: - # sync to "started()" call. + chan = LinkedTaskChannel(task, aio_q, from_aio, aio_task_complete) + with cs: + async with translate_aio_errors(from_aio, task): + # sync to a "started()"-like first delivered value from the + # ``asyncio`` task. first = await from_aio.receive() # stream values upward async with from_aio: - yield first, from_aio - await aio_task_complete.wait() + yield first, chan def run_as_asyncio_guest(