Add root and random addr getters on MsgTransport type

Guillermo Rodriguez 2025-03-22 16:17:50 -03:00
parent 2907719cbe
commit 7400f89753
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
6 changed files with 63 additions and 12 deletions

View File

@ -16,6 +16,7 @@
import platform
from ._transport import (
MsgTransportKey as MsgTransportKey,
AddressType as AddressType,
MsgType as MsgType,
MsgTransport as MsgTransport,
@ -26,6 +27,7 @@ from ._tcp import MsgpackTCPStream as MsgpackTCPStream
from ._uds import MsgpackUDSStream as MsgpackUDSStream
from ._types import (
default_lo_addrs as default_lo_addrs,
transport_from_destaddr as transport_from_destaddr,
transport_from_stream as transport_from_stream,
AddressTypes as AddressTypes

View File

@ -90,3 +90,11 @@ class MsgpackTCPStream(MsgpackTransport):
tuple(lsockname[:2]),
tuple(rsockname[:2]),
)
@classmethod
def get_random_addr(self) -> tuple[str, int]:
return (None, 0)
@classmethod
def get_root_addr(self) -> tuple[str, int]:
return ('127.0.0.1', 1616)

View File

@ -31,10 +31,6 @@ from collections.abc import (
AsyncIterator,
)
import struct
from typing import (
Any,
Callable,
)
import trio
import msgspec
@ -58,6 +54,10 @@ from tractor.msg import (
log = get_logger(__name__)
# (codec, transport)
MsgTransportKey = tuple[str, str]
# from tractor.msg.types import MsgType
# ?TODO? this should be our `Union[*msgtypes.__spec__]` alias now right..?
# => BLEH, except can't bc prots must inherit typevar or param-spec
@ -77,6 +77,9 @@ class MsgTransport(Protocol[AddressType, MsgType]):
drained: list[MsgType]
address_type: ClassVar[Type[AddressType]]
codec_key: ClassVar[str]
name_key: ClassVar[str]
# XXX: should this instead be called `.sendall()`?
async def send(self, msg: MsgType) -> None:
...
@ -95,6 +98,10 @@ class MsgTransport(Protocol[AddressType, MsgType]):
def drain(self) -> AsyncIterator[dict]:
...
@classmethod
def key(cls) -> MsgTransportKey:
return cls.codec_key, cls.name_key
@property
def laddr(self) -> AddressType:
...
@ -126,6 +133,14 @@ class MsgTransport(Protocol[AddressType, MsgType]):
'''
...
@classmethod
def get_random_addr(self) -> AddressType:
...
@classmethod
def get_root_addr(self) -> AddressType:
...
class MsgpackTransport(MsgTransport):
@ -411,7 +426,7 @@ class MsgpackTransport(MsgTransport):
# __tracebackhide__: bool = False
# raise
async def recv(self) -> Any:
async def recv(self) -> msgtypes.MsgType:
return await self._aiter_pkts.asend(None)
async def drain(self) -> AsyncIterator[dict]:

View File

@ -18,15 +18,24 @@ from typing import Type, Union
import trio
import socket
from ._transport import MsgTransport
from ._transport import (
MsgTransportKey,
MsgTransport
)
from ._tcp import MsgpackTCPStream
from ._uds import MsgpackUDSStream
_msg_transports = [
MsgpackTCPStream,
MsgpackUDSStream
]
# manually updated list of all supported codec+transport types
_msg_transports = {
('msgpack', 'tcp'): MsgpackTCPStream,
('msgpack', 'uds'): MsgpackUDSStream
key_to_transport: dict[MsgTransportKey, Type[MsgTransport]] = {
cls.key(): cls
for cls in _msg_transports
}
@ -34,11 +43,17 @@ _msg_transports = {
AddressTypes = Union[
tuple([
cls.address_type
for key, cls in _msg_transports.items()
for cls in _msg_transports
])
]
default_lo_addrs: dict[MsgTransportKey, AddressTypes] = {
cls.key(): cls.get_root_addr()
for cls in _msg_transports
}
def transport_from_destaddr(
destaddr: AddressTypes,
codec_key: str = 'msgpack',

View File

@ -18,6 +18,8 @@ Unix Domain Socket implementation of tractor.ipc._transport.MsgTransport protoco
'''
from __future__ import annotations
import tempfile
from uuid import uuid4
import trio
@ -82,3 +84,11 @@ class MsgpackUDSStream(MsgpackTransport):
stream.socket.getsockname(),
stream.socket.getpeername(),
)
@classmethod
def get_random_addr(self) -> str:
return f'{tempfile.gettempdir()}/{uuid4()}.sock'
@classmethod
def get_root_addr(self) -> str:
return 'tractor.sock'

View File

@ -46,6 +46,7 @@ from msgspec import (
from tractor.msg import (
pretty_struct,
)
from tractor.ipc import AddressTypes
from tractor.log import get_logger
@ -167,8 +168,8 @@ class SpawnSpec(
# TODO: not just sockaddr pairs?
# -[ ] abstract into a `TransportAddr` type?
reg_addrs: list[tuple[str, int]]
bind_addrs: list[tuple[str, int]]
reg_addrs: list[AddressTypes]
bind_addrs: list[AddressTypes]
# TODO: caps based RPC support in the payload?