Add "message transport" structured sub-typing
In an effort to have some kind of more formal interface around the transport layer, add a `MsgTransport` protocol type and use with the channel composition of message streams. Start a little "key map" of `(<codec>, <protocol>)` to `MsgTransport` types which can be dynamically loaded. Add a `Channel.from_stream()` constructor thus cleaning up the mangled logic that was in the constructor based on inputs. Drop all the "auto reconnect" channel logic for now since nothing is using it (internally) and it's likely it will need rework once we bring in a protocol besides TCP.optional_msgspec_support
parent
135459ca25
commit
c6dc96b08c
248
tractor/_ipc.py
248
tractor/_ipc.py
|
@ -2,10 +2,14 @@
|
|||
Inter-process comms abstractions
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import platform
|
||||
import struct
|
||||
import typing
|
||||
from typing import Any, Tuple, Optional, Type
|
||||
from typing import (
|
||||
Any, Tuple, Optional,
|
||||
Type, Protocol, TypeVar
|
||||
)
|
||||
|
||||
from tricycle import BufferedReceiveStream
|
||||
import msgpack
|
||||
|
@ -21,6 +25,53 @@ _is_windows = platform.system() == 'Windows'
|
|||
log = get_logger(__name__)
|
||||
|
||||
|
||||
def get_stream_addrs(stream: trio.SocketStream) -> Tuple:
|
||||
# should both be IP sockets
|
||||
lsockname = stream.socket.getsockname()
|
||||
rsockname = stream.socket.getpeername()
|
||||
return (
|
||||
tuple(lsockname[:2]),
|
||||
tuple(rsockname[:2]),
|
||||
)
|
||||
|
||||
|
||||
MsgType = TypeVar("MsgType")
|
||||
|
||||
# TODO: consider using a generic def and indexing with our eventual
|
||||
# msg definition/types?
|
||||
# - https://docs.python.org/3/library/typing.html#typing.Protocol
|
||||
# - https://jcristharif.com/msgspec/usage.html#structs
|
||||
|
||||
|
||||
class MsgTransport(Protocol[MsgType]):
|
||||
|
||||
stream: trio.SocketStream
|
||||
|
||||
def __init__(self, stream: trio.SocketStream) -> None:
|
||||
...
|
||||
|
||||
# XXX: should this instead be called `.sendall()`?
|
||||
async def send(self, msg: MsgType) -> None:
|
||||
...
|
||||
|
||||
async def recv(self) -> MsgType:
|
||||
...
|
||||
|
||||
def __aiter__(self) -> MsgType:
|
||||
...
|
||||
|
||||
def connected(self) -> bool:
|
||||
...
|
||||
|
||||
@property
|
||||
def laddr(self) -> Tuple[str, int]:
|
||||
...
|
||||
|
||||
@property
|
||||
def raddr(self) -> Tuple[str, int]:
|
||||
...
|
||||
|
||||
|
||||
class MsgpackTCPStream:
|
||||
'''A ``trio.SocketStream`` delivering ``msgpack`` formatted data
|
||||
using ``msgpack-python``.
|
||||
|
@ -36,17 +87,10 @@ class MsgpackTCPStream:
|
|||
assert self.stream.socket
|
||||
|
||||
# should both be IP sockets
|
||||
lsockname = stream.socket.getsockname()
|
||||
assert isinstance(lsockname, tuple)
|
||||
self._laddr = lsockname[:2]
|
||||
self._laddr, self._raddr = get_stream_addrs(stream)
|
||||
|
||||
rsockname = stream.socket.getpeername()
|
||||
assert isinstance(rsockname, tuple)
|
||||
self._raddr = rsockname[:2]
|
||||
|
||||
# start first entry to read loop
|
||||
# create read loop instance
|
||||
self._agen = self._iter_packets()
|
||||
|
||||
self._send_lock = trio.StrictFIFOLock()
|
||||
|
||||
async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
|
||||
|
@ -103,11 +147,10 @@ class MsgpackTCPStream:
|
|||
def raddr(self) -> Tuple[Any, ...]:
|
||||
return self._raddr
|
||||
|
||||
# XXX: should this instead be called `.sendall()`?
|
||||
async def send(self, data: Any) -> None:
|
||||
async def send(self, msg: Any) -> None:
|
||||
async with self._send_lock:
|
||||
return await self.stream.send_all(
|
||||
msgpack.dumps(data, use_bin_type=True)
|
||||
msgpack.dumps(msg, use_bin_type=True)
|
||||
)
|
||||
|
||||
async def recv(self) -> Any:
|
||||
|
@ -191,10 +234,10 @@ class MsgspecTCPStream(MsgpackTCPStream):
|
|||
else:
|
||||
raise
|
||||
|
||||
async def send(self, data: Any) -> None:
|
||||
async def send(self, msg: Any) -> None:
|
||||
async with self._send_lock:
|
||||
|
||||
bytes_data: bytes = self.encode(data)
|
||||
bytes_data: bytes = self.encode(msg)
|
||||
|
||||
# supposedly the fastest says,
|
||||
# https://stackoverflow.com/a/54027962
|
||||
|
@ -203,13 +246,16 @@ class MsgspecTCPStream(MsgpackTCPStream):
|
|||
return await self.stream.send_all(size + bytes_data)
|
||||
|
||||
|
||||
def get_serializer_stream_type(
|
||||
name: str,
|
||||
) -> Type:
|
||||
def get_msg_transport(
|
||||
|
||||
key: Tuple[str, str],
|
||||
|
||||
) -> Type[MsgTransport]:
|
||||
|
||||
return {
|
||||
'msgpack': MsgpackTCPStream,
|
||||
'msgspec': MsgspecTCPStream,
|
||||
}[name]
|
||||
('msgpack', 'tcp'): MsgpackTCPStream,
|
||||
('msgspec', 'tcp'): MsgspecTCPStream,
|
||||
}[key]
|
||||
|
||||
|
||||
class Channel:
|
||||
|
@ -221,34 +267,34 @@ class Channel:
|
|||
def __init__(
|
||||
|
||||
self,
|
||||
destaddr: Optional[Tuple[str, int]] = None,
|
||||
on_reconnect: typing.Callable[..., typing.Awaitable] = None,
|
||||
auto_reconnect: bool = False,
|
||||
stream: trio.SocketStream = None, # expected to be active
|
||||
destaddr: Optional[Tuple[str, int]],
|
||||
|
||||
msg_transport_type_key: Tuple[str, str] = ('msgpack', 'tcp'),
|
||||
|
||||
# TODO: optional reconnection support?
|
||||
# auto_reconnect: bool = False,
|
||||
# on_reconnect: typing.Callable[..., typing.Awaitable] = None,
|
||||
|
||||
) -> None:
|
||||
|
||||
self._recon_seq = on_reconnect
|
||||
self._autorecon = auto_reconnect
|
||||
# self._recon_seq = on_reconnect
|
||||
# self._autorecon = auto_reconnect
|
||||
|
||||
# TODO: maybe expose this through the nursery api?
|
||||
try:
|
||||
# if installed load the msgspec transport since it's faster
|
||||
import msgspec # noqa
|
||||
serializer = 'msgspec'
|
||||
msg_transport_type_key = ('msgspec', 'tcp')
|
||||
except ImportError:
|
||||
serializer = 'msgpack'
|
||||
pass
|
||||
|
||||
self.stream_serializer_type = get_serializer_stream_type(serializer)
|
||||
self.msgstream = self.stream_serializer_type(
|
||||
stream) if stream else None
|
||||
self._destaddr = destaddr
|
||||
self._transport_key = msg_transport_type_key
|
||||
|
||||
if self.msgstream and destaddr:
|
||||
raise ValueError(
|
||||
f"A stream was provided with local addr {self.laddr}"
|
||||
)
|
||||
|
||||
self._destaddr = self.msgstream.raddr if self.msgstream else destaddr
|
||||
# Either created in ``.connect()`` or passed in by
|
||||
# user in ``.from_stream()``.
|
||||
self._stream: Optional[trio.SocketStream] = None
|
||||
self.msgstream: Optional[MsgTransport] = None
|
||||
|
||||
# set after handshake - always uid of far end
|
||||
self.uid: Optional[Tuple[str, str]] = None
|
||||
|
@ -256,9 +302,34 @@ class Channel:
|
|||
# set if far end actor errors internally
|
||||
self._exc: Optional[Exception] = None
|
||||
self._agen = self._aiter_recv()
|
||||
|
||||
self._closed: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_stream(
|
||||
cls,
|
||||
stream: trio.SocketStream,
|
||||
**kwargs,
|
||||
|
||||
) -> Channel:
|
||||
|
||||
src, dst = get_stream_addrs(stream)
|
||||
chan = Channel(destaddr=dst, **kwargs)
|
||||
|
||||
# set immediately here from provided instance
|
||||
chan._stream = stream
|
||||
chan.set_msg_transport(stream)
|
||||
return chan
|
||||
|
||||
def set_msg_transport(
|
||||
self,
|
||||
stream: trio.SocketStream,
|
||||
type_key: Optional[Tuple[str, str]] = None,
|
||||
|
||||
) -> MsgTransport:
|
||||
type_key = type_key or self._transport_key
|
||||
self.msgstream = get_msg_transport(type_key)(stream)
|
||||
return self.msgstream
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self.msgstream:
|
||||
return repr(
|
||||
|
@ -267,11 +338,11 @@ class Channel:
|
|||
return object.__repr__(self)
|
||||
|
||||
@property
|
||||
def laddr(self) -> Optional[Tuple[Any, ...]]:
|
||||
def laddr(self) -> Optional[Tuple[str, int]]:
|
||||
return self.msgstream.laddr if self.msgstream else None
|
||||
|
||||
@property
|
||||
def raddr(self) -> Optional[Tuple[Any, ...]]:
|
||||
def raddr(self) -> Optional[Tuple[str, int]]:
|
||||
return self.msgstream.raddr if self.msgstream else None
|
||||
|
||||
async def connect(
|
||||
|
@ -279,7 +350,7 @@ class Channel:
|
|||
destaddr: Tuple[Any, ...] = None,
|
||||
**kwargs
|
||||
|
||||
) -> trio.SocketStream:
|
||||
) -> MsgTransport:
|
||||
|
||||
if self.connected():
|
||||
raise RuntimeError("channel is already connected?")
|
||||
|
@ -291,12 +362,12 @@ class Channel:
|
|||
*destaddr,
|
||||
**kwargs
|
||||
)
|
||||
self.msgstream = self.stream_serializer_type(stream)
|
||||
msgstream = self.set_msg_transport(stream)
|
||||
|
||||
log.transport(
|
||||
f'Opened channel to peer {self.laddr} -> {self.raddr}'
|
||||
f'Opened channel[{type(msgstream)}]: {self.laddr} -> {self.raddr}'
|
||||
)
|
||||
return stream
|
||||
return msgstream
|
||||
|
||||
async def send(self, item: Any) -> None:
|
||||
|
||||
|
@ -307,16 +378,15 @@ class Channel:
|
|||
|
||||
async def recv(self) -> Any:
|
||||
assert self.msgstream
|
||||
return await self.msgstream.recv()
|
||||
|
||||
try:
|
||||
return await self.msgstream.recv()
|
||||
|
||||
except trio.BrokenResourceError:
|
||||
if self._autorecon:
|
||||
await self._reconnect()
|
||||
return await self.recv()
|
||||
|
||||
raise
|
||||
# try:
|
||||
# return await self.msgstream.recv()
|
||||
# except trio.BrokenResourceError:
|
||||
# if self._autorecon:
|
||||
# await self._reconnect()
|
||||
# return await self.recv()
|
||||
# raise
|
||||
|
||||
async def aclose(self) -> None:
|
||||
|
||||
|
@ -338,34 +408,36 @@ class Channel:
|
|||
def __aiter__(self):
|
||||
return self._agen
|
||||
|
||||
async def _reconnect(self) -> None:
|
||||
"""Handle connection failures by polling until a reconnect can be
|
||||
established.
|
||||
"""
|
||||
down = False
|
||||
while True:
|
||||
try:
|
||||
with trio.move_on_after(3) as cancel_scope:
|
||||
await self.connect()
|
||||
cancelled = cancel_scope.cancelled_caught
|
||||
if cancelled:
|
||||
log.transport(
|
||||
"Reconnect timed out after 3 seconds, retrying...")
|
||||
continue
|
||||
else:
|
||||
log.transport("Stream connection re-established!")
|
||||
# run any reconnection sequence
|
||||
on_recon = self._recon_seq
|
||||
if on_recon:
|
||||
await on_recon(self)
|
||||
break
|
||||
except (OSError, ConnectionRefusedError):
|
||||
if not down:
|
||||
down = True
|
||||
log.transport(
|
||||
f"Connection to {self.raddr} went down, waiting"
|
||||
" for re-establishment")
|
||||
await trio.sleep(1)
|
||||
# async def _reconnect(self) -> None:
|
||||
# """Handle connection failures by polling until a reconnect can be
|
||||
# established.
|
||||
# """
|
||||
# down = False
|
||||
# while True:
|
||||
# try:
|
||||
# with trio.move_on_after(3) as cancel_scope:
|
||||
# await self.connect()
|
||||
# cancelled = cancel_scope.cancelled_caught
|
||||
# if cancelled:
|
||||
# log.transport(
|
||||
# "Reconnect timed out after 3 seconds, retrying...")
|
||||
# continue
|
||||
# else:
|
||||
# log.transport("Stream connection re-established!")
|
||||
|
||||
# # TODO: run any reconnection sequence
|
||||
# # on_recon = self._recon_seq
|
||||
# # if on_recon:
|
||||
# # await on_recon(self)
|
||||
|
||||
# break
|
||||
# except (OSError, ConnectionRefusedError):
|
||||
# if not down:
|
||||
# down = True
|
||||
# log.transport(
|
||||
# f"Connection to {self.raddr} went down, waiting"
|
||||
# " for re-establishment")
|
||||
# await trio.sleep(1)
|
||||
|
||||
async def _aiter_recv(
|
||||
self
|
||||
|
@ -384,16 +456,14 @@ class Channel:
|
|||
# await self.msgstream.send(sent)
|
||||
except trio.BrokenResourceError:
|
||||
|
||||
if not self._autorecon:
|
||||
raise
|
||||
# if not self._autorecon:
|
||||
raise
|
||||
|
||||
await self.aclose()
|
||||
|
||||
if self._autorecon: # attempt reconnect
|
||||
await self._reconnect()
|
||||
continue
|
||||
else:
|
||||
return
|
||||
# if self._autorecon: # attempt reconnect
|
||||
# await self._reconnect()
|
||||
# continue
|
||||
|
||||
def connected(self) -> bool:
|
||||
return self.msgstream.connected() if self.msgstream else False
|
||||
|
|
Loading…
Reference in New Issue