Add a `LinkedTaskChannel` for synced inter-loop-streaming

Wraps the pairs of underlying `trio` mem chans and the `asyncio.Queue`
with this new composite which will be delivered from `open_channel_from()`.
This allows for both sending and receiving values from the `asyncio`
task (2 way msg passing) as well controls for cancelling or waiting on
the task.

Factor `asyncio` translation and re-raising logic into a new closure
which is run on both `trio` side error handling as well as on normal
termination to avoid missing `asyncio` errors even when `trio` task
cancellation is handled first.

Only close the `trio` mem chans on `trio` task termination *iff*
the task was spawned using `open_channel_from()`:
- on `open_channel_from()` exit, mem chan closure is the desired semantic
- on `run_task()` we normally only return a single value or error and
  if the channel is closed before the error is raised we may propagate
  a `trio.EndOfChannel` instead of the desired underlying `asyncio`
  task's error
infect_asyncio
Tyler Goodlet 2021-11-22 13:08:00 -05:00
parent d27ddb7bbb
commit 44d0e9fc32
1 changed files with 80 additions and 24 deletions

View File

@ -5,6 +5,7 @@ Infection apis for ``asyncio`` loops running ``trio`` using guest mode.
import asyncio import asyncio
from asyncio.exceptions import CancelledError from asyncio.exceptions import CancelledError
from contextlib import asynccontextmanager as acm from contextlib import asynccontextmanager as acm
from dataclasses import dataclass
import inspect import inspect
from typing import ( from typing import (
Any, Any,
@ -41,7 +42,8 @@ def _run_asyncio_task(
if not current_actor().is_infected_aio(): if not current_actor().is_infected_aio():
raise RuntimeError("`infect_asyncio` mode is not enabled!?") raise RuntimeError("`infect_asyncio` mode is not enabled!?")
# ITC (inter task comms) # ITC (inter task comms), these channel/queue names are mostly from
# ``asyncio``'s perspective.
from_trio = asyncio.Queue(qsize) # type: ignore from_trio = asyncio.Queue(qsize) # type: ignore
to_trio, from_aio = trio.open_memory_channel(qsize) # type: ignore to_trio, from_aio = trio.open_memory_channel(qsize) # type: ignore
@ -89,6 +91,8 @@ def _run_asyncio_task(
except BaseException as err: except BaseException as err:
aio_err = err aio_err = err
from_aio._err = aio_err from_aio._err = aio_err
to_trio.close()
from_aio.close()
raise raise
finally: finally:
@ -102,8 +106,12 @@ def _run_asyncio_task(
): ):
to_trio.send_nowait(result) to_trio.send_nowait(result)
to_trio.close() # if the task was spawned using ``open_channel_from()``
from_aio.close() # then we close the channels on exit.
if provide_channels:
to_trio.close()
from_aio.close()
aio_task_complete.set() aio_task_complete.set()
# start the asyncio task we submitted from trio # start the asyncio task we submitted from trio
@ -134,15 +142,17 @@ def _run_asyncio_task(
cancel_scope.cancel() cancel_scope.cancel()
else: else:
if aio_err is not None: if aio_err is not None:
aio_err.with_traceback(aio_err.__traceback__)
log.exception("infected task errorred:") log.exception("infected task errorred:")
from_aio._err = aio_err from_aio._err = aio_err
# order is opposite here
# NOTE: order is opposite here
cancel_scope.cancel() cancel_scope.cancel()
from_aio.close() from_aio.close()
task.add_done_callback(cancel_trio) task.add_done_callback(cancel_trio)
return task, from_aio, to_trio, cancel_scope, aio_task_complete return task, from_aio, to_trio, from_trio, cancel_scope, aio_task_complete
@acm @acm
@ -157,28 +167,32 @@ async def translate_aio_errors(
appropriately translates errors and cancels into ``trio`` land. appropriately translates errors and cancels into ``trio`` land.
''' '''
err: Optional[Exception] = None
aio_err: Optional[Exception] = None
def maybe_raise_aio_err(err: Exception):
aio_err = from_aio._err
if (
aio_err is not None and
type(aio_err) != CancelledError
):
# always raise from any captured asyncio error
raise aio_err from err
try: try:
yield yield
except ( except (
Exception, Exception,
CancelledError, CancelledError,
) as err: ) as err:
maybe_raise_aio_err(err)
aio_err = from_aio._err raise
if (
aio_err is not None and
type(aio_err) != CancelledError
):
# always raise from any captured asyncio error
raise err from aio_err
else:
raise
finally: finally:
if not task.done(): if not task.done() and aio_err:
task.cancel() task.cancel()
maybe_raise_aio_err(err)
# if task.cancelled(): # if task.cancelled():
# ... do what .. # ... do what ..
@ -197,7 +211,7 @@ async def run_task(
''' '''
# simple async func # simple async func
task, from_aio, to_trio, cs, _ = _run_asyncio_task( task, from_aio, to_trio, aio_q, cs, _ = _run_asyncio_task(
func, func,
qsize=1, qsize=1,
**kwargs, **kwargs,
@ -224,28 +238,70 @@ async def run_task(
# NB: code below is untested. # NB: code below is untested.
@dataclass
class LinkedTaskChannel(trio.abc.Channel):
'''
A "linked task channel" which allows for two-way synchronized msg
passing between a ``trio``-in-guest-mode task and an ``asyncio``
task.
'''
_aio_task: asyncio.Task
_to_aio: asyncio.Queue
_from_aio: trio.MemoryReceiveChannel
_aio_task_complete: trio.Event
async def aclose(self) -> None:
self._from_aio.close()
async def receive(self) -> Any:
async with translate_aio_errors(self._from_aio, self._aio_task):
return await self._from_aio.receive()
async def wait_ayncio_complete(self) -> None:
await self._aio_task_complete.wait()
def cancel_asyncio_task(self) -> None:
self._aio_task.cancel()
async def send(self, item: Any) -> None:
'''
Send a value through to the asyncio task presuming
it defines a ``from_trio`` argument, if it does not
this method will raise an error.
'''
self._to_aio.put_nowait(item)
@acm @acm
async def open_channel_from( async def open_channel_from(
target: Callable[[Any, ...], Any], target: Callable[[Any, ...], Any],
**kwargs, **kwargs,
) -> AsyncIterator[Any]: ) -> AsyncIterator[Any]:
'''
Open an inter-loop linked task channel for streaming between a target
spawned ``asyncio`` task and ``trio``.
task, from_aio, to_trio, cs, aio_task_complete = _run_asyncio_task( '''
task, from_aio, to_trio, aio_q, cs, aio_task_complete = _run_asyncio_task(
target, target,
qsize=2**8, qsize=2**8,
provide_channels=True, provide_channels=True,
**kwargs, **kwargs,
) )
async with translate_aio_errors(from_aio, task): chan = LinkedTaskChannel(task, aio_q, from_aio, aio_task_complete)
with cs: with cs:
# sync to "started()" call. async with translate_aio_errors(from_aio, task):
# sync to a "started()"-like first delivered value from the
# ``asyncio`` task.
first = await from_aio.receive() first = await from_aio.receive()
# stream values upward # stream values upward
async with from_aio: async with from_aio:
yield first, from_aio yield first, chan
await aio_task_complete.wait()
def run_as_asyncio_guest( def run_as_asyncio_guest(