Factor error translation into a ctx mngr

Pull the common `asyncio` -> `trio` error translation logic into
a common context manager and don't expect a final result to be captured
when using `open_channel_from()` since it's a manager interface and it
would be clunky to try and deliver some "final result" after exit.
infect_asyncio
Tyler Goodlet 2021-11-20 12:43:54 -05:00
parent e6687bcdc4
commit 9bc94b5ccc
1 changed files with 71 additions and 62 deletions

View File

@ -92,14 +92,28 @@ def _run_asyncio_task(
raise raise
finally: finally:
aio_task_complete.set() if (
if result != orig and aio_err is None: result != orig and
aio_err is None and
# in the ``open_channel_from()`` case we don't
# relay through the "return value".
not provide_channels
):
to_trio.send_nowait(result) to_trio.send_nowait(result)
to_trio.close()
from_aio.close()
aio_task_complete.set()
# start the asyncio task we submitted from trio # start the asyncio task we submitted from trio
if inspect.isawaitable(coro): if inspect.isawaitable(coro):
task = asyncio.create_task( task = asyncio.create_task(
wait_on_coro_final_result(to_trio, coro, aio_task_complete) wait_on_coro_final_result(
to_trio,
coro,
aio_task_complete
)
) )
else: else:
@ -120,7 +134,7 @@ def _run_asyncio_task(
cancel_scope.cancel() cancel_scope.cancel()
else: else:
if aio_err is not None: if aio_err is not None:
log.exception(f"infected task errorred:") log.exception("infected task errorred:")
from_aio._err = aio_err from_aio._err = aio_err
# order is opposite here # order is opposite here
cancel_scope.cancel() cancel_scope.cancel()
@ -131,41 +145,20 @@ def _run_asyncio_task(
return task, from_aio, to_trio, cancel_scope, aio_task_complete return task, from_aio, to_trio, cancel_scope, aio_task_complete
async def run_task( @acm
func: Callable, async def translate_aio_errors(
*,
qsize: int = 2**10, from_aio: trio.MemoryReceiveChannel,
**kwargs, task: asyncio.Task,
) -> Any: ) -> None:
''' '''
Run an ``asyncio`` async function or generator in a task, return Error handling context around ``asyncio`` task spawns which
or stream the result back to ``trio``. appropriately translates errors and cancels into ``trio`` land.
''' '''
# simple async func
try: try:
task, from_aio, to_trio, cs, _ = _run_asyncio_task( yield
func,
qsize=1,
**kwargs,
)
# return single value
with cs:
# naively expect the mem chan api to do the job
# of handling cross-framework cancellations / errors
return await from_aio.receive()
if cs.cancelled_caught:
aio_err = from_aio._err
# always raise from any captured asyncio error
if aio_err:
raise aio_err
# Do we need this?
except ( except (
Exception, Exception,
CancelledError, CancelledError,
@ -190,6 +183,41 @@ async def run_task(
# ... do what .. # ... do what ..
async def run_task(
func: Callable,
*,
qsize: int = 2**10,
**kwargs,
) -> Any:
'''
Run an ``asyncio`` async function or generator in a task, return
or stream the result back to ``trio``.
'''
# simple async func
task, from_aio, to_trio, cs, _ = _run_asyncio_task(
func,
qsize=1,
**kwargs,
)
async with translate_aio_errors(from_aio, task):
# return single value
with cs:
# naively expect the mem chan api to do the job
# of handling cross-framework cancellations / errors
return await from_aio.receive()
if cs.cancelled_caught:
aio_err = from_aio._err
# always raise from any captured asyncio error
if aio_err:
raise aio_err
# TODO: explicitly api for the streaming case where # TODO: explicitly api for the streaming case where
# we pull from the mem chan in an async generator? # we pull from the mem chan in an async generator?
# This ends up looking more like our ``Portal.open_stream_from()`` # This ends up looking more like our ``Portal.open_stream_from()``
@ -203,40 +231,21 @@ async def open_channel_from(
) -> AsyncIterator[Any]: ) -> AsyncIterator[Any]:
try:
task, from_aio, to_trio, cs, aio_task_complete = _run_asyncio_task( task, from_aio, to_trio, cs, aio_task_complete = _run_asyncio_task(
target, target,
qsize=2**8, qsize=2**8,
provide_channels=True, provide_channels=True,
**kwargs, **kwargs,
) )
async with translate_aio_errors(from_aio, task):
with cs: with cs:
# sync to "started()" call. # sync to "started()" call.
first = await from_aio.receive() first = await from_aio.receive()
# stream values upward # stream values upward
async with from_aio: async with from_aio:
yield first, from_aio yield first, from_aio
# await aio_task_complete.wait() await aio_task_complete.wait()
except BaseException as err:
aio_err = from_aio._err
if aio_err is not None:
# always raise from any captured asyncio error
raise err from aio_err
else:
raise
finally:
if cs.cancelled_caught:
# always raise from any captured asyncio error
if from_aio._err:
raise from_aio._err
if not task.done():
task.cancel()
def run_as_asyncio_guest( def run_as_asyncio_guest(
@ -284,7 +293,7 @@ def run_as_asyncio_guest(
run_sync_soon_threadsafe=loop.call_soon_threadsafe, run_sync_soon_threadsafe=loop.call_soon_threadsafe,
done_callback=trio_done_callback, done_callback=trio_done_callback,
) )
(await trio_done_fut).unwrap() return (await trio_done_fut).unwrap()
# might as well if it's installed. # might as well if it's installed.
try: try: