diff --git a/tractor/ipc/_tcp.py b/tractor/ipc/_tcp.py index a8008519..eb2003ec 100644 --- a/tractor/ipc/_tcp.py +++ b/tractor/ipc/_tcp.py @@ -42,24 +42,15 @@ class MsgpackTCPStream(MsgpackTransport): address_type = TCPAddress layer_key: int = 4 - # def __init__( - # self, - # stream: trio.SocketStream, - # prefix_size: int = 4, - # codec: CodecType = None, - - # ) -> None: - # super().__init__( - # stream, - # prefix_size=prefix_size, - # codec=codec - # ) - @property def maddr(self) -> str: host, port = self.raddr.unwrap() return ( + # TODO, use `ipaddress` from stdlib to handle + # first detecting which of `ipv4/6` before + # choosing the routing prefix part. f'/ipv4/{host}' + f'/{self.address_type.name_key}/{port}' # f'/{self.chan.uid[0]}' # f'/{self.cid}' @@ -94,12 +85,15 @@ class MsgpackTCPStream(MsgpackTransport): cls, stream: trio.SocketStream ) -> tuple[ - tuple[str, int], - tuple[str, int] + TCPAddress, + TCPAddress, ]: + # TODO, what types are these? lsockname = stream.socket.getsockname() + l_sockaddr: tuple[str, int] = tuple(lsockname[:2]) rsockname = stream.socket.getpeername() + r_sockaddr: tuple[str, int] = tuple(rsockname[:2]) return ( - TCPAddress.from_addr(tuple(lsockname[:2])), - TCPAddress.from_addr(tuple(rsockname[:2])), + TCPAddress.from_addr(l_sockaddr), + TCPAddress.from_addr(r_sockaddr), ) diff --git a/tractor/ipc/_uds.py b/tractor/ipc/_uds.py index ee147d42..894e3fbc 100644 --- a/tractor/ipc/_uds.py +++ b/tractor/ipc/_uds.py @@ -18,8 +18,23 @@ Unix Domain Socket implementation of tractor.ipc._transport.MsgTransport protoco ''' from __future__ import annotations +from pathlib import Path +import os +from socket import ( + # socket, + AF_UNIX, + SOCK_STREAM, + SO_PASSCRED, + SO_PEERCRED, + SOL_SOCKET, +) +import struct import trio +from trio._highlevel_open_unix_stream import ( + close_on_error, + has_unix, +) from tractor.msg import MsgCodec from tractor.log import get_logger @@ -30,33 +45,80 @@ from tractor.ipc._transport import MsgpackTransport log = get_logger(__name__) +async def open_unix_socket_w_passcred( + filename: str|bytes|os.PathLike[str]|os.PathLike[bytes], +) -> trio.SocketStream: + ''' + Literally the exact same as `trio.open_unix_socket()` except we set the additiona + `socket.SO_PASSCRED` option to ensure the server side (the process calling `accept()`) + can extract the connecting peer's credentials, namely OS specific process + related IDs. + + See this SO for "why" the extra opts, + - https://stackoverflow.com/a/7982749 + + ''' + if not has_unix: + raise RuntimeError("Unix sockets are not supported on this platform") + + # much more simplified logic vs tcp sockets - one socket type and only one + # possible location to connect to + sock = trio.socket.socket(AF_UNIX, SOCK_STREAM) + sock.setsockopt(SOL_SOCKET, SO_PASSCRED, 1) + with close_on_error(sock): + await sock.connect(os.fspath(filename)) + + return trio.SocketStream(sock) + + +def get_peer_info(sock: trio.socket.socket) -> tuple[ + int, # pid + int, # uid + int, # guid +]: + ''' + Deliver the connecting peer's "credentials"-info as defined in + a very Linux specific way.. + + For more deats see, + - `man accept`, + - `man unix`, + + this great online guide to all things sockets, + - https://beej.us/guide/bgnet/html/split-wide/man-pages.html#setsockoptman + + AND this **wonderful SO answer** + - https://stackoverflow.com/a/7982749 + + ''' + creds: bytes = sock.getsockopt( + SOL_SOCKET, + SO_PEERCRED, + struct.calcsize('3i') + ) + # i.e a tuple of the fields, + # pid: int, "process" + # uid: int, "user" + # gid: int, "group" + return struct.unpack('3i', creds) + + class MsgpackUDSStream(MsgpackTransport): ''' - A ``trio.SocketStream`` delivering ``msgpack`` formatted data - using the ``msgspec`` codec lib. + A `trio.SocketStream` around a Unix-Domain-Socket transport + delivering `msgpack` encoded msgs using the `msgspec` codec lib. ''' address_type = UDSAddress - layer_key: int = 7 - - # def __init__( - # self, - # stream: trio.SocketStream, - # prefix_size: int = 4, - # codec: CodecType = None, - - # ) -> None: - # super().__init__( - # stream, - # prefix_size=prefix_size, - # codec=codec - # ) + layer_key: int = 4 @property def maddr(self) -> str: - filepath = self.raddr.unwrap() + if not self.raddr: + return '' + + filepath: Path = Path(self.raddr.unwrap()[0]) return ( - f'/ipv4/localhost' f'/{self.address_type.name_key}/{filepath}' # f'/{self.chan.uid[0]}' # f'/{self.cid}' @@ -76,22 +138,72 @@ class MsgpackUDSStream(MsgpackTransport): codec: MsgCodec|None = None, **kwargs ) -> MsgpackUDSStream: - stream = await trio.open_unix_socket( - addr.unwrap(), + + filepath: Path + pid: int + ( + filepath, + pid, + ) = addr.unwrap() + + # XXX NOTE, we don't need to provide the `.pid` part from + # the addr since the OS does this implicitly! .. lel + # stream = await trio.open_unix_socket( + stream = await open_unix_socket_w_passcred( + str(filepath), **kwargs ) - return MsgpackUDSStream( + stream = MsgpackUDSStream( stream, prefix_size=prefix_size, codec=codec ) + stream._raddr = addr + return stream @classmethod def get_stream_addrs( cls, stream: trio.SocketStream - ) -> tuple[UDSAddress, UDSAddress]: - return ( - UDSAddress.from_addr(stream.socket.getsockname()), - UDSAddress.from_addr(stream.socket.getsockname()), + ) -> tuple[ + Path, + int, + ]: + sock: trio.socket.socket = stream.socket + + # NOTE XXX, it's unclear why one or the other ends up being + # `bytes` versus the socket-file-path, i presume it's + # something to do with who is the server (called `.listen()`)? + # maybe could be better implemented using another info-query + # on the socket like, + # https://beej.us/guide/bgnet/html/split-wide/system-calls-or-bust.html#gethostnamewho-am-i + sockname: str|bytes = sock.getsockname() + # https://beej.us/guide/bgnet/html/split-wide/system-calls-or-bust.html#getpeernamewho-are-you + peername: str|bytes = sock.getpeername() + match (peername, sockname): + case (str(), bytes()): + sock_path: Path = Path(peername) + case (bytes(), str()): + sock_path: Path = Path(sockname) + ( + pid, + uid, + gid, + ) = get_peer_info(sock) + log.info( + f'UDS connection from process {pid!r}\n' + f'>[\n' + f'|_{sock_path}\n' + f' |_pid: {pid}\n' + f' |_uid: {uid}\n' + f' |_gid: {gid}\n' ) + laddr = UDSAddress.from_addr(( + sock_path, + os.getpid(), + )) + raddr = UDSAddress.from_addr(( + sock_path, + pid + )) + return (laddr, raddr)