Add a stream type factory

optional_msgspec_support
Tyler Goodlet 2021-09-07 21:07:33 -04:00
parent 5b23a3bc35
commit 07e8821cd5
1 changed files with 23 additions and 12 deletions

View File

@ -5,7 +5,7 @@ Inter-process comms abstractions
import platform
import struct
import typing
from typing import Any, Tuple, Optional
from typing import Any, Tuple, Optional, Type
from tricycle import BufferedReceiveStream
import msgpack
@ -55,6 +55,7 @@ class MsgpackTCPStream:
unpacker = msgpack.Unpacker(
raw=False,
use_list=False,
strict_map_key=False
)
while True:
try:
@ -130,12 +131,12 @@ class MsgspecTCPStream(MsgpackTCPStream):
prefix_size: int = 4,
) -> None:
import msgspec
super().__init__(stream)
self.recv_stream = BufferedReceiveStream(transport_stream=stream)
self.prefix_size = prefix_size
import msgspec
# TODO: struct aware messaging coders
self.encode = msgspec.Encoder().encode
self.decode = msgspec.Decoder().decode # dict[str, Any])
@ -185,7 +186,7 @@ class MsgspecTCPStream(MsgpackTCPStream):
# ignore decoding errors for now and assume they have to
# do with a channel drop - hope that receiving from the
# channel will raise an expected error and bubble up.
log.error(f'`msgspec` failed to decode!?')
log.error('`msgspec` failed to decode!?')
last_decode_failed = True
async def send(self, data: Any) -> None:
@ -200,11 +201,21 @@ class MsgspecTCPStream(MsgpackTCPStream):
return await self.stream.send_all(size + bytes_data)
def get_serializer_stream_type(
name: str,
) -> Type:
return {
'msgpack': MsgpackTCPStream,
'msgspec': MsgspecTCPStream,
}[name]
class Channel:
"""An inter-process channel for communication between (remote) actors.
'''An inter-process channel for communication between (remote) actors.
Currently the only supported transport is a ``trio.SocketStream``.
"""
'''
def __init__(
self,
@ -218,17 +229,17 @@ class Channel:
self._recon_seq = on_reconnect
self._autorecon = auto_reconnect
stream_serializer_type = MsgpackTCPStream
# TODO: maybe expose this through the nursery api?
try:
# if installed load the msgspec transport since it's faster
import msgspec # noqa
stream_serializer_type = MsgspecTCPStream
serializer = 'msgspec'
except ImportError:
pass
serializer = 'msgpack'
self.stream_serializer_type = stream_serializer_type
self.msgstream = stream_serializer_type(stream) if stream else None
self.stream_serializer_type = get_serializer_stream_type(serializer)
self.msgstream = self.stream_serializer_type(
stream) if stream else None
if self.msgstream and destaddr:
raise ValueError(