diff --git a/tractor/_ipc.py b/tractor/_ipc.py index 6051a15..5989a2e 100644 --- a/tractor/_ipc.py +++ b/tractor/_ipc.py @@ -3,10 +3,11 @@ Inter-process comms abstractions """ import platform +import struct import typing -from typing import Any, Tuple, Optional, Callable -from functools import partial +from typing import Any, Tuple, Optional +from tricycle import BufferedReceiveStream import msgpack import msgspec import trio @@ -18,20 +19,11 @@ log = get_logger(__name__) _is_windows = platform.system() == 'Windows' - -# :eyeroll: -try: - import msgpack_numpy - Unpacker = msgpack_numpy.Unpacker -except ImportError: - # just plain ``msgpack`` requires tweaking key settings - Unpacker = partial(msgpack.Unpacker, strict_map_key=False) - - +log = get_logger(__name__) ms_decode = msgspec.Encoder().encode -class MsgpackTCPStream: +class MsgpackStream: '''A ``trio.SocketStream`` delivering ``msgpack`` formatted data using ``msgpack-python``. @@ -39,6 +31,7 @@ class MsgpackTCPStream: def __init__( self, stream: trio.SocketStream, + ) -> None: self.stream = stream @@ -62,11 +55,10 @@ class MsgpackTCPStream: async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]: """Yield packets from the underlying stream. """ - unpacker = Unpacker( + unpacker = msgpack.Unpacker( raw=False, use_list=False, ) - # decoder = msgspec.Decoder() #dict[str, Any]) while True: try: data = await self.stream.receive_some(2**10) @@ -101,7 +93,6 @@ class MsgpackTCPStream: f'transport {self} was already closed prior ro read' ) - # yield decoder.decode(data) unpacker.feed(data) for packet in unpacker: yield packet @@ -118,8 +109,7 @@ class MsgpackTCPStream: async def send(self, data: Any) -> None: async with self._send_lock: return await self.stream.send_all( - # msgpack.dumps(data, use_bin_type=True)) - ms_decode(data) + msgpack.dumps(data, use_bin_type=True) ) async def recv(self) -> Any: @@ -132,27 +122,95 @@ class MsgpackTCPStream: return self.stream.socket.fileno() != -1 +class MsgspecStream(MsgpackStream): + '''A ``trio.SocketStream`` delivering ``msgpack`` formatted data + using ``msgspec``. + + ''' + ms_encode = msgspec.Encoder().encode + + def __init__( + self, + stream: trio.SocketStream, + prefix_size: int = 4, + + ) -> None: + super().__init__(stream) + self.recv_stream = BufferedReceiveStream(transport_stream=stream) + self.prefix_size = prefix_size + + async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]: + """Yield packets from the underlying stream. + """ + decoder = msgspec.Decoder() # dict[str, Any]) + + while True: + try: + header = await self.recv_stream.receive_exactly(4) + if header is None: + continue + + if header == b'': + log.debug(f"Stream connection {self.raddr} was closed") + return + + size, = struct.unpack(" None: + async with self._send_lock: + + bytes_data = self.ms_encode(data) + + # supposedly the fastest says, + # https://stackoverflow.com/a/54027962 + size: int = struct.pack(" None: + self._recon_seq = on_reconnect self._autorecon = auto_reconnect - self.msgstream: Optional[MsgpackTCPStream] = MsgpackTCPStream( + + self.stream_serializer_type = stream_serializer_type + self.msgstream: Optional[type] = stream_serializer_type( stream) if stream else None + if self.msgstream and destaddr: raise ValueError( f"A stream was provided with local addr {self.laddr}" ) + self._destaddr = self.msgstream.raddr if self.msgstream else destaddr # set after handshake - always uid of far end self.uid: Optional[Tuple[str, str]] = None @@ -195,7 +253,7 @@ class Channel: *destaddr, **kwargs ) - self.msgstream = MsgpackTCPStream(stream) + self.msgstream = self.stream_serializer_type(stream) log.transport( f'Opened channel to peer {self.laddr} -> {self.raddr}'