Factor error translation into a ctx mngr

Pull the common `asyncio` -> `trio` error translation logic into
a common context manager and don't expect a final result to be captured
when using `open_channel_from()` since it's a manager interface and it
would be clunky to try and deliver some "final result" after exit.
infect_asyncio
Tyler Goodlet 2021-11-20 12:43:54 -05:00
parent e6687bcdc4
commit 9bc94b5ccc
1 changed files with 71 additions and 62 deletions

View File

@ -92,14 +92,28 @@ def _run_asyncio_task(
raise
finally:
aio_task_complete.set()
if result != orig and aio_err is None:
if (
result != orig and
aio_err is None and
# in the ``open_channel_from()`` case we don't
# relay through the "return value".
not provide_channels
):
to_trio.send_nowait(result)
to_trio.close()
from_aio.close()
aio_task_complete.set()
# start the asyncio task we submitted from trio
if inspect.isawaitable(coro):
task = asyncio.create_task(
wait_on_coro_final_result(to_trio, coro, aio_task_complete)
wait_on_coro_final_result(
to_trio,
coro,
aio_task_complete
)
)
else:
@ -120,7 +134,7 @@ def _run_asyncio_task(
cancel_scope.cancel()
else:
if aio_err is not None:
log.exception(f"infected task errorred:")
log.exception("infected task errorred:")
from_aio._err = aio_err
# order is opposite here
cancel_scope.cancel()
@ -131,41 +145,20 @@ def _run_asyncio_task(
return task, from_aio, to_trio, cancel_scope, aio_task_complete
async def run_task(
func: Callable,
*,
@acm
async def translate_aio_errors(
qsize: int = 2**10,
**kwargs,
from_aio: trio.MemoryReceiveChannel,
task: asyncio.Task,
) -> Any:
) -> None:
'''
Run an ``asyncio`` async function or generator in a task, return
or stream the result back to ``trio``.
Error handling context around ``asyncio`` task spawns which
appropriately translates errors and cancels into ``trio`` land.
'''
# simple async func
try:
task, from_aio, to_trio, cs, _ = _run_asyncio_task(
func,
qsize=1,
**kwargs,
)
# return single value
with cs:
# naively expect the mem chan api to do the job
# of handling cross-framework cancellations / errors
return await from_aio.receive()
if cs.cancelled_caught:
aio_err = from_aio._err
# always raise from any captured asyncio error
if aio_err:
raise aio_err
# Do we need this?
yield
except (
Exception,
CancelledError,
@ -190,6 +183,41 @@ async def run_task(
# ... do what ..
async def run_task(
func: Callable,
*,
qsize: int = 2**10,
**kwargs,
) -> Any:
'''
Run an ``asyncio`` async function or generator in a task, return
or stream the result back to ``trio``.
'''
# simple async func
task, from_aio, to_trio, cs, _ = _run_asyncio_task(
func,
qsize=1,
**kwargs,
)
async with translate_aio_errors(from_aio, task):
# return single value
with cs:
# naively expect the mem chan api to do the job
# of handling cross-framework cancellations / errors
return await from_aio.receive()
if cs.cancelled_caught:
aio_err = from_aio._err
# always raise from any captured asyncio error
if aio_err:
raise aio_err
# TODO: explicitly api for the streaming case where
# we pull from the mem chan in an async generator?
# This ends up looking more like our ``Portal.open_stream_from()``
@ -203,40 +231,21 @@ async def open_channel_from(
) -> AsyncIterator[Any]:
try:
task, from_aio, to_trio, cs, aio_task_complete = _run_asyncio_task(
target,
qsize=2**8,
provide_channels=True,
**kwargs,
)
task, from_aio, to_trio, cs, aio_task_complete = _run_asyncio_task(
target,
qsize=2**8,
provide_channels=True,
**kwargs,
)
async with translate_aio_errors(from_aio, task):
with cs:
# sync to "started()" call.
first = await from_aio.receive()
# stream values upward
async with from_aio:
yield first, from_aio
# await aio_task_complete.wait()
except BaseException as err:
aio_err = from_aio._err
if aio_err is not None:
# always raise from any captured asyncio error
raise err from aio_err
else:
raise
finally:
if cs.cancelled_caught:
# always raise from any captured asyncio error
if from_aio._err:
raise from_aio._err
if not task.done():
task.cancel()
await aio_task_complete.wait()
def run_as_asyncio_guest(
@ -284,7 +293,7 @@ def run_as_asyncio_guest(
run_sync_soon_threadsafe=loop.call_soon_threadsafe,
done_callback=trio_done_callback,
)
(await trio_done_fut).unwrap()
return (await trio_done_fut).unwrap()
# might as well if it's installed.
try: