Pkg `msgpec` as optional dep, load transport type if importable

optional_msgspec_support
Tyler Goodlet 2021-07-01 09:41:23 -04:00
parent 96b3f94c72
commit b64396f708
2 changed files with 24 additions and 11 deletions

View File

@ -48,14 +48,21 @@ setup(
'tricycle', 'tricycle',
'trio_typing', 'trio_typing',
# tooling
'colorlog', 'colorlog',
'wrapt', 'wrapt',
'pdbpp', 'pdbpp',
# serialization # serialization
'msgpack', 'msgpack',
'msgspec',
], ],
extras_require={
# serialization
'msgspec': ['msgspec; python_version >= 3.9'],
},
tests_require=['pytest'], tests_require=['pytest'],
python_requires=">=3.8", python_requires=">=3.8",
keywords=[ keywords=[

View File

@ -9,7 +9,6 @@ from typing import Any, Tuple, Optional
from tricycle import BufferedReceiveStream from tricycle import BufferedReceiveStream
import msgpack import msgpack
import msgspec
import trio import trio
from async_generator import asynccontextmanager from async_generator import asynccontextmanager
@ -20,7 +19,6 @@ log = get_logger(__name__)
_is_windows = platform.system() == 'Windows' _is_windows = platform.system() == 'Windows'
log = get_logger(__name__) log = get_logger(__name__)
ms_decode = msgspec.Encoder().encode
class MsgpackTCPStream: class MsgpackTCPStream:
@ -126,8 +124,6 @@ class MsgspecTCPStream(MsgpackTCPStream):
using ``msgspec``. using ``msgspec``.
''' '''
ms_encode = msgspec.Encoder().encode
def __init__( def __init__(
self, self,
stream: trio.SocketStream, stream: trio.SocketStream,
@ -138,10 +134,15 @@ class MsgspecTCPStream(MsgpackTCPStream):
self.recv_stream = BufferedReceiveStream(transport_stream=stream) self.recv_stream = BufferedReceiveStream(transport_stream=stream)
self.prefix_size = prefix_size self.prefix_size = prefix_size
import msgspec
# TODO: struct aware messaging coders
self.encode = msgspec.Encoder().encode
self.decode = msgspec.Decoder().decode # dict[str, Any])
async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]: async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
"""Yield packets from the underlying stream. """Yield packets from the underlying stream.
""" """
decoder = msgspec.Decoder() # dict[str, Any])
while True: while True:
try: try:
@ -164,12 +165,12 @@ class MsgspecTCPStream(MsgpackTCPStream):
msg_bytes = await self.recv_stream.receive_exactly(size) msg_bytes = await self.recv_stream.receive_exactly(size)
log.trace(f"received {msg_bytes}") # type: ignore log.trace(f"received {msg_bytes}") # type: ignore
yield decoder.decode(msg_bytes) yield self.decode(msg_bytes)
async def send(self, data: Any) -> None: async def send(self, data: Any) -> None:
async with self._send_lock: async with self._send_lock:
bytes_data = self.ms_encode(data) bytes_data = self.encode(data)
# supposedly the fastest says, # supposedly the fastest says,
# https://stackoverflow.com/a/54027962 # https://stackoverflow.com/a/54027962
@ -191,14 +192,19 @@ class Channel:
auto_reconnect: bool = False, auto_reconnect: bool = False,
stream: trio.SocketStream = None, # expected to be active stream: trio.SocketStream = None, # expected to be active
# stream_serializer_type: type = MsgspecTCPStream,
stream_serializer_type: type = MsgpackTCPStream,
) -> None: ) -> None:
self._recon_seq = on_reconnect self._recon_seq = on_reconnect
self._autorecon = auto_reconnect self._autorecon = auto_reconnect
try:
# if installed load the msgspec transport since it's faster
import msgspec # noqa
stream_serializer_type: type = MsgspecTCPStream
except ImportError:
stream_serializer_type: type = MsgpackTCPStream
self.stream_serializer_type = stream_serializer_type self.stream_serializer_type = stream_serializer_type
self.msgstream: Optional[type] = stream_serializer_type( self.msgstream: Optional[type] = stream_serializer_type(
stream) if stream else None stream) if stream else None