From 07e8821cd5d3b71413c4e3abb46ed37e1c714dfd Mon Sep 17 00:00:00 2001 From: Tyler Goodlet Date: Tue, 7 Sep 2021 21:07:33 -0400 Subject: [PATCH] Add a stream type factory --- tractor/_ipc.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/tractor/_ipc.py b/tractor/_ipc.py index e420509..6f3fffa 100644 --- a/tractor/_ipc.py +++ b/tractor/_ipc.py @@ -5,7 +5,7 @@ Inter-process comms abstractions import platform import struct import typing -from typing import Any, Tuple, Optional +from typing import Any, Tuple, Optional, Type from tricycle import BufferedReceiveStream import msgpack @@ -55,6 +55,7 @@ class MsgpackTCPStream: unpacker = msgpack.Unpacker( raw=False, use_list=False, + strict_map_key=False ) while True: try: @@ -130,12 +131,12 @@ class MsgspecTCPStream(MsgpackTCPStream): prefix_size: int = 4, ) -> None: + import msgspec + super().__init__(stream) self.recv_stream = BufferedReceiveStream(transport_stream=stream) self.prefix_size = prefix_size - import msgspec - # TODO: struct aware messaging coders self.encode = msgspec.Encoder().encode self.decode = msgspec.Decoder().decode # dict[str, Any]) @@ -185,7 +186,7 @@ class MsgspecTCPStream(MsgpackTCPStream): # ignore decoding errors for now and assume they have to # do with a channel drop - hope that receiving from the # channel will raise an expected error and bubble up. - log.error(f'`msgspec` failed to decode!?') + log.error('`msgspec` failed to decode!?') last_decode_failed = True async def send(self, data: Any) -> None: @@ -200,11 +201,21 @@ class MsgspecTCPStream(MsgpackTCPStream): return await self.stream.send_all(size + bytes_data) +def get_serializer_stream_type( + name: str, +) -> Type: + return { + 'msgpack': MsgpackTCPStream, + 'msgspec': MsgspecTCPStream, + }[name] + + class Channel: - """An inter-process channel for communication between (remote) actors. + '''An inter-process channel for communication between (remote) actors. Currently the only supported transport is a ``trio.SocketStream``. - """ + + ''' def __init__( self, @@ -218,17 +229,17 @@ class Channel: self._recon_seq = on_reconnect self._autorecon = auto_reconnect - stream_serializer_type = MsgpackTCPStream - + # TODO: maybe expose this through the nursery api? try: # if installed load the msgspec transport since it's faster import msgspec # noqa - stream_serializer_type = MsgspecTCPStream + serializer = 'msgspec' except ImportError: - pass + serializer = 'msgpack' - self.stream_serializer_type = stream_serializer_type - self.msgstream = stream_serializer_type(stream) if stream else None + self.stream_serializer_type = get_serializer_stream_type(serializer) + self.msgstream = self.stream_serializer_type( + stream) if stream else None if self.msgstream and destaddr: raise ValueError(