Add root and random addr getters on MsgTransport type
parent
2907719cbe
commit
7400f89753
|
@ -16,6 +16,7 @@
|
|||
import platform
|
||||
|
||||
from ._transport import (
|
||||
MsgTransportKey as MsgTransportKey,
|
||||
AddressType as AddressType,
|
||||
MsgType as MsgType,
|
||||
MsgTransport as MsgTransport,
|
||||
|
@ -26,6 +27,7 @@ from ._tcp import MsgpackTCPStream as MsgpackTCPStream
|
|||
from ._uds import MsgpackUDSStream as MsgpackUDSStream
|
||||
|
||||
from ._types import (
|
||||
default_lo_addrs as default_lo_addrs,
|
||||
transport_from_destaddr as transport_from_destaddr,
|
||||
transport_from_stream as transport_from_stream,
|
||||
AddressTypes as AddressTypes
|
||||
|
|
|
@ -90,3 +90,11 @@ class MsgpackTCPStream(MsgpackTransport):
|
|||
tuple(lsockname[:2]),
|
||||
tuple(rsockname[:2]),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_random_addr(self) -> tuple[str, int]:
|
||||
return (None, 0)
|
||||
|
||||
@classmethod
|
||||
def get_root_addr(self) -> tuple[str, int]:
|
||||
return ('127.0.0.1', 1616)
|
||||
|
|
|
@ -31,10 +31,6 @@ from collections.abc import (
|
|||
AsyncIterator,
|
||||
)
|
||||
import struct
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
)
|
||||
|
||||
import trio
|
||||
import msgspec
|
||||
|
@ -58,6 +54,10 @@ from tractor.msg import (
|
|||
log = get_logger(__name__)
|
||||
|
||||
|
||||
# (codec, transport)
|
||||
MsgTransportKey = tuple[str, str]
|
||||
|
||||
|
||||
# from tractor.msg.types import MsgType
|
||||
# ?TODO? this should be our `Union[*msgtypes.__spec__]` alias now right..?
|
||||
# => BLEH, except can't bc prots must inherit typevar or param-spec
|
||||
|
@ -77,6 +77,9 @@ class MsgTransport(Protocol[AddressType, MsgType]):
|
|||
drained: list[MsgType]
|
||||
address_type: ClassVar[Type[AddressType]]
|
||||
|
||||
codec_key: ClassVar[str]
|
||||
name_key: ClassVar[str]
|
||||
|
||||
# XXX: should this instead be called `.sendall()`?
|
||||
async def send(self, msg: MsgType) -> None:
|
||||
...
|
||||
|
@ -95,6 +98,10 @@ class MsgTransport(Protocol[AddressType, MsgType]):
|
|||
def drain(self) -> AsyncIterator[dict]:
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def key(cls) -> MsgTransportKey:
|
||||
return cls.codec_key, cls.name_key
|
||||
|
||||
@property
|
||||
def laddr(self) -> AddressType:
|
||||
...
|
||||
|
@ -126,6 +133,14 @@ class MsgTransport(Protocol[AddressType, MsgType]):
|
|||
'''
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def get_random_addr(self) -> AddressType:
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def get_root_addr(self) -> AddressType:
|
||||
...
|
||||
|
||||
|
||||
class MsgpackTransport(MsgTransport):
|
||||
|
||||
|
@ -411,7 +426,7 @@ class MsgpackTransport(MsgTransport):
|
|||
# __tracebackhide__: bool = False
|
||||
# raise
|
||||
|
||||
async def recv(self) -> Any:
|
||||
async def recv(self) -> msgtypes.MsgType:
|
||||
return await self._aiter_pkts.asend(None)
|
||||
|
||||
async def drain(self) -> AsyncIterator[dict]:
|
||||
|
|
|
@ -18,15 +18,24 @@ from typing import Type, Union
|
|||
import trio
|
||||
import socket
|
||||
|
||||
from ._transport import MsgTransport
|
||||
from ._transport import (
|
||||
MsgTransportKey,
|
||||
MsgTransport
|
||||
)
|
||||
from ._tcp import MsgpackTCPStream
|
||||
from ._uds import MsgpackUDSStream
|
||||
|
||||
|
||||
_msg_transports = [
|
||||
MsgpackTCPStream,
|
||||
MsgpackUDSStream
|
||||
]
|
||||
|
||||
|
||||
# manually updated list of all supported codec+transport types
|
||||
_msg_transports = {
|
||||
('msgpack', 'tcp'): MsgpackTCPStream,
|
||||
('msgpack', 'uds'): MsgpackUDSStream
|
||||
key_to_transport: dict[MsgTransportKey, Type[MsgTransport]] = {
|
||||
cls.key(): cls
|
||||
for cls in _msg_transports
|
||||
}
|
||||
|
||||
|
||||
|
@ -34,11 +43,17 @@ _msg_transports = {
|
|||
AddressTypes = Union[
|
||||
tuple([
|
||||
cls.address_type
|
||||
for key, cls in _msg_transports.items()
|
||||
for cls in _msg_transports
|
||||
])
|
||||
]
|
||||
|
||||
|
||||
default_lo_addrs: dict[MsgTransportKey, AddressTypes] = {
|
||||
cls.key(): cls.get_root_addr()
|
||||
for cls in _msg_transports
|
||||
}
|
||||
|
||||
|
||||
def transport_from_destaddr(
|
||||
destaddr: AddressTypes,
|
||||
codec_key: str = 'msgpack',
|
||||
|
|
|
@ -18,6 +18,8 @@ Unix Domain Socket implementation of tractor.ipc._transport.MsgTransport protoco
|
|||
|
||||
'''
|
||||
from __future__ import annotations
|
||||
import tempfile
|
||||
from uuid import uuid4
|
||||
|
||||
import trio
|
||||
|
||||
|
@ -82,3 +84,11 @@ class MsgpackUDSStream(MsgpackTransport):
|
|||
stream.socket.getsockname(),
|
||||
stream.socket.getpeername(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_random_addr(self) -> str:
|
||||
return f'{tempfile.gettempdir()}/{uuid4()}.sock'
|
||||
|
||||
@classmethod
|
||||
def get_root_addr(self) -> str:
|
||||
return 'tractor.sock'
|
||||
|
|
|
@ -46,6 +46,7 @@ from msgspec import (
|
|||
from tractor.msg import (
|
||||
pretty_struct,
|
||||
)
|
||||
from tractor.ipc import AddressTypes
|
||||
from tractor.log import get_logger
|
||||
|
||||
|
||||
|
@ -167,8 +168,8 @@ class SpawnSpec(
|
|||
|
||||
# TODO: not just sockaddr pairs?
|
||||
# -[ ] abstract into a `TransportAddr` type?
|
||||
reg_addrs: list[tuple[str, int]]
|
||||
bind_addrs: list[tuple[str, int]]
|
||||
reg_addrs: list[AddressTypes]
|
||||
bind_addrs: list[AddressTypes]
|
||||
|
||||
|
||||
# TODO: caps based RPC support in the payload?
|
||||
|
|
Loading…
Reference in New Issue