Add a `LinkedTaskChannel` for synced inter-loop-streaming

Wraps the pairs of underlying `trio` mem chans and the `asyncio.Queue`
with this new composite which will be delivered from `open_channel_from()`.
This allows for both sending and receiving values from the `asyncio`
task (2 way msg passing) as well controls for cancelling or waiting on
the task.

Factor `asyncio` translation and re-raising logic into a new closure
which is run on both `trio` side error handling as well as on normal
termination to avoid missing `asyncio` errors even when `trio` task
cancellation is handled first.

Only close the `trio` mem chans on `trio` task termination *iff*
the task was spawned using `open_channel_from()`:
- on `open_channel_from()` exit, mem chan closure is the desired semantic
- on `run_task()` we normally only return a single value or error and
  if the channel is closed before the error is raised we may propagate
  a `trio.EndOfChannel` instead of the desired underlying `asyncio`
  task's error
infect_asyncio
Tyler Goodlet 2021-11-22 13:08:00 -05:00
parent d27ddb7bbb
commit 44d0e9fc32
1 changed files with 80 additions and 24 deletions

View File

@ -5,6 +5,7 @@ Infection apis for ``asyncio`` loops running ``trio`` using guest mode.
import asyncio
from asyncio.exceptions import CancelledError
from contextlib import asynccontextmanager as acm
from dataclasses import dataclass
import inspect
from typing import (
Any,
@ -41,7 +42,8 @@ def _run_asyncio_task(
if not current_actor().is_infected_aio():
raise RuntimeError("`infect_asyncio` mode is not enabled!?")
# ITC (inter task comms)
# ITC (inter task comms), these channel/queue names are mostly from
# ``asyncio``'s perspective.
from_trio = asyncio.Queue(qsize) # type: ignore
to_trio, from_aio = trio.open_memory_channel(qsize) # type: ignore
@ -89,6 +91,8 @@ def _run_asyncio_task(
except BaseException as err:
aio_err = err
from_aio._err = aio_err
to_trio.close()
from_aio.close()
raise
finally:
@ -102,8 +106,12 @@ def _run_asyncio_task(
):
to_trio.send_nowait(result)
to_trio.close()
from_aio.close()
# if the task was spawned using ``open_channel_from()``
# then we close the channels on exit.
if provide_channels:
to_trio.close()
from_aio.close()
aio_task_complete.set()
# start the asyncio task we submitted from trio
@ -134,15 +142,17 @@ def _run_asyncio_task(
cancel_scope.cancel()
else:
if aio_err is not None:
aio_err.with_traceback(aio_err.__traceback__)
log.exception("infected task errorred:")
from_aio._err = aio_err
# order is opposite here
# NOTE: order is opposite here
cancel_scope.cancel()
from_aio.close()
task.add_done_callback(cancel_trio)
return task, from_aio, to_trio, cancel_scope, aio_task_complete
return task, from_aio, to_trio, from_trio, cancel_scope, aio_task_complete
@acm
@ -157,28 +167,32 @@ async def translate_aio_errors(
appropriately translates errors and cancels into ``trio`` land.
'''
err: Optional[Exception] = None
aio_err: Optional[Exception] = None
def maybe_raise_aio_err(err: Exception):
aio_err = from_aio._err
if (
aio_err is not None and
type(aio_err) != CancelledError
):
# always raise from any captured asyncio error
raise aio_err from err
try:
yield
except (
Exception,
CancelledError,
) as err:
aio_err = from_aio._err
if (
aio_err is not None and
type(aio_err) != CancelledError
):
# always raise from any captured asyncio error
raise err from aio_err
else:
raise
maybe_raise_aio_err(err)
raise
finally:
if not task.done():
if not task.done() and aio_err:
task.cancel()
maybe_raise_aio_err(err)
# if task.cancelled():
# ... do what ..
@ -197,7 +211,7 @@ async def run_task(
'''
# simple async func
task, from_aio, to_trio, cs, _ = _run_asyncio_task(
task, from_aio, to_trio, aio_q, cs, _ = _run_asyncio_task(
func,
qsize=1,
**kwargs,
@ -224,28 +238,70 @@ async def run_task(
# NB: code below is untested.
@dataclass
class LinkedTaskChannel(trio.abc.Channel):
'''
A "linked task channel" which allows for two-way synchronized msg
passing between a ``trio``-in-guest-mode task and an ``asyncio``
task.
'''
_aio_task: asyncio.Task
_to_aio: asyncio.Queue
_from_aio: trio.MemoryReceiveChannel
_aio_task_complete: trio.Event
async def aclose(self) -> None:
self._from_aio.close()
async def receive(self) -> Any:
async with translate_aio_errors(self._from_aio, self._aio_task):
return await self._from_aio.receive()
async def wait_ayncio_complete(self) -> None:
await self._aio_task_complete.wait()
def cancel_asyncio_task(self) -> None:
self._aio_task.cancel()
async def send(self, item: Any) -> None:
'''
Send a value through to the asyncio task presuming
it defines a ``from_trio`` argument, if it does not
this method will raise an error.
'''
self._to_aio.put_nowait(item)
@acm
async def open_channel_from(
target: Callable[[Any, ...], Any],
**kwargs,
) -> AsyncIterator[Any]:
'''
Open an inter-loop linked task channel for streaming between a target
spawned ``asyncio`` task and ``trio``.
task, from_aio, to_trio, cs, aio_task_complete = _run_asyncio_task(
'''
task, from_aio, to_trio, aio_q, cs, aio_task_complete = _run_asyncio_task(
target,
qsize=2**8,
provide_channels=True,
**kwargs,
)
async with translate_aio_errors(from_aio, task):
with cs:
# sync to "started()" call.
chan = LinkedTaskChannel(task, aio_q, from_aio, aio_task_complete)
with cs:
async with translate_aio_errors(from_aio, task):
# sync to a "started()"-like first delivered value from the
# ``asyncio`` task.
first = await from_aio.receive()
# stream values upward
async with from_aio:
yield first, from_aio
await aio_task_complete.wait()
yield first, chan
def run_as_asyncio_guest(