forked from goodboy/tractor
Add streaming decode support for `msgspec`
Add a `tractor._ipc.MsgspecStream` type which can be swapped in for `msgspec` serialization transparently. A small msg-length-prefix framing is implemented as part of the type and we use `tricycle.BufferedReceieveStream` to handle buffering logic for the underlying transport. Notes: - had to force cast a few more list -> tuple spots due to no native `tuple`decode-by-default in `msgspec`: https://github.com/jcrist/msgspec/issues/30 - the framing can be understood by this protobuf walkthrough: https://eli.thegreenplace.net/2011/08/02/length-prefix-framing-for-protocol-buffers - `tricycle` becomes a new dependencyoptional_msgspec_support
parent
e39ee3a9cc
commit
95e35f3d60
|
@ -3,10 +3,11 @@ Inter-process comms abstractions
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import platform
|
import platform
|
||||||
|
import struct
|
||||||
import typing
|
import typing
|
||||||
from typing import Any, Tuple, Optional, Callable
|
from typing import Any, Tuple, Optional
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
|
from tricycle import BufferedReceiveStream
|
||||||
import msgpack
|
import msgpack
|
||||||
import msgspec
|
import msgspec
|
||||||
import trio
|
import trio
|
||||||
|
@ -18,20 +19,11 @@ log = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
_is_windows = platform.system() == 'Windows'
|
_is_windows = platform.system() == 'Windows'
|
||||||
|
log = get_logger(__name__)
|
||||||
# :eyeroll:
|
|
||||||
try:
|
|
||||||
import msgpack_numpy
|
|
||||||
Unpacker = msgpack_numpy.Unpacker
|
|
||||||
except ImportError:
|
|
||||||
# just plain ``msgpack`` requires tweaking key settings
|
|
||||||
Unpacker = partial(msgpack.Unpacker, strict_map_key=False)
|
|
||||||
|
|
||||||
|
|
||||||
ms_decode = msgspec.Encoder().encode
|
ms_decode = msgspec.Encoder().encode
|
||||||
|
|
||||||
|
|
||||||
class MsgpackTCPStream:
|
class MsgpackStream:
|
||||||
'''A ``trio.SocketStream`` delivering ``msgpack`` formatted data
|
'''A ``trio.SocketStream`` delivering ``msgpack`` formatted data
|
||||||
using ``msgpack-python``.
|
using ``msgpack-python``.
|
||||||
|
|
||||||
|
@ -39,6 +31,7 @@ class MsgpackTCPStream:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
stream: trio.SocketStream,
|
stream: trio.SocketStream,
|
||||||
|
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
|
@ -62,11 +55,10 @@ class MsgpackTCPStream:
|
||||||
async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
|
async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
|
||||||
"""Yield packets from the underlying stream.
|
"""Yield packets from the underlying stream.
|
||||||
"""
|
"""
|
||||||
unpacker = Unpacker(
|
unpacker = msgpack.Unpacker(
|
||||||
raw=False,
|
raw=False,
|
||||||
use_list=False,
|
use_list=False,
|
||||||
)
|
)
|
||||||
# decoder = msgspec.Decoder() #dict[str, Any])
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
data = await self.stream.receive_some(2**10)
|
data = await self.stream.receive_some(2**10)
|
||||||
|
@ -101,7 +93,6 @@ class MsgpackTCPStream:
|
||||||
f'transport {self} was already closed prior ro read'
|
f'transport {self} was already closed prior ro read'
|
||||||
)
|
)
|
||||||
|
|
||||||
# yield decoder.decode(data)
|
|
||||||
unpacker.feed(data)
|
unpacker.feed(data)
|
||||||
for packet in unpacker:
|
for packet in unpacker:
|
||||||
yield packet
|
yield packet
|
||||||
|
@ -118,8 +109,7 @@ class MsgpackTCPStream:
|
||||||
async def send(self, data: Any) -> None:
|
async def send(self, data: Any) -> None:
|
||||||
async with self._send_lock:
|
async with self._send_lock:
|
||||||
return await self.stream.send_all(
|
return await self.stream.send_all(
|
||||||
# msgpack.dumps(data, use_bin_type=True))
|
msgpack.dumps(data, use_bin_type=True)
|
||||||
ms_decode(data)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def recv(self) -> Any:
|
async def recv(self) -> Any:
|
||||||
|
@ -132,27 +122,95 @@ class MsgpackTCPStream:
|
||||||
return self.stream.socket.fileno() != -1
|
return self.stream.socket.fileno() != -1
|
||||||
|
|
||||||
|
|
||||||
|
class MsgspecStream(MsgpackStream):
|
||||||
|
'''A ``trio.SocketStream`` delivering ``msgpack`` formatted data
|
||||||
|
using ``msgspec``.
|
||||||
|
|
||||||
|
'''
|
||||||
|
ms_encode = msgspec.Encoder().encode
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
stream: trio.SocketStream,
|
||||||
|
prefix_size: int = 4,
|
||||||
|
|
||||||
|
) -> None:
|
||||||
|
super().__init__(stream)
|
||||||
|
self.recv_stream = BufferedReceiveStream(transport_stream=stream)
|
||||||
|
self.prefix_size = prefix_size
|
||||||
|
|
||||||
|
async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
|
||||||
|
"""Yield packets from the underlying stream.
|
||||||
|
"""
|
||||||
|
decoder = msgspec.Decoder() # dict[str, Any])
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
header = await self.recv_stream.receive_exactly(4)
|
||||||
|
if header is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if header == b'':
|
||||||
|
log.debug(f"Stream connection {self.raddr} was closed")
|
||||||
|
return
|
||||||
|
|
||||||
|
size, = struct.unpack("<I", header)
|
||||||
|
|
||||||
|
log.trace(f'received header {size}')
|
||||||
|
|
||||||
|
msg_bytes = await self.recv_stream.receive_exactly(size)
|
||||||
|
|
||||||
|
# the value error here is to catch a connect with immediate
|
||||||
|
# disconnect that will cause an EOF error inside `tricycle`.
|
||||||
|
except (ValueError, trio.BrokenResourceError):
|
||||||
|
log.warning(f"Stream connection {self.raddr} broke")
|
||||||
|
return
|
||||||
|
|
||||||
|
log.trace(f"received {msg_bytes}") # type: ignore
|
||||||
|
yield decoder.decode(msg_bytes)
|
||||||
|
|
||||||
|
async def send(self, data: Any) -> None:
|
||||||
|
async with self._send_lock:
|
||||||
|
|
||||||
|
bytes_data = self.ms_encode(data)
|
||||||
|
|
||||||
|
# supposedly the fastest says,
|
||||||
|
# https://stackoverflow.com/a/54027962
|
||||||
|
size: int = struct.pack("<I", len(bytes_data))
|
||||||
|
|
||||||
|
return await self.stream.send_all(size + bytes_data)
|
||||||
|
|
||||||
|
|
||||||
class Channel:
|
class Channel:
|
||||||
"""An inter-process channel for communication between (remote) actors.
|
"""An inter-process channel for communication between (remote) actors.
|
||||||
|
|
||||||
Currently the only supported transport is a ``trio.SocketStream``.
|
Currently the only supported transport is a ``trio.SocketStream``.
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
||||||
self,
|
self,
|
||||||
destaddr: Optional[Tuple[str, int]] = None,
|
destaddr: Optional[Tuple[str, int]] = None,
|
||||||
on_reconnect: typing.Callable[..., typing.Awaitable] = None,
|
on_reconnect: typing.Callable[..., typing.Awaitable] = None,
|
||||||
auto_reconnect: bool = False,
|
auto_reconnect: bool = False,
|
||||||
stream: trio.SocketStream = None, # expected to be active
|
stream: trio.SocketStream = None, # expected to be active
|
||||||
|
|
||||||
|
# stream_serializer: type = MsgpackStream,
|
||||||
|
stream_serializer_type: type = MsgspecStream,
|
||||||
|
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
self._recon_seq = on_reconnect
|
self._recon_seq = on_reconnect
|
||||||
self._autorecon = auto_reconnect
|
self._autorecon = auto_reconnect
|
||||||
self.msgstream: Optional[MsgpackTCPStream] = MsgpackTCPStream(
|
|
||||||
|
self.stream_serializer_type = stream_serializer_type
|
||||||
|
self.msgstream: Optional[type] = stream_serializer_type(
|
||||||
stream) if stream else None
|
stream) if stream else None
|
||||||
|
|
||||||
if self.msgstream and destaddr:
|
if self.msgstream and destaddr:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"A stream was provided with local addr {self.laddr}"
|
f"A stream was provided with local addr {self.laddr}"
|
||||||
)
|
)
|
||||||
|
|
||||||
self._destaddr = self.msgstream.raddr if self.msgstream else destaddr
|
self._destaddr = self.msgstream.raddr if self.msgstream else destaddr
|
||||||
# set after handshake - always uid of far end
|
# set after handshake - always uid of far end
|
||||||
self.uid: Optional[Tuple[str, str]] = None
|
self.uid: Optional[Tuple[str, str]] = None
|
||||||
|
@ -195,7 +253,7 @@ class Channel:
|
||||||
*destaddr,
|
*destaddr,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
self.msgstream = MsgpackTCPStream(stream)
|
self.msgstream = self.stream_serializer_type(stream)
|
||||||
|
|
||||||
log.transport(
|
log.transport(
|
||||||
f'Opened channel to peer {self.laddr} -> {self.raddr}'
|
f'Opened channel to peer {self.laddr} -> {self.raddr}'
|
||||||
|
|
Loading…
Reference in New Issue