Drop `msgpack` lib and use `msgspec` for transport

drop_msgpack
Tyler Goodlet 2022-07-11 20:34:10 -04:00
parent f6af5c7bf8
commit f94b7cd991
1 changed files with 34 additions and 110 deletions

View File

@ -29,7 +29,7 @@ from typing import (
) )
from tricycle import BufferedReceiveStream from tricycle import BufferedReceiveStream
import msgpack import msgspec
import trio import trio
from async_generator import asynccontextmanager from async_generator import asynccontextmanager
@ -98,12 +98,13 @@ class MsgTransport(Protocol[MsgType]):
class MsgpackTCPStream: class MsgpackTCPStream:
''' '''
A ``trio.SocketStream`` delivering ``msgpack`` formatted data A ``trio.SocketStream`` delivering ``msgpack`` formatted data
using ``msgpack-python``. using the ``msgspec`` codec lib.
''' '''
def __init__( def __init__(
self, self,
stream: trio.SocketStream, stream: trio.SocketStream,
prefix_size: int = 4,
) -> None: ) -> None:
@ -120,105 +121,6 @@ class MsgpackTCPStream:
# public i guess? # public i guess?
self.drained: list[dict] = [] self.drained: list[dict] = []
async def _iter_packets(self) -> AsyncGenerator[dict, None]:
'''
Yield packets from the underlying stream.
'''
unpacker = msgpack.Unpacker(
raw=False,
)
while True:
try:
data = await self.stream.receive_some(2**10)
except trio.BrokenResourceError as err:
msg = err.args[0]
# XXX: handle connection-reset-by-peer the same as a EOF.
# we're currently remapping this since we allow
# a quick connect then drop for root actors when
# checking to see if there exists an "arbiter"
# on the chosen sockaddr (``_root.py:108`` or thereabouts)
if (
# nix
'[Errno 104]' in msg or
# on windows it seems there are a variety of errors
# to handle..
_is_windows
):
raise TransportClosed(
f'{self} was broken with {msg}'
)
else:
raise
log.transport(f"received {data}") # type: ignore
if data == b'':
raise TransportClosed(
f'transport {self} was already closed prior to read'
)
unpacker.feed(data)
for packet in unpacker:
yield packet
@property
def laddr(self) -> Tuple[Any, ...]:
return self._laddr
@property
def raddr(self) -> Tuple[Any, ...]:
return self._raddr
async def send(self, msg: Any) -> None:
async with self._send_lock:
return await self.stream.send_all(
msgpack.dumps(msg, use_bin_type=True)
)
async def recv(self) -> Any:
return await self._agen.asend(None)
async def drain(self) -> AsyncIterator[dict]:
'''
Drain the stream's remaining messages sent from
the far end until the connection is closed by
the peer.
'''
try:
async for msg in self._iter_packets():
self.drained.append(msg)
except TransportClosed:
for msg in self.drained:
yield msg
def __aiter__(self):
return self._agen
def connected(self) -> bool:
return self.stream.socket.fileno() != -1
class MsgspecTCPStream(MsgpackTCPStream):
'''
A ``trio.SocketStream`` delivering ``msgpack`` formatted data
using ``msgspec``.
'''
def __init__(
self,
stream: trio.SocketStream,
prefix_size: int = 4,
) -> None:
import msgspec
super().__init__(stream)
self.recv_stream = BufferedReceiveStream(transport_stream=stream) self.recv_stream = BufferedReceiveStream(transport_stream=stream)
self.prefix_size = prefix_size self.prefix_size = prefix_size
@ -287,6 +189,37 @@ class MsgspecTCPStream(MsgpackTCPStream):
return await self.stream.send_all(size + bytes_data) return await self.stream.send_all(size + bytes_data)
@property
def laddr(self) -> Tuple[Any, ...]:
return self._laddr
@property
def raddr(self) -> Tuple[Any, ...]:
return self._raddr
async def recv(self) -> Any:
return await self._agen.asend(None)
async def drain(self) -> AsyncIterator[dict]:
'''
Drain the stream's remaining messages sent from
the far end until the connection is closed by
the peer.
'''
try:
async for msg in self._iter_packets():
self.drained.append(msg)
except TransportClosed:
for msg in self.drained:
yield msg
def __aiter__(self):
return self._agen
def connected(self) -> bool:
return self.stream.socket.fileno() != -1
def get_msg_transport( def get_msg_transport(
@ -296,7 +229,6 @@ def get_msg_transport(
return { return {
('msgpack', 'tcp'): MsgpackTCPStream, ('msgpack', 'tcp'): MsgpackTCPStream,
('msgspec', 'tcp'): MsgspecTCPStream,
}[key] }[key]
@ -325,14 +257,6 @@ class Channel:
# self._recon_seq = on_reconnect # self._recon_seq = on_reconnect
# self._autorecon = auto_reconnect # self._autorecon = auto_reconnect
# TODO: maybe expose this through the nursery api?
try:
# if installed load the msgspec transport since it's faster
import msgspec # noqa
msg_transport_type_key = ('msgspec', 'tcp')
except ImportError:
pass
self._destaddr = destaddr self._destaddr = destaddr
self._transport_key = msg_transport_type_key self._transport_key = msg_transport_type_key