forked from goodboy/tractor
Pkg `msgpec` as optional dep, load transport type if importable
parent
96b3f94c72
commit
b64396f708
9
setup.py
9
setup.py
|
@ -48,14 +48,21 @@ setup(
|
|||
'tricycle',
|
||||
'trio_typing',
|
||||
|
||||
# tooling
|
||||
'colorlog',
|
||||
'wrapt',
|
||||
'pdbpp',
|
||||
|
||||
# serialization
|
||||
'msgpack',
|
||||
'msgspec',
|
||||
|
||||
],
|
||||
extras_require={
|
||||
|
||||
# serialization
|
||||
'msgspec': ['msgspec; python_version >= 3.9'],
|
||||
|
||||
},
|
||||
tests_require=['pytest'],
|
||||
python_requires=">=3.8",
|
||||
keywords=[
|
||||
|
|
|
@ -9,7 +9,6 @@ from typing import Any, Tuple, Optional
|
|||
|
||||
from tricycle import BufferedReceiveStream
|
||||
import msgpack
|
||||
import msgspec
|
||||
import trio
|
||||
from async_generator import asynccontextmanager
|
||||
|
||||
|
@ -20,7 +19,6 @@ log = get_logger(__name__)
|
|||
|
||||
_is_windows = platform.system() == 'Windows'
|
||||
log = get_logger(__name__)
|
||||
ms_decode = msgspec.Encoder().encode
|
||||
|
||||
|
||||
class MsgpackTCPStream:
|
||||
|
@ -126,8 +124,6 @@ class MsgspecTCPStream(MsgpackTCPStream):
|
|||
using ``msgspec``.
|
||||
|
||||
'''
|
||||
ms_encode = msgspec.Encoder().encode
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream: trio.SocketStream,
|
||||
|
@ -138,10 +134,15 @@ class MsgspecTCPStream(MsgpackTCPStream):
|
|||
self.recv_stream = BufferedReceiveStream(transport_stream=stream)
|
||||
self.prefix_size = prefix_size
|
||||
|
||||
import msgspec
|
||||
|
||||
# TODO: struct aware messaging coders
|
||||
self.encode = msgspec.Encoder().encode
|
||||
self.decode = msgspec.Decoder().decode # dict[str, Any])
|
||||
|
||||
async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
|
||||
"""Yield packets from the underlying stream.
|
||||
"""
|
||||
decoder = msgspec.Decoder() # dict[str, Any])
|
||||
|
||||
while True:
|
||||
try:
|
||||
|
@ -164,12 +165,12 @@ class MsgspecTCPStream(MsgpackTCPStream):
|
|||
msg_bytes = await self.recv_stream.receive_exactly(size)
|
||||
|
||||
log.trace(f"received {msg_bytes}") # type: ignore
|
||||
yield decoder.decode(msg_bytes)
|
||||
yield self.decode(msg_bytes)
|
||||
|
||||
async def send(self, data: Any) -> None:
|
||||
async with self._send_lock:
|
||||
|
||||
bytes_data = self.ms_encode(data)
|
||||
bytes_data = self.encode(data)
|
||||
|
||||
# supposedly the fastest says,
|
||||
# https://stackoverflow.com/a/54027962
|
||||
|
@ -191,14 +192,19 @@ class Channel:
|
|||
auto_reconnect: bool = False,
|
||||
stream: trio.SocketStream = None, # expected to be active
|
||||
|
||||
# stream_serializer_type: type = MsgspecTCPStream,
|
||||
stream_serializer_type: type = MsgpackTCPStream,
|
||||
|
||||
) -> None:
|
||||
|
||||
self._recon_seq = on_reconnect
|
||||
self._autorecon = auto_reconnect
|
||||
|
||||
try:
|
||||
# if installed load the msgspec transport since it's faster
|
||||
import msgspec # noqa
|
||||
stream_serializer_type: type = MsgspecTCPStream
|
||||
|
||||
except ImportError:
|
||||
stream_serializer_type: type = MsgpackTCPStream
|
||||
|
||||
self.stream_serializer_type = stream_serializer_type
|
||||
self.msgstream: Optional[type] = stream_serializer_type(
|
||||
stream) if stream else None
|
||||
|
|
Loading…
Reference in New Issue