Finally switch to using address protocol in all runtime

Guillermo Rodriguez 2025-03-23 00:14:04 -03:00
parent 7400f89753
commit 34a2f0c1f3
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
23 changed files with 590 additions and 304 deletions

View File

@ -9,7 +9,7 @@ async def main(service_name):
async with tractor.open_nursery() as an: async with tractor.open_nursery() as an:
await an.start_actor(service_name) await an.start_actor(service_name)
async with tractor.get_registry('127.0.0.1', 1616) as portal: async with tractor.get_registry(('127.0.0.1', 1616)) as portal:
print(f"Arbiter is listening on {portal.channel}") print(f"Arbiter is listening on {portal.channel}")
async with tractor.wait_for_actor(service_name) as sockaddr: async with tractor.wait_for_actor(service_name) as sockaddr:

View File

@ -26,7 +26,7 @@ async def test_reg_then_unreg(reg_addr):
portal = await n.start_actor('actor', enable_modules=[__name__]) portal = await n.start_actor('actor', enable_modules=[__name__])
uid = portal.channel.uid uid = portal.channel.uid
async with tractor.get_registry(*reg_addr) as aportal: async with tractor.get_registry(reg_addr) as aportal:
# this local actor should be the arbiter # this local actor should be the arbiter
assert actor is aportal.actor assert actor is aportal.actor
@ -160,7 +160,7 @@ async def spawn_and_check_registry(
async with tractor.open_root_actor( async with tractor.open_root_actor(
registry_addrs=[reg_addr], registry_addrs=[reg_addr],
): ):
async with tractor.get_registry(*reg_addr) as portal: async with tractor.get_registry(reg_addr) as portal:
# runtime needs to be up to call this # runtime needs to be up to call this
actor = tractor.current_actor() actor = tractor.current_actor()
@ -300,7 +300,7 @@ async def close_chans_before_nursery(
async with tractor.open_root_actor( async with tractor.open_root_actor(
registry_addrs=[reg_addr], registry_addrs=[reg_addr],
): ):
async with tractor.get_registry(*reg_addr) as aportal: async with tractor.get_registry(reg_addr) as aportal:
try: try:
get_reg = partial(unpack_reg, aportal) get_reg = partial(unpack_reg, aportal)

View File

@ -871,7 +871,7 @@ async def serve_subactors(
) )
await ipc.send(( await ipc.send((
peer.chan.uid, peer.chan.uid,
peer.chan.raddr, peer.chan.raddr.unwrap(),
)) ))
print('Spawner exiting spawn serve loop!') print('Spawner exiting spawn serve loop!')

View File

@ -38,7 +38,7 @@ async def test_self_is_registered_localportal(reg_addr):
"Verify waiting on the arbiter to register itself using a local portal." "Verify waiting on the arbiter to register itself using a local portal."
actor = tractor.current_actor() actor = tractor.current_actor()
assert actor.is_arbiter assert actor.is_arbiter
async with tractor.get_registry(*reg_addr) as portal: async with tractor.get_registry(reg_addr) as portal:
assert isinstance(portal, tractor._portal.LocalPortal) assert isinstance(portal, tractor._portal.LocalPortal)
with trio.fail_after(0.2): with trio.fail_after(0.2):

View File

@ -32,7 +32,7 @@ def test_abort_on_sigint(daemon):
@tractor_test @tractor_test
async def test_cancel_remote_arbiter(daemon, reg_addr): async def test_cancel_remote_arbiter(daemon, reg_addr):
assert not tractor.current_actor().is_arbiter assert not tractor.current_actor().is_arbiter
async with tractor.get_registry(*reg_addr) as portal: async with tractor.get_registry(reg_addr) as portal:
await portal.cancel_actor() await portal.cancel_actor()
time.sleep(0.1) time.sleep(0.1)
@ -41,7 +41,7 @@ async def test_cancel_remote_arbiter(daemon, reg_addr):
# no arbiter socket should exist # no arbiter socket should exist
with pytest.raises(OSError): with pytest.raises(OSError):
async with tractor.get_registry(*reg_addr) as portal: async with tractor.get_registry(reg_addr) as portal:
pass pass

View File

@ -77,7 +77,7 @@ async def movie_theatre_question():
async def test_movie_theatre_convo(start_method): async def test_movie_theatre_convo(start_method):
"""The main ``tractor`` routine. """The main ``tractor`` routine.
""" """
async with tractor.open_nursery() as n: async with tractor.open_nursery(debug_mode=True) as n:
portal = await n.start_actor( portal = await n.start_actor(
'frank', 'frank',

301
tractor/_addr.py 100644
View File

@ -0,0 +1,301 @@
# tractor: structured concurrent "actors".
# Copyright 2018-eternity Tyler Goodlet.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
import tempfile
from uuid import uuid4
from typing import (
Protocol,
ClassVar,
TypeVar,
Union,
Type
)
import trio
from trio import socket
NamespaceType = TypeVar('NamespaceType')
AddressType = TypeVar('AddressType')
StreamType = TypeVar('StreamType')
ListenerType = TypeVar('ListenerType')
class Address(Protocol[
NamespaceType,
AddressType,
StreamType,
ListenerType
]):
name_key: ClassVar[str]
address_type: ClassVar[Type[AddressType]]
@property
def is_valid(self) -> bool:
...
@property
def namespace(self) -> NamespaceType|None:
...
@classmethod
def from_addr(cls, addr: AddressType) -> Address:
...
def unwrap(self) -> AddressType:
...
@classmethod
def get_random(cls, namespace: NamespaceType | None = None) -> Address:
...
@classmethod
def get_root(cls) -> Address:
...
def __repr__(self) -> str:
...
def __eq__(self, other) -> bool:
...
async def open_stream(self, **kwargs) -> StreamType:
...
async def open_listener(self, **kwargs) -> ListenerType:
...
class TCPAddress(Address[
str,
tuple[str, int],
trio.SocketStream,
trio.SocketListener
]):
name_key: str = 'tcp'
address_type: type = tuple[str, int]
def __init__(
self,
host: str,
port: int
):
if (
not isinstance(host, str)
or
not isinstance(port, int)
):
raise TypeError(f'Expected host {host} to be str and port {port} to be int')
self._host = host
self._port = port
@property
def is_valid(self) -> bool:
return self._port != 0
@property
def namespace(self) -> str:
return self._host
@classmethod
def from_addr(cls, addr: tuple[str, int]) -> TCPAddress:
return TCPAddress(addr[0], addr[1])
def unwrap(self) -> tuple[str, int]:
return self._host, self._port
@classmethod
def get_random(cls, namespace: str = '127.0.0.1') -> TCPAddress:
return TCPAddress(namespace, 0)
@classmethod
def get_root(cls) -> Address:
return TCPAddress('127.0.0.1', 1616)
def __repr__(self) -> str:
return f'{type(self)} @ {self.unwrap()}'
def __eq__(self, other) -> bool:
if not isinstance(other, TCPAddress):
raise TypeError(
f'Can not compare {type(other)} with {type(self)}'
)
return (
self._host == other._host
and
self._port == other._port
)
async def open_stream(self, **kwargs) -> trio.SocketStream:
stream = await trio.open_tcp_stream(
self._host,
self._port,
**kwargs
)
self._host, self._port = stream.socket.getsockname()[:2]
return stream
async def open_listener(self, **kwargs) -> trio.SocketListener:
listeners = await trio.open_tcp_listeners(
host=self._host,
port=self._port,
**kwargs
)
assert len(listeners) == 1
listener = listeners[0]
self._host, self._port = listener.socket.getsockname()[:2]
return listener
class UDSAddress(Address[
None,
str,
trio.SocketStream,
trio.SocketListener
]):
name_key: str = 'uds'
address_type: type = str
def __init__(
self,
filepath: str
):
self._filepath = filepath
@property
def is_valid(self) -> bool:
return True
@property
def namespace(self) -> None:
return
@classmethod
def from_addr(cls, filepath: str) -> UDSAddress:
return UDSAddress(filepath)
def unwrap(self) -> str:
return self._filepath
@classmethod
def get_random(cls, _ns: None = None) -> UDSAddress:
return UDSAddress(f'{tempfile.gettempdir()}/{uuid4().sock}')
@classmethod
def get_root(cls) -> Address:
return UDSAddress('tractor.sock')
def __repr__(self) -> str:
return f'{type(self)} @ {self._filepath}'
def __eq__(self, other) -> bool:
if not isinstance(other, UDSAddress):
raise TypeError(
f'Can not compare {type(other)} with {type(self)}'
)
return self._filepath == other._filepath
async def open_stream(self, **kwargs) -> trio.SocketStream:
stream = await trio.open_tcp_stream(
self._filepath,
**kwargs
)
self._binded = True
return stream
async def open_listener(self, **kwargs) -> trio.SocketListener:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.bind(self._filepath)
sock.listen()
self._binded = True
return trio.SocketListener(sock)
preferred_transport = 'tcp'
_address_types = (
TCPAddress,
UDSAddress
)
_default_addrs: dict[str, Type[Address]] = {
cls.name_key: cls
for cls in _address_types
}
AddressTypes = Union[
tuple([
cls.address_type
for cls in _address_types
])
]
_default_lo_addrs: dict[
str,
AddressTypes
] = {
cls.name_key: cls.get_root().unwrap()
for cls in _address_types
}
def get_address_cls(name: str) -> Type[Address]:
return _default_addrs[name]
def is_wrapped_addr(addr: any) -> bool:
return type(addr) in _address_types
def wrap_address(addr: AddressTypes) -> Address:
if is_wrapped_addr(addr):
return addr
cls = None
match addr:
case str():
cls = UDSAddress
case tuple() | list():
cls = TCPAddress
case None:
cls = get_address_cls(preferred_transport)
addr = cls.get_root().unwrap()
case _:
raise TypeError(
f'Can not wrap addr {addr} of type {type(addr)}'
)
return cls.from_addr(addr)
def default_lo_addrs(transports: list[str]) -> list[AddressTypes]:
return [
_default_lo_addrs[transport]
for transport in transports
]

View File

@ -31,8 +31,7 @@ def parse_uid(arg):
return str(name), str(uuid) # ensures str encoding return str(name), str(uuid) # ensures str encoding
def parse_ipaddr(arg): def parse_ipaddr(arg):
host, port = literal_eval(arg) return literal_eval(arg)
return (str(host), int(port))
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -859,19 +859,10 @@ class Context:
@property @property
def dst_maddr(self) -> str: def dst_maddr(self) -> str:
chan: Channel = self.chan chan: Channel = self.chan
dst_addr, dst_port = chan.raddr
trans: MsgTransport = chan.transport trans: MsgTransport = chan.transport
# cid: str = self.cid # cid: str = self.cid
# cid_head, cid_tail = cid[:6], cid[-6:] # cid_head, cid_tail = cid[:6], cid[-6:]
return ( return trans.maddr
f'/ipv4/{dst_addr}'
f'/{trans.name_key}/{dst_port}'
# f'/{self.chan.uid[0]}'
# f'/{self.cid}'
# f'/cid={cid_head}..{cid_tail}'
# TODO: ? not use this ^ right ?
)
dmaddr = dst_maddr dmaddr = dst_maddr

View File

@ -30,6 +30,12 @@ from contextlib import asynccontextmanager as acm
from tractor.log import get_logger from tractor.log import get_logger
from .trionics import gather_contexts from .trionics import gather_contexts
from .ipc import _connect_chan, Channel from .ipc import _connect_chan, Channel
from ._addr import (
AddressTypes,
Address,
preferred_transport,
wrap_address
)
from ._portal import ( from ._portal import (
Portal, Portal,
open_portal, open_portal,
@ -48,11 +54,7 @@ log = get_logger(__name__)
@acm @acm
async def get_registry( async def get_registry(addr: AddressTypes) -> AsyncGenerator[
host: str,
port: int,
) -> AsyncGenerator[
Portal | LocalPortal | None, Portal | LocalPortal | None,
None, None,
]: ]:
@ -69,13 +71,13 @@ async def get_registry(
# (likely a re-entrant call from the arbiter actor) # (likely a re-entrant call from the arbiter actor)
yield LocalPortal( yield LocalPortal(
actor, actor,
Channel((host, port)) await Channel.from_addr(addr)
) )
else: else:
# TODO: try to look pre-existing connection from # TODO: try to look pre-existing connection from
# `Actor._peers` and use it instead? # `Actor._peers` and use it instead?
async with ( async with (
_connect_chan((host, port)) as chan, _connect_chan(addr) as chan,
open_portal(chan) as regstr_ptl, open_portal(chan) as regstr_ptl,
): ):
yield regstr_ptl yield regstr_ptl
@ -89,11 +91,10 @@ async def get_root(
# TODO: rename mailbox to `_root_maddr` when we finally # TODO: rename mailbox to `_root_maddr` when we finally
# add and impl libp2p multi-addrs? # add and impl libp2p multi-addrs?
host, port = _runtime_vars['_root_mailbox'] addr = _runtime_vars['_root_mailbox']
assert host is not None
async with ( async with (
_connect_chan((host, port)) as chan, _connect_chan(addr) as chan,
open_portal(chan, **kwargs) as portal, open_portal(chan, **kwargs) as portal,
): ):
yield portal yield portal
@ -134,10 +135,10 @@ def get_peer_by_name(
@acm @acm
async def query_actor( async def query_actor(
name: str, name: str,
regaddr: tuple[str, int]|None = None, regaddr: AddressTypes|None = None,
) -> AsyncGenerator[ ) -> AsyncGenerator[
tuple[str, int]|None, AddressTypes|None,
None, None,
]: ]:
''' '''
@ -163,31 +164,31 @@ async def query_actor(
return return
reg_portal: Portal reg_portal: Portal
regaddr: tuple[str, int] = regaddr or actor.reg_addrs[0] regaddr: Address = wrap_address(regaddr) or actor.reg_addrs[0]
async with get_registry(*regaddr) as reg_portal: async with get_registry(regaddr) as reg_portal:
# TODO: return portals to all available actors - for now # TODO: return portals to all available actors - for now
# just the last one that registered # just the last one that registered
sockaddr: tuple[str, int] = await reg_portal.run_from_ns( addr: AddressTypes = await reg_portal.run_from_ns(
'self', 'self',
'find_actor', 'find_actor',
name=name, name=name,
) )
yield sockaddr yield addr
@acm @acm
async def maybe_open_portal( async def maybe_open_portal(
addr: tuple[str, int], addr: AddressTypes,
name: str, name: str,
): ):
async with query_actor( async with query_actor(
name=name, name=name,
regaddr=addr, regaddr=addr,
) as sockaddr: ) as addr:
pass pass
if sockaddr: if addr:
async with _connect_chan(sockaddr) as chan: async with _connect_chan(addr) as chan:
async with open_portal(chan) as portal: async with open_portal(chan) as portal:
yield portal yield portal
else: else:
@ -197,7 +198,8 @@ async def maybe_open_portal(
@acm @acm
async def find_actor( async def find_actor(
name: str, name: str,
registry_addrs: list[tuple[str, int]]|None = None, registry_addrs: list[AddressTypes]|None = None,
enable_transports: list[str] = [preferred_transport],
only_first: bool = True, only_first: bool = True,
raise_on_none: bool = False, raise_on_none: bool = False,
@ -224,15 +226,15 @@ async def find_actor(
# XXX NOTE: make sure to dynamically read the value on # XXX NOTE: make sure to dynamically read the value on
# every call since something may change it globally (eg. # every call since something may change it globally (eg.
# like in our discovery test suite)! # like in our discovery test suite)!
from . import _root from ._addr import default_lo_addrs
registry_addrs = ( registry_addrs = (
_runtime_vars['_registry_addrs'] _runtime_vars['_registry_addrs']
or or
_root._default_lo_addrs default_lo_addrs(enable_transports)
) )
maybe_portals: list[ maybe_portals: list[
AsyncContextManager[tuple[str, int]] AsyncContextManager[AddressTypes]
] = list( ] = list(
maybe_open_portal( maybe_open_portal(
addr=addr, addr=addr,
@ -274,7 +276,7 @@ async def find_actor(
@acm @acm
async def wait_for_actor( async def wait_for_actor(
name: str, name: str,
registry_addr: tuple[str, int] | None = None, registry_addr: AddressTypes | None = None,
) -> AsyncGenerator[Portal, None]: ) -> AsyncGenerator[Portal, None]:
''' '''
@ -291,7 +293,7 @@ async def wait_for_actor(
yield peer_portal yield peer_portal
return return
regaddr: tuple[str, int] = ( regaddr: AddressTypes = (
registry_addr registry_addr
or or
actor.reg_addrs[0] actor.reg_addrs[0]
@ -299,8 +301,8 @@ async def wait_for_actor(
# TODO: use `.trionics.gather_contexts()` like # TODO: use `.trionics.gather_contexts()` like
# above in `find_actor()` as well? # above in `find_actor()` as well?
reg_portal: Portal reg_portal: Portal
async with get_registry(*regaddr) as reg_portal: async with get_registry(regaddr) as reg_portal:
sockaddrs = await reg_portal.run_from_ns( addrs = await reg_portal.run_from_ns(
'self', 'self',
'wait_for_actor', 'wait_for_actor',
name=name, name=name,
@ -308,8 +310,8 @@ async def wait_for_actor(
# get latest registered addr by default? # get latest registered addr by default?
# TODO: offer multi-portal yields in multi-homed case? # TODO: offer multi-portal yields in multi-homed case?
sockaddr: tuple[str, int] = sockaddrs[-1] addr: AddressTypes = addrs[-1]
async with _connect_chan(sockaddr) as chan: async with _connect_chan(addr) as chan:
async with open_portal(chan) as portal: async with open_portal(chan) as portal:
yield portal yield portal

View File

@ -37,6 +37,7 @@ from .log import (
from . import _state from . import _state
from .devx import _debug from .devx import _debug
from .to_asyncio import run_as_asyncio_guest from .to_asyncio import run_as_asyncio_guest
from ._addr import AddressTypes
from ._runtime import ( from ._runtime import (
async_main, async_main,
Actor, Actor,
@ -52,10 +53,10 @@ log = get_logger(__name__)
def _mp_main( def _mp_main(
actor: Actor, actor: Actor,
accept_addrs: list[tuple[str, int]], accept_addrs: list[AddressTypes],
forkserver_info: tuple[Any, Any, Any, Any, Any], forkserver_info: tuple[Any, Any, Any, Any, Any],
start_method: SpawnMethodKey, start_method: SpawnMethodKey,
parent_addr: tuple[str, int] | None = None, parent_addr: AddressTypes | None = None,
infect_asyncio: bool = False, infect_asyncio: bool = False,
) -> None: ) -> None:
@ -206,7 +207,7 @@ def nest_from_op(
def _trio_main( def _trio_main(
actor: Actor, actor: Actor,
*, *,
parent_addr: tuple[str, int] | None = None, parent_addr: AddressTypes | None = None,
infect_asyncio: bool = False, infect_asyncio: bool = False,
) -> None: ) -> None:

View File

@ -43,21 +43,18 @@ from .devx import _debug
from . import _spawn from . import _spawn
from . import _state from . import _state
from . import log from . import log
from .ipc import _connect_chan from .ipc import (
_connect_chan,
)
from ._addr import (
AddressTypes,
wrap_address,
preferred_transport,
default_lo_addrs
)
from ._exceptions import is_multi_cancelled from ._exceptions import is_multi_cancelled
# set at startup and after forks
_default_host: str = '127.0.0.1'
_default_port: int = 1616
# default registry always on localhost
_default_lo_addrs: list[tuple[str, int]] = [(
_default_host,
_default_port,
)]
logger = log.get_logger('tractor') logger = log.get_logger('tractor')
@ -66,10 +63,12 @@ async def open_root_actor(
*, *,
# defaults are above # defaults are above
registry_addrs: list[tuple[str, int]]|None = None, registry_addrs: list[AddressTypes]|None = None,
# defaults are above # defaults are above
arbiter_addr: tuple[str, int]|None = None, arbiter_addr: tuple[AddressTypes]|None = None,
enable_transports: list[str] = [preferred_transport],
name: str|None = 'root', name: str|None = 'root',
@ -195,11 +194,9 @@ async def open_root_actor(
) )
registry_addrs = [arbiter_addr] registry_addrs = [arbiter_addr]
registry_addrs: list[tuple[str, int]] = ( if not registry_addrs:
registry_addrs registry_addrs: list[AddressTypes] = default_lo_addrs(enable_transports)
or
_default_lo_addrs
)
assert registry_addrs assert registry_addrs
loglevel = ( loglevel = (
@ -248,10 +245,10 @@ async def open_root_actor(
enable_stack_on_sig() enable_stack_on_sig()
# closed into below ping task-func # closed into below ping task-func
ponged_addrs: list[tuple[str, int]] = [] ponged_addrs: list[AddressTypes] = []
async def ping_tpt_socket( async def ping_tpt_socket(
addr: tuple[str, int], addr: AddressTypes,
timeout: float = 1, timeout: float = 1,
) -> None: ) -> None:
''' '''
@ -284,10 +281,10 @@ async def open_root_actor(
for addr in registry_addrs: for addr in registry_addrs:
tn.start_soon( tn.start_soon(
ping_tpt_socket, ping_tpt_socket,
tuple(addr), # TODO: just drop this requirement? addr,
) )
trans_bind_addrs: list[tuple[str, int]] = [] trans_bind_addrs: list[AddressTypes] = []
# Create a new local root-actor instance which IS NOT THE # Create a new local root-actor instance which IS NOT THE
# REGISTRAR # REGISTRAR
@ -311,9 +308,12 @@ async def open_root_actor(
) )
# DO NOT use the registry_addrs as the transport server # DO NOT use the registry_addrs as the transport server
# addrs for this new non-registar, root-actor. # addrs for this new non-registar, root-actor.
for host, port in ponged_addrs: for addr in ponged_addrs:
# NOTE: zero triggers dynamic OS port allocation waddr = wrap_address(addr)
trans_bind_addrs.append((host, 0)) print(waddr)
trans_bind_addrs.append(
waddr.get_random(namespace=waddr.namespace)
)
# Start this local actor as the "registrar", aka a regular # Start this local actor as the "registrar", aka a regular
# actor who manages the local registry of "mailboxes" of # actor who manages the local registry of "mailboxes" of
@ -322,7 +322,7 @@ async def open_root_actor(
# NOTE that if the current actor IS THE REGISTAR, the # NOTE that if the current actor IS THE REGISTAR, the
# following init steps are taken: # following init steps are taken:
# - the tranport layer server is bound to each (host, port) # - the tranport layer server is bound to each addr
# pair defined in provided registry_addrs, or the default. # pair defined in provided registry_addrs, or the default.
trans_bind_addrs = registry_addrs trans_bind_addrs = registry_addrs
@ -462,7 +462,7 @@ def run_daemon(
# runtime kwargs # runtime kwargs
name: str | None = 'root', name: str | None = 'root',
registry_addrs: list[tuple[str, int]] = _default_lo_addrs, registry_addrs: list[AddressTypes]|None = None,
start_method: str | None = None, start_method: str | None = None,
debug_mode: bool = False, debug_mode: bool = False,

View File

@ -74,6 +74,12 @@ from tractor.msg import (
types as msgtypes, types as msgtypes,
) )
from .ipc import Channel from .ipc import Channel
from ._addr import (
AddressTypes,
Address,
TCPAddress,
wrap_address,
)
from ._context import ( from ._context import (
mk_context, mk_context,
Context, Context,
@ -179,11 +185,11 @@ class Actor:
enable_modules: list[str] = [], enable_modules: list[str] = [],
uid: str|None = None, uid: str|None = None,
loglevel: str|None = None, loglevel: str|None = None,
registry_addrs: list[tuple[str, int]]|None = None, registry_addrs: list[AddressTypes]|None = None,
spawn_method: str|None = None, spawn_method: str|None = None,
# TODO: remove! # TODO: remove!
arbiter_addr: tuple[str, int]|None = None, arbiter_addr: AddressTypes|None = None,
) -> None: ) -> None:
''' '''
@ -223,7 +229,7 @@ class Actor:
DeprecationWarning, DeprecationWarning,
stacklevel=2, stacklevel=2,
) )
registry_addrs: list[tuple[str, int]] = [arbiter_addr] registry_addrs: list[AddressTypes] = [arbiter_addr]
# marked by the process spawning backend at startup # marked by the process spawning backend at startup
# will be None for the parent most process started manually # will be None for the parent most process started manually
@ -257,6 +263,7 @@ class Actor:
] = {} ] = {}
self._listeners: list[trio.abc.Listener] = [] self._listeners: list[trio.abc.Listener] = []
self._listen_addrs: list[Address] = []
self._parent_chan: Channel|None = None self._parent_chan: Channel|None = None
self._forkserver_info: tuple|None = None self._forkserver_info: tuple|None = None
@ -269,13 +276,13 @@ class Actor:
# when provided, init the registry addresses property from # when provided, init the registry addresses property from
# input via the validator. # input via the validator.
self._reg_addrs: list[tuple[str, int]] = [] self._reg_addrs: list[AddressTypes] = []
if registry_addrs: if registry_addrs:
self.reg_addrs: list[tuple[str, int]] = registry_addrs self.reg_addrs: list[AddressTypes] = registry_addrs
_state._runtime_vars['_registry_addrs'] = registry_addrs _state._runtime_vars['_registry_addrs'] = registry_addrs
@property @property
def reg_addrs(self) -> list[tuple[str, int]]: def reg_addrs(self) -> list[AddressTypes]:
''' '''
List of (socket) addresses for all known (and contactable) List of (socket) addresses for all known (and contactable)
registry actors. registry actors.
@ -286,7 +293,7 @@ class Actor:
@reg_addrs.setter @reg_addrs.setter
def reg_addrs( def reg_addrs(
self, self,
addrs: list[tuple[str, int]], addrs: list[AddressTypes],
) -> None: ) -> None:
if not addrs: if not addrs:
log.warning( log.warning(
@ -295,16 +302,7 @@ class Actor:
) )
return return
# always sanity check the input list since it's critical self._reg_addrs = addrs
# that addrs are correct for discovery sys operation.
for addr in addrs:
if not isinstance(addr, tuple):
raise ValueError(
'Expected `Actor.reg_addrs: list[tuple[str, int]]`\n'
f'Got {addrs}'
)
self._reg_addrs = addrs
async def wait_for_peer( async def wait_for_peer(
self, self,
@ -1024,11 +1022,11 @@ class Actor:
async def _from_parent( async def _from_parent(
self, self,
parent_addr: tuple[str, int]|None, parent_addr: AddressTypes|None,
) -> tuple[ ) -> tuple[
Channel, Channel,
list[tuple[str, int]]|None, list[AddressTypes]|None,
]: ]:
''' '''
Bootstrap this local actor's runtime config from its parent by Bootstrap this local actor's runtime config from its parent by
@ -1040,13 +1038,13 @@ class Actor:
# Connect back to the parent actor and conduct initial # Connect back to the parent actor and conduct initial
# handshake. From this point on if we error, we # handshake. From this point on if we error, we
# attempt to ship the exception back to the parent. # attempt to ship the exception back to the parent.
chan = await Channel.from_destaddr(parent_addr) chan = await Channel.from_addr(wrap_address(parent_addr))
# TODO: move this into a `Channel.handshake()`? # TODO: move this into a `Channel.handshake()`?
# Initial handshake: swap names. # Initial handshake: swap names.
await self._do_handshake(chan) await self._do_handshake(chan)
accept_addrs: list[tuple[str, int]]|None = None accept_addrs: list[AddressTypes]|None = None
if self._spawn_method == "trio": if self._spawn_method == "trio":
@ -1063,7 +1061,7 @@ class Actor:
# if "trace"/"util" mode is enabled? # if "trace"/"util" mode is enabled?
f'{pretty_struct.pformat(spawnspec)}\n' f'{pretty_struct.pformat(spawnspec)}\n'
) )
accept_addrs: list[tuple[str, int]] = spawnspec.bind_addrs accept_addrs: list[AddressTypes] = spawnspec.bind_addrs
# TODO: another `Struct` for rtvs.. # TODO: another `Struct` for rtvs..
rvs: dict[str, Any] = spawnspec._runtime_vars rvs: dict[str, Any] = spawnspec._runtime_vars
@ -1170,8 +1168,7 @@ class Actor:
self, self,
handler_nursery: Nursery, handler_nursery: Nursery,
*, *,
# (host, port) to bind for channel server listen_addrs: list[AddressTypes]|None = None,
listen_sockaddrs: list[tuple[str, int]]|None = None,
task_status: TaskStatus[Nursery] = trio.TASK_STATUS_IGNORED, task_status: TaskStatus[Nursery] = trio.TASK_STATUS_IGNORED,
) -> None: ) -> None:
@ -1183,37 +1180,39 @@ class Actor:
`.cancel_server()` is called. `.cancel_server()` is called.
''' '''
if listen_sockaddrs is None: if listen_addrs is None:
listen_sockaddrs = [(None, 0)] listen_addrs = [TCPAddress.get_random()]
else:
listen_addrs: list[Address] = [
wrap_address(a) for a in listen_addrs
]
self._server_down = trio.Event() self._server_down = trio.Event()
try: try:
async with trio.open_nursery() as server_n: async with trio.open_nursery() as server_n:
listeners: list[trio.abc.Listener] = [
await addr.open_listener()
for addr in listen_addrs
]
await server_n.start(
partial(
trio.serve_listeners,
handler=self._stream_handler,
listeners=listeners,
for host, port in listen_sockaddrs: # NOTE: configured such that new
listeners: list[trio.abc.Listener] = await server_n.start( # connections will stay alive even if
partial( # this server is cancelled!
trio.serve_tcp, handler_nursery=handler_nursery
handler=self._stream_handler,
port=port,
host=host,
# NOTE: configured such that new
# connections will stay alive even if
# this server is cancelled!
handler_nursery=handler_nursery,
)
) )
sockets: list[trio.socket] = [ )
getattr(listener, 'socket', 'unknown socket') log.runtime(
for listener in listeners 'Started server(s)\n'
] '\n'.join([f'|_{addr}' for addr in listen_addrs])
log.runtime( )
'Started TCP server(s)\n' self._listen_addrs.extend(listen_addrs)
f'|_{sockets}\n' self._listeners.extend(listeners)
)
self._listeners.extend(listeners)
task_status.started(server_n) task_status.started(server_n)
@ -1576,26 +1575,21 @@ class Actor:
return False return False
@property @property
def accept_addrs(self) -> list[tuple[str, int]]: def accept_addrs(self) -> list[AddressTypes]:
''' '''
All addresses to which the transport-channel server binds All addresses to which the transport-channel server binds
and listens for new connections. and listens for new connections.
''' '''
# throws OSError on failure return [a.unwrap() for a in self._listen_addrs]
return [
listener.socket.getsockname()
for listener in self._listeners
] # type: ignore
@property @property
def accept_addr(self) -> tuple[str, int]: def accept_addr(self) -> AddressTypes:
''' '''
Primary address to which the IPC transport server is Primary address to which the IPC transport server is
bound and listening for new connections. bound and listening for new connections.
''' '''
# throws OSError on failure
return self.accept_addrs[0] return self.accept_addrs[0]
def get_parent(self) -> Portal: def get_parent(self) -> Portal:
@ -1667,7 +1661,7 @@ class Actor:
async def async_main( async def async_main(
actor: Actor, actor: Actor,
accept_addrs: tuple[str, int]|None = None, accept_addrs: AddressTypes|None = None,
# XXX: currently ``parent_addr`` is only needed for the # XXX: currently ``parent_addr`` is only needed for the
# ``multiprocessing`` backend (which pickles state sent to # ``multiprocessing`` backend (which pickles state sent to
@ -1676,7 +1670,7 @@ async def async_main(
# change this to a simple ``is_subactor: bool`` which will # change this to a simple ``is_subactor: bool`` which will
# be False when running as root actor and True when as # be False when running as root actor and True when as
# a subactor. # a subactor.
parent_addr: tuple[str, int]|None = None, parent_addr: AddressTypes|None = None,
task_status: TaskStatus[None] = trio.TASK_STATUS_IGNORED, task_status: TaskStatus[None] = trio.TASK_STATUS_IGNORED,
) -> None: ) -> None:
@ -1766,7 +1760,7 @@ async def async_main(
partial( partial(
actor._serve_forever, actor._serve_forever,
service_nursery, service_nursery,
listen_sockaddrs=accept_addrs, listen_addrs=accept_addrs,
) )
) )
except OSError as oserr: except OSError as oserr:
@ -1782,7 +1776,7 @@ async def async_main(
raise raise
accept_addrs: list[tuple[str, int]] = actor.accept_addrs accept_addrs: list[AddressTypes] = actor.accept_addrs
# NOTE: only set the loopback addr for the # NOTE: only set the loopback addr for the
# process-tree-global "root" mailbox since # process-tree-global "root" mailbox since
@ -1790,9 +1784,8 @@ async def async_main(
# their root actor over that channel. # their root actor over that channel.
if _state._runtime_vars['_is_root']: if _state._runtime_vars['_is_root']:
for addr in accept_addrs: for addr in accept_addrs:
host, _ = addr waddr = wrap_address(addr)
# TODO: generic 'lo' detector predicate if waddr == waddr.get_root():
if '127.0.0.1' in host:
_state._runtime_vars['_root_mailbox'] = addr _state._runtime_vars['_root_mailbox'] = addr
# Register with the arbiter if we're told its addr # Register with the arbiter if we're told its addr
@ -1807,24 +1800,21 @@ async def async_main(
# only on unique actor uids? # only on unique actor uids?
for addr in actor.reg_addrs: for addr in actor.reg_addrs:
try: try:
assert isinstance(addr, tuple) waddr = wrap_address(addr)
assert addr[1] # non-zero after bind assert waddr.is_valid
except AssertionError: except AssertionError:
await _debug.pause() await _debug.pause()
async with get_registry(*addr) as reg_portal: async with get_registry(addr) as reg_portal:
for accept_addr in accept_addrs: for accept_addr in accept_addrs:
accept_addr = wrap_address(accept_addr)
if not accept_addr[1]: assert accept_addr.is_valid
await _debug.pause()
assert accept_addr[1]
await reg_portal.run_from_ns( await reg_portal.run_from_ns(
'self', 'self',
'register_actor', 'register_actor',
uid=actor.uid, uid=actor.uid,
sockaddr=accept_addr, addr=accept_addr.unwrap(),
) )
is_registered: bool = True is_registered: bool = True
@ -1951,12 +1941,13 @@ async def async_main(
): ):
failed: bool = False failed: bool = False
for addr in actor.reg_addrs: for addr in actor.reg_addrs:
assert isinstance(addr, tuple) waddr = wrap_address(addr)
assert waddr.is_valid
with trio.move_on_after(0.5) as cs: with trio.move_on_after(0.5) as cs:
cs.shield = True cs.shield = True
try: try:
async with get_registry( async with get_registry(
*addr, addr,
) as reg_portal: ) as reg_portal:
await reg_portal.run_from_ns( await reg_portal.run_from_ns(
'self', 'self',
@ -2034,7 +2025,7 @@ class Arbiter(Actor):
self._registry: dict[ self._registry: dict[
tuple[str, str], tuple[str, str],
tuple[str, int], AddressTypes,
] = {} ] = {}
self._waiters: dict[ self._waiters: dict[
str, str,
@ -2050,18 +2041,18 @@ class Arbiter(Actor):
self, self,
name: str, name: str,
) -> tuple[str, int]|None: ) -> AddressTypes|None:
for uid, sockaddr in self._registry.items(): for uid, addr in self._registry.items():
if name in uid: if name in uid:
return sockaddr return addr
return None return None
async def get_registry( async def get_registry(
self self
) -> dict[str, tuple[str, int]]: ) -> dict[str, AddressTypes]:
''' '''
Return current name registry. Return current name registry.
@ -2081,7 +2072,7 @@ class Arbiter(Actor):
self, self,
name: str, name: str,
) -> list[tuple[str, int]]: ) -> list[AddressTypes]:
''' '''
Wait for a particular actor to register. Wait for a particular actor to register.
@ -2089,44 +2080,41 @@ class Arbiter(Actor):
registered. registered.
''' '''
sockaddrs: list[tuple[str, int]] = [] addrs: list[AddressTypes] = []
sockaddr: tuple[str, int] addr: AddressTypes
mailbox_info: str = 'Actor registry contact infos:\n' mailbox_info: str = 'Actor registry contact infos:\n'
for uid, sockaddr in self._registry.items(): for uid, addr in self._registry.items():
mailbox_info += ( mailbox_info += (
f'|_uid: {uid}\n' f'|_uid: {uid}\n'
f'|_sockaddr: {sockaddr}\n\n' f'|_addr: {addr}\n\n'
) )
if name == uid[0]: if name == uid[0]:
sockaddrs.append(sockaddr) addrs.append(addr)
if not sockaddrs: if not addrs:
waiter = trio.Event() waiter = trio.Event()
self._waiters.setdefault(name, []).append(waiter) self._waiters.setdefault(name, []).append(waiter)
await waiter.wait() await waiter.wait()
for uid in self._waiters[name]: for uid in self._waiters[name]:
if not isinstance(uid, trio.Event): if not isinstance(uid, trio.Event):
sockaddrs.append(self._registry[uid]) addrs.append(self._registry[uid])
log.runtime(mailbox_info) log.runtime(mailbox_info)
return sockaddrs return addrs
async def register_actor( async def register_actor(
self, self,
uid: tuple[str, str], uid: tuple[str, str],
sockaddr: tuple[str, int] addr: AddressTypes
) -> None: ) -> None:
uid = name, hash = (str(uid[0]), str(uid[1])) uid = name, hash = (str(uid[0]), str(uid[1]))
addr = (host, port) = ( waddr: Address = wrap_address(addr)
str(sockaddr[0]), if not waddr.is_valid:
int(sockaddr[1]), # should never be 0-dynamic-os-alloc
)
if port == 0:
await _debug.pause() await _debug.pause()
assert port # should never be 0-dynamic-os-alloc
self._registry[uid] = addr self._registry[uid] = addr
# pop and signal all waiter events # pop and signal all waiter events

View File

@ -46,6 +46,7 @@ from tractor._state import (
_runtime_vars, _runtime_vars,
) )
from tractor.log import get_logger from tractor.log import get_logger
from tractor._addr import AddressTypes
from tractor._portal import Portal from tractor._portal import Portal
from tractor._runtime import Actor from tractor._runtime import Actor
from tractor._entry import _mp_main from tractor._entry import _mp_main
@ -392,8 +393,8 @@ async def new_proc(
errors: dict[tuple[str, str], Exception], errors: dict[tuple[str, str], Exception],
# passed through to actor main # passed through to actor main
bind_addrs: list[tuple[str, int]], bind_addrs: list[AddressTypes],
parent_addr: tuple[str, int], parent_addr: AddressTypes,
_runtime_vars: dict[str, Any], # serialized and sent to _child _runtime_vars: dict[str, Any], # serialized and sent to _child
*, *,
@ -431,8 +432,8 @@ async def trio_proc(
errors: dict[tuple[str, str], Exception], errors: dict[tuple[str, str], Exception],
# passed through to actor main # passed through to actor main
bind_addrs: list[tuple[str, int]], bind_addrs: list[AddressTypes],
parent_addr: tuple[str, int], parent_addr: AddressTypes,
_runtime_vars: dict[str, Any], # serialized and sent to _child _runtime_vars: dict[str, Any], # serialized and sent to _child
*, *,
infect_asyncio: bool = False, infect_asyncio: bool = False,
@ -520,15 +521,15 @@ async def trio_proc(
# send a "spawning specification" which configures the # send a "spawning specification" which configures the
# initial runtime state of the child. # initial runtime state of the child.
await chan.send( sspec = SpawnSpec(
SpawnSpec( _parent_main_data=subactor._parent_main_data,
_parent_main_data=subactor._parent_main_data, enable_modules=subactor.enable_modules,
enable_modules=subactor.enable_modules, reg_addrs=subactor.reg_addrs,
reg_addrs=subactor.reg_addrs, bind_addrs=bind_addrs,
bind_addrs=bind_addrs, _runtime_vars=_runtime_vars,
_runtime_vars=_runtime_vars,
)
) )
log.runtime(f'Sending spawn spec: {str(sspec)}')
await chan.send(sspec)
# track subactor in current nursery # track subactor in current nursery
curr_actor: Actor = current_actor() curr_actor: Actor = current_actor()
@ -638,8 +639,8 @@ async def mp_proc(
subactor: Actor, subactor: Actor,
errors: dict[tuple[str, str], Exception], errors: dict[tuple[str, str], Exception],
# passed through to actor main # passed through to actor main
bind_addrs: list[tuple[str, int]], bind_addrs: list[AddressTypes],
parent_addr: tuple[str, int], parent_addr: AddressTypes,
_runtime_vars: dict[str, Any], # serialized and sent to _child _runtime_vars: dict[str, Any], # serialized and sent to _child
*, *,
infect_asyncio: bool = False, infect_asyncio: bool = False,

View File

@ -28,7 +28,13 @@ import warnings
import trio import trio
from .devx._debug import maybe_wait_for_debugger from .devx._debug import maybe_wait_for_debugger
from ._addr import (
AddressTypes,
preferred_transport,
get_address_cls
)
from ._state import current_actor, is_main_process from ._state import current_actor, is_main_process
from .log import get_logger, get_loglevel from .log import get_logger, get_loglevel
from ._runtime import Actor from ._runtime import Actor
@ -47,8 +53,6 @@ if TYPE_CHECKING:
log = get_logger(__name__) log = get_logger(__name__)
_default_bind_addr: tuple[str, int] = ('127.0.0.1', 0)
class ActorNursery: class ActorNursery:
''' '''
@ -130,8 +134,9 @@ class ActorNursery:
*, *,
bind_addrs: list[tuple[str, int]] = [_default_bind_addr], bind_addrs: list[AddressTypes]|None = None,
rpc_module_paths: list[str]|None = None, rpc_module_paths: list[str]|None = None,
enable_transports: list[str] = [preferred_transport],
enable_modules: list[str]|None = None, enable_modules: list[str]|None = None,
loglevel: str|None = None, # set log level per subactor loglevel: str|None = None, # set log level per subactor
debug_mode: bool|None = None, debug_mode: bool|None = None,
@ -156,6 +161,12 @@ class ActorNursery:
or get_loglevel() or get_loglevel()
) )
if not bind_addrs:
bind_addrs: list[AddressTypes] = [
get_address_cls(transport).get_random().unwrap()
for transport in enable_transports
]
# configure and pass runtime state # configure and pass runtime state
_rtv = _state._runtime_vars.copy() _rtv = _state._runtime_vars.copy()
_rtv['_is_root'] = False _rtv['_is_root'] = False
@ -224,7 +235,7 @@ class ActorNursery:
*, *,
name: str | None = None, name: str | None = None,
bind_addrs: tuple[str, int] = [_default_bind_addr], bind_addrs: AddressTypes|None = None,
rpc_module_paths: list[str] | None = None, rpc_module_paths: list[str] | None = None,
enable_modules: list[str] | None = None, enable_modules: list[str] | None = None,
loglevel: str | None = None, # set log level per subactor loglevel: str | None = None, # set log level per subactor

View File

@ -17,7 +17,6 @@ import platform
from ._transport import ( from ._transport import (
MsgTransportKey as MsgTransportKey, MsgTransportKey as MsgTransportKey,
AddressType as AddressType,
MsgType as MsgType, MsgType as MsgType,
MsgTransport as MsgTransport, MsgTransport as MsgTransport,
MsgpackTransport as MsgpackTransport MsgpackTransport as MsgpackTransport
@ -27,10 +26,8 @@ from ._tcp import MsgpackTCPStream as MsgpackTCPStream
from ._uds import MsgpackUDSStream as MsgpackUDSStream from ._uds import MsgpackUDSStream as MsgpackUDSStream
from ._types import ( from ._types import (
default_lo_addrs as default_lo_addrs, transport_from_addr as transport_from_addr,
transport_from_destaddr as transport_from_destaddr,
transport_from_stream as transport_from_stream, transport_from_stream as transport_from_stream,
AddressTypes as AddressTypes
) )
from ._chan import ( from ._chan import (

View File

@ -35,8 +35,12 @@ import trio
from tractor.ipc._transport import MsgTransport from tractor.ipc._transport import MsgTransport
from tractor.ipc._types import ( from tractor.ipc._types import (
transport_from_destaddr, transport_from_addr,
transport_from_stream, transport_from_stream,
)
from tractor._addr import (
wrap_address,
Address,
AddressTypes AddressTypes
) )
from tractor.log import get_logger from tractor.log import get_logger
@ -66,7 +70,6 @@ class Channel:
def __init__( def __init__(
self, self,
destaddr: AddressTypes|None = None,
transport: MsgTransport|None = None, transport: MsgTransport|None = None,
# TODO: optional reconnection support? # TODO: optional reconnection support?
# auto_reconnect: bool = False, # auto_reconnect: bool = False,
@ -81,8 +84,6 @@ class Channel:
# user in ``.from_stream()``. # user in ``.from_stream()``.
self._transport: MsgTransport|None = transport self._transport: MsgTransport|None = transport
self._destaddr = destaddr if destaddr else self._transport.raddr
# set after handshake - always uid of far end # set after handshake - always uid of far end
self.uid: tuple[str, str]|None = None self.uid: tuple[str, str]|None = None
@ -121,13 +122,14 @@ class Channel:
) )
@classmethod @classmethod
async def from_destaddr( async def from_addr(
cls, cls,
destaddr: AddressTypes, addr: AddressTypes,
**kwargs **kwargs
) -> Channel: ) -> Channel:
transport_cls = transport_from_destaddr(destaddr) addr: Address = wrap_address(addr)
transport = await transport_cls.connect_to(destaddr, **kwargs) transport_cls = transport_from_addr(addr)
transport = await transport_cls.connect_to(addr, **kwargs)
log.transport( log.transport(
f'Opened channel[{type(transport)}]: {transport.laddr} -> {transport.raddr}' f'Opened channel[{type(transport)}]: {transport.laddr} -> {transport.raddr}'
@ -164,11 +166,11 @@ class Channel:
) )
@property @property
def laddr(self) -> tuple[str, int]|None: def laddr(self) -> Address|None:
return self._transport.laddr if self._transport else None return self._transport.laddr if self._transport else None
@property @property
def raddr(self) -> tuple[str, int]|None: def raddr(self) -> Address|None:
return self._transport.raddr if self._transport else None return self._transport.raddr if self._transport else None
# TODO: something like, # TODO: something like,
@ -205,7 +207,11 @@ class Channel:
# assert err # assert err
__tracebackhide__: bool = False __tracebackhide__: bool = False
else: else:
assert err.cid try:
assert err.cid
except KeyError:
raise err
raise raise
@ -332,14 +338,14 @@ class Channel:
@acm @acm
async def _connect_chan( async def _connect_chan(
destaddr: AddressTypes addr: AddressTypes
) -> typing.AsyncGenerator[Channel, None]: ) -> typing.AsyncGenerator[Channel, None]:
''' '''
Create and connect a channel with disconnect on context manager Create and connect a channel with disconnect on context manager
teardown. teardown.
''' '''
chan = await Channel.from_destaddr(destaddr) chan = await Channel.from_addr(addr)
yield chan yield chan
with trio.CancelScope(shield=True): with trio.CancelScope(shield=True):
await chan.aclose() await chan.aclose()

View File

@ -183,6 +183,9 @@ class RingBuffSender(trio.abc.SendStream):
def wrap_fd(self) -> int: def wrap_fd(self) -> int:
return self._wrap_event.fd return self._wrap_event.fd
async def _wait_wrap(self):
await self._wrap_event.read()
async def send_all(self, data: Buffer): async def send_all(self, data: Buffer):
async with self._send_lock: async with self._send_lock:
# while data is larger than the remaining buf # while data is larger than the remaining buf
@ -193,7 +196,7 @@ class RingBuffSender(trio.abc.SendStream):
self._shm.buf[self.ptr:] = data[:remaining] self._shm.buf[self.ptr:] = data[:remaining]
# signal write and wait for reader wrap around # signal write and wait for reader wrap around
self._write_event.write(remaining) self._write_event.write(remaining)
await self._wrap_event.read() await self._wait_wrap()
# wrap around and trim already written bytes # wrap around and trim already written bytes
self._ptr = 0 self._ptr = 0

View File

@ -23,6 +23,7 @@ import trio
from tractor.msg import MsgCodec from tractor.msg import MsgCodec
from tractor.log import get_logger from tractor.log import get_logger
from tractor._addr import TCPAddress
from tractor.ipc._transport import MsgpackTransport from tractor.ipc._transport import MsgpackTransport
@ -38,9 +39,8 @@ class MsgpackTCPStream(MsgpackTransport):
using the ``msgspec`` codec lib. using the ``msgspec`` codec lib.
''' '''
address_type = tuple[str, int] address_type = TCPAddress
layer_key: int = 4 layer_key: int = 4
name_key: str = 'tcp'
# def __init__( # def __init__(
# self, # self,
@ -55,19 +55,32 @@ class MsgpackTCPStream(MsgpackTransport):
# codec=codec # codec=codec
# ) # )
@property
def maddr(self) -> str:
host, port = self.raddr.unwrap()
return (
f'/ipv4/{host}'
f'/{self.address_type.name_key}/{port}'
# f'/{self.chan.uid[0]}'
# f'/{self.cid}'
# f'/cid={cid_head}..{cid_tail}'
# TODO: ? not use this ^ right ?
)
def connected(self) -> bool: def connected(self) -> bool:
return self.stream.socket.fileno() != -1 return self.stream.socket.fileno() != -1
@classmethod @classmethod
async def connect_to( async def connect_to(
cls, cls,
destaddr: tuple[str, int], destaddr: TCPAddress,
prefix_size: int = 4, prefix_size: int = 4,
codec: MsgCodec|None = None, codec: MsgCodec|None = None,
**kwargs **kwargs
) -> MsgpackTCPStream: ) -> MsgpackTCPStream:
stream = await trio.open_tcp_stream( stream = await trio.open_tcp_stream(
*destaddr, *destaddr.unwrap(),
**kwargs **kwargs
) )
return MsgpackTCPStream( return MsgpackTCPStream(
@ -87,14 +100,6 @@ class MsgpackTCPStream(MsgpackTransport):
lsockname = stream.socket.getsockname() lsockname = stream.socket.getsockname()
rsockname = stream.socket.getpeername() rsockname = stream.socket.getpeername()
return ( return (
tuple(lsockname[:2]), TCPAddress.from_addr(tuple(lsockname[:2])),
tuple(rsockname[:2]), TCPAddress.from_addr(tuple(rsockname[:2])),
) )
@classmethod
def get_random_addr(self) -> tuple[str, int]:
return (None, 0)
@classmethod
def get_root_addr(self) -> tuple[str, int]:
return ('127.0.0.1', 1616)

View File

@ -50,6 +50,7 @@ from tractor.msg import (
types as msgtypes, types as msgtypes,
pretty_struct, pretty_struct,
) )
from tractor._addr import Address
log = get_logger(__name__) log = get_logger(__name__)
@ -62,12 +63,11 @@ MsgTransportKey = tuple[str, str]
# ?TODO? this should be our `Union[*msgtypes.__spec__]` alias now right..? # ?TODO? this should be our `Union[*msgtypes.__spec__]` alias now right..?
# => BLEH, except can't bc prots must inherit typevar or param-spec # => BLEH, except can't bc prots must inherit typevar or param-spec
# vars.. # vars..
AddressType = TypeVar('AddressType')
MsgType = TypeVar('MsgType') MsgType = TypeVar('MsgType')
@runtime_checkable @runtime_checkable
class MsgTransport(Protocol[AddressType, MsgType]): class MsgTransport(Protocol[MsgType]):
# #
# ^-TODO-^ consider using a generic def and indexing with our # ^-TODO-^ consider using a generic def and indexing with our
# eventual msg definition/types? # eventual msg definition/types?
@ -75,10 +75,9 @@ class MsgTransport(Protocol[AddressType, MsgType]):
stream: trio.abc.Stream stream: trio.abc.Stream
drained: list[MsgType] drained: list[MsgType]
address_type: ClassVar[Type[AddressType]]
address_type: ClassVar[Type[Address]]
codec_key: ClassVar[str] codec_key: ClassVar[str]
name_key: ClassVar[str]
# XXX: should this instead be called `.sendall()`? # XXX: should this instead be called `.sendall()`?
async def send(self, msg: MsgType) -> None: async def send(self, msg: MsgType) -> None:
@ -100,20 +99,24 @@ class MsgTransport(Protocol[AddressType, MsgType]):
@classmethod @classmethod
def key(cls) -> MsgTransportKey: def key(cls) -> MsgTransportKey:
return cls.codec_key, cls.name_key return cls.codec_key, cls.address_type.name_key
@property @property
def laddr(self) -> AddressType: def laddr(self) -> Address:
... ...
@property @property
def raddr(self) -> AddressType: def raddr(self) -> Address:
...
@property
def maddr(self) -> str:
... ...
@classmethod @classmethod
async def connect_to( async def connect_to(
cls, cls,
destaddr: AddressType, addr: Address,
**kwargs **kwargs
) -> MsgTransport: ) -> MsgTransport:
... ...
@ -123,8 +126,8 @@ class MsgTransport(Protocol[AddressType, MsgType]):
cls, cls,
stream: trio.abc.Stream stream: trio.abc.Stream
) -> tuple[ ) -> tuple[
AddressType, # local Address, # local
AddressType # remote Address # remote
]: ]:
''' '''
Return the `trio` streaming transport prot's addrs for both Return the `trio` streaming transport prot's addrs for both
@ -133,14 +136,6 @@ class MsgTransport(Protocol[AddressType, MsgType]):
''' '''
... ...
@classmethod
def get_random_addr(self) -> AddressType:
...
@classmethod
def get_root_addr(self) -> AddressType:
...
class MsgpackTransport(MsgTransport): class MsgpackTransport(MsgTransport):
@ -447,9 +442,9 @@ class MsgpackTransport(MsgTransport):
return self._aiter_pkts return self._aiter_pkts
@property @property
def laddr(self) -> AddressType: def laddr(self) -> Address:
return self._laddr return self._laddr
@property @property
def raddr(self) -> AddressType: def raddr(self) -> Address:
return self._raddr return self._raddr

View File

@ -13,49 +13,42 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Type, Union from typing import Type
import trio import trio
import socket import socket
from ._transport import ( from tractor._addr import Address
from tractor.ipc._transport import (
MsgTransportKey, MsgTransportKey,
MsgTransport MsgTransport
) )
from ._tcp import MsgpackTCPStream from tractor.ipc._tcp import MsgpackTCPStream
from ._uds import MsgpackUDSStream from tractor.ipc._uds import MsgpackUDSStream
# manually updated list of all supported msg transport types
_msg_transports = [ _msg_transports = [
MsgpackTCPStream, MsgpackTCPStream,
MsgpackUDSStream MsgpackUDSStream
] ]
# manually updated list of all supported codec+transport types # convert a MsgTransportKey to the corresponding transport type
key_to_transport: dict[MsgTransportKey, Type[MsgTransport]] = { _key_to_transport: dict[MsgTransportKey, Type[MsgTransport]] = {
cls.key(): cls cls.key(): cls
for cls in _msg_transports for cls in _msg_transports
} }
# convert an Address wrapper to its corresponding transport type
# all different address py types we use _addr_to_transport: dict[Type[Address], Type[MsgTransport]] = {
AddressTypes = Union[ cls.address_type: cls
tuple([
cls.address_type
for cls in _msg_transports
])
]
default_lo_addrs: dict[MsgTransportKey, AddressTypes] = {
cls.key(): cls.get_root_addr()
for cls in _msg_transports for cls in _msg_transports
} }
def transport_from_destaddr( def transport_from_addr(
destaddr: AddressTypes, addr: Address,
codec_key: str = 'msgpack', codec_key: str = 'msgpack',
) -> Type[MsgTransport]: ) -> Type[MsgTransport]:
''' '''
@ -63,23 +56,13 @@ def transport_from_destaddr(
corresponding `MsgTransport` type. corresponding `MsgTransport` type.
''' '''
match destaddr: try:
case str(): return _addr_to_transport[type(addr)]
return MsgpackUDSStream
case tuple(): except KeyError:
if ( raise NotImplementedError(
len(destaddr) == 2 f'No known transport for address {repr(addr)}'
and )
isinstance(destaddr[0], str)
and
isinstance(destaddr[1], int)
):
return MsgpackTCPStream
raise NotImplementedError(
f'No known transport for address {destaddr}'
)
def transport_from_stream( def transport_from_stream(
@ -113,4 +96,4 @@ def transport_from_stream(
key = (codec_key, transport) key = (codec_key, transport)
return _msg_transports[key] return _key_to_transport[key]

View File

@ -18,13 +18,12 @@ Unix Domain Socket implementation of tractor.ipc._transport.MsgTransport protoco
''' '''
from __future__ import annotations from __future__ import annotations
import tempfile
from uuid import uuid4
import trio import trio
from tractor.msg import MsgCodec from tractor.msg import MsgCodec
from tractor.log import get_logger from tractor.log import get_logger
from tractor._addr import UDSAddress
from tractor.ipc._transport import MsgpackTransport from tractor.ipc._transport import MsgpackTransport
@ -37,9 +36,8 @@ class MsgpackUDSStream(MsgpackTransport):
using the ``msgspec`` codec lib. using the ``msgspec`` codec lib.
''' '''
address_type = str address_type = UDSAddress
layer_key: int = 7 layer_key: int = 7
name_key: str = 'uds'
# def __init__( # def __init__(
# self, # self,
@ -54,19 +52,32 @@ class MsgpackUDSStream(MsgpackTransport):
# codec=codec # codec=codec
# ) # )
@property
def maddr(self) -> str:
filepath = self.raddr.unwrap()
return (
f'/ipv4/localhost'
f'/{self.address_type.name_key}/{filepath}'
# f'/{self.chan.uid[0]}'
# f'/{self.cid}'
# f'/cid={cid_head}..{cid_tail}'
# TODO: ? not use this ^ right ?
)
def connected(self) -> bool: def connected(self) -> bool:
return self.stream.socket.fileno() != -1 return self.stream.socket.fileno() != -1
@classmethod @classmethod
async def connect_to( async def connect_to(
cls, cls,
filename: str, addr: UDSAddress,
prefix_size: int = 4, prefix_size: int = 4,
codec: MsgCodec|None = None, codec: MsgCodec|None = None,
**kwargs **kwargs
) -> MsgpackUDSStream: ) -> MsgpackUDSStream:
stream = await trio.open_unix_socket( stream = await trio.open_unix_socket(
filename, addr.unwrap(),
**kwargs **kwargs
) )
return MsgpackUDSStream( return MsgpackUDSStream(
@ -79,16 +90,8 @@ class MsgpackUDSStream(MsgpackTransport):
def get_stream_addrs( def get_stream_addrs(
cls, cls,
stream: trio.SocketStream stream: trio.SocketStream
) -> tuple[str, str]: ) -> tuple[UDSAddress, UDSAddress]:
return ( return (
stream.socket.getsockname(), UDSAddress.from_addr(stream.socket.getsockname()),
stream.socket.getpeername(), UDSAddress.from_addr(stream.socket.getsockname()),
) )
@classmethod
def get_random_addr(self) -> str:
return f'{tempfile.gettempdir()}/{uuid4()}.sock'
@classmethod
def get_root_addr(self) -> str:
return 'tractor.sock'

View File

@ -46,8 +46,8 @@ from msgspec import (
from tractor.msg import ( from tractor.msg import (
pretty_struct, pretty_struct,
) )
from tractor.ipc import AddressTypes
from tractor.log import get_logger from tractor.log import get_logger
from tractor._addr import AddressTypes
log = get_logger('tractor.msgspec') log = get_logger('tractor.msgspec')