Starting to make tractor.ipc.Channel work with multiple MsgTransports

Guillermo Rodriguez 2025-03-22 15:29:48 -03:00
parent 32b5210648
commit 2907719cbe
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
9 changed files with 657 additions and 475 deletions

View File

@ -75,7 +75,7 @@ async def get_registry(
# TODO: try to look pre-existing connection from # TODO: try to look pre-existing connection from
# `Actor._peers` and use it instead? # `Actor._peers` and use it instead?
async with ( async with (
_connect_chan(host, port) as chan, _connect_chan((host, port)) as chan,
open_portal(chan) as regstr_ptl, open_portal(chan) as regstr_ptl,
): ):
yield regstr_ptl yield regstr_ptl
@ -93,7 +93,7 @@ async def get_root(
assert host is not None assert host is not None
async with ( async with (
_connect_chan(host, port) as chan, _connect_chan((host, port)) as chan,
open_portal(chan, **kwargs) as portal, open_portal(chan, **kwargs) as portal,
): ):
yield portal yield portal
@ -187,7 +187,7 @@ async def maybe_open_portal(
pass pass
if sockaddr: if sockaddr:
async with _connect_chan(*sockaddr) as chan: async with _connect_chan(sockaddr) as chan:
async with open_portal(chan) as portal: async with open_portal(chan) as portal:
yield portal yield portal
else: else:
@ -310,6 +310,6 @@ async def wait_for_actor(
# TODO: offer multi-portal yields in multi-homed case? # TODO: offer multi-portal yields in multi-homed case?
sockaddr: tuple[str, int] = sockaddrs[-1] sockaddr: tuple[str, int] = sockaddrs[-1]
async with _connect_chan(*sockaddr) as chan: async with _connect_chan(sockaddr) as chan:
async with open_portal(chan) as portal: async with open_portal(chan) as portal:
yield portal yield portal

View File

@ -271,7 +271,7 @@ async def open_root_actor(
# be better to eventually have a "discovery" protocol # be better to eventually have a "discovery" protocol
# with basic handshake instead? # with basic handshake instead?
with trio.move_on_after(timeout): with trio.move_on_after(timeout):
async with _connect_chan(*addr): async with _connect_chan(addr):
ponged_addrs.append(addr) ponged_addrs.append(addr)
except OSError: except OSError:

View File

@ -1040,10 +1040,7 @@ class Actor:
# Connect back to the parent actor and conduct initial # Connect back to the parent actor and conduct initial
# handshake. From this point on if we error, we # handshake. From this point on if we error, we
# attempt to ship the exception back to the parent. # attempt to ship the exception back to the parent.
chan = Channel( chan = await Channel.from_destaddr(parent_addr)
destaddr=parent_addr,
)
await chan.connect()
# TODO: move this into a `Channel.handshake()`? # TODO: move this into a `Channel.handshake()`?
# Initial handshake: swap names. # Initial handshake: swap names.

View File

@ -13,20 +13,26 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
import platform import platform
from ._transport import MsgTransport as MsgTransport from ._transport import (
AddressType as AddressType,
MsgType as MsgType,
MsgTransport as MsgTransport,
MsgpackTransport as MsgpackTransport
)
from ._tcp import ( from ._tcp import MsgpackTCPStream as MsgpackTCPStream
get_stream_addrs as get_stream_addrs, from ._uds import MsgpackUDSStream as MsgpackUDSStream
MsgpackTCPStream as MsgpackTCPStream
from ._types import (
transport_from_destaddr as transport_from_destaddr,
transport_from_stream as transport_from_stream,
AddressTypes as AddressTypes
) )
from ._chan import ( from ._chan import (
_connect_chan as _connect_chan, _connect_chan as _connect_chan,
get_msg_transport as get_msg_transport,
Channel as Channel Channel as Channel
) )

View File

@ -29,15 +29,15 @@ from pprint import pformat
import typing import typing
from typing import ( from typing import (
Any, Any,
Type
) )
import trio import trio
from tractor.ipc._transport import MsgTransport from tractor.ipc._transport import MsgTransport
from tractor.ipc._tcp import ( from tractor.ipc._types import (
MsgpackTCPStream, transport_from_destaddr,
get_stream_addrs transport_from_stream,
AddressTypes
) )
from tractor.log import get_logger from tractor.log import get_logger
from tractor._exceptions import ( from tractor._exceptions import (
@ -52,17 +52,6 @@ log = get_logger(__name__)
_is_windows = platform.system() == 'Windows' _is_windows = platform.system() == 'Windows'
def get_msg_transport(
key: tuple[str, str],
) -> Type[MsgTransport]:
return {
('msgpack', 'tcp'): MsgpackTCPStream,
}[key]
class Channel: class Channel:
''' '''
An inter-process channel for communication between (remote) actors. An inter-process channel for communication between (remote) actors.
@ -77,10 +66,8 @@ class Channel:
def __init__( def __init__(
self, self,
destaddr: tuple[str, int]|None, destaddr: AddressTypes|None = None,
transport: MsgTransport|None = None,
msg_transport_type_key: tuple[str, str] = ('msgpack', 'tcp'),
# TODO: optional reconnection support? # TODO: optional reconnection support?
# auto_reconnect: bool = False, # auto_reconnect: bool = False,
# on_reconnect: typing.Callable[..., typing.Awaitable] = None, # on_reconnect: typing.Callable[..., typing.Awaitable] = None,
@ -90,13 +77,11 @@ class Channel:
# self._recon_seq = on_reconnect # self._recon_seq = on_reconnect
# self._autorecon = auto_reconnect # self._autorecon = auto_reconnect
self._destaddr = destaddr
self._transport_key = msg_transport_type_key
# Either created in ``.connect()`` or passed in by # Either created in ``.connect()`` or passed in by
# user in ``.from_stream()``. # user in ``.from_stream()``.
self._stream: trio.SocketStream|None = None self._transport: MsgTransport|None = transport
self._transport: MsgTransport|None = None
self._destaddr = destaddr if destaddr else self._transport.raddr
# set after handshake - always uid of far end # set after handshake - always uid of far end
self.uid: tuple[str, str]|None = None self.uid: tuple[str, str]|None = None
@ -110,6 +95,10 @@ class Channel:
# runtime. # runtime.
self._cancel_called: bool = False self._cancel_called: bool = False
@property
def stream(self) -> trio.abc.Stream | None:
return self._transport.stream if self._transport else None
@property @property
def msgstream(self) -> MsgTransport: def msgstream(self) -> MsgTransport:
log.info( log.info(
@ -124,52 +113,31 @@ class Channel:
@classmethod @classmethod
def from_stream( def from_stream(
cls, cls,
stream: trio.SocketStream, stream: trio.abc.Stream,
**kwargs,
) -> Channel: ) -> Channel:
transport_cls = transport_from_stream(stream)
src, dst = get_stream_addrs(stream) return Channel(
chan = Channel( transport=transport_cls(stream)
destaddr=dst,
**kwargs,
) )
# set immediately here from provided instance @classmethod
chan._stream: trio.SocketStream = stream async def from_destaddr(
chan.set_msg_transport(stream) cls,
return chan destaddr: AddressTypes,
**kwargs
) -> Channel:
transport_cls = transport_from_destaddr(destaddr)
transport = await transport_cls.connect_to(destaddr, **kwargs)
def set_msg_transport( log.transport(
self, f'Opened channel[{type(transport)}]: {transport.laddr} -> {transport.raddr}'
stream: trio.SocketStream,
type_key: tuple[str, str]|None = None,
# XXX optionally provided codec pair for `msgspec`:
# https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types
codec: MsgCodec|None = None,
) -> MsgTransport:
type_key = (
type_key
or
self._transport_key
) )
# get transport type, then return Channel(transport=transport)
self._transport = get_msg_transport(
type_key
# instantiate an instance of the msg-transport
)(
stream,
codec=codec,
)
return self._transport
@cm @cm
def apply_codec( def apply_codec(
self, self,
codec: MsgCodec, codec: MsgCodec,
) -> None: ) -> None:
''' '''
Temporarily override the underlying IPC msg codec for Temporarily override the underlying IPC msg codec for
@ -189,7 +157,7 @@ class Channel:
return '<Channel with inactive transport?>' return '<Channel with inactive transport?>'
return repr( return repr(
self._transport.stream.socket._sock self._transport
).replace( # type: ignore ).replace( # type: ignore
"socket.socket", "socket.socket",
"Channel", "Channel",
@ -203,30 +171,6 @@ class Channel:
def raddr(self) -> tuple[str, int]|None: def raddr(self) -> tuple[str, int]|None:
return self._transport.raddr if self._transport else None return self._transport.raddr if self._transport else None
async def connect(
self,
destaddr: tuple[Any, ...] | None = None,
**kwargs
) -> MsgTransport:
if self.connected():
raise RuntimeError("channel is already connected?")
destaddr = destaddr or self._destaddr
assert isinstance(destaddr, tuple)
stream = await trio.open_tcp_stream(
*destaddr,
**kwargs
)
transport = self.set_msg_transport(stream)
log.transport(
f'Opened channel[{type(transport)}]: {self.laddr} -> {self.raddr}'
)
return transport
# TODO: something like, # TODO: something like,
# `pdbp.hideframe_on(errors=[MsgTypeError])` # `pdbp.hideframe_on(errors=[MsgTypeError])`
# instead of the `try/except` hack we have rn.. # instead of the `try/except` hack we have rn..
@ -388,17 +332,14 @@ class Channel:
@acm @acm
async def _connect_chan( async def _connect_chan(
host: str, destaddr: AddressTypes
port: int
) -> typing.AsyncGenerator[Channel, None]: ) -> typing.AsyncGenerator[Channel, None]:
''' '''
Create and connect a channel with disconnect on context manager Create and connect a channel with disconnect on context manager
teardown. teardown.
''' '''
chan = Channel((host, port)) chan = await Channel.from_destaddr(destaddr)
await chan.connect()
yield chan yield chan
with trio.CancelScope(shield=True): with trio.CancelScope(shield=True):
await chan.aclose() await chan.aclose()

View File

@ -18,388 +18,75 @@ TCP implementation of tractor.ipc._transport.MsgTransport protocol
''' '''
from __future__ import annotations from __future__ import annotations
from collections.abc import (
AsyncGenerator,
AsyncIterator,
)
import struct
from typing import (
Any,
Callable,
)
import msgspec
from tricycle import BufferedReceiveStream
import trio import trio
from tractor.msg import MsgCodec
from tractor.log import get_logger from tractor.log import get_logger
from tractor._exceptions import ( from tractor.ipc._transport import MsgpackTransport
MsgTypeError,
TransportClosed,
_mk_send_mte,
_mk_recv_mte,
)
from tractor.msg import (
_ctxvar_MsgCodec,
# _codec, XXX see `self._codec` sanity/debug checks
MsgCodec,
types as msgtypes,
pretty_struct,
)
from tractor.ipc import MsgTransport
log = get_logger(__name__) log = get_logger(__name__)
def get_stream_addrs(
stream: trio.SocketStream
) -> tuple[
tuple[str, int], # local
tuple[str, int], # remote
]:
'''
Return the `trio` streaming transport prot's socket-addrs for
both the local and remote sides as a pair.
'''
# rn, should both be IP sockets
lsockname = stream.socket.getsockname()
rsockname = stream.socket.getpeername()
return (
tuple(lsockname[:2]),
tuple(rsockname[:2]),
)
# TODO: typing oddity.. not sure why we have to inherit here, but it # TODO: typing oddity.. not sure why we have to inherit here, but it
# seems to be an issue with `get_msg_transport()` returning # seems to be an issue with `get_msg_transport()` returning
# a `Type[Protocol]`; probably should make a `mypy` issue? # a `Type[Protocol]`; probably should make a `mypy` issue?
class MsgpackTCPStream(MsgTransport): class MsgpackTCPStream(MsgpackTransport):
''' '''
A ``trio.SocketStream`` delivering ``msgpack`` formatted data A ``trio.SocketStream`` delivering ``msgpack`` formatted data
using the ``msgspec`` codec lib. using the ``msgspec`` codec lib.
''' '''
address_type = tuple[str, int]
layer_key: int = 4 layer_key: int = 4
name_key: str = 'tcp' name_key: str = 'tcp'
# TODO: better naming for this? # def __init__(
# -[ ] check how libp2p does naming for such things? # self,
codec_key: str = 'msgpack' # stream: trio.SocketStream,
# prefix_size: int = 4,
# codec: CodecType = None,
def __init__( # ) -> None:
self, # super().__init__(
stream: trio.SocketStream, # stream,
prefix_size: int = 4, # prefix_size=prefix_size,
# codec=codec
# XXX optionally provided codec pair for `msgspec`: # )
# https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types
#
# TODO: define this as a `Codec` struct which can be
# overriden dynamically by the application/runtime?
codec: tuple[
Callable[[Any], Any]|None, # coder
Callable[[type, Any], Any]|None, # decoder
]|None = None,
) -> None:
self.stream = stream
assert self.stream.socket
# should both be IP sockets
self._laddr, self._raddr = get_stream_addrs(stream)
# create read loop instance
self._aiter_pkts = self._iter_packets()
self._send_lock = trio.StrictFIFOLock()
# public i guess?
self.drained: list[dict] = []
self.recv_stream = BufferedReceiveStream(
transport_stream=stream
)
self.prefix_size = prefix_size
# allow for custom IPC msg interchange format
# dynamic override Bo
self._task = trio.lowlevel.current_task()
# XXX for ctxvar debug only!
# self._codec: MsgCodec = (
# codec
# or
# _codec._ctxvar_MsgCodec.get()
# )
async def _iter_packets(self) -> AsyncGenerator[dict, None]:
'''
Yield `bytes`-blob decoded packets from the underlying TCP
stream using the current task's `MsgCodec`.
This is a streaming routine implemented as an async generator
func (which was the original design, but could be changed?)
and is allocated by a `.__call__()` inside `.__init__()` where
it is assigned to the `._aiter_pkts` attr.
'''
decodes_failed: int = 0
while True:
try:
header: bytes = await self.recv_stream.receive_exactly(4)
except (
ValueError,
ConnectionResetError,
# not sure entirely why we need this but without it we
# seem to be getting racy failures here on
# arbiter/registry name subs..
trio.BrokenResourceError,
) as trans_err:
loglevel = 'transport'
match trans_err:
# case (
# ConnectionResetError()
# ):
# loglevel = 'transport'
# peer actor (graceful??) TCP EOF but `tricycle`
# seems to raise a 0-bytes-read?
case ValueError() if (
'unclean EOF' in trans_err.args[0]
):
pass
# peer actor (task) prolly shutdown quickly due
# to cancellation
case trio.BrokenResourceError() if (
'Connection reset by peer' in trans_err.args[0]
):
pass
# unless the disconnect condition falls under "a
# normal operation breakage" we usualy console warn
# about it.
case _:
loglevel: str = 'warning'
raise TransportClosed(
message=(
f'IPC transport already closed by peer\n'
f'x)> {type(trans_err)}\n'
f' |_{self}\n'
),
loglevel=loglevel,
) from trans_err
# XXX definitely can happen if transport is closed
# manually by another `trio.lowlevel.Task` in the
# same actor; we use this in some simulated fault
# testing for ex, but generally should never happen
# under normal operation!
#
# NOTE: as such we always re-raise this error from the
# RPC msg loop!
except trio.ClosedResourceError as closure_err:
raise TransportClosed(
message=(
f'IPC transport already manually closed locally?\n'
f'x)> {type(closure_err)} \n'
f' |_{self}\n'
),
loglevel='error',
raise_on_report=(
closure_err.args[0] == 'another task closed this fd'
or
closure_err.args[0] in ['another task closed this fd']
),
) from closure_err
# graceful TCP EOF disconnect
if header == b'':
raise TransportClosed(
message=(
f'IPC transport already gracefully closed\n'
f')>\n'
f'|_{self}\n'
),
loglevel='transport',
# cause=??? # handy or no?
)
size: int
size, = struct.unpack("<I", header)
log.transport(f'received header {size}') # type: ignore
msg_bytes: bytes = await self.recv_stream.receive_exactly(size)
log.transport(f"received {msg_bytes}") # type: ignore
try:
# NOTE: lookup the `trio.Task.context`'s var for
# the current `MsgCodec`.
codec: MsgCodec = _ctxvar_MsgCodec.get()
# XXX for ctxvar debug only!
# if self._codec.pld_spec != codec.pld_spec:
# assert (
# task := trio.lowlevel.current_task()
# ) is not self._task
# self._task = task
# self._codec = codec
# log.runtime(
# f'Using new codec in {self}.recv()\n'
# f'codec: {self._codec}\n\n'
# f'msg_bytes: {msg_bytes}\n'
# )
yield codec.decode(msg_bytes)
# XXX NOTE: since the below error derives from
# `DecodeError` we need to catch is specially
# and always raise such that spec violations
# are never allowed to be caught silently!
except msgspec.ValidationError as verr:
msgtyperr: MsgTypeError = _mk_recv_mte(
msg=msg_bytes,
codec=codec,
src_validation_error=verr,
)
# XXX deliver up to `Channel.recv()` where
# a re-raise and `Error`-pack can inject the far
# end actor `.uid`.
yield msgtyperr
except (
msgspec.DecodeError,
UnicodeDecodeError,
):
if decodes_failed < 4:
# ignore decoding errors for now and assume they have to
# do with a channel drop - hope that receiving from the
# channel will raise an expected error and bubble up.
try:
msg_str: str|bytes = msg_bytes.decode()
except UnicodeDecodeError:
msg_str = msg_bytes
log.exception(
'Failed to decode msg?\n'
f'{codec}\n\n'
'Rxed bytes from wire:\n\n'
f'{msg_str!r}\n'
)
decodes_failed += 1
else:
raise
async def send(
self,
msg: msgtypes.MsgType,
strict_types: bool = True,
hide_tb: bool = False,
) -> None:
'''
Send a msgpack encoded py-object-blob-as-msg over TCP.
If `strict_types == True` then a `MsgTypeError` will be raised on any
invalid msg type
'''
__tracebackhide__: bool = hide_tb
# XXX see `trio._sync.AsyncContextManagerMixin` for details
# on the `.acquire()`/`.release()` sequencing..
async with self._send_lock:
# NOTE: lookup the `trio.Task.context`'s var for
# the current `MsgCodec`.
codec: MsgCodec = _ctxvar_MsgCodec.get()
# XXX for ctxvar debug only!
# if self._codec.pld_spec != codec.pld_spec:
# self._codec = codec
# log.runtime(
# f'Using new codec in {self}.send()\n'
# f'codec: {self._codec}\n\n'
# f'msg: {msg}\n'
# )
if type(msg) not in msgtypes.__msg_types__:
if strict_types:
raise _mk_send_mte(
msg,
codec=codec,
)
else:
log.warning(
'Sending non-`Msg`-spec msg?\n\n'
f'{msg}\n'
)
try:
bytes_data: bytes = codec.encode(msg)
except TypeError as _err:
typerr = _err
msgtyperr: MsgTypeError = _mk_send_mte(
msg,
codec=codec,
message=(
f'IPC-msg-spec violation in\n\n'
f'{pretty_struct.Struct.pformat(msg)}'
),
src_type_error=typerr,
)
raise msgtyperr from typerr
# supposedly the fastest says,
# https://stackoverflow.com/a/54027962
size: bytes = struct.pack("<I", len(bytes_data))
return await self.stream.send_all(size + bytes_data)
# ?TODO? does it help ever to dynamically show this
# frame?
# try:
# <the-above_code>
# except BaseException as _err:
# err = _err
# if not isinstance(err, MsgTypeError):
# __tracebackhide__: bool = False
# raise
@property
def laddr(self) -> tuple[str, int]:
return self._laddr
@property
def raddr(self) -> tuple[str, int]:
return self._raddr
async def recv(self) -> Any:
return await self._aiter_pkts.asend(None)
async def drain(self) -> AsyncIterator[dict]:
'''
Drain the stream's remaining messages sent from
the far end until the connection is closed by
the peer.
'''
try:
async for msg in self._iter_packets():
self.drained.append(msg)
except TransportClosed:
for msg in self.drained:
yield msg
def __aiter__(self):
return self._aiter_pkts
def connected(self) -> bool: def connected(self) -> bool:
return self.stream.socket.fileno() != -1 return self.stream.socket.fileno() != -1
@classmethod
async def connect_to(
cls,
destaddr: tuple[str, int],
prefix_size: int = 4,
codec: MsgCodec|None = None,
**kwargs
) -> MsgpackTCPStream:
stream = await trio.open_tcp_stream(
*destaddr,
**kwargs
)
return MsgpackTCPStream(
stream,
prefix_size=prefix_size,
codec=codec
)
@classmethod
def get_stream_addrs(
cls,
stream: trio.SocketStream
) -> tuple[
tuple[str, int],
tuple[str, int]
]:
lsockname = stream.socket.getsockname()
rsockname = stream.socket.getpeername()
return (
tuple(lsockname[:2]),
tuple(rsockname[:2]),
)

View File

@ -18,24 +18,56 @@ typing.Protocol based generic msg API, implement this class to add backends for
tractor.ipc.Channel tractor.ipc.Channel
''' '''
import trio from __future__ import annotations
from typing import ( from typing import (
runtime_checkable, runtime_checkable,
Type,
Protocol, Protocol,
TypeVar, TypeVar,
ClassVar
) )
from collections.abc import AsyncIterator from collections.abc import (
AsyncGenerator,
AsyncIterator,
)
import struct
from typing import (
Any,
Callable,
)
import trio
import msgspec
from tricycle import BufferedReceiveStream
from tractor.log import get_logger
from tractor._exceptions import (
MsgTypeError,
TransportClosed,
_mk_send_mte,
_mk_recv_mte,
)
from tractor.msg import (
_ctxvar_MsgCodec,
# _codec, XXX see `self._codec` sanity/debug checks
MsgCodec,
types as msgtypes,
pretty_struct,
)
log = get_logger(__name__)
# from tractor.msg.types import MsgType # from tractor.msg.types import MsgType
# ?TODO? this should be our `Union[*msgtypes.__spec__]` alias now right..? # ?TODO? this should be our `Union[*msgtypes.__spec__]` alias now right..?
# => BLEH, except can't bc prots must inherit typevar or param-spec # => BLEH, except can't bc prots must inherit typevar or param-spec
# vars.. # vars..
AddressType = TypeVar('AddressType')
MsgType = TypeVar('MsgType') MsgType = TypeVar('MsgType')
@runtime_checkable @runtime_checkable
class MsgTransport(Protocol[MsgType]): class MsgTransport(Protocol[AddressType, MsgType]):
# #
# ^-TODO-^ consider using a generic def and indexing with our # ^-TODO-^ consider using a generic def and indexing with our
# eventual msg definition/types? # eventual msg definition/types?
@ -43,9 +75,7 @@ class MsgTransport(Protocol[MsgType]):
stream: trio.abc.Stream stream: trio.abc.Stream
drained: list[MsgType] drained: list[MsgType]
address_type: ClassVar[Type[AddressType]]
def __init__(self, stream: trio.abc.Stream) -> None:
...
# XXX: should this instead be called `.sendall()`? # XXX: should this instead be called `.sendall()`?
async def send(self, msg: MsgType) -> None: async def send(self, msg: MsgType) -> None:
@ -66,9 +96,345 @@ class MsgTransport(Protocol[MsgType]):
... ...
@property @property
def laddr(self) -> tuple[str, int]: def laddr(self) -> AddressType:
... ...
@property @property
def raddr(self) -> tuple[str, int]: def raddr(self) -> AddressType:
... ...
@classmethod
async def connect_to(
cls,
destaddr: AddressType,
**kwargs
) -> MsgTransport:
...
@classmethod
def get_stream_addrs(
cls,
stream: trio.abc.Stream
) -> tuple[
AddressType, # local
AddressType # remote
]:
'''
Return the `trio` streaming transport prot's addrs for both
the local and remote sides as a pair.
'''
...
class MsgpackTransport(MsgTransport):
# TODO: better naming for this?
# -[ ] check how libp2p does naming for such things?
codec_key: str = 'msgpack'
def __init__(
self,
stream: trio.abc.Stream,
prefix_size: int = 4,
# XXX optionally provided codec pair for `msgspec`:
# https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types
#
# TODO: define this as a `Codec` struct which can be
# overriden dynamically by the application/runtime?
codec: MsgCodec = None,
) -> None:
self.stream = stream
self._laddr, self._raddr = self.get_stream_addrs(stream)
# create read loop instance
self._aiter_pkts = self._iter_packets()
self._send_lock = trio.StrictFIFOLock()
# public i guess?
self.drained: list[dict] = []
self.recv_stream = BufferedReceiveStream(
transport_stream=stream
)
self.prefix_size = prefix_size
# allow for custom IPC msg interchange format
# dynamic override Bo
self._task = trio.lowlevel.current_task()
# XXX for ctxvar debug only!
# self._codec: MsgCodec = (
# codec
# or
# _codec._ctxvar_MsgCodec.get()
# )
async def _iter_packets(self) -> AsyncGenerator[dict, None]:
'''
Yield `bytes`-blob decoded packets from the underlying TCP
stream using the current task's `MsgCodec`.
This is a streaming routine implemented as an async generator
func (which was the original design, but could be changed?)
and is allocated by a `.__call__()` inside `.__init__()` where
it is assigned to the `._aiter_pkts` attr.
'''
decodes_failed: int = 0
while True:
try:
header: bytes = await self.recv_stream.receive_exactly(4)
except (
ValueError,
ConnectionResetError,
# not sure entirely why we need this but without it we
# seem to be getting racy failures here on
# arbiter/registry name subs..
trio.BrokenResourceError,
) as trans_err:
loglevel = 'transport'
match trans_err:
# case (
# ConnectionResetError()
# ):
# loglevel = 'transport'
# peer actor (graceful??) TCP EOF but `tricycle`
# seems to raise a 0-bytes-read?
case ValueError() if (
'unclean EOF' in trans_err.args[0]
):
pass
# peer actor (task) prolly shutdown quickly due
# to cancellation
case trio.BrokenResourceError() if (
'Connection reset by peer' in trans_err.args[0]
):
pass
# unless the disconnect condition falls under "a
# normal operation breakage" we usualy console warn
# about it.
case _:
loglevel: str = 'warning'
raise TransportClosed(
message=(
f'IPC transport already closed by peer\n'
f'x)> {type(trans_err)}\n'
f' |_{self}\n'
),
loglevel=loglevel,
) from trans_err
# XXX definitely can happen if transport is closed
# manually by another `trio.lowlevel.Task` in the
# same actor; we use this in some simulated fault
# testing for ex, but generally should never happen
# under normal operation!
#
# NOTE: as such we always re-raise this error from the
# RPC msg loop!
except trio.ClosedResourceError as closure_err:
raise TransportClosed(
message=(
f'IPC transport already manually closed locally?\n'
f'x)> {type(closure_err)} \n'
f' |_{self}\n'
),
loglevel='error',
raise_on_report=(
closure_err.args[0] == 'another task closed this fd'
or
closure_err.args[0] in ['another task closed this fd']
),
) from closure_err
# graceful TCP EOF disconnect
if header == b'':
raise TransportClosed(
message=(
f'IPC transport already gracefully closed\n'
f')>\n'
f'|_{self}\n'
),
loglevel='transport',
# cause=??? # handy or no?
)
size: int
size, = struct.unpack("<I", header)
log.transport(f'received header {size}') # type: ignore
msg_bytes: bytes = await self.recv_stream.receive_exactly(size)
log.transport(f"received {msg_bytes}") # type: ignore
try:
# NOTE: lookup the `trio.Task.context`'s var for
# the current `MsgCodec`.
codec: MsgCodec = _ctxvar_MsgCodec.get()
# XXX for ctxvar debug only!
# if self._codec.pld_spec != codec.pld_spec:
# assert (
# task := trio.lowlevel.current_task()
# ) is not self._task
# self._task = task
# self._codec = codec
# log.runtime(
# f'Using new codec in {self}.recv()\n'
# f'codec: {self._codec}\n\n'
# f'msg_bytes: {msg_bytes}\n'
# )
yield codec.decode(msg_bytes)
# XXX NOTE: since the below error derives from
# `DecodeError` we need to catch is specially
# and always raise such that spec violations
# are never allowed to be caught silently!
except msgspec.ValidationError as verr:
msgtyperr: MsgTypeError = _mk_recv_mte(
msg=msg_bytes,
codec=codec,
src_validation_error=verr,
)
# XXX deliver up to `Channel.recv()` where
# a re-raise and `Error`-pack can inject the far
# end actor `.uid`.
yield msgtyperr
except (
msgspec.DecodeError,
UnicodeDecodeError,
):
if decodes_failed < 4:
# ignore decoding errors for now and assume they have to
# do with a channel drop - hope that receiving from the
# channel will raise an expected error and bubble up.
try:
msg_str: str|bytes = msg_bytes.decode()
except UnicodeDecodeError:
msg_str = msg_bytes
log.exception(
'Failed to decode msg?\n'
f'{codec}\n\n'
'Rxed bytes from wire:\n\n'
f'{msg_str!r}\n'
)
decodes_failed += 1
else:
raise
async def send(
self,
msg: msgtypes.MsgType,
strict_types: bool = True,
hide_tb: bool = False,
) -> None:
'''
Send a msgpack encoded py-object-blob-as-msg over TCP.
If `strict_types == True` then a `MsgTypeError` will be raised on any
invalid msg type
'''
__tracebackhide__: bool = hide_tb
# XXX see `trio._sync.AsyncContextManagerMixin` for details
# on the `.acquire()`/`.release()` sequencing..
async with self._send_lock:
# NOTE: lookup the `trio.Task.context`'s var for
# the current `MsgCodec`.
codec: MsgCodec = _ctxvar_MsgCodec.get()
# XXX for ctxvar debug only!
# if self._codec.pld_spec != codec.pld_spec:
# self._codec = codec
# log.runtime(
# f'Using new codec in {self}.send()\n'
# f'codec: {self._codec}\n\n'
# f'msg: {msg}\n'
# )
if type(msg) not in msgtypes.__msg_types__:
if strict_types:
raise _mk_send_mte(
msg,
codec=codec,
)
else:
log.warning(
'Sending non-`Msg`-spec msg?\n\n'
f'{msg}\n'
)
try:
bytes_data: bytes = codec.encode(msg)
except TypeError as _err:
typerr = _err
msgtyperr: MsgTypeError = _mk_send_mte(
msg,
codec=codec,
message=(
f'IPC-msg-spec violation in\n\n'
f'{pretty_struct.Struct.pformat(msg)}'
),
src_type_error=typerr,
)
raise msgtyperr from typerr
# supposedly the fastest says,
# https://stackoverflow.com/a/54027962
size: bytes = struct.pack("<I", len(bytes_data))
return await self.stream.send_all(size + bytes_data)
# ?TODO? does it help ever to dynamically show this
# frame?
# try:
# <the-above_code>
# except BaseException as _err:
# err = _err
# if not isinstance(err, MsgTypeError):
# __tracebackhide__: bool = False
# raise
async def recv(self) -> Any:
return await self._aiter_pkts.asend(None)
async def drain(self) -> AsyncIterator[dict]:
'''
Drain the stream's remaining messages sent from
the far end until the connection is closed by
the peer.
'''
try:
async for msg in self._iter_packets():
self.drained.append(msg)
except TransportClosed:
for msg in self.drained:
yield msg
def __aiter__(self):
return self._aiter_pkts
@property
def laddr(self) -> AddressType:
return self._laddr
@property
def raddr(self) -> AddressType:
return self._raddr

View File

@ -0,0 +1,101 @@
# tractor: structured concurrent "actors".
# Copyright 2018-eternity Tyler Goodlet.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Type, Union
import trio
import socket
from ._transport import MsgTransport
from ._tcp import MsgpackTCPStream
from ._uds import MsgpackUDSStream
# manually updated list of all supported codec+transport types
_msg_transports = {
('msgpack', 'tcp'): MsgpackTCPStream,
('msgpack', 'uds'): MsgpackUDSStream
}
# all different address py types we use
AddressTypes = Union[
tuple([
cls.address_type
for key, cls in _msg_transports.items()
])
]
def transport_from_destaddr(
destaddr: AddressTypes,
codec_key: str = 'msgpack',
) -> Type[MsgTransport]:
'''
Given a destination address and a desired codec, find the
corresponding `MsgTransport` type.
'''
match destaddr:
case str():
return MsgpackUDSStream
case tuple():
if (
len(destaddr) == 2
and
isinstance(destaddr[0], str)
and
isinstance(destaddr[1], int)
):
return MsgpackTCPStream
raise NotImplementedError(
f'No known transport for address {destaddr}'
)
def transport_from_stream(
stream: trio.abc.Stream,
codec_key: str = 'msgpack'
) -> Type[MsgTransport]:
'''
Given an arbitrary `trio.abc.Stream` and a desired codec,
find the corresponding `MsgTransport` type.
'''
transport = None
if isinstance(stream, trio.SocketStream):
sock = stream.socket
match sock.family:
case socket.AF_INET | socket.AF_INET6:
transport = 'tcp'
case socket.AF_UNIX:
transport = 'uds'
case _:
raise NotImplementedError(
f'Unsupported socket family: {sock.family}'
)
if not transport:
raise NotImplementedError(
f'Could not figure out transport type for stream type {type(stream)}'
)
key = (codec_key, transport)
return _msg_transports[key]

View File

@ -0,0 +1,84 @@
# tractor: structured concurrent "actors".
# Copyright 2018-eternity Tyler Goodlet.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
'''
Unix Domain Socket implementation of tractor.ipc._transport.MsgTransport protocol
'''
from __future__ import annotations
import trio
from tractor.msg import MsgCodec
from tractor.log import get_logger
from tractor.ipc._transport import MsgpackTransport
log = get_logger(__name__)
class MsgpackUDSStream(MsgpackTransport):
'''
A ``trio.SocketStream`` delivering ``msgpack`` formatted data
using the ``msgspec`` codec lib.
'''
address_type = str
layer_key: int = 7
name_key: str = 'uds'
# def __init__(
# self,
# stream: trio.SocketStream,
# prefix_size: int = 4,
# codec: CodecType = None,
# ) -> None:
# super().__init__(
# stream,
# prefix_size=prefix_size,
# codec=codec
# )
def connected(self) -> bool:
return self.stream.socket.fileno() != -1
@classmethod
async def connect_to(
cls,
filename: str,
prefix_size: int = 4,
codec: MsgCodec|None = None,
**kwargs
) -> MsgpackUDSStream:
stream = await trio.open_unix_socket(
filename,
**kwargs
)
return MsgpackUDSStream(
stream,
prefix_size=prefix_size,
codec=codec
)
@classmethod
def get_stream_addrs(
cls,
stream: trio.SocketStream
) -> tuple[str, str]:
return (
stream.socket.getsockname(),
stream.socket.getpeername(),
)