Improve error handling in fdshare functions, add comments

one_ring_to_rule_them_all
Guillermo Rodriguez 2025-04-03 11:48:07 -03:00
parent 28b86cb880
commit 4b9d6b9276
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
1 changed files with 102 additions and 39 deletions

View File

@ -14,81 +14,144 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
''' '''
Re-Impl of multiprocessing.reduction.sendfds & recvfds, Reimplementation of multiprocessing.reduction.sendfds & recvfds, using acms and trio.
using acms and trio
cpython impl:
https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L138
''' '''
import array import array
from typing import AsyncContextManager
from contextlib import asynccontextmanager as acm from contextlib import asynccontextmanager as acm
import trio import trio
from trio import socket from trio import socket
class FDSharingError(Exception):
...
@acm @acm
async def send_fds(fds: list[int], sock_path: str): async def send_fds(fds: list[int], sock_path: str) -> AsyncContextManager[None]:
'''
Async trio reimplementation of `multiprocessing.reduction.sendfds`
https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L142
It's implemented using an async context manager in order to simplyfy usage
with `tractor.context`s, we can open a context in a remote actor that uses
this acm inside of it, and uses `ctx.started()` to signal the original
caller actor to perform the `recv_fds` call.
See `tractor.ipc._ringbuf._ringd._attach_to_ring` for an example.
'''
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
await sock.bind(sock_path) await sock.bind(sock_path)
sock.listen(1) sock.listen(1)
yield
fds = array.array('i', fds) yield # socket is setup, ready for receiver connect
# first byte of msg will be len of fds to send % 256
msg = bytes([len(fds) % 256]) # wait until receiver connects
conn, _ = await sock.accept() conn, _ = await sock.accept()
# setup int array for fds
fds = array.array('i', fds)
# first byte of msg will be len of fds to send % 256, acting as a fd amount
# verification on `recv_fds` we refer to it as `check_byte`
msg = bytes([len(fds) % 256])
# send msg with custom SCM_RIGHTS type
await conn.sendmsg( await conn.sendmsg(
[msg], [msg],
[(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds)] [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds)]
) )
# wait ack
# finally wait receiver ack
if await conn.recv(1) != b'A': if await conn.recv(1) != b'A':
raise RuntimeError('did not receive acknowledgement of fd') raise FDSharingError('did not receive acknowledgement of fd')
conn.close() conn.close()
sock.close() sock.close()
async def recv_fds(sock_path: str, amount: int) -> tuple: async def recv_fds(sock_path: str, amount: int) -> tuple:
'''
Async trio reimplementation of `multiprocessing.reduction.recvfds`
https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L150
It's equivalent to std just using `trio.open_unix_socket` for connecting and
changes on error handling.
See `tractor.ipc._ringbuf._ringd._attach_to_ring` for an example.
'''
stream = await trio.open_unix_socket(sock_path) stream = await trio.open_unix_socket(sock_path)
sock = stream.socket sock = stream.socket
# prepare int array for fds
a = array.array('i') a = array.array('i')
bytes_size = a.itemsize * amount bytes_size = a.itemsize * amount
# receive 1 byte + space necesary for SCM_RIGHTS msg for {amount} fds
msg, ancdata, flags, addr = await sock.recvmsg( msg, ancdata, flags, addr = await sock.recvmsg(
1, socket.CMSG_SPACE(bytes_size) 1, socket.CMSG_SPACE(bytes_size)
) )
# maybe failed to receive msg?
if not msg and not ancdata: if not msg and not ancdata:
raise EOFError raise FDSharingError(f'Expected to receive {amount} fds from {sock_path}, but got EOF')
try:
# send ack, std comment mentions this ack pattern was to get around an
# old macosx bug, but they are not sure if its necesary any more, in
# any case its not a bad pattern to keep
await sock.send(b'A') # Ack await sock.send(b'A') # Ack
# expect to receive only one `ancdata` item
if len(ancdata) != 1: if len(ancdata) != 1:
raise RuntimeError( raise FDSharingError(
f'received {len(ancdata)} items of ancdata' f'Expected to receive exactly one \"ancdata\" but got {len(ancdata)}: {ancdata}'
) )
# unpack SCM_RIGHTS msg
cmsg_level, cmsg_type, cmsg_data = ancdata[0] cmsg_level, cmsg_type, cmsg_data = ancdata[0]
# check proper msg type # check proper msg type
if ( if cmsg_level != socket.SOL_SOCKET:
cmsg_level == socket.SOL_SOCKET raise FDSharingError(
and f'Expected CMSG level to be SOL_SOCKET({socket.SOL_SOCKET}) but got {cmsg_level}'
cmsg_type == socket.SCM_RIGHTS )
):
if cmsg_type != socket.SCM_RIGHTS:
raise FDSharingError(
f'Expected CMSG type to be SCM_RIGHTS({socket.SCM_RIGHTS}) but got {cmsg_type}'
)
# check proper data alignment # check proper data alignment
if len(cmsg_data) % a.itemsize != 0: length = len(cmsg_data)
raise ValueError if length % a.itemsize != 0:
raise FDSharingError(
f'CMSG data alignment error: len of {length} is not divisible by int size {a.itemsize}'
)
# attempt to cast as int array # attempt to cast as int array
a.frombytes(cmsg_data) a.frombytes(cmsg_data)
# check first byte of message is amount % 256 # validate length check byte
if len(a) % 256 != msg[0]: valid_check_byte = amount % 256 # check byte acording to `recv_fds` caller
raise AssertionError( recvd_check_byte = msg[0] # actual received check byte
'Len is {0:n} but msg[0] is {1!r}'.format( payload_check_byte = len(a) % 256 # check byte acording to received fd int array
len(a), msg[0]
if recvd_check_byte != payload_check_byte:
raise FDSharingError(
'Validation failed: received check byte '
f'({recvd_check_byte}) does not match fd int array len % 256 ({payload_check_byte})'
) )
if valid_check_byte != recvd_check_byte:
raise FDSharingError(
'Validation failed: received check byte '
f'({recvd_check_byte}) does not match expected fd amount % 256 ({valid_check_byte})'
) )
return tuple(a) return tuple(a)
except (ValueError, IndexError):
pass
raise RuntimeError('Invalid data received')