Add "message transport" structured sub-typing

In an effort to have some kind of more formal interface around the
transport layer, add a `MsgTransport` protocol type and use with
the channel composition of message streams. Start a little "key map"
of `(<codec>, <protocol>)` to `MsgTransport` types which can be
dynamically loaded. Add a `Channel.from_stream()` constructor thus
cleaning up the mangled logic that was in the constructor based on
inputs. Drop all the "auto reconnect" channel logic for now since
nothing is using it (internally) and it's likely it will need rework
once we bring in a protocol besides TCP.
optional_msgspec_support
Tyler Goodlet 2021-10-06 14:52:12 -04:00
parent 135459ca25
commit c6dc96b08c
1 changed files with 159 additions and 89 deletions

View File

@ -2,10 +2,14 @@
Inter-process comms abstractions Inter-process comms abstractions
""" """
from __future__ import annotations
import platform import platform
import struct import struct
import typing import typing
from typing import Any, Tuple, Optional, Type from typing import (
Any, Tuple, Optional,
Type, Protocol, TypeVar
)
from tricycle import BufferedReceiveStream from tricycle import BufferedReceiveStream
import msgpack import msgpack
@ -21,6 +25,53 @@ _is_windows = platform.system() == 'Windows'
log = get_logger(__name__) log = get_logger(__name__)
def get_stream_addrs(stream: trio.SocketStream) -> Tuple:
# should both be IP sockets
lsockname = stream.socket.getsockname()
rsockname = stream.socket.getpeername()
return (
tuple(lsockname[:2]),
tuple(rsockname[:2]),
)
MsgType = TypeVar("MsgType")
# TODO: consider using a generic def and indexing with our eventual
# msg definition/types?
# - https://docs.python.org/3/library/typing.html#typing.Protocol
# - https://jcristharif.com/msgspec/usage.html#structs
class MsgTransport(Protocol[MsgType]):
stream: trio.SocketStream
def __init__(self, stream: trio.SocketStream) -> None:
...
# XXX: should this instead be called `.sendall()`?
async def send(self, msg: MsgType) -> None:
...
async def recv(self) -> MsgType:
...
def __aiter__(self) -> MsgType:
...
def connected(self) -> bool:
...
@property
def laddr(self) -> Tuple[str, int]:
...
@property
def raddr(self) -> Tuple[str, int]:
...
class MsgpackTCPStream: class MsgpackTCPStream:
'''A ``trio.SocketStream`` delivering ``msgpack`` formatted data '''A ``trio.SocketStream`` delivering ``msgpack`` formatted data
using ``msgpack-python``. using ``msgpack-python``.
@ -36,17 +87,10 @@ class MsgpackTCPStream:
assert self.stream.socket assert self.stream.socket
# should both be IP sockets # should both be IP sockets
lsockname = stream.socket.getsockname() self._laddr, self._raddr = get_stream_addrs(stream)
assert isinstance(lsockname, tuple)
self._laddr = lsockname[:2]
rsockname = stream.socket.getpeername() # create read loop instance
assert isinstance(rsockname, tuple)
self._raddr = rsockname[:2]
# start first entry to read loop
self._agen = self._iter_packets() self._agen = self._iter_packets()
self._send_lock = trio.StrictFIFOLock() self._send_lock = trio.StrictFIFOLock()
async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]: async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
@ -103,11 +147,10 @@ class MsgpackTCPStream:
def raddr(self) -> Tuple[Any, ...]: def raddr(self) -> Tuple[Any, ...]:
return self._raddr return self._raddr
# XXX: should this instead be called `.sendall()`? async def send(self, msg: Any) -> None:
async def send(self, data: Any) -> None:
async with self._send_lock: async with self._send_lock:
return await self.stream.send_all( return await self.stream.send_all(
msgpack.dumps(data, use_bin_type=True) msgpack.dumps(msg, use_bin_type=True)
) )
async def recv(self) -> Any: async def recv(self) -> Any:
@ -191,10 +234,10 @@ class MsgspecTCPStream(MsgpackTCPStream):
else: else:
raise raise
async def send(self, data: Any) -> None: async def send(self, msg: Any) -> None:
async with self._send_lock: async with self._send_lock:
bytes_data: bytes = self.encode(data) bytes_data: bytes = self.encode(msg)
# supposedly the fastest says, # supposedly the fastest says,
# https://stackoverflow.com/a/54027962 # https://stackoverflow.com/a/54027962
@ -203,13 +246,16 @@ class MsgspecTCPStream(MsgpackTCPStream):
return await self.stream.send_all(size + bytes_data) return await self.stream.send_all(size + bytes_data)
def get_serializer_stream_type( def get_msg_transport(
name: str,
) -> Type: key: Tuple[str, str],
) -> Type[MsgTransport]:
return { return {
'msgpack': MsgpackTCPStream, ('msgpack', 'tcp'): MsgpackTCPStream,
'msgspec': MsgspecTCPStream, ('msgspec', 'tcp'): MsgspecTCPStream,
}[name] }[key]
class Channel: class Channel:
@ -221,34 +267,34 @@ class Channel:
def __init__( def __init__(
self, self,
destaddr: Optional[Tuple[str, int]] = None, destaddr: Optional[Tuple[str, int]],
on_reconnect: typing.Callable[..., typing.Awaitable] = None,
auto_reconnect: bool = False, msg_transport_type_key: Tuple[str, str] = ('msgpack', 'tcp'),
stream: trio.SocketStream = None, # expected to be active
# TODO: optional reconnection support?
# auto_reconnect: bool = False,
# on_reconnect: typing.Callable[..., typing.Awaitable] = None,
) -> None: ) -> None:
self._recon_seq = on_reconnect # self._recon_seq = on_reconnect
self._autorecon = auto_reconnect # self._autorecon = auto_reconnect
# TODO: maybe expose this through the nursery api? # TODO: maybe expose this through the nursery api?
try: try:
# if installed load the msgspec transport since it's faster # if installed load the msgspec transport since it's faster
import msgspec # noqa import msgspec # noqa
serializer = 'msgspec' msg_transport_type_key = ('msgspec', 'tcp')
except ImportError: except ImportError:
serializer = 'msgpack' pass
self.stream_serializer_type = get_serializer_stream_type(serializer) self._destaddr = destaddr
self.msgstream = self.stream_serializer_type( self._transport_key = msg_transport_type_key
stream) if stream else None
if self.msgstream and destaddr: # Either created in ``.connect()`` or passed in by
raise ValueError( # user in ``.from_stream()``.
f"A stream was provided with local addr {self.laddr}" self._stream: Optional[trio.SocketStream] = None
) self.msgstream: Optional[MsgTransport] = None
self._destaddr = self.msgstream.raddr if self.msgstream else destaddr
# set after handshake - always uid of far end # set after handshake - always uid of far end
self.uid: Optional[Tuple[str, str]] = None self.uid: Optional[Tuple[str, str]] = None
@ -256,9 +302,34 @@ class Channel:
# set if far end actor errors internally # set if far end actor errors internally
self._exc: Optional[Exception] = None self._exc: Optional[Exception] = None
self._agen = self._aiter_recv() self._agen = self._aiter_recv()
self._closed: bool = False self._closed: bool = False
@classmethod
def from_stream(
cls,
stream: trio.SocketStream,
**kwargs,
) -> Channel:
src, dst = get_stream_addrs(stream)
chan = Channel(destaddr=dst, **kwargs)
# set immediately here from provided instance
chan._stream = stream
chan.set_msg_transport(stream)
return chan
def set_msg_transport(
self,
stream: trio.SocketStream,
type_key: Optional[Tuple[str, str]] = None,
) -> MsgTransport:
type_key = type_key or self._transport_key
self.msgstream = get_msg_transport(type_key)(stream)
return self.msgstream
def __repr__(self) -> str: def __repr__(self) -> str:
if self.msgstream: if self.msgstream:
return repr( return repr(
@ -267,11 +338,11 @@ class Channel:
return object.__repr__(self) return object.__repr__(self)
@property @property
def laddr(self) -> Optional[Tuple[Any, ...]]: def laddr(self) -> Optional[Tuple[str, int]]:
return self.msgstream.laddr if self.msgstream else None return self.msgstream.laddr if self.msgstream else None
@property @property
def raddr(self) -> Optional[Tuple[Any, ...]]: def raddr(self) -> Optional[Tuple[str, int]]:
return self.msgstream.raddr if self.msgstream else None return self.msgstream.raddr if self.msgstream else None
async def connect( async def connect(
@ -279,7 +350,7 @@ class Channel:
destaddr: Tuple[Any, ...] = None, destaddr: Tuple[Any, ...] = None,
**kwargs **kwargs
) -> trio.SocketStream: ) -> MsgTransport:
if self.connected(): if self.connected():
raise RuntimeError("channel is already connected?") raise RuntimeError("channel is already connected?")
@ -291,12 +362,12 @@ class Channel:
*destaddr, *destaddr,
**kwargs **kwargs
) )
self.msgstream = self.stream_serializer_type(stream) msgstream = self.set_msg_transport(stream)
log.transport( log.transport(
f'Opened channel to peer {self.laddr} -> {self.raddr}' f'Opened channel[{type(msgstream)}]: {self.laddr} -> {self.raddr}'
) )
return stream return msgstream
async def send(self, item: Any) -> None: async def send(self, item: Any) -> None:
@ -307,16 +378,15 @@ class Channel:
async def recv(self) -> Any: async def recv(self) -> Any:
assert self.msgstream assert self.msgstream
try:
return await self.msgstream.recv() return await self.msgstream.recv()
except trio.BrokenResourceError: # try:
if self._autorecon: # return await self.msgstream.recv()
await self._reconnect() # except trio.BrokenResourceError:
return await self.recv() # if self._autorecon:
# await self._reconnect()
raise # return await self.recv()
# raise
async def aclose(self) -> None: async def aclose(self) -> None:
@ -338,34 +408,36 @@ class Channel:
def __aiter__(self): def __aiter__(self):
return self._agen return self._agen
async def _reconnect(self) -> None: # async def _reconnect(self) -> None:
"""Handle connection failures by polling until a reconnect can be # """Handle connection failures by polling until a reconnect can be
established. # established.
""" # """
down = False # down = False
while True: # while True:
try: # try:
with trio.move_on_after(3) as cancel_scope: # with trio.move_on_after(3) as cancel_scope:
await self.connect() # await self.connect()
cancelled = cancel_scope.cancelled_caught # cancelled = cancel_scope.cancelled_caught
if cancelled: # if cancelled:
log.transport( # log.transport(
"Reconnect timed out after 3 seconds, retrying...") # "Reconnect timed out after 3 seconds, retrying...")
continue # continue
else: # else:
log.transport("Stream connection re-established!") # log.transport("Stream connection re-established!")
# run any reconnection sequence
on_recon = self._recon_seq # # TODO: run any reconnection sequence
if on_recon: # # on_recon = self._recon_seq
await on_recon(self) # # if on_recon:
break # # await on_recon(self)
except (OSError, ConnectionRefusedError):
if not down: # break
down = True # except (OSError, ConnectionRefusedError):
log.transport( # if not down:
f"Connection to {self.raddr} went down, waiting" # down = True
" for re-establishment") # log.transport(
await trio.sleep(1) # f"Connection to {self.raddr} went down, waiting"
# " for re-establishment")
# await trio.sleep(1)
async def _aiter_recv( async def _aiter_recv(
self self
@ -384,16 +456,14 @@ class Channel:
# await self.msgstream.send(sent) # await self.msgstream.send(sent)
except trio.BrokenResourceError: except trio.BrokenResourceError:
if not self._autorecon: # if not self._autorecon:
raise raise
await self.aclose() await self.aclose()
if self._autorecon: # attempt reconnect # if self._autorecon: # attempt reconnect
await self._reconnect() # await self._reconnect()
continue # continue
else:
return
def connected(self) -> bool: def connected(self) -> bool:
return self.msgstream.connected() if self.msgstream else False return self.msgstream.connected() if self.msgstream else False