Merge pull request #318 from goodboy/aio_error_propagation
Add context test that opens an inter-task-channel that errorsaio_error_propagation
commit
4902e184e9
|
@ -13,6 +13,7 @@ import tractor
|
||||||
async def aio_echo_server(
|
async def aio_echo_server(
|
||||||
to_trio: trio.MemorySendChannel,
|
to_trio: trio.MemorySendChannel,
|
||||||
from_trio: asyncio.Queue,
|
from_trio: asyncio.Queue,
|
||||||
|
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
# a first message must be sent **from** this ``asyncio``
|
# a first message must be sent **from** this ``asyncio``
|
||||||
|
|
|
@ -0,0 +1,13 @@
|
||||||
|
Fix a previously undetected ``trio``-``asyncio`` task lifetime linking
|
||||||
|
issue with the ``to_asyncio.open_channel_from()`` api where both sides
|
||||||
|
where not properly waiting/signalling termination and it was possible
|
||||||
|
for ``asyncio``-side errors to not propagate due to a race condition.
|
||||||
|
|
||||||
|
The implementation fix summary is:
|
||||||
|
- add state to signal the end of the ``trio`` side task to be
|
||||||
|
read by the ``asyncio`` side and always cancel any ongoing
|
||||||
|
task in such cases.
|
||||||
|
- always wait on the ``asyncio`` task termination from the ``trio``
|
||||||
|
side on error before maybe raising said error.
|
||||||
|
- always close the ``trio`` mem chan on exit to ensure the other
|
||||||
|
side can detect it and follow.
|
|
@ -11,12 +11,25 @@ import importlib
|
||||||
import pytest
|
import pytest
|
||||||
import trio
|
import trio
|
||||||
import tractor
|
import tractor
|
||||||
from tractor import to_asyncio
|
from tractor import (
|
||||||
from tractor import RemoteActorError
|
to_asyncio,
|
||||||
|
RemoteActorError,
|
||||||
|
)
|
||||||
from tractor.trionics import BroadcastReceiver
|
from tractor.trionics import BroadcastReceiver
|
||||||
|
|
||||||
|
|
||||||
async def sleep_and_err(sleep_for: float = 0.1):
|
async def sleep_and_err(
|
||||||
|
sleep_for: float = 0.1,
|
||||||
|
|
||||||
|
# just signature placeholders for compat with
|
||||||
|
# ``to_asyncio.open_channel_from()``
|
||||||
|
to_trio: Optional[trio.MemorySendChannel] = None,
|
||||||
|
from_trio: Optional[asyncio.Queue] = None,
|
||||||
|
|
||||||
|
):
|
||||||
|
if to_trio:
|
||||||
|
to_trio.send_nowait('start')
|
||||||
|
|
||||||
await asyncio.sleep(sleep_for)
|
await asyncio.sleep(sleep_for)
|
||||||
assert 0
|
assert 0
|
||||||
|
|
||||||
|
@ -146,6 +159,80 @@ def test_trio_cancels_aio(arb_addr):
|
||||||
trio.run(main)
|
trio.run(main)
|
||||||
|
|
||||||
|
|
||||||
|
@tractor.context
|
||||||
|
async def trio_ctx(
|
||||||
|
ctx: tractor.Context,
|
||||||
|
):
|
||||||
|
|
||||||
|
await ctx.started('start')
|
||||||
|
|
||||||
|
# this will block until the ``asyncio`` task sends a "first"
|
||||||
|
# message.
|
||||||
|
with trio.fail_after(2):
|
||||||
|
async with (
|
||||||
|
tractor.to_asyncio.open_channel_from(
|
||||||
|
sleep_and_err,
|
||||||
|
) as (first, chan),
|
||||||
|
|
||||||
|
trio.open_nursery() as n,
|
||||||
|
):
|
||||||
|
|
||||||
|
assert first == 'start'
|
||||||
|
|
||||||
|
# spawn another asyncio task for the cuck of it.
|
||||||
|
n.start_soon(
|
||||||
|
tractor.to_asyncio.run_task,
|
||||||
|
sleep_forever,
|
||||||
|
)
|
||||||
|
await trio.sleep_forever()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'parent_cancels', [False, True],
|
||||||
|
ids='parent_actor_cancels_child={}'.format
|
||||||
|
)
|
||||||
|
def test_context_spawns_aio_task_that_errors(
|
||||||
|
arb_addr,
|
||||||
|
parent_cancels: bool,
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
Verify that spawning a task via an intertask channel ctx mngr that
|
||||||
|
errors correctly propagates the error back from the `asyncio`-side
|
||||||
|
task.
|
||||||
|
|
||||||
|
'''
|
||||||
|
async def main():
|
||||||
|
|
||||||
|
async with tractor.open_nursery() as n:
|
||||||
|
p = await n.start_actor(
|
||||||
|
'aio_daemon',
|
||||||
|
enable_modules=[__name__],
|
||||||
|
infect_asyncio=True,
|
||||||
|
# debug_mode=True,
|
||||||
|
loglevel='cancel',
|
||||||
|
)
|
||||||
|
async with p.open_context(
|
||||||
|
trio_ctx,
|
||||||
|
) as (ctx, first):
|
||||||
|
|
||||||
|
assert first == 'start'
|
||||||
|
|
||||||
|
if parent_cancels:
|
||||||
|
await p.cancel_actor()
|
||||||
|
|
||||||
|
await trio.sleep_forever()
|
||||||
|
|
||||||
|
with pytest.raises(RemoteActorError) as excinfo:
|
||||||
|
trio.run(main)
|
||||||
|
|
||||||
|
err = excinfo.value
|
||||||
|
assert isinstance(err, RemoteActorError)
|
||||||
|
if parent_cancels:
|
||||||
|
assert err.type == trio.Cancelled
|
||||||
|
else:
|
||||||
|
assert err.type == AssertionError
|
||||||
|
|
||||||
|
|
||||||
async def aio_cancel():
|
async def aio_cancel():
|
||||||
''''
|
''''
|
||||||
Cancel urself boi.
|
Cancel urself boi.
|
||||||
|
@ -385,6 +472,8 @@ async def trio_to_aio_echo_server(
|
||||||
print('breaking aio echo loop')
|
print('breaking aio echo loop')
|
||||||
break
|
break
|
||||||
|
|
||||||
|
print('exiting asyncio task')
|
||||||
|
|
||||||
async with to_asyncio.open_channel_from(
|
async with to_asyncio.open_channel_from(
|
||||||
aio_echo_server,
|
aio_echo_server,
|
||||||
) as (first, chan):
|
) as (first, chan):
|
||||||
|
|
|
@ -23,7 +23,6 @@ from asyncio.exceptions import CancelledError
|
||||||
from contextlib import asynccontextmanager as acm
|
from contextlib import asynccontextmanager as acm
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import inspect
|
import inspect
|
||||||
import traceback
|
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
|
@ -63,6 +62,7 @@ class LinkedTaskChannel(trio.abc.Channel):
|
||||||
|
|
||||||
_trio_cs: trio.CancelScope
|
_trio_cs: trio.CancelScope
|
||||||
_aio_task_complete: trio.Event
|
_aio_task_complete: trio.Event
|
||||||
|
_trio_exited: bool = False
|
||||||
|
|
||||||
# set after ``asyncio.create_task()``
|
# set after ``asyncio.create_task()``
|
||||||
_aio_task: Optional[asyncio.Task] = None
|
_aio_task: Optional[asyncio.Task] = None
|
||||||
|
@ -73,7 +73,13 @@ class LinkedTaskChannel(trio.abc.Channel):
|
||||||
await self._from_aio.aclose()
|
await self._from_aio.aclose()
|
||||||
|
|
||||||
async def receive(self) -> Any:
|
async def receive(self) -> Any:
|
||||||
async with translate_aio_errors(self):
|
async with translate_aio_errors(
|
||||||
|
self,
|
||||||
|
|
||||||
|
# XXX: obviously this will deadlock if an on-going stream is
|
||||||
|
# being procesed.
|
||||||
|
# wait_on_aio_task=False,
|
||||||
|
):
|
||||||
|
|
||||||
# TODO: do we need this to guarantee asyncio code get's
|
# TODO: do we need this to guarantee asyncio code get's
|
||||||
# cancelled in the case where the trio side somehow creates
|
# cancelled in the case where the trio side somehow creates
|
||||||
|
@ -210,10 +216,8 @@ def _run_asyncio_task(
|
||||||
orig = result = id(coro)
|
orig = result = id(coro)
|
||||||
try:
|
try:
|
||||||
result = await coro
|
result = await coro
|
||||||
except GeneratorExit:
|
|
||||||
# no need to relay error
|
|
||||||
raise
|
|
||||||
except BaseException as aio_err:
|
except BaseException as aio_err:
|
||||||
|
log.exception('asyncio task errored')
|
||||||
chan._aio_err = aio_err
|
chan._aio_err = aio_err
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@ -237,6 +241,7 @@ def _run_asyncio_task(
|
||||||
to_trio.close()
|
to_trio.close()
|
||||||
|
|
||||||
aio_task_complete.set()
|
aio_task_complete.set()
|
||||||
|
log.runtime(f'`asyncio` task: {task.get_name()} is complete')
|
||||||
|
|
||||||
# start the asyncio task we submitted from trio
|
# start the asyncio task we submitted from trio
|
||||||
if not inspect.isawaitable(coro):
|
if not inspect.isawaitable(coro):
|
||||||
|
@ -291,10 +296,12 @@ def _run_asyncio_task(
|
||||||
elif task_err is None:
|
elif task_err is None:
|
||||||
assert aio_err
|
assert aio_err
|
||||||
aio_err.with_traceback(aio_err.__traceback__)
|
aio_err.with_traceback(aio_err.__traceback__)
|
||||||
msg = ''.join(traceback.format_exception(type(aio_err)))
|
log.error('infected task errorred')
|
||||||
log.error(
|
|
||||||
f'infected task errorred:\n{msg}'
|
# XXX: alway cancel the scope on error
|
||||||
)
|
# in case the trio task is blocking
|
||||||
|
# on a checkpoint.
|
||||||
|
cancel_scope.cancel()
|
||||||
|
|
||||||
# raise any ``asyncio`` side error.
|
# raise any ``asyncio`` side error.
|
||||||
raise aio_err
|
raise aio_err
|
||||||
|
@ -307,6 +314,7 @@ def _run_asyncio_task(
|
||||||
async def translate_aio_errors(
|
async def translate_aio_errors(
|
||||||
|
|
||||||
chan: LinkedTaskChannel,
|
chan: LinkedTaskChannel,
|
||||||
|
wait_on_aio_task: bool = False,
|
||||||
|
|
||||||
) -> AsyncIterator[None]:
|
) -> AsyncIterator[None]:
|
||||||
'''
|
'''
|
||||||
|
@ -318,6 +326,7 @@ async def translate_aio_errors(
|
||||||
|
|
||||||
aio_err: Optional[BaseException] = None
|
aio_err: Optional[BaseException] = None
|
||||||
|
|
||||||
|
# TODO: make thisi a channel method?
|
||||||
def maybe_raise_aio_err(
|
def maybe_raise_aio_err(
|
||||||
err: Optional[Exception] = None
|
err: Optional[Exception] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -367,13 +376,30 @@ async def translate_aio_errors(
|
||||||
raise
|
raise
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# always cancel the ``asyncio`` task if we've made it this far
|
if (
|
||||||
# and it's not done.
|
# NOTE: always cancel the ``asyncio`` task if we've made it
|
||||||
if not task.done() and aio_err:
|
# this far and it's not done.
|
||||||
|
not task.done() and aio_err
|
||||||
|
|
||||||
|
# or the trio side has exited it's surrounding cancel scope
|
||||||
|
# indicating the lifetime of the ``asyncio``-side task
|
||||||
|
# should also be terminated.
|
||||||
|
or chan._trio_exited
|
||||||
|
):
|
||||||
|
log.runtime(
|
||||||
|
f'Cancelling `asyncio`-task: {task.get_name()}'
|
||||||
|
)
|
||||||
# assert not aio_err, 'WTF how did asyncio do this?!'
|
# assert not aio_err, 'WTF how did asyncio do this?!'
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
# if any ``asyncio`` error was caught, raise it here inline
|
# Required to sync with the far end ``asyncio``-task to ensure
|
||||||
|
# any error is captured (via monkeypatching the
|
||||||
|
# ``channel._aio_err``) before calling ``maybe_raise_aio_err()``
|
||||||
|
# below!
|
||||||
|
if wait_on_aio_task:
|
||||||
|
await chan._aio_task_complete.wait()
|
||||||
|
|
||||||
|
# NOTE: if any ``asyncio`` error was caught, raise it here inline
|
||||||
# here in the ``trio`` task
|
# here in the ``trio`` task
|
||||||
maybe_raise_aio_err()
|
maybe_raise_aio_err()
|
||||||
|
|
||||||
|
@ -398,7 +424,10 @@ async def run_task(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
with chan._from_aio:
|
with chan._from_aio:
|
||||||
async with translate_aio_errors(chan):
|
async with translate_aio_errors(
|
||||||
|
chan,
|
||||||
|
wait_on_aio_task=True,
|
||||||
|
):
|
||||||
# return single value that is the output from the
|
# return single value that is the output from the
|
||||||
# ``asyncio`` function-as-task. Expect the mem chan api to
|
# ``asyncio`` function-as-task. Expect the mem chan api to
|
||||||
# do the job of handling cross-framework cancellations
|
# do the job of handling cross-framework cancellations
|
||||||
|
@ -426,13 +455,21 @@ async def open_channel_from(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
async with chan._from_aio:
|
async with chan._from_aio:
|
||||||
async with translate_aio_errors(chan):
|
async with translate_aio_errors(
|
||||||
|
chan,
|
||||||
|
wait_on_aio_task=True,
|
||||||
|
):
|
||||||
# sync to a "started()"-like first delivered value from the
|
# sync to a "started()"-like first delivered value from the
|
||||||
# ``asyncio`` task.
|
# ``asyncio`` task.
|
||||||
first = await chan.receive()
|
first = await chan.receive()
|
||||||
|
|
||||||
# deliver stream handle upward
|
# deliver stream handle upward
|
||||||
|
try:
|
||||||
|
with chan._trio_cs:
|
||||||
yield first, chan
|
yield first, chan
|
||||||
|
finally:
|
||||||
|
chan._trio_exited = True
|
||||||
|
chan._to_trio.close()
|
||||||
|
|
||||||
|
|
||||||
def run_as_asyncio_guest(
|
def run_as_asyncio_guest(
|
||||||
|
@ -482,7 +519,7 @@ def run_as_asyncio_guest(
|
||||||
main_outcome.unwrap()
|
main_outcome.unwrap()
|
||||||
else:
|
else:
|
||||||
trio_done_fut.set_result(main_outcome)
|
trio_done_fut.set_result(main_outcome)
|
||||||
print(f"trio_main finished: {main_outcome!r}")
|
log.runtime(f"trio_main finished: {main_outcome!r}")
|
||||||
|
|
||||||
# start the infection: run trio on the asyncio loop in "guest mode"
|
# start the infection: run trio on the asyncio loop in "guest mode"
|
||||||
log.info(f"Infecting asyncio process with {trio_main}")
|
log.info(f"Infecting asyncio process with {trio_main}")
|
||||||
|
|
Loading…
Reference in New Issue