Starting to make tractor.ipc.Channel work with multiple MsgTransports
parent
32b5210648
commit
2907719cbe
|
@ -75,7 +75,7 @@ async def get_registry(
|
|||
# TODO: try to look pre-existing connection from
|
||||
# `Actor._peers` and use it instead?
|
||||
async with (
|
||||
_connect_chan(host, port) as chan,
|
||||
_connect_chan((host, port)) as chan,
|
||||
open_portal(chan) as regstr_ptl,
|
||||
):
|
||||
yield regstr_ptl
|
||||
|
@ -93,7 +93,7 @@ async def get_root(
|
|||
assert host is not None
|
||||
|
||||
async with (
|
||||
_connect_chan(host, port) as chan,
|
||||
_connect_chan((host, port)) as chan,
|
||||
open_portal(chan, **kwargs) as portal,
|
||||
):
|
||||
yield portal
|
||||
|
@ -187,7 +187,7 @@ async def maybe_open_portal(
|
|||
pass
|
||||
|
||||
if sockaddr:
|
||||
async with _connect_chan(*sockaddr) as chan:
|
||||
async with _connect_chan(sockaddr) as chan:
|
||||
async with open_portal(chan) as portal:
|
||||
yield portal
|
||||
else:
|
||||
|
@ -310,6 +310,6 @@ async def wait_for_actor(
|
|||
# TODO: offer multi-portal yields in multi-homed case?
|
||||
sockaddr: tuple[str, int] = sockaddrs[-1]
|
||||
|
||||
async with _connect_chan(*sockaddr) as chan:
|
||||
async with _connect_chan(sockaddr) as chan:
|
||||
async with open_portal(chan) as portal:
|
||||
yield portal
|
||||
|
|
|
@ -271,7 +271,7 @@ async def open_root_actor(
|
|||
# be better to eventually have a "discovery" protocol
|
||||
# with basic handshake instead?
|
||||
with trio.move_on_after(timeout):
|
||||
async with _connect_chan(*addr):
|
||||
async with _connect_chan(addr):
|
||||
ponged_addrs.append(addr)
|
||||
|
||||
except OSError:
|
||||
|
|
|
@ -1040,10 +1040,7 @@ class Actor:
|
|||
# Connect back to the parent actor and conduct initial
|
||||
# handshake. From this point on if we error, we
|
||||
# attempt to ship the exception back to the parent.
|
||||
chan = Channel(
|
||||
destaddr=parent_addr,
|
||||
)
|
||||
await chan.connect()
|
||||
chan = await Channel.from_destaddr(parent_addr)
|
||||
|
||||
# TODO: move this into a `Channel.handshake()`?
|
||||
# Initial handshake: swap names.
|
||||
|
|
|
@ -13,20 +13,26 @@
|
|||
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
|
||||
import platform
|
||||
|
||||
from ._transport import MsgTransport as MsgTransport
|
||||
from ._transport import (
|
||||
AddressType as AddressType,
|
||||
MsgType as MsgType,
|
||||
MsgTransport as MsgTransport,
|
||||
MsgpackTransport as MsgpackTransport
|
||||
)
|
||||
|
||||
from ._tcp import (
|
||||
get_stream_addrs as get_stream_addrs,
|
||||
MsgpackTCPStream as MsgpackTCPStream
|
||||
from ._tcp import MsgpackTCPStream as MsgpackTCPStream
|
||||
from ._uds import MsgpackUDSStream as MsgpackUDSStream
|
||||
|
||||
from ._types import (
|
||||
transport_from_destaddr as transport_from_destaddr,
|
||||
transport_from_stream as transport_from_stream,
|
||||
AddressTypes as AddressTypes
|
||||
)
|
||||
|
||||
from ._chan import (
|
||||
_connect_chan as _connect_chan,
|
||||
get_msg_transport as get_msg_transport,
|
||||
Channel as Channel
|
||||
)
|
||||
|
||||
|
|
|
@ -29,15 +29,15 @@ from pprint import pformat
|
|||
import typing
|
||||
from typing import (
|
||||
Any,
|
||||
Type
|
||||
)
|
||||
|
||||
import trio
|
||||
|
||||
from tractor.ipc._transport import MsgTransport
|
||||
from tractor.ipc._tcp import (
|
||||
MsgpackTCPStream,
|
||||
get_stream_addrs
|
||||
from tractor.ipc._types import (
|
||||
transport_from_destaddr,
|
||||
transport_from_stream,
|
||||
AddressTypes
|
||||
)
|
||||
from tractor.log import get_logger
|
||||
from tractor._exceptions import (
|
||||
|
@ -52,17 +52,6 @@ log = get_logger(__name__)
|
|||
_is_windows = platform.system() == 'Windows'
|
||||
|
||||
|
||||
def get_msg_transport(
|
||||
|
||||
key: tuple[str, str],
|
||||
|
||||
) -> Type[MsgTransport]:
|
||||
|
||||
return {
|
||||
('msgpack', 'tcp'): MsgpackTCPStream,
|
||||
}[key]
|
||||
|
||||
|
||||
class Channel:
|
||||
'''
|
||||
An inter-process channel for communication between (remote) actors.
|
||||
|
@ -77,10 +66,8 @@ class Channel:
|
|||
def __init__(
|
||||
|
||||
self,
|
||||
destaddr: tuple[str, int]|None,
|
||||
|
||||
msg_transport_type_key: tuple[str, str] = ('msgpack', 'tcp'),
|
||||
|
||||
destaddr: AddressTypes|None = None,
|
||||
transport: MsgTransport|None = None,
|
||||
# TODO: optional reconnection support?
|
||||
# auto_reconnect: bool = False,
|
||||
# on_reconnect: typing.Callable[..., typing.Awaitable] = None,
|
||||
|
@ -90,13 +77,11 @@ class Channel:
|
|||
# self._recon_seq = on_reconnect
|
||||
# self._autorecon = auto_reconnect
|
||||
|
||||
self._destaddr = destaddr
|
||||
self._transport_key = msg_transport_type_key
|
||||
|
||||
# Either created in ``.connect()`` or passed in by
|
||||
# user in ``.from_stream()``.
|
||||
self._stream: trio.SocketStream|None = None
|
||||
self._transport: MsgTransport|None = None
|
||||
self._transport: MsgTransport|None = transport
|
||||
|
||||
self._destaddr = destaddr if destaddr else self._transport.raddr
|
||||
|
||||
# set after handshake - always uid of far end
|
||||
self.uid: tuple[str, str]|None = None
|
||||
|
@ -110,6 +95,10 @@ class Channel:
|
|||
# runtime.
|
||||
self._cancel_called: bool = False
|
||||
|
||||
@property
|
||||
def stream(self) -> trio.abc.Stream | None:
|
||||
return self._transport.stream if self._transport else None
|
||||
|
||||
@property
|
||||
def msgstream(self) -> MsgTransport:
|
||||
log.info(
|
||||
|
@ -124,52 +113,31 @@ class Channel:
|
|||
@classmethod
|
||||
def from_stream(
|
||||
cls,
|
||||
stream: trio.SocketStream,
|
||||
**kwargs,
|
||||
|
||||
stream: trio.abc.Stream,
|
||||
) -> Channel:
|
||||
|
||||
src, dst = get_stream_addrs(stream)
|
||||
chan = Channel(
|
||||
destaddr=dst,
|
||||
**kwargs,
|
||||
transport_cls = transport_from_stream(stream)
|
||||
return Channel(
|
||||
transport=transport_cls(stream)
|
||||
)
|
||||
|
||||
# set immediately here from provided instance
|
||||
chan._stream: trio.SocketStream = stream
|
||||
chan.set_msg_transport(stream)
|
||||
return chan
|
||||
@classmethod
|
||||
async def from_destaddr(
|
||||
cls,
|
||||
destaddr: AddressTypes,
|
||||
**kwargs
|
||||
) -> Channel:
|
||||
transport_cls = transport_from_destaddr(destaddr)
|
||||
transport = await transport_cls.connect_to(destaddr, **kwargs)
|
||||
|
||||
def set_msg_transport(
|
||||
self,
|
||||
stream: trio.SocketStream,
|
||||
type_key: tuple[str, str]|None = None,
|
||||
|
||||
# XXX optionally provided codec pair for `msgspec`:
|
||||
# https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types
|
||||
codec: MsgCodec|None = None,
|
||||
|
||||
) -> MsgTransport:
|
||||
type_key = (
|
||||
type_key
|
||||
or
|
||||
self._transport_key
|
||||
log.transport(
|
||||
f'Opened channel[{type(transport)}]: {transport.laddr} -> {transport.raddr}'
|
||||
)
|
||||
# get transport type, then
|
||||
self._transport = get_msg_transport(
|
||||
type_key
|
||||
# instantiate an instance of the msg-transport
|
||||
)(
|
||||
stream,
|
||||
codec=codec,
|
||||
)
|
||||
return self._transport
|
||||
return Channel(transport=transport)
|
||||
|
||||
@cm
|
||||
def apply_codec(
|
||||
self,
|
||||
codec: MsgCodec,
|
||||
|
||||
) -> None:
|
||||
'''
|
||||
Temporarily override the underlying IPC msg codec for
|
||||
|
@ -189,7 +157,7 @@ class Channel:
|
|||
return '<Channel with inactive transport?>'
|
||||
|
||||
return repr(
|
||||
self._transport.stream.socket._sock
|
||||
self._transport
|
||||
).replace( # type: ignore
|
||||
"socket.socket",
|
||||
"Channel",
|
||||
|
@ -203,30 +171,6 @@ class Channel:
|
|||
def raddr(self) -> tuple[str, int]|None:
|
||||
return self._transport.raddr if self._transport else None
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
destaddr: tuple[Any, ...] | None = None,
|
||||
**kwargs
|
||||
|
||||
) -> MsgTransport:
|
||||
|
||||
if self.connected():
|
||||
raise RuntimeError("channel is already connected?")
|
||||
|
||||
destaddr = destaddr or self._destaddr
|
||||
assert isinstance(destaddr, tuple)
|
||||
|
||||
stream = await trio.open_tcp_stream(
|
||||
*destaddr,
|
||||
**kwargs
|
||||
)
|
||||
transport = self.set_msg_transport(stream)
|
||||
|
||||
log.transport(
|
||||
f'Opened channel[{type(transport)}]: {self.laddr} -> {self.raddr}'
|
||||
)
|
||||
return transport
|
||||
|
||||
# TODO: something like,
|
||||
# `pdbp.hideframe_on(errors=[MsgTypeError])`
|
||||
# instead of the `try/except` hack we have rn..
|
||||
|
@ -388,17 +332,14 @@ class Channel:
|
|||
|
||||
@acm
|
||||
async def _connect_chan(
|
||||
host: str,
|
||||
port: int
|
||||
|
||||
destaddr: AddressTypes
|
||||
) -> typing.AsyncGenerator[Channel, None]:
|
||||
'''
|
||||
Create and connect a channel with disconnect on context manager
|
||||
teardown.
|
||||
|
||||
'''
|
||||
chan = Channel((host, port))
|
||||
await chan.connect()
|
||||
chan = await Channel.from_destaddr(destaddr)
|
||||
yield chan
|
||||
with trio.CancelScope(shield=True):
|
||||
await chan.aclose()
|
||||
|
|
|
@ -18,388 +18,75 @@ TCP implementation of tractor.ipc._transport.MsgTransport protocol
|
|||
|
||||
'''
|
||||
from __future__ import annotations
|
||||
from collections.abc import (
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
)
|
||||
import struct
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
)
|
||||
|
||||
import msgspec
|
||||
from tricycle import BufferedReceiveStream
|
||||
import trio
|
||||
|
||||
from tractor.msg import MsgCodec
|
||||
from tractor.log import get_logger
|
||||
from tractor._exceptions import (
|
||||
MsgTypeError,
|
||||
TransportClosed,
|
||||
_mk_send_mte,
|
||||
_mk_recv_mte,
|
||||
)
|
||||
from tractor.msg import (
|
||||
_ctxvar_MsgCodec,
|
||||
# _codec, XXX see `self._codec` sanity/debug checks
|
||||
MsgCodec,
|
||||
types as msgtypes,
|
||||
pretty_struct,
|
||||
)
|
||||
from tractor.ipc import MsgTransport
|
||||
from tractor.ipc._transport import MsgpackTransport
|
||||
|
||||
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
def get_stream_addrs(
|
||||
stream: trio.SocketStream
|
||||
) -> tuple[
|
||||
tuple[str, int], # local
|
||||
tuple[str, int], # remote
|
||||
]:
|
||||
# TODO: typing oddity.. not sure why we have to inherit here, but it
|
||||
# seems to be an issue with `get_msg_transport()` returning
|
||||
# a `Type[Protocol]`; probably should make a `mypy` issue?
|
||||
class MsgpackTCPStream(MsgpackTransport):
|
||||
'''
|
||||
Return the `trio` streaming transport prot's socket-addrs for
|
||||
both the local and remote sides as a pair.
|
||||
A ``trio.SocketStream`` delivering ``msgpack`` formatted data
|
||||
using the ``msgspec`` codec lib.
|
||||
|
||||
'''
|
||||
# rn, should both be IP sockets
|
||||
address_type = tuple[str, int]
|
||||
layer_key: int = 4
|
||||
name_key: str = 'tcp'
|
||||
|
||||
# def __init__(
|
||||
# self,
|
||||
# stream: trio.SocketStream,
|
||||
# prefix_size: int = 4,
|
||||
# codec: CodecType = None,
|
||||
|
||||
# ) -> None:
|
||||
# super().__init__(
|
||||
# stream,
|
||||
# prefix_size=prefix_size,
|
||||
# codec=codec
|
||||
# )
|
||||
|
||||
def connected(self) -> bool:
|
||||
return self.stream.socket.fileno() != -1
|
||||
|
||||
@classmethod
|
||||
async def connect_to(
|
||||
cls,
|
||||
destaddr: tuple[str, int],
|
||||
prefix_size: int = 4,
|
||||
codec: MsgCodec|None = None,
|
||||
**kwargs
|
||||
) -> MsgpackTCPStream:
|
||||
stream = await trio.open_tcp_stream(
|
||||
*destaddr,
|
||||
**kwargs
|
||||
)
|
||||
return MsgpackTCPStream(
|
||||
stream,
|
||||
prefix_size=prefix_size,
|
||||
codec=codec
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_stream_addrs(
|
||||
cls,
|
||||
stream: trio.SocketStream
|
||||
) -> tuple[
|
||||
tuple[str, int],
|
||||
tuple[str, int]
|
||||
]:
|
||||
lsockname = stream.socket.getsockname()
|
||||
rsockname = stream.socket.getpeername()
|
||||
return (
|
||||
tuple(lsockname[:2]),
|
||||
tuple(rsockname[:2]),
|
||||
)
|
||||
|
||||
|
||||
# TODO: typing oddity.. not sure why we have to inherit here, but it
|
||||
# seems to be an issue with `get_msg_transport()` returning
|
||||
# a `Type[Protocol]`; probably should make a `mypy` issue?
|
||||
class MsgpackTCPStream(MsgTransport):
|
||||
'''
|
||||
A ``trio.SocketStream`` delivering ``msgpack`` formatted data
|
||||
using the ``msgspec`` codec lib.
|
||||
|
||||
'''
|
||||
layer_key: int = 4
|
||||
name_key: str = 'tcp'
|
||||
|
||||
# TODO: better naming for this?
|
||||
# -[ ] check how libp2p does naming for such things?
|
||||
codec_key: str = 'msgpack'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream: trio.SocketStream,
|
||||
prefix_size: int = 4,
|
||||
|
||||
# XXX optionally provided codec pair for `msgspec`:
|
||||
# https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types
|
||||
#
|
||||
# TODO: define this as a `Codec` struct which can be
|
||||
# overriden dynamically by the application/runtime?
|
||||
codec: tuple[
|
||||
Callable[[Any], Any]|None, # coder
|
||||
Callable[[type, Any], Any]|None, # decoder
|
||||
]|None = None,
|
||||
|
||||
) -> None:
|
||||
|
||||
self.stream = stream
|
||||
assert self.stream.socket
|
||||
|
||||
# should both be IP sockets
|
||||
self._laddr, self._raddr = get_stream_addrs(stream)
|
||||
|
||||
# create read loop instance
|
||||
self._aiter_pkts = self._iter_packets()
|
||||
self._send_lock = trio.StrictFIFOLock()
|
||||
|
||||
# public i guess?
|
||||
self.drained: list[dict] = []
|
||||
|
||||
self.recv_stream = BufferedReceiveStream(
|
||||
transport_stream=stream
|
||||
)
|
||||
self.prefix_size = prefix_size
|
||||
|
||||
# allow for custom IPC msg interchange format
|
||||
# dynamic override Bo
|
||||
self._task = trio.lowlevel.current_task()
|
||||
|
||||
# XXX for ctxvar debug only!
|
||||
# self._codec: MsgCodec = (
|
||||
# codec
|
||||
# or
|
||||
# _codec._ctxvar_MsgCodec.get()
|
||||
# )
|
||||
|
||||
async def _iter_packets(self) -> AsyncGenerator[dict, None]:
|
||||
'''
|
||||
Yield `bytes`-blob decoded packets from the underlying TCP
|
||||
stream using the current task's `MsgCodec`.
|
||||
|
||||
This is a streaming routine implemented as an async generator
|
||||
func (which was the original design, but could be changed?)
|
||||
and is allocated by a `.__call__()` inside `.__init__()` where
|
||||
it is assigned to the `._aiter_pkts` attr.
|
||||
|
||||
'''
|
||||
decodes_failed: int = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
header: bytes = await self.recv_stream.receive_exactly(4)
|
||||
except (
|
||||
ValueError,
|
||||
ConnectionResetError,
|
||||
|
||||
# not sure entirely why we need this but without it we
|
||||
# seem to be getting racy failures here on
|
||||
# arbiter/registry name subs..
|
||||
trio.BrokenResourceError,
|
||||
|
||||
) as trans_err:
|
||||
|
||||
loglevel = 'transport'
|
||||
match trans_err:
|
||||
# case (
|
||||
# ConnectionResetError()
|
||||
# ):
|
||||
# loglevel = 'transport'
|
||||
|
||||
# peer actor (graceful??) TCP EOF but `tricycle`
|
||||
# seems to raise a 0-bytes-read?
|
||||
case ValueError() if (
|
||||
'unclean EOF' in trans_err.args[0]
|
||||
):
|
||||
pass
|
||||
|
||||
# peer actor (task) prolly shutdown quickly due
|
||||
# to cancellation
|
||||
case trio.BrokenResourceError() if (
|
||||
'Connection reset by peer' in trans_err.args[0]
|
||||
):
|
||||
pass
|
||||
|
||||
# unless the disconnect condition falls under "a
|
||||
# normal operation breakage" we usualy console warn
|
||||
# about it.
|
||||
case _:
|
||||
loglevel: str = 'warning'
|
||||
|
||||
|
||||
raise TransportClosed(
|
||||
message=(
|
||||
f'IPC transport already closed by peer\n'
|
||||
f'x)> {type(trans_err)}\n'
|
||||
f' |_{self}\n'
|
||||
),
|
||||
loglevel=loglevel,
|
||||
) from trans_err
|
||||
|
||||
# XXX definitely can happen if transport is closed
|
||||
# manually by another `trio.lowlevel.Task` in the
|
||||
# same actor; we use this in some simulated fault
|
||||
# testing for ex, but generally should never happen
|
||||
# under normal operation!
|
||||
#
|
||||
# NOTE: as such we always re-raise this error from the
|
||||
# RPC msg loop!
|
||||
except trio.ClosedResourceError as closure_err:
|
||||
raise TransportClosed(
|
||||
message=(
|
||||
f'IPC transport already manually closed locally?\n'
|
||||
f'x)> {type(closure_err)} \n'
|
||||
f' |_{self}\n'
|
||||
),
|
||||
loglevel='error',
|
||||
raise_on_report=(
|
||||
closure_err.args[0] == 'another task closed this fd'
|
||||
or
|
||||
closure_err.args[0] in ['another task closed this fd']
|
||||
),
|
||||
) from closure_err
|
||||
|
||||
# graceful TCP EOF disconnect
|
||||
if header == b'':
|
||||
raise TransportClosed(
|
||||
message=(
|
||||
f'IPC transport already gracefully closed\n'
|
||||
f')>\n'
|
||||
f'|_{self}\n'
|
||||
),
|
||||
loglevel='transport',
|
||||
# cause=??? # handy or no?
|
||||
)
|
||||
|
||||
size: int
|
||||
size, = struct.unpack("<I", header)
|
||||
|
||||
log.transport(f'received header {size}') # type: ignore
|
||||
msg_bytes: bytes = await self.recv_stream.receive_exactly(size)
|
||||
|
||||
log.transport(f"received {msg_bytes}") # type: ignore
|
||||
try:
|
||||
# NOTE: lookup the `trio.Task.context`'s var for
|
||||
# the current `MsgCodec`.
|
||||
codec: MsgCodec = _ctxvar_MsgCodec.get()
|
||||
|
||||
# XXX for ctxvar debug only!
|
||||
# if self._codec.pld_spec != codec.pld_spec:
|
||||
# assert (
|
||||
# task := trio.lowlevel.current_task()
|
||||
# ) is not self._task
|
||||
# self._task = task
|
||||
# self._codec = codec
|
||||
# log.runtime(
|
||||
# f'Using new codec in {self}.recv()\n'
|
||||
# f'codec: {self._codec}\n\n'
|
||||
# f'msg_bytes: {msg_bytes}\n'
|
||||
# )
|
||||
yield codec.decode(msg_bytes)
|
||||
|
||||
# XXX NOTE: since the below error derives from
|
||||
# `DecodeError` we need to catch is specially
|
||||
# and always raise such that spec violations
|
||||
# are never allowed to be caught silently!
|
||||
except msgspec.ValidationError as verr:
|
||||
msgtyperr: MsgTypeError = _mk_recv_mte(
|
||||
msg=msg_bytes,
|
||||
codec=codec,
|
||||
src_validation_error=verr,
|
||||
)
|
||||
# XXX deliver up to `Channel.recv()` where
|
||||
# a re-raise and `Error`-pack can inject the far
|
||||
# end actor `.uid`.
|
||||
yield msgtyperr
|
||||
|
||||
except (
|
||||
msgspec.DecodeError,
|
||||
UnicodeDecodeError,
|
||||
):
|
||||
if decodes_failed < 4:
|
||||
# ignore decoding errors for now and assume they have to
|
||||
# do with a channel drop - hope that receiving from the
|
||||
# channel will raise an expected error and bubble up.
|
||||
try:
|
||||
msg_str: str|bytes = msg_bytes.decode()
|
||||
except UnicodeDecodeError:
|
||||
msg_str = msg_bytes
|
||||
|
||||
log.exception(
|
||||
'Failed to decode msg?\n'
|
||||
f'{codec}\n\n'
|
||||
'Rxed bytes from wire:\n\n'
|
||||
f'{msg_str!r}\n'
|
||||
)
|
||||
decodes_failed += 1
|
||||
else:
|
||||
raise
|
||||
|
||||
async def send(
|
||||
self,
|
||||
msg: msgtypes.MsgType,
|
||||
|
||||
strict_types: bool = True,
|
||||
hide_tb: bool = False,
|
||||
|
||||
) -> None:
|
||||
'''
|
||||
Send a msgpack encoded py-object-blob-as-msg over TCP.
|
||||
|
||||
If `strict_types == True` then a `MsgTypeError` will be raised on any
|
||||
invalid msg type
|
||||
|
||||
'''
|
||||
__tracebackhide__: bool = hide_tb
|
||||
|
||||
# XXX see `trio._sync.AsyncContextManagerMixin` for details
|
||||
# on the `.acquire()`/`.release()` sequencing..
|
||||
async with self._send_lock:
|
||||
|
||||
# NOTE: lookup the `trio.Task.context`'s var for
|
||||
# the current `MsgCodec`.
|
||||
codec: MsgCodec = _ctxvar_MsgCodec.get()
|
||||
|
||||
# XXX for ctxvar debug only!
|
||||
# if self._codec.pld_spec != codec.pld_spec:
|
||||
# self._codec = codec
|
||||
# log.runtime(
|
||||
# f'Using new codec in {self}.send()\n'
|
||||
# f'codec: {self._codec}\n\n'
|
||||
# f'msg: {msg}\n'
|
||||
# )
|
||||
|
||||
if type(msg) not in msgtypes.__msg_types__:
|
||||
if strict_types:
|
||||
raise _mk_send_mte(
|
||||
msg,
|
||||
codec=codec,
|
||||
)
|
||||
else:
|
||||
log.warning(
|
||||
'Sending non-`Msg`-spec msg?\n\n'
|
||||
f'{msg}\n'
|
||||
)
|
||||
|
||||
try:
|
||||
bytes_data: bytes = codec.encode(msg)
|
||||
except TypeError as _err:
|
||||
typerr = _err
|
||||
msgtyperr: MsgTypeError = _mk_send_mte(
|
||||
msg,
|
||||
codec=codec,
|
||||
message=(
|
||||
f'IPC-msg-spec violation in\n\n'
|
||||
f'{pretty_struct.Struct.pformat(msg)}'
|
||||
),
|
||||
src_type_error=typerr,
|
||||
)
|
||||
raise msgtyperr from typerr
|
||||
|
||||
# supposedly the fastest says,
|
||||
# https://stackoverflow.com/a/54027962
|
||||
size: bytes = struct.pack("<I", len(bytes_data))
|
||||
return await self.stream.send_all(size + bytes_data)
|
||||
|
||||
# ?TODO? does it help ever to dynamically show this
|
||||
# frame?
|
||||
# try:
|
||||
# <the-above_code>
|
||||
# except BaseException as _err:
|
||||
# err = _err
|
||||
# if not isinstance(err, MsgTypeError):
|
||||
# __tracebackhide__: bool = False
|
||||
# raise
|
||||
|
||||
@property
|
||||
def laddr(self) -> tuple[str, int]:
|
||||
return self._laddr
|
||||
|
||||
@property
|
||||
def raddr(self) -> tuple[str, int]:
|
||||
return self._raddr
|
||||
|
||||
async def recv(self) -> Any:
|
||||
return await self._aiter_pkts.asend(None)
|
||||
|
||||
async def drain(self) -> AsyncIterator[dict]:
|
||||
'''
|
||||
Drain the stream's remaining messages sent from
|
||||
the far end until the connection is closed by
|
||||
the peer.
|
||||
|
||||
'''
|
||||
try:
|
||||
async for msg in self._iter_packets():
|
||||
self.drained.append(msg)
|
||||
except TransportClosed:
|
||||
for msg in self.drained:
|
||||
yield msg
|
||||
|
||||
def __aiter__(self):
|
||||
return self._aiter_pkts
|
||||
|
||||
def connected(self) -> bool:
|
||||
return self.stream.socket.fileno() != -1
|
||||
|
|
|
@ -18,24 +18,56 @@ typing.Protocol based generic msg API, implement this class to add backends for
|
|||
tractor.ipc.Channel
|
||||
|
||||
'''
|
||||
import trio
|
||||
from __future__ import annotations
|
||||
from typing import (
|
||||
runtime_checkable,
|
||||
Type,
|
||||
Protocol,
|
||||
TypeVar,
|
||||
ClassVar
|
||||
)
|
||||
from collections.abc import AsyncIterator
|
||||
from collections.abc import (
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
)
|
||||
import struct
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
)
|
||||
|
||||
import trio
|
||||
import msgspec
|
||||
from tricycle import BufferedReceiveStream
|
||||
|
||||
from tractor.log import get_logger
|
||||
from tractor._exceptions import (
|
||||
MsgTypeError,
|
||||
TransportClosed,
|
||||
_mk_send_mte,
|
||||
_mk_recv_mte,
|
||||
)
|
||||
from tractor.msg import (
|
||||
_ctxvar_MsgCodec,
|
||||
# _codec, XXX see `self._codec` sanity/debug checks
|
||||
MsgCodec,
|
||||
types as msgtypes,
|
||||
pretty_struct,
|
||||
)
|
||||
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
# from tractor.msg.types import MsgType
|
||||
# ?TODO? this should be our `Union[*msgtypes.__spec__]` alias now right..?
|
||||
# => BLEH, except can't bc prots must inherit typevar or param-spec
|
||||
# vars..
|
||||
AddressType = TypeVar('AddressType')
|
||||
MsgType = TypeVar('MsgType')
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class MsgTransport(Protocol[MsgType]):
|
||||
class MsgTransport(Protocol[AddressType, MsgType]):
|
||||
#
|
||||
# ^-TODO-^ consider using a generic def and indexing with our
|
||||
# eventual msg definition/types?
|
||||
|
@ -43,9 +75,7 @@ class MsgTransport(Protocol[MsgType]):
|
|||
|
||||
stream: trio.abc.Stream
|
||||
drained: list[MsgType]
|
||||
|
||||
def __init__(self, stream: trio.abc.Stream) -> None:
|
||||
...
|
||||
address_type: ClassVar[Type[AddressType]]
|
||||
|
||||
# XXX: should this instead be called `.sendall()`?
|
||||
async def send(self, msg: MsgType) -> None:
|
||||
|
@ -66,9 +96,345 @@ class MsgTransport(Protocol[MsgType]):
|
|||
...
|
||||
|
||||
@property
|
||||
def laddr(self) -> tuple[str, int]:
|
||||
def laddr(self) -> AddressType:
|
||||
...
|
||||
|
||||
@property
|
||||
def raddr(self) -> tuple[str, int]:
|
||||
def raddr(self) -> AddressType:
|
||||
...
|
||||
|
||||
@classmethod
|
||||
async def connect_to(
|
||||
cls,
|
||||
destaddr: AddressType,
|
||||
**kwargs
|
||||
) -> MsgTransport:
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def get_stream_addrs(
|
||||
cls,
|
||||
stream: trio.abc.Stream
|
||||
) -> tuple[
|
||||
AddressType, # local
|
||||
AddressType # remote
|
||||
]:
|
||||
'''
|
||||
Return the `trio` streaming transport prot's addrs for both
|
||||
the local and remote sides as a pair.
|
||||
|
||||
'''
|
||||
...
|
||||
|
||||
|
||||
class MsgpackTransport(MsgTransport):
|
||||
|
||||
# TODO: better naming for this?
|
||||
# -[ ] check how libp2p does naming for such things?
|
||||
codec_key: str = 'msgpack'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream: trio.abc.Stream,
|
||||
prefix_size: int = 4,
|
||||
|
||||
# XXX optionally provided codec pair for `msgspec`:
|
||||
# https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types
|
||||
#
|
||||
# TODO: define this as a `Codec` struct which can be
|
||||
# overriden dynamically by the application/runtime?
|
||||
codec: MsgCodec = None,
|
||||
|
||||
) -> None:
|
||||
self.stream = stream
|
||||
self._laddr, self._raddr = self.get_stream_addrs(stream)
|
||||
|
||||
# create read loop instance
|
||||
self._aiter_pkts = self._iter_packets()
|
||||
self._send_lock = trio.StrictFIFOLock()
|
||||
|
||||
# public i guess?
|
||||
self.drained: list[dict] = []
|
||||
|
||||
self.recv_stream = BufferedReceiveStream(
|
||||
transport_stream=stream
|
||||
)
|
||||
self.prefix_size = prefix_size
|
||||
|
||||
# allow for custom IPC msg interchange format
|
||||
# dynamic override Bo
|
||||
self._task = trio.lowlevel.current_task()
|
||||
|
||||
# XXX for ctxvar debug only!
|
||||
# self._codec: MsgCodec = (
|
||||
# codec
|
||||
# or
|
||||
# _codec._ctxvar_MsgCodec.get()
|
||||
# )
|
||||
|
||||
async def _iter_packets(self) -> AsyncGenerator[dict, None]:
|
||||
'''
|
||||
Yield `bytes`-blob decoded packets from the underlying TCP
|
||||
stream using the current task's `MsgCodec`.
|
||||
|
||||
This is a streaming routine implemented as an async generator
|
||||
func (which was the original design, but could be changed?)
|
||||
and is allocated by a `.__call__()` inside `.__init__()` where
|
||||
it is assigned to the `._aiter_pkts` attr.
|
||||
|
||||
'''
|
||||
decodes_failed: int = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
header: bytes = await self.recv_stream.receive_exactly(4)
|
||||
except (
|
||||
ValueError,
|
||||
ConnectionResetError,
|
||||
|
||||
# not sure entirely why we need this but without it we
|
||||
# seem to be getting racy failures here on
|
||||
# arbiter/registry name subs..
|
||||
trio.BrokenResourceError,
|
||||
|
||||
) as trans_err:
|
||||
|
||||
loglevel = 'transport'
|
||||
match trans_err:
|
||||
# case (
|
||||
# ConnectionResetError()
|
||||
# ):
|
||||
# loglevel = 'transport'
|
||||
|
||||
# peer actor (graceful??) TCP EOF but `tricycle`
|
||||
# seems to raise a 0-bytes-read?
|
||||
case ValueError() if (
|
||||
'unclean EOF' in trans_err.args[0]
|
||||
):
|
||||
pass
|
||||
|
||||
# peer actor (task) prolly shutdown quickly due
|
||||
# to cancellation
|
||||
case trio.BrokenResourceError() if (
|
||||
'Connection reset by peer' in trans_err.args[0]
|
||||
):
|
||||
pass
|
||||
|
||||
# unless the disconnect condition falls under "a
|
||||
# normal operation breakage" we usualy console warn
|
||||
# about it.
|
||||
case _:
|
||||
loglevel: str = 'warning'
|
||||
|
||||
|
||||
raise TransportClosed(
|
||||
message=(
|
||||
f'IPC transport already closed by peer\n'
|
||||
f'x)> {type(trans_err)}\n'
|
||||
f' |_{self}\n'
|
||||
),
|
||||
loglevel=loglevel,
|
||||
) from trans_err
|
||||
|
||||
# XXX definitely can happen if transport is closed
|
||||
# manually by another `trio.lowlevel.Task` in the
|
||||
# same actor; we use this in some simulated fault
|
||||
# testing for ex, but generally should never happen
|
||||
# under normal operation!
|
||||
#
|
||||
# NOTE: as such we always re-raise this error from the
|
||||
# RPC msg loop!
|
||||
except trio.ClosedResourceError as closure_err:
|
||||
raise TransportClosed(
|
||||
message=(
|
||||
f'IPC transport already manually closed locally?\n'
|
||||
f'x)> {type(closure_err)} \n'
|
||||
f' |_{self}\n'
|
||||
),
|
||||
loglevel='error',
|
||||
raise_on_report=(
|
||||
closure_err.args[0] == 'another task closed this fd'
|
||||
or
|
||||
closure_err.args[0] in ['another task closed this fd']
|
||||
),
|
||||
) from closure_err
|
||||
|
||||
# graceful TCP EOF disconnect
|
||||
if header == b'':
|
||||
raise TransportClosed(
|
||||
message=(
|
||||
f'IPC transport already gracefully closed\n'
|
||||
f')>\n'
|
||||
f'|_{self}\n'
|
||||
),
|
||||
loglevel='transport',
|
||||
# cause=??? # handy or no?
|
||||
)
|
||||
|
||||
size: int
|
||||
size, = struct.unpack("<I", header)
|
||||
|
||||
log.transport(f'received header {size}') # type: ignore
|
||||
msg_bytes: bytes = await self.recv_stream.receive_exactly(size)
|
||||
|
||||
log.transport(f"received {msg_bytes}") # type: ignore
|
||||
try:
|
||||
# NOTE: lookup the `trio.Task.context`'s var for
|
||||
# the current `MsgCodec`.
|
||||
codec: MsgCodec = _ctxvar_MsgCodec.get()
|
||||
|
||||
# XXX for ctxvar debug only!
|
||||
# if self._codec.pld_spec != codec.pld_spec:
|
||||
# assert (
|
||||
# task := trio.lowlevel.current_task()
|
||||
# ) is not self._task
|
||||
# self._task = task
|
||||
# self._codec = codec
|
||||
# log.runtime(
|
||||
# f'Using new codec in {self}.recv()\n'
|
||||
# f'codec: {self._codec}\n\n'
|
||||
# f'msg_bytes: {msg_bytes}\n'
|
||||
# )
|
||||
yield codec.decode(msg_bytes)
|
||||
|
||||
# XXX NOTE: since the below error derives from
|
||||
# `DecodeError` we need to catch is specially
|
||||
# and always raise such that spec violations
|
||||
# are never allowed to be caught silently!
|
||||
except msgspec.ValidationError as verr:
|
||||
msgtyperr: MsgTypeError = _mk_recv_mte(
|
||||
msg=msg_bytes,
|
||||
codec=codec,
|
||||
src_validation_error=verr,
|
||||
)
|
||||
# XXX deliver up to `Channel.recv()` where
|
||||
# a re-raise and `Error`-pack can inject the far
|
||||
# end actor `.uid`.
|
||||
yield msgtyperr
|
||||
|
||||
except (
|
||||
msgspec.DecodeError,
|
||||
UnicodeDecodeError,
|
||||
):
|
||||
if decodes_failed < 4:
|
||||
# ignore decoding errors for now and assume they have to
|
||||
# do with a channel drop - hope that receiving from the
|
||||
# channel will raise an expected error and bubble up.
|
||||
try:
|
||||
msg_str: str|bytes = msg_bytes.decode()
|
||||
except UnicodeDecodeError:
|
||||
msg_str = msg_bytes
|
||||
|
||||
log.exception(
|
||||
'Failed to decode msg?\n'
|
||||
f'{codec}\n\n'
|
||||
'Rxed bytes from wire:\n\n'
|
||||
f'{msg_str!r}\n'
|
||||
)
|
||||
decodes_failed += 1
|
||||
else:
|
||||
raise
|
||||
|
||||
async def send(
|
||||
self,
|
||||
msg: msgtypes.MsgType,
|
||||
|
||||
strict_types: bool = True,
|
||||
hide_tb: bool = False,
|
||||
|
||||
) -> None:
|
||||
'''
|
||||
Send a msgpack encoded py-object-blob-as-msg over TCP.
|
||||
|
||||
If `strict_types == True` then a `MsgTypeError` will be raised on any
|
||||
invalid msg type
|
||||
|
||||
'''
|
||||
__tracebackhide__: bool = hide_tb
|
||||
|
||||
# XXX see `trio._sync.AsyncContextManagerMixin` for details
|
||||
# on the `.acquire()`/`.release()` sequencing..
|
||||
async with self._send_lock:
|
||||
|
||||
# NOTE: lookup the `trio.Task.context`'s var for
|
||||
# the current `MsgCodec`.
|
||||
codec: MsgCodec = _ctxvar_MsgCodec.get()
|
||||
|
||||
# XXX for ctxvar debug only!
|
||||
# if self._codec.pld_spec != codec.pld_spec:
|
||||
# self._codec = codec
|
||||
# log.runtime(
|
||||
# f'Using new codec in {self}.send()\n'
|
||||
# f'codec: {self._codec}\n\n'
|
||||
# f'msg: {msg}\n'
|
||||
# )
|
||||
|
||||
if type(msg) not in msgtypes.__msg_types__:
|
||||
if strict_types:
|
||||
raise _mk_send_mte(
|
||||
msg,
|
||||
codec=codec,
|
||||
)
|
||||
else:
|
||||
log.warning(
|
||||
'Sending non-`Msg`-spec msg?\n\n'
|
||||
f'{msg}\n'
|
||||
)
|
||||
|
||||
try:
|
||||
bytes_data: bytes = codec.encode(msg)
|
||||
except TypeError as _err:
|
||||
typerr = _err
|
||||
msgtyperr: MsgTypeError = _mk_send_mte(
|
||||
msg,
|
||||
codec=codec,
|
||||
message=(
|
||||
f'IPC-msg-spec violation in\n\n'
|
||||
f'{pretty_struct.Struct.pformat(msg)}'
|
||||
),
|
||||
src_type_error=typerr,
|
||||
)
|
||||
raise msgtyperr from typerr
|
||||
|
||||
# supposedly the fastest says,
|
||||
# https://stackoverflow.com/a/54027962
|
||||
size: bytes = struct.pack("<I", len(bytes_data))
|
||||
return await self.stream.send_all(size + bytes_data)
|
||||
|
||||
# ?TODO? does it help ever to dynamically show this
|
||||
# frame?
|
||||
# try:
|
||||
# <the-above_code>
|
||||
# except BaseException as _err:
|
||||
# err = _err
|
||||
# if not isinstance(err, MsgTypeError):
|
||||
# __tracebackhide__: bool = False
|
||||
# raise
|
||||
|
||||
async def recv(self) -> Any:
|
||||
return await self._aiter_pkts.asend(None)
|
||||
|
||||
async def drain(self) -> AsyncIterator[dict]:
|
||||
'''
|
||||
Drain the stream's remaining messages sent from
|
||||
the far end until the connection is closed by
|
||||
the peer.
|
||||
|
||||
'''
|
||||
try:
|
||||
async for msg in self._iter_packets():
|
||||
self.drained.append(msg)
|
||||
except TransportClosed:
|
||||
for msg in self.drained:
|
||||
yield msg
|
||||
|
||||
def __aiter__(self):
|
||||
return self._aiter_pkts
|
||||
|
||||
@property
|
||||
def laddr(self) -> AddressType:
|
||||
return self._laddr
|
||||
|
||||
@property
|
||||
def raddr(self) -> AddressType:
|
||||
return self._raddr
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
# tractor: structured concurrent "actors".
|
||||
# Copyright 2018-eternity Tyler Goodlet.
|
||||
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Type, Union
|
||||
|
||||
import trio
|
||||
import socket
|
||||
|
||||
from ._transport import MsgTransport
|
||||
from ._tcp import MsgpackTCPStream
|
||||
from ._uds import MsgpackUDSStream
|
||||
|
||||
|
||||
# manually updated list of all supported codec+transport types
|
||||
_msg_transports = {
|
||||
('msgpack', 'tcp'): MsgpackTCPStream,
|
||||
('msgpack', 'uds'): MsgpackUDSStream
|
||||
}
|
||||
|
||||
|
||||
# all different address py types we use
|
||||
AddressTypes = Union[
|
||||
tuple([
|
||||
cls.address_type
|
||||
for key, cls in _msg_transports.items()
|
||||
])
|
||||
]
|
||||
|
||||
|
||||
def transport_from_destaddr(
|
||||
destaddr: AddressTypes,
|
||||
codec_key: str = 'msgpack',
|
||||
) -> Type[MsgTransport]:
|
||||
'''
|
||||
Given a destination address and a desired codec, find the
|
||||
corresponding `MsgTransport` type.
|
||||
|
||||
'''
|
||||
match destaddr:
|
||||
case str():
|
||||
return MsgpackUDSStream
|
||||
|
||||
case tuple():
|
||||
if (
|
||||
len(destaddr) == 2
|
||||
and
|
||||
isinstance(destaddr[0], str)
|
||||
and
|
||||
isinstance(destaddr[1], int)
|
||||
):
|
||||
return MsgpackTCPStream
|
||||
|
||||
raise NotImplementedError(
|
||||
f'No known transport for address {destaddr}'
|
||||
)
|
||||
|
||||
|
||||
def transport_from_stream(
|
||||
stream: trio.abc.Stream,
|
||||
codec_key: str = 'msgpack'
|
||||
) -> Type[MsgTransport]:
|
||||
'''
|
||||
Given an arbitrary `trio.abc.Stream` and a desired codec,
|
||||
find the corresponding `MsgTransport` type.
|
||||
|
||||
'''
|
||||
transport = None
|
||||
if isinstance(stream, trio.SocketStream):
|
||||
sock = stream.socket
|
||||
match sock.family:
|
||||
case socket.AF_INET | socket.AF_INET6:
|
||||
transport = 'tcp'
|
||||
|
||||
case socket.AF_UNIX:
|
||||
transport = 'uds'
|
||||
|
||||
case _:
|
||||
raise NotImplementedError(
|
||||
f'Unsupported socket family: {sock.family}'
|
||||
)
|
||||
|
||||
if not transport:
|
||||
raise NotImplementedError(
|
||||
f'Could not figure out transport type for stream type {type(stream)}'
|
||||
)
|
||||
|
||||
key = (codec_key, transport)
|
||||
|
||||
return _msg_transports[key]
|
|
@ -0,0 +1,84 @@
|
|||
# tractor: structured concurrent "actors".
|
||||
# Copyright 2018-eternity Tyler Goodlet.
|
||||
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
'''
|
||||
Unix Domain Socket implementation of tractor.ipc._transport.MsgTransport protocol
|
||||
|
||||
'''
|
||||
from __future__ import annotations
|
||||
|
||||
import trio
|
||||
|
||||
from tractor.msg import MsgCodec
|
||||
from tractor.log import get_logger
|
||||
from tractor.ipc._transport import MsgpackTransport
|
||||
|
||||
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
class MsgpackUDSStream(MsgpackTransport):
|
||||
'''
|
||||
A ``trio.SocketStream`` delivering ``msgpack`` formatted data
|
||||
using the ``msgspec`` codec lib.
|
||||
|
||||
'''
|
||||
address_type = str
|
||||
layer_key: int = 7
|
||||
name_key: str = 'uds'
|
||||
|
||||
# def __init__(
|
||||
# self,
|
||||
# stream: trio.SocketStream,
|
||||
# prefix_size: int = 4,
|
||||
# codec: CodecType = None,
|
||||
|
||||
# ) -> None:
|
||||
# super().__init__(
|
||||
# stream,
|
||||
# prefix_size=prefix_size,
|
||||
# codec=codec
|
||||
# )
|
||||
|
||||
def connected(self) -> bool:
|
||||
return self.stream.socket.fileno() != -1
|
||||
|
||||
@classmethod
|
||||
async def connect_to(
|
||||
cls,
|
||||
filename: str,
|
||||
prefix_size: int = 4,
|
||||
codec: MsgCodec|None = None,
|
||||
**kwargs
|
||||
) -> MsgpackUDSStream:
|
||||
stream = await trio.open_unix_socket(
|
||||
filename,
|
||||
**kwargs
|
||||
)
|
||||
return MsgpackUDSStream(
|
||||
stream,
|
||||
prefix_size=prefix_size,
|
||||
codec=codec
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_stream_addrs(
|
||||
cls,
|
||||
stream: trio.SocketStream
|
||||
) -> tuple[str, str]:
|
||||
return (
|
||||
stream.socket.getsockname(),
|
||||
stream.socket.getpeername(),
|
||||
)
|
Loading…
Reference in New Issue