Factor error translation into a ctx mngr
Pull the common `asyncio` -> `trio` error translation logic into a common context manager and don't expect a final result to be captured when using `open_channel_from()` since it's a manager interface and it would be clunky to try and deliver some "final result" after exit.infect_asyncio
parent
e6687bcdc4
commit
9bc94b5ccc
|
@ -92,14 +92,28 @@ def _run_asyncio_task(
|
|||
raise
|
||||
|
||||
finally:
|
||||
aio_task_complete.set()
|
||||
if result != orig and aio_err is None:
|
||||
if (
|
||||
result != orig and
|
||||
aio_err is None and
|
||||
|
||||
# in the ``open_channel_from()`` case we don't
|
||||
# relay through the "return value".
|
||||
not provide_channels
|
||||
):
|
||||
to_trio.send_nowait(result)
|
||||
|
||||
to_trio.close()
|
||||
from_aio.close()
|
||||
aio_task_complete.set()
|
||||
|
||||
# start the asyncio task we submitted from trio
|
||||
if inspect.isawaitable(coro):
|
||||
task = asyncio.create_task(
|
||||
wait_on_coro_final_result(to_trio, coro, aio_task_complete)
|
||||
wait_on_coro_final_result(
|
||||
to_trio,
|
||||
coro,
|
||||
aio_task_complete
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
|
@ -120,7 +134,7 @@ def _run_asyncio_task(
|
|||
cancel_scope.cancel()
|
||||
else:
|
||||
if aio_err is not None:
|
||||
log.exception(f"infected task errorred:")
|
||||
log.exception("infected task errorred:")
|
||||
from_aio._err = aio_err
|
||||
# order is opposite here
|
||||
cancel_scope.cancel()
|
||||
|
@ -131,41 +145,20 @@ def _run_asyncio_task(
|
|||
return task, from_aio, to_trio, cancel_scope, aio_task_complete
|
||||
|
||||
|
||||
async def run_task(
|
||||
func: Callable,
|
||||
*,
|
||||
@acm
|
||||
async def translate_aio_errors(
|
||||
|
||||
qsize: int = 2**10,
|
||||
**kwargs,
|
||||
from_aio: trio.MemoryReceiveChannel,
|
||||
task: asyncio.Task,
|
||||
|
||||
) -> Any:
|
||||
) -> None:
|
||||
'''
|
||||
Run an ``asyncio`` async function or generator in a task, return
|
||||
or stream the result back to ``trio``.
|
||||
Error handling context around ``asyncio`` task spawns which
|
||||
appropriately translates errors and cancels into ``trio`` land.
|
||||
|
||||
'''
|
||||
# simple async func
|
||||
try:
|
||||
task, from_aio, to_trio, cs, _ = _run_asyncio_task(
|
||||
func,
|
||||
qsize=1,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# return single value
|
||||
with cs:
|
||||
# naively expect the mem chan api to do the job
|
||||
# of handling cross-framework cancellations / errors
|
||||
return await from_aio.receive()
|
||||
|
||||
if cs.cancelled_caught:
|
||||
aio_err = from_aio._err
|
||||
|
||||
# always raise from any captured asyncio error
|
||||
if aio_err:
|
||||
raise aio_err
|
||||
|
||||
# Do we need this?
|
||||
yield
|
||||
except (
|
||||
Exception,
|
||||
CancelledError,
|
||||
|
@ -190,6 +183,41 @@ async def run_task(
|
|||
# ... do what ..
|
||||
|
||||
|
||||
async def run_task(
|
||||
func: Callable,
|
||||
*,
|
||||
|
||||
qsize: int = 2**10,
|
||||
**kwargs,
|
||||
|
||||
) -> Any:
|
||||
'''
|
||||
Run an ``asyncio`` async function or generator in a task, return
|
||||
or stream the result back to ``trio``.
|
||||
|
||||
'''
|
||||
# simple async func
|
||||
task, from_aio, to_trio, cs, _ = _run_asyncio_task(
|
||||
func,
|
||||
qsize=1,
|
||||
**kwargs,
|
||||
)
|
||||
async with translate_aio_errors(from_aio, task):
|
||||
|
||||
# return single value
|
||||
with cs:
|
||||
# naively expect the mem chan api to do the job
|
||||
# of handling cross-framework cancellations / errors
|
||||
return await from_aio.receive()
|
||||
|
||||
if cs.cancelled_caught:
|
||||
aio_err = from_aio._err
|
||||
|
||||
# always raise from any captured asyncio error
|
||||
if aio_err:
|
||||
raise aio_err
|
||||
|
||||
|
||||
# TODO: explicitly api for the streaming case where
|
||||
# we pull from the mem chan in an async generator?
|
||||
# This ends up looking more like our ``Portal.open_stream_from()``
|
||||
|
@ -203,40 +231,21 @@ async def open_channel_from(
|
|||
|
||||
) -> AsyncIterator[Any]:
|
||||
|
||||
try:
|
||||
task, from_aio, to_trio, cs, aio_task_complete = _run_asyncio_task(
|
||||
target,
|
||||
qsize=2**8,
|
||||
provide_channels=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
task, from_aio, to_trio, cs, aio_task_complete = _run_asyncio_task(
|
||||
target,
|
||||
qsize=2**8,
|
||||
provide_channels=True,
|
||||
**kwargs,
|
||||
)
|
||||
async with translate_aio_errors(from_aio, task):
|
||||
with cs:
|
||||
# sync to "started()" call.
|
||||
first = await from_aio.receive()
|
||||
|
||||
# stream values upward
|
||||
async with from_aio:
|
||||
yield first, from_aio
|
||||
# await aio_task_complete.wait()
|
||||
|
||||
except BaseException as err:
|
||||
|
||||
aio_err = from_aio._err
|
||||
|
||||
if aio_err is not None:
|
||||
# always raise from any captured asyncio error
|
||||
raise err from aio_err
|
||||
else:
|
||||
raise
|
||||
|
||||
finally:
|
||||
if cs.cancelled_caught:
|
||||
# always raise from any captured asyncio error
|
||||
if from_aio._err:
|
||||
raise from_aio._err
|
||||
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
await aio_task_complete.wait()
|
||||
|
||||
|
||||
def run_as_asyncio_guest(
|
||||
|
@ -284,7 +293,7 @@ def run_as_asyncio_guest(
|
|||
run_sync_soon_threadsafe=loop.call_soon_threadsafe,
|
||||
done_callback=trio_done_callback,
|
||||
)
|
||||
(await trio_done_fut).unwrap()
|
||||
return (await trio_done_fut).unwrap()
|
||||
|
||||
# might as well if it's installed.
|
||||
try:
|
||||
|
|
Loading…
Reference in New Issue