Improve error handling in fdshare functions, add comments

one_ring_to_rule_them_all
Guillermo Rodriguez 2025-04-03 11:48:07 -03:00
parent 28b86cb880
commit 4b9d6b9276
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
1 changed files with 102 additions and 39 deletions

View File

@ -14,81 +14,144 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
'''
Re-Impl of multiprocessing.reduction.sendfds & recvfds,
using acms and trio
Reimplementation of multiprocessing.reduction.sendfds & recvfds, using acms and trio.
cpython impl:
https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L138
'''
import array
from typing import AsyncContextManager
from contextlib import asynccontextmanager as acm
import trio
from trio import socket
class FDSharingError(Exception):
...
@acm
async def send_fds(fds: list[int], sock_path: str):
async def send_fds(fds: list[int], sock_path: str) -> AsyncContextManager[None]:
'''
Async trio reimplementation of `multiprocessing.reduction.sendfds`
https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L142
It's implemented using an async context manager in order to simplyfy usage
with `tractor.context`s, we can open a context in a remote actor that uses
this acm inside of it, and uses `ctx.started()` to signal the original
caller actor to perform the `recv_fds` call.
See `tractor.ipc._ringbuf._ringd._attach_to_ring` for an example.
'''
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
await sock.bind(sock_path)
sock.listen(1)
yield
fds = array.array('i', fds)
# first byte of msg will be len of fds to send % 256
msg = bytes([len(fds) % 256])
yield # socket is setup, ready for receiver connect
# wait until receiver connects
conn, _ = await sock.accept()
# setup int array for fds
fds = array.array('i', fds)
# first byte of msg will be len of fds to send % 256, acting as a fd amount
# verification on `recv_fds` we refer to it as `check_byte`
msg = bytes([len(fds) % 256])
# send msg with custom SCM_RIGHTS type
await conn.sendmsg(
[msg],
[(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds)]
)
# wait ack
# finally wait receiver ack
if await conn.recv(1) != b'A':
raise RuntimeError('did not receive acknowledgement of fd')
raise FDSharingError('did not receive acknowledgement of fd')
conn.close()
sock.close()
async def recv_fds(sock_path: str, amount: int) -> tuple:
'''
Async trio reimplementation of `multiprocessing.reduction.recvfds`
https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L150
It's equivalent to std just using `trio.open_unix_socket` for connecting and
changes on error handling.
See `tractor.ipc._ringbuf._ringd._attach_to_ring` for an example.
'''
stream = await trio.open_unix_socket(sock_path)
sock = stream.socket
# prepare int array for fds
a = array.array('i')
bytes_size = a.itemsize * amount
# receive 1 byte + space necesary for SCM_RIGHTS msg for {amount} fds
msg, ancdata, flags, addr = await sock.recvmsg(
1, socket.CMSG_SPACE(bytes_size)
)
# maybe failed to receive msg?
if not msg and not ancdata:
raise EOFError
try:
await sock.send(b'A') # Ack
raise FDSharingError(f'Expected to receive {amount} fds from {sock_path}, but got EOF')
if len(ancdata) != 1:
raise RuntimeError(
f'received {len(ancdata)} items of ancdata'
)
# send ack, std comment mentions this ack pattern was to get around an
# old macosx bug, but they are not sure if its necesary any more, in
# any case its not a bad pattern to keep
await sock.send(b'A') # Ack
cmsg_level, cmsg_type, cmsg_data = ancdata[0]
# check proper msg type
if (
cmsg_level == socket.SOL_SOCKET
and
cmsg_type == socket.SCM_RIGHTS
):
# check proper data alignment
if len(cmsg_data) % a.itemsize != 0:
raise ValueError
# expect to receive only one `ancdata` item
if len(ancdata) != 1:
raise FDSharingError(
f'Expected to receive exactly one \"ancdata\" but got {len(ancdata)}: {ancdata}'
)
# attempt to cast as int array
a.frombytes(cmsg_data)
# unpack SCM_RIGHTS msg
cmsg_level, cmsg_type, cmsg_data = ancdata[0]
# check first byte of message is amount % 256
if len(a) % 256 != msg[0]:
raise AssertionError(
'Len is {0:n} but msg[0] is {1!r}'.format(
len(a), msg[0]
)
)
# check proper msg type
if cmsg_level != socket.SOL_SOCKET:
raise FDSharingError(
f'Expected CMSG level to be SOL_SOCKET({socket.SOL_SOCKET}) but got {cmsg_level}'
)
return tuple(a)
if cmsg_type != socket.SCM_RIGHTS:
raise FDSharingError(
f'Expected CMSG type to be SCM_RIGHTS({socket.SCM_RIGHTS}) but got {cmsg_type}'
)
except (ValueError, IndexError):
pass
# check proper data alignment
length = len(cmsg_data)
if length % a.itemsize != 0:
raise FDSharingError(
f'CMSG data alignment error: len of {length} is not divisible by int size {a.itemsize}'
)
raise RuntimeError('Invalid data received')
# attempt to cast as int array
a.frombytes(cmsg_data)
# validate length check byte
valid_check_byte = amount % 256 # check byte acording to `recv_fds` caller
recvd_check_byte = msg[0] # actual received check byte
payload_check_byte = len(a) % 256 # check byte acording to received fd int array
if recvd_check_byte != payload_check_byte:
raise FDSharingError(
'Validation failed: received check byte '
f'({recvd_check_byte}) does not match fd int array len % 256 ({payload_check_byte})'
)
if valid_check_byte != recvd_check_byte:
raise FDSharingError(
'Validation failed: received check byte '
f'({recvd_check_byte}) does not match expected fd amount % 256 ({valid_check_byte})'
)
return tuple(a)