Implement peer-info tracking for UDS streams

Such that any UDS socket pair is represented (and with the recent
updates to) a `USDAddress` via a similar pair-`tuple[str, int]` as TCP
sockets, a pair of the `.filepath: Path` & the peer proc's `.pid: int`
which we read from the underlying `socket.socket` using
`.set/getsockopt()` calls

Impl deats,
- using the Linux specific APIs, we add a `get_peer_info()` which reads
  the `(pid, uid, gid)` using the `SOL_SOCKET` and `SOL_PEECRED` opts to
  `sock.getsockopt()`.
  |_ this presumes the client has been correspondingly configured to
     deliver the creds via a `sock.setsockopt(SOL_SOCKET, SO_PASSCRED,
     1)` call - this required us to override `trio.open_unix_socket()`.
- override `trio.open_unix_socket()` as per the above bullet to ensure
  connecting peers always transmit "credentials" options info to the
  listener.
- update `.get_stream_addrs()` to always call `get_peer_info()` and
  extract the peer's pid for the `raddr` and use `os.getpid()` for
  `laddr` (obvi).
  |_ as part of the new impl also `log.info()` the creds-info deats and
    socket-file path.
  |_ handle the oddity where it depends which of `.getpeername()` or
    `.getsockname()` will return the file-path; i think it's to do with
    who is client vs. server?

Related refinements,
- set `.layer_key: int = 4` for the "transport layer" ;)
- tweak some typing and multi-line unpacking in `.ipc/_tcp`.
structural_dynamics_of_flow
Tyler Goodlet 2025-03-30 21:00:36 -04:00
parent 4a8a555bdf
commit bf9d7ba074
2 changed files with 148 additions and 42 deletions

View File

@ -42,24 +42,15 @@ class MsgpackTCPStream(MsgpackTransport):
address_type = TCPAddress address_type = TCPAddress
layer_key: int = 4 layer_key: int = 4
# def __init__(
# self,
# stream: trio.SocketStream,
# prefix_size: int = 4,
# codec: CodecType = None,
# ) -> None:
# super().__init__(
# stream,
# prefix_size=prefix_size,
# codec=codec
# )
@property @property
def maddr(self) -> str: def maddr(self) -> str:
host, port = self.raddr.unwrap() host, port = self.raddr.unwrap()
return ( return (
# TODO, use `ipaddress` from stdlib to handle
# first detecting which of `ipv4/6` before
# choosing the routing prefix part.
f'/ipv4/{host}' f'/ipv4/{host}'
f'/{self.address_type.name_key}/{port}' f'/{self.address_type.name_key}/{port}'
# f'/{self.chan.uid[0]}' # f'/{self.chan.uid[0]}'
# f'/{self.cid}' # f'/{self.cid}'
@ -94,12 +85,15 @@ class MsgpackTCPStream(MsgpackTransport):
cls, cls,
stream: trio.SocketStream stream: trio.SocketStream
) -> tuple[ ) -> tuple[
tuple[str, int], TCPAddress,
tuple[str, int] TCPAddress,
]: ]:
# TODO, what types are these?
lsockname = stream.socket.getsockname() lsockname = stream.socket.getsockname()
l_sockaddr: tuple[str, int] = tuple(lsockname[:2])
rsockname = stream.socket.getpeername() rsockname = stream.socket.getpeername()
r_sockaddr: tuple[str, int] = tuple(rsockname[:2])
return ( return (
TCPAddress.from_addr(tuple(lsockname[:2])), TCPAddress.from_addr(l_sockaddr),
TCPAddress.from_addr(tuple(rsockname[:2])), TCPAddress.from_addr(r_sockaddr),
) )

View File

@ -18,8 +18,23 @@ Unix Domain Socket implementation of tractor.ipc._transport.MsgTransport protoco
''' '''
from __future__ import annotations from __future__ import annotations
from pathlib import Path
import os
from socket import (
# socket,
AF_UNIX,
SOCK_STREAM,
SO_PASSCRED,
SO_PEERCRED,
SOL_SOCKET,
)
import struct
import trio import trio
from trio._highlevel_open_unix_stream import (
close_on_error,
has_unix,
)
from tractor.msg import MsgCodec from tractor.msg import MsgCodec
from tractor.log import get_logger from tractor.log import get_logger
@ -30,33 +45,80 @@ from tractor.ipc._transport import MsgpackTransport
log = get_logger(__name__) log = get_logger(__name__)
async def open_unix_socket_w_passcred(
filename: str|bytes|os.PathLike[str]|os.PathLike[bytes],
) -> trio.SocketStream:
'''
Literally the exact same as `trio.open_unix_socket()` except we set the additiona
`socket.SO_PASSCRED` option to ensure the server side (the process calling `accept()`)
can extract the connecting peer's credentials, namely OS specific process
related IDs.
See this SO for "why" the extra opts,
- https://stackoverflow.com/a/7982749
'''
if not has_unix:
raise RuntimeError("Unix sockets are not supported on this platform")
# much more simplified logic vs tcp sockets - one socket type and only one
# possible location to connect to
sock = trio.socket.socket(AF_UNIX, SOCK_STREAM)
sock.setsockopt(SOL_SOCKET, SO_PASSCRED, 1)
with close_on_error(sock):
await sock.connect(os.fspath(filename))
return trio.SocketStream(sock)
def get_peer_info(sock: trio.socket.socket) -> tuple[
int, # pid
int, # uid
int, # guid
]:
'''
Deliver the connecting peer's "credentials"-info as defined in
a very Linux specific way..
For more deats see,
- `man accept`,
- `man unix`,
this great online guide to all things sockets,
- https://beej.us/guide/bgnet/html/split-wide/man-pages.html#setsockoptman
AND this **wonderful SO answer**
- https://stackoverflow.com/a/7982749
'''
creds: bytes = sock.getsockopt(
SOL_SOCKET,
SO_PEERCRED,
struct.calcsize('3i')
)
# i.e a tuple of the fields,
# pid: int, "process"
# uid: int, "user"
# gid: int, "group"
return struct.unpack('3i', creds)
class MsgpackUDSStream(MsgpackTransport): class MsgpackUDSStream(MsgpackTransport):
''' '''
A ``trio.SocketStream`` delivering ``msgpack`` formatted data A `trio.SocketStream` around a Unix-Domain-Socket transport
using the ``msgspec`` codec lib. delivering `msgpack` encoded msgs using the `msgspec` codec lib.
''' '''
address_type = UDSAddress address_type = UDSAddress
layer_key: int = 7 layer_key: int = 4
# def __init__(
# self,
# stream: trio.SocketStream,
# prefix_size: int = 4,
# codec: CodecType = None,
# ) -> None:
# super().__init__(
# stream,
# prefix_size=prefix_size,
# codec=codec
# )
@property @property
def maddr(self) -> str: def maddr(self) -> str:
filepath = self.raddr.unwrap() if not self.raddr:
return '<unknown-peer>'
filepath: Path = Path(self.raddr.unwrap()[0])
return ( return (
f'/ipv4/localhost'
f'/{self.address_type.name_key}/{filepath}' f'/{self.address_type.name_key}/{filepath}'
# f'/{self.chan.uid[0]}' # f'/{self.chan.uid[0]}'
# f'/{self.cid}' # f'/{self.cid}'
@ -76,22 +138,72 @@ class MsgpackUDSStream(MsgpackTransport):
codec: MsgCodec|None = None, codec: MsgCodec|None = None,
**kwargs **kwargs
) -> MsgpackUDSStream: ) -> MsgpackUDSStream:
stream = await trio.open_unix_socket(
addr.unwrap(), filepath: Path
pid: int
(
filepath,
pid,
) = addr.unwrap()
# XXX NOTE, we don't need to provide the `.pid` part from
# the addr since the OS does this implicitly! .. lel
# stream = await trio.open_unix_socket(
stream = await open_unix_socket_w_passcred(
str(filepath),
**kwargs **kwargs
) )
return MsgpackUDSStream( stream = MsgpackUDSStream(
stream, stream,
prefix_size=prefix_size, prefix_size=prefix_size,
codec=codec codec=codec
) )
stream._raddr = addr
return stream
@classmethod @classmethod
def get_stream_addrs( def get_stream_addrs(
cls, cls,
stream: trio.SocketStream stream: trio.SocketStream
) -> tuple[UDSAddress, UDSAddress]: ) -> tuple[
return ( Path,
UDSAddress.from_addr(stream.socket.getsockname()), int,
UDSAddress.from_addr(stream.socket.getsockname()), ]:
sock: trio.socket.socket = stream.socket
# NOTE XXX, it's unclear why one or the other ends up being
# `bytes` versus the socket-file-path, i presume it's
# something to do with who is the server (called `.listen()`)?
# maybe could be better implemented using another info-query
# on the socket like,
# https://beej.us/guide/bgnet/html/split-wide/system-calls-or-bust.html#gethostnamewho-am-i
sockname: str|bytes = sock.getsockname()
# https://beej.us/guide/bgnet/html/split-wide/system-calls-or-bust.html#getpeernamewho-are-you
peername: str|bytes = sock.getpeername()
match (peername, sockname):
case (str(), bytes()):
sock_path: Path = Path(peername)
case (bytes(), str()):
sock_path: Path = Path(sockname)
(
pid,
uid,
gid,
) = get_peer_info(sock)
log.info(
f'UDS connection from process {pid!r}\n'
f'>[\n'
f'|_{sock_path}\n'
f' |_pid: {pid}\n'
f' |_uid: {uid}\n'
f' |_gid: {gid}\n'
) )
laddr = UDSAddress.from_addr((
sock_path,
os.getpid(),
))
raddr = UDSAddress.from_addr((
sock_path,
pid
))
return (laddr, raddr)