Starting to make tractor.ipc.Channel work with multiple MsgTransports
parent
32b5210648
commit
2907719cbe
|
@ -75,7 +75,7 @@ async def get_registry(
|
||||||
# TODO: try to look pre-existing connection from
|
# TODO: try to look pre-existing connection from
|
||||||
# `Actor._peers` and use it instead?
|
# `Actor._peers` and use it instead?
|
||||||
async with (
|
async with (
|
||||||
_connect_chan(host, port) as chan,
|
_connect_chan((host, port)) as chan,
|
||||||
open_portal(chan) as regstr_ptl,
|
open_portal(chan) as regstr_ptl,
|
||||||
):
|
):
|
||||||
yield regstr_ptl
|
yield regstr_ptl
|
||||||
|
@ -93,7 +93,7 @@ async def get_root(
|
||||||
assert host is not None
|
assert host is not None
|
||||||
|
|
||||||
async with (
|
async with (
|
||||||
_connect_chan(host, port) as chan,
|
_connect_chan((host, port)) as chan,
|
||||||
open_portal(chan, **kwargs) as portal,
|
open_portal(chan, **kwargs) as portal,
|
||||||
):
|
):
|
||||||
yield portal
|
yield portal
|
||||||
|
@ -187,7 +187,7 @@ async def maybe_open_portal(
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if sockaddr:
|
if sockaddr:
|
||||||
async with _connect_chan(*sockaddr) as chan:
|
async with _connect_chan(sockaddr) as chan:
|
||||||
async with open_portal(chan) as portal:
|
async with open_portal(chan) as portal:
|
||||||
yield portal
|
yield portal
|
||||||
else:
|
else:
|
||||||
|
@ -310,6 +310,6 @@ async def wait_for_actor(
|
||||||
# TODO: offer multi-portal yields in multi-homed case?
|
# TODO: offer multi-portal yields in multi-homed case?
|
||||||
sockaddr: tuple[str, int] = sockaddrs[-1]
|
sockaddr: tuple[str, int] = sockaddrs[-1]
|
||||||
|
|
||||||
async with _connect_chan(*sockaddr) as chan:
|
async with _connect_chan(sockaddr) as chan:
|
||||||
async with open_portal(chan) as portal:
|
async with open_portal(chan) as portal:
|
||||||
yield portal
|
yield portal
|
||||||
|
|
|
@ -271,7 +271,7 @@ async def open_root_actor(
|
||||||
# be better to eventually have a "discovery" protocol
|
# be better to eventually have a "discovery" protocol
|
||||||
# with basic handshake instead?
|
# with basic handshake instead?
|
||||||
with trio.move_on_after(timeout):
|
with trio.move_on_after(timeout):
|
||||||
async with _connect_chan(*addr):
|
async with _connect_chan(addr):
|
||||||
ponged_addrs.append(addr)
|
ponged_addrs.append(addr)
|
||||||
|
|
||||||
except OSError:
|
except OSError:
|
||||||
|
|
|
@ -1040,10 +1040,7 @@ class Actor:
|
||||||
# Connect back to the parent actor and conduct initial
|
# Connect back to the parent actor and conduct initial
|
||||||
# handshake. From this point on if we error, we
|
# handshake. From this point on if we error, we
|
||||||
# attempt to ship the exception back to the parent.
|
# attempt to ship the exception back to the parent.
|
||||||
chan = Channel(
|
chan = await Channel.from_destaddr(parent_addr)
|
||||||
destaddr=parent_addr,
|
|
||||||
)
|
|
||||||
await chan.connect()
|
|
||||||
|
|
||||||
# TODO: move this into a `Channel.handshake()`?
|
# TODO: move this into a `Channel.handshake()`?
|
||||||
# Initial handshake: swap names.
|
# Initial handshake: swap names.
|
||||||
|
|
|
@ -13,20 +13,26 @@
|
||||||
|
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
|
||||||
import platform
|
import platform
|
||||||
|
|
||||||
from ._transport import MsgTransport as MsgTransport
|
from ._transport import (
|
||||||
|
AddressType as AddressType,
|
||||||
|
MsgType as MsgType,
|
||||||
|
MsgTransport as MsgTransport,
|
||||||
|
MsgpackTransport as MsgpackTransport
|
||||||
|
)
|
||||||
|
|
||||||
from ._tcp import (
|
from ._tcp import MsgpackTCPStream as MsgpackTCPStream
|
||||||
get_stream_addrs as get_stream_addrs,
|
from ._uds import MsgpackUDSStream as MsgpackUDSStream
|
||||||
MsgpackTCPStream as MsgpackTCPStream
|
|
||||||
|
from ._types import (
|
||||||
|
transport_from_destaddr as transport_from_destaddr,
|
||||||
|
transport_from_stream as transport_from_stream,
|
||||||
|
AddressTypes as AddressTypes
|
||||||
)
|
)
|
||||||
|
|
||||||
from ._chan import (
|
from ._chan import (
|
||||||
_connect_chan as _connect_chan,
|
_connect_chan as _connect_chan,
|
||||||
get_msg_transport as get_msg_transport,
|
|
||||||
Channel as Channel
|
Channel as Channel
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -29,15 +29,15 @@ from pprint import pformat
|
||||||
import typing
|
import typing
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Type
|
|
||||||
)
|
)
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
from tractor.ipc._transport import MsgTransport
|
from tractor.ipc._transport import MsgTransport
|
||||||
from tractor.ipc._tcp import (
|
from tractor.ipc._types import (
|
||||||
MsgpackTCPStream,
|
transport_from_destaddr,
|
||||||
get_stream_addrs
|
transport_from_stream,
|
||||||
|
AddressTypes
|
||||||
)
|
)
|
||||||
from tractor.log import get_logger
|
from tractor.log import get_logger
|
||||||
from tractor._exceptions import (
|
from tractor._exceptions import (
|
||||||
|
@ -52,17 +52,6 @@ log = get_logger(__name__)
|
||||||
_is_windows = platform.system() == 'Windows'
|
_is_windows = platform.system() == 'Windows'
|
||||||
|
|
||||||
|
|
||||||
def get_msg_transport(
|
|
||||||
|
|
||||||
key: tuple[str, str],
|
|
||||||
|
|
||||||
) -> Type[MsgTransport]:
|
|
||||||
|
|
||||||
return {
|
|
||||||
('msgpack', 'tcp'): MsgpackTCPStream,
|
|
||||||
}[key]
|
|
||||||
|
|
||||||
|
|
||||||
class Channel:
|
class Channel:
|
||||||
'''
|
'''
|
||||||
An inter-process channel for communication between (remote) actors.
|
An inter-process channel for communication between (remote) actors.
|
||||||
|
@ -77,10 +66,8 @@ class Channel:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
||||||
self,
|
self,
|
||||||
destaddr: tuple[str, int]|None,
|
destaddr: AddressTypes|None = None,
|
||||||
|
transport: MsgTransport|None = None,
|
||||||
msg_transport_type_key: tuple[str, str] = ('msgpack', 'tcp'),
|
|
||||||
|
|
||||||
# TODO: optional reconnection support?
|
# TODO: optional reconnection support?
|
||||||
# auto_reconnect: bool = False,
|
# auto_reconnect: bool = False,
|
||||||
# on_reconnect: typing.Callable[..., typing.Awaitable] = None,
|
# on_reconnect: typing.Callable[..., typing.Awaitable] = None,
|
||||||
|
@ -90,13 +77,11 @@ class Channel:
|
||||||
# self._recon_seq = on_reconnect
|
# self._recon_seq = on_reconnect
|
||||||
# self._autorecon = auto_reconnect
|
# self._autorecon = auto_reconnect
|
||||||
|
|
||||||
self._destaddr = destaddr
|
|
||||||
self._transport_key = msg_transport_type_key
|
|
||||||
|
|
||||||
# Either created in ``.connect()`` or passed in by
|
# Either created in ``.connect()`` or passed in by
|
||||||
# user in ``.from_stream()``.
|
# user in ``.from_stream()``.
|
||||||
self._stream: trio.SocketStream|None = None
|
self._transport: MsgTransport|None = transport
|
||||||
self._transport: MsgTransport|None = None
|
|
||||||
|
self._destaddr = destaddr if destaddr else self._transport.raddr
|
||||||
|
|
||||||
# set after handshake - always uid of far end
|
# set after handshake - always uid of far end
|
||||||
self.uid: tuple[str, str]|None = None
|
self.uid: tuple[str, str]|None = None
|
||||||
|
@ -110,6 +95,10 @@ class Channel:
|
||||||
# runtime.
|
# runtime.
|
||||||
self._cancel_called: bool = False
|
self._cancel_called: bool = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stream(self) -> trio.abc.Stream | None:
|
||||||
|
return self._transport.stream if self._transport else None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def msgstream(self) -> MsgTransport:
|
def msgstream(self) -> MsgTransport:
|
||||||
log.info(
|
log.info(
|
||||||
|
@ -124,52 +113,31 @@ class Channel:
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_stream(
|
def from_stream(
|
||||||
cls,
|
cls,
|
||||||
stream: trio.SocketStream,
|
stream: trio.abc.Stream,
|
||||||
**kwargs,
|
|
||||||
|
|
||||||
) -> Channel:
|
) -> Channel:
|
||||||
|
transport_cls = transport_from_stream(stream)
|
||||||
src, dst = get_stream_addrs(stream)
|
return Channel(
|
||||||
chan = Channel(
|
transport=transport_cls(stream)
|
||||||
destaddr=dst,
|
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# set immediately here from provided instance
|
@classmethod
|
||||||
chan._stream: trio.SocketStream = stream
|
async def from_destaddr(
|
||||||
chan.set_msg_transport(stream)
|
cls,
|
||||||
return chan
|
destaddr: AddressTypes,
|
||||||
|
**kwargs
|
||||||
|
) -> Channel:
|
||||||
|
transport_cls = transport_from_destaddr(destaddr)
|
||||||
|
transport = await transport_cls.connect_to(destaddr, **kwargs)
|
||||||
|
|
||||||
def set_msg_transport(
|
log.transport(
|
||||||
self,
|
f'Opened channel[{type(transport)}]: {transport.laddr} -> {transport.raddr}'
|
||||||
stream: trio.SocketStream,
|
|
||||||
type_key: tuple[str, str]|None = None,
|
|
||||||
|
|
||||||
# XXX optionally provided codec pair for `msgspec`:
|
|
||||||
# https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types
|
|
||||||
codec: MsgCodec|None = None,
|
|
||||||
|
|
||||||
) -> MsgTransport:
|
|
||||||
type_key = (
|
|
||||||
type_key
|
|
||||||
or
|
|
||||||
self._transport_key
|
|
||||||
)
|
)
|
||||||
# get transport type, then
|
return Channel(transport=transport)
|
||||||
self._transport = get_msg_transport(
|
|
||||||
type_key
|
|
||||||
# instantiate an instance of the msg-transport
|
|
||||||
)(
|
|
||||||
stream,
|
|
||||||
codec=codec,
|
|
||||||
)
|
|
||||||
return self._transport
|
|
||||||
|
|
||||||
@cm
|
@cm
|
||||||
def apply_codec(
|
def apply_codec(
|
||||||
self,
|
self,
|
||||||
codec: MsgCodec,
|
codec: MsgCodec,
|
||||||
|
|
||||||
) -> None:
|
) -> None:
|
||||||
'''
|
'''
|
||||||
Temporarily override the underlying IPC msg codec for
|
Temporarily override the underlying IPC msg codec for
|
||||||
|
@ -189,7 +157,7 @@ class Channel:
|
||||||
return '<Channel with inactive transport?>'
|
return '<Channel with inactive transport?>'
|
||||||
|
|
||||||
return repr(
|
return repr(
|
||||||
self._transport.stream.socket._sock
|
self._transport
|
||||||
).replace( # type: ignore
|
).replace( # type: ignore
|
||||||
"socket.socket",
|
"socket.socket",
|
||||||
"Channel",
|
"Channel",
|
||||||
|
@ -203,30 +171,6 @@ class Channel:
|
||||||
def raddr(self) -> tuple[str, int]|None:
|
def raddr(self) -> tuple[str, int]|None:
|
||||||
return self._transport.raddr if self._transport else None
|
return self._transport.raddr if self._transport else None
|
||||||
|
|
||||||
async def connect(
|
|
||||||
self,
|
|
||||||
destaddr: tuple[Any, ...] | None = None,
|
|
||||||
**kwargs
|
|
||||||
|
|
||||||
) -> MsgTransport:
|
|
||||||
|
|
||||||
if self.connected():
|
|
||||||
raise RuntimeError("channel is already connected?")
|
|
||||||
|
|
||||||
destaddr = destaddr or self._destaddr
|
|
||||||
assert isinstance(destaddr, tuple)
|
|
||||||
|
|
||||||
stream = await trio.open_tcp_stream(
|
|
||||||
*destaddr,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
transport = self.set_msg_transport(stream)
|
|
||||||
|
|
||||||
log.transport(
|
|
||||||
f'Opened channel[{type(transport)}]: {self.laddr} -> {self.raddr}'
|
|
||||||
)
|
|
||||||
return transport
|
|
||||||
|
|
||||||
# TODO: something like,
|
# TODO: something like,
|
||||||
# `pdbp.hideframe_on(errors=[MsgTypeError])`
|
# `pdbp.hideframe_on(errors=[MsgTypeError])`
|
||||||
# instead of the `try/except` hack we have rn..
|
# instead of the `try/except` hack we have rn..
|
||||||
|
@ -388,17 +332,14 @@ class Channel:
|
||||||
|
|
||||||
@acm
|
@acm
|
||||||
async def _connect_chan(
|
async def _connect_chan(
|
||||||
host: str,
|
destaddr: AddressTypes
|
||||||
port: int
|
|
||||||
|
|
||||||
) -> typing.AsyncGenerator[Channel, None]:
|
) -> typing.AsyncGenerator[Channel, None]:
|
||||||
'''
|
'''
|
||||||
Create and connect a channel with disconnect on context manager
|
Create and connect a channel with disconnect on context manager
|
||||||
teardown.
|
teardown.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
chan = Channel((host, port))
|
chan = await Channel.from_destaddr(destaddr)
|
||||||
await chan.connect()
|
|
||||||
yield chan
|
yield chan
|
||||||
with trio.CancelScope(shield=True):
|
with trio.CancelScope(shield=True):
|
||||||
await chan.aclose()
|
await chan.aclose()
|
||||||
|
|
|
@ -18,388 +18,75 @@ TCP implementation of tractor.ipc._transport.MsgTransport protocol
|
||||||
|
|
||||||
'''
|
'''
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from collections.abc import (
|
|
||||||
AsyncGenerator,
|
|
||||||
AsyncIterator,
|
|
||||||
)
|
|
||||||
import struct
|
|
||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
)
|
|
||||||
|
|
||||||
import msgspec
|
|
||||||
from tricycle import BufferedReceiveStream
|
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
|
from tractor.msg import MsgCodec
|
||||||
from tractor.log import get_logger
|
from tractor.log import get_logger
|
||||||
from tractor._exceptions import (
|
from tractor.ipc._transport import MsgpackTransport
|
||||||
MsgTypeError,
|
|
||||||
TransportClosed,
|
|
||||||
_mk_send_mte,
|
|
||||||
_mk_recv_mte,
|
|
||||||
)
|
|
||||||
from tractor.msg import (
|
|
||||||
_ctxvar_MsgCodec,
|
|
||||||
# _codec, XXX see `self._codec` sanity/debug checks
|
|
||||||
MsgCodec,
|
|
||||||
types as msgtypes,
|
|
||||||
pretty_struct,
|
|
||||||
)
|
|
||||||
from tractor.ipc import MsgTransport
|
|
||||||
|
|
||||||
|
|
||||||
log = get_logger(__name__)
|
log = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_stream_addrs(
|
|
||||||
stream: trio.SocketStream
|
|
||||||
) -> tuple[
|
|
||||||
tuple[str, int], # local
|
|
||||||
tuple[str, int], # remote
|
|
||||||
]:
|
|
||||||
'''
|
|
||||||
Return the `trio` streaming transport prot's socket-addrs for
|
|
||||||
both the local and remote sides as a pair.
|
|
||||||
|
|
||||||
'''
|
|
||||||
# rn, should both be IP sockets
|
|
||||||
lsockname = stream.socket.getsockname()
|
|
||||||
rsockname = stream.socket.getpeername()
|
|
||||||
return (
|
|
||||||
tuple(lsockname[:2]),
|
|
||||||
tuple(rsockname[:2]),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: typing oddity.. not sure why we have to inherit here, but it
|
# TODO: typing oddity.. not sure why we have to inherit here, but it
|
||||||
# seems to be an issue with `get_msg_transport()` returning
|
# seems to be an issue with `get_msg_transport()` returning
|
||||||
# a `Type[Protocol]`; probably should make a `mypy` issue?
|
# a `Type[Protocol]`; probably should make a `mypy` issue?
|
||||||
class MsgpackTCPStream(MsgTransport):
|
class MsgpackTCPStream(MsgpackTransport):
|
||||||
'''
|
'''
|
||||||
A ``trio.SocketStream`` delivering ``msgpack`` formatted data
|
A ``trio.SocketStream`` delivering ``msgpack`` formatted data
|
||||||
using the ``msgspec`` codec lib.
|
using the ``msgspec`` codec lib.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
address_type = tuple[str, int]
|
||||||
layer_key: int = 4
|
layer_key: int = 4
|
||||||
name_key: str = 'tcp'
|
name_key: str = 'tcp'
|
||||||
|
|
||||||
# TODO: better naming for this?
|
# def __init__(
|
||||||
# -[ ] check how libp2p does naming for such things?
|
# self,
|
||||||
codec_key: str = 'msgpack'
|
# stream: trio.SocketStream,
|
||||||
|
# prefix_size: int = 4,
|
||||||
|
# codec: CodecType = None,
|
||||||
|
|
||||||
def __init__(
|
# ) -> None:
|
||||||
self,
|
# super().__init__(
|
||||||
stream: trio.SocketStream,
|
# stream,
|
||||||
prefix_size: int = 4,
|
# prefix_size=prefix_size,
|
||||||
|
# codec=codec
|
||||||
# XXX optionally provided codec pair for `msgspec`:
|
# )
|
||||||
# https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types
|
|
||||||
#
|
|
||||||
# TODO: define this as a `Codec` struct which can be
|
|
||||||
# overriden dynamically by the application/runtime?
|
|
||||||
codec: tuple[
|
|
||||||
Callable[[Any], Any]|None, # coder
|
|
||||||
Callable[[type, Any], Any]|None, # decoder
|
|
||||||
]|None = None,
|
|
||||||
|
|
||||||
) -> None:
|
|
||||||
|
|
||||||
self.stream = stream
|
|
||||||
assert self.stream.socket
|
|
||||||
|
|
||||||
# should both be IP sockets
|
|
||||||
self._laddr, self._raddr = get_stream_addrs(stream)
|
|
||||||
|
|
||||||
# create read loop instance
|
|
||||||
self._aiter_pkts = self._iter_packets()
|
|
||||||
self._send_lock = trio.StrictFIFOLock()
|
|
||||||
|
|
||||||
# public i guess?
|
|
||||||
self.drained: list[dict] = []
|
|
||||||
|
|
||||||
self.recv_stream = BufferedReceiveStream(
|
|
||||||
transport_stream=stream
|
|
||||||
)
|
|
||||||
self.prefix_size = prefix_size
|
|
||||||
|
|
||||||
# allow for custom IPC msg interchange format
|
|
||||||
# dynamic override Bo
|
|
||||||
self._task = trio.lowlevel.current_task()
|
|
||||||
|
|
||||||
# XXX for ctxvar debug only!
|
|
||||||
# self._codec: MsgCodec = (
|
|
||||||
# codec
|
|
||||||
# or
|
|
||||||
# _codec._ctxvar_MsgCodec.get()
|
|
||||||
# )
|
|
||||||
|
|
||||||
async def _iter_packets(self) -> AsyncGenerator[dict, None]:
|
|
||||||
'''
|
|
||||||
Yield `bytes`-blob decoded packets from the underlying TCP
|
|
||||||
stream using the current task's `MsgCodec`.
|
|
||||||
|
|
||||||
This is a streaming routine implemented as an async generator
|
|
||||||
func (which was the original design, but could be changed?)
|
|
||||||
and is allocated by a `.__call__()` inside `.__init__()` where
|
|
||||||
it is assigned to the `._aiter_pkts` attr.
|
|
||||||
|
|
||||||
'''
|
|
||||||
decodes_failed: int = 0
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
header: bytes = await self.recv_stream.receive_exactly(4)
|
|
||||||
except (
|
|
||||||
ValueError,
|
|
||||||
ConnectionResetError,
|
|
||||||
|
|
||||||
# not sure entirely why we need this but without it we
|
|
||||||
# seem to be getting racy failures here on
|
|
||||||
# arbiter/registry name subs..
|
|
||||||
trio.BrokenResourceError,
|
|
||||||
|
|
||||||
) as trans_err:
|
|
||||||
|
|
||||||
loglevel = 'transport'
|
|
||||||
match trans_err:
|
|
||||||
# case (
|
|
||||||
# ConnectionResetError()
|
|
||||||
# ):
|
|
||||||
# loglevel = 'transport'
|
|
||||||
|
|
||||||
# peer actor (graceful??) TCP EOF but `tricycle`
|
|
||||||
# seems to raise a 0-bytes-read?
|
|
||||||
case ValueError() if (
|
|
||||||
'unclean EOF' in trans_err.args[0]
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# peer actor (task) prolly shutdown quickly due
|
|
||||||
# to cancellation
|
|
||||||
case trio.BrokenResourceError() if (
|
|
||||||
'Connection reset by peer' in trans_err.args[0]
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# unless the disconnect condition falls under "a
|
|
||||||
# normal operation breakage" we usualy console warn
|
|
||||||
# about it.
|
|
||||||
case _:
|
|
||||||
loglevel: str = 'warning'
|
|
||||||
|
|
||||||
|
|
||||||
raise TransportClosed(
|
|
||||||
message=(
|
|
||||||
f'IPC transport already closed by peer\n'
|
|
||||||
f'x)> {type(trans_err)}\n'
|
|
||||||
f' |_{self}\n'
|
|
||||||
),
|
|
||||||
loglevel=loglevel,
|
|
||||||
) from trans_err
|
|
||||||
|
|
||||||
# XXX definitely can happen if transport is closed
|
|
||||||
# manually by another `trio.lowlevel.Task` in the
|
|
||||||
# same actor; we use this in some simulated fault
|
|
||||||
# testing for ex, but generally should never happen
|
|
||||||
# under normal operation!
|
|
||||||
#
|
|
||||||
# NOTE: as such we always re-raise this error from the
|
|
||||||
# RPC msg loop!
|
|
||||||
except trio.ClosedResourceError as closure_err:
|
|
||||||
raise TransportClosed(
|
|
||||||
message=(
|
|
||||||
f'IPC transport already manually closed locally?\n'
|
|
||||||
f'x)> {type(closure_err)} \n'
|
|
||||||
f' |_{self}\n'
|
|
||||||
),
|
|
||||||
loglevel='error',
|
|
||||||
raise_on_report=(
|
|
||||||
closure_err.args[0] == 'another task closed this fd'
|
|
||||||
or
|
|
||||||
closure_err.args[0] in ['another task closed this fd']
|
|
||||||
),
|
|
||||||
) from closure_err
|
|
||||||
|
|
||||||
# graceful TCP EOF disconnect
|
|
||||||
if header == b'':
|
|
||||||
raise TransportClosed(
|
|
||||||
message=(
|
|
||||||
f'IPC transport already gracefully closed\n'
|
|
||||||
f')>\n'
|
|
||||||
f'|_{self}\n'
|
|
||||||
),
|
|
||||||
loglevel='transport',
|
|
||||||
# cause=??? # handy or no?
|
|
||||||
)
|
|
||||||
|
|
||||||
size: int
|
|
||||||
size, = struct.unpack("<I", header)
|
|
||||||
|
|
||||||
log.transport(f'received header {size}') # type: ignore
|
|
||||||
msg_bytes: bytes = await self.recv_stream.receive_exactly(size)
|
|
||||||
|
|
||||||
log.transport(f"received {msg_bytes}") # type: ignore
|
|
||||||
try:
|
|
||||||
# NOTE: lookup the `trio.Task.context`'s var for
|
|
||||||
# the current `MsgCodec`.
|
|
||||||
codec: MsgCodec = _ctxvar_MsgCodec.get()
|
|
||||||
|
|
||||||
# XXX for ctxvar debug only!
|
|
||||||
# if self._codec.pld_spec != codec.pld_spec:
|
|
||||||
# assert (
|
|
||||||
# task := trio.lowlevel.current_task()
|
|
||||||
# ) is not self._task
|
|
||||||
# self._task = task
|
|
||||||
# self._codec = codec
|
|
||||||
# log.runtime(
|
|
||||||
# f'Using new codec in {self}.recv()\n'
|
|
||||||
# f'codec: {self._codec}\n\n'
|
|
||||||
# f'msg_bytes: {msg_bytes}\n'
|
|
||||||
# )
|
|
||||||
yield codec.decode(msg_bytes)
|
|
||||||
|
|
||||||
# XXX NOTE: since the below error derives from
|
|
||||||
# `DecodeError` we need to catch is specially
|
|
||||||
# and always raise such that spec violations
|
|
||||||
# are never allowed to be caught silently!
|
|
||||||
except msgspec.ValidationError as verr:
|
|
||||||
msgtyperr: MsgTypeError = _mk_recv_mte(
|
|
||||||
msg=msg_bytes,
|
|
||||||
codec=codec,
|
|
||||||
src_validation_error=verr,
|
|
||||||
)
|
|
||||||
# XXX deliver up to `Channel.recv()` where
|
|
||||||
# a re-raise and `Error`-pack can inject the far
|
|
||||||
# end actor `.uid`.
|
|
||||||
yield msgtyperr
|
|
||||||
|
|
||||||
except (
|
|
||||||
msgspec.DecodeError,
|
|
||||||
UnicodeDecodeError,
|
|
||||||
):
|
|
||||||
if decodes_failed < 4:
|
|
||||||
# ignore decoding errors for now and assume they have to
|
|
||||||
# do with a channel drop - hope that receiving from the
|
|
||||||
# channel will raise an expected error and bubble up.
|
|
||||||
try:
|
|
||||||
msg_str: str|bytes = msg_bytes.decode()
|
|
||||||
except UnicodeDecodeError:
|
|
||||||
msg_str = msg_bytes
|
|
||||||
|
|
||||||
log.exception(
|
|
||||||
'Failed to decode msg?\n'
|
|
||||||
f'{codec}\n\n'
|
|
||||||
'Rxed bytes from wire:\n\n'
|
|
||||||
f'{msg_str!r}\n'
|
|
||||||
)
|
|
||||||
decodes_failed += 1
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def send(
|
|
||||||
self,
|
|
||||||
msg: msgtypes.MsgType,
|
|
||||||
|
|
||||||
strict_types: bool = True,
|
|
||||||
hide_tb: bool = False,
|
|
||||||
|
|
||||||
) -> None:
|
|
||||||
'''
|
|
||||||
Send a msgpack encoded py-object-blob-as-msg over TCP.
|
|
||||||
|
|
||||||
If `strict_types == True` then a `MsgTypeError` will be raised on any
|
|
||||||
invalid msg type
|
|
||||||
|
|
||||||
'''
|
|
||||||
__tracebackhide__: bool = hide_tb
|
|
||||||
|
|
||||||
# XXX see `trio._sync.AsyncContextManagerMixin` for details
|
|
||||||
# on the `.acquire()`/`.release()` sequencing..
|
|
||||||
async with self._send_lock:
|
|
||||||
|
|
||||||
# NOTE: lookup the `trio.Task.context`'s var for
|
|
||||||
# the current `MsgCodec`.
|
|
||||||
codec: MsgCodec = _ctxvar_MsgCodec.get()
|
|
||||||
|
|
||||||
# XXX for ctxvar debug only!
|
|
||||||
# if self._codec.pld_spec != codec.pld_spec:
|
|
||||||
# self._codec = codec
|
|
||||||
# log.runtime(
|
|
||||||
# f'Using new codec in {self}.send()\n'
|
|
||||||
# f'codec: {self._codec}\n\n'
|
|
||||||
# f'msg: {msg}\n'
|
|
||||||
# )
|
|
||||||
|
|
||||||
if type(msg) not in msgtypes.__msg_types__:
|
|
||||||
if strict_types:
|
|
||||||
raise _mk_send_mte(
|
|
||||||
msg,
|
|
||||||
codec=codec,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
log.warning(
|
|
||||||
'Sending non-`Msg`-spec msg?\n\n'
|
|
||||||
f'{msg}\n'
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
bytes_data: bytes = codec.encode(msg)
|
|
||||||
except TypeError as _err:
|
|
||||||
typerr = _err
|
|
||||||
msgtyperr: MsgTypeError = _mk_send_mte(
|
|
||||||
msg,
|
|
||||||
codec=codec,
|
|
||||||
message=(
|
|
||||||
f'IPC-msg-spec violation in\n\n'
|
|
||||||
f'{pretty_struct.Struct.pformat(msg)}'
|
|
||||||
),
|
|
||||||
src_type_error=typerr,
|
|
||||||
)
|
|
||||||
raise msgtyperr from typerr
|
|
||||||
|
|
||||||
# supposedly the fastest says,
|
|
||||||
# https://stackoverflow.com/a/54027962
|
|
||||||
size: bytes = struct.pack("<I", len(bytes_data))
|
|
||||||
return await self.stream.send_all(size + bytes_data)
|
|
||||||
|
|
||||||
# ?TODO? does it help ever to dynamically show this
|
|
||||||
# frame?
|
|
||||||
# try:
|
|
||||||
# <the-above_code>
|
|
||||||
# except BaseException as _err:
|
|
||||||
# err = _err
|
|
||||||
# if not isinstance(err, MsgTypeError):
|
|
||||||
# __tracebackhide__: bool = False
|
|
||||||
# raise
|
|
||||||
|
|
||||||
@property
|
|
||||||
def laddr(self) -> tuple[str, int]:
|
|
||||||
return self._laddr
|
|
||||||
|
|
||||||
@property
|
|
||||||
def raddr(self) -> tuple[str, int]:
|
|
||||||
return self._raddr
|
|
||||||
|
|
||||||
async def recv(self) -> Any:
|
|
||||||
return await self._aiter_pkts.asend(None)
|
|
||||||
|
|
||||||
async def drain(self) -> AsyncIterator[dict]:
|
|
||||||
'''
|
|
||||||
Drain the stream's remaining messages sent from
|
|
||||||
the far end until the connection is closed by
|
|
||||||
the peer.
|
|
||||||
|
|
||||||
'''
|
|
||||||
try:
|
|
||||||
async for msg in self._iter_packets():
|
|
||||||
self.drained.append(msg)
|
|
||||||
except TransportClosed:
|
|
||||||
for msg in self.drained:
|
|
||||||
yield msg
|
|
||||||
|
|
||||||
def __aiter__(self):
|
|
||||||
return self._aiter_pkts
|
|
||||||
|
|
||||||
def connected(self) -> bool:
|
def connected(self) -> bool:
|
||||||
return self.stream.socket.fileno() != -1
|
return self.stream.socket.fileno() != -1
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def connect_to(
|
||||||
|
cls,
|
||||||
|
destaddr: tuple[str, int],
|
||||||
|
prefix_size: int = 4,
|
||||||
|
codec: MsgCodec|None = None,
|
||||||
|
**kwargs
|
||||||
|
) -> MsgpackTCPStream:
|
||||||
|
stream = await trio.open_tcp_stream(
|
||||||
|
*destaddr,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return MsgpackTCPStream(
|
||||||
|
stream,
|
||||||
|
prefix_size=prefix_size,
|
||||||
|
codec=codec
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_stream_addrs(
|
||||||
|
cls,
|
||||||
|
stream: trio.SocketStream
|
||||||
|
) -> tuple[
|
||||||
|
tuple[str, int],
|
||||||
|
tuple[str, int]
|
||||||
|
]:
|
||||||
|
lsockname = stream.socket.getsockname()
|
||||||
|
rsockname = stream.socket.getpeername()
|
||||||
|
return (
|
||||||
|
tuple(lsockname[:2]),
|
||||||
|
tuple(rsockname[:2]),
|
||||||
|
)
|
||||||
|
|
|
@ -18,24 +18,56 @@ typing.Protocol based generic msg API, implement this class to add backends for
|
||||||
tractor.ipc.Channel
|
tractor.ipc.Channel
|
||||||
|
|
||||||
'''
|
'''
|
||||||
import trio
|
from __future__ import annotations
|
||||||
from typing import (
|
from typing import (
|
||||||
runtime_checkable,
|
runtime_checkable,
|
||||||
|
Type,
|
||||||
Protocol,
|
Protocol,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
|
ClassVar
|
||||||
)
|
)
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import (
|
||||||
|
AsyncGenerator,
|
||||||
|
AsyncIterator,
|
||||||
|
)
|
||||||
|
import struct
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
)
|
||||||
|
|
||||||
|
import trio
|
||||||
|
import msgspec
|
||||||
|
from tricycle import BufferedReceiveStream
|
||||||
|
|
||||||
|
from tractor.log import get_logger
|
||||||
|
from tractor._exceptions import (
|
||||||
|
MsgTypeError,
|
||||||
|
TransportClosed,
|
||||||
|
_mk_send_mte,
|
||||||
|
_mk_recv_mte,
|
||||||
|
)
|
||||||
|
from tractor.msg import (
|
||||||
|
_ctxvar_MsgCodec,
|
||||||
|
# _codec, XXX see `self._codec` sanity/debug checks
|
||||||
|
MsgCodec,
|
||||||
|
types as msgtypes,
|
||||||
|
pretty_struct,
|
||||||
|
)
|
||||||
|
|
||||||
|
log = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# from tractor.msg.types import MsgType
|
# from tractor.msg.types import MsgType
|
||||||
# ?TODO? this should be our `Union[*msgtypes.__spec__]` alias now right..?
|
# ?TODO? this should be our `Union[*msgtypes.__spec__]` alias now right..?
|
||||||
# => BLEH, except can't bc prots must inherit typevar or param-spec
|
# => BLEH, except can't bc prots must inherit typevar or param-spec
|
||||||
# vars..
|
# vars..
|
||||||
|
AddressType = TypeVar('AddressType')
|
||||||
MsgType = TypeVar('MsgType')
|
MsgType = TypeVar('MsgType')
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class MsgTransport(Protocol[MsgType]):
|
class MsgTransport(Protocol[AddressType, MsgType]):
|
||||||
#
|
#
|
||||||
# ^-TODO-^ consider using a generic def and indexing with our
|
# ^-TODO-^ consider using a generic def and indexing with our
|
||||||
# eventual msg definition/types?
|
# eventual msg definition/types?
|
||||||
|
@ -43,9 +75,7 @@ class MsgTransport(Protocol[MsgType]):
|
||||||
|
|
||||||
stream: trio.abc.Stream
|
stream: trio.abc.Stream
|
||||||
drained: list[MsgType]
|
drained: list[MsgType]
|
||||||
|
address_type: ClassVar[Type[AddressType]]
|
||||||
def __init__(self, stream: trio.abc.Stream) -> None:
|
|
||||||
...
|
|
||||||
|
|
||||||
# XXX: should this instead be called `.sendall()`?
|
# XXX: should this instead be called `.sendall()`?
|
||||||
async def send(self, msg: MsgType) -> None:
|
async def send(self, msg: MsgType) -> None:
|
||||||
|
@ -66,9 +96,345 @@ class MsgTransport(Protocol[MsgType]):
|
||||||
...
|
...
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def laddr(self) -> tuple[str, int]:
|
def laddr(self) -> AddressType:
|
||||||
...
|
...
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def raddr(self) -> tuple[str, int]:
|
def raddr(self) -> AddressType:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def connect_to(
|
||||||
|
cls,
|
||||||
|
destaddr: AddressType,
|
||||||
|
**kwargs
|
||||||
|
) -> MsgTransport:
|
||||||
|
...
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_stream_addrs(
|
||||||
|
cls,
|
||||||
|
stream: trio.abc.Stream
|
||||||
|
) -> tuple[
|
||||||
|
AddressType, # local
|
||||||
|
AddressType # remote
|
||||||
|
]:
|
||||||
|
'''
|
||||||
|
Return the `trio` streaming transport prot's addrs for both
|
||||||
|
the local and remote sides as a pair.
|
||||||
|
|
||||||
|
'''
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class MsgpackTransport(MsgTransport):
|
||||||
|
|
||||||
|
# TODO: better naming for this?
|
||||||
|
# -[ ] check how libp2p does naming for such things?
|
||||||
|
codec_key: str = 'msgpack'
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
stream: trio.abc.Stream,
|
||||||
|
prefix_size: int = 4,
|
||||||
|
|
||||||
|
# XXX optionally provided codec pair for `msgspec`:
|
||||||
|
# https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types
|
||||||
|
#
|
||||||
|
# TODO: define this as a `Codec` struct which can be
|
||||||
|
# overriden dynamically by the application/runtime?
|
||||||
|
codec: MsgCodec = None,
|
||||||
|
|
||||||
|
) -> None:
|
||||||
|
self.stream = stream
|
||||||
|
self._laddr, self._raddr = self.get_stream_addrs(stream)
|
||||||
|
|
||||||
|
# create read loop instance
|
||||||
|
self._aiter_pkts = self._iter_packets()
|
||||||
|
self._send_lock = trio.StrictFIFOLock()
|
||||||
|
|
||||||
|
# public i guess?
|
||||||
|
self.drained: list[dict] = []
|
||||||
|
|
||||||
|
self.recv_stream = BufferedReceiveStream(
|
||||||
|
transport_stream=stream
|
||||||
|
)
|
||||||
|
self.prefix_size = prefix_size
|
||||||
|
|
||||||
|
# allow for custom IPC msg interchange format
|
||||||
|
# dynamic override Bo
|
||||||
|
self._task = trio.lowlevel.current_task()
|
||||||
|
|
||||||
|
# XXX for ctxvar debug only!
|
||||||
|
# self._codec: MsgCodec = (
|
||||||
|
# codec
|
||||||
|
# or
|
||||||
|
# _codec._ctxvar_MsgCodec.get()
|
||||||
|
# )
|
||||||
|
|
||||||
|
async def _iter_packets(self) -> AsyncGenerator[dict, None]:
|
||||||
|
'''
|
||||||
|
Yield `bytes`-blob decoded packets from the underlying TCP
|
||||||
|
stream using the current task's `MsgCodec`.
|
||||||
|
|
||||||
|
This is a streaming routine implemented as an async generator
|
||||||
|
func (which was the original design, but could be changed?)
|
||||||
|
and is allocated by a `.__call__()` inside `.__init__()` where
|
||||||
|
it is assigned to the `._aiter_pkts` attr.
|
||||||
|
|
||||||
|
'''
|
||||||
|
decodes_failed: int = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
header: bytes = await self.recv_stream.receive_exactly(4)
|
||||||
|
except (
|
||||||
|
ValueError,
|
||||||
|
ConnectionResetError,
|
||||||
|
|
||||||
|
# not sure entirely why we need this but without it we
|
||||||
|
# seem to be getting racy failures here on
|
||||||
|
# arbiter/registry name subs..
|
||||||
|
trio.BrokenResourceError,
|
||||||
|
|
||||||
|
) as trans_err:
|
||||||
|
|
||||||
|
loglevel = 'transport'
|
||||||
|
match trans_err:
|
||||||
|
# case (
|
||||||
|
# ConnectionResetError()
|
||||||
|
# ):
|
||||||
|
# loglevel = 'transport'
|
||||||
|
|
||||||
|
# peer actor (graceful??) TCP EOF but `tricycle`
|
||||||
|
# seems to raise a 0-bytes-read?
|
||||||
|
case ValueError() if (
|
||||||
|
'unclean EOF' in trans_err.args[0]
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# peer actor (task) prolly shutdown quickly due
|
||||||
|
# to cancellation
|
||||||
|
case trio.BrokenResourceError() if (
|
||||||
|
'Connection reset by peer' in trans_err.args[0]
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# unless the disconnect condition falls under "a
|
||||||
|
# normal operation breakage" we usualy console warn
|
||||||
|
# about it.
|
||||||
|
case _:
|
||||||
|
loglevel: str = 'warning'
|
||||||
|
|
||||||
|
|
||||||
|
raise TransportClosed(
|
||||||
|
message=(
|
||||||
|
f'IPC transport already closed by peer\n'
|
||||||
|
f'x)> {type(trans_err)}\n'
|
||||||
|
f' |_{self}\n'
|
||||||
|
),
|
||||||
|
loglevel=loglevel,
|
||||||
|
) from trans_err
|
||||||
|
|
||||||
|
# XXX definitely can happen if transport is closed
|
||||||
|
# manually by another `trio.lowlevel.Task` in the
|
||||||
|
# same actor; we use this in some simulated fault
|
||||||
|
# testing for ex, but generally should never happen
|
||||||
|
# under normal operation!
|
||||||
|
#
|
||||||
|
# NOTE: as such we always re-raise this error from the
|
||||||
|
# RPC msg loop!
|
||||||
|
except trio.ClosedResourceError as closure_err:
|
||||||
|
raise TransportClosed(
|
||||||
|
message=(
|
||||||
|
f'IPC transport already manually closed locally?\n'
|
||||||
|
f'x)> {type(closure_err)} \n'
|
||||||
|
f' |_{self}\n'
|
||||||
|
),
|
||||||
|
loglevel='error',
|
||||||
|
raise_on_report=(
|
||||||
|
closure_err.args[0] == 'another task closed this fd'
|
||||||
|
or
|
||||||
|
closure_err.args[0] in ['another task closed this fd']
|
||||||
|
),
|
||||||
|
) from closure_err
|
||||||
|
|
||||||
|
# graceful TCP EOF disconnect
|
||||||
|
if header == b'':
|
||||||
|
raise TransportClosed(
|
||||||
|
message=(
|
||||||
|
f'IPC transport already gracefully closed\n'
|
||||||
|
f')>\n'
|
||||||
|
f'|_{self}\n'
|
||||||
|
),
|
||||||
|
loglevel='transport',
|
||||||
|
# cause=??? # handy or no?
|
||||||
|
)
|
||||||
|
|
||||||
|
size: int
|
||||||
|
size, = struct.unpack("<I", header)
|
||||||
|
|
||||||
|
log.transport(f'received header {size}') # type: ignore
|
||||||
|
msg_bytes: bytes = await self.recv_stream.receive_exactly(size)
|
||||||
|
|
||||||
|
log.transport(f"received {msg_bytes}") # type: ignore
|
||||||
|
try:
|
||||||
|
# NOTE: lookup the `trio.Task.context`'s var for
|
||||||
|
# the current `MsgCodec`.
|
||||||
|
codec: MsgCodec = _ctxvar_MsgCodec.get()
|
||||||
|
|
||||||
|
# XXX for ctxvar debug only!
|
||||||
|
# if self._codec.pld_spec != codec.pld_spec:
|
||||||
|
# assert (
|
||||||
|
# task := trio.lowlevel.current_task()
|
||||||
|
# ) is not self._task
|
||||||
|
# self._task = task
|
||||||
|
# self._codec = codec
|
||||||
|
# log.runtime(
|
||||||
|
# f'Using new codec in {self}.recv()\n'
|
||||||
|
# f'codec: {self._codec}\n\n'
|
||||||
|
# f'msg_bytes: {msg_bytes}\n'
|
||||||
|
# )
|
||||||
|
yield codec.decode(msg_bytes)
|
||||||
|
|
||||||
|
# XXX NOTE: since the below error derives from
|
||||||
|
# `DecodeError` we need to catch is specially
|
||||||
|
# and always raise such that spec violations
|
||||||
|
# are never allowed to be caught silently!
|
||||||
|
except msgspec.ValidationError as verr:
|
||||||
|
msgtyperr: MsgTypeError = _mk_recv_mte(
|
||||||
|
msg=msg_bytes,
|
||||||
|
codec=codec,
|
||||||
|
src_validation_error=verr,
|
||||||
|
)
|
||||||
|
# XXX deliver up to `Channel.recv()` where
|
||||||
|
# a re-raise and `Error`-pack can inject the far
|
||||||
|
# end actor `.uid`.
|
||||||
|
yield msgtyperr
|
||||||
|
|
||||||
|
except (
|
||||||
|
msgspec.DecodeError,
|
||||||
|
UnicodeDecodeError,
|
||||||
|
):
|
||||||
|
if decodes_failed < 4:
|
||||||
|
# ignore decoding errors for now and assume they have to
|
||||||
|
# do with a channel drop - hope that receiving from the
|
||||||
|
# channel will raise an expected error and bubble up.
|
||||||
|
try:
|
||||||
|
msg_str: str|bytes = msg_bytes.decode()
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
msg_str = msg_bytes
|
||||||
|
|
||||||
|
log.exception(
|
||||||
|
'Failed to decode msg?\n'
|
||||||
|
f'{codec}\n\n'
|
||||||
|
'Rxed bytes from wire:\n\n'
|
||||||
|
f'{msg_str!r}\n'
|
||||||
|
)
|
||||||
|
decodes_failed += 1
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def send(
|
||||||
|
self,
|
||||||
|
msg: msgtypes.MsgType,
|
||||||
|
|
||||||
|
strict_types: bool = True,
|
||||||
|
hide_tb: bool = False,
|
||||||
|
|
||||||
|
) -> None:
|
||||||
|
'''
|
||||||
|
Send a msgpack encoded py-object-blob-as-msg over TCP.
|
||||||
|
|
||||||
|
If `strict_types == True` then a `MsgTypeError` will be raised on any
|
||||||
|
invalid msg type
|
||||||
|
|
||||||
|
'''
|
||||||
|
__tracebackhide__: bool = hide_tb
|
||||||
|
|
||||||
|
# XXX see `trio._sync.AsyncContextManagerMixin` for details
|
||||||
|
# on the `.acquire()`/`.release()` sequencing..
|
||||||
|
async with self._send_lock:
|
||||||
|
|
||||||
|
# NOTE: lookup the `trio.Task.context`'s var for
|
||||||
|
# the current `MsgCodec`.
|
||||||
|
codec: MsgCodec = _ctxvar_MsgCodec.get()
|
||||||
|
|
||||||
|
# XXX for ctxvar debug only!
|
||||||
|
# if self._codec.pld_spec != codec.pld_spec:
|
||||||
|
# self._codec = codec
|
||||||
|
# log.runtime(
|
||||||
|
# f'Using new codec in {self}.send()\n'
|
||||||
|
# f'codec: {self._codec}\n\n'
|
||||||
|
# f'msg: {msg}\n'
|
||||||
|
# )
|
||||||
|
|
||||||
|
if type(msg) not in msgtypes.__msg_types__:
|
||||||
|
if strict_types:
|
||||||
|
raise _mk_send_mte(
|
||||||
|
msg,
|
||||||
|
codec=codec,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
log.warning(
|
||||||
|
'Sending non-`Msg`-spec msg?\n\n'
|
||||||
|
f'{msg}\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
bytes_data: bytes = codec.encode(msg)
|
||||||
|
except TypeError as _err:
|
||||||
|
typerr = _err
|
||||||
|
msgtyperr: MsgTypeError = _mk_send_mte(
|
||||||
|
msg,
|
||||||
|
codec=codec,
|
||||||
|
message=(
|
||||||
|
f'IPC-msg-spec violation in\n\n'
|
||||||
|
f'{pretty_struct.Struct.pformat(msg)}'
|
||||||
|
),
|
||||||
|
src_type_error=typerr,
|
||||||
|
)
|
||||||
|
raise msgtyperr from typerr
|
||||||
|
|
||||||
|
# supposedly the fastest says,
|
||||||
|
# https://stackoverflow.com/a/54027962
|
||||||
|
size: bytes = struct.pack("<I", len(bytes_data))
|
||||||
|
return await self.stream.send_all(size + bytes_data)
|
||||||
|
|
||||||
|
# ?TODO? does it help ever to dynamically show this
|
||||||
|
# frame?
|
||||||
|
# try:
|
||||||
|
# <the-above_code>
|
||||||
|
# except BaseException as _err:
|
||||||
|
# err = _err
|
||||||
|
# if not isinstance(err, MsgTypeError):
|
||||||
|
# __tracebackhide__: bool = False
|
||||||
|
# raise
|
||||||
|
|
||||||
|
async def recv(self) -> Any:
|
||||||
|
return await self._aiter_pkts.asend(None)
|
||||||
|
|
||||||
|
async def drain(self) -> AsyncIterator[dict]:
|
||||||
|
'''
|
||||||
|
Drain the stream's remaining messages sent from
|
||||||
|
the far end until the connection is closed by
|
||||||
|
the peer.
|
||||||
|
|
||||||
|
'''
|
||||||
|
try:
|
||||||
|
async for msg in self._iter_packets():
|
||||||
|
self.drained.append(msg)
|
||||||
|
except TransportClosed:
|
||||||
|
for msg in self.drained:
|
||||||
|
yield msg
|
||||||
|
|
||||||
|
def __aiter__(self):
|
||||||
|
return self._aiter_pkts
|
||||||
|
|
||||||
|
@property
|
||||||
|
def laddr(self) -> AddressType:
|
||||||
|
return self._laddr
|
||||||
|
|
||||||
|
@property
|
||||||
|
def raddr(self) -> AddressType:
|
||||||
|
return self._raddr
|
||||||
|
|
|
@ -0,0 +1,101 @@
|
||||||
|
# tractor: structured concurrent "actors".
|
||||||
|
# Copyright 2018-eternity Tyler Goodlet.
|
||||||
|
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
|
# the Free Software Foundation, either version 3 of the License, or
|
||||||
|
# (at your option) any later version.
|
||||||
|
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
from typing import Type, Union
|
||||||
|
|
||||||
|
import trio
|
||||||
|
import socket
|
||||||
|
|
||||||
|
from ._transport import MsgTransport
|
||||||
|
from ._tcp import MsgpackTCPStream
|
||||||
|
from ._uds import MsgpackUDSStream
|
||||||
|
|
||||||
|
|
||||||
|
# manually updated list of all supported codec+transport types
|
||||||
|
_msg_transports = {
|
||||||
|
('msgpack', 'tcp'): MsgpackTCPStream,
|
||||||
|
('msgpack', 'uds'): MsgpackUDSStream
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# all different address py types we use
|
||||||
|
AddressTypes = Union[
|
||||||
|
tuple([
|
||||||
|
cls.address_type
|
||||||
|
for key, cls in _msg_transports.items()
|
||||||
|
])
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def transport_from_destaddr(
|
||||||
|
destaddr: AddressTypes,
|
||||||
|
codec_key: str = 'msgpack',
|
||||||
|
) -> Type[MsgTransport]:
|
||||||
|
'''
|
||||||
|
Given a destination address and a desired codec, find the
|
||||||
|
corresponding `MsgTransport` type.
|
||||||
|
|
||||||
|
'''
|
||||||
|
match destaddr:
|
||||||
|
case str():
|
||||||
|
return MsgpackUDSStream
|
||||||
|
|
||||||
|
case tuple():
|
||||||
|
if (
|
||||||
|
len(destaddr) == 2
|
||||||
|
and
|
||||||
|
isinstance(destaddr[0], str)
|
||||||
|
and
|
||||||
|
isinstance(destaddr[1], int)
|
||||||
|
):
|
||||||
|
return MsgpackTCPStream
|
||||||
|
|
||||||
|
raise NotImplementedError(
|
||||||
|
f'No known transport for address {destaddr}'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def transport_from_stream(
|
||||||
|
stream: trio.abc.Stream,
|
||||||
|
codec_key: str = 'msgpack'
|
||||||
|
) -> Type[MsgTransport]:
|
||||||
|
'''
|
||||||
|
Given an arbitrary `trio.abc.Stream` and a desired codec,
|
||||||
|
find the corresponding `MsgTransport` type.
|
||||||
|
|
||||||
|
'''
|
||||||
|
transport = None
|
||||||
|
if isinstance(stream, trio.SocketStream):
|
||||||
|
sock = stream.socket
|
||||||
|
match sock.family:
|
||||||
|
case socket.AF_INET | socket.AF_INET6:
|
||||||
|
transport = 'tcp'
|
||||||
|
|
||||||
|
case socket.AF_UNIX:
|
||||||
|
transport = 'uds'
|
||||||
|
|
||||||
|
case _:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f'Unsupported socket family: {sock.family}'
|
||||||
|
)
|
||||||
|
|
||||||
|
if not transport:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f'Could not figure out transport type for stream type {type(stream)}'
|
||||||
|
)
|
||||||
|
|
||||||
|
key = (codec_key, transport)
|
||||||
|
|
||||||
|
return _msg_transports[key]
|
|
@ -0,0 +1,84 @@
|
||||||
|
# tractor: structured concurrent "actors".
|
||||||
|
# Copyright 2018-eternity Tyler Goodlet.
|
||||||
|
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
|
# the Free Software Foundation, either version 3 of the License, or
|
||||||
|
# (at your option) any later version.
|
||||||
|
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
'''
|
||||||
|
Unix Domain Socket implementation of tractor.ipc._transport.MsgTransport protocol
|
||||||
|
|
||||||
|
'''
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import trio
|
||||||
|
|
||||||
|
from tractor.msg import MsgCodec
|
||||||
|
from tractor.log import get_logger
|
||||||
|
from tractor.ipc._transport import MsgpackTransport
|
||||||
|
|
||||||
|
|
||||||
|
log = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MsgpackUDSStream(MsgpackTransport):
|
||||||
|
'''
|
||||||
|
A ``trio.SocketStream`` delivering ``msgpack`` formatted data
|
||||||
|
using the ``msgspec`` codec lib.
|
||||||
|
|
||||||
|
'''
|
||||||
|
address_type = str
|
||||||
|
layer_key: int = 7
|
||||||
|
name_key: str = 'uds'
|
||||||
|
|
||||||
|
# def __init__(
|
||||||
|
# self,
|
||||||
|
# stream: trio.SocketStream,
|
||||||
|
# prefix_size: int = 4,
|
||||||
|
# codec: CodecType = None,
|
||||||
|
|
||||||
|
# ) -> None:
|
||||||
|
# super().__init__(
|
||||||
|
# stream,
|
||||||
|
# prefix_size=prefix_size,
|
||||||
|
# codec=codec
|
||||||
|
# )
|
||||||
|
|
||||||
|
def connected(self) -> bool:
|
||||||
|
return self.stream.socket.fileno() != -1
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def connect_to(
|
||||||
|
cls,
|
||||||
|
filename: str,
|
||||||
|
prefix_size: int = 4,
|
||||||
|
codec: MsgCodec|None = None,
|
||||||
|
**kwargs
|
||||||
|
) -> MsgpackUDSStream:
|
||||||
|
stream = await trio.open_unix_socket(
|
||||||
|
filename,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return MsgpackUDSStream(
|
||||||
|
stream,
|
||||||
|
prefix_size=prefix_size,
|
||||||
|
codec=codec
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_stream_addrs(
|
||||||
|
cls,
|
||||||
|
stream: trio.SocketStream
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
return (
|
||||||
|
stream.socket.getsockname(),
|
||||||
|
stream.socket.getpeername(),
|
||||||
|
)
|
Loading…
Reference in New Issue