Improve error handling in fdshare functions, add comments

one_ring_to_rule_them_all
Guillermo Rodriguez 2025-04-03 11:48:07 -03:00
parent 28b86cb880
commit 4b9d6b9276
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
1 changed files with 102 additions and 39 deletions

View File

@ -14,81 +14,144 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
''' '''
Re-Impl of multiprocessing.reduction.sendfds & recvfds, Reimplementation of multiprocessing.reduction.sendfds & recvfds, using acms and trio.
using acms and trio
cpython impl:
https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L138
''' '''
import array import array
from typing import AsyncContextManager
from contextlib import asynccontextmanager as acm from contextlib import asynccontextmanager as acm
import trio import trio
from trio import socket from trio import socket
class FDSharingError(Exception):
...
@acm @acm
async def send_fds(fds: list[int], sock_path: str): async def send_fds(fds: list[int], sock_path: str) -> AsyncContextManager[None]:
'''
Async trio reimplementation of `multiprocessing.reduction.sendfds`
https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L142
It's implemented using an async context manager in order to simplyfy usage
with `tractor.context`s, we can open a context in a remote actor that uses
this acm inside of it, and uses `ctx.started()` to signal the original
caller actor to perform the `recv_fds` call.
See `tractor.ipc._ringbuf._ringd._attach_to_ring` for an example.
'''
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
await sock.bind(sock_path) await sock.bind(sock_path)
sock.listen(1) sock.listen(1)
yield
fds = array.array('i', fds) yield # socket is setup, ready for receiver connect
# first byte of msg will be len of fds to send % 256
msg = bytes([len(fds) % 256]) # wait until receiver connects
conn, _ = await sock.accept() conn, _ = await sock.accept()
# setup int array for fds
fds = array.array('i', fds)
# first byte of msg will be len of fds to send % 256, acting as a fd amount
# verification on `recv_fds` we refer to it as `check_byte`
msg = bytes([len(fds) % 256])
# send msg with custom SCM_RIGHTS type
await conn.sendmsg( await conn.sendmsg(
[msg], [msg],
[(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds)] [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds)]
) )
# wait ack
# finally wait receiver ack
if await conn.recv(1) != b'A': if await conn.recv(1) != b'A':
raise RuntimeError('did not receive acknowledgement of fd') raise FDSharingError('did not receive acknowledgement of fd')
conn.close() conn.close()
sock.close() sock.close()
async def recv_fds(sock_path: str, amount: int) -> tuple: async def recv_fds(sock_path: str, amount: int) -> tuple:
'''
Async trio reimplementation of `multiprocessing.reduction.recvfds`
https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L150
It's equivalent to std just using `trio.open_unix_socket` for connecting and
changes on error handling.
See `tractor.ipc._ringbuf._ringd._attach_to_ring` for an example.
'''
stream = await trio.open_unix_socket(sock_path) stream = await trio.open_unix_socket(sock_path)
sock = stream.socket sock = stream.socket
# prepare int array for fds
a = array.array('i') a = array.array('i')
bytes_size = a.itemsize * amount bytes_size = a.itemsize * amount
# receive 1 byte + space necesary for SCM_RIGHTS msg for {amount} fds
msg, ancdata, flags, addr = await sock.recvmsg( msg, ancdata, flags, addr = await sock.recvmsg(
1, socket.CMSG_SPACE(bytes_size) 1, socket.CMSG_SPACE(bytes_size)
) )
# maybe failed to receive msg?
if not msg and not ancdata: if not msg and not ancdata:
raise EOFError raise FDSharingError(f'Expected to receive {amount} fds from {sock_path}, but got EOF')
try:
await sock.send(b'A') # Ack
if len(ancdata) != 1: # send ack, std comment mentions this ack pattern was to get around an
raise RuntimeError( # old macosx bug, but they are not sure if its necesary any more, in
f'received {len(ancdata)} items of ancdata' # any case its not a bad pattern to keep
) await sock.send(b'A') # Ack
cmsg_level, cmsg_type, cmsg_data = ancdata[0] # expect to receive only one `ancdata` item
# check proper msg type if len(ancdata) != 1:
if ( raise FDSharingError(
cmsg_level == socket.SOL_SOCKET f'Expected to receive exactly one \"ancdata\" but got {len(ancdata)}: {ancdata}'
and )
cmsg_type == socket.SCM_RIGHTS
):
# check proper data alignment
if len(cmsg_data) % a.itemsize != 0:
raise ValueError
# attempt to cast as int array # unpack SCM_RIGHTS msg
a.frombytes(cmsg_data) cmsg_level, cmsg_type, cmsg_data = ancdata[0]
# check first byte of message is amount % 256 # check proper msg type
if len(a) % 256 != msg[0]: if cmsg_level != socket.SOL_SOCKET:
raise AssertionError( raise FDSharingError(
'Len is {0:n} but msg[0] is {1!r}'.format( f'Expected CMSG level to be SOL_SOCKET({socket.SOL_SOCKET}) but got {cmsg_level}'
len(a), msg[0] )
)
)
return tuple(a) if cmsg_type != socket.SCM_RIGHTS:
raise FDSharingError(
f'Expected CMSG type to be SCM_RIGHTS({socket.SCM_RIGHTS}) but got {cmsg_type}'
)
except (ValueError, IndexError): # check proper data alignment
pass length = len(cmsg_data)
if length % a.itemsize != 0:
raise FDSharingError(
f'CMSG data alignment error: len of {length} is not divisible by int size {a.itemsize}'
)
raise RuntimeError('Invalid data received') # attempt to cast as int array
a.frombytes(cmsg_data)
# validate length check byte
valid_check_byte = amount % 256 # check byte acording to `recv_fds` caller
recvd_check_byte = msg[0] # actual received check byte
payload_check_byte = len(a) % 256 # check byte acording to received fd int array
if recvd_check_byte != payload_check_byte:
raise FDSharingError(
'Validation failed: received check byte '
f'({recvd_check_byte}) does not match fd int array len % 256 ({payload_check_byte})'
)
if valid_check_byte != recvd_check_byte:
raise FDSharingError(
'Validation failed: received check byte '
f'({recvd_check_byte}) does not match expected fd amount % 256 ({valid_check_byte})'
)
return tuple(a)