Add streaming decode support for `msgspec`
Add a `tractor._ipc.MsgspecStream` type which can be swapped in for `msgspec` serialization transparently. A small msg-length-prefix framing is implemented as part of the type and we use `tricycle.BufferedReceieveStream` to handle buffering logic for the underlying transport. Notes: - had to force cast a few more list -> tuple spots due to no native `tuple`decode-by-default in `msgspec`: https://github.com/jcrist/msgspec/issues/30 - the framing can be understood by this protobuf walkthrough: https://eli.thegreenplace.net/2011/08/02/length-prefix-framing-for-protocol-buffers - `tricycle` becomes a new dependencymsgspec_not_fucked
							parent
							
								
									d89e632a16
								
							
						
					
					
						commit
						39453e43e0
					
				|  | @ -1,17 +1,19 @@ | ||||||
| """ | """ | ||||||
| Inter-process comms abstractions | Inter-process comms abstractions | ||||||
| """ | """ | ||||||
| import typing |  | ||||||
| from typing import Any, Tuple, Optional, Callable |  | ||||||
| from functools import partial | from functools import partial | ||||||
|  | import struct | ||||||
|  | import typing | ||||||
|  | from typing import Any, Tuple, Optional | ||||||
| 
 | 
 | ||||||
|  | from tricycle import BufferedReceiveStream | ||||||
| import msgpack | import msgpack | ||||||
| import msgspec | import msgspec | ||||||
| import trio | import trio | ||||||
| from async_generator import asynccontextmanager | from async_generator import asynccontextmanager | ||||||
| 
 | 
 | ||||||
| from .log import get_logger | from .log import get_logger | ||||||
| log = get_logger('ipc') | log = get_logger(__name__) | ||||||
| 
 | 
 | ||||||
| # :eyeroll: | # :eyeroll: | ||||||
| try: | try: | ||||||
|  | @ -22,21 +24,14 @@ except ImportError: | ||||||
|     Unpacker = partial(msgpack.Unpacker, strict_map_key=False) |     Unpacker = partial(msgpack.Unpacker, strict_map_key=False) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| ms_decode = msgspec.Encoder().encode |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class MsgpackStream: | class MsgpackStream: | ||||||
|     """A ``trio.SocketStream`` delivering ``msgpack`` formatted data. |     '''A ``trio.SocketStream`` delivering ``msgpack`` formatted data | ||||||
|  |     using ``msgpack-python``. | ||||||
| 
 | 
 | ||||||
|     """ |     ''' | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         stream: trio.SocketStream, |         stream: trio.SocketStream, | ||||||
|         serialize: Callable = Unpacker( |  | ||||||
|             raw=False, |  | ||||||
|             use_list=False, |  | ||||||
|         ).feed, |  | ||||||
|         deserialize: Callable = msgpack.dumps, |  | ||||||
| 
 | 
 | ||||||
|     ) -> None: |     ) -> None: | ||||||
| 
 | 
 | ||||||
|  | @ -62,7 +57,6 @@ class MsgpackStream: | ||||||
|             raw=False, |             raw=False, | ||||||
|             use_list=False, |             use_list=False, | ||||||
|         ) |         ) | ||||||
|         # decoder = msgspec.Decoder()  #dict[str, Any]) |  | ||||||
|         while True: |         while True: | ||||||
|             try: |             try: | ||||||
|                 data = await self.stream.receive_some(2**10) |                 data = await self.stream.receive_some(2**10) | ||||||
|  | @ -75,7 +69,6 @@ class MsgpackStream: | ||||||
|                 log.debug(f"Stream connection {self.raddr} was closed") |                 log.debug(f"Stream connection {self.raddr} was closed") | ||||||
|                 return |                 return | ||||||
| 
 | 
 | ||||||
|             # yield decoder.decode(data) |  | ||||||
|             unpacker.feed(data) |             unpacker.feed(data) | ||||||
|             for packet in unpacker: |             for packet in unpacker: | ||||||
|                 yield packet |                 yield packet | ||||||
|  | @ -92,8 +85,7 @@ class MsgpackStream: | ||||||
|     async def send(self, data: Any) -> None: |     async def send(self, data: Any) -> None: | ||||||
|         async with self._send_lock: |         async with self._send_lock: | ||||||
|             return await self.stream.send_all( |             return await self.stream.send_all( | ||||||
|                 # msgpack.dumps(data, use_bin_type=True)) |                 msgpack.dumps(data, use_bin_type=True) | ||||||
|                 ms_decode(data) |  | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|     async def recv(self) -> Any: |     async def recv(self) -> Any: | ||||||
|  | @ -106,26 +98,93 @@ class MsgpackStream: | ||||||
|         return self.stream.socket.fileno() != -1 |         return self.stream.socket.fileno() != -1 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class MsgspecStream(MsgpackStream): | ||||||
|  |     '''A ``trio.SocketStream`` delivering ``msgpack`` formatted data | ||||||
|  |     using ``msgspec``. | ||||||
|  | 
 | ||||||
|  |     ''' | ||||||
|  |     ms_encode = msgspec.Encoder().encode | ||||||
|  | 
 | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         stream: trio.SocketStream, | ||||||
|  |         prefix_size: int = 4, | ||||||
|  | 
 | ||||||
|  |     ) -> None: | ||||||
|  |         super().__init__(stream) | ||||||
|  |         self.recv_stream = BufferedReceiveStream(transport_stream=stream) | ||||||
|  |         self.prefix_size = prefix_size | ||||||
|  | 
 | ||||||
|  |     async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]: | ||||||
|  |         """Yield packets from the underlying stream. | ||||||
|  |         """ | ||||||
|  |         decoder = msgspec.Decoder()  # dict[str, Any]) | ||||||
|  | 
 | ||||||
|  |         while True: | ||||||
|  |             try: | ||||||
|  |                 header = await self.recv_stream.receive_exactly(4) | ||||||
|  |                 if header is None: | ||||||
|  |                     continue | ||||||
|  | 
 | ||||||
|  |                 if header == b'': | ||||||
|  |                     log.debug(f"Stream connection {self.raddr} was closed") | ||||||
|  |                     return | ||||||
|  | 
 | ||||||
|  |                 size, = struct.unpack("<I", header) | ||||||
|  | 
 | ||||||
|  |                 log.trace(f'received header {size}') | ||||||
|  | 
 | ||||||
|  |                 msg_bytes = await self.recv_stream.receive_exactly(size) | ||||||
|  | 
 | ||||||
|  |             # the value error here is to catch a connect with immediate | ||||||
|  |             # disconnect that will cause an EOF error inside `tricycle`. | ||||||
|  |             except (ValueError, trio.BrokenResourceError): | ||||||
|  |                 log.warning(f"Stream connection {self.raddr} broke") | ||||||
|  |                 return | ||||||
|  | 
 | ||||||
|  |             log.trace(f"received {msg_bytes}")  # type: ignore | ||||||
|  |             yield decoder.decode(msg_bytes) | ||||||
|  | 
 | ||||||
|  |     async def send(self, data: Any) -> None: | ||||||
|  |         async with self._send_lock: | ||||||
|  | 
 | ||||||
|  |             bytes_data = self.ms_encode(data) | ||||||
|  | 
 | ||||||
|  |             # supposedly the fastest says, | ||||||
|  |             # https://stackoverflow.com/a/54027962 | ||||||
|  |             size: int = struct.pack("<I", len(bytes_data)) | ||||||
|  | 
 | ||||||
|  |             return await self.stream.send_all(size + bytes_data) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class Channel: | class Channel: | ||||||
|     """An inter-process channel for communication between (remote) actors. |     """An inter-process channel for communication between (remote) actors. | ||||||
| 
 | 
 | ||||||
|     Currently the only supported transport is a ``trio.SocketStream``. |     Currently the only supported transport is a ``trio.SocketStream``. | ||||||
|     """ |     """ | ||||||
|     def __init__( |     def __init__( | ||||||
|  | 
 | ||||||
|         self, |         self, | ||||||
|         destaddr: Optional[Tuple[str, int]] = None, |         destaddr: Optional[Tuple[str, int]] = None, | ||||||
|         on_reconnect: typing.Callable[..., typing.Awaitable] = None, |         on_reconnect: typing.Callable[..., typing.Awaitable] = None, | ||||||
|         auto_reconnect: bool = False, |         auto_reconnect: bool = False, | ||||||
|         stream: trio.SocketStream = None,  # expected to be active |         stream: trio.SocketStream = None,  # expected to be active | ||||||
|  |         # stream_serializer: type = MsgpackStream, | ||||||
|  |         stream_serializer_type: type = MsgspecStream, | ||||||
|  | 
 | ||||||
|     ) -> None: |     ) -> None: | ||||||
|  | 
 | ||||||
|         self._recon_seq = on_reconnect |         self._recon_seq = on_reconnect | ||||||
|         self._autorecon = auto_reconnect |         self._autorecon = auto_reconnect | ||||||
|         self.msgstream: Optional[MsgpackStream] = MsgpackStream( |         self.stream_serializer_type = stream_serializer_type | ||||||
|  |         self.msgstream: Optional[type] = stream_serializer_type( | ||||||
|             stream) if stream else None |             stream) if stream else None | ||||||
|  | 
 | ||||||
|         if self.msgstream and destaddr: |         if self.msgstream and destaddr: | ||||||
|             raise ValueError( |             raise ValueError( | ||||||
|                 f"A stream was provided with local addr {self.laddr}" |                 f"A stream was provided with local addr {self.laddr}" | ||||||
|             ) |             ) | ||||||
|  | 
 | ||||||
|         self._destaddr = self.msgstream.raddr if self.msgstream else destaddr |         self._destaddr = self.msgstream.raddr if self.msgstream else destaddr | ||||||
|         # set after handshake - always uid of far end |         # set after handshake - always uid of far end | ||||||
|         self.uid: Optional[Tuple[str, str]] = None |         self.uid: Optional[Tuple[str, str]] = None | ||||||
|  | @ -157,7 +216,7 @@ class Channel: | ||||||
|         destaddr = destaddr or self._destaddr |         destaddr = destaddr or self._destaddr | ||||||
|         assert isinstance(destaddr, tuple) |         assert isinstance(destaddr, tuple) | ||||||
|         stream = await trio.open_tcp_stream(*destaddr, **kwargs) |         stream = await trio.open_tcp_stream(*destaddr, **kwargs) | ||||||
|         self.msgstream = MsgpackStream(stream) |         self.msgstream = self.stream_serializer_type(stream) | ||||||
|         return stream |         return stream | ||||||
| 
 | 
 | ||||||
|     async def send(self, item: Any) -> None: |     async def send(self, item: Any) -> None: | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue