224 lines
5.6 KiB
Python
224 lines
5.6 KiB
Python
# tractor: structured concurrent "actors".
|
|
# Copyright 2018-eternity Tyler Goodlet.
|
|
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Affero General Public License as published by
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
'''
|
|
TCP implementation of tractor.ipc._transport.MsgTransport protocol
|
|
|
|
'''
|
|
from __future__ import annotations
|
|
from typing import (
|
|
ClassVar,
|
|
)
|
|
# from contextlib import (
|
|
# asynccontextmanager as acm,
|
|
# )
|
|
|
|
import msgspec
|
|
import trio
|
|
from trio import (
|
|
SocketListener,
|
|
open_tcp_listeners,
|
|
)
|
|
|
|
from tractor.msg import MsgCodec
|
|
from tractor.log import get_logger
|
|
from tractor.ipc._transport import (
|
|
MsgTransport,
|
|
MsgpackTransport,
|
|
)
|
|
|
|
|
|
log = get_logger(__name__)
|
|
|
|
|
|
class TCPAddress(
|
|
msgspec.Struct,
|
|
frozen=True,
|
|
):
|
|
_host: str
|
|
_port: int
|
|
|
|
proto_key: ClassVar[str] = 'tcp'
|
|
unwrapped_type: ClassVar[type] = tuple[str, int]
|
|
def_bindspace: ClassVar[str] = '127.0.0.1'
|
|
|
|
@property
|
|
def is_valid(self) -> bool:
|
|
return self._port != 0
|
|
|
|
@property
|
|
def bindspace(self) -> str:
|
|
return self._host
|
|
|
|
@property
|
|
def domain(self) -> str:
|
|
return self._host
|
|
|
|
@classmethod
|
|
def from_addr(
|
|
cls,
|
|
addr: tuple[str, int]
|
|
) -> TCPAddress:
|
|
match addr:
|
|
case (str(), int()):
|
|
return TCPAddress(addr[0], addr[1])
|
|
case _:
|
|
raise ValueError(
|
|
f'Invalid unwrapped address for {cls}\n'
|
|
f'{addr}\n'
|
|
)
|
|
|
|
def unwrap(self) -> tuple[str, int]:
|
|
return (
|
|
self._host,
|
|
self._port,
|
|
)
|
|
|
|
@classmethod
|
|
def get_random(
|
|
cls,
|
|
bindspace: str = def_bindspace,
|
|
) -> TCPAddress:
|
|
return TCPAddress(bindspace, 0)
|
|
|
|
@classmethod
|
|
def get_root(cls) -> TCPAddress:
|
|
return TCPAddress(
|
|
'127.0.0.1',
|
|
1616,
|
|
)
|
|
|
|
def __repr__(self) -> str:
|
|
return (
|
|
f'{type(self).__name__}[{self.unwrap()}]'
|
|
)
|
|
|
|
@classmethod
|
|
def get_transport(
|
|
cls,
|
|
codec: str = 'msgpack',
|
|
) -> MsgTransport:
|
|
match codec:
|
|
case 'msgspack':
|
|
return MsgpackTCPStream
|
|
case _:
|
|
raise ValueError(
|
|
f'No IPC transport with {codec!r} supported !'
|
|
)
|
|
|
|
|
|
async def start_listener(
|
|
addr: TCPAddress,
|
|
**kwargs,
|
|
) -> SocketListener:
|
|
'''
|
|
Start a TCP socket listener on the given `TCPAddress`.
|
|
|
|
'''
|
|
log.info(
|
|
f'Attempting to bind TCP socket\n'
|
|
f'>[\n'
|
|
f'|_{addr}\n'
|
|
)
|
|
# ?TODO, maybe we should just change the lower-level call this is
|
|
# using internall per-listener?
|
|
listeners: list[SocketListener] = await open_tcp_listeners(
|
|
host=addr._host,
|
|
port=addr._port,
|
|
**kwargs
|
|
)
|
|
# NOTE, for now we don't expect non-singleton-resolving
|
|
# domain-addresses/multi-homed-hosts.
|
|
# (though it is supported by `open_tcp_listeners()`)
|
|
assert len(listeners) == 1
|
|
listener = listeners[0]
|
|
host, port = listener.socket.getsockname()[:2]
|
|
|
|
log.info(
|
|
f'Listening on TCP socket\n'
|
|
f'[>\n'
|
|
f' |_{addr}\n'
|
|
)
|
|
return listener
|
|
|
|
|
|
# TODO: typing oddity.. not sure why we have to inherit here, but it
|
|
# seems to be an issue with `get_msg_transport()` returning
|
|
# a `Type[Protocol]`; probably should make a `mypy` issue?
|
|
class MsgpackTCPStream(MsgpackTransport):
|
|
'''
|
|
A ``trio.SocketStream`` delivering ``msgpack`` formatted data
|
|
using the ``msgspec`` codec lib.
|
|
|
|
'''
|
|
address_type = TCPAddress
|
|
layer_key: int = 4
|
|
|
|
@property
|
|
def maddr(self) -> str:
|
|
host, port = self.raddr.unwrap()
|
|
return (
|
|
# TODO, use `ipaddress` from stdlib to handle
|
|
# first detecting which of `ipv4/6` before
|
|
# choosing the routing prefix part.
|
|
f'/ipv4/{host}'
|
|
|
|
f'/{self.address_type.proto_key}/{port}'
|
|
# f'/{self.chan.uid[0]}'
|
|
# f'/{self.cid}'
|
|
|
|
# f'/cid={cid_head}..{cid_tail}'
|
|
# TODO: ? not use this ^ right ?
|
|
)
|
|
|
|
def connected(self) -> bool:
|
|
return self.stream.socket.fileno() != -1
|
|
|
|
@classmethod
|
|
async def connect_to(
|
|
cls,
|
|
destaddr: TCPAddress,
|
|
prefix_size: int = 4,
|
|
codec: MsgCodec|None = None,
|
|
**kwargs
|
|
) -> MsgpackTCPStream:
|
|
stream = await trio.open_tcp_stream(
|
|
*destaddr.unwrap(),
|
|
**kwargs
|
|
)
|
|
return MsgpackTCPStream(
|
|
stream,
|
|
prefix_size=prefix_size,
|
|
codec=codec
|
|
)
|
|
|
|
@classmethod
|
|
def get_stream_addrs(
|
|
cls,
|
|
stream: trio.SocketStream
|
|
) -> tuple[
|
|
TCPAddress,
|
|
TCPAddress,
|
|
]:
|
|
# TODO, what types are these?
|
|
lsockname = stream.socket.getsockname()
|
|
l_sockaddr: tuple[str, int] = tuple(lsockname[:2])
|
|
rsockname = stream.socket.getpeername()
|
|
r_sockaddr: tuple[str, int] = tuple(rsockname[:2])
|
|
return (
|
|
TCPAddress.from_addr(l_sockaddr),
|
|
TCPAddress.from_addr(r_sockaddr),
|
|
)
|