Don't use async gen functions for the stream API
As mentioned in prior commits there's currently a bug in Python that make async gens **not** task safe. Since this is the core cause of almost all recent problems, instead implement our own async iterator derivative of `trio.abc.ReceiveChannel` by wrapping a `trio._channel.MemoryReceiveChannel`. This fits more natively with the memory channel API in ``trio`` and adds potentially more flexibility for possible bidirectional inter-actor streaming in the future. Huge thanks to @oremanj and of course @njsmith for guidance on this one!trio_memchans
							parent
							
								
									b91d13cfea
								
							
						
					
					
						commit
						616192d853
					
				| 
						 | 
				
			
			@ -312,8 +312,11 @@ class Actor:
 | 
			
		|||
        cid = msg['cid']
 | 
			
		||||
        send_chan = self._cids2qs[(actorid, cid)]
 | 
			
		||||
        assert send_chan.cid == cid
 | 
			
		||||
        log.debug(f"Delivering {msg} from {actorid} to caller {cid}")
 | 
			
		||||
        if 'stop' in msg:
 | 
			
		||||
            log.debug(f"{send_chan} was terminated at remote end")
 | 
			
		||||
            return await send_chan.aclose()
 | 
			
		||||
        try:
 | 
			
		||||
            log.debug(f"Delivering {msg} from {actorid} to caller {cid}")
 | 
			
		||||
            # maintain backpressure
 | 
			
		||||
            await send_chan.send(msg)
 | 
			
		||||
        except trio.BrokenResourceError:
 | 
			
		||||
| 
						 | 
				
			
			@ -665,9 +668,14 @@ class Actor:
 | 
			
		|||
        # streaming IPC but it should be called
 | 
			
		||||
        # to cancel any remotely spawned task
 | 
			
		||||
        chan = ctx.chan
 | 
			
		||||
        # the ``dict.get()`` ensures the requested task to be cancelled
 | 
			
		||||
        # was indeed spawned by a request from this channel
 | 
			
		||||
        scope, func, is_complete = self._rpc_tasks[(ctx.chan, cid)]
 | 
			
		||||
        try:
 | 
			
		||||
            # this ctx based lookup ensures the requested task to
 | 
			
		||||
            # be cancelled was indeed spawned by a request from this channel
 | 
			
		||||
            scope, func, is_complete = self._rpc_tasks[(ctx.chan, cid)]
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            log.warning(f"{cid} has already completed/terminated?")
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        log.debug(
 | 
			
		||||
            f"Cancelling task:\ncid: {cid}\nfunc: {func}\n"
 | 
			
		||||
            f"peer: {chan.uid}\n")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -9,7 +9,7 @@ from functools import partial
 | 
			
		|||
from dataclasses import dataclass
 | 
			
		||||
 | 
			
		||||
import trio
 | 
			
		||||
from async_generator import asynccontextmanager, aclosing
 | 
			
		||||
from async_generator import asynccontextmanager
 | 
			
		||||
 | 
			
		||||
from ._state import current_actor
 | 
			
		||||
from ._ipc import Channel
 | 
			
		||||
| 
						 | 
				
			
			@ -48,6 +48,87 @@ async def _do_handshake(
 | 
			
		|||
    return uid
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class StreamReceiveChannel(trio.abc.ReceiveChannel):
 | 
			
		||||
    """A wrapper around a ``trio.abc.ReceiveChannel`` with
 | 
			
		||||
    special behaviour for stream termination on both ends of
 | 
			
		||||
    an inter-actor channel.
 | 
			
		||||
 | 
			
		||||
    Termination rules:
 | 
			
		||||
    - if the local task signals stop iteration a cancel signal is
 | 
			
		||||
      relayed to the remote task indicating to stop streaming
 | 
			
		||||
    - if the remote task signals the end of a stream, raise a
 | 
			
		||||
      ``StopAsyncIteration`` to terminate the local ``async for``
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        cid: str,
 | 
			
		||||
        rx_chan: trio.abc.ReceiveChannel,
 | 
			
		||||
        portal: 'Portal',
 | 
			
		||||
    ):
 | 
			
		||||
        self._cid = cid
 | 
			
		||||
        self._rx_chan = rx_chan
 | 
			
		||||
        self._portal = portal
 | 
			
		||||
 | 
			
		||||
    # delegate directly to underlying mem channel
 | 
			
		||||
    def receive_nowait(self):
 | 
			
		||||
        return self._rx_chan.receive_nowait()
 | 
			
		||||
 | 
			
		||||
    async def receive(self):
 | 
			
		||||
        try:
 | 
			
		||||
            msg = await self._rx_chan.receive()
 | 
			
		||||
            return msg['yield']
 | 
			
		||||
        except trio.ClosedResourceError:
 | 
			
		||||
            # when the send is closed we assume the stream has
 | 
			
		||||
            # terminated and signal this local iterator to stop
 | 
			
		||||
            await self.aclose()
 | 
			
		||||
            raise StopAsyncIteration
 | 
			
		||||
        except trio.Cancelled:
 | 
			
		||||
            await self.aclose()
 | 
			
		||||
            raise
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            # if 'stop' in msg:
 | 
			
		||||
            #     break  # far end async gen terminated
 | 
			
		||||
            # else:
 | 
			
		||||
            # internal error should never get here
 | 
			
		||||
            assert msg.get('cid'), (
 | 
			
		||||
                "Received internal error at portal?")
 | 
			
		||||
            raise unpack_error(msg, self._portal.channel)
 | 
			
		||||
 | 
			
		||||
    async def aclose(self):
 | 
			
		||||
        if self._rx_chan._closed:
 | 
			
		||||
            log.warning(f"{self} is already closed")
 | 
			
		||||
            return
 | 
			
		||||
        cid = self._cid
 | 
			
		||||
        # XXX: cancel remote task on close
 | 
			
		||||
        log.warning(
 | 
			
		||||
            f"Cancelling stream {cid} to "
 | 
			
		||||
            f"{self._portal.channel.uid}")
 | 
			
		||||
        with trio.move_on_after(0.5) as cs:
 | 
			
		||||
            cs.shield = True
 | 
			
		||||
            # TODO: yeah.. it'd be nice if this was just an
 | 
			
		||||
            # async func on the far end. Gotta figure out a
 | 
			
		||||
            # better way then implicitly feeding the ctx
 | 
			
		||||
            # to declaring functions; likely a decorator
 | 
			
		||||
            # system.
 | 
			
		||||
            rchan = await self._portal.run(
 | 
			
		||||
                'self', 'cancel_task', cid=cid)
 | 
			
		||||
            async for _ in rchan:
 | 
			
		||||
                pass
 | 
			
		||||
 | 
			
		||||
        if cs.cancelled_caught:
 | 
			
		||||
            if not self._portal.channel.connected():
 | 
			
		||||
                log.warning(
 | 
			
		||||
                    "May have failed to cancel remote task "
 | 
			
		||||
                    f"{cid} for {self._portal.channel.uid}")
 | 
			
		||||
 | 
			
		||||
        with trio.open_cancel_scope(shield=True):
 | 
			
		||||
            await self._rx_chan.aclose()
 | 
			
		||||
 | 
			
		||||
    def clone(self):
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Portal:
 | 
			
		||||
    """A 'portal' to a(n) (remote) ``Actor``.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -67,13 +148,7 @@ class Portal:
 | 
			
		|||
        self._expect_result: Optional[
 | 
			
		||||
            Tuple[str, Any, str, Dict[str, Any]]
 | 
			
		||||
        ] = None
 | 
			
		||||
        self._agens: Set[typing.AsyncGenerator] = set()
 | 
			
		||||
 | 
			
		||||
    async def aclose(self) -> None:
 | 
			
		||||
        log.debug(f"Closing {self}")
 | 
			
		||||
        # XXX: won't work until https://github.com/python-trio/trio/pull/460
 | 
			
		||||
        # gets in!
 | 
			
		||||
        await self.channel.aclose()
 | 
			
		||||
        self._streams: Set[StreamReceiveChannel] = set()
 | 
			
		||||
 | 
			
		||||
    async def _submit(
 | 
			
		||||
        self,
 | 
			
		||||
| 
						 | 
				
			
			@ -135,50 +210,9 @@ class Portal:
 | 
			
		|||
        # to make async-generators the fundamental IPC API over channels!
 | 
			
		||||
        # (think `yield from`, `gen.send()`, and functional reactive stuff)
 | 
			
		||||
        if resptype == 'yield':  # stream response
 | 
			
		||||
 | 
			
		||||
            async def yield_from_recvchan():
 | 
			
		||||
                async with recv_chan:
 | 
			
		||||
                    try:
 | 
			
		||||
                        async for msg in recv_chan:
 | 
			
		||||
                            try:
 | 
			
		||||
                                yield msg['yield']
 | 
			
		||||
                            except KeyError:
 | 
			
		||||
                                if 'stop' in msg:
 | 
			
		||||
                                    break  # far end async gen terminated
 | 
			
		||||
                                else:
 | 
			
		||||
                                    # internal error should never get here
 | 
			
		||||
                                    assert msg.get('cid'), (
 | 
			
		||||
                                        "Received internal error at portal?")
 | 
			
		||||
                                    raise unpack_error(msg, self.channel)
 | 
			
		||||
 | 
			
		||||
                    except (GeneratorExit, trio.Cancelled) as err:
 | 
			
		||||
                        log.warning(
 | 
			
		||||
                            f"Cancelling async gen call {cid} to "
 | 
			
		||||
                            f"{self.channel.uid}")
 | 
			
		||||
                        with trio.move_on_after(0.5) as cs:
 | 
			
		||||
                            cs.shield = True
 | 
			
		||||
                            # TODO: yeah.. it'd be nice if this was just an
 | 
			
		||||
                            # async func on the far end. Gotta figure out a
 | 
			
		||||
                            # better way then implicitly feeding the ctx
 | 
			
		||||
                            # to declaring functions; likely a decorator
 | 
			
		||||
                            # system.
 | 
			
		||||
                            agen = await self.run(
 | 
			
		||||
                                'self', 'cancel_task', cid=cid)
 | 
			
		||||
                            async with aclosing(agen) as agen:
 | 
			
		||||
                                async for _ in agen:
 | 
			
		||||
                                    pass
 | 
			
		||||
 | 
			
		||||
                        if cs.cancelled_caught:
 | 
			
		||||
                            if not self.channel.connected():
 | 
			
		||||
                                log.warning(
 | 
			
		||||
                                    "May have failed to cancel remote task "
 | 
			
		||||
                                    f"{cid} for {self.channel.uid}")
 | 
			
		||||
 | 
			
		||||
            # TODO: use AsyncExitStack to aclose() all agens
 | 
			
		||||
            # on teardown
 | 
			
		||||
            agen = yield_from_recvchan()
 | 
			
		||||
            self._agens.add(agen)
 | 
			
		||||
            return agen
 | 
			
		||||
            rchan = StreamReceiveChannel(cid, recv_chan, self)
 | 
			
		||||
            self._streams.add(rchan)
 | 
			
		||||
            return rchan
 | 
			
		||||
 | 
			
		||||
        elif resptype == 'return':  # single response
 | 
			
		||||
            msg = await recv_chan.receive()
 | 
			
		||||
| 
						 | 
				
			
			@ -224,20 +258,21 @@ class Portal:
 | 
			
		|||
 | 
			
		||||
        return self._result
 | 
			
		||||
 | 
			
		||||
    # async def _cancel_streams(self):
 | 
			
		||||
    #     # terminate all locally running async generator
 | 
			
		||||
    #     # IPC calls
 | 
			
		||||
    #     if self._agens:
 | 
			
		||||
    #         log.warning(
 | 
			
		||||
    #             f"Cancelling all streams with {self.channel.uid}")
 | 
			
		||||
    #         for agen in self._agens:
 | 
			
		||||
    #             await agen.aclose()
 | 
			
		||||
    async def _cancel_streams(self):
 | 
			
		||||
        # terminate all locally running async generator
 | 
			
		||||
        # IPC calls
 | 
			
		||||
        if self._streams:
 | 
			
		||||
            log.warning(
 | 
			
		||||
                f"Cancelling all streams with {self.channel.uid}")
 | 
			
		||||
            for stream in self._streams.copy():
 | 
			
		||||
                await stream.aclose()
 | 
			
		||||
 | 
			
		||||
    async def aclose(self) -> None:
 | 
			
		||||
        log.debug(f"Closing {self}")
 | 
			
		||||
        # TODO: once we move to implementing our own `ReceiveChannel`
 | 
			
		||||
        # (including remote task cancellation inside its `.aclose()`)
 | 
			
		||||
        # we'll need to .aclose all those channels here
 | 
			
		||||
        pass
 | 
			
		||||
        await self._cancel_streams()
 | 
			
		||||
 | 
			
		||||
    async def cancel_actor(self) -> bool:
 | 
			
		||||
        """Cancel the actor on the other end of this portal.
 | 
			
		||||
| 
						 | 
				
			
			@ -246,7 +281,7 @@ class Portal:
 | 
			
		|||
            log.warning("This portal is already closed can't cancel")
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        # await self._cancel_streams()
 | 
			
		||||
        await self._cancel_streams()
 | 
			
		||||
 | 
			
		||||
        log.warning(
 | 
			
		||||
            f"Sending actor cancel request to {self.channel.uid} on "
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue