Merge pull request #318 from goodboy/aio_error_propagation

Add context test that opens an inter-task-channel that errors
aio_error_propagation
goodboy 2022-07-15 12:42:19 -04:00 committed by GitHub
commit 4902e184e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 160 additions and 20 deletions

View File

@ -13,6 +13,7 @@ import tractor
async def aio_echo_server( async def aio_echo_server(
to_trio: trio.MemorySendChannel, to_trio: trio.MemorySendChannel,
from_trio: asyncio.Queue, from_trio: asyncio.Queue,
) -> None: ) -> None:
# a first message must be sent **from** this ``asyncio`` # a first message must be sent **from** this ``asyncio``

13
nooz/318.bug.rst 100644
View File

@ -0,0 +1,13 @@
Fix a previously undetected ``trio``-``asyncio`` task lifetime linking
issue with the ``to_asyncio.open_channel_from()`` api where both sides
where not properly waiting/signalling termination and it was possible
for ``asyncio``-side errors to not propagate due to a race condition.
The implementation fix summary is:
- add state to signal the end of the ``trio`` side task to be
read by the ``asyncio`` side and always cancel any ongoing
task in such cases.
- always wait on the ``asyncio`` task termination from the ``trio``
side on error before maybe raising said error.
- always close the ``trio`` mem chan on exit to ensure the other
side can detect it and follow.

View File

@ -11,12 +11,25 @@ import importlib
import pytest import pytest
import trio import trio
import tractor import tractor
from tractor import to_asyncio from tractor import (
from tractor import RemoteActorError to_asyncio,
RemoteActorError,
)
from tractor.trionics import BroadcastReceiver from tractor.trionics import BroadcastReceiver
async def sleep_and_err(sleep_for: float = 0.1): async def sleep_and_err(
sleep_for: float = 0.1,
# just signature placeholders for compat with
# ``to_asyncio.open_channel_from()``
to_trio: Optional[trio.MemorySendChannel] = None,
from_trio: Optional[asyncio.Queue] = None,
):
if to_trio:
to_trio.send_nowait('start')
await asyncio.sleep(sleep_for) await asyncio.sleep(sleep_for)
assert 0 assert 0
@ -146,6 +159,80 @@ def test_trio_cancels_aio(arb_addr):
trio.run(main) trio.run(main)
@tractor.context
async def trio_ctx(
ctx: tractor.Context,
):
await ctx.started('start')
# this will block until the ``asyncio`` task sends a "first"
# message.
with trio.fail_after(2):
async with (
tractor.to_asyncio.open_channel_from(
sleep_and_err,
) as (first, chan),
trio.open_nursery() as n,
):
assert first == 'start'
# spawn another asyncio task for the cuck of it.
n.start_soon(
tractor.to_asyncio.run_task,
sleep_forever,
)
await trio.sleep_forever()
@pytest.mark.parametrize(
'parent_cancels', [False, True],
ids='parent_actor_cancels_child={}'.format
)
def test_context_spawns_aio_task_that_errors(
arb_addr,
parent_cancels: bool,
):
'''
Verify that spawning a task via an intertask channel ctx mngr that
errors correctly propagates the error back from the `asyncio`-side
task.
'''
async def main():
async with tractor.open_nursery() as n:
p = await n.start_actor(
'aio_daemon',
enable_modules=[__name__],
infect_asyncio=True,
# debug_mode=True,
loglevel='cancel',
)
async with p.open_context(
trio_ctx,
) as (ctx, first):
assert first == 'start'
if parent_cancels:
await p.cancel_actor()
await trio.sleep_forever()
with pytest.raises(RemoteActorError) as excinfo:
trio.run(main)
err = excinfo.value
assert isinstance(err, RemoteActorError)
if parent_cancels:
assert err.type == trio.Cancelled
else:
assert err.type == AssertionError
async def aio_cancel(): async def aio_cancel():
'''' ''''
Cancel urself boi. Cancel urself boi.
@ -385,6 +472,8 @@ async def trio_to_aio_echo_server(
print('breaking aio echo loop') print('breaking aio echo loop')
break break
print('exiting asyncio task')
async with to_asyncio.open_channel_from( async with to_asyncio.open_channel_from(
aio_echo_server, aio_echo_server,
) as (first, chan): ) as (first, chan):

View File

@ -23,7 +23,6 @@ from asyncio.exceptions import CancelledError
from contextlib import asynccontextmanager as acm from contextlib import asynccontextmanager as acm
from dataclasses import dataclass from dataclasses import dataclass
import inspect import inspect
import traceback
from typing import ( from typing import (
Any, Any,
Callable, Callable,
@ -63,6 +62,7 @@ class LinkedTaskChannel(trio.abc.Channel):
_trio_cs: trio.CancelScope _trio_cs: trio.CancelScope
_aio_task_complete: trio.Event _aio_task_complete: trio.Event
_trio_exited: bool = False
# set after ``asyncio.create_task()`` # set after ``asyncio.create_task()``
_aio_task: Optional[asyncio.Task] = None _aio_task: Optional[asyncio.Task] = None
@ -73,7 +73,13 @@ class LinkedTaskChannel(trio.abc.Channel):
await self._from_aio.aclose() await self._from_aio.aclose()
async def receive(self) -> Any: async def receive(self) -> Any:
async with translate_aio_errors(self): async with translate_aio_errors(
self,
# XXX: obviously this will deadlock if an on-going stream is
# being procesed.
# wait_on_aio_task=False,
):
# TODO: do we need this to guarantee asyncio code get's # TODO: do we need this to guarantee asyncio code get's
# cancelled in the case where the trio side somehow creates # cancelled in the case where the trio side somehow creates
@ -210,10 +216,8 @@ def _run_asyncio_task(
orig = result = id(coro) orig = result = id(coro)
try: try:
result = await coro result = await coro
except GeneratorExit:
# no need to relay error
raise
except BaseException as aio_err: except BaseException as aio_err:
log.exception('asyncio task errored')
chan._aio_err = aio_err chan._aio_err = aio_err
raise raise
@ -237,6 +241,7 @@ def _run_asyncio_task(
to_trio.close() to_trio.close()
aio_task_complete.set() aio_task_complete.set()
log.runtime(f'`asyncio` task: {task.get_name()} is complete')
# start the asyncio task we submitted from trio # start the asyncio task we submitted from trio
if not inspect.isawaitable(coro): if not inspect.isawaitable(coro):
@ -291,10 +296,12 @@ def _run_asyncio_task(
elif task_err is None: elif task_err is None:
assert aio_err assert aio_err
aio_err.with_traceback(aio_err.__traceback__) aio_err.with_traceback(aio_err.__traceback__)
msg = ''.join(traceback.format_exception(type(aio_err))) log.error('infected task errorred')
log.error(
f'infected task errorred:\n{msg}' # XXX: alway cancel the scope on error
) # in case the trio task is blocking
# on a checkpoint.
cancel_scope.cancel()
# raise any ``asyncio`` side error. # raise any ``asyncio`` side error.
raise aio_err raise aio_err
@ -307,6 +314,7 @@ def _run_asyncio_task(
async def translate_aio_errors( async def translate_aio_errors(
chan: LinkedTaskChannel, chan: LinkedTaskChannel,
wait_on_aio_task: bool = False,
) -> AsyncIterator[None]: ) -> AsyncIterator[None]:
''' '''
@ -318,6 +326,7 @@ async def translate_aio_errors(
aio_err: Optional[BaseException] = None aio_err: Optional[BaseException] = None
# TODO: make thisi a channel method?
def maybe_raise_aio_err( def maybe_raise_aio_err(
err: Optional[Exception] = None err: Optional[Exception] = None
) -> None: ) -> None:
@ -367,13 +376,30 @@ async def translate_aio_errors(
raise raise
finally: finally:
# always cancel the ``asyncio`` task if we've made it this far if (
# and it's not done. # NOTE: always cancel the ``asyncio`` task if we've made it
if not task.done() and aio_err: # this far and it's not done.
not task.done() and aio_err
# or the trio side has exited it's surrounding cancel scope
# indicating the lifetime of the ``asyncio``-side task
# should also be terminated.
or chan._trio_exited
):
log.runtime(
f'Cancelling `asyncio`-task: {task.get_name()}'
)
# assert not aio_err, 'WTF how did asyncio do this?!' # assert not aio_err, 'WTF how did asyncio do this?!'
task.cancel() task.cancel()
# if any ``asyncio`` error was caught, raise it here inline # Required to sync with the far end ``asyncio``-task to ensure
# any error is captured (via monkeypatching the
# ``channel._aio_err``) before calling ``maybe_raise_aio_err()``
# below!
if wait_on_aio_task:
await chan._aio_task_complete.wait()
# NOTE: if any ``asyncio`` error was caught, raise it here inline
# here in the ``trio`` task # here in the ``trio`` task
maybe_raise_aio_err() maybe_raise_aio_err()
@ -398,7 +424,10 @@ async def run_task(
**kwargs, **kwargs,
) )
with chan._from_aio: with chan._from_aio:
async with translate_aio_errors(chan): async with translate_aio_errors(
chan,
wait_on_aio_task=True,
):
# return single value that is the output from the # return single value that is the output from the
# ``asyncio`` function-as-task. Expect the mem chan api to # ``asyncio`` function-as-task. Expect the mem chan api to
# do the job of handling cross-framework cancellations # do the job of handling cross-framework cancellations
@ -426,13 +455,21 @@ async def open_channel_from(
**kwargs, **kwargs,
) )
async with chan._from_aio: async with chan._from_aio:
async with translate_aio_errors(chan): async with translate_aio_errors(
chan,
wait_on_aio_task=True,
):
# sync to a "started()"-like first delivered value from the # sync to a "started()"-like first delivered value from the
# ``asyncio`` task. # ``asyncio`` task.
first = await chan.receive() first = await chan.receive()
# deliver stream handle upward # deliver stream handle upward
try:
with chan._trio_cs:
yield first, chan yield first, chan
finally:
chan._trio_exited = True
chan._to_trio.close()
def run_as_asyncio_guest( def run_as_asyncio_guest(
@ -482,7 +519,7 @@ def run_as_asyncio_guest(
main_outcome.unwrap() main_outcome.unwrap()
else: else:
trio_done_fut.set_result(main_outcome) trio_done_fut.set_result(main_outcome)
print(f"trio_main finished: {main_outcome!r}") log.runtime(f"trio_main finished: {main_outcome!r}")
# start the infection: run trio on the asyncio loop in "guest mode" # start the infection: run trio on the asyncio loop in "guest mode"
log.info(f"Infecting asyncio process with {trio_main}") log.info(f"Infecting asyncio process with {trio_main}")