Factor error translation into a ctx mngr
Pull the common `asyncio` -> `trio` error translation logic into a common context manager and don't expect a final result to be captured when using `open_channel_from()` since it's a manager interface and it would be clunky to try and deliver some "final result" after exit.infect_asyncio
parent
e6687bcdc4
commit
9bc94b5ccc
|
@ -92,14 +92,28 @@ def _run_asyncio_task(
|
||||||
raise
|
raise
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
aio_task_complete.set()
|
if (
|
||||||
if result != orig and aio_err is None:
|
result != orig and
|
||||||
|
aio_err is None and
|
||||||
|
|
||||||
|
# in the ``open_channel_from()`` case we don't
|
||||||
|
# relay through the "return value".
|
||||||
|
not provide_channels
|
||||||
|
):
|
||||||
to_trio.send_nowait(result)
|
to_trio.send_nowait(result)
|
||||||
|
|
||||||
|
to_trio.close()
|
||||||
|
from_aio.close()
|
||||||
|
aio_task_complete.set()
|
||||||
|
|
||||||
# start the asyncio task we submitted from trio
|
# start the asyncio task we submitted from trio
|
||||||
if inspect.isawaitable(coro):
|
if inspect.isawaitable(coro):
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(
|
||||||
wait_on_coro_final_result(to_trio, coro, aio_task_complete)
|
wait_on_coro_final_result(
|
||||||
|
to_trio,
|
||||||
|
coro,
|
||||||
|
aio_task_complete
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -120,7 +134,7 @@ def _run_asyncio_task(
|
||||||
cancel_scope.cancel()
|
cancel_scope.cancel()
|
||||||
else:
|
else:
|
||||||
if aio_err is not None:
|
if aio_err is not None:
|
||||||
log.exception(f"infected task errorred:")
|
log.exception("infected task errorred:")
|
||||||
from_aio._err = aio_err
|
from_aio._err = aio_err
|
||||||
# order is opposite here
|
# order is opposite here
|
||||||
cancel_scope.cancel()
|
cancel_scope.cancel()
|
||||||
|
@ -131,41 +145,20 @@ def _run_asyncio_task(
|
||||||
return task, from_aio, to_trio, cancel_scope, aio_task_complete
|
return task, from_aio, to_trio, cancel_scope, aio_task_complete
|
||||||
|
|
||||||
|
|
||||||
async def run_task(
|
@acm
|
||||||
func: Callable,
|
async def translate_aio_errors(
|
||||||
*,
|
|
||||||
|
|
||||||
qsize: int = 2**10,
|
from_aio: trio.MemoryReceiveChannel,
|
||||||
**kwargs,
|
task: asyncio.Task,
|
||||||
|
|
||||||
) -> Any:
|
) -> None:
|
||||||
'''
|
'''
|
||||||
Run an ``asyncio`` async function or generator in a task, return
|
Error handling context around ``asyncio`` task spawns which
|
||||||
or stream the result back to ``trio``.
|
appropriately translates errors and cancels into ``trio`` land.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
# simple async func
|
|
||||||
try:
|
try:
|
||||||
task, from_aio, to_trio, cs, _ = _run_asyncio_task(
|
yield
|
||||||
func,
|
|
||||||
qsize=1,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# return single value
|
|
||||||
with cs:
|
|
||||||
# naively expect the mem chan api to do the job
|
|
||||||
# of handling cross-framework cancellations / errors
|
|
||||||
return await from_aio.receive()
|
|
||||||
|
|
||||||
if cs.cancelled_caught:
|
|
||||||
aio_err = from_aio._err
|
|
||||||
|
|
||||||
# always raise from any captured asyncio error
|
|
||||||
if aio_err:
|
|
||||||
raise aio_err
|
|
||||||
|
|
||||||
# Do we need this?
|
|
||||||
except (
|
except (
|
||||||
Exception,
|
Exception,
|
||||||
CancelledError,
|
CancelledError,
|
||||||
|
@ -190,6 +183,41 @@ async def run_task(
|
||||||
# ... do what ..
|
# ... do what ..
|
||||||
|
|
||||||
|
|
||||||
|
async def run_task(
|
||||||
|
func: Callable,
|
||||||
|
*,
|
||||||
|
|
||||||
|
qsize: int = 2**10,
|
||||||
|
**kwargs,
|
||||||
|
|
||||||
|
) -> Any:
|
||||||
|
'''
|
||||||
|
Run an ``asyncio`` async function or generator in a task, return
|
||||||
|
or stream the result back to ``trio``.
|
||||||
|
|
||||||
|
'''
|
||||||
|
# simple async func
|
||||||
|
task, from_aio, to_trio, cs, _ = _run_asyncio_task(
|
||||||
|
func,
|
||||||
|
qsize=1,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
async with translate_aio_errors(from_aio, task):
|
||||||
|
|
||||||
|
# return single value
|
||||||
|
with cs:
|
||||||
|
# naively expect the mem chan api to do the job
|
||||||
|
# of handling cross-framework cancellations / errors
|
||||||
|
return await from_aio.receive()
|
||||||
|
|
||||||
|
if cs.cancelled_caught:
|
||||||
|
aio_err = from_aio._err
|
||||||
|
|
||||||
|
# always raise from any captured asyncio error
|
||||||
|
if aio_err:
|
||||||
|
raise aio_err
|
||||||
|
|
||||||
|
|
||||||
# TODO: explicitly api for the streaming case where
|
# TODO: explicitly api for the streaming case where
|
||||||
# we pull from the mem chan in an async generator?
|
# we pull from the mem chan in an async generator?
|
||||||
# This ends up looking more like our ``Portal.open_stream_from()``
|
# This ends up looking more like our ``Portal.open_stream_from()``
|
||||||
|
@ -203,40 +231,21 @@ async def open_channel_from(
|
||||||
|
|
||||||
) -> AsyncIterator[Any]:
|
) -> AsyncIterator[Any]:
|
||||||
|
|
||||||
try:
|
|
||||||
task, from_aio, to_trio, cs, aio_task_complete = _run_asyncio_task(
|
task, from_aio, to_trio, cs, aio_task_complete = _run_asyncio_task(
|
||||||
target,
|
target,
|
||||||
qsize=2**8,
|
qsize=2**8,
|
||||||
provide_channels=True,
|
provide_channels=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
async with translate_aio_errors(from_aio, task):
|
||||||
with cs:
|
with cs:
|
||||||
# sync to "started()" call.
|
# sync to "started()" call.
|
||||||
first = await from_aio.receive()
|
first = await from_aio.receive()
|
||||||
|
|
||||||
# stream values upward
|
# stream values upward
|
||||||
async with from_aio:
|
async with from_aio:
|
||||||
yield first, from_aio
|
yield first, from_aio
|
||||||
# await aio_task_complete.wait()
|
await aio_task_complete.wait()
|
||||||
|
|
||||||
except BaseException as err:
|
|
||||||
|
|
||||||
aio_err = from_aio._err
|
|
||||||
|
|
||||||
if aio_err is not None:
|
|
||||||
# always raise from any captured asyncio error
|
|
||||||
raise err from aio_err
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
finally:
|
|
||||||
if cs.cancelled_caught:
|
|
||||||
# always raise from any captured asyncio error
|
|
||||||
if from_aio._err:
|
|
||||||
raise from_aio._err
|
|
||||||
|
|
||||||
if not task.done():
|
|
||||||
task.cancel()
|
|
||||||
|
|
||||||
|
|
||||||
def run_as_asyncio_guest(
|
def run_as_asyncio_guest(
|
||||||
|
@ -284,7 +293,7 @@ def run_as_asyncio_guest(
|
||||||
run_sync_soon_threadsafe=loop.call_soon_threadsafe,
|
run_sync_soon_threadsafe=loop.call_soon_threadsafe,
|
||||||
done_callback=trio_done_callback,
|
done_callback=trio_done_callback,
|
||||||
)
|
)
|
||||||
(await trio_done_fut).unwrap()
|
return (await trio_done_fut).unwrap()
|
||||||
|
|
||||||
# might as well if it's installed.
|
# might as well if it's installed.
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Reference in New Issue