Try out `msgspec` in our msgpack stream channel

Can only really use an encoder currently since there is no streaming api
in `msgspec` as of currently. See jcrist/msgspec#27.

Not sure if any encoding speedups are currently noticeable especially
without any validation going on yet XD.

First experiments toward #196
prehardkill
Tyler Goodlet 2021-05-30 17:19:20 -04:00
parent cc2b9d20a4
commit c83b9cc940
1 changed files with 24 additions and 3 deletions

View File

@ -2,10 +2,11 @@
Inter-process comms abstractions Inter-process comms abstractions
""" """
import typing import typing
from typing import Any, Tuple, Optional from typing import Any, Tuple, Optional, Callable
from functools import partial from functools import partial
import msgpack import msgpack
import msgspec
import trio import trio
from async_generator import asynccontextmanager from async_generator import asynccontextmanager
@ -21,16 +22,32 @@ except ImportError:
Unpacker = partial(msgpack.Unpacker, strict_map_key=False) Unpacker = partial(msgpack.Unpacker, strict_map_key=False)
ms_decode = msgspec.Encoder().encode
class MsgpackStream: class MsgpackStream:
"""A ``trio.SocketStream`` delivering ``msgpack`` formatted data. """A ``trio.SocketStream`` delivering ``msgpack`` formatted data.
""" """
def __init__(self, stream: trio.SocketStream) -> None: def __init__(
self,
stream: trio.SocketStream,
serialize: Callable = Unpacker(
raw=False,
use_list=False,
).feed,
deserialize: Callable = msgpack.dumps,
) -> None:
self.stream = stream self.stream = stream
assert self.stream.socket assert self.stream.socket
# should both be IP sockets # should both be IP sockets
lsockname = stream.socket.getsockname() lsockname = stream.socket.getsockname()
assert isinstance(lsockname, tuple) assert isinstance(lsockname, tuple)
self._laddr = lsockname[:2] self._laddr = lsockname[:2]
rsockname = stream.socket.getpeername() rsockname = stream.socket.getpeername()
assert isinstance(rsockname, tuple) assert isinstance(rsockname, tuple)
self._raddr = rsockname[:2] self._raddr = rsockname[:2]
@ -45,6 +62,7 @@ class MsgpackStream:
raw=False, raw=False,
use_list=False, use_list=False,
) )
# decoder = msgspec.Decoder() #dict[str, Any])
while True: while True:
try: try:
data = await self.stream.receive_some(2**10) data = await self.stream.receive_some(2**10)
@ -57,6 +75,7 @@ class MsgpackStream:
log.debug(f"Stream connection {self.raddr} was closed") log.debug(f"Stream connection {self.raddr} was closed")
return return
# yield decoder.decode(data)
unpacker.feed(data) unpacker.feed(data)
for packet in unpacker: for packet in unpacker:
yield packet yield packet
@ -73,7 +92,9 @@ class MsgpackStream:
async def send(self, data: Any) -> None: async def send(self, data: Any) -> None:
async with self._send_lock: async with self._send_lock:
return await self.stream.send_all( return await self.stream.send_all(
msgpack.dumps(data, use_bin_type=True)) # msgpack.dumps(data, use_bin_type=True))
ms_decode(data)
)
async def recv(self) -> Any: async def recv(self) -> Any:
return await self._agen.asend(None) return await self._agen.asend(None)