forked from goodboy/tractor
Add a stream type factory
parent
5b23a3bc35
commit
07e8821cd5
|
@ -5,7 +5,7 @@ Inter-process comms abstractions
|
||||||
import platform
|
import platform
|
||||||
import struct
|
import struct
|
||||||
import typing
|
import typing
|
||||||
from typing import Any, Tuple, Optional
|
from typing import Any, Tuple, Optional, Type
|
||||||
|
|
||||||
from tricycle import BufferedReceiveStream
|
from tricycle import BufferedReceiveStream
|
||||||
import msgpack
|
import msgpack
|
||||||
|
@ -55,6 +55,7 @@ class MsgpackTCPStream:
|
||||||
unpacker = msgpack.Unpacker(
|
unpacker = msgpack.Unpacker(
|
||||||
raw=False,
|
raw=False,
|
||||||
use_list=False,
|
use_list=False,
|
||||||
|
strict_map_key=False
|
||||||
)
|
)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
@ -130,12 +131,12 @@ class MsgspecTCPStream(MsgpackTCPStream):
|
||||||
prefix_size: int = 4,
|
prefix_size: int = 4,
|
||||||
|
|
||||||
) -> None:
|
) -> None:
|
||||||
|
import msgspec
|
||||||
|
|
||||||
super().__init__(stream)
|
super().__init__(stream)
|
||||||
self.recv_stream = BufferedReceiveStream(transport_stream=stream)
|
self.recv_stream = BufferedReceiveStream(transport_stream=stream)
|
||||||
self.prefix_size = prefix_size
|
self.prefix_size = prefix_size
|
||||||
|
|
||||||
import msgspec
|
|
||||||
|
|
||||||
# TODO: struct aware messaging coders
|
# TODO: struct aware messaging coders
|
||||||
self.encode = msgspec.Encoder().encode
|
self.encode = msgspec.Encoder().encode
|
||||||
self.decode = msgspec.Decoder().decode # dict[str, Any])
|
self.decode = msgspec.Decoder().decode # dict[str, Any])
|
||||||
|
@ -185,7 +186,7 @@ class MsgspecTCPStream(MsgpackTCPStream):
|
||||||
# ignore decoding errors for now and assume they have to
|
# ignore decoding errors for now and assume they have to
|
||||||
# do with a channel drop - hope that receiving from the
|
# do with a channel drop - hope that receiving from the
|
||||||
# channel will raise an expected error and bubble up.
|
# channel will raise an expected error and bubble up.
|
||||||
log.error(f'`msgspec` failed to decode!?')
|
log.error('`msgspec` failed to decode!?')
|
||||||
last_decode_failed = True
|
last_decode_failed = True
|
||||||
|
|
||||||
async def send(self, data: Any) -> None:
|
async def send(self, data: Any) -> None:
|
||||||
|
@ -200,11 +201,21 @@ class MsgspecTCPStream(MsgpackTCPStream):
|
||||||
return await self.stream.send_all(size + bytes_data)
|
return await self.stream.send_all(size + bytes_data)
|
||||||
|
|
||||||
|
|
||||||
|
def get_serializer_stream_type(
|
||||||
|
name: str,
|
||||||
|
) -> Type:
|
||||||
|
return {
|
||||||
|
'msgpack': MsgpackTCPStream,
|
||||||
|
'msgspec': MsgspecTCPStream,
|
||||||
|
}[name]
|
||||||
|
|
||||||
|
|
||||||
class Channel:
|
class Channel:
|
||||||
"""An inter-process channel for communication between (remote) actors.
|
'''An inter-process channel for communication between (remote) actors.
|
||||||
|
|
||||||
Currently the only supported transport is a ``trio.SocketStream``.
|
Currently the only supported transport is a ``trio.SocketStream``.
|
||||||
"""
|
|
||||||
|
'''
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
||||||
self,
|
self,
|
||||||
|
@ -218,17 +229,17 @@ class Channel:
|
||||||
self._recon_seq = on_reconnect
|
self._recon_seq = on_reconnect
|
||||||
self._autorecon = auto_reconnect
|
self._autorecon = auto_reconnect
|
||||||
|
|
||||||
stream_serializer_type = MsgpackTCPStream
|
# TODO: maybe expose this through the nursery api?
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# if installed load the msgspec transport since it's faster
|
# if installed load the msgspec transport since it's faster
|
||||||
import msgspec # noqa
|
import msgspec # noqa
|
||||||
stream_serializer_type = MsgspecTCPStream
|
serializer = 'msgspec'
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
serializer = 'msgpack'
|
||||||
|
|
||||||
self.stream_serializer_type = stream_serializer_type
|
self.stream_serializer_type = get_serializer_stream_type(serializer)
|
||||||
self.msgstream = stream_serializer_type(stream) if stream else None
|
self.msgstream = self.stream_serializer_type(
|
||||||
|
stream) if stream else None
|
||||||
|
|
||||||
if self.msgstream and destaddr:
|
if self.msgstream and destaddr:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
Loading…
Reference in New Issue