diff --git a/tractor/_runtime.py b/tractor/_runtime.py index f8356582..4daa4742 100644 --- a/tractor/_runtime.py +++ b/tractor/_runtime.py @@ -74,11 +74,14 @@ from tractor.msg import ( pretty_struct, types as msgtypes, ) -from .ipc import Channel +from .ipc import ( + Channel, + _server, +) from ._addr import ( UnwrappedAddress, Address, - default_lo_addrs, + # default_lo_addrs, get_address_cls, wrap_address, ) @@ -157,16 +160,24 @@ class Actor: # nursery placeholders filled in by `async_main()` after fork _root_n: Nursery|None = None _service_n: Nursery|None = None - _server_n: Nursery|None = None + + # XXX moving to IPCServer! + _ipc_server: _server.IPCServer|None = None + + @property + def ipc_server(self) -> _server.IPCServer: + ''' + The IPC transport-server for this actor; normally + a process-singleton. + + ''' + return self._ipc_server # Information about `__main__` from parent _parent_main_data: dict[str, str] _parent_chan_cs: CancelScope|None = None _spawn_spec: msgtypes.SpawnSpec|None = None - # syncs for setup/teardown sequences - _server_down: trio.Event|None = None - # if started on ``asycio`` running ``trio`` in guest mode _infected_aio: bool = False @@ -266,8 +277,6 @@ class Actor: Context ] = {} - self._listeners: list[trio.abc.Listener] = [] - self._listen_addrs: list[Address] = [] self._parent_chan: Channel|None = None self._forkserver_info: tuple|None = None @@ -335,7 +344,6 @@ class Actor: if rent_chan := self._parent_chan: parent_uid = rent_chan.uid peers: list[tuple] = list(self._peer_connected) - listen_addrs: str = pformat(self._listen_addrs) fmtstr: str = ( f' |_id: {self.aid!r}\n' # f" aid{ds}{self.aid!r}\n" @@ -343,8 +351,7 @@ class Actor: f'\n' f' |_ipc: {len(peers)!r} connected peers\n' f" peers{ds}{peers!r}\n" - f" _listen_addrs{ds}'{listen_addrs}'\n" - f" _listeners{ds}'{self._listeners}'\n" + f" ipc_server{ds}{self._ipc_server}\n" f'\n' f' |_rpc: {len(self._rpc_tasks)} tasks\n' f" ctxs{ds}{len(self._contexts)}\n" @@ -499,6 +506,9 @@ class Actor: ''' self._no_more_peers = trio.Event() # unset by making new + # with _debug.maybe_open_crash_handler( + # pdb=True, + # ) as boxerr: chan = Channel.from_stream(stream) con_status: str = ( 'New inbound IPC connection <=\n' @@ -1303,85 +1313,6 @@ class Actor: await self.cancel(req_chan=None) # self cancel raise - async def _serve_forever( - self, - handler_nursery: Nursery, - *, - listen_addrs: list[UnwrappedAddress]|None = None, - - task_status: TaskStatus[Nursery] = trio.TASK_STATUS_IGNORED, - ) -> None: - ''' - Start the IPC transport server, begin listening/accepting new - `trio.SocketStream` connections. - - This will cause an actor to continue living (and thus - blocking at the process/OS-thread level) until - `.cancel_server()` is called. - - ''' - if listen_addrs is None: - listen_addrs = default_lo_addrs([ - _state._def_tpt_proto - ]) - - else: - listen_addrs: list[Address] = [ - wrap_address(a) for a in listen_addrs - ] - - self._server_down = trio.Event() - try: - async with trio.open_nursery() as server_n: - - listeners: list[trio.abc.Listener] = [] - for addr in listen_addrs: - try: - listener: trio.abc.Listener = await addr.open_listener() - except OSError as oserr: - if ( - '[Errno 98] Address already in use' - in - oserr.args#[0] - ): - log.exception( - f'Address already in use?\n' - f'{addr}\n' - ) - raise - listeners.append(listener) - - await server_n.start( - partial( - trio.serve_listeners, - handler=self._stream_handler, - listeners=listeners, - - # NOTE: configured such that new - # connections will stay alive even if - # this server is cancelled! - handler_nursery=handler_nursery - ) - ) - # TODO, wow make this message better! XD - log.info( - 'Started server(s)\n' - + - '\n'.join([f'|_{addr}' for addr in listen_addrs]) - ) - self._listen_addrs.extend(listen_addrs) - self._listeners.extend(listeners) - - task_status.started(server_n) - - finally: - addr: Address - for addr in listen_addrs: - addr.close_listener() - - # signal the server is down since nursery above terminated - self._server_down.set() - def cancel_soon(self) -> None: ''' Cancel this actor asap; can be called from a sync context. @@ -1481,18 +1412,9 @@ class Actor: ) # stop channel server - self.cancel_server() - if self._server_down is not None: - await self._server_down.wait() - else: - tpt_protos: list[str] = [] - addr: Address - for addr in self._listen_addrs: - tpt_protos.append(addr.proto_key) - log.warning( - 'Transport server(s) may have been cancelled before started?\n' - f'protos: {tpt_protos!r}\n' - ) + if ipc_server := self.ipc_server: + ipc_server.cancel() + await ipc_server.wait_for_shutdown() # cancel all rpc tasks permanently if self._service_n: @@ -1723,24 +1645,6 @@ class Actor: ) await self._ongoing_rpc_tasks.wait() - def cancel_server(self) -> bool: - ''' - Cancel the internal IPC transport server nursery thereby - preventing any new inbound IPC connections establishing. - - ''' - if self._server_n: - # TODO: obvi a different server type when we eventually - # support some others XD - server_prot: str = 'TCP' - log.runtime( - f'Cancelling {server_prot} server' - ) - self._server_n.cancel_scope.cancel() - return True - - return False - @property def accept_addrs(self) -> list[UnwrappedAddress]: ''' @@ -1748,7 +1652,7 @@ class Actor: and listens for new connections. ''' - return [a.unwrap() for a in self._listen_addrs] + return self._ipc_server.accept_addrs @property def accept_addr(self) -> UnwrappedAddress: @@ -1856,6 +1760,7 @@ async def async_main( addr: Address = transport_cls.get_random() accept_addrs.append(addr.unwrap()) + assert accept_addrs # The "root" nursery ensures the channel with the immediate # parent is kept alive as a resilient service until # cancellation steps have (mostly) occurred in @@ -1866,15 +1771,37 @@ async def async_main( actor._root_n = root_nursery assert actor._root_n - async with trio.open_nursery( - strict_exception_groups=False, - ) as service_nursery: + ipc_server: _server.IPCServer + async with ( + trio.open_nursery( + strict_exception_groups=False, + ) as service_nursery, + + _server.open_ipc_server( + actor=actor, + parent_tn=service_nursery, + stream_handler_tn=service_nursery, + ) as ipc_server, + # ) as actor._ipc_server, + # ^TODO? prettier? + + ): # This nursery is used to handle all inbound # connections to us such that if the TCP server # is killed, connections can continue to process # in the background until this nursery is cancelled. actor._service_n = service_nursery - assert actor._service_n + actor._ipc_server = ipc_server + assert ( + actor._service_n + and ( + actor._service_n + is + actor._ipc_server._parent_tn + is + ipc_server._stream_handler_tn + ) + ) # load exposed/allowed RPC modules # XXX: do this **after** establishing a channel to the parent @@ -1898,30 +1825,42 @@ async def async_main( # - subactor: the bind address is sent by our parent # over our established channel # - root actor: the ``accept_addr`` passed to this method - assert accept_addrs + # TODO: why is this not with the root nursery? try: - # TODO: why is this not with the root nursery? - actor._server_n = await service_nursery.start( - partial( - actor._serve_forever, - service_nursery, - listen_addrs=accept_addrs, - ) + log.runtime( + 'Booting IPC server' ) + eps: list = await ipc_server.listen_on( + actor=actor, + accept_addrs=accept_addrs, + stream_handler_nursery=service_nursery, + ) + log.runtime( + f'Booted IPC server\n' + f'{ipc_server}\n' + ) + assert ( + (eps[0].listen_tn) + is not service_nursery + ) + except OSError as oserr: # NOTE: always allow runtime hackers to debug # tranport address bind errors - normally it's # something silly like the wrong socket-address # passed via a config or CLI Bo - entered_debug: bool = await _debug._maybe_enter_pm(oserr) + entered_debug: bool = await _debug._maybe_enter_pm( + oserr, + ) if not entered_debug: - log.exception('Failed to init IPC channel server !?\n') + log.exception('Failed to init IPC server !?\n') else: log.runtime('Exited debug REPL..') raise + # TODO, just read direct from ipc_server? accept_addrs: list[UnwrappedAddress] = actor.accept_addrs # NOTE: only set the loopback addr for the @@ -1954,7 +1893,9 @@ async def async_main( async with get_registry(addr) as reg_portal: for accept_addr in accept_addrs: accept_addr = wrap_address(accept_addr) - assert accept_addr.is_valid + + if not accept_addr.is_valid: + breakpoint() await reg_portal.run_from_ns( 'self', diff --git a/tractor/ipc/_server.py b/tractor/ipc/_server.py new file mode 100644 index 00000000..f23cf697 --- /dev/null +++ b/tractor/ipc/_server.py @@ -0,0 +1,467 @@ +# tractor: structured concurrent "actors". +# Copyright 2018-eternity Tyler Goodlet. + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +''' +High-level "IPC server" encapsulation for all your +multi-transport-protcol needs! + +''' +from __future__ import annotations +from contextlib import ( + asynccontextmanager as acm, +) +from functools import partial +import inspect +from types import ( + ModuleType, +) +from typing import ( + Callable, + TYPE_CHECKING, +) + +import trio +from trio import ( + EventStatistics, + Nursery, + TaskStatus, + SocketListener, +) + +from ..msg import Struct +from ..trionics import maybe_open_nursery +from .. import ( + _state, + log, +) +from .._addr import Address +from ._transport import MsgTransport +from ._uds import UDSAddress +from ._tcp import TCPAddress + +if TYPE_CHECKING: + from .._runtime import Actor + + +log = log.get_logger(__name__) + + +class IPCEndpoint(Struct): + ''' + An instance of an IPC "bound" address where the lifetime of the + "ability to accept connections" (from clients) and then handle + those inbound sessions or sequences-of-packets is determined by + a (maybe pair of) nurser(y/ies). + + ''' + addr: Address + listen_tn: Nursery + stream_handler_tn: Nursery|None = None + + # NOTE, normally filled in by calling `.start_listener()` + _listener: SocketListener|None = None + + # ?TODO, mk stream_handler hook into this ep instance so that we + # always keep track of all `SocketStream` instances per + # listener/ep? + peer_tpts: dict[ + UDSAddress|TCPAddress, # peer addr + MsgTransport, # handle to encoded-msg transport stream + ] = {} + + async def start_listener(self) -> SocketListener: + tpt_mod: ModuleType = inspect.getmodule(self.addr) + lstnr: SocketListener = await tpt_mod.start_listener( + addr=self.addr, + ) + + # NOTE, for handling the resolved non-0 port for + # TCP/UDP network sockets. + if ( + (unwrapped := lstnr.socket.getsockname()) + != + self.addr.unwrap() + ): + self.addr=self.addr.from_addr(unwrapped) + + self._listener = lstnr + return lstnr + + def close_listener( + self, + ) -> bool: + tpt_mod: ModuleType = inspect.getmodule(self.addr) + closer: Callable = getattr( + tpt_mod, + 'close_listener', + False, + ) + # when no defined closing is implicit! + if not closer: + return True + return closer( + addr=self.addr, + lstnr=self._listener, + ) + + +class IPCServer(Struct): + _parent_tn: Nursery + _stream_handler_tn: Nursery + _endpoints: list[IPCEndpoint] = [] + + # syncs for setup/teardown sequences + _shutdown: trio.Event|None = None + + # TODO, maybe just make `._endpoints: list[IPCEndpoint]` and + # provide dict-views onto it? + # @property + # def addrs2eps(self) -> dict[Address, IPCEndpoint]: + # ... + + @property + def proto_keys(self) -> list[str]: + return [ + ep.addr.proto_key + for ep in self._endpoints + ] + + # def cancel_server(self) -> bool: + def cancel( + self, + + # !TODO, suport just shutting down accepting new clients, + # not existing ones! + # only_listeners: str|None = None + + ) -> bool: + ''' + Cancel this IPC transport server nursery thereby + preventing any new inbound IPC connections establishing. + + ''' + if self._parent_tn: + # TODO: obvi a different server type when we eventually + # support some others XD + log.runtime( + f'Cancelling server(s) for\n' + f'{self.proto_keys!r}\n' + ) + self._parent_tn.cancel_scope.cancel() + return True + + log.warning( + 'No IPC server started before cancelling ?' + ) + return False + + async def wait_for_shutdown( + self, + ) -> bool: + if self._shutdown is not None: + await self._shutdown.wait() + else: + tpt_protos: list[str] = [] + ep: IPCEndpoint + for ep in self._endpoints: + tpt_protos.append(ep.addr.proto_key) + + log.warning( + 'Transport server(s) may have been cancelled before started?\n' + f'protos: {tpt_protos!r}\n' + ) + + @property + def addrs(self) -> list[Address]: + return [ep.addr for ep in self._endpoints] + + @property + def accept_addrs(self) -> list[str, str|int]: + ''' + The `list` of `Address.unwrap()`-ed active IPC endpoint addrs. + + ''' + return [ep.addr.unwrap() for ep in self._endpoints] + + def epsdict(self) -> dict[ + Address, + IPCEndpoint, + ]: + return { + ep.addr: ep + for ep in self._endpoints + } + + def is_shutdown(self) -> bool: + if (ev := self._shutdown) is None: + return False + + return ev.is_set() + + def pformat(self) -> str: + + fmtstr: str = ( + f' |_endpoints: {self._endpoints}\n' + ) + if self._shutdown is not None: + shutdown_stats: EventStatistics = self._shutdown.statistics() + fmtstr += ( + f'\n' + f' |_shutdown: {shutdown_stats}\n' + ) + + return ( + f'\n' + ) + + __repr__ = pformat + + # TODO? maybe allow shutting down a `.listen_on()`s worth of + # listeners by cancelling the corresponding + # `IPCEndpoint._listen_tn` only ? + # -[ ] in theory you could use this to + # "boot-and-wait-for-reconnect" of all current and connecting + # peers? + # |_ would require that the stream-handler is intercepted so we + # can intercept every `MsgTransport` (stream) and track per + # `IPCEndpoint` likely? + # + # async def unlisten( + # self, + # listener: SocketListener, + # ) -> bool: + # ... + + async def listen_on( + self, + *, + actor: Actor, + accept_addrs: list[tuple[str, int|str]]|None = None, + stream_handler_nursery: Nursery|None = None, + ) -> list[IPCEndpoint]: + ''' + Start `SocketListeners` (i.e. bind and call `socket.listen()`) + for all IPC-transport-protocol specific `Address`-types + in `accept_addrs`. + + ''' + from .._addr import ( + default_lo_addrs, + wrap_address, + ) + if accept_addrs is None: + accept_addrs = default_lo_addrs([ + _state._def_tpt_proto + ]) + + else: + accept_addrs: list[Address] = [ + wrap_address(a) for a in accept_addrs + ] + + if self._shutdown is None: + self._shutdown = trio.Event() + + elif self.is_shutdown(): + raise RuntimeError( + f'IPC server has already terminated ?\n' + f'{self}\n' + ) + + log.info( + f'Binding to endpoints for,\n' + f'{accept_addrs}\n' + ) + eps: list[IPCEndpoint] = await self._parent_tn.start( + partial( + _serve_ipc_eps, + actor=actor, + server=self, + stream_handler_tn=stream_handler_nursery, + listen_addrs=accept_addrs, + ) + ) + log.info( + f'Started IPC endpoints\n' + f'{eps}\n' + ) + + self._endpoints.extend(eps) + # XXX, just a little bit of sanity + group_tn: Nursery|None = None + ep: IPCEndpoint + for ep in eps: + if ep.addr not in self.addrs: + breakpoint() + + if group_tn is None: + group_tn = ep.listen_tn + else: + assert group_tn is ep.listen_tn + + return eps + + +async def _serve_ipc_eps( + *, + actor: Actor, + server: IPCServer, + stream_handler_tn: Nursery, + listen_addrs: list[tuple[str, int|str]], + + task_status: TaskStatus[ + Nursery, + ] = trio.TASK_STATUS_IGNORED, +) -> None: + ''' + Start IPC transport server(s) for the actor, begin + listening/accepting new `trio.SocketStream` connections + from peer actors via a `SocketListener`. + + This will cause an actor to continue living (and thus + blocking at the process/OS-thread level) until + `.cancel_server()` is called. + + ''' + try: + listen_tn: Nursery + async with trio.open_nursery() as listen_tn: + + eps: list[IPCEndpoint] = [] + # XXX NOTE, required to call `serve_listeners()` below. + # ?TODO, maybe just pass `list(eps.values()` tho? + listeners: list[trio.abc.Listener] = [] + for addr in listen_addrs: + ep = IPCEndpoint( + addr=addr, + listen_tn=listen_tn, + stream_handler_tn=stream_handler_tn, + ) + try: + log.info( + f'Starting new endpoint listener\n' + f'{ep}\n' + ) + listener: trio.abc.Listener = await ep.start_listener() + assert listener is ep._listener + # if actor.is_registry: + # import pdbp; pdbp.set_trace() + + except OSError as oserr: + if ( + '[Errno 98] Address already in use' + in + oserr.args#[0] + ): + log.exception( + f'Address already in use?\n' + f'{addr}\n' + ) + raise + + listeners.append(listener) + eps.append(ep) + + _listeners: list[SocketListener] = await listen_tn.start( + partial( + trio.serve_listeners, + handler=actor._stream_handler, + listeners=listeners, + + # NOTE: configured such that new + # connections will stay alive even if + # this server is cancelled! + handler_nursery=stream_handler_tn + ) + ) + # TODO, wow make this message better! XD + log.info( + 'Started server(s)\n' + + + '\n'.join([f'|_{addr}' for addr in listen_addrs]) + ) + + log.info( + f'Started IPC endpoints\n' + f'{eps}\n' + ) + task_status.started( + eps, + ) + + finally: + if eps: + addr: Address + ep: IPCEndpoint + for addr, ep in server.epsdict().items(): + ep.close_listener() + server._endpoints.remove(ep) + + # if actor.is_arbiter: + # import pdbp; pdbp.set_trace() + + # signal the server is "shutdown"/"terminated" + # since no more active endpoints are active. + if not server._endpoints: + server._shutdown.set() + +@acm +async def open_ipc_server( + actor: Actor, + parent_tn: Nursery|None = None, + stream_handler_tn: Nursery|None = None, + +) -> IPCServer: + + async with maybe_open_nursery( + nursery=parent_tn, + ) as rent_tn: + ipc_server = IPCServer( + _parent_tn=rent_tn, + _stream_handler_tn=stream_handler_tn or rent_tn, + ) + try: + yield ipc_server + + # except BaseException as berr: + # log.exception( + # 'IPC server crashed on exit ?' + # ) + # raise berr + + finally: + # ?TODO, maybe we can ensure the endpoints are torndown + # (and thus their managed listeners) beforehand to ensure + # super graceful RPC mechanics? + # + # -[ ] but aren't we doing that already per-`listen_tn` + # inside `_serve_ipc_eps()` above? + # + # if not ipc_server.is_shutdown(): + # ipc_server.cancel() + # await ipc_server.wait_for_shutdown() + # assert ipc_server.is_shutdown() + pass + + # !XXX TODO! lol so classic, the below code is rekt! + # + # XXX here is a perfect example of suppressing errors with + # `trio.Cancelled` as per our demonstrating example, + # `test_trioisms::test_acm_embedded_nursery_propagates_enter_err + # + # with trio.CancelScope(shield=True): + # await ipc_server.wait_for_shutdown() diff --git a/tractor/ipc/_tcp.py b/tractor/ipc/_tcp.py index dbecdf5e..b534b143 100644 --- a/tractor/ipc/_tcp.py +++ b/tractor/ipc/_tcp.py @@ -18,7 +18,14 @@ TCP implementation of tractor.ipc._transport.MsgTransport protocol ''' from __future__ import annotations +from typing import ( + ClassVar, +) +# from contextlib import ( +# asynccontextmanager as acm, +# ) +import msgspec import trio from trio import ( SocketListener, @@ -27,33 +34,25 @@ from trio import ( from tractor.msg import MsgCodec from tractor.log import get_logger -from tractor.ipc._transport import MsgpackTransport +from tractor.ipc._transport import ( + MsgTransport, + MsgpackTransport, +) log = get_logger(__name__) -class TCPAddress: - proto_key: str = 'tcp' - unwrapped_type: type = tuple[str, int] - def_bindspace: str = '127.0.0.1' +class TCPAddress( + msgspec.Struct, + frozen=True, +): + _host: str + _port: int - def __init__( - self, - host: str, - port: int - ): - if ( - not isinstance(host, str) - or - not isinstance(port, int) - ): - raise TypeError( - f'Expected host {host!r} to be str and port {port!r} to be int' - ) - - self._host: str = host - self._port: int = port + proto_key: ClassVar[str] = 'tcp' + unwrapped_type: ClassVar[type] = tuple[str, int] + def_bindspace: ClassVar[str] = '127.0.0.1' @property def is_valid(self) -> bool: @@ -106,34 +105,42 @@ class TCPAddress: f'{type(self).__name__}[{self.unwrap()}]' ) - def __eq__(self, other) -> bool: - if not isinstance(other, TCPAddress): - raise TypeError( - f'Can not compare {type(other)} with {type(self)}' - ) + @classmethod + def get_transport( + cls, + codec: str = 'msgpack', + ) -> MsgTransport: + match codec: + case 'msgspack': + return MsgpackTCPStream + case _: + raise ValueError( + f'No IPC transport with {codec!r} supported !' + ) - return ( - self._host == other._host - and - self._port == other._port - ) - async def open_listener( - self, - **kwargs, - ) -> SocketListener: - listeners: list[SocketListener] = await open_tcp_listeners( - host=self._host, - port=self._port, - **kwargs - ) - assert len(listeners) == 1 - listener = listeners[0] - self._host, self._port = listener.socket.getsockname()[:2] - return listener +async def start_listener( + addr: TCPAddress, + **kwargs, +) -> SocketListener: + ''' + Start a TCP socket listener on the given `TCPAddress`. - async def close_listener(self): - ... + ''' + # ?TODO, maybe we should just change the lower-level call this is + # using internall per-listener? + listeners: list[SocketListener] = await open_tcp_listeners( + host=addr._host, + port=addr._port, + **kwargs + ) + # NOTE, for now we don't expect non-singleton-resolving + # domain-addresses/multi-homed-hosts. + # (though it is supported by `open_tcp_listeners()`) + assert len(listeners) == 1 + listener = listeners[0] + host, port = listener.socket.getsockname()[:2] + return listener # TODO: typing oddity.. not sure why we have to inherit here, but it diff --git a/tractor/ipc/_transport.py b/tractor/ipc/_transport.py index 2a9926f9..ec3c442c 100644 --- a/tractor/ipc/_transport.py +++ b/tractor/ipc/_transport.py @@ -104,7 +104,10 @@ class MsgTransport(Protocol): @classmethod def key(cls) -> MsgTransportKey: - return cls.codec_key, cls.address_type.proto_key + return ( + cls.codec_key, + cls.address_type.proto_key, + ) @property def laddr(self) -> Address: @@ -135,8 +138,8 @@ class MsgTransport(Protocol): Address # remote ]: ''' - Return the `trio` streaming transport prot's addrs for both - the local and remote sides as a pair. + Return the transport protocol's address pair for the local + and remote-peer side. ''' ... diff --git a/tractor/ipc/_types.py b/tractor/ipc/_types.py index 8d543d9d..59653b17 100644 --- a/tractor/ipc/_types.py +++ b/tractor/ipc/_types.py @@ -53,9 +53,12 @@ _msg_transports = [ # convert a MsgTransportKey to the corresponding transport type -_key_to_transport: dict[MsgTransportKey, Type[MsgTransport]] = { - cls.key(): cls - for cls in _msg_transports +_key_to_transport: dict[ + MsgTransportKey, + Type[MsgTransport], +] = { + ('msgpack', 'tcp'): MsgpackTCPStream, + ('msgpack', 'uds'): MsgpackUDSStream, } # convert an Address wrapper to its corresponding transport type @@ -63,8 +66,8 @@ _addr_to_transport: dict[ Type[TCPAddress|UDSAddress], Type[MsgTransport] ] = { - cls.address_type: cls - for cls in _msg_transports + TCPAddress: MsgpackTCPStream, + UDSAddress: MsgpackUDSStream, } diff --git a/tractor/ipc/_uds.py b/tractor/ipc/_uds.py index 33843f6a..604802f3 100644 --- a/tractor/ipc/_uds.py +++ b/tractor/ipc/_uds.py @@ -21,7 +21,6 @@ from __future__ import annotations from pathlib import Path import os from socket import ( - # socket, AF_UNIX, SOCK_STREAM, SO_PASSCRED, @@ -31,8 +30,10 @@ from socket import ( import struct from typing import ( TYPE_CHECKING, + ClassVar, ) +import msgspec import trio from trio import ( socket, @@ -70,56 +71,22 @@ def unwrap_sockpath( ) -class UDSAddress: +class UDSAddress( + msgspec.Struct, + frozen=True, +): + filedir: str|Path|None + filename: str|Path + maybe_pid: int|None = None + # TODO, maybe we should use better field and value # -[x] really this is a `.protocol_key` not a "name" of anything. # -[ ] consider a 'unix' proto-key instead? # -[ ] need to check what other mult-transport frameworks do # like zmq, nng, uri-spec et al! - proto_key: str = 'uds' - unwrapped_type: type = tuple[str, int] - def_bindspace: Path = get_rt_dir() - - def __init__( - self, - filedir: Path|str|None, - # TODO, i think i want `.filename` here? - filename: str|Path, - - # XXX, in the sense you can also pass - # a "non-real-world-process-id" such as is handy to represent - # our host-local default "port-like" key for the very first - # root actor to create a registry address. - maybe_pid: int|None = None, - ): - fdir = self._filedir = Path( - filedir - or - self.def_bindspace - ).absolute() - fpath = self._filename = Path(filename) - fp: Path = fdir / fpath - assert ( - fp.is_absolute() - and - fp == self.sockpath - ) - - # to track which "side" is the peer process by reading socket - # credentials-info. - self._pid: int = maybe_pid - - @property - def sockpath(self) -> Path: - return self._filedir / self._filename - - @property - def is_valid(self) -> bool: - ''' - We block socket files not allocated under the runtime subdir. - - ''' - return self.bindspace in self.sockpath.parents + proto_key: ClassVar[str] = 'uds' + unwrapped_type: ClassVar[type] = tuple[str, int] + def_bindspace: ClassVar[Path] = get_rt_dir() @property def bindspace(self) -> Path: @@ -128,7 +95,25 @@ class UDSAddress: just the sub-directory in which we allocate socket files. ''' - return self._filedir or self.def_bindspace + return ( + self.filedir + or + self.def_bindspace + # or + # get_rt_dir() + ) + + @property + def sockpath(self) -> Path: + return self.bindspace / self.filename + + @property + def is_valid(self) -> bool: + ''' + We block socket files not allocated under the runtime subdir. + + ''' + return self.bindspace in self.sockpath.parents @classmethod def from_addr( @@ -141,9 +126,6 @@ class UDSAddress: case tuple()|list(): filedir = Path(addr[0]) filename = Path(addr[1]) - # sockpath: Path = Path(addr[0]) - # filedir, filename = unwrap_sockpath(sockpath) - # pid: int = addr[1] return UDSAddress( filedir=filedir, filename=filename, @@ -165,8 +147,8 @@ class UDSAddress: # XXX NOTE, since this gets passed DIRECTLY to # `.ipc._uds.open_unix_socket_w_passcred()` return ( - str(self._filedir), - str(self._filename), + str(self.filedir), + str(self.filename), ) @classmethod @@ -199,55 +181,77 @@ class UDSAddress: def get_root(cls) -> UDSAddress: def_uds_filename: Path = 'registry@1616.sock' return UDSAddress( - filedir=None, + filedir=cls.def_bindspace, filename=def_uds_filename, # maybe_pid=1616, ) + # ?TODO, maybe we should just our .msg.pretty_struct.Struct` for + # this instead? + # -[ ] is it too "multi-line"y tho? + # the compact tuple/.unwrapped() form is simple enough? + # def __repr__(self) -> str: + if not (pid := self.maybe_pid): + pid: str = '' + + body: str = ( + f'({self.filedir}, {self.filename}, {pid})' + ) return ( f'{type(self).__name__}' f'[' - f'({self._filedir}, {self._filename})' + f'{body}' f']' ) - def __eq__(self, other) -> bool: - if not isinstance(other, UDSAddress): - raise TypeError( - f'Can not compare {type(other)} with {type(self)}' - ) - return self.sockpath == other.sockpath +async def start_listener( + addr: UDSAddress, + **kwargs, +) -> SocketListener: + # sock = addr._sock = socket.socket( + sock = socket.socket( + socket.AF_UNIX, + socket.SOCK_STREAM + ) + log.info( + f'Attempting to bind UDS socket\n' + f'>[\n' + f'|_{addr}\n' + ) - # async def open_listener(self, **kwargs) -> SocketListener: - async def open_listener( - self, - **kwargs, - ) -> SocketListener: - sock = self._sock = socket.socket( - socket.AF_UNIX, - socket.SOCK_STREAM - ) - log.info( - f'Attempting to bind UDS socket\n' - f'>[\n' - f'|_{self}\n' - ) - - bindpath: Path = self.sockpath + bindpath: Path = addr.sockpath + try: await sock.bind(str(bindpath)) - sock.listen(1) - log.info( - f'Listening on UDS socket\n' - f'[>\n' - f' |_{self}\n' - ) - return SocketListener(self._sock) + except ( + FileNotFoundError, + ) as fdne: + raise ConnectionError( + f'Bad UDS socket-filepath-as-address ??\n' + f'{addr}\n' + f' |_sockpath: {addr.sockpath}\n' + ) from fdne - def close_listener(self): - self._sock.close() - os.unlink(self.sockpath) + sock.listen(1) + log.info( + f'Listening on UDS socket\n' + f'[>\n' + f' |_{addr}\n' + ) + return SocketListener(sock) + + +def close_listener( + addr: UDSAddress, + lstnr: SocketListener, +) -> None: + ''' + Close and remove the listening unix socket's path. + + ''' + lstnr.socket.close() + os.unlink(addr.sockpath) async def open_unix_socket_w_passcred( @@ -416,5 +420,3 @@ class MsgpackUDSStream(MsgpackTransport): maybe_pid=peer_pid ) return (laddr, raddr) - -