Return channel type from `_run_asyncio_task()`
Better encapsulate all the mem-chan, Queue, sync-primitives inside our linked task channel in order to avoid `mypy`'s complaints about monkey patching. This also sets footing for adding an `asyncio`-side channel API that can be used more like this `trio`-side API.infect_asyncio
parent
9a2de90de6
commit
56cc98375e
|
@ -27,14 +27,57 @@ log = get_logger(__name__)
|
||||||
__all__ = ['run_task', 'run_as_asyncio_guest']
|
__all__ = ['run_task', 'run_as_asyncio_guest']
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LinkedTaskChannel(trio.abc.Channel):
|
||||||
|
'''
|
||||||
|
A "linked task channel" which allows for two-way synchronized msg
|
||||||
|
passing between a ``trio``-in-guest-mode task and an ``asyncio``
|
||||||
|
task scheduled in the host loop.
|
||||||
|
|
||||||
|
'''
|
||||||
|
_to_aio: asyncio.Queue
|
||||||
|
_from_aio: trio.MemoryReceiveChannel
|
||||||
|
_to_trio: trio.MemorySendChannel
|
||||||
|
|
||||||
|
_trio_cs: trio.CancelScope
|
||||||
|
_aio_task_complete: trio.Event
|
||||||
|
|
||||||
|
# set after ``asyncio.create_task()``
|
||||||
|
_aio_task: Optional[asyncio.Task] = None
|
||||||
|
_aio_err: Optional[BaseException] = None
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
await self._from_aio.aclose()
|
||||||
|
|
||||||
|
async def receive(self) -> Any:
|
||||||
|
async with translate_aio_errors(self):
|
||||||
|
return await self._from_aio.receive()
|
||||||
|
|
||||||
|
async def wait_ayncio_complete(self) -> None:
|
||||||
|
await self._aio_task_complete.wait()
|
||||||
|
|
||||||
|
# def cancel_asyncio_task(self) -> None:
|
||||||
|
# self._aio_task.cancel()
|
||||||
|
|
||||||
|
async def send(self, item: Any) -> None:
|
||||||
|
'''
|
||||||
|
Send a value through to the asyncio task presuming
|
||||||
|
it defines a ``from_trio`` argument, if it does not
|
||||||
|
this method will raise an error.
|
||||||
|
|
||||||
|
'''
|
||||||
|
self._to_aio.put_nowait(item)
|
||||||
|
|
||||||
|
|
||||||
def _run_asyncio_task(
|
def _run_asyncio_task(
|
||||||
|
|
||||||
func: Callable,
|
func: Callable,
|
||||||
*,
|
*,
|
||||||
qsize: int = 1,
|
qsize: int = 1,
|
||||||
provide_channels: bool = False,
|
provide_channels: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
|
||||||
) -> Any:
|
) -> LinkedTaskChannel:
|
||||||
'''
|
'''
|
||||||
Run an ``asyncio`` async function or generator in a task, return
|
Run an ``asyncio`` async function or generator in a task, return
|
||||||
or stream the result back to ``trio``.
|
or stream the result back to ``trio``.
|
||||||
|
@ -45,11 +88,9 @@ def _run_asyncio_task(
|
||||||
|
|
||||||
# ITC (inter task comms), these channel/queue names are mostly from
|
# ITC (inter task comms), these channel/queue names are mostly from
|
||||||
# ``asyncio``'s perspective.
|
# ``asyncio``'s perspective.
|
||||||
from_trio = asyncio.Queue(qsize) # type: ignore
|
aio_q = from_trio = asyncio.Queue(qsize) # type: ignore
|
||||||
to_trio, from_aio = trio.open_memory_channel(qsize) # type: ignore
|
to_trio, from_aio = trio.open_memory_channel(qsize) # type: ignore
|
||||||
|
|
||||||
from_aio._err = None
|
|
||||||
|
|
||||||
args = tuple(inspect.getfullargspec(func).args)
|
args = tuple(inspect.getfullargspec(func).args)
|
||||||
|
|
||||||
if getattr(func, '_tractor_steam_function', None):
|
if getattr(func, '_tractor_steam_function', None):
|
||||||
|
@ -74,6 +115,15 @@ def _run_asyncio_task(
|
||||||
aio_task_complete = trio.Event()
|
aio_task_complete = trio.Event()
|
||||||
aio_err: Optional[BaseException] = None
|
aio_err: Optional[BaseException] = None
|
||||||
|
|
||||||
|
chan = LinkedTaskChannel(
|
||||||
|
aio_q, # asyncio.Queue
|
||||||
|
from_aio, # recv chan
|
||||||
|
to_trio, # send chan
|
||||||
|
|
||||||
|
cancel_scope,
|
||||||
|
aio_task_complete,
|
||||||
|
)
|
||||||
|
|
||||||
async def wait_on_coro_final_result(
|
async def wait_on_coro_final_result(
|
||||||
|
|
||||||
to_trio: trio.MemorySendChannel,
|
to_trio: trio.MemorySendChannel,
|
||||||
|
@ -86,12 +136,13 @@ def _run_asyncio_task(
|
||||||
|
|
||||||
'''
|
'''
|
||||||
nonlocal aio_err
|
nonlocal aio_err
|
||||||
|
nonlocal chan
|
||||||
|
|
||||||
orig = result = id(coro)
|
orig = result = id(coro)
|
||||||
try:
|
try:
|
||||||
result = await coro
|
result = await coro
|
||||||
except BaseException as err:
|
except BaseException as aio_err:
|
||||||
aio_err = err
|
chan._aio_err = aio_err
|
||||||
from_aio._err = aio_err
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -116,25 +167,25 @@ def _run_asyncio_task(
|
||||||
aio_task_complete.set()
|
aio_task_complete.set()
|
||||||
|
|
||||||
# start the asyncio task we submitted from trio
|
# start the asyncio task we submitted from trio
|
||||||
if inspect.isawaitable(coro):
|
if not inspect.isawaitable(coro):
|
||||||
task = asyncio.create_task(
|
|
||||||
wait_on_coro_final_result(
|
|
||||||
to_trio,
|
|
||||||
coro,
|
|
||||||
aio_task_complete
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise TypeError(f"No support for invoking {coro}")
|
raise TypeError(f"No support for invoking {coro}")
|
||||||
|
|
||||||
|
task = asyncio.create_task(
|
||||||
|
wait_on_coro_final_result(
|
||||||
|
to_trio,
|
||||||
|
coro,
|
||||||
|
aio_task_complete
|
||||||
|
)
|
||||||
|
)
|
||||||
|
chan._aio_task = task
|
||||||
|
|
||||||
def cancel_trio(task: asyncio.Task) -> None:
|
def cancel_trio(task: asyncio.Task) -> None:
|
||||||
'''
|
'''
|
||||||
Cancel the calling ``trio`` task on error.
|
Cancel the calling ``trio`` task on error.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
nonlocal aio_err
|
nonlocal chan
|
||||||
aio_err = from_aio._err
|
aio_err = chan._aio_err
|
||||||
|
|
||||||
# only to avoid ``asyncio`` complaining about uncaptured
|
# only to avoid ``asyncio`` complaining about uncaptured
|
||||||
# task exceptions
|
# task exceptions
|
||||||
|
@ -159,27 +210,26 @@ def _run_asyncio_task(
|
||||||
|
|
||||||
task.add_done_callback(cancel_trio)
|
task.add_done_callback(cancel_trio)
|
||||||
|
|
||||||
return task, from_aio, to_trio, from_trio, cancel_scope, aio_task_complete
|
return chan
|
||||||
|
|
||||||
|
|
||||||
@acm
|
@acm
|
||||||
async def translate_aio_errors(
|
async def translate_aio_errors(
|
||||||
|
|
||||||
from_aio: trio.MemoryReceiveChannel,
|
chan: LinkedTaskChannel,
|
||||||
task: asyncio.Task,
|
|
||||||
|
|
||||||
) -> None:
|
) -> AsyncIterator[None]:
|
||||||
'''
|
'''
|
||||||
Error handling context around ``asyncio`` task spawns which
|
Error handling context around ``asyncio`` task spawns which
|
||||||
appropriately translates errors and cancels into ``trio`` land.
|
appropriately translates errors and cancels into ``trio`` land.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
aio_err: Optional[Exception] = None
|
aio_err: Optional[BaseException] = None
|
||||||
|
|
||||||
def maybe_raise_aio_err(
|
def maybe_raise_aio_err(
|
||||||
err: Optional[Exception] = None
|
err: Optional[Exception] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
aio_err = from_aio._err
|
aio_err = chan._aio_err
|
||||||
if (
|
if (
|
||||||
aio_err is not None and
|
aio_err is not None and
|
||||||
type(aio_err) != CancelledError
|
type(aio_err) != CancelledError
|
||||||
|
@ -189,6 +239,9 @@ async def translate_aio_errors(
|
||||||
raise aio_err from err
|
raise aio_err from err
|
||||||
else:
|
else:
|
||||||
raise aio_err
|
raise aio_err
|
||||||
|
|
||||||
|
task = chan._aio_task
|
||||||
|
assert task
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
except (
|
except (
|
||||||
|
@ -196,7 +249,7 @@ async def translate_aio_errors(
|
||||||
# termination callback
|
# termination callback
|
||||||
trio.ClosedResourceError,
|
trio.ClosedResourceError,
|
||||||
):
|
):
|
||||||
aio_err = from_aio._err
|
aio_err = chan._aio_err
|
||||||
if (
|
if (
|
||||||
task.cancelled() and
|
task.cancelled() and
|
||||||
type(aio_err) is CancelledError
|
type(aio_err) is CancelledError
|
||||||
|
@ -234,65 +287,26 @@ async def run_task(
|
||||||
|
|
||||||
'''
|
'''
|
||||||
# simple async func
|
# simple async func
|
||||||
task, from_aio, to_trio, aio_q, cs, _ = _run_asyncio_task(
|
chan = _run_asyncio_task(
|
||||||
func,
|
func,
|
||||||
qsize=1,
|
qsize=1,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
with from_aio:
|
with chan._from_aio:
|
||||||
# try:
|
# try:
|
||||||
async with translate_aio_errors(from_aio, task):
|
async with translate_aio_errors(chan):
|
||||||
# return single value that is the output from the
|
# return single value that is the output from the
|
||||||
# ``asyncio`` function-as-task. Expect the mem chan api to
|
# ``asyncio`` function-as-task. Expect the mem chan api to
|
||||||
# do the job of handling cross-framework cancellations
|
# do the job of handling cross-framework cancellations
|
||||||
# / errors via closure and translation in the
|
# / errors via closure and translation in the
|
||||||
# ``translate_aio_errors()`` in the above ctx mngr.
|
# ``translate_aio_errors()`` in the above ctx mngr.
|
||||||
return await from_aio.receive()
|
return await chan.receive()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LinkedTaskChannel(trio.abc.Channel):
|
|
||||||
'''
|
|
||||||
A "linked task channel" which allows for two-way synchronized msg
|
|
||||||
passing between a ``trio``-in-guest-mode task and an ``asyncio``
|
|
||||||
task.
|
|
||||||
|
|
||||||
'''
|
|
||||||
_aio_task: asyncio.Task
|
|
||||||
_to_aio: asyncio.Queue
|
|
||||||
_from_aio: trio.MemoryReceiveChannel
|
|
||||||
_aio_task_complete: trio.Event
|
|
||||||
|
|
||||||
async def aclose(self) -> None:
|
|
||||||
self._from_aio.close()
|
|
||||||
|
|
||||||
async def receive(self) -> Any:
|
|
||||||
async with translate_aio_errors(
|
|
||||||
self._from_aio,
|
|
||||||
self._aio_task,
|
|
||||||
):
|
|
||||||
return await self._from_aio.receive()
|
|
||||||
|
|
||||||
async def wait_ayncio_complete(self) -> None:
|
|
||||||
await self._aio_task_complete.wait()
|
|
||||||
|
|
||||||
# def cancel_asyncio_task(self) -> None:
|
|
||||||
# self._aio_task.cancel()
|
|
||||||
|
|
||||||
async def send(self, item: Any) -> None:
|
|
||||||
'''
|
|
||||||
Send a value through to the asyncio task presuming
|
|
||||||
it defines a ``from_trio`` argument, if it does not
|
|
||||||
this method will raise an error.
|
|
||||||
|
|
||||||
'''
|
|
||||||
self._to_aio.put_nowait(item)
|
|
||||||
|
|
||||||
|
|
||||||
@acm
|
@acm
|
||||||
async def open_channel_from(
|
async def open_channel_from(
|
||||||
|
|
||||||
target: Callable[[Any, ...], Any],
|
target: Callable[..., Any],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
|
||||||
) -> AsyncIterator[Any]:
|
) -> AsyncIterator[Any]:
|
||||||
|
@ -301,21 +315,17 @@ async def open_channel_from(
|
||||||
spawned ``asyncio`` task and ``trio``.
|
spawned ``asyncio`` task and ``trio``.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
task, from_aio, to_trio, aio_q, cs, aio_task_complete = _run_asyncio_task(
|
chan = _run_asyncio_task(
|
||||||
target,
|
target,
|
||||||
qsize=2**8,
|
qsize=2**8,
|
||||||
provide_channels=True,
|
provide_channels=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
chan = LinkedTaskChannel(
|
async with chan._from_aio:
|
||||||
task, aio_q, from_aio,
|
async with translate_aio_errors(chan):
|
||||||
aio_task_complete
|
|
||||||
)
|
|
||||||
async with from_aio:
|
|
||||||
async with translate_aio_errors(from_aio, task):
|
|
||||||
# sync to a "started()"-like first delivered value from the
|
# sync to a "started()"-like first delivered value from the
|
||||||
# ``asyncio`` task.
|
# ``asyncio`` task.
|
||||||
first = await from_aio.receive()
|
first = await chan.receive()
|
||||||
|
|
||||||
# stream values upward
|
# stream values upward
|
||||||
yield first, chan
|
yield first, chan
|
||||||
|
|
Loading…
Reference in New Issue