Pkg `msgpec` as optional dep, load transport type if importable
parent
96b3f94c72
commit
b64396f708
9
setup.py
9
setup.py
|
@ -48,14 +48,21 @@ setup(
|
||||||
'tricycle',
|
'tricycle',
|
||||||
'trio_typing',
|
'trio_typing',
|
||||||
|
|
||||||
|
# tooling
|
||||||
'colorlog',
|
'colorlog',
|
||||||
'wrapt',
|
'wrapt',
|
||||||
'pdbpp',
|
'pdbpp',
|
||||||
|
|
||||||
# serialization
|
# serialization
|
||||||
'msgpack',
|
'msgpack',
|
||||||
'msgspec',
|
|
||||||
],
|
],
|
||||||
|
extras_require={
|
||||||
|
|
||||||
|
# serialization
|
||||||
|
'msgspec': ['msgspec; python_version >= 3.9'],
|
||||||
|
|
||||||
|
},
|
||||||
tests_require=['pytest'],
|
tests_require=['pytest'],
|
||||||
python_requires=">=3.8",
|
python_requires=">=3.8",
|
||||||
keywords=[
|
keywords=[
|
||||||
|
|
|
@ -9,7 +9,6 @@ from typing import Any, Tuple, Optional
|
||||||
|
|
||||||
from tricycle import BufferedReceiveStream
|
from tricycle import BufferedReceiveStream
|
||||||
import msgpack
|
import msgpack
|
||||||
import msgspec
|
|
||||||
import trio
|
import trio
|
||||||
from async_generator import asynccontextmanager
|
from async_generator import asynccontextmanager
|
||||||
|
|
||||||
|
@ -20,7 +19,6 @@ log = get_logger(__name__)
|
||||||
|
|
||||||
_is_windows = platform.system() == 'Windows'
|
_is_windows = platform.system() == 'Windows'
|
||||||
log = get_logger(__name__)
|
log = get_logger(__name__)
|
||||||
ms_decode = msgspec.Encoder().encode
|
|
||||||
|
|
||||||
|
|
||||||
class MsgpackTCPStream:
|
class MsgpackTCPStream:
|
||||||
|
@ -126,8 +124,6 @@ class MsgspecTCPStream(MsgpackTCPStream):
|
||||||
using ``msgspec``.
|
using ``msgspec``.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
ms_encode = msgspec.Encoder().encode
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
stream: trio.SocketStream,
|
stream: trio.SocketStream,
|
||||||
|
@ -138,10 +134,15 @@ class MsgspecTCPStream(MsgpackTCPStream):
|
||||||
self.recv_stream = BufferedReceiveStream(transport_stream=stream)
|
self.recv_stream = BufferedReceiveStream(transport_stream=stream)
|
||||||
self.prefix_size = prefix_size
|
self.prefix_size = prefix_size
|
||||||
|
|
||||||
|
import msgspec
|
||||||
|
|
||||||
|
# TODO: struct aware messaging coders
|
||||||
|
self.encode = msgspec.Encoder().encode
|
||||||
|
self.decode = msgspec.Decoder().decode # dict[str, Any])
|
||||||
|
|
||||||
async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
|
async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
|
||||||
"""Yield packets from the underlying stream.
|
"""Yield packets from the underlying stream.
|
||||||
"""
|
"""
|
||||||
decoder = msgspec.Decoder() # dict[str, Any])
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
@ -164,12 +165,12 @@ class MsgspecTCPStream(MsgpackTCPStream):
|
||||||
msg_bytes = await self.recv_stream.receive_exactly(size)
|
msg_bytes = await self.recv_stream.receive_exactly(size)
|
||||||
|
|
||||||
log.trace(f"received {msg_bytes}") # type: ignore
|
log.trace(f"received {msg_bytes}") # type: ignore
|
||||||
yield decoder.decode(msg_bytes)
|
yield self.decode(msg_bytes)
|
||||||
|
|
||||||
async def send(self, data: Any) -> None:
|
async def send(self, data: Any) -> None:
|
||||||
async with self._send_lock:
|
async with self._send_lock:
|
||||||
|
|
||||||
bytes_data = self.ms_encode(data)
|
bytes_data = self.encode(data)
|
||||||
|
|
||||||
# supposedly the fastest says,
|
# supposedly the fastest says,
|
||||||
# https://stackoverflow.com/a/54027962
|
# https://stackoverflow.com/a/54027962
|
||||||
|
@ -191,14 +192,19 @@ class Channel:
|
||||||
auto_reconnect: bool = False,
|
auto_reconnect: bool = False,
|
||||||
stream: trio.SocketStream = None, # expected to be active
|
stream: trio.SocketStream = None, # expected to be active
|
||||||
|
|
||||||
# stream_serializer_type: type = MsgspecTCPStream,
|
|
||||||
stream_serializer_type: type = MsgpackTCPStream,
|
|
||||||
|
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
self._recon_seq = on_reconnect
|
self._recon_seq = on_reconnect
|
||||||
self._autorecon = auto_reconnect
|
self._autorecon = auto_reconnect
|
||||||
|
|
||||||
|
try:
|
||||||
|
# if installed load the msgspec transport since it's faster
|
||||||
|
import msgspec # noqa
|
||||||
|
stream_serializer_type: type = MsgspecTCPStream
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
stream_serializer_type: type = MsgpackTCPStream
|
||||||
|
|
||||||
self.stream_serializer_type = stream_serializer_type
|
self.stream_serializer_type = stream_serializer_type
|
||||||
self.msgstream: Optional[type] = stream_serializer_type(
|
self.msgstream: Optional[type] = stream_serializer_type(
|
||||||
stream) if stream else None
|
stream) if stream else None
|
||||||
|
|
Loading…
Reference in New Issue