From c6dc96b08cb38078b2eea4504f0e7b0e6b789abe Mon Sep 17 00:00:00 2001 From: Tyler Goodlet Date: Wed, 6 Oct 2021 14:52:12 -0400 Subject: [PATCH] Add "message transport" structured sub-typing In an effort to have some kind of more formal interface around the transport layer, add a `MsgTransport` protocol type and use with the channel composition of message streams. Start a little "key map" of `(, )` to `MsgTransport` types which can be dynamically loaded. Add a `Channel.from_stream()` constructor thus cleaning up the mangled logic that was in the constructor based on inputs. Drop all the "auto reconnect" channel logic for now since nothing is using it (internally) and it's likely it will need rework once we bring in a protocol besides TCP. --- tractor/_ipc.py | 248 +++++++++++++++++++++++++++++++----------------- 1 file changed, 159 insertions(+), 89 deletions(-) diff --git a/tractor/_ipc.py b/tractor/_ipc.py index b8d0437..28bef97 100644 --- a/tractor/_ipc.py +++ b/tractor/_ipc.py @@ -2,10 +2,14 @@ Inter-process comms abstractions """ +from __future__ import annotations import platform import struct import typing -from typing import Any, Tuple, Optional, Type +from typing import ( + Any, Tuple, Optional, + Type, Protocol, TypeVar +) from tricycle import BufferedReceiveStream import msgpack @@ -21,6 +25,53 @@ _is_windows = platform.system() == 'Windows' log = get_logger(__name__) +def get_stream_addrs(stream: trio.SocketStream) -> Tuple: + # should both be IP sockets + lsockname = stream.socket.getsockname() + rsockname = stream.socket.getpeername() + return ( + tuple(lsockname[:2]), + tuple(rsockname[:2]), + ) + + +MsgType = TypeVar("MsgType") + +# TODO: consider using a generic def and indexing with our eventual +# msg definition/types? +# - https://docs.python.org/3/library/typing.html#typing.Protocol +# - https://jcristharif.com/msgspec/usage.html#structs + + +class MsgTransport(Protocol[MsgType]): + + stream: trio.SocketStream + + def __init__(self, stream: trio.SocketStream) -> None: + ... + + # XXX: should this instead be called `.sendall()`? + async def send(self, msg: MsgType) -> None: + ... + + async def recv(self) -> MsgType: + ... + + def __aiter__(self) -> MsgType: + ... + + def connected(self) -> bool: + ... + + @property + def laddr(self) -> Tuple[str, int]: + ... + + @property + def raddr(self) -> Tuple[str, int]: + ... + + class MsgpackTCPStream: '''A ``trio.SocketStream`` delivering ``msgpack`` formatted data using ``msgpack-python``. @@ -36,17 +87,10 @@ class MsgpackTCPStream: assert self.stream.socket # should both be IP sockets - lsockname = stream.socket.getsockname() - assert isinstance(lsockname, tuple) - self._laddr = lsockname[:2] + self._laddr, self._raddr = get_stream_addrs(stream) - rsockname = stream.socket.getpeername() - assert isinstance(rsockname, tuple) - self._raddr = rsockname[:2] - - # start first entry to read loop + # create read loop instance self._agen = self._iter_packets() - self._send_lock = trio.StrictFIFOLock() async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]: @@ -103,11 +147,10 @@ class MsgpackTCPStream: def raddr(self) -> Tuple[Any, ...]: return self._raddr - # XXX: should this instead be called `.sendall()`? - async def send(self, data: Any) -> None: + async def send(self, msg: Any) -> None: async with self._send_lock: return await self.stream.send_all( - msgpack.dumps(data, use_bin_type=True) + msgpack.dumps(msg, use_bin_type=True) ) async def recv(self) -> Any: @@ -191,10 +234,10 @@ class MsgspecTCPStream(MsgpackTCPStream): else: raise - async def send(self, data: Any) -> None: + async def send(self, msg: Any) -> None: async with self._send_lock: - bytes_data: bytes = self.encode(data) + bytes_data: bytes = self.encode(msg) # supposedly the fastest says, # https://stackoverflow.com/a/54027962 @@ -203,13 +246,16 @@ class MsgspecTCPStream(MsgpackTCPStream): return await self.stream.send_all(size + bytes_data) -def get_serializer_stream_type( - name: str, -) -> Type: +def get_msg_transport( + + key: Tuple[str, str], + +) -> Type[MsgTransport]: + return { - 'msgpack': MsgpackTCPStream, - 'msgspec': MsgspecTCPStream, - }[name] + ('msgpack', 'tcp'): MsgpackTCPStream, + ('msgspec', 'tcp'): MsgspecTCPStream, + }[key] class Channel: @@ -221,34 +267,34 @@ class Channel: def __init__( self, - destaddr: Optional[Tuple[str, int]] = None, - on_reconnect: typing.Callable[..., typing.Awaitable] = None, - auto_reconnect: bool = False, - stream: trio.SocketStream = None, # expected to be active + destaddr: Optional[Tuple[str, int]], + + msg_transport_type_key: Tuple[str, str] = ('msgpack', 'tcp'), + + # TODO: optional reconnection support? + # auto_reconnect: bool = False, + # on_reconnect: typing.Callable[..., typing.Awaitable] = None, ) -> None: - self._recon_seq = on_reconnect - self._autorecon = auto_reconnect + # self._recon_seq = on_reconnect + # self._autorecon = auto_reconnect # TODO: maybe expose this through the nursery api? try: # if installed load the msgspec transport since it's faster import msgspec # noqa - serializer = 'msgspec' + msg_transport_type_key = ('msgspec', 'tcp') except ImportError: - serializer = 'msgpack' + pass - self.stream_serializer_type = get_serializer_stream_type(serializer) - self.msgstream = self.stream_serializer_type( - stream) if stream else None + self._destaddr = destaddr + self._transport_key = msg_transport_type_key - if self.msgstream and destaddr: - raise ValueError( - f"A stream was provided with local addr {self.laddr}" - ) - - self._destaddr = self.msgstream.raddr if self.msgstream else destaddr + # Either created in ``.connect()`` or passed in by + # user in ``.from_stream()``. + self._stream: Optional[trio.SocketStream] = None + self.msgstream: Optional[MsgTransport] = None # set after handshake - always uid of far end self.uid: Optional[Tuple[str, str]] = None @@ -256,9 +302,34 @@ class Channel: # set if far end actor errors internally self._exc: Optional[Exception] = None self._agen = self._aiter_recv() - self._closed: bool = False + @classmethod + def from_stream( + cls, + stream: trio.SocketStream, + **kwargs, + + ) -> Channel: + + src, dst = get_stream_addrs(stream) + chan = Channel(destaddr=dst, **kwargs) + + # set immediately here from provided instance + chan._stream = stream + chan.set_msg_transport(stream) + return chan + + def set_msg_transport( + self, + stream: trio.SocketStream, + type_key: Optional[Tuple[str, str]] = None, + + ) -> MsgTransport: + type_key = type_key or self._transport_key + self.msgstream = get_msg_transport(type_key)(stream) + return self.msgstream + def __repr__(self) -> str: if self.msgstream: return repr( @@ -267,11 +338,11 @@ class Channel: return object.__repr__(self) @property - def laddr(self) -> Optional[Tuple[Any, ...]]: + def laddr(self) -> Optional[Tuple[str, int]]: return self.msgstream.laddr if self.msgstream else None @property - def raddr(self) -> Optional[Tuple[Any, ...]]: + def raddr(self) -> Optional[Tuple[str, int]]: return self.msgstream.raddr if self.msgstream else None async def connect( @@ -279,7 +350,7 @@ class Channel: destaddr: Tuple[Any, ...] = None, **kwargs - ) -> trio.SocketStream: + ) -> MsgTransport: if self.connected(): raise RuntimeError("channel is already connected?") @@ -291,12 +362,12 @@ class Channel: *destaddr, **kwargs ) - self.msgstream = self.stream_serializer_type(stream) + msgstream = self.set_msg_transport(stream) log.transport( - f'Opened channel to peer {self.laddr} -> {self.raddr}' + f'Opened channel[{type(msgstream)}]: {self.laddr} -> {self.raddr}' ) - return stream + return msgstream async def send(self, item: Any) -> None: @@ -307,16 +378,15 @@ class Channel: async def recv(self) -> Any: assert self.msgstream + return await self.msgstream.recv() - try: - return await self.msgstream.recv() - - except trio.BrokenResourceError: - if self._autorecon: - await self._reconnect() - return await self.recv() - - raise + # try: + # return await self.msgstream.recv() + # except trio.BrokenResourceError: + # if self._autorecon: + # await self._reconnect() + # return await self.recv() + # raise async def aclose(self) -> None: @@ -338,34 +408,36 @@ class Channel: def __aiter__(self): return self._agen - async def _reconnect(self) -> None: - """Handle connection failures by polling until a reconnect can be - established. - """ - down = False - while True: - try: - with trio.move_on_after(3) as cancel_scope: - await self.connect() - cancelled = cancel_scope.cancelled_caught - if cancelled: - log.transport( - "Reconnect timed out after 3 seconds, retrying...") - continue - else: - log.transport("Stream connection re-established!") - # run any reconnection sequence - on_recon = self._recon_seq - if on_recon: - await on_recon(self) - break - except (OSError, ConnectionRefusedError): - if not down: - down = True - log.transport( - f"Connection to {self.raddr} went down, waiting" - " for re-establishment") - await trio.sleep(1) + # async def _reconnect(self) -> None: + # """Handle connection failures by polling until a reconnect can be + # established. + # """ + # down = False + # while True: + # try: + # with trio.move_on_after(3) as cancel_scope: + # await self.connect() + # cancelled = cancel_scope.cancelled_caught + # if cancelled: + # log.transport( + # "Reconnect timed out after 3 seconds, retrying...") + # continue + # else: + # log.transport("Stream connection re-established!") + + # # TODO: run any reconnection sequence + # # on_recon = self._recon_seq + # # if on_recon: + # # await on_recon(self) + + # break + # except (OSError, ConnectionRefusedError): + # if not down: + # down = True + # log.transport( + # f"Connection to {self.raddr} went down, waiting" + # " for re-establishment") + # await trio.sleep(1) async def _aiter_recv( self @@ -384,16 +456,14 @@ class Channel: # await self.msgstream.send(sent) except trio.BrokenResourceError: - if not self._autorecon: - raise + # if not self._autorecon: + raise await self.aclose() - if self._autorecon: # attempt reconnect - await self._reconnect() - continue - else: - return + # if self._autorecon: # attempt reconnect + # await self._reconnect() + # continue def connected(self) -> bool: return self.msgstream.connected() if self.msgstream else False