diff --git a/tractor/linux/_fdshare.py b/tractor/linux/_fdshare.py index e3817c34..c632f532 100644 --- a/tractor/linux/_fdshare.py +++ b/tractor/linux/_fdshare.py @@ -14,81 +14,144 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . ''' -Re-Impl of multiprocessing.reduction.sendfds & recvfds, -using acms and trio +Reimplementation of multiprocessing.reduction.sendfds & recvfds, using acms and trio. + +cpython impl: +https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L138 ''' import array +from typing import AsyncContextManager from contextlib import asynccontextmanager as acm import trio from trio import socket +class FDSharingError(Exception): + ... + + @acm -async def send_fds(fds: list[int], sock_path: str): +async def send_fds(fds: list[int], sock_path: str) -> AsyncContextManager[None]: + ''' + Async trio reimplementation of `multiprocessing.reduction.sendfds` + + https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L142 + + It's implemented using an async context manager in order to simplyfy usage + with `tractor.context`s, we can open a context in a remote actor that uses + this acm inside of it, and uses `ctx.started()` to signal the original + caller actor to perform the `recv_fds` call. + + See `tractor.ipc._ringbuf._ringd._attach_to_ring` for an example. + ''' sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) await sock.bind(sock_path) sock.listen(1) - yield - fds = array.array('i', fds) - # first byte of msg will be len of fds to send % 256 - msg = bytes([len(fds) % 256]) + + yield # socket is setup, ready for receiver connect + + # wait until receiver connects conn, _ = await sock.accept() + + # setup int array for fds + fds = array.array('i', fds) + + # first byte of msg will be len of fds to send % 256, acting as a fd amount + # verification on `recv_fds` we refer to it as `check_byte` + msg = bytes([len(fds) % 256]) + + # send msg with custom SCM_RIGHTS type await conn.sendmsg( [msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds)] ) - # wait ack + + # finally wait receiver ack if await conn.recv(1) != b'A': - raise RuntimeError('did not receive acknowledgement of fd') + raise FDSharingError('did not receive acknowledgement of fd') conn.close() sock.close() async def recv_fds(sock_path: str, amount: int) -> tuple: + ''' + Async trio reimplementation of `multiprocessing.reduction.recvfds` + + https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L150 + + It's equivalent to std just using `trio.open_unix_socket` for connecting and + changes on error handling. + + See `tractor.ipc._ringbuf._ringd._attach_to_ring` for an example. + ''' stream = await trio.open_unix_socket(sock_path) sock = stream.socket + + # prepare int array for fds a = array.array('i') bytes_size = a.itemsize * amount + + # receive 1 byte + space necesary for SCM_RIGHTS msg for {amount} fds msg, ancdata, flags, addr = await sock.recvmsg( 1, socket.CMSG_SPACE(bytes_size) ) + + # maybe failed to receive msg? if not msg and not ancdata: - raise EOFError - try: - await sock.send(b'A') # Ack + raise FDSharingError(f'Expected to receive {amount} fds from {sock_path}, but got EOF') - if len(ancdata) != 1: - raise RuntimeError( - f'received {len(ancdata)} items of ancdata' - ) + # send ack, std comment mentions this ack pattern was to get around an + # old macosx bug, but they are not sure if its necesary any more, in + # any case its not a bad pattern to keep + await sock.send(b'A') # Ack - cmsg_level, cmsg_type, cmsg_data = ancdata[0] - # check proper msg type - if ( - cmsg_level == socket.SOL_SOCKET - and - cmsg_type == socket.SCM_RIGHTS - ): - # check proper data alignment - if len(cmsg_data) % a.itemsize != 0: - raise ValueError + # expect to receive only one `ancdata` item + if len(ancdata) != 1: + raise FDSharingError( + f'Expected to receive exactly one \"ancdata\" but got {len(ancdata)}: {ancdata}' + ) - # attempt to cast as int array - a.frombytes(cmsg_data) + # unpack SCM_RIGHTS msg + cmsg_level, cmsg_type, cmsg_data = ancdata[0] - # check first byte of message is amount % 256 - if len(a) % 256 != msg[0]: - raise AssertionError( - 'Len is {0:n} but msg[0] is {1!r}'.format( - len(a), msg[0] - ) - ) + # check proper msg type + if cmsg_level != socket.SOL_SOCKET: + raise FDSharingError( + f'Expected CMSG level to be SOL_SOCKET({socket.SOL_SOCKET}) but got {cmsg_level}' + ) - return tuple(a) + if cmsg_type != socket.SCM_RIGHTS: + raise FDSharingError( + f'Expected CMSG type to be SCM_RIGHTS({socket.SCM_RIGHTS}) but got {cmsg_type}' + ) - except (ValueError, IndexError): - pass + # check proper data alignment + length = len(cmsg_data) + if length % a.itemsize != 0: + raise FDSharingError( + f'CMSG data alignment error: len of {length} is not divisible by int size {a.itemsize}' + ) - raise RuntimeError('Invalid data received') + # attempt to cast as int array + a.frombytes(cmsg_data) + + # validate length check byte + valid_check_byte = amount % 256 # check byte acording to `recv_fds` caller + recvd_check_byte = msg[0] # actual received check byte + payload_check_byte = len(a) % 256 # check byte acording to received fd int array + + if recvd_check_byte != payload_check_byte: + raise FDSharingError( + 'Validation failed: received check byte ' + f'({recvd_check_byte}) does not match fd int array len % 256 ({payload_check_byte})' + ) + + if valid_check_byte != recvd_check_byte: + raise FDSharingError( + 'Validation failed: received check byte ' + f'({recvd_check_byte}) does not match expected fd amount % 256 ({valid_check_byte})' + ) + + return tuple(a)