Add our own "transport closed" signal
This change some super old (and bad) code from the project's very early days. For some redic reason i must have thought masking `trio`'s internal stream / transport errors and a TCP EOF as `StopAsyncIteration` somehow a good idea. The reality is you probably want to know the difference between an unexpected transport error and a simple EOF lol. This begins to resolve that by adding our own special `TransportClosed` error to signal the "graceful" termination of a channel's underlying transport. Oh, and this builds on the `msgspec` integration which helped shed light on the core issues here B)transport_cleaning
							parent
							
								
									44d7988204
								
							
						
					
					
						commit
						80e100f818
					
				| 
						 | 
				
			
			@ -38,6 +38,10 @@ class InternalActorError(RemoteActorError):
 | 
			
		|||
    """
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TransportClosed(trio.ClosedResourceError):
 | 
			
		||||
    "Underlying channel transport was closed prior to use"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NoResult(RuntimeError):
 | 
			
		||||
    "No final result is expected for this actor"
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -63,12 +67,15 @@ def pack_error(exc: BaseException) -> Dict[str, Any]:
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def unpack_error(
 | 
			
		||||
 | 
			
		||||
    msg: Dict[str, Any],
 | 
			
		||||
    chan=None,
 | 
			
		||||
    err_type=RemoteActorError
 | 
			
		||||
 | 
			
		||||
) -> Exception:
 | 
			
		||||
    """Unpack an 'error' message from the wire
 | 
			
		||||
    into a local ``RemoteActorError``.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    tb_str = msg['error'].get('tb_str', '')
 | 
			
		||||
    return err_type(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -10,7 +10,8 @@ import trio
 | 
			
		|||
from async_generator import asynccontextmanager
 | 
			
		||||
 | 
			
		||||
from .log import get_logger
 | 
			
		||||
log = get_logger('ipc')
 | 
			
		||||
from ._exceptions import TransportClosed
 | 
			
		||||
log = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
# :eyeroll:
 | 
			
		||||
try:
 | 
			
		||||
| 
						 | 
				
			
			@ -21,10 +22,17 @@ except ImportError:
 | 
			
		|||
    Unpacker = partial(msgpack.Unpacker, strict_map_key=False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MsgpackStream:
 | 
			
		||||
    """A ``trio.SocketStream`` delivering ``msgpack`` formatted data.
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, stream: trio.SocketStream) -> None:
 | 
			
		||||
class MsgpackTCPStream:
 | 
			
		||||
    '''A ``trio.SocketStream`` delivering ``msgpack`` formatted data
 | 
			
		||||
    using ``msgpack-python``.
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        stream: trio.SocketStream,
 | 
			
		||||
 | 
			
		||||
    ) -> None:
 | 
			
		||||
 | 
			
		||||
        self.stream = stream
 | 
			
		||||
        assert self.stream.socket
 | 
			
		||||
        # should both be IP sockets
 | 
			
		||||
| 
						 | 
				
			
			@ -35,7 +43,10 @@ class MsgpackStream:
 | 
			
		|||
        assert isinstance(rsockname, tuple)
 | 
			
		||||
        self._raddr = rsockname[:2]
 | 
			
		||||
 | 
			
		||||
        # start and seed first entry to read loop
 | 
			
		||||
        self._agen = self._iter_packets()
 | 
			
		||||
        # self._agen.asend(None) is None
 | 
			
		||||
 | 
			
		||||
        self._send_lock = trio.StrictFIFOLock()
 | 
			
		||||
 | 
			
		||||
    async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
 | 
			
		||||
| 
						 | 
				
			
			@ -46,16 +57,13 @@ class MsgpackStream:
 | 
			
		|||
            use_list=False,
 | 
			
		||||
        )
 | 
			
		||||
        while True:
 | 
			
		||||
            try:
 | 
			
		||||
            data = await self.stream.receive_some(2**10)
 | 
			
		||||
            log.trace(f"received {data}")  # type: ignore
 | 
			
		||||
            except trio.BrokenResourceError:
 | 
			
		||||
                log.warning(f"Stream connection {self.raddr} broke")
 | 
			
		||||
                return
 | 
			
		||||
 | 
			
		||||
            if data == b'':
 | 
			
		||||
                log.debug(f"Stream connection {self.raddr} was closed")
 | 
			
		||||
                return
 | 
			
		||||
                raise TransportClosed(
 | 
			
		||||
                    f'transport {self} was already closed prior ro read'
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            unpacker.feed(data)
 | 
			
		||||
            for packet in unpacker:
 | 
			
		||||
| 
						 | 
				
			
			@ -96,10 +104,11 @@ class Channel:
 | 
			
		|||
        on_reconnect: typing.Callable[..., typing.Awaitable] = None,
 | 
			
		||||
        auto_reconnect: bool = False,
 | 
			
		||||
        stream: trio.SocketStream = None,  # expected to be active
 | 
			
		||||
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        self._recon_seq = on_reconnect
 | 
			
		||||
        self._autorecon = auto_reconnect
 | 
			
		||||
        self.msgstream: Optional[MsgpackStream] = MsgpackStream(
 | 
			
		||||
        self.msgstream: Optional[MsgpackTCPStream] = MsgpackTCPStream(
 | 
			
		||||
            stream) if stream else None
 | 
			
		||||
        if self.msgstream and destaddr:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
| 
						 | 
				
			
			@ -112,6 +121,8 @@ class Channel:
 | 
			
		|||
        self._exc: Optional[Exception] = None
 | 
			
		||||
        self._agen = self._aiter_recv()
 | 
			
		||||
 | 
			
		||||
        self._closed: bool = False
 | 
			
		||||
 | 
			
		||||
    def __repr__(self) -> str:
 | 
			
		||||
        if self.msgstream:
 | 
			
		||||
            return repr(
 | 
			
		||||
| 
						 | 
				
			
			@ -128,35 +139,47 @@ class Channel:
 | 
			
		|||
        return self.msgstream.raddr if self.msgstream else None
 | 
			
		||||
 | 
			
		||||
    async def connect(
 | 
			
		||||
        self, destaddr: Tuple[Any, ...] = None,
 | 
			
		||||
        self,
 | 
			
		||||
        destaddr: Tuple[Any, ...] = None,
 | 
			
		||||
        **kwargs
 | 
			
		||||
 | 
			
		||||
    ) -> trio.SocketStream:
 | 
			
		||||
 | 
			
		||||
        if self.connected():
 | 
			
		||||
            raise RuntimeError("channel is already connected?")
 | 
			
		||||
 | 
			
		||||
        destaddr = destaddr or self._destaddr
 | 
			
		||||
        assert isinstance(destaddr, tuple)
 | 
			
		||||
        stream = await trio.open_tcp_stream(*destaddr, **kwargs)
 | 
			
		||||
        self.msgstream = MsgpackStream(stream)
 | 
			
		||||
        self.msgstream = MsgpackTCPStream(stream)
 | 
			
		||||
        return stream
 | 
			
		||||
 | 
			
		||||
    async def send(self, item: Any) -> None:
 | 
			
		||||
 | 
			
		||||
        log.trace(f"send `{item}`")  # type: ignore
 | 
			
		||||
        assert self.msgstream
 | 
			
		||||
 | 
			
		||||
        await self.msgstream.send(item)
 | 
			
		||||
 | 
			
		||||
    async def recv(self) -> Any:
 | 
			
		||||
        assert self.msgstream
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            return await self.msgstream.recv()
 | 
			
		||||
 | 
			
		||||
        except trio.BrokenResourceError:
 | 
			
		||||
            if self._autorecon:
 | 
			
		||||
                await self._reconnect()
 | 
			
		||||
                return await self.recv()
 | 
			
		||||
 | 
			
		||||
            raise
 | 
			
		||||
 | 
			
		||||
    async def aclose(self) -> None:
 | 
			
		||||
        log.debug(f"Closing {self}")
 | 
			
		||||
        assert self.msgstream
 | 
			
		||||
        await self.msgstream.stream.aclose()
 | 
			
		||||
        self._closed = True
 | 
			
		||||
        log.error(f'CLOSING CHAN {self}')
 | 
			
		||||
 | 
			
		||||
    async def __aenter__(self):
 | 
			
		||||
        await self.connect()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue