Add a stream type factory
parent
6cf4a80fe4
commit
0d41f1410f
|
@ -5,7 +5,7 @@ Inter-process comms abstractions
|
|||
import platform
|
||||
import struct
|
||||
import typing
|
||||
from typing import Any, Tuple, Optional
|
||||
from typing import Any, Tuple, Optional, Type
|
||||
|
||||
from tricycle import BufferedReceiveStream
|
||||
import msgpack
|
||||
|
@ -55,6 +55,7 @@ class MsgpackTCPStream:
|
|||
unpacker = msgpack.Unpacker(
|
||||
raw=False,
|
||||
use_list=False,
|
||||
strict_map_key=False
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
|
@ -130,12 +131,12 @@ class MsgspecTCPStream(MsgpackTCPStream):
|
|||
prefix_size: int = 4,
|
||||
|
||||
) -> None:
|
||||
import msgspec
|
||||
|
||||
super().__init__(stream)
|
||||
self.recv_stream = BufferedReceiveStream(transport_stream=stream)
|
||||
self.prefix_size = prefix_size
|
||||
|
||||
import msgspec
|
||||
|
||||
# TODO: struct aware messaging coders
|
||||
self.encode = msgspec.Encoder().encode
|
||||
self.decode = msgspec.Decoder().decode # dict[str, Any])
|
||||
|
@ -185,7 +186,7 @@ class MsgspecTCPStream(MsgpackTCPStream):
|
|||
# ignore decoding errors for now and assume they have to
|
||||
# do with a channel drop - hope that receiving from the
|
||||
# channel will raise an expected error and bubble up.
|
||||
log.error(f'`msgspec` failed to decode!?')
|
||||
log.error('`msgspec` failed to decode!?')
|
||||
last_decode_failed = True
|
||||
|
||||
async def send(self, data: Any) -> None:
|
||||
|
@ -200,11 +201,21 @@ class MsgspecTCPStream(MsgpackTCPStream):
|
|||
return await self.stream.send_all(size + bytes_data)
|
||||
|
||||
|
||||
def get_serializer_stream_type(
|
||||
name: str,
|
||||
) -> Type:
|
||||
return {
|
||||
'msgpack': MsgpackTCPStream,
|
||||
'msgspec': MsgspecTCPStream,
|
||||
}[name]
|
||||
|
||||
|
||||
class Channel:
|
||||
"""An inter-process channel for communication between (remote) actors.
|
||||
'''An inter-process channel for communication between (remote) actors.
|
||||
|
||||
Currently the only supported transport is a ``trio.SocketStream``.
|
||||
"""
|
||||
|
||||
'''
|
||||
def __init__(
|
||||
|
||||
self,
|
||||
|
@ -218,17 +229,17 @@ class Channel:
|
|||
self._recon_seq = on_reconnect
|
||||
self._autorecon = auto_reconnect
|
||||
|
||||
stream_serializer_type = MsgpackTCPStream
|
||||
|
||||
# TODO: maybe expose this through the nursery api?
|
||||
try:
|
||||
# if installed load the msgspec transport since it's faster
|
||||
import msgspec # noqa
|
||||
stream_serializer_type = MsgspecTCPStream
|
||||
serializer = 'msgspec'
|
||||
except ImportError:
|
||||
pass
|
||||
serializer = 'msgpack'
|
||||
|
||||
self.stream_serializer_type = stream_serializer_type
|
||||
self.msgstream = stream_serializer_type(stream) if stream else None
|
||||
self.stream_serializer_type = get_serializer_stream_type(serializer)
|
||||
self.msgstream = self.stream_serializer_type(
|
||||
stream) if stream else None
|
||||
|
||||
if self.msgstream and destaddr:
|
||||
raise ValueError(
|
||||
|
|
Loading…
Reference in New Issue