Raise from asyncio error; fixes mypy
							parent
							
								
									8f15f438c7
								
							
						
					
					
						commit
						285dea04ea
					
				| 
						 | 
					@ -6,7 +6,7 @@ import inspect
 | 
				
			||||||
from typing import (
 | 
					from typing import (
 | 
				
			||||||
    Any,
 | 
					    Any,
 | 
				
			||||||
    Callable,
 | 
					    Callable,
 | 
				
			||||||
    AsyncGenerator,
 | 
					    AsyncIterator,
 | 
				
			||||||
    Awaitable,
 | 
					    Awaitable,
 | 
				
			||||||
    Union,
 | 
					    Union,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
| 
						 | 
					@ -33,7 +33,7 @@ async def run_coro(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def consume_asyncgen(
 | 
					async def consume_asyncgen(
 | 
				
			||||||
    to_trio: trio.MemorySendChannel,
 | 
					    to_trio: trio.MemorySendChannel,
 | 
				
			||||||
    coro: AsyncGenerator,
 | 
					    coro: AsyncIterator,
 | 
				
			||||||
) -> None:
 | 
					) -> None:
 | 
				
			||||||
    """Stream async generator results back to ``trio``.
 | 
					    """Stream async generator results back to ``trio``.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -50,7 +50,7 @@ async def run_task(
 | 
				
			||||||
    qsize: int = 2**10,
 | 
					    qsize: int = 2**10,
 | 
				
			||||||
    _treat_as_stream: bool = False,
 | 
					    _treat_as_stream: bool = False,
 | 
				
			||||||
    **kwargs,
 | 
					    **kwargs,
 | 
				
			||||||
) -> Union[AsyncGenerator, Any]:
 | 
					) -> Any:
 | 
				
			||||||
    """Run an ``asyncio`` async function or generator in a task, return
 | 
					    """Run an ``asyncio`` async function or generator in a task, return
 | 
				
			||||||
    or stream the result back to ``trio``.
 | 
					    or stream the result back to ``trio``.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
| 
						 | 
					@ -79,50 +79,58 @@ async def run_task(
 | 
				
			||||||
    cancel_scope = trio.CancelScope()
 | 
					    cancel_scope = trio.CancelScope()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # start the asyncio task we submitted from trio
 | 
					    # start the asyncio task we submitted from trio
 | 
				
			||||||
    # TODO: try out ``anyio`` asyncio based tg here
 | 
					 | 
				
			||||||
    if inspect.isawaitable(coro):
 | 
					    if inspect.isawaitable(coro):
 | 
				
			||||||
        task = asyncio.create_task(run_coro(to_trio, coro))
 | 
					        task = asyncio.create_task(run_coro(to_trio, coro))
 | 
				
			||||||
    elif inspect.isasyncgen(coro):
 | 
					    elif inspect.isasyncgen(coro):
 | 
				
			||||||
        task = asyncio.create_task(consume_asyncgen(to_trio, coro))
 | 
					        task = asyncio.create_task(consume_asyncgen(to_trio, coro))
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        raise TypeError(f"No support for {coro}")
 | 
					        raise TypeError(f"No support for invoking {coro}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    err = None
 | 
					    aio_err = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def cancel_trio(task):
 | 
					    def cancel_trio(task):
 | 
				
			||||||
        """Cancel the calling ``trio`` task on error.
 | 
					        """Cancel the calling ``trio`` task on error.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        nonlocal err
 | 
					        nonlocal err
 | 
				
			||||||
        err = task.exception()
 | 
					        aio_err = task.exception()
 | 
				
			||||||
        if err:
 | 
					        if aio_err:
 | 
				
			||||||
            log.exception(f"asyncio task errorred:\n{err}")
 | 
					            log.exception(f"asyncio task errorred:\n{aio_err}")
 | 
				
			||||||
        cancel_scope.cancel()
 | 
					        cancel_scope.cancel()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    task.add_done_callback(cancel_trio)
 | 
					    task.add_done_callback(cancel_trio)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # asycn gen
 | 
					    # async iterator
 | 
				
			||||||
    if inspect.isasyncgen(coro) or _treat_as_stream:
 | 
					    if inspect.isasyncgen(coro) or _treat_as_stream:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        async def stream_results():
 | 
					        async def stream_results():
 | 
				
			||||||
            with cancel_scope:
 | 
					            try:
 | 
				
			||||||
                # stream values upward
 | 
					                with cancel_scope:
 | 
				
			||||||
                async with from_aio:
 | 
					                    # stream values upward
 | 
				
			||||||
                    async for item in from_aio:
 | 
					                    async with from_aio:
 | 
				
			||||||
                        yield item
 | 
					                        async for item in from_aio:
 | 
				
			||||||
            if cancel_scope.cancelled_caught and err:
 | 
					                            yield item
 | 
				
			||||||
                raise err
 | 
					            except BaseException as err:
 | 
				
			||||||
 | 
					                if aio_err is not None:
 | 
				
			||||||
 | 
					                    # always raise from any captured asyncio error
 | 
				
			||||||
 | 
					                    raise err from aio_err
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    raise
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return stream_results()
 | 
					        return stream_results()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # simple async func
 | 
					    # simple async func
 | 
				
			||||||
    elif inspect.iscoroutine(coro):
 | 
					    try:
 | 
				
			||||||
 | 
					 | 
				
			||||||
        with cancel_scope:
 | 
					        with cancel_scope:
 | 
				
			||||||
            # return single value
 | 
					            # return single value
 | 
				
			||||||
            return await from_aio.receive()
 | 
					            return await from_aio.receive()
 | 
				
			||||||
        if cancel_scope.cancelled_caught and err:
 | 
					 | 
				
			||||||
            raise err
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Do we need this?
 | 
				
			||||||
 | 
					    except BaseException as err:
 | 
				
			||||||
 | 
					        if aio_err is not None:
 | 
				
			||||||
 | 
					            # always raise from any captured asyncio error
 | 
				
			||||||
 | 
					            raise err from aio_err
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def run_as_asyncio_guest(
 | 
					def run_as_asyncio_guest(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue