Add streaming decode support for `msgspec`

Add a `tractor._ipc.MsgspecStream` type which can be swapped in for
`msgspec` serialization transparently. A small msg-length-prefix framing
is implemented as part of the type and we use
`tricycle.BufferedReceieveStream` to handle buffering logic for the
underlying transport.

Notes:
- had to force cast a few more list  -> tuple spots due to no native
  `tuple`decode-by-default in `msgspec`: https://github.com/jcrist/msgspec/issues/30
- the framing can be understood by this protobuf walkthrough:
  https://eli.thegreenplace.net/2011/08/02/length-prefix-framing-for-protocol-buffers
- `tricycle` becomes a new dependency
prehardkill
Tyler Goodlet 2021-06-11 16:38:25 -04:00
parent 2d36cf478d
commit 5b65dd8871
1 changed files with 78 additions and 19 deletions

View File

@ -1,17 +1,19 @@
"""
Inter-process comms abstractions
"""
import typing
from typing import Any, Tuple, Optional, Callable
from functools import partial
import struct
import typing
from typing import Any, Tuple, Optional
from tricycle import BufferedReceiveStream
import msgpack
import msgspec
import trio
from async_generator import asynccontextmanager
from .log import get_logger
log = get_logger('ipc')
log = get_logger(__name__)
# :eyeroll:
try:
@ -22,21 +24,14 @@ except ImportError:
Unpacker = partial(msgpack.Unpacker, strict_map_key=False)
ms_decode = msgspec.Encoder().encode
class MsgpackStream:
"""A ``trio.SocketStream`` delivering ``msgpack`` formatted data.
'''A ``trio.SocketStream`` delivering ``msgpack`` formatted data
using ``msgpack-python``.
"""
'''
def __init__(
self,
stream: trio.SocketStream,
serialize: Callable = Unpacker(
raw=False,
use_list=False,
).feed,
deserialize: Callable = msgpack.dumps,
) -> None:
@ -62,7 +57,6 @@ class MsgpackStream:
raw=False,
use_list=False,
)
# decoder = msgspec.Decoder() #dict[str, Any])
while True:
try:
data = await self.stream.receive_some(2**10)
@ -75,7 +69,6 @@ class MsgpackStream:
log.debug(f"Stream connection {self.raddr} was closed")
return
# yield decoder.decode(data)
unpacker.feed(data)
for packet in unpacker:
yield packet
@ -92,8 +85,7 @@ class MsgpackStream:
async def send(self, data: Any) -> None:
async with self._send_lock:
return await self.stream.send_all(
# msgpack.dumps(data, use_bin_type=True))
ms_decode(data)
msgpack.dumps(data, use_bin_type=True)
)
async def recv(self) -> Any:
@ -106,26 +98,93 @@ class MsgpackStream:
return self.stream.socket.fileno() != -1
class MsgspecStream(MsgpackStream):
'''A ``trio.SocketStream`` delivering ``msgpack`` formatted data
using ``msgspec``.
'''
ms_encode = msgspec.Encoder().encode
def __init__(
self,
stream: trio.SocketStream,
prefix_size: int = 4,
) -> None:
super().__init__(stream)
self.recv_stream = BufferedReceiveStream(transport_stream=stream)
self.prefix_size = prefix_size
async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
"""Yield packets from the underlying stream.
"""
decoder = msgspec.Decoder() # dict[str, Any])
while True:
try:
header = await self.recv_stream.receive_exactly(4)
if header is None:
continue
if header == b'':
log.debug(f"Stream connection {self.raddr} was closed")
return
size, = struct.unpack("<I", header)
log.trace(f'received header {size}')
msg_bytes = await self.recv_stream.receive_exactly(size)
# the value error here is to catch a connect with immediate
# disconnect that will cause an EOF error inside `tricycle`.
except (ValueError, trio.BrokenResourceError):
log.warning(f"Stream connection {self.raddr} broke")
return
log.trace(f"received {msg_bytes}") # type: ignore
yield decoder.decode(msg_bytes)
async def send(self, data: Any) -> None:
async with self._send_lock:
bytes_data = self.ms_encode(data)
# supposedly the fastest says,
# https://stackoverflow.com/a/54027962
size: int = struct.pack("<I", len(bytes_data))
return await self.stream.send_all(size + bytes_data)
class Channel:
"""An inter-process channel for communication between (remote) actors.
Currently the only supported transport is a ``trio.SocketStream``.
"""
def __init__(
self,
destaddr: Optional[Tuple[str, int]] = None,
on_reconnect: typing.Callable[..., typing.Awaitable] = None,
auto_reconnect: bool = False,
stream: trio.SocketStream = None, # expected to be active
# stream_serializer: type = MsgpackStream,
stream_serializer_type: type = MsgspecStream,
) -> None:
self._recon_seq = on_reconnect
self._autorecon = auto_reconnect
self.msgstream: Optional[MsgpackStream] = MsgpackStream(
self.stream_serializer_type = stream_serializer_type
self.msgstream: Optional[type] = stream_serializer_type(
stream) if stream else None
if self.msgstream and destaddr:
raise ValueError(
f"A stream was provided with local addr {self.laddr}"
)
self._destaddr = self.msgstream.raddr if self.msgstream else destaddr
# set after handshake - always uid of far end
self.uid: Optional[Tuple[str, str]] = None
@ -157,7 +216,7 @@ class Channel:
destaddr = destaddr or self._destaddr
assert isinstance(destaddr, tuple)
stream = await trio.open_tcp_stream(*destaddr, **kwargs)
self.msgstream = MsgpackStream(stream)
self.msgstream = self.stream_serializer_type(stream)
return stream
async def send(self, item: Any) -> None: