Add a `LinkedTaskChannel` for synced inter-loop-streaming
Wraps the pairs of underlying `trio` mem chans and the `asyncio.Queue` with this new composite which will be delivered from `open_channel_from()`. This allows for both sending and receiving values from the `asyncio` task (2 way msg passing) as well controls for cancelling or waiting on the task. Factor `asyncio` translation and re-raising logic into a new closure which is run on both `trio` side error handling as well as on normal termination to avoid missing `asyncio` errors even when `trio` task cancellation is handled first. Only close the `trio` mem chans on `trio` task termination *iff* the task was spawned using `open_channel_from()`: - on `open_channel_from()` exit, mem chan closure is the desired semantic - on `run_task()` we normally only return a single value or error and if the channel is closed before the error is raised we may propagate a `trio.EndOfChannel` instead of the desired underlying `asyncio` task's errorinfect_asyncio
parent
d27ddb7bbb
commit
44d0e9fc32
|
@ -5,6 +5,7 @@ Infection apis for ``asyncio`` loops running ``trio`` using guest mode.
|
||||||
import asyncio
|
import asyncio
|
||||||
from asyncio.exceptions import CancelledError
|
from asyncio.exceptions import CancelledError
|
||||||
from contextlib import asynccontextmanager as acm
|
from contextlib import asynccontextmanager as acm
|
||||||
|
from dataclasses import dataclass
|
||||||
import inspect
|
import inspect
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
@ -41,7 +42,8 @@ def _run_asyncio_task(
|
||||||
if not current_actor().is_infected_aio():
|
if not current_actor().is_infected_aio():
|
||||||
raise RuntimeError("`infect_asyncio` mode is not enabled!?")
|
raise RuntimeError("`infect_asyncio` mode is not enabled!?")
|
||||||
|
|
||||||
# ITC (inter task comms)
|
# ITC (inter task comms), these channel/queue names are mostly from
|
||||||
|
# ``asyncio``'s perspective.
|
||||||
from_trio = asyncio.Queue(qsize) # type: ignore
|
from_trio = asyncio.Queue(qsize) # type: ignore
|
||||||
to_trio, from_aio = trio.open_memory_channel(qsize) # type: ignore
|
to_trio, from_aio = trio.open_memory_channel(qsize) # type: ignore
|
||||||
|
|
||||||
|
@ -89,6 +91,8 @@ def _run_asyncio_task(
|
||||||
except BaseException as err:
|
except BaseException as err:
|
||||||
aio_err = err
|
aio_err = err
|
||||||
from_aio._err = aio_err
|
from_aio._err = aio_err
|
||||||
|
to_trio.close()
|
||||||
|
from_aio.close()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
@ -102,8 +106,12 @@ def _run_asyncio_task(
|
||||||
):
|
):
|
||||||
to_trio.send_nowait(result)
|
to_trio.send_nowait(result)
|
||||||
|
|
||||||
|
# if the task was spawned using ``open_channel_from()``
|
||||||
|
# then we close the channels on exit.
|
||||||
|
if provide_channels:
|
||||||
to_trio.close()
|
to_trio.close()
|
||||||
from_aio.close()
|
from_aio.close()
|
||||||
|
|
||||||
aio_task_complete.set()
|
aio_task_complete.set()
|
||||||
|
|
||||||
# start the asyncio task we submitted from trio
|
# start the asyncio task we submitted from trio
|
||||||
|
@ -134,15 +142,17 @@ def _run_asyncio_task(
|
||||||
cancel_scope.cancel()
|
cancel_scope.cancel()
|
||||||
else:
|
else:
|
||||||
if aio_err is not None:
|
if aio_err is not None:
|
||||||
|
aio_err.with_traceback(aio_err.__traceback__)
|
||||||
log.exception("infected task errorred:")
|
log.exception("infected task errorred:")
|
||||||
from_aio._err = aio_err
|
from_aio._err = aio_err
|
||||||
# order is opposite here
|
|
||||||
|
# NOTE: order is opposite here
|
||||||
cancel_scope.cancel()
|
cancel_scope.cancel()
|
||||||
from_aio.close()
|
from_aio.close()
|
||||||
|
|
||||||
task.add_done_callback(cancel_trio)
|
task.add_done_callback(cancel_trio)
|
||||||
|
|
||||||
return task, from_aio, to_trio, cancel_scope, aio_task_complete
|
return task, from_aio, to_trio, from_trio, cancel_scope, aio_task_complete
|
||||||
|
|
||||||
|
|
||||||
@acm
|
@acm
|
||||||
|
@ -157,28 +167,32 @@ async def translate_aio_errors(
|
||||||
appropriately translates errors and cancels into ``trio`` land.
|
appropriately translates errors and cancels into ``trio`` land.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
err: Optional[Exception] = None
|
||||||
|
aio_err: Optional[Exception] = None
|
||||||
|
|
||||||
|
def maybe_raise_aio_err(err: Exception):
|
||||||
|
aio_err = from_aio._err
|
||||||
|
if (
|
||||||
|
aio_err is not None and
|
||||||
|
type(aio_err) != CancelledError
|
||||||
|
):
|
||||||
|
# always raise from any captured asyncio error
|
||||||
|
raise aio_err from err
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
except (
|
except (
|
||||||
Exception,
|
Exception,
|
||||||
CancelledError,
|
CancelledError,
|
||||||
) as err:
|
) as err:
|
||||||
|
maybe_raise_aio_err(err)
|
||||||
aio_err = from_aio._err
|
|
||||||
|
|
||||||
if (
|
|
||||||
aio_err is not None and
|
|
||||||
type(aio_err) != CancelledError
|
|
||||||
):
|
|
||||||
# always raise from any captured asyncio error
|
|
||||||
raise err from aio_err
|
|
||||||
else:
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if not task.done():
|
if not task.done() and aio_err:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
|
maybe_raise_aio_err(err)
|
||||||
# if task.cancelled():
|
# if task.cancelled():
|
||||||
# ... do what ..
|
# ... do what ..
|
||||||
|
|
||||||
|
@ -197,7 +211,7 @@ async def run_task(
|
||||||
|
|
||||||
'''
|
'''
|
||||||
# simple async func
|
# simple async func
|
||||||
task, from_aio, to_trio, cs, _ = _run_asyncio_task(
|
task, from_aio, to_trio, aio_q, cs, _ = _run_asyncio_task(
|
||||||
func,
|
func,
|
||||||
qsize=1,
|
qsize=1,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
@ -224,28 +238,70 @@ async def run_task(
|
||||||
# NB: code below is untested.
|
# NB: code below is untested.
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LinkedTaskChannel(trio.abc.Channel):
|
||||||
|
'''
|
||||||
|
A "linked task channel" which allows for two-way synchronized msg
|
||||||
|
passing between a ``trio``-in-guest-mode task and an ``asyncio``
|
||||||
|
task.
|
||||||
|
|
||||||
|
'''
|
||||||
|
_aio_task: asyncio.Task
|
||||||
|
_to_aio: asyncio.Queue
|
||||||
|
_from_aio: trio.MemoryReceiveChannel
|
||||||
|
_aio_task_complete: trio.Event
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
self._from_aio.close()
|
||||||
|
|
||||||
|
async def receive(self) -> Any:
|
||||||
|
async with translate_aio_errors(self._from_aio, self._aio_task):
|
||||||
|
return await self._from_aio.receive()
|
||||||
|
|
||||||
|
async def wait_ayncio_complete(self) -> None:
|
||||||
|
await self._aio_task_complete.wait()
|
||||||
|
|
||||||
|
def cancel_asyncio_task(self) -> None:
|
||||||
|
self._aio_task.cancel()
|
||||||
|
|
||||||
|
async def send(self, item: Any) -> None:
|
||||||
|
'''
|
||||||
|
Send a value through to the asyncio task presuming
|
||||||
|
it defines a ``from_trio`` argument, if it does not
|
||||||
|
this method will raise an error.
|
||||||
|
|
||||||
|
'''
|
||||||
|
self._to_aio.put_nowait(item)
|
||||||
|
|
||||||
|
|
||||||
@acm
|
@acm
|
||||||
async def open_channel_from(
|
async def open_channel_from(
|
||||||
|
|
||||||
target: Callable[[Any, ...], Any],
|
target: Callable[[Any, ...], Any],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
|
||||||
) -> AsyncIterator[Any]:
|
) -> AsyncIterator[Any]:
|
||||||
|
'''
|
||||||
|
Open an inter-loop linked task channel for streaming between a target
|
||||||
|
spawned ``asyncio`` task and ``trio``.
|
||||||
|
|
||||||
task, from_aio, to_trio, cs, aio_task_complete = _run_asyncio_task(
|
'''
|
||||||
|
task, from_aio, to_trio, aio_q, cs, aio_task_complete = _run_asyncio_task(
|
||||||
target,
|
target,
|
||||||
qsize=2**8,
|
qsize=2**8,
|
||||||
provide_channels=True,
|
provide_channels=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
async with translate_aio_errors(from_aio, task):
|
chan = LinkedTaskChannel(task, aio_q, from_aio, aio_task_complete)
|
||||||
with cs:
|
with cs:
|
||||||
# sync to "started()" call.
|
async with translate_aio_errors(from_aio, task):
|
||||||
|
# sync to a "started()"-like first delivered value from the
|
||||||
|
# ``asyncio`` task.
|
||||||
first = await from_aio.receive()
|
first = await from_aio.receive()
|
||||||
|
|
||||||
# stream values upward
|
# stream values upward
|
||||||
async with from_aio:
|
async with from_aio:
|
||||||
yield first, from_aio
|
yield first, chan
|
||||||
await aio_task_complete.wait()
|
|
||||||
|
|
||||||
|
|
||||||
def run_as_asyncio_guest(
|
def run_as_asyncio_guest(
|
||||||
|
|
Loading…
Reference in New Issue