Add streaming decode support for `msgspec`

Add a `tractor._ipc.MsgspecStream` type which can be swapped in for
`msgspec` serialization transparently. A small msg-length-prefix framing
is implemented as part of the type and we use
`tricycle.BufferedReceieveStream` to handle buffering logic for the
underlying transport.

Notes:
- had to force cast a few more list  -> tuple spots due to no native
  `tuple`decode-by-default in `msgspec`: https://github.com/jcrist/msgspec/issues/30
- the framing can be understood by this protobuf walkthrough:
  https://eli.thegreenplace.net/2011/08/02/length-prefix-framing-for-protocol-buffers
- `tricycle` becomes a new dependency
msgspec_infect_asyncio
Tyler Goodlet 2021-06-11 16:38:25 -04:00
parent 5e03108211
commit bc6af2219e
1 changed files with 78 additions and 20 deletions

View File

@ -3,10 +3,11 @@ Inter-process comms abstractions
"""
import platform
import struct
import typing
from typing import Any, Tuple, Optional, Callable
from functools import partial
from typing import Any, Tuple, Optional
from tricycle import BufferedReceiveStream
import msgpack
import msgspec
import trio
@ -18,20 +19,11 @@ log = get_logger(__name__)
_is_windows = platform.system() == 'Windows'
# :eyeroll:
try:
import msgpack_numpy
Unpacker = msgpack_numpy.Unpacker
except ImportError:
# just plain ``msgpack`` requires tweaking key settings
Unpacker = partial(msgpack.Unpacker, strict_map_key=False)
log = get_logger(__name__)
ms_decode = msgspec.Encoder().encode
class MsgpackTCPStream:
class MsgpackStream:
'''A ``trio.SocketStream`` delivering ``msgpack`` formatted data
using ``msgpack-python``.
@ -39,6 +31,7 @@ class MsgpackTCPStream:
def __init__(
self,
stream: trio.SocketStream,
) -> None:
self.stream = stream
@ -62,11 +55,10 @@ class MsgpackTCPStream:
async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
"""Yield packets from the underlying stream.
"""
unpacker = Unpacker(
unpacker = msgpack.Unpacker(
raw=False,
use_list=False,
)
# decoder = msgspec.Decoder() #dict[str, Any])
while True:
try:
data = await self.stream.receive_some(2**10)
@ -101,7 +93,6 @@ class MsgpackTCPStream:
f'transport {self} was already closed prior ro read'
)
# yield decoder.decode(data)
unpacker.feed(data)
for packet in unpacker:
yield packet
@ -118,8 +109,7 @@ class MsgpackTCPStream:
async def send(self, data: Any) -> None:
async with self._send_lock:
return await self.stream.send_all(
# msgpack.dumps(data, use_bin_type=True))
ms_decode(data)
msgpack.dumps(data, use_bin_type=True)
)
async def recv(self) -> Any:
@ -132,27 +122,95 @@ class MsgpackTCPStream:
return self.stream.socket.fileno() != -1
class MsgspecStream(MsgpackStream):
'''A ``trio.SocketStream`` delivering ``msgpack`` formatted data
using ``msgspec``.
'''
ms_encode = msgspec.Encoder().encode
def __init__(
self,
stream: trio.SocketStream,
prefix_size: int = 4,
) -> None:
super().__init__(stream)
self.recv_stream = BufferedReceiveStream(transport_stream=stream)
self.prefix_size = prefix_size
async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
"""Yield packets from the underlying stream.
"""
decoder = msgspec.Decoder() # dict[str, Any])
while True:
try:
header = await self.recv_stream.receive_exactly(4)
if header is None:
continue
if header == b'':
log.debug(f"Stream connection {self.raddr} was closed")
return
size, = struct.unpack("<I", header)
log.trace(f'received header {size}')
msg_bytes = await self.recv_stream.receive_exactly(size)
# the value error here is to catch a connect with immediate
# disconnect that will cause an EOF error inside `tricycle`.
except (ValueError, trio.BrokenResourceError):
log.warning(f"Stream connection {self.raddr} broke")
return
log.trace(f"received {msg_bytes}") # type: ignore
yield decoder.decode(msg_bytes)
async def send(self, data: Any) -> None:
async with self._send_lock:
bytes_data = self.ms_encode(data)
# supposedly the fastest says,
# https://stackoverflow.com/a/54027962
size: int = struct.pack("<I", len(bytes_data))
return await self.stream.send_all(size + bytes_data)
class Channel:
"""An inter-process channel for communication between (remote) actors.
Currently the only supported transport is a ``trio.SocketStream``.
"""
def __init__(
self,
destaddr: Optional[Tuple[str, int]] = None,
on_reconnect: typing.Callable[..., typing.Awaitable] = None,
auto_reconnect: bool = False,
stream: trio.SocketStream = None, # expected to be active
# stream_serializer: type = MsgpackStream,
stream_serializer_type: type = MsgspecStream,
) -> None:
self._recon_seq = on_reconnect
self._autorecon = auto_reconnect
self.msgstream: Optional[MsgpackTCPStream] = MsgpackTCPStream(
self.stream_serializer_type = stream_serializer_type
self.msgstream: Optional[type] = stream_serializer_type(
stream) if stream else None
if self.msgstream and destaddr:
raise ValueError(
f"A stream was provided with local addr {self.laddr}"
)
self._destaddr = self.msgstream.raddr if self.msgstream else destaddr
# set after handshake - always uid of far end
self.uid: Optional[Tuple[str, str]] = None
@ -195,7 +253,7 @@ class Channel:
*destaddr,
**kwargs
)
self.msgstream = MsgpackTCPStream(stream)
self.msgstream = self.stream_serializer_type(stream)
log.transport(
f'Opened channel to peer {self.laddr} -> {self.raddr}'