Add streaming decode support for `msgspec`

Add a `tractor._ipc.MsgspecStream` type which can be swapped in for
`msgspec` serialization transparently. A small msg-length-prefix framing
is implemented as part of the type and we use
`tricycle.BufferedReceieveStream` to handle buffering logic for the
underlying transport.

Notes:
- had to force cast a few more list  -> tuple spots due to no native
  `tuple`decode-by-default in `msgspec`: https://github.com/jcrist/msgspec/issues/30
- the framing can be understood by this protobuf walkthrough:
  https://eli.thegreenplace.net/2011/08/02/length-prefix-framing-for-protocol-buffers
- `tricycle` becomes a new dependency
msgspec_infect_asyncio
Tyler Goodlet 2021-06-11 16:38:25 -04:00
parent 5e03108211
commit bc6af2219e
1 changed files with 78 additions and 20 deletions

View File

@ -3,10 +3,11 @@ Inter-process comms abstractions
""" """
import platform import platform
import struct
import typing import typing
from typing import Any, Tuple, Optional, Callable from typing import Any, Tuple, Optional
from functools import partial
from tricycle import BufferedReceiveStream
import msgpack import msgpack
import msgspec import msgspec
import trio import trio
@ -18,20 +19,11 @@ log = get_logger(__name__)
_is_windows = platform.system() == 'Windows' _is_windows = platform.system() == 'Windows'
log = get_logger(__name__)
# :eyeroll:
try:
import msgpack_numpy
Unpacker = msgpack_numpy.Unpacker
except ImportError:
# just plain ``msgpack`` requires tweaking key settings
Unpacker = partial(msgpack.Unpacker, strict_map_key=False)
ms_decode = msgspec.Encoder().encode ms_decode = msgspec.Encoder().encode
class MsgpackTCPStream: class MsgpackStream:
'''A ``trio.SocketStream`` delivering ``msgpack`` formatted data '''A ``trio.SocketStream`` delivering ``msgpack`` formatted data
using ``msgpack-python``. using ``msgpack-python``.
@ -39,6 +31,7 @@ class MsgpackTCPStream:
def __init__( def __init__(
self, self,
stream: trio.SocketStream, stream: trio.SocketStream,
) -> None: ) -> None:
self.stream = stream self.stream = stream
@ -62,11 +55,10 @@ class MsgpackTCPStream:
async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]: async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
"""Yield packets from the underlying stream. """Yield packets from the underlying stream.
""" """
unpacker = Unpacker( unpacker = msgpack.Unpacker(
raw=False, raw=False,
use_list=False, use_list=False,
) )
# decoder = msgspec.Decoder() #dict[str, Any])
while True: while True:
try: try:
data = await self.stream.receive_some(2**10) data = await self.stream.receive_some(2**10)
@ -101,7 +93,6 @@ class MsgpackTCPStream:
f'transport {self} was already closed prior ro read' f'transport {self} was already closed prior ro read'
) )
# yield decoder.decode(data)
unpacker.feed(data) unpacker.feed(data)
for packet in unpacker: for packet in unpacker:
yield packet yield packet
@ -118,8 +109,7 @@ class MsgpackTCPStream:
async def send(self, data: Any) -> None: async def send(self, data: Any) -> None:
async with self._send_lock: async with self._send_lock:
return await self.stream.send_all( return await self.stream.send_all(
# msgpack.dumps(data, use_bin_type=True)) msgpack.dumps(data, use_bin_type=True)
ms_decode(data)
) )
async def recv(self) -> Any: async def recv(self) -> Any:
@ -132,27 +122,95 @@ class MsgpackTCPStream:
return self.stream.socket.fileno() != -1 return self.stream.socket.fileno() != -1
class MsgspecStream(MsgpackStream):
'''A ``trio.SocketStream`` delivering ``msgpack`` formatted data
using ``msgspec``.
'''
ms_encode = msgspec.Encoder().encode
def __init__(
self,
stream: trio.SocketStream,
prefix_size: int = 4,
) -> None:
super().__init__(stream)
self.recv_stream = BufferedReceiveStream(transport_stream=stream)
self.prefix_size = prefix_size
async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
"""Yield packets from the underlying stream.
"""
decoder = msgspec.Decoder() # dict[str, Any])
while True:
try:
header = await self.recv_stream.receive_exactly(4)
if header is None:
continue
if header == b'':
log.debug(f"Stream connection {self.raddr} was closed")
return
size, = struct.unpack("<I", header)
log.trace(f'received header {size}')
msg_bytes = await self.recv_stream.receive_exactly(size)
# the value error here is to catch a connect with immediate
# disconnect that will cause an EOF error inside `tricycle`.
except (ValueError, trio.BrokenResourceError):
log.warning(f"Stream connection {self.raddr} broke")
return
log.trace(f"received {msg_bytes}") # type: ignore
yield decoder.decode(msg_bytes)
async def send(self, data: Any) -> None:
async with self._send_lock:
bytes_data = self.ms_encode(data)
# supposedly the fastest says,
# https://stackoverflow.com/a/54027962
size: int = struct.pack("<I", len(bytes_data))
return await self.stream.send_all(size + bytes_data)
class Channel: class Channel:
"""An inter-process channel for communication between (remote) actors. """An inter-process channel for communication between (remote) actors.
Currently the only supported transport is a ``trio.SocketStream``. Currently the only supported transport is a ``trio.SocketStream``.
""" """
def __init__( def __init__(
self, self,
destaddr: Optional[Tuple[str, int]] = None, destaddr: Optional[Tuple[str, int]] = None,
on_reconnect: typing.Callable[..., typing.Awaitable] = None, on_reconnect: typing.Callable[..., typing.Awaitable] = None,
auto_reconnect: bool = False, auto_reconnect: bool = False,
stream: trio.SocketStream = None, # expected to be active stream: trio.SocketStream = None, # expected to be active
# stream_serializer: type = MsgpackStream,
stream_serializer_type: type = MsgspecStream,
) -> None: ) -> None:
self._recon_seq = on_reconnect self._recon_seq = on_reconnect
self._autorecon = auto_reconnect self._autorecon = auto_reconnect
self.msgstream: Optional[MsgpackTCPStream] = MsgpackTCPStream(
self.stream_serializer_type = stream_serializer_type
self.msgstream: Optional[type] = stream_serializer_type(
stream) if stream else None stream) if stream else None
if self.msgstream and destaddr: if self.msgstream and destaddr:
raise ValueError( raise ValueError(
f"A stream was provided with local addr {self.laddr}" f"A stream was provided with local addr {self.laddr}"
) )
self._destaddr = self.msgstream.raddr if self.msgstream else destaddr self._destaddr = self.msgstream.raddr if self.msgstream else destaddr
# set after handshake - always uid of far end # set after handshake - always uid of far end
self.uid: Optional[Tuple[str, str]] = None self.uid: Optional[Tuple[str, str]] = None
@ -195,7 +253,7 @@ class Channel:
*destaddr, *destaddr,
**kwargs **kwargs
) )
self.msgstream = MsgpackTCPStream(stream) self.msgstream = self.stream_serializer_type(stream)
log.transport( log.transport(
f'Opened channel to peer {self.laddr} -> {self.raddr}' f'Opened channel to peer {self.laddr} -> {self.raddr}'