forked from goodboy/tractor
Add "message transport" structured sub-typing
In an effort to have some kind of more formal interface around the transport layer, add a `MsgTransport` protocol type and use with the channel composition of message streams. Start a little "key map" of `(<codec>, <protocol>)` to `MsgTransport` types which can be dynamically loaded. Add a `Channel.from_stream()` constructor thus cleaning up the mangled logic that was in the constructor based on inputs. Drop all the "auto reconnect" channel logic for now since nothing is using it (internally) and it's likely it will need rework once we bring in a protocol besides TCP.optional_msgspec_support
parent
135459ca25
commit
c6dc96b08c
244
tractor/_ipc.py
244
tractor/_ipc.py
|
@ -2,10 +2,14 @@
|
||||||
Inter-process comms abstractions
|
Inter-process comms abstractions
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
import platform
|
import platform
|
||||||
import struct
|
import struct
|
||||||
import typing
|
import typing
|
||||||
from typing import Any, Tuple, Optional, Type
|
from typing import (
|
||||||
|
Any, Tuple, Optional,
|
||||||
|
Type, Protocol, TypeVar
|
||||||
|
)
|
||||||
|
|
||||||
from tricycle import BufferedReceiveStream
|
from tricycle import BufferedReceiveStream
|
||||||
import msgpack
|
import msgpack
|
||||||
|
@ -21,6 +25,53 @@ _is_windows = platform.system() == 'Windows'
|
||||||
log = get_logger(__name__)
|
log = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_stream_addrs(stream: trio.SocketStream) -> Tuple:
|
||||||
|
# should both be IP sockets
|
||||||
|
lsockname = stream.socket.getsockname()
|
||||||
|
rsockname = stream.socket.getpeername()
|
||||||
|
return (
|
||||||
|
tuple(lsockname[:2]),
|
||||||
|
tuple(rsockname[:2]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MsgType = TypeVar("MsgType")
|
||||||
|
|
||||||
|
# TODO: consider using a generic def and indexing with our eventual
|
||||||
|
# msg definition/types?
|
||||||
|
# - https://docs.python.org/3/library/typing.html#typing.Protocol
|
||||||
|
# - https://jcristharif.com/msgspec/usage.html#structs
|
||||||
|
|
||||||
|
|
||||||
|
class MsgTransport(Protocol[MsgType]):
|
||||||
|
|
||||||
|
stream: trio.SocketStream
|
||||||
|
|
||||||
|
def __init__(self, stream: trio.SocketStream) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
# XXX: should this instead be called `.sendall()`?
|
||||||
|
async def send(self, msg: MsgType) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def recv(self) -> MsgType:
|
||||||
|
...
|
||||||
|
|
||||||
|
def __aiter__(self) -> MsgType:
|
||||||
|
...
|
||||||
|
|
||||||
|
def connected(self) -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def laddr(self) -> Tuple[str, int]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def raddr(self) -> Tuple[str, int]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class MsgpackTCPStream:
|
class MsgpackTCPStream:
|
||||||
'''A ``trio.SocketStream`` delivering ``msgpack`` formatted data
|
'''A ``trio.SocketStream`` delivering ``msgpack`` formatted data
|
||||||
using ``msgpack-python``.
|
using ``msgpack-python``.
|
||||||
|
@ -36,17 +87,10 @@ class MsgpackTCPStream:
|
||||||
assert self.stream.socket
|
assert self.stream.socket
|
||||||
|
|
||||||
# should both be IP sockets
|
# should both be IP sockets
|
||||||
lsockname = stream.socket.getsockname()
|
self._laddr, self._raddr = get_stream_addrs(stream)
|
||||||
assert isinstance(lsockname, tuple)
|
|
||||||
self._laddr = lsockname[:2]
|
|
||||||
|
|
||||||
rsockname = stream.socket.getpeername()
|
# create read loop instance
|
||||||
assert isinstance(rsockname, tuple)
|
|
||||||
self._raddr = rsockname[:2]
|
|
||||||
|
|
||||||
# start first entry to read loop
|
|
||||||
self._agen = self._iter_packets()
|
self._agen = self._iter_packets()
|
||||||
|
|
||||||
self._send_lock = trio.StrictFIFOLock()
|
self._send_lock = trio.StrictFIFOLock()
|
||||||
|
|
||||||
async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
|
async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
|
||||||
|
@ -103,11 +147,10 @@ class MsgpackTCPStream:
|
||||||
def raddr(self) -> Tuple[Any, ...]:
|
def raddr(self) -> Tuple[Any, ...]:
|
||||||
return self._raddr
|
return self._raddr
|
||||||
|
|
||||||
# XXX: should this instead be called `.sendall()`?
|
async def send(self, msg: Any) -> None:
|
||||||
async def send(self, data: Any) -> None:
|
|
||||||
async with self._send_lock:
|
async with self._send_lock:
|
||||||
return await self.stream.send_all(
|
return await self.stream.send_all(
|
||||||
msgpack.dumps(data, use_bin_type=True)
|
msgpack.dumps(msg, use_bin_type=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def recv(self) -> Any:
|
async def recv(self) -> Any:
|
||||||
|
@ -191,10 +234,10 @@ class MsgspecTCPStream(MsgpackTCPStream):
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def send(self, data: Any) -> None:
|
async def send(self, msg: Any) -> None:
|
||||||
async with self._send_lock:
|
async with self._send_lock:
|
||||||
|
|
||||||
bytes_data: bytes = self.encode(data)
|
bytes_data: bytes = self.encode(msg)
|
||||||
|
|
||||||
# supposedly the fastest says,
|
# supposedly the fastest says,
|
||||||
# https://stackoverflow.com/a/54027962
|
# https://stackoverflow.com/a/54027962
|
||||||
|
@ -203,13 +246,16 @@ class MsgspecTCPStream(MsgpackTCPStream):
|
||||||
return await self.stream.send_all(size + bytes_data)
|
return await self.stream.send_all(size + bytes_data)
|
||||||
|
|
||||||
|
|
||||||
def get_serializer_stream_type(
|
def get_msg_transport(
|
||||||
name: str,
|
|
||||||
) -> Type:
|
key: Tuple[str, str],
|
||||||
|
|
||||||
|
) -> Type[MsgTransport]:
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'msgpack': MsgpackTCPStream,
|
('msgpack', 'tcp'): MsgpackTCPStream,
|
||||||
'msgspec': MsgspecTCPStream,
|
('msgspec', 'tcp'): MsgspecTCPStream,
|
||||||
}[name]
|
}[key]
|
||||||
|
|
||||||
|
|
||||||
class Channel:
|
class Channel:
|
||||||
|
@ -221,34 +267,34 @@ class Channel:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
||||||
self,
|
self,
|
||||||
destaddr: Optional[Tuple[str, int]] = None,
|
destaddr: Optional[Tuple[str, int]],
|
||||||
on_reconnect: typing.Callable[..., typing.Awaitable] = None,
|
|
||||||
auto_reconnect: bool = False,
|
msg_transport_type_key: Tuple[str, str] = ('msgpack', 'tcp'),
|
||||||
stream: trio.SocketStream = None, # expected to be active
|
|
||||||
|
# TODO: optional reconnection support?
|
||||||
|
# auto_reconnect: bool = False,
|
||||||
|
# on_reconnect: typing.Callable[..., typing.Awaitable] = None,
|
||||||
|
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
self._recon_seq = on_reconnect
|
# self._recon_seq = on_reconnect
|
||||||
self._autorecon = auto_reconnect
|
# self._autorecon = auto_reconnect
|
||||||
|
|
||||||
# TODO: maybe expose this through the nursery api?
|
# TODO: maybe expose this through the nursery api?
|
||||||
try:
|
try:
|
||||||
# if installed load the msgspec transport since it's faster
|
# if installed load the msgspec transport since it's faster
|
||||||
import msgspec # noqa
|
import msgspec # noqa
|
||||||
serializer = 'msgspec'
|
msg_transport_type_key = ('msgspec', 'tcp')
|
||||||
except ImportError:
|
except ImportError:
|
||||||
serializer = 'msgpack'
|
pass
|
||||||
|
|
||||||
self.stream_serializer_type = get_serializer_stream_type(serializer)
|
self._destaddr = destaddr
|
||||||
self.msgstream = self.stream_serializer_type(
|
self._transport_key = msg_transport_type_key
|
||||||
stream) if stream else None
|
|
||||||
|
|
||||||
if self.msgstream and destaddr:
|
# Either created in ``.connect()`` or passed in by
|
||||||
raise ValueError(
|
# user in ``.from_stream()``.
|
||||||
f"A stream was provided with local addr {self.laddr}"
|
self._stream: Optional[trio.SocketStream] = None
|
||||||
)
|
self.msgstream: Optional[MsgTransport] = None
|
||||||
|
|
||||||
self._destaddr = self.msgstream.raddr if self.msgstream else destaddr
|
|
||||||
|
|
||||||
# set after handshake - always uid of far end
|
# set after handshake - always uid of far end
|
||||||
self.uid: Optional[Tuple[str, str]] = None
|
self.uid: Optional[Tuple[str, str]] = None
|
||||||
|
@ -256,9 +302,34 @@ class Channel:
|
||||||
# set if far end actor errors internally
|
# set if far end actor errors internally
|
||||||
self._exc: Optional[Exception] = None
|
self._exc: Optional[Exception] = None
|
||||||
self._agen = self._aiter_recv()
|
self._agen = self._aiter_recv()
|
||||||
|
|
||||||
self._closed: bool = False
|
self._closed: bool = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_stream(
|
||||||
|
cls,
|
||||||
|
stream: trio.SocketStream,
|
||||||
|
**kwargs,
|
||||||
|
|
||||||
|
) -> Channel:
|
||||||
|
|
||||||
|
src, dst = get_stream_addrs(stream)
|
||||||
|
chan = Channel(destaddr=dst, **kwargs)
|
||||||
|
|
||||||
|
# set immediately here from provided instance
|
||||||
|
chan._stream = stream
|
||||||
|
chan.set_msg_transport(stream)
|
||||||
|
return chan
|
||||||
|
|
||||||
|
def set_msg_transport(
|
||||||
|
self,
|
||||||
|
stream: trio.SocketStream,
|
||||||
|
type_key: Optional[Tuple[str, str]] = None,
|
||||||
|
|
||||||
|
) -> MsgTransport:
|
||||||
|
type_key = type_key or self._transport_key
|
||||||
|
self.msgstream = get_msg_transport(type_key)(stream)
|
||||||
|
return self.msgstream
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
if self.msgstream:
|
if self.msgstream:
|
||||||
return repr(
|
return repr(
|
||||||
|
@ -267,11 +338,11 @@ class Channel:
|
||||||
return object.__repr__(self)
|
return object.__repr__(self)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def laddr(self) -> Optional[Tuple[Any, ...]]:
|
def laddr(self) -> Optional[Tuple[str, int]]:
|
||||||
return self.msgstream.laddr if self.msgstream else None
|
return self.msgstream.laddr if self.msgstream else None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def raddr(self) -> Optional[Tuple[Any, ...]]:
|
def raddr(self) -> Optional[Tuple[str, int]]:
|
||||||
return self.msgstream.raddr if self.msgstream else None
|
return self.msgstream.raddr if self.msgstream else None
|
||||||
|
|
||||||
async def connect(
|
async def connect(
|
||||||
|
@ -279,7 +350,7 @@ class Channel:
|
||||||
destaddr: Tuple[Any, ...] = None,
|
destaddr: Tuple[Any, ...] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
|
|
||||||
) -> trio.SocketStream:
|
) -> MsgTransport:
|
||||||
|
|
||||||
if self.connected():
|
if self.connected():
|
||||||
raise RuntimeError("channel is already connected?")
|
raise RuntimeError("channel is already connected?")
|
||||||
|
@ -291,12 +362,12 @@ class Channel:
|
||||||
*destaddr,
|
*destaddr,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
self.msgstream = self.stream_serializer_type(stream)
|
msgstream = self.set_msg_transport(stream)
|
||||||
|
|
||||||
log.transport(
|
log.transport(
|
||||||
f'Opened channel to peer {self.laddr} -> {self.raddr}'
|
f'Opened channel[{type(msgstream)}]: {self.laddr} -> {self.raddr}'
|
||||||
)
|
)
|
||||||
return stream
|
return msgstream
|
||||||
|
|
||||||
async def send(self, item: Any) -> None:
|
async def send(self, item: Any) -> None:
|
||||||
|
|
||||||
|
@ -307,16 +378,15 @@ class Channel:
|
||||||
|
|
||||||
async def recv(self) -> Any:
|
async def recv(self) -> Any:
|
||||||
assert self.msgstream
|
assert self.msgstream
|
||||||
|
|
||||||
try:
|
|
||||||
return await self.msgstream.recv()
|
return await self.msgstream.recv()
|
||||||
|
|
||||||
except trio.BrokenResourceError:
|
# try:
|
||||||
if self._autorecon:
|
# return await self.msgstream.recv()
|
||||||
await self._reconnect()
|
# except trio.BrokenResourceError:
|
||||||
return await self.recv()
|
# if self._autorecon:
|
||||||
|
# await self._reconnect()
|
||||||
raise
|
# return await self.recv()
|
||||||
|
# raise
|
||||||
|
|
||||||
async def aclose(self) -> None:
|
async def aclose(self) -> None:
|
||||||
|
|
||||||
|
@ -338,34 +408,36 @@ class Channel:
|
||||||
def __aiter__(self):
|
def __aiter__(self):
|
||||||
return self._agen
|
return self._agen
|
||||||
|
|
||||||
async def _reconnect(self) -> None:
|
# async def _reconnect(self) -> None:
|
||||||
"""Handle connection failures by polling until a reconnect can be
|
# """Handle connection failures by polling until a reconnect can be
|
||||||
established.
|
# established.
|
||||||
"""
|
# """
|
||||||
down = False
|
# down = False
|
||||||
while True:
|
# while True:
|
||||||
try:
|
# try:
|
||||||
with trio.move_on_after(3) as cancel_scope:
|
# with trio.move_on_after(3) as cancel_scope:
|
||||||
await self.connect()
|
# await self.connect()
|
||||||
cancelled = cancel_scope.cancelled_caught
|
# cancelled = cancel_scope.cancelled_caught
|
||||||
if cancelled:
|
# if cancelled:
|
||||||
log.transport(
|
# log.transport(
|
||||||
"Reconnect timed out after 3 seconds, retrying...")
|
# "Reconnect timed out after 3 seconds, retrying...")
|
||||||
continue
|
# continue
|
||||||
else:
|
# else:
|
||||||
log.transport("Stream connection re-established!")
|
# log.transport("Stream connection re-established!")
|
||||||
# run any reconnection sequence
|
|
||||||
on_recon = self._recon_seq
|
# # TODO: run any reconnection sequence
|
||||||
if on_recon:
|
# # on_recon = self._recon_seq
|
||||||
await on_recon(self)
|
# # if on_recon:
|
||||||
break
|
# # await on_recon(self)
|
||||||
except (OSError, ConnectionRefusedError):
|
|
||||||
if not down:
|
# break
|
||||||
down = True
|
# except (OSError, ConnectionRefusedError):
|
||||||
log.transport(
|
# if not down:
|
||||||
f"Connection to {self.raddr} went down, waiting"
|
# down = True
|
||||||
" for re-establishment")
|
# log.transport(
|
||||||
await trio.sleep(1)
|
# f"Connection to {self.raddr} went down, waiting"
|
||||||
|
# " for re-establishment")
|
||||||
|
# await trio.sleep(1)
|
||||||
|
|
||||||
async def _aiter_recv(
|
async def _aiter_recv(
|
||||||
self
|
self
|
||||||
|
@ -384,16 +456,14 @@ class Channel:
|
||||||
# await self.msgstream.send(sent)
|
# await self.msgstream.send(sent)
|
||||||
except trio.BrokenResourceError:
|
except trio.BrokenResourceError:
|
||||||
|
|
||||||
if not self._autorecon:
|
# if not self._autorecon:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
await self.aclose()
|
await self.aclose()
|
||||||
|
|
||||||
if self._autorecon: # attempt reconnect
|
# if self._autorecon: # attempt reconnect
|
||||||
await self._reconnect()
|
# await self._reconnect()
|
||||||
continue
|
# continue
|
||||||
else:
|
|
||||||
return
|
|
||||||
|
|
||||||
def connected(self) -> bool:
|
def connected(self) -> bool:
|
||||||
return self.msgstream.connected() if self.msgstream else False
|
return self.msgstream.connected() if self.msgstream else False
|
||||||
|
|
Loading…
Reference in New Issue