Raise from asyncio error; fixes mypy
							parent
							
								
									6ae4d8699e
								
							
						
					
					
						commit
						762e6ad2a2
					
				| 
						 | 
				
			
			@ -6,7 +6,7 @@ import inspect
 | 
			
		|||
from typing import (
 | 
			
		||||
    Any,
 | 
			
		||||
    Callable,
 | 
			
		||||
    AsyncGenerator,
 | 
			
		||||
    AsyncIterator,
 | 
			
		||||
    Awaitable,
 | 
			
		||||
    Union,
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			@ -33,7 +33,7 @@ async def run_coro(
 | 
			
		|||
 | 
			
		||||
async def consume_asyncgen(
 | 
			
		||||
    to_trio: trio.MemorySendChannel,
 | 
			
		||||
    coro: AsyncGenerator,
 | 
			
		||||
    coro: AsyncIterator,
 | 
			
		||||
) -> None:
 | 
			
		||||
    """Stream async generator results back to ``trio``.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -50,7 +50,7 @@ async def run_task(
 | 
			
		|||
    qsize: int = 2**10,
 | 
			
		||||
    _treat_as_stream: bool = False,
 | 
			
		||||
    **kwargs,
 | 
			
		||||
) -> Union[AsyncGenerator, Any]:
 | 
			
		||||
) -> Any:
 | 
			
		||||
    """Run an ``asyncio`` async function or generator in a task, return
 | 
			
		||||
    or stream the result back to ``trio``.
 | 
			
		||||
    """
 | 
			
		||||
| 
						 | 
				
			
			@ -79,50 +79,58 @@ async def run_task(
 | 
			
		|||
    cancel_scope = trio.CancelScope()
 | 
			
		||||
 | 
			
		||||
    # start the asyncio task we submitted from trio
 | 
			
		||||
    # TODO: try out ``anyio`` asyncio based tg here
 | 
			
		||||
    if inspect.isawaitable(coro):
 | 
			
		||||
        task = asyncio.create_task(run_coro(to_trio, coro))
 | 
			
		||||
    elif inspect.isasyncgen(coro):
 | 
			
		||||
        task = asyncio.create_task(consume_asyncgen(to_trio, coro))
 | 
			
		||||
    else:
 | 
			
		||||
        raise TypeError(f"No support for {coro}")
 | 
			
		||||
        raise TypeError(f"No support for invoking {coro}")
 | 
			
		||||
 | 
			
		||||
    err = None
 | 
			
		||||
    aio_err = None
 | 
			
		||||
 | 
			
		||||
    def cancel_trio(task):
 | 
			
		||||
        """Cancel the calling ``trio`` task on error.
 | 
			
		||||
        """
 | 
			
		||||
        nonlocal err
 | 
			
		||||
        err = task.exception()
 | 
			
		||||
        if err:
 | 
			
		||||
            log.exception(f"asyncio task errorred:\n{err}")
 | 
			
		||||
        aio_err = task.exception()
 | 
			
		||||
        if aio_err:
 | 
			
		||||
            log.exception(f"asyncio task errorred:\n{aio_err}")
 | 
			
		||||
        cancel_scope.cancel()
 | 
			
		||||
 | 
			
		||||
    task.add_done_callback(cancel_trio)
 | 
			
		||||
 | 
			
		||||
    # asycn gen
 | 
			
		||||
    # async iterator
 | 
			
		||||
    if inspect.isasyncgen(coro) or _treat_as_stream:
 | 
			
		||||
 | 
			
		||||
        async def stream_results():
 | 
			
		||||
            with cancel_scope:
 | 
			
		||||
                # stream values upward
 | 
			
		||||
                async with from_aio:
 | 
			
		||||
                    async for item in from_aio:
 | 
			
		||||
                        yield item
 | 
			
		||||
            if cancel_scope.cancelled_caught and err:
 | 
			
		||||
                raise err
 | 
			
		||||
            try:
 | 
			
		||||
                with cancel_scope:
 | 
			
		||||
                    # stream values upward
 | 
			
		||||
                    async with from_aio:
 | 
			
		||||
                        async for item in from_aio:
 | 
			
		||||
                            yield item
 | 
			
		||||
            except BaseException as err:
 | 
			
		||||
                if aio_err is not None:
 | 
			
		||||
                    # always raise from any captured asyncio error
 | 
			
		||||
                    raise err from aio_err
 | 
			
		||||
                else:
 | 
			
		||||
                    raise
 | 
			
		||||
 | 
			
		||||
        return stream_results()
 | 
			
		||||
 | 
			
		||||
    # simple async func
 | 
			
		||||
    elif inspect.iscoroutine(coro):
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        with cancel_scope:
 | 
			
		||||
            # return single value
 | 
			
		||||
            return await from_aio.receive()
 | 
			
		||||
        if cancel_scope.cancelled_caught and err:
 | 
			
		||||
            raise err
 | 
			
		||||
 | 
			
		||||
    # Do we need this?
 | 
			
		||||
    except BaseException as err:
 | 
			
		||||
        if aio_err is not None:
 | 
			
		||||
            # always raise from any captured asyncio error
 | 
			
		||||
            raise err from aio_err
 | 
			
		||||
        else:
 | 
			
		||||
            raise
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_as_asyncio_guest(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue