Raise from asyncio error; fixes mypy

infect_asyncio
Tyler Goodlet 2020-10-14 12:51:41 -04:00
parent 2cf87146a3
commit 80f47dece2
1 changed files with 29 additions and 21 deletions

View File

@ -6,7 +6,7 @@ import inspect
from typing import (
Any,
Callable,
AsyncGenerator,
AsyncIterator,
Awaitable,
Union,
)
@ -33,7 +33,7 @@ async def run_coro(
async def consume_asyncgen(
to_trio: trio.MemorySendChannel,
coro: AsyncGenerator,
coro: AsyncIterator,
) -> None:
"""Stream async generator results back to ``trio``.
@ -50,7 +50,7 @@ async def run_task(
qsize: int = 2**10,
_treat_as_stream: bool = False,
**kwargs,
) -> Union[AsyncGenerator, Any]:
) -> Any:
"""Run an ``asyncio`` async function or generator in a task, return
or stream the result back to ``trio``.
"""
@ -79,50 +79,58 @@ async def run_task(
cancel_scope = trio.CancelScope()
# start the asyncio task we submitted from trio
# TODO: try out ``anyio`` asyncio based tg here
if inspect.isawaitable(coro):
task = asyncio.create_task(run_coro(to_trio, coro))
elif inspect.isasyncgen(coro):
task = asyncio.create_task(consume_asyncgen(to_trio, coro))
else:
raise TypeError(f"No support for {coro}")
raise TypeError(f"No support for invoking {coro}")
err = None
aio_err = None
def cancel_trio(task):
"""Cancel the calling ``trio`` task on error.
"""
nonlocal err
err = task.exception()
if err:
log.exception(f"asyncio task errorred:\n{err}")
aio_err = task.exception()
if aio_err:
log.exception(f"asyncio task errorred:\n{aio_err}")
cancel_scope.cancel()
task.add_done_callback(cancel_trio)
# asycn gen
# async iterator
if inspect.isasyncgen(coro) or _treat_as_stream:
async def stream_results():
with cancel_scope:
# stream values upward
async with from_aio:
async for item in from_aio:
yield item
if cancel_scope.cancelled_caught and err:
raise err
try:
with cancel_scope:
# stream values upward
async with from_aio:
async for item in from_aio:
yield item
except BaseException as err:
if aio_err is not None:
# always raise from any captured asyncio error
raise err from aio_err
else:
raise
return stream_results()
# simple async func
elif inspect.iscoroutine(coro):
try:
with cancel_scope:
# return single value
return await from_aio.receive()
if cancel_scope.cancelled_caught and err:
raise err
# Do we need this?
except BaseException as err:
if aio_err is not None:
# always raise from any captured asyncio error
raise err from aio_err
else:
raise
def run_as_asyncio_guest(