From 39453e43e09445129b782ccd0caf7e04083f018f Mon Sep 17 00:00:00 2001 From: Tyler Goodlet Date: Fri, 11 Jun 2021 16:38:25 -0400 Subject: [PATCH] Add streaming decode support for `msgspec` Add a `tractor._ipc.MsgspecStream` type which can be swapped in for `msgspec` serialization transparently. A small msg-length-prefix framing is implemented as part of the type and we use `tricycle.BufferedReceieveStream` to handle buffering logic for the underlying transport. Notes: - had to force cast a few more list -> tuple spots due to no native `tuple`decode-by-default in `msgspec`: https://github.com/jcrist/msgspec/issues/30 - the framing can be understood by this protobuf walkthrough: https://eli.thegreenplace.net/2011/08/02/length-prefix-framing-for-protocol-buffers - `tricycle` becomes a new dependency --- tractor/_ipc.py | 97 +++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 78 insertions(+), 19 deletions(-) diff --git a/tractor/_ipc.py b/tractor/_ipc.py index 3d0cac2..48023ea 100644 --- a/tractor/_ipc.py +++ b/tractor/_ipc.py @@ -1,17 +1,19 @@ """ Inter-process comms abstractions """ -import typing -from typing import Any, Tuple, Optional, Callable from functools import partial +import struct +import typing +from typing import Any, Tuple, Optional +from tricycle import BufferedReceiveStream import msgpack import msgspec import trio from async_generator import asynccontextmanager from .log import get_logger -log = get_logger('ipc') +log = get_logger(__name__) # :eyeroll: try: @@ -22,21 +24,14 @@ except ImportError: Unpacker = partial(msgpack.Unpacker, strict_map_key=False) -ms_decode = msgspec.Encoder().encode - - class MsgpackStream: - """A ``trio.SocketStream`` delivering ``msgpack`` formatted data. + '''A ``trio.SocketStream`` delivering ``msgpack`` formatted data + using ``msgpack-python``. - """ + ''' def __init__( self, stream: trio.SocketStream, - serialize: Callable = Unpacker( - raw=False, - use_list=False, - ).feed, - deserialize: Callable = msgpack.dumps, ) -> None: @@ -62,7 +57,6 @@ class MsgpackStream: raw=False, use_list=False, ) - # decoder = msgspec.Decoder() #dict[str, Any]) while True: try: data = await self.stream.receive_some(2**10) @@ -75,7 +69,6 @@ class MsgpackStream: log.debug(f"Stream connection {self.raddr} was closed") return - # yield decoder.decode(data) unpacker.feed(data) for packet in unpacker: yield packet @@ -92,8 +85,7 @@ class MsgpackStream: async def send(self, data: Any) -> None: async with self._send_lock: return await self.stream.send_all( - # msgpack.dumps(data, use_bin_type=True)) - ms_decode(data) + msgpack.dumps(data, use_bin_type=True) ) async def recv(self) -> Any: @@ -106,26 +98,93 @@ class MsgpackStream: return self.stream.socket.fileno() != -1 +class MsgspecStream(MsgpackStream): + '''A ``trio.SocketStream`` delivering ``msgpack`` formatted data + using ``msgspec``. + + ''' + ms_encode = msgspec.Encoder().encode + + def __init__( + self, + stream: trio.SocketStream, + prefix_size: int = 4, + + ) -> None: + super().__init__(stream) + self.recv_stream = BufferedReceiveStream(transport_stream=stream) + self.prefix_size = prefix_size + + async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]: + """Yield packets from the underlying stream. + """ + decoder = msgspec.Decoder() # dict[str, Any]) + + while True: + try: + header = await self.recv_stream.receive_exactly(4) + if header is None: + continue + + if header == b'': + log.debug(f"Stream connection {self.raddr} was closed") + return + + size, = struct.unpack(" None: + async with self._send_lock: + + bytes_data = self.ms_encode(data) + + # supposedly the fastest says, + # https://stackoverflow.com/a/54027962 + size: int = struct.pack(" None: + self._recon_seq = on_reconnect self._autorecon = auto_reconnect - self.msgstream: Optional[MsgpackStream] = MsgpackStream( + self.stream_serializer_type = stream_serializer_type + self.msgstream: Optional[type] = stream_serializer_type( stream) if stream else None + if self.msgstream and destaddr: raise ValueError( f"A stream was provided with local addr {self.laddr}" ) + self._destaddr = self.msgstream.raddr if self.msgstream else destaddr # set after handshake - always uid of far end self.uid: Optional[Tuple[str, str]] = None @@ -157,7 +216,7 @@ class Channel: destaddr = destaddr or self._destaddr assert isinstance(destaddr, tuple) stream = await trio.open_tcp_stream(*destaddr, **kwargs) - self.msgstream = MsgpackStream(stream) + self.msgstream = self.stream_serializer_type(stream) return stream async def send(self, item: Any) -> None: