""" Inter-process comms abstractions """ from functools import partial import math import struct import typing from typing import Any, Tuple, Optional from tricycle import BufferedReceiveStream import msgpack import msgspec import trio from async_generator import asynccontextmanager from .log import get_logger from ._exceptions import TransportClosed log = get_logger(__name__) # :eyeroll: try: import msgpack_numpy Unpacker = msgpack_numpy.Unpacker except ImportError: # just plain ``msgpack`` requires tweaking key settings Unpacker = partial(msgpack.Unpacker, strict_map_key=False) class MsgpackTCPStream: '''A ``trio.SocketStream`` delivering ``msgpack`` formatted data using ``msgpack-python``. ''' def __init__( self, stream: trio.SocketStream, ) -> None: self.stream = stream assert self.stream.socket # should both be IP sockets lsockname = stream.socket.getsockname() assert isinstance(lsockname, tuple) self._laddr = lsockname[:2] rsockname = stream.socket.getpeername() assert isinstance(rsockname, tuple) self._raddr = rsockname[:2] # start and seed first entry to read loop self._agen = self._iter_packets() # self._agen.asend(None) is None self._send_lock = trio.StrictFIFOLock() async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]: """Yield packets from the underlying stream. """ unpacker = Unpacker( raw=False, use_list=False, ) while True: data = await self.stream.receive_some(2**10) log.trace(f"received {data}") # type: ignore if data == b'': raise TransportClosed( f'transport {self} was already closed prior ro read' ) unpacker.feed(data) for packet in unpacker: yield packet @property def laddr(self) -> Tuple[Any, ...]: return self._laddr @property def raddr(self) -> Tuple[Any, ...]: return self._raddr # XXX: should this instead be called `.sendall()`? async def send(self, data: Any) -> None: async with self._send_lock: return await self.stream.send_all( msgpack.dumps(data, use_bin_type=True) ) async def recv(self) -> Any: return await self._agen.asend(None) def __aiter__(self): return self._agen def connected(self) -> bool: return self.stream.socket.fileno() != -1 class MsgspecTCPStream(MsgpackTCPStream): '''A ``trio.SocketStream`` delivering ``msgpack`` formatted data using ``msgspec``. ''' ms_encode = msgspec.Encoder().encode def __init__( self, stream: trio.SocketStream, prefix_size: int = 4, ) -> None: super().__init__(stream) self.recv_stream = BufferedReceiveStream(transport_stream=stream) self.prefix_size = prefix_size async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]: """Yield packets from the underlying stream. """ decoder = msgspec.Decoder() # dict[str, Any]) while True: try: header = await self.recv_stream.receive_exactly(4) except (ValueError): raise TransportClosed( f'transport {self} was already closed prior ro read' ) if header == b'': raise TransportClosed( f'transport {self} was already closed prior ro read' ) size, = struct.unpack(" None: async with self._send_lock: bytes_data = self.ms_encode(data) # supposedly the fastest says, # https://stackoverflow.com/a/54027962 size: int = struct.pack(" None: self._recon_seq = on_reconnect self._autorecon = auto_reconnect self.stream_serializer_type = stream_serializer_type self.msgstream: Optional[type] = stream_serializer_type( stream) if stream else None if self.msgstream and destaddr: raise ValueError( f"A stream was provided with local addr {self.laddr}" ) self._destaddr = self.msgstream.raddr if self.msgstream else destaddr # set after handshake - always uid of far end self.uid: Optional[Tuple[str, str]] = None # set if far end actor errors internally self._exc: Optional[Exception] = None self._agen = self._aiter_recv() self._closed: bool = False def __repr__(self) -> str: if self.msgstream: return repr( self.msgstream.stream.socket._sock).replace( # type: ignore "socket.socket", "Channel") return object.__repr__(self) @property def laddr(self) -> Optional[Tuple[Any, ...]]: return self.msgstream.laddr if self.msgstream else None @property def raddr(self) -> Optional[Tuple[Any, ...]]: return self.msgstream.raddr if self.msgstream else None async def connect( self, destaddr: Tuple[Any, ...] = None, **kwargs ) -> trio.SocketStream: if self.connected(): raise RuntimeError("channel is already connected?") destaddr = destaddr or self._destaddr assert isinstance(destaddr, tuple) stream = await trio.open_tcp_stream( *destaddr, happy_eyeballs_delay=math.inf, **kwargs ) self.msgstream = self.stream_serializer_type(stream) return stream async def send(self, item: Any) -> None: log.trace(f"send `{item}`") # type: ignore assert self.msgstream await self.msgstream.send(item) async def recv(self) -> Any: assert self.msgstream try: return await self.msgstream.recv() except trio.BrokenResourceError: if self._autorecon: await self._reconnect() return await self.recv() raise async def aclose(self) -> None: log.debug(f"Closing {self}") assert self.msgstream await self.msgstream.stream.aclose() self._closed = True log.error(f'CLOSING CHAN {self}') async def __aenter__(self): await self.connect() return self async def __aexit__(self, *args): await self.aclose(*args) def __aiter__(self): return self._agen async def _reconnect(self) -> None: """Handle connection failures by polling until a reconnect can be established. """ down = False while True: try: with trio.move_on_after(3) as cancel_scope: await self.connect() cancelled = cancel_scope.cancelled_caught if cancelled: log.warning( "Reconnect timed out after 3 seconds, retrying...") continue else: log.warning("Stream connection re-established!") # run any reconnection sequence on_recon = self._recon_seq if on_recon: await on_recon(self) break except (OSError, ConnectionRefusedError): if not down: down = True log.warning( f"Connection to {self.raddr} went down, waiting" " for re-establishment") await trio.sleep(1) async def _aiter_recv( self ) -> typing.AsyncGenerator[Any, None]: """Async iterate items from underlying stream. """ assert self.msgstream while True: try: async for item in self.msgstream: yield item # sent = yield item # if sent is not None: # # optimization, passing None through all the # # time is pointless # await self.msgstream.send(sent) except trio.BrokenResourceError: if not self._autorecon: raise await self.aclose() if self._autorecon: # attempt reconnect await self._reconnect() continue else: return def connected(self) -> bool: return self.msgstream.connected() if self.msgstream else False @asynccontextmanager async def _connect_chan( host: str, port: int ) -> typing.AsyncGenerator[Channel, None]: """Create and connect a channel with disconnect on context manager teardown. """ chan = Channel((host, port)) await chan.connect() yield chan await chan.aclose()