Add root and random addr getters on MsgTransport type
parent
2907719cbe
commit
7400f89753
|
@ -16,6 +16,7 @@
|
||||||
import platform
|
import platform
|
||||||
|
|
||||||
from ._transport import (
|
from ._transport import (
|
||||||
|
MsgTransportKey as MsgTransportKey,
|
||||||
AddressType as AddressType,
|
AddressType as AddressType,
|
||||||
MsgType as MsgType,
|
MsgType as MsgType,
|
||||||
MsgTransport as MsgTransport,
|
MsgTransport as MsgTransport,
|
||||||
|
@ -26,6 +27,7 @@ from ._tcp import MsgpackTCPStream as MsgpackTCPStream
|
||||||
from ._uds import MsgpackUDSStream as MsgpackUDSStream
|
from ._uds import MsgpackUDSStream as MsgpackUDSStream
|
||||||
|
|
||||||
from ._types import (
|
from ._types import (
|
||||||
|
default_lo_addrs as default_lo_addrs,
|
||||||
transport_from_destaddr as transport_from_destaddr,
|
transport_from_destaddr as transport_from_destaddr,
|
||||||
transport_from_stream as transport_from_stream,
|
transport_from_stream as transport_from_stream,
|
||||||
AddressTypes as AddressTypes
|
AddressTypes as AddressTypes
|
||||||
|
|
|
@ -90,3 +90,11 @@ class MsgpackTCPStream(MsgpackTransport):
|
||||||
tuple(lsockname[:2]),
|
tuple(lsockname[:2]),
|
||||||
tuple(rsockname[:2]),
|
tuple(rsockname[:2]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_random_addr(self) -> tuple[str, int]:
|
||||||
|
return (None, 0)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_root_addr(self) -> tuple[str, int]:
|
||||||
|
return ('127.0.0.1', 1616)
|
||||||
|
|
|
@ -31,10 +31,6 @@ from collections.abc import (
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
)
|
)
|
||||||
import struct
|
import struct
|
||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
)
|
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
import msgspec
|
import msgspec
|
||||||
|
@ -58,6 +54,10 @@ from tractor.msg import (
|
||||||
log = get_logger(__name__)
|
log = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# (codec, transport)
|
||||||
|
MsgTransportKey = tuple[str, str]
|
||||||
|
|
||||||
|
|
||||||
# from tractor.msg.types import MsgType
|
# from tractor.msg.types import MsgType
|
||||||
# ?TODO? this should be our `Union[*msgtypes.__spec__]` alias now right..?
|
# ?TODO? this should be our `Union[*msgtypes.__spec__]` alias now right..?
|
||||||
# => BLEH, except can't bc prots must inherit typevar or param-spec
|
# => BLEH, except can't bc prots must inherit typevar or param-spec
|
||||||
|
@ -77,6 +77,9 @@ class MsgTransport(Protocol[AddressType, MsgType]):
|
||||||
drained: list[MsgType]
|
drained: list[MsgType]
|
||||||
address_type: ClassVar[Type[AddressType]]
|
address_type: ClassVar[Type[AddressType]]
|
||||||
|
|
||||||
|
codec_key: ClassVar[str]
|
||||||
|
name_key: ClassVar[str]
|
||||||
|
|
||||||
# XXX: should this instead be called `.sendall()`?
|
# XXX: should this instead be called `.sendall()`?
|
||||||
async def send(self, msg: MsgType) -> None:
|
async def send(self, msg: MsgType) -> None:
|
||||||
...
|
...
|
||||||
|
@ -95,6 +98,10 @@ class MsgTransport(Protocol[AddressType, MsgType]):
|
||||||
def drain(self) -> AsyncIterator[dict]:
|
def drain(self) -> AsyncIterator[dict]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def key(cls) -> MsgTransportKey:
|
||||||
|
return cls.codec_key, cls.name_key
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def laddr(self) -> AddressType:
|
def laddr(self) -> AddressType:
|
||||||
...
|
...
|
||||||
|
@ -126,6 +133,14 @@ class MsgTransport(Protocol[AddressType, MsgType]):
|
||||||
'''
|
'''
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_random_addr(self) -> AddressType:
|
||||||
|
...
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_root_addr(self) -> AddressType:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class MsgpackTransport(MsgTransport):
|
class MsgpackTransport(MsgTransport):
|
||||||
|
|
||||||
|
@ -411,7 +426,7 @@ class MsgpackTransport(MsgTransport):
|
||||||
# __tracebackhide__: bool = False
|
# __tracebackhide__: bool = False
|
||||||
# raise
|
# raise
|
||||||
|
|
||||||
async def recv(self) -> Any:
|
async def recv(self) -> msgtypes.MsgType:
|
||||||
return await self._aiter_pkts.asend(None)
|
return await self._aiter_pkts.asend(None)
|
||||||
|
|
||||||
async def drain(self) -> AsyncIterator[dict]:
|
async def drain(self) -> AsyncIterator[dict]:
|
||||||
|
|
|
@ -18,15 +18,24 @@ from typing import Type, Union
|
||||||
import trio
|
import trio
|
||||||
import socket
|
import socket
|
||||||
|
|
||||||
from ._transport import MsgTransport
|
from ._transport import (
|
||||||
|
MsgTransportKey,
|
||||||
|
MsgTransport
|
||||||
|
)
|
||||||
from ._tcp import MsgpackTCPStream
|
from ._tcp import MsgpackTCPStream
|
||||||
from ._uds import MsgpackUDSStream
|
from ._uds import MsgpackUDSStream
|
||||||
|
|
||||||
|
|
||||||
|
_msg_transports = [
|
||||||
|
MsgpackTCPStream,
|
||||||
|
MsgpackUDSStream
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
# manually updated list of all supported codec+transport types
|
# manually updated list of all supported codec+transport types
|
||||||
_msg_transports = {
|
key_to_transport: dict[MsgTransportKey, Type[MsgTransport]] = {
|
||||||
('msgpack', 'tcp'): MsgpackTCPStream,
|
cls.key(): cls
|
||||||
('msgpack', 'uds'): MsgpackUDSStream
|
for cls in _msg_transports
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -34,11 +43,17 @@ _msg_transports = {
|
||||||
AddressTypes = Union[
|
AddressTypes = Union[
|
||||||
tuple([
|
tuple([
|
||||||
cls.address_type
|
cls.address_type
|
||||||
for key, cls in _msg_transports.items()
|
for cls in _msg_transports
|
||||||
])
|
])
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
default_lo_addrs: dict[MsgTransportKey, AddressTypes] = {
|
||||||
|
cls.key(): cls.get_root_addr()
|
||||||
|
for cls in _msg_transports
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def transport_from_destaddr(
|
def transport_from_destaddr(
|
||||||
destaddr: AddressTypes,
|
destaddr: AddressTypes,
|
||||||
codec_key: str = 'msgpack',
|
codec_key: str = 'msgpack',
|
||||||
|
|
|
@ -18,6 +18,8 @@ Unix Domain Socket implementation of tractor.ipc._transport.MsgTransport protoco
|
||||||
|
|
||||||
'''
|
'''
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
import tempfile
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
|
@ -82,3 +84,11 @@ class MsgpackUDSStream(MsgpackTransport):
|
||||||
stream.socket.getsockname(),
|
stream.socket.getsockname(),
|
||||||
stream.socket.getpeername(),
|
stream.socket.getpeername(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_random_addr(self) -> str:
|
||||||
|
return f'{tempfile.gettempdir()}/{uuid4()}.sock'
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_root_addr(self) -> str:
|
||||||
|
return 'tractor.sock'
|
||||||
|
|
|
@ -46,6 +46,7 @@ from msgspec import (
|
||||||
from tractor.msg import (
|
from tractor.msg import (
|
||||||
pretty_struct,
|
pretty_struct,
|
||||||
)
|
)
|
||||||
|
from tractor.ipc import AddressTypes
|
||||||
from tractor.log import get_logger
|
from tractor.log import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
@ -167,8 +168,8 @@ class SpawnSpec(
|
||||||
|
|
||||||
# TODO: not just sockaddr pairs?
|
# TODO: not just sockaddr pairs?
|
||||||
# -[ ] abstract into a `TransportAddr` type?
|
# -[ ] abstract into a `TransportAddr` type?
|
||||||
reg_addrs: list[tuple[str, int]]
|
reg_addrs: list[AddressTypes]
|
||||||
bind_addrs: list[tuple[str, int]]
|
bind_addrs: list[AddressTypes]
|
||||||
|
|
||||||
|
|
||||||
# TODO: caps based RPC support in the payload?
|
# TODO: caps based RPC support in the payload?
|
||||||
|
|
Loading…
Reference in New Issue