Raise from asyncio error; fixes mypy

msgspec_infect_asyncio
Tyler Goodlet 2020-10-14 12:51:41 -04:00
parent 2adb59f40f
commit 68e5c2a95f
1 changed files with 29 additions and 21 deletions

View File

@ -6,7 +6,7 @@ import inspect
from typing import ( from typing import (
Any, Any,
Callable, Callable,
AsyncGenerator, AsyncIterator,
Awaitable, Awaitable,
Union, Union,
) )
@ -33,7 +33,7 @@ async def run_coro(
async def consume_asyncgen( async def consume_asyncgen(
to_trio: trio.MemorySendChannel, to_trio: trio.MemorySendChannel,
coro: AsyncGenerator, coro: AsyncIterator,
) -> None: ) -> None:
"""Stream async generator results back to ``trio``. """Stream async generator results back to ``trio``.
@ -50,7 +50,7 @@ async def run_task(
qsize: int = 2**10, qsize: int = 2**10,
_treat_as_stream: bool = False, _treat_as_stream: bool = False,
**kwargs, **kwargs,
) -> Union[AsyncGenerator, Any]: ) -> Any:
"""Run an ``asyncio`` async function or generator in a task, return """Run an ``asyncio`` async function or generator in a task, return
or stream the result back to ``trio``. or stream the result back to ``trio``.
""" """
@ -79,50 +79,58 @@ async def run_task(
cancel_scope = trio.CancelScope() cancel_scope = trio.CancelScope()
# start the asyncio task we submitted from trio # start the asyncio task we submitted from trio
# TODO: try out ``anyio`` asyncio based tg here
if inspect.isawaitable(coro): if inspect.isawaitable(coro):
task = asyncio.create_task(run_coro(to_trio, coro)) task = asyncio.create_task(run_coro(to_trio, coro))
elif inspect.isasyncgen(coro): elif inspect.isasyncgen(coro):
task = asyncio.create_task(consume_asyncgen(to_trio, coro)) task = asyncio.create_task(consume_asyncgen(to_trio, coro))
else: else:
raise TypeError(f"No support for {coro}") raise TypeError(f"No support for invoking {coro}")
err = None aio_err = None
def cancel_trio(task): def cancel_trio(task):
"""Cancel the calling ``trio`` task on error. """Cancel the calling ``trio`` task on error.
""" """
nonlocal err nonlocal err
err = task.exception() aio_err = task.exception()
if err: if aio_err:
log.exception(f"asyncio task errorred:\n{err}") log.exception(f"asyncio task errorred:\n{aio_err}")
cancel_scope.cancel() cancel_scope.cancel()
task.add_done_callback(cancel_trio) task.add_done_callback(cancel_trio)
# asycn gen # async iterator
if inspect.isasyncgen(coro) or _treat_as_stream: if inspect.isasyncgen(coro) or _treat_as_stream:
async def stream_results(): async def stream_results():
with cancel_scope: try:
# stream values upward with cancel_scope:
async with from_aio: # stream values upward
async for item in from_aio: async with from_aio:
yield item async for item in from_aio:
if cancel_scope.cancelled_caught and err: yield item
raise err except BaseException as err:
if aio_err is not None:
# always raise from any captured asyncio error
raise err from aio_err
else:
raise
return stream_results() return stream_results()
# simple async func # simple async func
elif inspect.iscoroutine(coro): try:
with cancel_scope: with cancel_scope:
# return single value # return single value
return await from_aio.receive() return await from_aio.receive()
if cancel_scope.cancelled_caught and err:
raise err
# Do we need this?
except BaseException as err:
if aio_err is not None:
# always raise from any captured asyncio error
raise err from aio_err
else:
raise
def run_as_asyncio_guest( def run_as_asyncio_guest(