forked from goodboy/tractor
				
			Return channel type from `_run_asyncio_task()`
Better encapsulate all the mem-chan, Queue, sync-primitives inside our linked task channel in order to avoid `mypy`'s complaints about monkey patching. This also sets footing for adding an `asyncio`-side channel API that can be used more like this `trio`-side API.infect_asyncio
							parent
							
								
									9a2de90de6
								
							
						
					
					
						commit
						56cc98375e
					
				| 
						 | 
				
			
			@ -27,14 +27,57 @@ log = get_logger(__name__)
 | 
			
		|||
__all__ = ['run_task', 'run_as_asyncio_guest']
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class LinkedTaskChannel(trio.abc.Channel):
 | 
			
		||||
    '''
 | 
			
		||||
    A "linked task channel" which allows for two-way synchronized msg
 | 
			
		||||
    passing between a ``trio``-in-guest-mode task and an ``asyncio``
 | 
			
		||||
    task scheduled in the host loop.
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    _to_aio: asyncio.Queue
 | 
			
		||||
    _from_aio: trio.MemoryReceiveChannel
 | 
			
		||||
    _to_trio: trio.MemorySendChannel
 | 
			
		||||
 | 
			
		||||
    _trio_cs: trio.CancelScope
 | 
			
		||||
    _aio_task_complete: trio.Event
 | 
			
		||||
 | 
			
		||||
    # set after ``asyncio.create_task()``
 | 
			
		||||
    _aio_task: Optional[asyncio.Task] = None
 | 
			
		||||
    _aio_err: Optional[BaseException] = None
 | 
			
		||||
 | 
			
		||||
    async def aclose(self) -> None:
 | 
			
		||||
        await self._from_aio.aclose()
 | 
			
		||||
 | 
			
		||||
    async def receive(self) -> Any:
 | 
			
		||||
        async with translate_aio_errors(self):
 | 
			
		||||
            return await self._from_aio.receive()
 | 
			
		||||
 | 
			
		||||
    async def wait_ayncio_complete(self) -> None:
 | 
			
		||||
        await self._aio_task_complete.wait()
 | 
			
		||||
 | 
			
		||||
    # def cancel_asyncio_task(self) -> None:
 | 
			
		||||
    #     self._aio_task.cancel()
 | 
			
		||||
 | 
			
		||||
    async def send(self, item: Any) -> None:
 | 
			
		||||
        '''
 | 
			
		||||
        Send a value through to the asyncio task presuming
 | 
			
		||||
        it defines a ``from_trio`` argument, if it does not
 | 
			
		||||
        this method will raise an error.
 | 
			
		||||
 | 
			
		||||
        '''
 | 
			
		||||
        self._to_aio.put_nowait(item)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _run_asyncio_task(
 | 
			
		||||
 | 
			
		||||
    func: Callable,
 | 
			
		||||
    *,
 | 
			
		||||
    qsize: int = 1,
 | 
			
		||||
    provide_channels: bool = False,
 | 
			
		||||
    **kwargs,
 | 
			
		||||
 | 
			
		||||
) -> Any:
 | 
			
		||||
) -> LinkedTaskChannel:
 | 
			
		||||
    '''
 | 
			
		||||
    Run an ``asyncio`` async function or generator in a task, return
 | 
			
		||||
    or stream the result back to ``trio``.
 | 
			
		||||
| 
						 | 
				
			
			@ -45,11 +88,9 @@ def _run_asyncio_task(
 | 
			
		|||
 | 
			
		||||
    # ITC (inter task comms), these channel/queue names are mostly from
 | 
			
		||||
    # ``asyncio``'s perspective.
 | 
			
		||||
    from_trio = asyncio.Queue(qsize)  # type: ignore
 | 
			
		||||
    aio_q = from_trio = asyncio.Queue(qsize)  # type: ignore
 | 
			
		||||
    to_trio, from_aio = trio.open_memory_channel(qsize)  # type: ignore
 | 
			
		||||
 | 
			
		||||
    from_aio._err = None
 | 
			
		||||
 | 
			
		||||
    args = tuple(inspect.getfullargspec(func).args)
 | 
			
		||||
 | 
			
		||||
    if getattr(func, '_tractor_steam_function', None):
 | 
			
		||||
| 
						 | 
				
			
			@ -74,6 +115,15 @@ def _run_asyncio_task(
 | 
			
		|||
    aio_task_complete = trio.Event()
 | 
			
		||||
    aio_err: Optional[BaseException] = None
 | 
			
		||||
 | 
			
		||||
    chan = LinkedTaskChannel(
 | 
			
		||||
        aio_q,  # asyncio.Queue
 | 
			
		||||
        from_aio,  # recv chan
 | 
			
		||||
        to_trio,  # send chan
 | 
			
		||||
 | 
			
		||||
        cancel_scope,
 | 
			
		||||
        aio_task_complete,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    async def wait_on_coro_final_result(
 | 
			
		||||
 | 
			
		||||
        to_trio: trio.MemorySendChannel,
 | 
			
		||||
| 
						 | 
				
			
			@ -86,12 +136,13 @@ def _run_asyncio_task(
 | 
			
		|||
 | 
			
		||||
        '''
 | 
			
		||||
        nonlocal aio_err
 | 
			
		||||
        nonlocal chan
 | 
			
		||||
 | 
			
		||||
        orig = result = id(coro)
 | 
			
		||||
        try:
 | 
			
		||||
            result = await coro
 | 
			
		||||
        except BaseException as err:
 | 
			
		||||
            aio_err = err
 | 
			
		||||
            from_aio._err = aio_err
 | 
			
		||||
        except BaseException as aio_err:
 | 
			
		||||
            chan._aio_err = aio_err
 | 
			
		||||
            raise
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
| 
						 | 
				
			
			@ -116,25 +167,25 @@ def _run_asyncio_task(
 | 
			
		|||
            aio_task_complete.set()
 | 
			
		||||
 | 
			
		||||
    # start the asyncio task we submitted from trio
 | 
			
		||||
    if inspect.isawaitable(coro):
 | 
			
		||||
        task = asyncio.create_task(
 | 
			
		||||
            wait_on_coro_final_result(
 | 
			
		||||
                to_trio,
 | 
			
		||||
                coro,
 | 
			
		||||
                aio_task_complete
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
    if not inspect.isawaitable(coro):
 | 
			
		||||
        raise TypeError(f"No support for invoking {coro}")
 | 
			
		||||
 | 
			
		||||
    task = asyncio.create_task(
 | 
			
		||||
        wait_on_coro_final_result(
 | 
			
		||||
            to_trio,
 | 
			
		||||
            coro,
 | 
			
		||||
            aio_task_complete
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    chan._aio_task = task
 | 
			
		||||
 | 
			
		||||
    def cancel_trio(task: asyncio.Task) -> None:
 | 
			
		||||
        '''
 | 
			
		||||
        Cancel the calling ``trio`` task on error.
 | 
			
		||||
 | 
			
		||||
        '''
 | 
			
		||||
        nonlocal aio_err
 | 
			
		||||
        aio_err = from_aio._err
 | 
			
		||||
        nonlocal chan
 | 
			
		||||
        aio_err = chan._aio_err
 | 
			
		||||
 | 
			
		||||
        # only to avoid ``asyncio`` complaining about uncaptured
 | 
			
		||||
        # task exceptions
 | 
			
		||||
| 
						 | 
				
			
			@ -159,27 +210,26 @@ def _run_asyncio_task(
 | 
			
		|||
 | 
			
		||||
    task.add_done_callback(cancel_trio)
 | 
			
		||||
 | 
			
		||||
    return task, from_aio, to_trio, from_trio, cancel_scope, aio_task_complete
 | 
			
		||||
    return chan
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@acm
 | 
			
		||||
async def translate_aio_errors(
 | 
			
		||||
 | 
			
		||||
    from_aio: trio.MemoryReceiveChannel,
 | 
			
		||||
    task: asyncio.Task,
 | 
			
		||||
    chan: LinkedTaskChannel,
 | 
			
		||||
 | 
			
		||||
) -> None:
 | 
			
		||||
) -> AsyncIterator[None]:
 | 
			
		||||
    '''
 | 
			
		||||
    Error handling context around ``asyncio`` task spawns which
 | 
			
		||||
    appropriately translates errors and cancels into ``trio`` land.
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    aio_err: Optional[Exception] = None
 | 
			
		||||
    aio_err: Optional[BaseException] = None
 | 
			
		||||
 | 
			
		||||
    def maybe_raise_aio_err(
 | 
			
		||||
        err: Optional[Exception] = None
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        aio_err = from_aio._err
 | 
			
		||||
        aio_err = chan._aio_err
 | 
			
		||||
        if (
 | 
			
		||||
            aio_err is not None and
 | 
			
		||||
            type(aio_err) != CancelledError
 | 
			
		||||
| 
						 | 
				
			
			@ -189,6 +239,9 @@ async def translate_aio_errors(
 | 
			
		|||
                raise aio_err from err
 | 
			
		||||
            else:
 | 
			
		||||
                raise aio_err
 | 
			
		||||
 | 
			
		||||
    task = chan._aio_task
 | 
			
		||||
    assert task
 | 
			
		||||
    try:
 | 
			
		||||
        yield
 | 
			
		||||
    except (
 | 
			
		||||
| 
						 | 
				
			
			@ -196,7 +249,7 @@ async def translate_aio_errors(
 | 
			
		|||
        # termination callback
 | 
			
		||||
        trio.ClosedResourceError,
 | 
			
		||||
    ):
 | 
			
		||||
        aio_err = from_aio._err
 | 
			
		||||
        aio_err = chan._aio_err
 | 
			
		||||
        if (
 | 
			
		||||
            task.cancelled() and
 | 
			
		||||
            type(aio_err) is CancelledError
 | 
			
		||||
| 
						 | 
				
			
			@ -234,65 +287,26 @@ async def run_task(
 | 
			
		|||
 | 
			
		||||
    '''
 | 
			
		||||
    # simple async func
 | 
			
		||||
    task, from_aio, to_trio, aio_q, cs, _ = _run_asyncio_task(
 | 
			
		||||
    chan = _run_asyncio_task(
 | 
			
		||||
        func,
 | 
			
		||||
        qsize=1,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    )
 | 
			
		||||
    with from_aio:
 | 
			
		||||
    with chan._from_aio:
 | 
			
		||||
        # try:
 | 
			
		||||
        async with translate_aio_errors(from_aio, task):
 | 
			
		||||
        async with translate_aio_errors(chan):
 | 
			
		||||
            # return single value that is the output from the
 | 
			
		||||
            # ``asyncio`` function-as-task. Expect the mem chan api to
 | 
			
		||||
            # do the job of handling cross-framework cancellations
 | 
			
		||||
            # / errors via closure and translation in the
 | 
			
		||||
            # ``translate_aio_errors()`` in the above ctx mngr.
 | 
			
		||||
            return await from_aio.receive()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class LinkedTaskChannel(trio.abc.Channel):
 | 
			
		||||
    '''
 | 
			
		||||
    A "linked task channel" which allows for two-way synchronized msg
 | 
			
		||||
    passing between a ``trio``-in-guest-mode task and an ``asyncio``
 | 
			
		||||
    task.
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    _aio_task: asyncio.Task
 | 
			
		||||
    _to_aio: asyncio.Queue
 | 
			
		||||
    _from_aio: trio.MemoryReceiveChannel
 | 
			
		||||
    _aio_task_complete: trio.Event
 | 
			
		||||
 | 
			
		||||
    async def aclose(self) -> None:
 | 
			
		||||
        self._from_aio.close()
 | 
			
		||||
 | 
			
		||||
    async def receive(self) -> Any:
 | 
			
		||||
        async with translate_aio_errors(
 | 
			
		||||
            self._from_aio,
 | 
			
		||||
            self._aio_task,
 | 
			
		||||
        ):
 | 
			
		||||
            return await self._from_aio.receive()
 | 
			
		||||
 | 
			
		||||
    async def wait_ayncio_complete(self) -> None:
 | 
			
		||||
        await self._aio_task_complete.wait()
 | 
			
		||||
 | 
			
		||||
    # def cancel_asyncio_task(self) -> None:
 | 
			
		||||
    #     self._aio_task.cancel()
 | 
			
		||||
 | 
			
		||||
    async def send(self, item: Any) -> None:
 | 
			
		||||
        '''
 | 
			
		||||
        Send a value through to the asyncio task presuming
 | 
			
		||||
        it defines a ``from_trio`` argument, if it does not
 | 
			
		||||
        this method will raise an error.
 | 
			
		||||
 | 
			
		||||
        '''
 | 
			
		||||
        self._to_aio.put_nowait(item)
 | 
			
		||||
            return await chan.receive()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@acm
 | 
			
		||||
async def open_channel_from(
 | 
			
		||||
 | 
			
		||||
    target: Callable[[Any, ...], Any],
 | 
			
		||||
    target: Callable[..., Any],
 | 
			
		||||
    **kwargs,
 | 
			
		||||
 | 
			
		||||
) -> AsyncIterator[Any]:
 | 
			
		||||
| 
						 | 
				
			
			@ -301,21 +315,17 @@ async def open_channel_from(
 | 
			
		|||
    spawned ``asyncio`` task and ``trio``.
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    task, from_aio, to_trio, aio_q, cs, aio_task_complete = _run_asyncio_task(
 | 
			
		||||
    chan = _run_asyncio_task(
 | 
			
		||||
        target,
 | 
			
		||||
        qsize=2**8,
 | 
			
		||||
        provide_channels=True,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    )
 | 
			
		||||
    chan = LinkedTaskChannel(
 | 
			
		||||
        task, aio_q, from_aio,
 | 
			
		||||
        aio_task_complete
 | 
			
		||||
    )
 | 
			
		||||
    async with from_aio:
 | 
			
		||||
        async with translate_aio_errors(from_aio, task):
 | 
			
		||||
    async with chan._from_aio:
 | 
			
		||||
        async with translate_aio_errors(chan):
 | 
			
		||||
            # sync to a "started()"-like first delivered value from the
 | 
			
		||||
            # ``asyncio`` task.
 | 
			
		||||
            first = await from_aio.receive()
 | 
			
		||||
            first = await chan.receive()
 | 
			
		||||
 | 
			
		||||
            # stream values upward
 | 
			
		||||
            yield first, chan
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue