Add "message transport" structured sub-typing

In an effort to have some kind of more formal interface around the
transport layer, add a `MsgTransport` protocol type and use with
the channel composition of message streams. Start a little "key map"
of `(<codec>, <protocol>)` to `MsgTransport` types which can be
dynamically loaded. Add a `Channel.from_stream()` constructor thus
cleaning up the mangled logic that was in the constructor based on
inputs. Drop all the "auto reconnect" channel logic for now since
nothing is using it (internally) and it's likely it will need rework
once we bring in a protocol besides TCP.
optional_msgspec_support
Tyler Goodlet 2021-10-06 14:52:12 -04:00
parent 135459ca25
commit c6dc96b08c
1 changed files with 159 additions and 89 deletions

View File

@ -2,10 +2,14 @@
Inter-process comms abstractions
"""
from __future__ import annotations
import platform
import struct
import typing
from typing import Any, Tuple, Optional, Type
from typing import (
Any, Tuple, Optional,
Type, Protocol, TypeVar
)
from tricycle import BufferedReceiveStream
import msgpack
@ -21,6 +25,53 @@ _is_windows = platform.system() == 'Windows'
log = get_logger(__name__)
def get_stream_addrs(stream: trio.SocketStream) -> Tuple:
# should both be IP sockets
lsockname = stream.socket.getsockname()
rsockname = stream.socket.getpeername()
return (
tuple(lsockname[:2]),
tuple(rsockname[:2]),
)
MsgType = TypeVar("MsgType")
# TODO: consider using a generic def and indexing with our eventual
# msg definition/types?
# - https://docs.python.org/3/library/typing.html#typing.Protocol
# - https://jcristharif.com/msgspec/usage.html#structs
class MsgTransport(Protocol[MsgType]):
stream: trio.SocketStream
def __init__(self, stream: trio.SocketStream) -> None:
...
# XXX: should this instead be called `.sendall()`?
async def send(self, msg: MsgType) -> None:
...
async def recv(self) -> MsgType:
...
def __aiter__(self) -> MsgType:
...
def connected(self) -> bool:
...
@property
def laddr(self) -> Tuple[str, int]:
...
@property
def raddr(self) -> Tuple[str, int]:
...
class MsgpackTCPStream:
'''A ``trio.SocketStream`` delivering ``msgpack`` formatted data
using ``msgpack-python``.
@ -36,17 +87,10 @@ class MsgpackTCPStream:
assert self.stream.socket
# should both be IP sockets
lsockname = stream.socket.getsockname()
assert isinstance(lsockname, tuple)
self._laddr = lsockname[:2]
self._laddr, self._raddr = get_stream_addrs(stream)
rsockname = stream.socket.getpeername()
assert isinstance(rsockname, tuple)
self._raddr = rsockname[:2]
# start first entry to read loop
# create read loop instance
self._agen = self._iter_packets()
self._send_lock = trio.StrictFIFOLock()
async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
@ -103,11 +147,10 @@ class MsgpackTCPStream:
def raddr(self) -> Tuple[Any, ...]:
return self._raddr
# XXX: should this instead be called `.sendall()`?
async def send(self, data: Any) -> None:
async def send(self, msg: Any) -> None:
async with self._send_lock:
return await self.stream.send_all(
msgpack.dumps(data, use_bin_type=True)
msgpack.dumps(msg, use_bin_type=True)
)
async def recv(self) -> Any:
@ -191,10 +234,10 @@ class MsgspecTCPStream(MsgpackTCPStream):
else:
raise
async def send(self, data: Any) -> None:
async def send(self, msg: Any) -> None:
async with self._send_lock:
bytes_data: bytes = self.encode(data)
bytes_data: bytes = self.encode(msg)
# supposedly the fastest says,
# https://stackoverflow.com/a/54027962
@ -203,13 +246,16 @@ class MsgspecTCPStream(MsgpackTCPStream):
return await self.stream.send_all(size + bytes_data)
def get_serializer_stream_type(
name: str,
) -> Type:
def get_msg_transport(
key: Tuple[str, str],
) -> Type[MsgTransport]:
return {
'msgpack': MsgpackTCPStream,
'msgspec': MsgspecTCPStream,
}[name]
('msgpack', 'tcp'): MsgpackTCPStream,
('msgspec', 'tcp'): MsgspecTCPStream,
}[key]
class Channel:
@ -221,34 +267,34 @@ class Channel:
def __init__(
self,
destaddr: Optional[Tuple[str, int]] = None,
on_reconnect: typing.Callable[..., typing.Awaitable] = None,
auto_reconnect: bool = False,
stream: trio.SocketStream = None, # expected to be active
destaddr: Optional[Tuple[str, int]],
msg_transport_type_key: Tuple[str, str] = ('msgpack', 'tcp'),
# TODO: optional reconnection support?
# auto_reconnect: bool = False,
# on_reconnect: typing.Callable[..., typing.Awaitable] = None,
) -> None:
self._recon_seq = on_reconnect
self._autorecon = auto_reconnect
# self._recon_seq = on_reconnect
# self._autorecon = auto_reconnect
# TODO: maybe expose this through the nursery api?
try:
# if installed load the msgspec transport since it's faster
import msgspec # noqa
serializer = 'msgspec'
msg_transport_type_key = ('msgspec', 'tcp')
except ImportError:
serializer = 'msgpack'
pass
self.stream_serializer_type = get_serializer_stream_type(serializer)
self.msgstream = self.stream_serializer_type(
stream) if stream else None
self._destaddr = destaddr
self._transport_key = msg_transport_type_key
if self.msgstream and destaddr:
raise ValueError(
f"A stream was provided with local addr {self.laddr}"
)
self._destaddr = self.msgstream.raddr if self.msgstream else destaddr
# Either created in ``.connect()`` or passed in by
# user in ``.from_stream()``.
self._stream: Optional[trio.SocketStream] = None
self.msgstream: Optional[MsgTransport] = None
# set after handshake - always uid of far end
self.uid: Optional[Tuple[str, str]] = None
@ -256,9 +302,34 @@ class Channel:
# set if far end actor errors internally
self._exc: Optional[Exception] = None
self._agen = self._aiter_recv()
self._closed: bool = False
@classmethod
def from_stream(
cls,
stream: trio.SocketStream,
**kwargs,
) -> Channel:
src, dst = get_stream_addrs(stream)
chan = Channel(destaddr=dst, **kwargs)
# set immediately here from provided instance
chan._stream = stream
chan.set_msg_transport(stream)
return chan
def set_msg_transport(
self,
stream: trio.SocketStream,
type_key: Optional[Tuple[str, str]] = None,
) -> MsgTransport:
type_key = type_key or self._transport_key
self.msgstream = get_msg_transport(type_key)(stream)
return self.msgstream
def __repr__(self) -> str:
if self.msgstream:
return repr(
@ -267,11 +338,11 @@ class Channel:
return object.__repr__(self)
@property
def laddr(self) -> Optional[Tuple[Any, ...]]:
def laddr(self) -> Optional[Tuple[str, int]]:
return self.msgstream.laddr if self.msgstream else None
@property
def raddr(self) -> Optional[Tuple[Any, ...]]:
def raddr(self) -> Optional[Tuple[str, int]]:
return self.msgstream.raddr if self.msgstream else None
async def connect(
@ -279,7 +350,7 @@ class Channel:
destaddr: Tuple[Any, ...] = None,
**kwargs
) -> trio.SocketStream:
) -> MsgTransport:
if self.connected():
raise RuntimeError("channel is already connected?")
@ -291,12 +362,12 @@ class Channel:
*destaddr,
**kwargs
)
self.msgstream = self.stream_serializer_type(stream)
msgstream = self.set_msg_transport(stream)
log.transport(
f'Opened channel to peer {self.laddr} -> {self.raddr}'
f'Opened channel[{type(msgstream)}]: {self.laddr} -> {self.raddr}'
)
return stream
return msgstream
async def send(self, item: Any) -> None:
@ -307,16 +378,15 @@ class Channel:
async def recv(self) -> Any:
assert self.msgstream
try:
return await self.msgstream.recv()
except trio.BrokenResourceError:
if self._autorecon:
await self._reconnect()
return await self.recv()
raise
# try:
# return await self.msgstream.recv()
# except trio.BrokenResourceError:
# if self._autorecon:
# await self._reconnect()
# return await self.recv()
# raise
async def aclose(self) -> None:
@ -338,34 +408,36 @@ class Channel:
def __aiter__(self):
return self._agen
async def _reconnect(self) -> None:
"""Handle connection failures by polling until a reconnect can be
established.
"""
down = False
while True:
try:
with trio.move_on_after(3) as cancel_scope:
await self.connect()
cancelled = cancel_scope.cancelled_caught
if cancelled:
log.transport(
"Reconnect timed out after 3 seconds, retrying...")
continue
else:
log.transport("Stream connection re-established!")
# run any reconnection sequence
on_recon = self._recon_seq
if on_recon:
await on_recon(self)
break
except (OSError, ConnectionRefusedError):
if not down:
down = True
log.transport(
f"Connection to {self.raddr} went down, waiting"
" for re-establishment")
await trio.sleep(1)
# async def _reconnect(self) -> None:
# """Handle connection failures by polling until a reconnect can be
# established.
# """
# down = False
# while True:
# try:
# with trio.move_on_after(3) as cancel_scope:
# await self.connect()
# cancelled = cancel_scope.cancelled_caught
# if cancelled:
# log.transport(
# "Reconnect timed out after 3 seconds, retrying...")
# continue
# else:
# log.transport("Stream connection re-established!")
# # TODO: run any reconnection sequence
# # on_recon = self._recon_seq
# # if on_recon:
# # await on_recon(self)
# break
# except (OSError, ConnectionRefusedError):
# if not down:
# down = True
# log.transport(
# f"Connection to {self.raddr} went down, waiting"
# " for re-establishment")
# await trio.sleep(1)
async def _aiter_recv(
self
@ -384,16 +456,14 @@ class Channel:
# await self.msgstream.send(sent)
except trio.BrokenResourceError:
if not self._autorecon:
# if not self._autorecon:
raise
await self.aclose()
if self._autorecon: # attempt reconnect
await self._reconnect()
continue
else:
return
# if self._autorecon: # attempt reconnect
# await self._reconnect()
# continue
def connected(self) -> bool:
return self.msgstream.connected() if self.msgstream else False