Add a stream type factory

optional_msgspec_support
Tyler Goodlet 2021-09-07 21:07:33 -04:00
parent 5b23a3bc35
commit 07e8821cd5
1 changed files with 23 additions and 12 deletions

View File

@ -5,7 +5,7 @@ Inter-process comms abstractions
import platform import platform
import struct import struct
import typing import typing
from typing import Any, Tuple, Optional from typing import Any, Tuple, Optional, Type
from tricycle import BufferedReceiveStream from tricycle import BufferedReceiveStream
import msgpack import msgpack
@ -55,6 +55,7 @@ class MsgpackTCPStream:
unpacker = msgpack.Unpacker( unpacker = msgpack.Unpacker(
raw=False, raw=False,
use_list=False, use_list=False,
strict_map_key=False
) )
while True: while True:
try: try:
@ -130,12 +131,12 @@ class MsgspecTCPStream(MsgpackTCPStream):
prefix_size: int = 4, prefix_size: int = 4,
) -> None: ) -> None:
import msgspec
super().__init__(stream) super().__init__(stream)
self.recv_stream = BufferedReceiveStream(transport_stream=stream) self.recv_stream = BufferedReceiveStream(transport_stream=stream)
self.prefix_size = prefix_size self.prefix_size = prefix_size
import msgspec
# TODO: struct aware messaging coders # TODO: struct aware messaging coders
self.encode = msgspec.Encoder().encode self.encode = msgspec.Encoder().encode
self.decode = msgspec.Decoder().decode # dict[str, Any]) self.decode = msgspec.Decoder().decode # dict[str, Any])
@ -185,7 +186,7 @@ class MsgspecTCPStream(MsgpackTCPStream):
# ignore decoding errors for now and assume they have to # ignore decoding errors for now and assume they have to
# do with a channel drop - hope that receiving from the # do with a channel drop - hope that receiving from the
# channel will raise an expected error and bubble up. # channel will raise an expected error and bubble up.
log.error(f'`msgspec` failed to decode!?') log.error('`msgspec` failed to decode!?')
last_decode_failed = True last_decode_failed = True
async def send(self, data: Any) -> None: async def send(self, data: Any) -> None:
@ -200,11 +201,21 @@ class MsgspecTCPStream(MsgpackTCPStream):
return await self.stream.send_all(size + bytes_data) return await self.stream.send_all(size + bytes_data)
def get_serializer_stream_type(
name: str,
) -> Type:
return {
'msgpack': MsgpackTCPStream,
'msgspec': MsgspecTCPStream,
}[name]
class Channel: class Channel:
"""An inter-process channel for communication between (remote) actors. '''An inter-process channel for communication between (remote) actors.
Currently the only supported transport is a ``trio.SocketStream``. Currently the only supported transport is a ``trio.SocketStream``.
"""
'''
def __init__( def __init__(
self, self,
@ -218,17 +229,17 @@ class Channel:
self._recon_seq = on_reconnect self._recon_seq = on_reconnect
self._autorecon = auto_reconnect self._autorecon = auto_reconnect
stream_serializer_type = MsgpackTCPStream # TODO: maybe expose this through the nursery api?
try: try:
# if installed load the msgspec transport since it's faster # if installed load the msgspec transport since it's faster
import msgspec # noqa import msgspec # noqa
stream_serializer_type = MsgspecTCPStream serializer = 'msgspec'
except ImportError: except ImportError:
pass serializer = 'msgpack'
self.stream_serializer_type = stream_serializer_type self.stream_serializer_type = get_serializer_stream_type(serializer)
self.msgstream = stream_serializer_type(stream) if stream else None self.msgstream = self.stream_serializer_type(
stream) if stream else None
if self.msgstream and destaddr: if self.msgstream and destaddr:
raise ValueError( raise ValueError(