diff --git a/tractor/_ipc.py b/tractor/_ipc.py index b8d0437..28bef97 100644 --- a/tractor/_ipc.py +++ b/tractor/_ipc.py @@ -2,10 +2,14 @@ Inter-process comms abstractions """ +from __future__ import annotations import platform import struct import typing -from typing import Any, Tuple, Optional, Type +from typing import ( + Any, Tuple, Optional, + Type, Protocol, TypeVar +) from tricycle import BufferedReceiveStream import msgpack @@ -21,6 +25,53 @@ _is_windows = platform.system() == 'Windows' log = get_logger(__name__) +def get_stream_addrs(stream: trio.SocketStream) -> Tuple: + # should both be IP sockets + lsockname = stream.socket.getsockname() + rsockname = stream.socket.getpeername() + return ( + tuple(lsockname[:2]), + tuple(rsockname[:2]), + ) + + +MsgType = TypeVar("MsgType") + +# TODO: consider using a generic def and indexing with our eventual +# msg definition/types? +# - https://docs.python.org/3/library/typing.html#typing.Protocol +# - https://jcristharif.com/msgspec/usage.html#structs + + +class MsgTransport(Protocol[MsgType]): + + stream: trio.SocketStream + + def __init__(self, stream: trio.SocketStream) -> None: + ... + + # XXX: should this instead be called `.sendall()`? + async def send(self, msg: MsgType) -> None: + ... + + async def recv(self) -> MsgType: + ... + + def __aiter__(self) -> MsgType: + ... + + def connected(self) -> bool: + ... + + @property + def laddr(self) -> Tuple[str, int]: + ... + + @property + def raddr(self) -> Tuple[str, int]: + ... + + class MsgpackTCPStream: '''A ``trio.SocketStream`` delivering ``msgpack`` formatted data using ``msgpack-python``. @@ -36,17 +87,10 @@ class MsgpackTCPStream: assert self.stream.socket # should both be IP sockets - lsockname = stream.socket.getsockname() - assert isinstance(lsockname, tuple) - self._laddr = lsockname[:2] + self._laddr, self._raddr = get_stream_addrs(stream) - rsockname = stream.socket.getpeername() - assert isinstance(rsockname, tuple) - self._raddr = rsockname[:2] - - # start first entry to read loop + # create read loop instance self._agen = self._iter_packets() - self._send_lock = trio.StrictFIFOLock() async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]: @@ -103,11 +147,10 @@ class MsgpackTCPStream: def raddr(self) -> Tuple[Any, ...]: return self._raddr - # XXX: should this instead be called `.sendall()`? - async def send(self, data: Any) -> None: + async def send(self, msg: Any) -> None: async with self._send_lock: return await self.stream.send_all( - msgpack.dumps(data, use_bin_type=True) + msgpack.dumps(msg, use_bin_type=True) ) async def recv(self) -> Any: @@ -191,10 +234,10 @@ class MsgspecTCPStream(MsgpackTCPStream): else: raise - async def send(self, data: Any) -> None: + async def send(self, msg: Any) -> None: async with self._send_lock: - bytes_data: bytes = self.encode(data) + bytes_data: bytes = self.encode(msg) # supposedly the fastest says, # https://stackoverflow.com/a/54027962 @@ -203,13 +246,16 @@ class MsgspecTCPStream(MsgpackTCPStream): return await self.stream.send_all(size + bytes_data) -def get_serializer_stream_type( - name: str, -) -> Type: +def get_msg_transport( + + key: Tuple[str, str], + +) -> Type[MsgTransport]: + return { - 'msgpack': MsgpackTCPStream, - 'msgspec': MsgspecTCPStream, - }[name] + ('msgpack', 'tcp'): MsgpackTCPStream, + ('msgspec', 'tcp'): MsgspecTCPStream, + }[key] class Channel: @@ -221,34 +267,34 @@ class Channel: def __init__( self, - destaddr: Optional[Tuple[str, int]] = None, - on_reconnect: typing.Callable[..., typing.Awaitable] = None, - auto_reconnect: bool = False, - stream: trio.SocketStream = None, # expected to be active + destaddr: Optional[Tuple[str, int]], + + msg_transport_type_key: Tuple[str, str] = ('msgpack', 'tcp'), + + # TODO: optional reconnection support? + # auto_reconnect: bool = False, + # on_reconnect: typing.Callable[..., typing.Awaitable] = None, ) -> None: - self._recon_seq = on_reconnect - self._autorecon = auto_reconnect + # self._recon_seq = on_reconnect + # self._autorecon = auto_reconnect # TODO: maybe expose this through the nursery api? try: # if installed load the msgspec transport since it's faster import msgspec # noqa - serializer = 'msgspec' + msg_transport_type_key = ('msgspec', 'tcp') except ImportError: - serializer = 'msgpack' + pass - self.stream_serializer_type = get_serializer_stream_type(serializer) - self.msgstream = self.stream_serializer_type( - stream) if stream else None + self._destaddr = destaddr + self._transport_key = msg_transport_type_key - if self.msgstream and destaddr: - raise ValueError( - f"A stream was provided with local addr {self.laddr}" - ) - - self._destaddr = self.msgstream.raddr if self.msgstream else destaddr + # Either created in ``.connect()`` or passed in by + # user in ``.from_stream()``. + self._stream: Optional[trio.SocketStream] = None + self.msgstream: Optional[MsgTransport] = None # set after handshake - always uid of far end self.uid: Optional[Tuple[str, str]] = None @@ -256,9 +302,34 @@ class Channel: # set if far end actor errors internally self._exc: Optional[Exception] = None self._agen = self._aiter_recv() - self._closed: bool = False + @classmethod + def from_stream( + cls, + stream: trio.SocketStream, + **kwargs, + + ) -> Channel: + + src, dst = get_stream_addrs(stream) + chan = Channel(destaddr=dst, **kwargs) + + # set immediately here from provided instance + chan._stream = stream + chan.set_msg_transport(stream) + return chan + + def set_msg_transport( + self, + stream: trio.SocketStream, + type_key: Optional[Tuple[str, str]] = None, + + ) -> MsgTransport: + type_key = type_key or self._transport_key + self.msgstream = get_msg_transport(type_key)(stream) + return self.msgstream + def __repr__(self) -> str: if self.msgstream: return repr( @@ -267,11 +338,11 @@ class Channel: return object.__repr__(self) @property - def laddr(self) -> Optional[Tuple[Any, ...]]: + def laddr(self) -> Optional[Tuple[str, int]]: return self.msgstream.laddr if self.msgstream else None @property - def raddr(self) -> Optional[Tuple[Any, ...]]: + def raddr(self) -> Optional[Tuple[str, int]]: return self.msgstream.raddr if self.msgstream else None async def connect( @@ -279,7 +350,7 @@ class Channel: destaddr: Tuple[Any, ...] = None, **kwargs - ) -> trio.SocketStream: + ) -> MsgTransport: if self.connected(): raise RuntimeError("channel is already connected?") @@ -291,12 +362,12 @@ class Channel: *destaddr, **kwargs ) - self.msgstream = self.stream_serializer_type(stream) + msgstream = self.set_msg_transport(stream) log.transport( - f'Opened channel to peer {self.laddr} -> {self.raddr}' + f'Opened channel[{type(msgstream)}]: {self.laddr} -> {self.raddr}' ) - return stream + return msgstream async def send(self, item: Any) -> None: @@ -307,16 +378,15 @@ class Channel: async def recv(self) -> Any: assert self.msgstream + return await self.msgstream.recv() - try: - return await self.msgstream.recv() - - except trio.BrokenResourceError: - if self._autorecon: - await self._reconnect() - return await self.recv() - - raise + # try: + # return await self.msgstream.recv() + # except trio.BrokenResourceError: + # if self._autorecon: + # await self._reconnect() + # return await self.recv() + # raise async def aclose(self) -> None: @@ -338,34 +408,36 @@ class Channel: def __aiter__(self): return self._agen - async def _reconnect(self) -> None: - """Handle connection failures by polling until a reconnect can be - established. - """ - down = False - while True: - try: - with trio.move_on_after(3) as cancel_scope: - await self.connect() - cancelled = cancel_scope.cancelled_caught - if cancelled: - log.transport( - "Reconnect timed out after 3 seconds, retrying...") - continue - else: - log.transport("Stream connection re-established!") - # run any reconnection sequence - on_recon = self._recon_seq - if on_recon: - await on_recon(self) - break - except (OSError, ConnectionRefusedError): - if not down: - down = True - log.transport( - f"Connection to {self.raddr} went down, waiting" - " for re-establishment") - await trio.sleep(1) + # async def _reconnect(self) -> None: + # """Handle connection failures by polling until a reconnect can be + # established. + # """ + # down = False + # while True: + # try: + # with trio.move_on_after(3) as cancel_scope: + # await self.connect() + # cancelled = cancel_scope.cancelled_caught + # if cancelled: + # log.transport( + # "Reconnect timed out after 3 seconds, retrying...") + # continue + # else: + # log.transport("Stream connection re-established!") + + # # TODO: run any reconnection sequence + # # on_recon = self._recon_seq + # # if on_recon: + # # await on_recon(self) + + # break + # except (OSError, ConnectionRefusedError): + # if not down: + # down = True + # log.transport( + # f"Connection to {self.raddr} went down, waiting" + # " for re-establishment") + # await trio.sleep(1) async def _aiter_recv( self @@ -384,16 +456,14 @@ class Channel: # await self.msgstream.send(sent) except trio.BrokenResourceError: - if not self._autorecon: - raise + # if not self._autorecon: + raise await self.aclose() - if self._autorecon: # attempt reconnect - await self._reconnect() - continue - else: - return + # if self._autorecon: # attempt reconnect + # await self._reconnect() + # continue def connected(self) -> bool: return self.msgstream.connected() if self.msgstream else False