Non TCP specific addressing everywhere #17

Open
guille wants to merge 4 commits from structural_dynamics_of_flow into one_ring_to_rule_them_all
25 changed files with 1217 additions and 673 deletions

View File

@ -10,6 +10,7 @@ pkgs.mkShell {
inherit nativeBuildInputs; inherit nativeBuildInputs;
LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath nativeBuildInputs; LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath nativeBuildInputs;
TMPDIR = "/tmp";
shellHook = '' shellHook = ''
set -e set -e

View File

@ -9,7 +9,7 @@ async def main(service_name):
async with tractor.open_nursery() as an: async with tractor.open_nursery() as an:
await an.start_actor(service_name) await an.start_actor(service_name)
async with tractor.get_registry('127.0.0.1', 1616) as portal: async with tractor.get_registry() as portal:
print(f"Arbiter is listening on {portal.channel}") print(f"Arbiter is listening on {portal.channel}")
async with tractor.wait_for_actor(service_name) as sockaddr: async with tractor.wait_for_actor(service_name) as sockaddr:

View File

@ -26,7 +26,7 @@ async def test_reg_then_unreg(reg_addr):
portal = await n.start_actor('actor', enable_modules=[__name__]) portal = await n.start_actor('actor', enable_modules=[__name__])
uid = portal.channel.uid uid = portal.channel.uid
async with tractor.get_registry(*reg_addr) as aportal: async with tractor.get_registry(reg_addr) as aportal:
# this local actor should be the arbiter # this local actor should be the arbiter
assert actor is aportal.actor assert actor is aportal.actor
@ -160,7 +160,7 @@ async def spawn_and_check_registry(
async with tractor.open_root_actor( async with tractor.open_root_actor(
registry_addrs=[reg_addr], registry_addrs=[reg_addr],
): ):
async with tractor.get_registry(*reg_addr) as portal: async with tractor.get_registry(reg_addr) as portal:
# runtime needs to be up to call this # runtime needs to be up to call this
actor = tractor.current_actor() actor = tractor.current_actor()
@ -300,7 +300,7 @@ async def close_chans_before_nursery(
async with tractor.open_root_actor( async with tractor.open_root_actor(
registry_addrs=[reg_addr], registry_addrs=[reg_addr],
): ):
async with tractor.get_registry(*reg_addr) as aportal: async with tractor.get_registry(reg_addr) as aportal:
try: try:
get_reg = partial(unpack_reg, aportal) get_reg = partial(unpack_reg, aportal)

View File

@ -66,6 +66,9 @@ def run_example_in_subproc(
# due to backpressure!!! # due to backpressure!!!
proc = testdir.popen( proc = testdir.popen(
cmdargs, cmdargs,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
**kwargs, **kwargs,
) )
assert not proc.returncode assert not proc.returncode
@ -119,10 +122,14 @@ def test_example(
code = ex.read() code = ex.read()
with run_example_in_subproc(code) as proc: with run_example_in_subproc(code) as proc:
proc.wait() err = None
err, _ = proc.stderr.read(), proc.stdout.read() try:
# print(f'STDERR: {err}') if not proc.poll():
# print(f'STDOUT: {out}') _, err = proc.communicate(timeout=15)
except subprocess.TimeoutExpired as e:
proc.kill()
err = e.stderr
# if we get some gnarly output let's aggregate and raise # if we get some gnarly output let's aggregate and raise
if err: if err:

View File

@ -871,7 +871,7 @@ async def serve_subactors(
) )
await ipc.send(( await ipc.send((
peer.chan.uid, peer.chan.uid,
peer.chan.raddr, peer.chan.raddr.unwrap(),
)) ))
print('Spawner exiting spawn serve loop!') print('Spawner exiting spawn serve loop!')

View File

@ -38,7 +38,7 @@ async def test_self_is_registered_localportal(reg_addr):
"Verify waiting on the arbiter to register itself using a local portal." "Verify waiting on the arbiter to register itself using a local portal."
actor = tractor.current_actor() actor = tractor.current_actor()
assert actor.is_arbiter assert actor.is_arbiter
async with tractor.get_registry(*reg_addr) as portal: async with tractor.get_registry(reg_addr) as portal:
assert isinstance(portal, tractor._portal.LocalPortal) assert isinstance(portal, tractor._portal.LocalPortal)
with trio.fail_after(0.2): with trio.fail_after(0.2):

View File

@ -32,7 +32,7 @@ def test_abort_on_sigint(daemon):
@tractor_test @tractor_test
async def test_cancel_remote_arbiter(daemon, reg_addr): async def test_cancel_remote_arbiter(daemon, reg_addr):
assert not tractor.current_actor().is_arbiter assert not tractor.current_actor().is_arbiter
async with tractor.get_registry(*reg_addr) as portal: async with tractor.get_registry(reg_addr) as portal:
await portal.cancel_actor() await portal.cancel_actor()
time.sleep(0.1) time.sleep(0.1)
@ -41,7 +41,7 @@ async def test_cancel_remote_arbiter(daemon, reg_addr):
# no arbiter socket should exist # no arbiter socket should exist
with pytest.raises(OSError): with pytest.raises(OSError):
async with tractor.get_registry(*reg_addr) as portal: async with tractor.get_registry(reg_addr) as portal:
pass pass

View File

@ -77,7 +77,7 @@ async def movie_theatre_question():
async def test_movie_theatre_convo(start_method): async def test_movie_theatre_convo(start_method):
"""The main ``tractor`` routine. """The main ``tractor`` routine.
""" """
async with tractor.open_nursery() as n: async with tractor.open_nursery(debug_mode=True) as n:
portal = await n.start_actor( portal = await n.start_actor(
'frank', 'frank',

310
tractor/_addr.py 100644
View File

@ -0,0 +1,310 @@
# tractor: structured concurrent "actors".
# Copyright 2018-eternity Tyler Goodlet.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
import os
import tempfile
from uuid import uuid4
from typing import (
Protocol,
ClassVar,
TypeVar,
Union,
Type
)
import trio
from trio import socket
NamespaceType = TypeVar('NamespaceType')
AddressType = TypeVar('AddressType')
StreamType = TypeVar('StreamType')
ListenerType = TypeVar('ListenerType')
class Address(Protocol[
NamespaceType,
AddressType,
StreamType,
ListenerType
]):
name_key: ClassVar[str]
address_type: ClassVar[Type[AddressType]]
@property
def is_valid(self) -> bool:
...
@property
def namespace(self) -> NamespaceType|None:
...
@classmethod
def from_addr(cls, addr: AddressType) -> Address:
...
def unwrap(self) -> AddressType:
...
@classmethod
def get_random(cls, namespace: NamespaceType | None = None) -> Address:
...
@classmethod
def get_root(cls) -> Address:
...
def __repr__(self) -> str:
...
def __eq__(self, other) -> bool:
...
async def open_stream(self, **kwargs) -> StreamType:
...
async def open_listener(self, **kwargs) -> ListenerType:
...
async def close_listener(self):
...
class TCPAddress(Address[
str,
tuple[str, int],
trio.SocketStream,
trio.SocketListener
]):
name_key: str = 'tcp'
address_type: type = tuple[str, int]
def __init__(
self,
host: str,
port: int
):
if (
not isinstance(host, str)
or
not isinstance(port, int)
):
raise TypeError(f'Expected host {host} to be str and port {port} to be int')
self._host = host
self._port = port
@property
def is_valid(self) -> bool:
return self._port != 0
@property
def namespace(self) -> str:
return self._host
@classmethod
def from_addr(cls, addr: tuple[str, int]) -> TCPAddress:
return TCPAddress(addr[0], addr[1])
def unwrap(self) -> tuple[str, int]:
return self._host, self._port
@classmethod
def get_random(cls, namespace: str = '127.0.0.1') -> TCPAddress:
return TCPAddress(namespace, 0)
@classmethod
def get_root(cls) -> Address:
return TCPAddress('127.0.0.1', 1616)
def __repr__(self) -> str:
return f'{type(self)} @ {self.unwrap()}'
def __eq__(self, other) -> bool:
if not isinstance(other, TCPAddress):
raise TypeError(
f'Can not compare {type(other)} with {type(self)}'
)
return (
self._host == other._host
and
self._port == other._port
)
async def open_stream(self, **kwargs) -> trio.SocketStream:
stream = await trio.open_tcp_stream(
self._host,
self._port,
**kwargs
)
self._host, self._port = stream.socket.getsockname()[:2]
return stream
async def open_listener(self, **kwargs) -> trio.SocketListener:
listeners = await trio.open_tcp_listeners(
host=self._host,
port=self._port,
**kwargs
)
assert len(listeners) == 1
listener = listeners[0]
self._host, self._port = listener.socket.getsockname()[:2]
return listener
async def close_listener(self):
...
class UDSAddress(Address[
None,
str,
trio.SocketStream,
trio.SocketListener
]):
name_key: str = 'uds'
address_type: type = str
def __init__(
self,
filepath: str
):
self._filepath = filepath
@property
def is_valid(self) -> bool:
return True
@property
def namespace(self) -> None:
return
@classmethod
def from_addr(cls, filepath: str) -> UDSAddress:
return UDSAddress(filepath)
def unwrap(self) -> str:
return self._filepath
@classmethod
def get_random(cls, namespace: None = None) -> UDSAddress:
return UDSAddress(f'{tempfile.gettempdir()}/{uuid4()}.sock')
@classmethod
def get_root(cls) -> Address:
return UDSAddress('tractor.sock')
def __repr__(self) -> str:
return f'{type(self)} @ {self._filepath}'
def __eq__(self, other) -> bool:
if not isinstance(other, UDSAddress):
raise TypeError(
f'Can not compare {type(other)} with {type(self)}'
)
return self._filepath == other._filepath
async def open_stream(self, **kwargs) -> trio.SocketStream:
stream = await trio.open_unix_socket(
self._filepath,
**kwargs
)
return stream
async def open_listener(self, **kwargs) -> trio.SocketListener:
self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
await self._sock.bind(self._filepath)
self._sock.listen(1)
return trio.SocketListener(self._sock)
async def close_listener(self):
self._sock.close()
os.unlink(self._filepath)
preferred_transport = 'uds'
_address_types = (
TCPAddress,
UDSAddress
)
_default_addrs: dict[str, Type[Address]] = {
cls.name_key: cls
for cls in _address_types
}
AddressTypes = Union[
tuple([
cls.address_type
for cls in _address_types
])
]
_default_lo_addrs: dict[
str,
AddressTypes
] = {
cls.name_key: cls.get_root().unwrap()
for cls in _address_types
}
def get_address_cls(name: str) -> Type[Address]:
return _default_addrs[name]
def is_wrapped_addr(addr: any) -> bool:
return type(addr) in _address_types
def wrap_address(addr: AddressTypes) -> Address:
if is_wrapped_addr(addr):
return addr
cls = None
match addr:
case str():
cls = UDSAddress
case tuple() | list():
cls = TCPAddress
case None:
cls = get_address_cls(preferred_transport)
addr = cls.get_root().unwrap()
case _:
raise TypeError(
f'Can not wrap addr {addr} of type {type(addr)}'
)
return cls.from_addr(addr)
def default_lo_addrs(transports: list[str]) -> list[AddressTypes]:
return [
_default_lo_addrs[transport]
for transport in transports
]

View File

@ -31,8 +31,12 @@ def parse_uid(arg):
return str(name), str(uuid) # ensures str encoding return str(name), str(uuid) # ensures str encoding
def parse_ipaddr(arg): def parse_ipaddr(arg):
host, port = literal_eval(arg) try:
return (str(host), int(port)) return literal_eval(arg)
except (ValueError, SyntaxError):
# UDS: try to interpret as a straight up str
return arg
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -859,19 +859,10 @@ class Context:
@property @property
def dst_maddr(self) -> str: def dst_maddr(self) -> str:
chan: Channel = self.chan chan: Channel = self.chan
dst_addr, dst_port = chan.raddr
trans: MsgTransport = chan.transport trans: MsgTransport = chan.transport
# cid: str = self.cid # cid: str = self.cid
# cid_head, cid_tail = cid[:6], cid[-6:] # cid_head, cid_tail = cid[:6], cid[-6:]
return ( return trans.maddr
f'/ipv4/{dst_addr}'
f'/{trans.name_key}/{dst_port}'
# f'/{self.chan.uid[0]}'
# f'/{self.cid}'
# f'/cid={cid_head}..{cid_tail}'
# TODO: ? not use this ^ right ?
)
dmaddr = dst_maddr dmaddr = dst_maddr

View File

@ -30,6 +30,12 @@ from contextlib import asynccontextmanager as acm
from tractor.log import get_logger from tractor.log import get_logger
from .trionics import gather_contexts from .trionics import gather_contexts
from .ipc import _connect_chan, Channel from .ipc import _connect_chan, Channel
from ._addr import (
AddressTypes,
Address,
preferred_transport,
wrap_address
)
from ._portal import ( from ._portal import (
Portal, Portal,
open_portal, open_portal,
@ -48,11 +54,7 @@ log = get_logger(__name__)
@acm @acm
async def get_registry( async def get_registry(addr: AddressTypes | None = None) -> AsyncGenerator[
host: str,
port: int,
) -> AsyncGenerator[
Portal | LocalPortal | None, Portal | LocalPortal | None,
None, None,
]: ]:
@ -69,13 +71,13 @@ async def get_registry(
# (likely a re-entrant call from the arbiter actor) # (likely a re-entrant call from the arbiter actor)
yield LocalPortal( yield LocalPortal(
actor, actor,
Channel((host, port)) await Channel.from_addr(addr)
) )
else: else:
# TODO: try to look pre-existing connection from # TODO: try to look pre-existing connection from
# `Actor._peers` and use it instead? # `Actor._peers` and use it instead?
async with ( async with (
_connect_chan(host, port) as chan, _connect_chan(addr) as chan,
open_portal(chan) as regstr_ptl, open_portal(chan) as regstr_ptl,
): ):
yield regstr_ptl yield regstr_ptl
@ -89,11 +91,10 @@ async def get_root(
# TODO: rename mailbox to `_root_maddr` when we finally # TODO: rename mailbox to `_root_maddr` when we finally
# add and impl libp2p multi-addrs? # add and impl libp2p multi-addrs?
host, port = _runtime_vars['_root_mailbox'] addr = _runtime_vars['_root_mailbox']
assert host is not None
async with ( async with (
_connect_chan(host, port) as chan, _connect_chan(addr) as chan,
open_portal(chan, **kwargs) as portal, open_portal(chan, **kwargs) as portal,
): ):
yield portal yield portal
@ -134,10 +135,10 @@ def get_peer_by_name(
@acm @acm
async def query_actor( async def query_actor(
name: str, name: str,
regaddr: tuple[str, int]|None = None, regaddr: AddressTypes|None = None,
) -> AsyncGenerator[ ) -> AsyncGenerator[
tuple[str, int]|None, AddressTypes|None,
None, None,
]: ]:
''' '''
@ -163,31 +164,31 @@ async def query_actor(
return return
reg_portal: Portal reg_portal: Portal
regaddr: tuple[str, int] = regaddr or actor.reg_addrs[0] regaddr: Address = wrap_address(regaddr) or actor.reg_addrs[0]
async with get_registry(*regaddr) as reg_portal: async with get_registry(regaddr) as reg_portal:
# TODO: return portals to all available actors - for now # TODO: return portals to all available actors - for now
# just the last one that registered # just the last one that registered
sockaddr: tuple[str, int] = await reg_portal.run_from_ns( addr: AddressTypes = await reg_portal.run_from_ns(
'self', 'self',
'find_actor', 'find_actor',
name=name, name=name,
) )
yield sockaddr yield addr
@acm @acm
async def maybe_open_portal( async def maybe_open_portal(
addr: tuple[str, int], addr: AddressTypes,
name: str, name: str,
): ):
async with query_actor( async with query_actor(
name=name, name=name,
regaddr=addr, regaddr=addr,
) as sockaddr: ) as addr:
pass pass
if sockaddr: if addr:
async with _connect_chan(*sockaddr) as chan: async with _connect_chan(addr) as chan:
async with open_portal(chan) as portal: async with open_portal(chan) as portal:
yield portal yield portal
else: else:
@ -197,7 +198,8 @@ async def maybe_open_portal(
@acm @acm
async def find_actor( async def find_actor(
name: str, name: str,
registry_addrs: list[tuple[str, int]]|None = None, registry_addrs: list[AddressTypes]|None = None,
enable_transports: list[str] = [preferred_transport],
only_first: bool = True, only_first: bool = True,
raise_on_none: bool = False, raise_on_none: bool = False,
@ -224,15 +226,15 @@ async def find_actor(
# XXX NOTE: make sure to dynamically read the value on # XXX NOTE: make sure to dynamically read the value on
# every call since something may change it globally (eg. # every call since something may change it globally (eg.
# like in our discovery test suite)! # like in our discovery test suite)!
from . import _root from ._addr import default_lo_addrs
registry_addrs = ( registry_addrs = (
_runtime_vars['_registry_addrs'] _runtime_vars['_registry_addrs']
or or
_root._default_lo_addrs default_lo_addrs(enable_transports)
) )
maybe_portals: list[ maybe_portals: list[
AsyncContextManager[tuple[str, int]] AsyncContextManager[AddressTypes]
] = list( ] = list(
maybe_open_portal( maybe_open_portal(
addr=addr, addr=addr,
@ -274,7 +276,7 @@ async def find_actor(
@acm @acm
async def wait_for_actor( async def wait_for_actor(
name: str, name: str,
registry_addr: tuple[str, int] | None = None, registry_addr: AddressTypes | None = None,
) -> AsyncGenerator[Portal, None]: ) -> AsyncGenerator[Portal, None]:
''' '''
@ -291,7 +293,7 @@ async def wait_for_actor(
yield peer_portal yield peer_portal
return return
regaddr: tuple[str, int] = ( regaddr: AddressTypes = (
registry_addr registry_addr
or or
actor.reg_addrs[0] actor.reg_addrs[0]
@ -299,8 +301,8 @@ async def wait_for_actor(
# TODO: use `.trionics.gather_contexts()` like # TODO: use `.trionics.gather_contexts()` like
# above in `find_actor()` as well? # above in `find_actor()` as well?
reg_portal: Portal reg_portal: Portal
async with get_registry(*regaddr) as reg_portal: async with get_registry(regaddr) as reg_portal:
sockaddrs = await reg_portal.run_from_ns( addrs = await reg_portal.run_from_ns(
'self', 'self',
'wait_for_actor', 'wait_for_actor',
name=name, name=name,
@ -308,8 +310,8 @@ async def wait_for_actor(
# get latest registered addr by default? # get latest registered addr by default?
# TODO: offer multi-portal yields in multi-homed case? # TODO: offer multi-portal yields in multi-homed case?
sockaddr: tuple[str, int] = sockaddrs[-1] addr: AddressTypes = addrs[-1]
async with _connect_chan(*sockaddr) as chan: async with _connect_chan(addr) as chan:
async with open_portal(chan) as portal: async with open_portal(chan) as portal:
yield portal yield portal

View File

@ -37,6 +37,7 @@ from .log import (
from . import _state from . import _state
from .devx import _debug from .devx import _debug
from .to_asyncio import run_as_asyncio_guest from .to_asyncio import run_as_asyncio_guest
from ._addr import AddressTypes
from ._runtime import ( from ._runtime import (
async_main, async_main,
Actor, Actor,
@ -52,10 +53,10 @@ log = get_logger(__name__)
def _mp_main( def _mp_main(
actor: Actor, actor: Actor,
accept_addrs: list[tuple[str, int]], accept_addrs: list[AddressTypes],
forkserver_info: tuple[Any, Any, Any, Any, Any], forkserver_info: tuple[Any, Any, Any, Any, Any],
start_method: SpawnMethodKey, start_method: SpawnMethodKey,
parent_addr: tuple[str, int] | None = None, parent_addr: AddressTypes | None = None,
infect_asyncio: bool = False, infect_asyncio: bool = False,
) -> None: ) -> None:
@ -206,7 +207,7 @@ def nest_from_op(
def _trio_main( def _trio_main(
actor: Actor, actor: Actor,
*, *,
parent_addr: tuple[str, int] | None = None, parent_addr: AddressTypes | None = None,
infect_asyncio: bool = False, infect_asyncio: bool = False,
) -> None: ) -> None:

View File

@ -43,21 +43,18 @@ from .devx import _debug
from . import _spawn from . import _spawn
from . import _state from . import _state
from . import log from . import log
from .ipc import _connect_chan from .ipc import (
_connect_chan,
)
from ._addr import (
AddressTypes,
wrap_address,
preferred_transport,
default_lo_addrs
)
from ._exceptions import is_multi_cancelled from ._exceptions import is_multi_cancelled
# set at startup and after forks
_default_host: str = '127.0.0.1'
_default_port: int = 1616
# default registry always on localhost
_default_lo_addrs: list[tuple[str, int]] = [(
_default_host,
_default_port,
)]
logger = log.get_logger('tractor') logger = log.get_logger('tractor')
@ -66,10 +63,12 @@ async def open_root_actor(
*, *,
# defaults are above # defaults are above
registry_addrs: list[tuple[str, int]]|None = None, registry_addrs: list[AddressTypes]|None = None,
# defaults are above # defaults are above
arbiter_addr: tuple[str, int]|None = None, arbiter_addr: tuple[AddressTypes]|None = None,
enable_transports: list[str] = [preferred_transport],
name: str|None = 'root', name: str|None = 'root',
@ -195,11 +194,9 @@ async def open_root_actor(
) )
registry_addrs = [arbiter_addr] registry_addrs = [arbiter_addr]
registry_addrs: list[tuple[str, int]] = ( if not registry_addrs:
registry_addrs registry_addrs: list[AddressTypes] = default_lo_addrs(enable_transports)
or
_default_lo_addrs
)
assert registry_addrs assert registry_addrs
loglevel = ( loglevel = (
@ -248,10 +245,10 @@ async def open_root_actor(
enable_stack_on_sig() enable_stack_on_sig()
# closed into below ping task-func # closed into below ping task-func
ponged_addrs: list[tuple[str, int]] = [] ponged_addrs: list[AddressTypes] = []
async def ping_tpt_socket( async def ping_tpt_socket(
addr: tuple[str, int], addr: AddressTypes,
timeout: float = 1, timeout: float = 1,
) -> None: ) -> None:
''' '''
@ -271,7 +268,7 @@ async def open_root_actor(
# be better to eventually have a "discovery" protocol # be better to eventually have a "discovery" protocol
# with basic handshake instead? # with basic handshake instead?
with trio.move_on_after(timeout): with trio.move_on_after(timeout):
async with _connect_chan(*addr): async with _connect_chan(addr):
ponged_addrs.append(addr) ponged_addrs.append(addr)
except OSError: except OSError:
@ -284,10 +281,10 @@ async def open_root_actor(
for addr in registry_addrs: for addr in registry_addrs:
tn.start_soon( tn.start_soon(
ping_tpt_socket, ping_tpt_socket,
tuple(addr), # TODO: just drop this requirement? addr,
) )
trans_bind_addrs: list[tuple[str, int]] = [] trans_bind_addrs: list[AddressTypes] = []
# Create a new local root-actor instance which IS NOT THE # Create a new local root-actor instance which IS NOT THE
# REGISTRAR # REGISTRAR
@ -311,9 +308,12 @@ async def open_root_actor(
) )
# DO NOT use the registry_addrs as the transport server # DO NOT use the registry_addrs as the transport server
# addrs for this new non-registar, root-actor. # addrs for this new non-registar, root-actor.
for host, port in ponged_addrs: for addr in ponged_addrs:
# NOTE: zero triggers dynamic OS port allocation waddr = wrap_address(addr)
trans_bind_addrs.append((host, 0)) print(waddr)
trans_bind_addrs.append(
waddr.get_random(namespace=waddr.namespace)
)
# Start this local actor as the "registrar", aka a regular # Start this local actor as the "registrar", aka a regular
# actor who manages the local registry of "mailboxes" of # actor who manages the local registry of "mailboxes" of
@ -322,7 +322,7 @@ async def open_root_actor(
# NOTE that if the current actor IS THE REGISTAR, the # NOTE that if the current actor IS THE REGISTAR, the
# following init steps are taken: # following init steps are taken:
# - the tranport layer server is bound to each (host, port) # - the tranport layer server is bound to each addr
# pair defined in provided registry_addrs, or the default. # pair defined in provided registry_addrs, or the default.
trans_bind_addrs = registry_addrs trans_bind_addrs = registry_addrs
@ -462,7 +462,7 @@ def run_daemon(
# runtime kwargs # runtime kwargs
name: str | None = 'root', name: str | None = 'root',
registry_addrs: list[tuple[str, int]] = _default_lo_addrs, registry_addrs: list[AddressTypes]|None = None,
start_method: str | None = None, start_method: str | None = None,
debug_mode: bool = False, debug_mode: bool = False,

View File

@ -74,6 +74,13 @@ from tractor.msg import (
types as msgtypes, types as msgtypes,
) )
from .ipc import Channel from .ipc import Channel
from ._addr import (
AddressTypes,
Address,
wrap_address,
preferred_transport,
default_lo_addrs
)
from ._context import ( from ._context import (
mk_context, mk_context,
Context, Context,
@ -179,11 +186,11 @@ class Actor:
enable_modules: list[str] = [], enable_modules: list[str] = [],
uid: str|None = None, uid: str|None = None,
loglevel: str|None = None, loglevel: str|None = None,
registry_addrs: list[tuple[str, int]]|None = None, registry_addrs: list[AddressTypes]|None = None,
spawn_method: str|None = None, spawn_method: str|None = None,
# TODO: remove! # TODO: remove!
arbiter_addr: tuple[str, int]|None = None, arbiter_addr: AddressTypes|None = None,
) -> None: ) -> None:
''' '''
@ -223,7 +230,7 @@ class Actor:
DeprecationWarning, DeprecationWarning,
stacklevel=2, stacklevel=2,
) )
registry_addrs: list[tuple[str, int]] = [arbiter_addr] registry_addrs: list[AddressTypes] = [arbiter_addr]
# marked by the process spawning backend at startup # marked by the process spawning backend at startup
# will be None for the parent most process started manually # will be None for the parent most process started manually
@ -257,6 +264,7 @@ class Actor:
] = {} ] = {}
self._listeners: list[trio.abc.Listener] = [] self._listeners: list[trio.abc.Listener] = []
self._listen_addrs: list[Address] = []
self._parent_chan: Channel|None = None self._parent_chan: Channel|None = None
self._forkserver_info: tuple|None = None self._forkserver_info: tuple|None = None
@ -269,13 +277,13 @@ class Actor:
# when provided, init the registry addresses property from # when provided, init the registry addresses property from
# input via the validator. # input via the validator.
self._reg_addrs: list[tuple[str, int]] = [] self._reg_addrs: list[AddressTypes] = []
if registry_addrs: if registry_addrs:
self.reg_addrs: list[tuple[str, int]] = registry_addrs self.reg_addrs: list[AddressTypes] = registry_addrs
_state._runtime_vars['_registry_addrs'] = registry_addrs _state._runtime_vars['_registry_addrs'] = registry_addrs
@property @property
def reg_addrs(self) -> list[tuple[str, int]]: def reg_addrs(self) -> list[AddressTypes]:
''' '''
List of (socket) addresses for all known (and contactable) List of (socket) addresses for all known (and contactable)
registry actors. registry actors.
@ -286,7 +294,7 @@ class Actor:
@reg_addrs.setter @reg_addrs.setter
def reg_addrs( def reg_addrs(
self, self,
addrs: list[tuple[str, int]], addrs: list[AddressTypes],
) -> None: ) -> None:
if not addrs: if not addrs:
log.warning( log.warning(
@ -295,16 +303,7 @@ class Actor:
) )
return return
# always sanity check the input list since it's critical self._reg_addrs = addrs
# that addrs are correct for discovery sys operation.
for addr in addrs:
if not isinstance(addr, tuple):
raise ValueError(
'Expected `Actor.reg_addrs: list[tuple[str, int]]`\n'
f'Got {addrs}'
)
self._reg_addrs = addrs
async def wait_for_peer( async def wait_for_peer(
self, self,
@ -1024,11 +1023,11 @@ class Actor:
async def _from_parent( async def _from_parent(
self, self,
parent_addr: tuple[str, int]|None, parent_addr: AddressTypes|None,
) -> tuple[ ) -> tuple[
Channel, Channel,
list[tuple[str, int]]|None, list[AddressTypes]|None,
]: ]:
''' '''
Bootstrap this local actor's runtime config from its parent by Bootstrap this local actor's runtime config from its parent by
@ -1040,16 +1039,13 @@ class Actor:
# Connect back to the parent actor and conduct initial # Connect back to the parent actor and conduct initial
# handshake. From this point on if we error, we # handshake. From this point on if we error, we
# attempt to ship the exception back to the parent. # attempt to ship the exception back to the parent.
chan = Channel( chan = await Channel.from_addr(wrap_address(parent_addr))
destaddr=parent_addr,
)
await chan.connect()
# TODO: move this into a `Channel.handshake()`? # TODO: move this into a `Channel.handshake()`?
# Initial handshake: swap names. # Initial handshake: swap names.
await self._do_handshake(chan) await self._do_handshake(chan)
accept_addrs: list[tuple[str, int]]|None = None accept_addrs: list[AddressTypes]|None = None
if self._spawn_method == "trio": if self._spawn_method == "trio":
@ -1066,7 +1062,7 @@ class Actor:
# if "trace"/"util" mode is enabled? # if "trace"/"util" mode is enabled?
f'{pretty_struct.pformat(spawnspec)}\n' f'{pretty_struct.pformat(spawnspec)}\n'
) )
accept_addrs: list[tuple[str, int]] = spawnspec.bind_addrs accept_addrs: list[AddressTypes] = spawnspec.bind_addrs
# TODO: another `Struct` for rtvs.. # TODO: another `Struct` for rtvs..
rvs: dict[str, Any] = spawnspec._runtime_vars rvs: dict[str, Any] = spawnspec._runtime_vars
@ -1173,8 +1169,7 @@ class Actor:
self, self,
handler_nursery: Nursery, handler_nursery: Nursery,
*, *,
# (host, port) to bind for channel server listen_addrs: list[AddressTypes]|None = None,
listen_sockaddrs: list[tuple[str, int]]|None = None,
task_status: TaskStatus[Nursery] = trio.TASK_STATUS_IGNORED, task_status: TaskStatus[Nursery] = trio.TASK_STATUS_IGNORED,
) -> None: ) -> None:
@ -1186,41 +1181,45 @@ class Actor:
`.cancel_server()` is called. `.cancel_server()` is called.
''' '''
if listen_sockaddrs is None: if listen_addrs is None:
listen_sockaddrs = [(None, 0)] listen_addrs = default_lo_addrs([preferred_transport])
else:
listen_addrs: list[Address] = [
wrap_address(a) for a in listen_addrs
]
self._server_down = trio.Event() self._server_down = trio.Event()
try: try:
async with trio.open_nursery() as server_n: async with trio.open_nursery() as server_n:
listeners: list[trio.abc.Listener] = [
await addr.open_listener()
for addr in listen_addrs
]
await server_n.start(
partial(
trio.serve_listeners,
handler=self._stream_handler,
listeners=listeners,
for host, port in listen_sockaddrs: # NOTE: configured such that new
listeners: list[trio.abc.Listener] = await server_n.start( # connections will stay alive even if
partial( # this server is cancelled!
trio.serve_tcp, handler_nursery=handler_nursery
handler=self._stream_handler,
port=port,
host=host,
# NOTE: configured such that new
# connections will stay alive even if
# this server is cancelled!
handler_nursery=handler_nursery,
)
) )
sockets: list[trio.socket] = [ )
getattr(listener, 'socket', 'unknown socket') log.runtime(
for listener in listeners 'Started server(s)\n'
] '\n'.join([f'|_{addr}' for addr in listen_addrs])
log.runtime( )
'Started TCP server(s)\n' self._listen_addrs.extend(listen_addrs)
f'|_{sockets}\n' self._listeners.extend(listeners)
)
self._listeners.extend(listeners)
task_status.started(server_n) task_status.started(server_n)
finally: finally:
for addr in listen_addrs:
await addr.close_listener()
# signal the server is down since nursery above terminated # signal the server is down since nursery above terminated
self._server_down.set() self._server_down.set()
@ -1579,26 +1578,21 @@ class Actor:
return False return False
@property @property
def accept_addrs(self) -> list[tuple[str, int]]: def accept_addrs(self) -> list[AddressTypes]:
''' '''
All addresses to which the transport-channel server binds All addresses to which the transport-channel server binds
and listens for new connections. and listens for new connections.
''' '''
# throws OSError on failure return [a.unwrap() for a in self._listen_addrs]
return [
listener.socket.getsockname()
for listener in self._listeners
] # type: ignore
@property @property
def accept_addr(self) -> tuple[str, int]: def accept_addr(self) -> AddressTypes:
''' '''
Primary address to which the IPC transport server is Primary address to which the IPC transport server is
bound and listening for new connections. bound and listening for new connections.
''' '''
# throws OSError on failure
return self.accept_addrs[0] return self.accept_addrs[0]
def get_parent(self) -> Portal: def get_parent(self) -> Portal:
@ -1670,7 +1664,7 @@ class Actor:
async def async_main( async def async_main(
actor: Actor, actor: Actor,
accept_addrs: tuple[str, int]|None = None, accept_addrs: AddressTypes|None = None,
# XXX: currently ``parent_addr`` is only needed for the # XXX: currently ``parent_addr`` is only needed for the
# ``multiprocessing`` backend (which pickles state sent to # ``multiprocessing`` backend (which pickles state sent to
@ -1679,7 +1673,7 @@ async def async_main(
# change this to a simple ``is_subactor: bool`` which will # change this to a simple ``is_subactor: bool`` which will
# be False when running as root actor and True when as # be False when running as root actor and True when as
# a subactor. # a subactor.
parent_addr: tuple[str, int]|None = None, parent_addr: AddressTypes|None = None,
task_status: TaskStatus[None] = trio.TASK_STATUS_IGNORED, task_status: TaskStatus[None] = trio.TASK_STATUS_IGNORED,
) -> None: ) -> None:
@ -1769,7 +1763,7 @@ async def async_main(
partial( partial(
actor._serve_forever, actor._serve_forever,
service_nursery, service_nursery,
listen_sockaddrs=accept_addrs, listen_addrs=accept_addrs,
) )
) )
except OSError as oserr: except OSError as oserr:
@ -1785,7 +1779,7 @@ async def async_main(
raise raise
accept_addrs: list[tuple[str, int]] = actor.accept_addrs accept_addrs: list[AddressTypes] = actor.accept_addrs
# NOTE: only set the loopback addr for the # NOTE: only set the loopback addr for the
# process-tree-global "root" mailbox since # process-tree-global "root" mailbox since
@ -1793,9 +1787,8 @@ async def async_main(
# their root actor over that channel. # their root actor over that channel.
if _state._runtime_vars['_is_root']: if _state._runtime_vars['_is_root']:
for addr in accept_addrs: for addr in accept_addrs:
host, _ = addr waddr = wrap_address(addr)
# TODO: generic 'lo' detector predicate if waddr == waddr.get_root():
if '127.0.0.1' in host:
_state._runtime_vars['_root_mailbox'] = addr _state._runtime_vars['_root_mailbox'] = addr
# Register with the arbiter if we're told its addr # Register with the arbiter if we're told its addr
@ -1810,24 +1803,21 @@ async def async_main(
# only on unique actor uids? # only on unique actor uids?
for addr in actor.reg_addrs: for addr in actor.reg_addrs:
try: try:
assert isinstance(addr, tuple) waddr = wrap_address(addr)
assert addr[1] # non-zero after bind assert waddr.is_valid
except AssertionError: except AssertionError:
await _debug.pause() await _debug.pause()
async with get_registry(*addr) as reg_portal: async with get_registry(addr) as reg_portal:
for accept_addr in accept_addrs: for accept_addr in accept_addrs:
accept_addr = wrap_address(accept_addr)
if not accept_addr[1]: assert accept_addr.is_valid
await _debug.pause()
assert accept_addr[1]
await reg_portal.run_from_ns( await reg_portal.run_from_ns(
'self', 'self',
'register_actor', 'register_actor',
uid=actor.uid, uid=actor.uid,
sockaddr=accept_addr, addr=accept_addr.unwrap(),
) )
is_registered: bool = True is_registered: bool = True
@ -1954,12 +1944,13 @@ async def async_main(
): ):
failed: bool = False failed: bool = False
for addr in actor.reg_addrs: for addr in actor.reg_addrs:
assert isinstance(addr, tuple) waddr = wrap_address(addr)
assert waddr.is_valid
with trio.move_on_after(0.5) as cs: with trio.move_on_after(0.5) as cs:
cs.shield = True cs.shield = True
try: try:
async with get_registry( async with get_registry(
*addr, addr,
) as reg_portal: ) as reg_portal:
await reg_portal.run_from_ns( await reg_portal.run_from_ns(
'self', 'self',
@ -2037,7 +2028,7 @@ class Arbiter(Actor):
self._registry: dict[ self._registry: dict[
tuple[str, str], tuple[str, str],
tuple[str, int], AddressTypes,
] = {} ] = {}
self._waiters: dict[ self._waiters: dict[
str, str,
@ -2053,18 +2044,18 @@ class Arbiter(Actor):
self, self,
name: str, name: str,
) -> tuple[str, int]|None: ) -> AddressTypes|None:
for uid, sockaddr in self._registry.items(): for uid, addr in self._registry.items():
if name in uid: if name in uid:
return sockaddr return addr
return None return None
async def get_registry( async def get_registry(
self self
) -> dict[str, tuple[str, int]]: ) -> dict[str, AddressTypes]:
''' '''
Return current name registry. Return current name registry.
@ -2084,7 +2075,7 @@ class Arbiter(Actor):
self, self,
name: str, name: str,
) -> list[tuple[str, int]]: ) -> list[AddressTypes]:
''' '''
Wait for a particular actor to register. Wait for a particular actor to register.
@ -2092,44 +2083,41 @@ class Arbiter(Actor):
registered. registered.
''' '''
sockaddrs: list[tuple[str, int]] = [] addrs: list[AddressTypes] = []
sockaddr: tuple[str, int] addr: AddressTypes
mailbox_info: str = 'Actor registry contact infos:\n' mailbox_info: str = 'Actor registry contact infos:\n'
for uid, sockaddr in self._registry.items(): for uid, addr in self._registry.items():
mailbox_info += ( mailbox_info += (
f'|_uid: {uid}\n' f'|_uid: {uid}\n'
f'|_sockaddr: {sockaddr}\n\n' f'|_addr: {addr}\n\n'
) )
if name == uid[0]: if name == uid[0]:
sockaddrs.append(sockaddr) addrs.append(addr)
if not sockaddrs: if not addrs:
waiter = trio.Event() waiter = trio.Event()
self._waiters.setdefault(name, []).append(waiter) self._waiters.setdefault(name, []).append(waiter)
await waiter.wait() await waiter.wait()
for uid in self._waiters[name]: for uid in self._waiters[name]:
if not isinstance(uid, trio.Event): if not isinstance(uid, trio.Event):
sockaddrs.append(self._registry[uid]) addrs.append(self._registry[uid])
log.runtime(mailbox_info) log.runtime(mailbox_info)
return sockaddrs return addrs
async def register_actor( async def register_actor(
self, self,
uid: tuple[str, str], uid: tuple[str, str],
sockaddr: tuple[str, int] addr: AddressTypes
) -> None: ) -> None:
uid = name, hash = (str(uid[0]), str(uid[1])) uid = name, hash = (str(uid[0]), str(uid[1]))
addr = (host, port) = ( waddr: Address = wrap_address(addr)
str(sockaddr[0]), if not waddr.is_valid:
int(sockaddr[1]), # should never be 0-dynamic-os-alloc
)
if port == 0:
await _debug.pause() await _debug.pause()
assert port # should never be 0-dynamic-os-alloc
self._registry[uid] = addr self._registry[uid] = addr
# pop and signal all waiter events # pop and signal all waiter events

View File

@ -46,6 +46,7 @@ from tractor._state import (
_runtime_vars, _runtime_vars,
) )
from tractor.log import get_logger from tractor.log import get_logger
from tractor._addr import AddressTypes
from tractor._portal import Portal from tractor._portal import Portal
from tractor._runtime import Actor from tractor._runtime import Actor
from tractor._entry import _mp_main from tractor._entry import _mp_main
@ -392,8 +393,8 @@ async def new_proc(
errors: dict[tuple[str, str], Exception], errors: dict[tuple[str, str], Exception],
# passed through to actor main # passed through to actor main
bind_addrs: list[tuple[str, int]], bind_addrs: list[AddressTypes],
parent_addr: tuple[str, int], parent_addr: AddressTypes,
_runtime_vars: dict[str, Any], # serialized and sent to _child _runtime_vars: dict[str, Any], # serialized and sent to _child
*, *,
@ -431,8 +432,8 @@ async def trio_proc(
errors: dict[tuple[str, str], Exception], errors: dict[tuple[str, str], Exception],
# passed through to actor main # passed through to actor main
bind_addrs: list[tuple[str, int]], bind_addrs: list[AddressTypes],
parent_addr: tuple[str, int], parent_addr: AddressTypes,
_runtime_vars: dict[str, Any], # serialized and sent to _child _runtime_vars: dict[str, Any], # serialized and sent to _child
*, *,
infect_asyncio: bool = False, infect_asyncio: bool = False,
@ -520,15 +521,15 @@ async def trio_proc(
# send a "spawning specification" which configures the # send a "spawning specification" which configures the
# initial runtime state of the child. # initial runtime state of the child.
await chan.send( sspec = SpawnSpec(
SpawnSpec( _parent_main_data=subactor._parent_main_data,
_parent_main_data=subactor._parent_main_data, enable_modules=subactor.enable_modules,
enable_modules=subactor.enable_modules, reg_addrs=subactor.reg_addrs,
reg_addrs=subactor.reg_addrs, bind_addrs=bind_addrs,
bind_addrs=bind_addrs, _runtime_vars=_runtime_vars,
_runtime_vars=_runtime_vars,
)
) )
log.runtime(f'Sending spawn spec: {str(sspec)}')
await chan.send(sspec)
# track subactor in current nursery # track subactor in current nursery
curr_actor: Actor = current_actor() curr_actor: Actor = current_actor()
@ -638,8 +639,8 @@ async def mp_proc(
subactor: Actor, subactor: Actor,
errors: dict[tuple[str, str], Exception], errors: dict[tuple[str, str], Exception],
# passed through to actor main # passed through to actor main
bind_addrs: list[tuple[str, int]], bind_addrs: list[AddressTypes],
parent_addr: tuple[str, int], parent_addr: AddressTypes,
_runtime_vars: dict[str, Any], # serialized and sent to _child _runtime_vars: dict[str, Any], # serialized and sent to _child
*, *,
infect_asyncio: bool = False, infect_asyncio: bool = False,

View File

@ -28,7 +28,13 @@ import warnings
import trio import trio
from .devx._debug import maybe_wait_for_debugger from .devx._debug import maybe_wait_for_debugger
from ._addr import (
AddressTypes,
preferred_transport,
get_address_cls
)
from ._state import current_actor, is_main_process from ._state import current_actor, is_main_process
from .log import get_logger, get_loglevel from .log import get_logger, get_loglevel
from ._runtime import Actor from ._runtime import Actor
@ -47,8 +53,6 @@ if TYPE_CHECKING:
log = get_logger(__name__) log = get_logger(__name__)
_default_bind_addr: tuple[str, int] = ('127.0.0.1', 0)
class ActorNursery: class ActorNursery:
''' '''
@ -130,8 +134,9 @@ class ActorNursery:
*, *,
bind_addrs: list[tuple[str, int]] = [_default_bind_addr], bind_addrs: list[AddressTypes]|None = None,
rpc_module_paths: list[str]|None = None, rpc_module_paths: list[str]|None = None,
enable_transports: list[str] = [preferred_transport],
enable_modules: list[str]|None = None, enable_modules: list[str]|None = None,
loglevel: str|None = None, # set log level per subactor loglevel: str|None = None, # set log level per subactor
debug_mode: bool|None = None, debug_mode: bool|None = None,
@ -156,6 +161,12 @@ class ActorNursery:
or get_loglevel() or get_loglevel()
) )
if not bind_addrs:
bind_addrs: list[AddressTypes] = [
get_address_cls(transport).get_random().unwrap()
for transport in enable_transports
]
# configure and pass runtime state # configure and pass runtime state
_rtv = _state._runtime_vars.copy() _rtv = _state._runtime_vars.copy()
_rtv['_is_root'] = False _rtv['_is_root'] = False
@ -224,7 +235,7 @@ class ActorNursery:
*, *,
name: str | None = None, name: str | None = None,
bind_addrs: tuple[str, int] = [_default_bind_addr], bind_addrs: AddressTypes|None = None,
rpc_module_paths: list[str] | None = None, rpc_module_paths: list[str] | None = None,
enable_modules: list[str] | None = None, enable_modules: list[str] | None = None,
loglevel: str | None = None, # set log level per subactor loglevel: str | None = None, # set log level per subactor

View File

@ -13,20 +13,25 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
import platform import platform
from ._transport import MsgTransport as MsgTransport from ._transport import (
MsgTransportKey as MsgTransportKey,
MsgType as MsgType,
MsgTransport as MsgTransport,
MsgpackTransport as MsgpackTransport
)
from ._tcp import ( from ._tcp import MsgpackTCPStream as MsgpackTCPStream
get_stream_addrs as get_stream_addrs, from ._uds import MsgpackUDSStream as MsgpackUDSStream
MsgpackTCPStream as MsgpackTCPStream
from ._types import (
transport_from_addr as transport_from_addr,
transport_from_stream as transport_from_stream,
) )
from ._chan import ( from ._chan import (
_connect_chan as _connect_chan, _connect_chan as _connect_chan,
get_msg_transport as get_msg_transport,
Channel as Channel Channel as Channel
) )

View File

@ -29,15 +29,19 @@ from pprint import pformat
import typing import typing
from typing import ( from typing import (
Any, Any,
Type
) )
import trio import trio
from tractor.ipc._transport import MsgTransport from tractor.ipc._transport import MsgTransport
from tractor.ipc._tcp import ( from tractor.ipc._types import (
MsgpackTCPStream, transport_from_addr,
get_stream_addrs transport_from_stream,
)
from tractor._addr import (
wrap_address,
Address,
AddressTypes
) )
from tractor.log import get_logger from tractor.log import get_logger
from tractor._exceptions import ( from tractor._exceptions import (
@ -52,17 +56,6 @@ log = get_logger(__name__)
_is_windows = platform.system() == 'Windows' _is_windows = platform.system() == 'Windows'
def get_msg_transport(
key: tuple[str, str],
) -> Type[MsgTransport]:
return {
('msgpack', 'tcp'): MsgpackTCPStream,
}[key]
class Channel: class Channel:
''' '''
An inter-process channel for communication between (remote) actors. An inter-process channel for communication between (remote) actors.
@ -77,10 +70,7 @@ class Channel:
def __init__( def __init__(
self, self,
destaddr: tuple[str, int]|None, transport: MsgTransport|None = None,
msg_transport_type_key: tuple[str, str] = ('msgpack', 'tcp'),
# TODO: optional reconnection support? # TODO: optional reconnection support?
# auto_reconnect: bool = False, # auto_reconnect: bool = False,
# on_reconnect: typing.Callable[..., typing.Awaitable] = None, # on_reconnect: typing.Callable[..., typing.Awaitable] = None,
@ -90,13 +80,9 @@ class Channel:
# self._recon_seq = on_reconnect # self._recon_seq = on_reconnect
# self._autorecon = auto_reconnect # self._autorecon = auto_reconnect
self._destaddr = destaddr
self._transport_key = msg_transport_type_key
# Either created in ``.connect()`` or passed in by # Either created in ``.connect()`` or passed in by
# user in ``.from_stream()``. # user in ``.from_stream()``.
self._stream: trio.SocketStream|None = None self._transport: MsgTransport|None = transport
self._transport: MsgTransport|None = None
# set after handshake - always uid of far end # set after handshake - always uid of far end
self.uid: tuple[str, str]|None = None self.uid: tuple[str, str]|None = None
@ -110,6 +96,10 @@ class Channel:
# runtime. # runtime.
self._cancel_called: bool = False self._cancel_called: bool = False
@property
def stream(self) -> trio.abc.Stream | None:
return self._transport.stream if self._transport else None
@property @property
def msgstream(self) -> MsgTransport: def msgstream(self) -> MsgTransport:
log.info( log.info(
@ -124,52 +114,32 @@ class Channel:
@classmethod @classmethod
def from_stream( def from_stream(
cls, cls,
stream: trio.SocketStream, stream: trio.abc.Stream,
**kwargs,
) -> Channel: ) -> Channel:
transport_cls = transport_from_stream(stream)
src, dst = get_stream_addrs(stream) return Channel(
chan = Channel( transport=transport_cls(stream)
destaddr=dst,
**kwargs,
) )
# set immediately here from provided instance @classmethod
chan._stream: trio.SocketStream = stream async def from_addr(
chan.set_msg_transport(stream) cls,
return chan addr: AddressTypes,
**kwargs
) -> Channel:
addr: Address = wrap_address(addr)
transport_cls = transport_from_addr(addr)
transport = await transport_cls.connect_to(addr, **kwargs)
def set_msg_transport( log.transport(
self, f'Opened channel[{type(transport)}]: {transport.laddr} -> {transport.raddr}'
stream: trio.SocketStream,
type_key: tuple[str, str]|None = None,
# XXX optionally provided codec pair for `msgspec`:
# https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types
codec: MsgCodec|None = None,
) -> MsgTransport:
type_key = (
type_key
or
self._transport_key
) )
# get transport type, then return Channel(transport=transport)
self._transport = get_msg_transport(
type_key
# instantiate an instance of the msg-transport
)(
stream,
codec=codec,
)
return self._transport
@cm @cm
def apply_codec( def apply_codec(
self, self,
codec: MsgCodec, codec: MsgCodec,
) -> None: ) -> None:
''' '''
Temporarily override the underlying IPC msg codec for Temporarily override the underlying IPC msg codec for
@ -189,44 +159,20 @@ class Channel:
return '<Channel with inactive transport?>' return '<Channel with inactive transport?>'
return repr( return repr(
self._transport.stream.socket._sock self._transport
).replace( # type: ignore ).replace( # type: ignore
"socket.socket", "socket.socket",
"Channel", "Channel",
) )
@property @property
def laddr(self) -> tuple[str, int]|None: def laddr(self) -> Address|None:
return self._transport.laddr if self._transport else None return self._transport.laddr if self._transport else None
@property @property
def raddr(self) -> tuple[str, int]|None: def raddr(self) -> Address|None:
return self._transport.raddr if self._transport else None return self._transport.raddr if self._transport else None
async def connect(
self,
destaddr: tuple[Any, ...] | None = None,
**kwargs
) -> MsgTransport:
if self.connected():
raise RuntimeError("channel is already connected?")
destaddr = destaddr or self._destaddr
assert isinstance(destaddr, tuple)
stream = await trio.open_tcp_stream(
*destaddr,
**kwargs
)
transport = self.set_msg_transport(stream)
log.transport(
f'Opened channel[{type(transport)}]: {self.laddr} -> {self.raddr}'
)
return transport
# TODO: something like, # TODO: something like,
# `pdbp.hideframe_on(errors=[MsgTypeError])` # `pdbp.hideframe_on(errors=[MsgTypeError])`
# instead of the `try/except` hack we have rn.. # instead of the `try/except` hack we have rn..
@ -261,7 +207,11 @@ class Channel:
# assert err # assert err
__tracebackhide__: bool = False __tracebackhide__: bool = False
else: else:
assert err.cid try:
assert err.cid
except KeyError:
raise err
raise raise
@ -388,17 +338,14 @@ class Channel:
@acm @acm
async def _connect_chan( async def _connect_chan(
host: str, addr: AddressTypes
port: int
) -> typing.AsyncGenerator[Channel, None]: ) -> typing.AsyncGenerator[Channel, None]:
''' '''
Create and connect a channel with disconnect on context manager Create and connect a channel with disconnect on context manager
teardown. teardown.
''' '''
chan = Channel((host, port)) chan = await Channel.from_addr(addr)
await chan.connect()
yield chan yield chan
with trio.CancelScope(shield=True): with trio.CancelScope(shield=True):
await chan.aclose() await chan.aclose()

View File

@ -183,6 +183,9 @@ class RingBuffSender(trio.abc.SendStream):
def wrap_fd(self) -> int: def wrap_fd(self) -> int:
return self._wrap_event.fd return self._wrap_event.fd
async def _wait_wrap(self):
await self._wrap_event.read()
async def send_all(self, data: Buffer): async def send_all(self, data: Buffer):
async with self._send_lock: async with self._send_lock:
# while data is larger than the remaining buf # while data is larger than the remaining buf
@ -193,7 +196,7 @@ class RingBuffSender(trio.abc.SendStream):
self._shm.buf[self.ptr:] = data[:remaining] self._shm.buf[self.ptr:] = data[:remaining]
# signal write and wait for reader wrap around # signal write and wait for reader wrap around
self._write_event.write(remaining) self._write_event.write(remaining)
await self._wrap_event.read() await self._wait_wrap()
# wrap around and trim already written bytes # wrap around and trim already written bytes
self._ptr = 0 self._ptr = 0

View File

@ -18,388 +18,88 @@ TCP implementation of tractor.ipc._transport.MsgTransport protocol
''' '''
from __future__ import annotations from __future__ import annotations
from collections.abc import (
AsyncGenerator,
AsyncIterator,
)
import struct
from typing import (
Any,
Callable,
)
import msgspec
from tricycle import BufferedReceiveStream
import trio import trio
from tractor.msg import MsgCodec
from tractor.log import get_logger from tractor.log import get_logger
from tractor._exceptions import ( from tractor._addr import TCPAddress
MsgTypeError, from tractor.ipc._transport import MsgpackTransport
TransportClosed,
_mk_send_mte,
_mk_recv_mte,
)
from tractor.msg import (
_ctxvar_MsgCodec,
# _codec, XXX see `self._codec` sanity/debug checks
MsgCodec,
types as msgtypes,
pretty_struct,
)
from tractor.ipc import MsgTransport
log = get_logger(__name__) log = get_logger(__name__)
def get_stream_addrs(
stream: trio.SocketStream
) -> tuple[
tuple[str, int], # local
tuple[str, int], # remote
]:
'''
Return the `trio` streaming transport prot's socket-addrs for
both the local and remote sides as a pair.
'''
# rn, should both be IP sockets
lsockname = stream.socket.getsockname()
rsockname = stream.socket.getpeername()
return (
tuple(lsockname[:2]),
tuple(rsockname[:2]),
)
# TODO: typing oddity.. not sure why we have to inherit here, but it # TODO: typing oddity.. not sure why we have to inherit here, but it
# seems to be an issue with `get_msg_transport()` returning # seems to be an issue with `get_msg_transport()` returning
# a `Type[Protocol]`; probably should make a `mypy` issue? # a `Type[Protocol]`; probably should make a `mypy` issue?
class MsgpackTCPStream(MsgTransport): class MsgpackTCPStream(MsgpackTransport):
''' '''
A ``trio.SocketStream`` delivering ``msgpack`` formatted data A ``trio.SocketStream`` delivering ``msgpack`` formatted data
using the ``msgspec`` codec lib. using the ``msgspec`` codec lib.
''' '''
address_type = TCPAddress
layer_key: int = 4 layer_key: int = 4
name_key: str = 'tcp'
# TODO: better naming for this? # def __init__(
# -[ ] check how libp2p does naming for such things? # self,
codec_key: str = 'msgpack' # stream: trio.SocketStream,
# prefix_size: int = 4,
# codec: CodecType = None,
def __init__( # ) -> None:
self, # super().__init__(
stream: trio.SocketStream, # stream,
prefix_size: int = 4, # prefix_size=prefix_size,
# codec=codec
# )
# XXX optionally provided codec pair for `msgspec`: @property
# https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types def maddr(self) -> str:
# host, port = self.raddr.unwrap()
# TODO: define this as a `Codec` struct which can be return (
# overriden dynamically by the application/runtime? f'/ipv4/{host}'
codec: tuple[ f'/{self.address_type.name_key}/{port}'
Callable[[Any], Any]|None, # coder # f'/{self.chan.uid[0]}'
Callable[[type, Any], Any]|None, # decoder # f'/{self.cid}'
]|None = None,
) -> None: # f'/cid={cid_head}..{cid_tail}'
# TODO: ? not use this ^ right ?
self.stream = stream
assert self.stream.socket
# should both be IP sockets
self._laddr, self._raddr = get_stream_addrs(stream)
# create read loop instance
self._aiter_pkts = self._iter_packets()
self._send_lock = trio.StrictFIFOLock()
# public i guess?
self.drained: list[dict] = []
self.recv_stream = BufferedReceiveStream(
transport_stream=stream
) )
self.prefix_size = prefix_size
# allow for custom IPC msg interchange format
# dynamic override Bo
self._task = trio.lowlevel.current_task()
# XXX for ctxvar debug only!
# self._codec: MsgCodec = (
# codec
# or
# _codec._ctxvar_MsgCodec.get()
# )
async def _iter_packets(self) -> AsyncGenerator[dict, None]:
'''
Yield `bytes`-blob decoded packets from the underlying TCP
stream using the current task's `MsgCodec`.
This is a streaming routine implemented as an async generator
func (which was the original design, but could be changed?)
and is allocated by a `.__call__()` inside `.__init__()` where
it is assigned to the `._aiter_pkts` attr.
'''
decodes_failed: int = 0
while True:
try:
header: bytes = await self.recv_stream.receive_exactly(4)
except (
ValueError,
ConnectionResetError,
# not sure entirely why we need this but without it we
# seem to be getting racy failures here on
# arbiter/registry name subs..
trio.BrokenResourceError,
) as trans_err:
loglevel = 'transport'
match trans_err:
# case (
# ConnectionResetError()
# ):
# loglevel = 'transport'
# peer actor (graceful??) TCP EOF but `tricycle`
# seems to raise a 0-bytes-read?
case ValueError() if (
'unclean EOF' in trans_err.args[0]
):
pass
# peer actor (task) prolly shutdown quickly due
# to cancellation
case trio.BrokenResourceError() if (
'Connection reset by peer' in trans_err.args[0]
):
pass
# unless the disconnect condition falls under "a
# normal operation breakage" we usualy console warn
# about it.
case _:
loglevel: str = 'warning'
raise TransportClosed(
message=(
f'IPC transport already closed by peer\n'
f'x)> {type(trans_err)}\n'
f' |_{self}\n'
),
loglevel=loglevel,
) from trans_err
# XXX definitely can happen if transport is closed
# manually by another `trio.lowlevel.Task` in the
# same actor; we use this in some simulated fault
# testing for ex, but generally should never happen
# under normal operation!
#
# NOTE: as such we always re-raise this error from the
# RPC msg loop!
except trio.ClosedResourceError as closure_err:
raise TransportClosed(
message=(
f'IPC transport already manually closed locally?\n'
f'x)> {type(closure_err)} \n'
f' |_{self}\n'
),
loglevel='error',
raise_on_report=(
closure_err.args[0] == 'another task closed this fd'
or
closure_err.args[0] in ['another task closed this fd']
),
) from closure_err
# graceful TCP EOF disconnect
if header == b'':
raise TransportClosed(
message=(
f'IPC transport already gracefully closed\n'
f')>\n'
f'|_{self}\n'
),
loglevel='transport',
# cause=??? # handy or no?
)
size: int
size, = struct.unpack("<I", header)
log.transport(f'received header {size}') # type: ignore
msg_bytes: bytes = await self.recv_stream.receive_exactly(size)
log.transport(f"received {msg_bytes}") # type: ignore
try:
# NOTE: lookup the `trio.Task.context`'s var for
# the current `MsgCodec`.
codec: MsgCodec = _ctxvar_MsgCodec.get()
# XXX for ctxvar debug only!
# if self._codec.pld_spec != codec.pld_spec:
# assert (
# task := trio.lowlevel.current_task()
# ) is not self._task
# self._task = task
# self._codec = codec
# log.runtime(
# f'Using new codec in {self}.recv()\n'
# f'codec: {self._codec}\n\n'
# f'msg_bytes: {msg_bytes}\n'
# )
yield codec.decode(msg_bytes)
# XXX NOTE: since the below error derives from
# `DecodeError` we need to catch is specially
# and always raise such that spec violations
# are never allowed to be caught silently!
except msgspec.ValidationError as verr:
msgtyperr: MsgTypeError = _mk_recv_mte(
msg=msg_bytes,
codec=codec,
src_validation_error=verr,
)
# XXX deliver up to `Channel.recv()` where
# a re-raise and `Error`-pack can inject the far
# end actor `.uid`.
yield msgtyperr
except (
msgspec.DecodeError,
UnicodeDecodeError,
):
if decodes_failed < 4:
# ignore decoding errors for now and assume they have to
# do with a channel drop - hope that receiving from the
# channel will raise an expected error and bubble up.
try:
msg_str: str|bytes = msg_bytes.decode()
except UnicodeDecodeError:
msg_str = msg_bytes
log.exception(
'Failed to decode msg?\n'
f'{codec}\n\n'
'Rxed bytes from wire:\n\n'
f'{msg_str!r}\n'
)
decodes_failed += 1
else:
raise
async def send(
self,
msg: msgtypes.MsgType,
strict_types: bool = True,
hide_tb: bool = False,
) -> None:
'''
Send a msgpack encoded py-object-blob-as-msg over TCP.
If `strict_types == True` then a `MsgTypeError` will be raised on any
invalid msg type
'''
__tracebackhide__: bool = hide_tb
# XXX see `trio._sync.AsyncContextManagerMixin` for details
# on the `.acquire()`/`.release()` sequencing..
async with self._send_lock:
# NOTE: lookup the `trio.Task.context`'s var for
# the current `MsgCodec`.
codec: MsgCodec = _ctxvar_MsgCodec.get()
# XXX for ctxvar debug only!
# if self._codec.pld_spec != codec.pld_spec:
# self._codec = codec
# log.runtime(
# f'Using new codec in {self}.send()\n'
# f'codec: {self._codec}\n\n'
# f'msg: {msg}\n'
# )
if type(msg) not in msgtypes.__msg_types__:
if strict_types:
raise _mk_send_mte(
msg,
codec=codec,
)
else:
log.warning(
'Sending non-`Msg`-spec msg?\n\n'
f'{msg}\n'
)
try:
bytes_data: bytes = codec.encode(msg)
except TypeError as _err:
typerr = _err
msgtyperr: MsgTypeError = _mk_send_mte(
msg,
codec=codec,
message=(
f'IPC-msg-spec violation in\n\n'
f'{pretty_struct.Struct.pformat(msg)}'
),
src_type_error=typerr,
)
raise msgtyperr from typerr
# supposedly the fastest says,
# https://stackoverflow.com/a/54027962
size: bytes = struct.pack("<I", len(bytes_data))
return await self.stream.send_all(size + bytes_data)
# ?TODO? does it help ever to dynamically show this
# frame?
# try:
# <the-above_code>
# except BaseException as _err:
# err = _err
# if not isinstance(err, MsgTypeError):
# __tracebackhide__: bool = False
# raise
@property
def laddr(self) -> tuple[str, int]:
return self._laddr
@property
def raddr(self) -> tuple[str, int]:
return self._raddr
async def recv(self) -> Any:
return await self._aiter_pkts.asend(None)
async def drain(self) -> AsyncIterator[dict]:
'''
Drain the stream's remaining messages sent from
the far end until the connection is closed by
the peer.
'''
try:
async for msg in self._iter_packets():
self.drained.append(msg)
except TransportClosed:
for msg in self.drained:
yield msg
def __aiter__(self):
return self._aiter_pkts
def connected(self) -> bool: def connected(self) -> bool:
return self.stream.socket.fileno() != -1 return self.stream.socket.fileno() != -1
@classmethod
async def connect_to(
cls,
destaddr: TCPAddress,
prefix_size: int = 4,
codec: MsgCodec|None = None,
**kwargs
) -> MsgpackTCPStream:
stream = await trio.open_tcp_stream(
*destaddr.unwrap(),
**kwargs
)
return MsgpackTCPStream(
stream,
prefix_size=prefix_size,
codec=codec
)
@classmethod
def get_stream_addrs(
cls,
stream: trio.SocketStream
) -> tuple[
tuple[str, int],
tuple[str, int]
]:
lsockname = stream.socket.getsockname()
rsockname = stream.socket.getpeername()
return (
TCPAddress.from_addr(tuple(lsockname[:2])),
TCPAddress.from_addr(tuple(rsockname[:2])),
)

View File

@ -18,13 +18,45 @@ typing.Protocol based generic msg API, implement this class to add backends for
tractor.ipc.Channel tractor.ipc.Channel
''' '''
import trio from __future__ import annotations
from typing import ( from typing import (
runtime_checkable, runtime_checkable,
Type,
Protocol, Protocol,
TypeVar, TypeVar,
ClassVar
) )
from collections.abc import AsyncIterator from collections.abc import (
AsyncGenerator,
AsyncIterator,
)
import struct
import trio
import msgspec
from tricycle import BufferedReceiveStream
from tractor.log import get_logger
from tractor._exceptions import (
MsgTypeError,
TransportClosed,
_mk_send_mte,
_mk_recv_mte,
)
from tractor.msg import (
_ctxvar_MsgCodec,
# _codec, XXX see `self._codec` sanity/debug checks
MsgCodec,
types as msgtypes,
pretty_struct,
)
from tractor._addr import Address
log = get_logger(__name__)
# (codec, transport)
MsgTransportKey = tuple[str, str]
# from tractor.msg.types import MsgType # from tractor.msg.types import MsgType
@ -44,8 +76,8 @@ class MsgTransport(Protocol[MsgType]):
stream: trio.abc.Stream stream: trio.abc.Stream
drained: list[MsgType] drained: list[MsgType]
def __init__(self, stream: trio.abc.Stream) -> None: address_type: ClassVar[Type[Address]]
... codec_key: ClassVar[str]
# XXX: should this instead be called `.sendall()`? # XXX: should this instead be called `.sendall()`?
async def send(self, msg: MsgType) -> None: async def send(self, msg: MsgType) -> None:
@ -65,10 +97,354 @@ class MsgTransport(Protocol[MsgType]):
def drain(self) -> AsyncIterator[dict]: def drain(self) -> AsyncIterator[dict]:
... ...
@classmethod
def key(cls) -> MsgTransportKey:
return cls.codec_key, cls.address_type.name_key
@property @property
def laddr(self) -> tuple[str, int]: def laddr(self) -> Address:
... ...
@property @property
def raddr(self) -> tuple[str, int]: def raddr(self) -> Address:
... ...
@property
def maddr(self) -> str:
...
@classmethod
async def connect_to(
cls,
addr: Address,
**kwargs
) -> MsgTransport:
...
@classmethod
def get_stream_addrs(
cls,
stream: trio.abc.Stream
) -> tuple[
Address, # local
Address # remote
]:
'''
Return the `trio` streaming transport prot's addrs for both
the local and remote sides as a pair.
'''
...
class MsgpackTransport(MsgTransport):
# TODO: better naming for this?
# -[ ] check how libp2p does naming for such things?
codec_key: str = 'msgpack'
def __init__(
self,
stream: trio.abc.Stream,
prefix_size: int = 4,
# XXX optionally provided codec pair for `msgspec`:
# https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types
#
# TODO: define this as a `Codec` struct which can be
# overriden dynamically by the application/runtime?
codec: MsgCodec = None,
) -> None:
self.stream = stream
self._laddr, self._raddr = self.get_stream_addrs(stream)
# create read loop instance
self._aiter_pkts = self._iter_packets()
self._send_lock = trio.StrictFIFOLock()
# public i guess?
self.drained: list[dict] = []
self.recv_stream = BufferedReceiveStream(
transport_stream=stream
)
self.prefix_size = prefix_size
# allow for custom IPC msg interchange format
# dynamic override Bo
self._task = trio.lowlevel.current_task()
# XXX for ctxvar debug only!
# self._codec: MsgCodec = (
# codec
# or
# _codec._ctxvar_MsgCodec.get()
# )
async def _iter_packets(self) -> AsyncGenerator[dict, None]:
'''
Yield `bytes`-blob decoded packets from the underlying TCP
stream using the current task's `MsgCodec`.
This is a streaming routine implemented as an async generator
func (which was the original design, but could be changed?)
and is allocated by a `.__call__()` inside `.__init__()` where
it is assigned to the `._aiter_pkts` attr.
'''
decodes_failed: int = 0
while True:
try:
header: bytes = await self.recv_stream.receive_exactly(4)
except (
ValueError,
ConnectionResetError,
# not sure entirely why we need this but without it we
# seem to be getting racy failures here on
# arbiter/registry name subs..
trio.BrokenResourceError,
) as trans_err:
loglevel = 'transport'
match trans_err:
# case (
# ConnectionResetError()
# ):
# loglevel = 'transport'
# peer actor (graceful??) TCP EOF but `tricycle`
# seems to raise a 0-bytes-read?
case ValueError() if (
'unclean EOF' in trans_err.args[0]
):
pass
# peer actor (task) prolly shutdown quickly due
# to cancellation
case trio.BrokenResourceError() if (
'Connection reset by peer' in trans_err.args[0]
):
pass
# unless the disconnect condition falls under "a
# normal operation breakage" we usualy console warn
# about it.
case _:
loglevel: str = 'warning'
raise TransportClosed(
message=(
f'IPC transport already closed by peer\n'
f'x)> {type(trans_err)}\n'
f' |_{self}\n'
),
loglevel=loglevel,
) from trans_err
# XXX definitely can happen if transport is closed
# manually by another `trio.lowlevel.Task` in the
# same actor; we use this in some simulated fault
# testing for ex, but generally should never happen
# under normal operation!
#
# NOTE: as such we always re-raise this error from the
# RPC msg loop!
except trio.ClosedResourceError as closure_err:
raise TransportClosed(
message=(
f'IPC transport already manually closed locally?\n'
f'x)> {type(closure_err)} \n'
f' |_{self}\n'
),
loglevel='error',
raise_on_report=(
closure_err.args[0] == 'another task closed this fd'
or
closure_err.args[0] in ['another task closed this fd']
),
) from closure_err
# graceful TCP EOF disconnect
if header == b'':
raise TransportClosed(
message=(
f'IPC transport already gracefully closed\n'
f')>\n'
f'|_{self}\n'
),
loglevel='transport',
# cause=??? # handy or no?
)
size: int
size, = struct.unpack("<I", header)
log.transport(f'received header {size}') # type: ignore
msg_bytes: bytes = await self.recv_stream.receive_exactly(size)
log.transport(f"received {msg_bytes}") # type: ignore
try:
# NOTE: lookup the `trio.Task.context`'s var for
# the current `MsgCodec`.
codec: MsgCodec = _ctxvar_MsgCodec.get()
# XXX for ctxvar debug only!
# if self._codec.pld_spec != codec.pld_spec:
# assert (
# task := trio.lowlevel.current_task()
# ) is not self._task
# self._task = task
# self._codec = codec
# log.runtime(
# f'Using new codec in {self}.recv()\n'
# f'codec: {self._codec}\n\n'
# f'msg_bytes: {msg_bytes}\n'
# )
yield codec.decode(msg_bytes)
# XXX NOTE: since the below error derives from
# `DecodeError` we need to catch is specially
# and always raise such that spec violations
# are never allowed to be caught silently!
except msgspec.ValidationError as verr:
msgtyperr: MsgTypeError = _mk_recv_mte(
msg=msg_bytes,
codec=codec,
src_validation_error=verr,
)
# XXX deliver up to `Channel.recv()` where
# a re-raise and `Error`-pack can inject the far
# end actor `.uid`.
yield msgtyperr
except (
msgspec.DecodeError,
UnicodeDecodeError,
):
if decodes_failed < 4:
# ignore decoding errors for now and assume they have to
# do with a channel drop - hope that receiving from the
# channel will raise an expected error and bubble up.
try:
msg_str: str|bytes = msg_bytes.decode()
except UnicodeDecodeError:
msg_str = msg_bytes
log.exception(
'Failed to decode msg?\n'
f'{codec}\n\n'
'Rxed bytes from wire:\n\n'
f'{msg_str!r}\n'
)
decodes_failed += 1
else:
raise
async def send(
self,
msg: msgtypes.MsgType,
strict_types: bool = True,
hide_tb: bool = False,
) -> None:
'''
Send a msgpack encoded py-object-blob-as-msg over TCP.
If `strict_types == True` then a `MsgTypeError` will be raised on any
invalid msg type
'''
__tracebackhide__: bool = hide_tb
# XXX see `trio._sync.AsyncContextManagerMixin` for details
# on the `.acquire()`/`.release()` sequencing..
async with self._send_lock:
# NOTE: lookup the `trio.Task.context`'s var for
# the current `MsgCodec`.
codec: MsgCodec = _ctxvar_MsgCodec.get()
# XXX for ctxvar debug only!
# if self._codec.pld_spec != codec.pld_spec:
# self._codec = codec
# log.runtime(
# f'Using new codec in {self}.send()\n'
# f'codec: {self._codec}\n\n'
# f'msg: {msg}\n'
# )
if type(msg) not in msgtypes.__msg_types__:
if strict_types:
raise _mk_send_mte(
msg,
codec=codec,
)
else:
log.warning(
'Sending non-`Msg`-spec msg?\n\n'
f'{msg}\n'
)
try:
bytes_data: bytes = codec.encode(msg)
except TypeError as _err:
typerr = _err
msgtyperr: MsgTypeError = _mk_send_mte(
msg,
codec=codec,
message=(
f'IPC-msg-spec violation in\n\n'
f'{pretty_struct.Struct.pformat(msg)}'
),
src_type_error=typerr,
)
raise msgtyperr from typerr
# supposedly the fastest says,
# https://stackoverflow.com/a/54027962
size: bytes = struct.pack("<I", len(bytes_data))
return await self.stream.send_all(size + bytes_data)
# ?TODO? does it help ever to dynamically show this
# frame?
# try:
# <the-above_code>
# except BaseException as _err:
# err = _err
# if not isinstance(err, MsgTypeError):
# __tracebackhide__: bool = False
# raise
async def recv(self) -> msgtypes.MsgType:
return await self._aiter_pkts.asend(None)
async def drain(self) -> AsyncIterator[dict]:
'''
Drain the stream's remaining messages sent from
the far end until the connection is closed by
the peer.
'''
try:
async for msg in self._iter_packets():
self.drained.append(msg)
except TransportClosed:
for msg in self.drained:
yield msg
def __aiter__(self):
return self._aiter_pkts
@property
def laddr(self) -> Address:
return self._laddr
@property
def raddr(self) -> Address:
return self._raddr

View File

@ -0,0 +1,99 @@
# tractor: structured concurrent "actors".
# Copyright 2018-eternity Tyler Goodlet.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Type
import trio
import socket
from tractor._addr import Address
from tractor.ipc._transport import (
MsgTransportKey,
MsgTransport
)
from tractor.ipc._tcp import MsgpackTCPStream
from tractor.ipc._uds import MsgpackUDSStream
# manually updated list of all supported msg transport types
_msg_transports = [
MsgpackTCPStream,
MsgpackUDSStream
]
# convert a MsgTransportKey to the corresponding transport type
_key_to_transport: dict[MsgTransportKey, Type[MsgTransport]] = {
cls.key(): cls
for cls in _msg_transports
}
# convert an Address wrapper to its corresponding transport type
_addr_to_transport: dict[Type[Address], Type[MsgTransport]] = {
cls.address_type: cls
for cls in _msg_transports
}
def transport_from_addr(
addr: Address,
codec_key: str = 'msgpack',
) -> Type[MsgTransport]:
'''
Given a destination address and a desired codec, find the
corresponding `MsgTransport` type.
'''
try:
return _addr_to_transport[type(addr)]
except KeyError:
raise NotImplementedError(
f'No known transport for address {repr(addr)}'
)
def transport_from_stream(
stream: trio.abc.Stream,
codec_key: str = 'msgpack'
) -> Type[MsgTransport]:
'''
Given an arbitrary `trio.abc.Stream` and a desired codec,
find the corresponding `MsgTransport` type.
'''
transport = None
if isinstance(stream, trio.SocketStream):
sock = stream.socket
match sock.family:
case socket.AF_INET | socket.AF_INET6:
transport = 'tcp'
case socket.AF_UNIX:
transport = 'uds'
case _:
raise NotImplementedError(
f'Unsupported socket family: {sock.family}'
)
if not transport:
raise NotImplementedError(
f'Could not figure out transport type for stream type {type(stream)}'
)
key = (codec_key, transport)
return _key_to_transport[key]

View File

@ -0,0 +1,97 @@
# tractor: structured concurrent "actors".
# Copyright 2018-eternity Tyler Goodlet.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
'''
Unix Domain Socket implementation of tractor.ipc._transport.MsgTransport protocol
'''
from __future__ import annotations
import trio
from tractor.msg import MsgCodec
from tractor.log import get_logger
from tractor._addr import UDSAddress
from tractor.ipc._transport import MsgpackTransport
log = get_logger(__name__)
class MsgpackUDSStream(MsgpackTransport):
'''
A ``trio.SocketStream`` delivering ``msgpack`` formatted data
using the ``msgspec`` codec lib.
'''
address_type = UDSAddress
layer_key: int = 7
# def __init__(
# self,
# stream: trio.SocketStream,
# prefix_size: int = 4,
# codec: CodecType = None,
# ) -> None:
# super().__init__(
# stream,
# prefix_size=prefix_size,
# codec=codec
# )
@property
def maddr(self) -> str:
filepath = self.raddr.unwrap()
return (
f'/ipv4/localhost'
f'/{self.address_type.name_key}/{filepath}'
# f'/{self.chan.uid[0]}'
# f'/{self.cid}'
# f'/cid={cid_head}..{cid_tail}'
# TODO: ? not use this ^ right ?
)
def connected(self) -> bool:
return self.stream.socket.fileno() != -1
@classmethod
async def connect_to(
cls,
addr: UDSAddress,
prefix_size: int = 4,
codec: MsgCodec|None = None,
**kwargs
) -> MsgpackUDSStream:
stream = await trio.open_unix_socket(
addr.unwrap(),
**kwargs
)
return MsgpackUDSStream(
stream,
prefix_size=prefix_size,
codec=codec
)
@classmethod
def get_stream_addrs(
cls,
stream: trio.SocketStream
) -> tuple[UDSAddress, UDSAddress]:
return (
UDSAddress.from_addr(stream.socket.getsockname()),
UDSAddress.from_addr(stream.socket.getsockname()),
)

View File

@ -47,6 +47,7 @@ from tractor.msg import (
pretty_struct, pretty_struct,
) )
from tractor.log import get_logger from tractor.log import get_logger
from tractor._addr import AddressTypes
log = get_logger('tractor.msgspec') log = get_logger('tractor.msgspec')
@ -167,8 +168,8 @@ class SpawnSpec(
# TODO: not just sockaddr pairs? # TODO: not just sockaddr pairs?
# -[ ] abstract into a `TransportAddr` type? # -[ ] abstract into a `TransportAddr` type?
reg_addrs: list[tuple[str, int]] reg_addrs: list[AddressTypes]
bind_addrs: list[tuple[str, int]] bind_addrs: list[AddressTypes]
# TODO: caps based RPC support in the payload? # TODO: caps based RPC support in the payload?