Add streaming decode support for `msgspec`

Add a `tractor._ipc.MsgspecStream` type which can be swapped in for
`msgspec` serialization transparently. A small msg-length-prefix framing
is implemented as part of the type and we use
`tricycle.BufferedReceieveStream` to handle buffering logic for the
underlying transport.

Notes:
- had to force cast a few more list  -> tuple spots due to no native
  `tuple`decode-by-default in `msgspec`: https://github.com/jcrist/msgspec/issues/30
- the framing can be understood by this protobuf walkthrough:
  https://eli.thegreenplace.net/2011/08/02/length-prefix-framing-for-protocol-buffers
- `tricycle` becomes a new dependency
msgspec_not_fucked
Tyler Goodlet 2021-06-11 16:38:25 -04:00
parent d89e632a16
commit 39453e43e0
1 changed files with 78 additions and 19 deletions

View File

@ -1,17 +1,19 @@
""" """
Inter-process comms abstractions Inter-process comms abstractions
""" """
import typing
from typing import Any, Tuple, Optional, Callable
from functools import partial from functools import partial
import struct
import typing
from typing import Any, Tuple, Optional
from tricycle import BufferedReceiveStream
import msgpack import msgpack
import msgspec import msgspec
import trio import trio
from async_generator import asynccontextmanager from async_generator import asynccontextmanager
from .log import get_logger from .log import get_logger
log = get_logger('ipc') log = get_logger(__name__)
# :eyeroll: # :eyeroll:
try: try:
@ -22,21 +24,14 @@ except ImportError:
Unpacker = partial(msgpack.Unpacker, strict_map_key=False) Unpacker = partial(msgpack.Unpacker, strict_map_key=False)
ms_decode = msgspec.Encoder().encode
class MsgpackStream: class MsgpackStream:
"""A ``trio.SocketStream`` delivering ``msgpack`` formatted data. '''A ``trio.SocketStream`` delivering ``msgpack`` formatted data
using ``msgpack-python``.
""" '''
def __init__( def __init__(
self, self,
stream: trio.SocketStream, stream: trio.SocketStream,
serialize: Callable = Unpacker(
raw=False,
use_list=False,
).feed,
deserialize: Callable = msgpack.dumps,
) -> None: ) -> None:
@ -62,7 +57,6 @@ class MsgpackStream:
raw=False, raw=False,
use_list=False, use_list=False,
) )
# decoder = msgspec.Decoder() #dict[str, Any])
while True: while True:
try: try:
data = await self.stream.receive_some(2**10) data = await self.stream.receive_some(2**10)
@ -75,7 +69,6 @@ class MsgpackStream:
log.debug(f"Stream connection {self.raddr} was closed") log.debug(f"Stream connection {self.raddr} was closed")
return return
# yield decoder.decode(data)
unpacker.feed(data) unpacker.feed(data)
for packet in unpacker: for packet in unpacker:
yield packet yield packet
@ -92,8 +85,7 @@ class MsgpackStream:
async def send(self, data: Any) -> None: async def send(self, data: Any) -> None:
async with self._send_lock: async with self._send_lock:
return await self.stream.send_all( return await self.stream.send_all(
# msgpack.dumps(data, use_bin_type=True)) msgpack.dumps(data, use_bin_type=True)
ms_decode(data)
) )
async def recv(self) -> Any: async def recv(self) -> Any:
@ -106,26 +98,93 @@ class MsgpackStream:
return self.stream.socket.fileno() != -1 return self.stream.socket.fileno() != -1
class MsgspecStream(MsgpackStream):
'''A ``trio.SocketStream`` delivering ``msgpack`` formatted data
using ``msgspec``.
'''
ms_encode = msgspec.Encoder().encode
def __init__(
self,
stream: trio.SocketStream,
prefix_size: int = 4,
) -> None:
super().__init__(stream)
self.recv_stream = BufferedReceiveStream(transport_stream=stream)
self.prefix_size = prefix_size
async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
"""Yield packets from the underlying stream.
"""
decoder = msgspec.Decoder() # dict[str, Any])
while True:
try:
header = await self.recv_stream.receive_exactly(4)
if header is None:
continue
if header == b'':
log.debug(f"Stream connection {self.raddr} was closed")
return
size, = struct.unpack("<I", header)
log.trace(f'received header {size}')
msg_bytes = await self.recv_stream.receive_exactly(size)
# the value error here is to catch a connect with immediate
# disconnect that will cause an EOF error inside `tricycle`.
except (ValueError, trio.BrokenResourceError):
log.warning(f"Stream connection {self.raddr} broke")
return
log.trace(f"received {msg_bytes}") # type: ignore
yield decoder.decode(msg_bytes)
async def send(self, data: Any) -> None:
async with self._send_lock:
bytes_data = self.ms_encode(data)
# supposedly the fastest says,
# https://stackoverflow.com/a/54027962
size: int = struct.pack("<I", len(bytes_data))
return await self.stream.send_all(size + bytes_data)
class Channel: class Channel:
"""An inter-process channel for communication between (remote) actors. """An inter-process channel for communication between (remote) actors.
Currently the only supported transport is a ``trio.SocketStream``. Currently the only supported transport is a ``trio.SocketStream``.
""" """
def __init__( def __init__(
self, self,
destaddr: Optional[Tuple[str, int]] = None, destaddr: Optional[Tuple[str, int]] = None,
on_reconnect: typing.Callable[..., typing.Awaitable] = None, on_reconnect: typing.Callable[..., typing.Awaitable] = None,
auto_reconnect: bool = False, auto_reconnect: bool = False,
stream: trio.SocketStream = None, # expected to be active stream: trio.SocketStream = None, # expected to be active
# stream_serializer: type = MsgpackStream,
stream_serializer_type: type = MsgspecStream,
) -> None: ) -> None:
self._recon_seq = on_reconnect self._recon_seq = on_reconnect
self._autorecon = auto_reconnect self._autorecon = auto_reconnect
self.msgstream: Optional[MsgpackStream] = MsgpackStream( self.stream_serializer_type = stream_serializer_type
self.msgstream: Optional[type] = stream_serializer_type(
stream) if stream else None stream) if stream else None
if self.msgstream and destaddr: if self.msgstream and destaddr:
raise ValueError( raise ValueError(
f"A stream was provided with local addr {self.laddr}" f"A stream was provided with local addr {self.laddr}"
) )
self._destaddr = self.msgstream.raddr if self.msgstream else destaddr self._destaddr = self.msgstream.raddr if self.msgstream else destaddr
# set after handshake - always uid of far end # set after handshake - always uid of far end
self.uid: Optional[Tuple[str, str]] = None self.uid: Optional[Tuple[str, str]] = None
@ -157,7 +216,7 @@ class Channel:
destaddr = destaddr or self._destaddr destaddr = destaddr or self._destaddr
assert isinstance(destaddr, tuple) assert isinstance(destaddr, tuple)
stream = await trio.open_tcp_stream(*destaddr, **kwargs) stream = await trio.open_tcp_stream(*destaddr, **kwargs)
self.msgstream = MsgpackStream(stream) self.msgstream = self.stream_serializer_type(stream)
return stream return stream
async def send(self, item: Any) -> None: async def send(self, item: Any) -> None: