Pkg `msgpec` as optional dep, load transport type if importable

msgspec_not_fucked
Tyler Goodlet 2021-07-01 09:41:23 -04:00
parent 700f09ce9b
commit 2bd6bbc1b7
2 changed files with 25 additions and 10 deletions

View File

@ -45,14 +45,21 @@ setup(
'tricycle', 'tricycle',
'trio_typing', 'trio_typing',
# tooling
'colorlog', 'colorlog',
'wrapt', 'wrapt',
'pdbpp', 'pdbpp',
# serialization # serialization
'msgpack', 'msgpack',
'msgspec',
], ],
extras_require={
# serialization
'msgspec': ['msgspec; python_version >= 3.9'],
},
tests_require=['pytest'], tests_require=['pytest'],
python_requires=">=3.7", python_requires=">=3.7",
keywords=[ keywords=[

View File

@ -8,7 +8,6 @@ from typing import Any, Tuple, Optional
from tricycle import BufferedReceiveStream from tricycle import BufferedReceiveStream
import msgpack import msgpack
import msgspec
import trio import trio
from async_generator import asynccontextmanager from async_generator import asynccontextmanager
@ -103,8 +102,6 @@ class MsgspecTCPStream(MsgpackTCPStream):
using ``msgspec``. using ``msgspec``.
''' '''
ms_encode = msgspec.Encoder().encode
def __init__( def __init__(
self, self,
stream: trio.SocketStream, stream: trio.SocketStream,
@ -115,10 +112,15 @@ class MsgspecTCPStream(MsgpackTCPStream):
self.recv_stream = BufferedReceiveStream(transport_stream=stream) self.recv_stream = BufferedReceiveStream(transport_stream=stream)
self.prefix_size = prefix_size self.prefix_size = prefix_size
import msgspec
# TODO: struct aware messaging coders
self.encode = msgspec.Encoder().encode
self.decode = msgspec.Decoder().decode # dict[str, Any])
async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]: async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
"""Yield packets from the underlying stream. """Yield packets from the underlying stream.
""" """
decoder = msgspec.Decoder() # dict[str, Any])
while True: while True:
try: try:
@ -141,12 +143,12 @@ class MsgspecTCPStream(MsgpackTCPStream):
msg_bytes = await self.recv_stream.receive_exactly(size) msg_bytes = await self.recv_stream.receive_exactly(size)
log.trace(f"received {msg_bytes}") # type: ignore log.trace(f"received {msg_bytes}") # type: ignore
yield decoder.decode(msg_bytes) yield self.decode(msg_bytes)
async def send(self, data: Any) -> None: async def send(self, data: Any) -> None:
async with self._send_lock: async with self._send_lock:
bytes_data = self.ms_encode(data) bytes_data = self.encode(data)
# supposedly the fastest says, # supposedly the fastest says,
# https://stackoverflow.com/a/54027962 # https://stackoverflow.com/a/54027962
@ -168,13 +170,19 @@ class Channel:
auto_reconnect: bool = False, auto_reconnect: bool = False,
stream: trio.SocketStream = None, # expected to be active stream: trio.SocketStream = None, # expected to be active
# stream_serializer_type: type = MsgspecTCPStream,
stream_serializer_type: type = MsgpackTCPStream,
) -> None: ) -> None:
self._recon_seq = on_reconnect self._recon_seq = on_reconnect
self._autorecon = auto_reconnect self._autorecon = auto_reconnect
try:
# if installed load the msgspec transport since it's faster
import msgspec # noqa
stream_serializer_type: type = MsgspecTCPStream
except ImportError:
stream_serializer_type: type = MsgpackTCPStream
self.stream_serializer_type = stream_serializer_type self.stream_serializer_type = stream_serializer_type
self.msgstream: Optional[type] = stream_serializer_type( self.msgstream: Optional[type] = stream_serializer_type(
stream) if stream else None stream) if stream else None