Return channel type from `_run_asyncio_task()`

Better encapsulate all the mem-chan, Queue, sync-primitives inside our
linked task channel in order to avoid `mypy`'s complaints about monkey
patching. This also sets footing for adding an `asyncio`-side channel
API that can be used more like this `trio`-side API.
infect_asyncio
Tyler Goodlet 2021-11-28 12:38:37 -05:00
parent 9a2de90de6
commit 56cc98375e
1 changed files with 88 additions and 78 deletions

View File

@ -27,14 +27,57 @@ log = get_logger(__name__)
__all__ = ['run_task', 'run_as_asyncio_guest']
@dataclass
class LinkedTaskChannel(trio.abc.Channel):
'''
A "linked task channel" which allows for two-way synchronized msg
passing between a ``trio``-in-guest-mode task and an ``asyncio``
task scheduled in the host loop.
'''
_to_aio: asyncio.Queue
_from_aio: trio.MemoryReceiveChannel
_to_trio: trio.MemorySendChannel
_trio_cs: trio.CancelScope
_aio_task_complete: trio.Event
# set after ``asyncio.create_task()``
_aio_task: Optional[asyncio.Task] = None
_aio_err: Optional[BaseException] = None
async def aclose(self) -> None:
await self._from_aio.aclose()
async def receive(self) -> Any:
async with translate_aio_errors(self):
return await self._from_aio.receive()
async def wait_ayncio_complete(self) -> None:
await self._aio_task_complete.wait()
# def cancel_asyncio_task(self) -> None:
# self._aio_task.cancel()
async def send(self, item: Any) -> None:
'''
Send a value through to the asyncio task presuming
it defines a ``from_trio`` argument, if it does not
this method will raise an error.
'''
self._to_aio.put_nowait(item)
def _run_asyncio_task(
func: Callable,
*,
qsize: int = 1,
provide_channels: bool = False,
**kwargs,
) -> Any:
) -> LinkedTaskChannel:
'''
Run an ``asyncio`` async function or generator in a task, return
or stream the result back to ``trio``.
@ -45,11 +88,9 @@ def _run_asyncio_task(
# ITC (inter task comms), these channel/queue names are mostly from
# ``asyncio``'s perspective.
from_trio = asyncio.Queue(qsize) # type: ignore
aio_q = from_trio = asyncio.Queue(qsize) # type: ignore
to_trio, from_aio = trio.open_memory_channel(qsize) # type: ignore
from_aio._err = None
args = tuple(inspect.getfullargspec(func).args)
if getattr(func, '_tractor_steam_function', None):
@ -74,6 +115,15 @@ def _run_asyncio_task(
aio_task_complete = trio.Event()
aio_err: Optional[BaseException] = None
chan = LinkedTaskChannel(
aio_q, # asyncio.Queue
from_aio, # recv chan
to_trio, # send chan
cancel_scope,
aio_task_complete,
)
async def wait_on_coro_final_result(
to_trio: trio.MemorySendChannel,
@ -86,12 +136,13 @@ def _run_asyncio_task(
'''
nonlocal aio_err
nonlocal chan
orig = result = id(coro)
try:
result = await coro
except BaseException as err:
aio_err = err
from_aio._err = aio_err
except BaseException as aio_err:
chan._aio_err = aio_err
raise
else:
@ -116,7 +167,9 @@ def _run_asyncio_task(
aio_task_complete.set()
# start the asyncio task we submitted from trio
if inspect.isawaitable(coro):
if not inspect.isawaitable(coro):
raise TypeError(f"No support for invoking {coro}")
task = asyncio.create_task(
wait_on_coro_final_result(
to_trio,
@ -124,17 +177,15 @@ def _run_asyncio_task(
aio_task_complete
)
)
else:
raise TypeError(f"No support for invoking {coro}")
chan._aio_task = task
def cancel_trio(task: asyncio.Task) -> None:
'''
Cancel the calling ``trio`` task on error.
'''
nonlocal aio_err
aio_err = from_aio._err
nonlocal chan
aio_err = chan._aio_err
# only to avoid ``asyncio`` complaining about uncaptured
# task exceptions
@ -159,27 +210,26 @@ def _run_asyncio_task(
task.add_done_callback(cancel_trio)
return task, from_aio, to_trio, from_trio, cancel_scope, aio_task_complete
return chan
@acm
async def translate_aio_errors(
from_aio: trio.MemoryReceiveChannel,
task: asyncio.Task,
chan: LinkedTaskChannel,
) -> None:
) -> AsyncIterator[None]:
'''
Error handling context around ``asyncio`` task spawns which
appropriately translates errors and cancels into ``trio`` land.
'''
aio_err: Optional[Exception] = None
aio_err: Optional[BaseException] = None
def maybe_raise_aio_err(
err: Optional[Exception] = None
) -> None:
aio_err = from_aio._err
aio_err = chan._aio_err
if (
aio_err is not None and
type(aio_err) != CancelledError
@ -189,6 +239,9 @@ async def translate_aio_errors(
raise aio_err from err
else:
raise aio_err
task = chan._aio_task
assert task
try:
yield
except (
@ -196,7 +249,7 @@ async def translate_aio_errors(
# termination callback
trio.ClosedResourceError,
):
aio_err = from_aio._err
aio_err = chan._aio_err
if (
task.cancelled() and
type(aio_err) is CancelledError
@ -234,65 +287,26 @@ async def run_task(
'''
# simple async func
task, from_aio, to_trio, aio_q, cs, _ = _run_asyncio_task(
chan = _run_asyncio_task(
func,
qsize=1,
**kwargs,
)
with from_aio:
with chan._from_aio:
# try:
async with translate_aio_errors(from_aio, task):
async with translate_aio_errors(chan):
# return single value that is the output from the
# ``asyncio`` function-as-task. Expect the mem chan api to
# do the job of handling cross-framework cancellations
# / errors via closure and translation in the
# ``translate_aio_errors()`` in the above ctx mngr.
return await from_aio.receive()
@dataclass
class LinkedTaskChannel(trio.abc.Channel):
'''
A "linked task channel" which allows for two-way synchronized msg
passing between a ``trio``-in-guest-mode task and an ``asyncio``
task.
'''
_aio_task: asyncio.Task
_to_aio: asyncio.Queue
_from_aio: trio.MemoryReceiveChannel
_aio_task_complete: trio.Event
async def aclose(self) -> None:
self._from_aio.close()
async def receive(self) -> Any:
async with translate_aio_errors(
self._from_aio,
self._aio_task,
):
return await self._from_aio.receive()
async def wait_ayncio_complete(self) -> None:
await self._aio_task_complete.wait()
# def cancel_asyncio_task(self) -> None:
# self._aio_task.cancel()
async def send(self, item: Any) -> None:
'''
Send a value through to the asyncio task presuming
it defines a ``from_trio`` argument, if it does not
this method will raise an error.
'''
self._to_aio.put_nowait(item)
return await chan.receive()
@acm
async def open_channel_from(
target: Callable[[Any, ...], Any],
target: Callable[..., Any],
**kwargs,
) -> AsyncIterator[Any]:
@ -301,21 +315,17 @@ async def open_channel_from(
spawned ``asyncio`` task and ``trio``.
'''
task, from_aio, to_trio, aio_q, cs, aio_task_complete = _run_asyncio_task(
chan = _run_asyncio_task(
target,
qsize=2**8,
provide_channels=True,
**kwargs,
)
chan = LinkedTaskChannel(
task, aio_q, from_aio,
aio_task_complete
)
async with from_aio:
async with translate_aio_errors(from_aio, task):
async with chan._from_aio:
async with translate_aio_errors(chan):
# sync to a "started()"-like first delivered value from the
# ``asyncio`` task.
first = await from_aio.receive()
first = await chan.receive()
# stream values upward
yield first, chan