Return channel type from `_run_asyncio_task()`

Better encapsulate all the mem-chan, Queue, sync-primitives inside our
linked task channel in order to avoid `mypy`'s complaints about monkey
patching. This also sets footing for adding an `asyncio`-side channel
API that can be used more like this `trio`-side API.
infect_asyncio
Tyler Goodlet 2021-11-28 12:38:37 -05:00
parent 9a2de90de6
commit 56cc98375e
1 changed files with 88 additions and 78 deletions

View File

@ -27,14 +27,57 @@ log = get_logger(__name__)
__all__ = ['run_task', 'run_as_asyncio_guest'] __all__ = ['run_task', 'run_as_asyncio_guest']
@dataclass
class LinkedTaskChannel(trio.abc.Channel):
'''
A "linked task channel" which allows for two-way synchronized msg
passing between a ``trio``-in-guest-mode task and an ``asyncio``
task scheduled in the host loop.
'''
_to_aio: asyncio.Queue
_from_aio: trio.MemoryReceiveChannel
_to_trio: trio.MemorySendChannel
_trio_cs: trio.CancelScope
_aio_task_complete: trio.Event
# set after ``asyncio.create_task()``
_aio_task: Optional[asyncio.Task] = None
_aio_err: Optional[BaseException] = None
async def aclose(self) -> None:
await self._from_aio.aclose()
async def receive(self) -> Any:
async with translate_aio_errors(self):
return await self._from_aio.receive()
async def wait_ayncio_complete(self) -> None:
await self._aio_task_complete.wait()
# def cancel_asyncio_task(self) -> None:
# self._aio_task.cancel()
async def send(self, item: Any) -> None:
'''
Send a value through to the asyncio task presuming
it defines a ``from_trio`` argument, if it does not
this method will raise an error.
'''
self._to_aio.put_nowait(item)
def _run_asyncio_task( def _run_asyncio_task(
func: Callable, func: Callable,
*, *,
qsize: int = 1, qsize: int = 1,
provide_channels: bool = False, provide_channels: bool = False,
**kwargs, **kwargs,
) -> Any: ) -> LinkedTaskChannel:
''' '''
Run an ``asyncio`` async function or generator in a task, return Run an ``asyncio`` async function or generator in a task, return
or stream the result back to ``trio``. or stream the result back to ``trio``.
@ -45,11 +88,9 @@ def _run_asyncio_task(
# ITC (inter task comms), these channel/queue names are mostly from # ITC (inter task comms), these channel/queue names are mostly from
# ``asyncio``'s perspective. # ``asyncio``'s perspective.
from_trio = asyncio.Queue(qsize) # type: ignore aio_q = from_trio = asyncio.Queue(qsize) # type: ignore
to_trio, from_aio = trio.open_memory_channel(qsize) # type: ignore to_trio, from_aio = trio.open_memory_channel(qsize) # type: ignore
from_aio._err = None
args = tuple(inspect.getfullargspec(func).args) args = tuple(inspect.getfullargspec(func).args)
if getattr(func, '_tractor_steam_function', None): if getattr(func, '_tractor_steam_function', None):
@ -74,6 +115,15 @@ def _run_asyncio_task(
aio_task_complete = trio.Event() aio_task_complete = trio.Event()
aio_err: Optional[BaseException] = None aio_err: Optional[BaseException] = None
chan = LinkedTaskChannel(
aio_q, # asyncio.Queue
from_aio, # recv chan
to_trio, # send chan
cancel_scope,
aio_task_complete,
)
async def wait_on_coro_final_result( async def wait_on_coro_final_result(
to_trio: trio.MemorySendChannel, to_trio: trio.MemorySendChannel,
@ -86,12 +136,13 @@ def _run_asyncio_task(
''' '''
nonlocal aio_err nonlocal aio_err
nonlocal chan
orig = result = id(coro) orig = result = id(coro)
try: try:
result = await coro result = await coro
except BaseException as err: except BaseException as aio_err:
aio_err = err chan._aio_err = aio_err
from_aio._err = aio_err
raise raise
else: else:
@ -116,7 +167,9 @@ def _run_asyncio_task(
aio_task_complete.set() aio_task_complete.set()
# start the asyncio task we submitted from trio # start the asyncio task we submitted from trio
if inspect.isawaitable(coro): if not inspect.isawaitable(coro):
raise TypeError(f"No support for invoking {coro}")
task = asyncio.create_task( task = asyncio.create_task(
wait_on_coro_final_result( wait_on_coro_final_result(
to_trio, to_trio,
@ -124,17 +177,15 @@ def _run_asyncio_task(
aio_task_complete aio_task_complete
) )
) )
chan._aio_task = task
else:
raise TypeError(f"No support for invoking {coro}")
def cancel_trio(task: asyncio.Task) -> None: def cancel_trio(task: asyncio.Task) -> None:
''' '''
Cancel the calling ``trio`` task on error. Cancel the calling ``trio`` task on error.
''' '''
nonlocal aio_err nonlocal chan
aio_err = from_aio._err aio_err = chan._aio_err
# only to avoid ``asyncio`` complaining about uncaptured # only to avoid ``asyncio`` complaining about uncaptured
# task exceptions # task exceptions
@ -159,27 +210,26 @@ def _run_asyncio_task(
task.add_done_callback(cancel_trio) task.add_done_callback(cancel_trio)
return task, from_aio, to_trio, from_trio, cancel_scope, aio_task_complete return chan
@acm @acm
async def translate_aio_errors( async def translate_aio_errors(
from_aio: trio.MemoryReceiveChannel, chan: LinkedTaskChannel,
task: asyncio.Task,
) -> None: ) -> AsyncIterator[None]:
''' '''
Error handling context around ``asyncio`` task spawns which Error handling context around ``asyncio`` task spawns which
appropriately translates errors and cancels into ``trio`` land. appropriately translates errors and cancels into ``trio`` land.
''' '''
aio_err: Optional[Exception] = None aio_err: Optional[BaseException] = None
def maybe_raise_aio_err( def maybe_raise_aio_err(
err: Optional[Exception] = None err: Optional[Exception] = None
) -> None: ) -> None:
aio_err = from_aio._err aio_err = chan._aio_err
if ( if (
aio_err is not None and aio_err is not None and
type(aio_err) != CancelledError type(aio_err) != CancelledError
@ -189,6 +239,9 @@ async def translate_aio_errors(
raise aio_err from err raise aio_err from err
else: else:
raise aio_err raise aio_err
task = chan._aio_task
assert task
try: try:
yield yield
except ( except (
@ -196,7 +249,7 @@ async def translate_aio_errors(
# termination callback # termination callback
trio.ClosedResourceError, trio.ClosedResourceError,
): ):
aio_err = from_aio._err aio_err = chan._aio_err
if ( if (
task.cancelled() and task.cancelled() and
type(aio_err) is CancelledError type(aio_err) is CancelledError
@ -234,65 +287,26 @@ async def run_task(
''' '''
# simple async func # simple async func
task, from_aio, to_trio, aio_q, cs, _ = _run_asyncio_task( chan = _run_asyncio_task(
func, func,
qsize=1, qsize=1,
**kwargs, **kwargs,
) )
with from_aio: with chan._from_aio:
# try: # try:
async with translate_aio_errors(from_aio, task): async with translate_aio_errors(chan):
# return single value that is the output from the # return single value that is the output from the
# ``asyncio`` function-as-task. Expect the mem chan api to # ``asyncio`` function-as-task. Expect the mem chan api to
# do the job of handling cross-framework cancellations # do the job of handling cross-framework cancellations
# / errors via closure and translation in the # / errors via closure and translation in the
# ``translate_aio_errors()`` in the above ctx mngr. # ``translate_aio_errors()`` in the above ctx mngr.
return await from_aio.receive() return await chan.receive()
@dataclass
class LinkedTaskChannel(trio.abc.Channel):
'''
A "linked task channel" which allows for two-way synchronized msg
passing between a ``trio``-in-guest-mode task and an ``asyncio``
task.
'''
_aio_task: asyncio.Task
_to_aio: asyncio.Queue
_from_aio: trio.MemoryReceiveChannel
_aio_task_complete: trio.Event
async def aclose(self) -> None:
self._from_aio.close()
async def receive(self) -> Any:
async with translate_aio_errors(
self._from_aio,
self._aio_task,
):
return await self._from_aio.receive()
async def wait_ayncio_complete(self) -> None:
await self._aio_task_complete.wait()
# def cancel_asyncio_task(self) -> None:
# self._aio_task.cancel()
async def send(self, item: Any) -> None:
'''
Send a value through to the asyncio task presuming
it defines a ``from_trio`` argument, if it does not
this method will raise an error.
'''
self._to_aio.put_nowait(item)
@acm @acm
async def open_channel_from( async def open_channel_from(
target: Callable[[Any, ...], Any], target: Callable[..., Any],
**kwargs, **kwargs,
) -> AsyncIterator[Any]: ) -> AsyncIterator[Any]:
@ -301,21 +315,17 @@ async def open_channel_from(
spawned ``asyncio`` task and ``trio``. spawned ``asyncio`` task and ``trio``.
''' '''
task, from_aio, to_trio, aio_q, cs, aio_task_complete = _run_asyncio_task( chan = _run_asyncio_task(
target, target,
qsize=2**8, qsize=2**8,
provide_channels=True, provide_channels=True,
**kwargs, **kwargs,
) )
chan = LinkedTaskChannel( async with chan._from_aio:
task, aio_q, from_aio, async with translate_aio_errors(chan):
aio_task_complete
)
async with from_aio:
async with translate_aio_errors(from_aio, task):
# sync to a "started()"-like first delivered value from the # sync to a "started()"-like first delivered value from the
# ``asyncio`` task. # ``asyncio`` task.
first = await from_aio.receive() first = await chan.receive()
# stream values upward # stream values upward
yield first, chan yield first, chan