Improve error handling in fdshare functions, add comments
parent
28b86cb880
commit
4b9d6b9276
|
@ -14,81 +14,144 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
'''
|
||||
Re-Impl of multiprocessing.reduction.sendfds & recvfds,
|
||||
using acms and trio
|
||||
Reimplementation of multiprocessing.reduction.sendfds & recvfds, using acms and trio.
|
||||
|
||||
cpython impl:
|
||||
https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L138
|
||||
'''
|
||||
import array
|
||||
from typing import AsyncContextManager
|
||||
from contextlib import asynccontextmanager as acm
|
||||
|
||||
import trio
|
||||
from trio import socket
|
||||
|
||||
|
||||
class FDSharingError(Exception):
|
||||
...
|
||||
|
||||
|
||||
@acm
|
||||
async def send_fds(fds: list[int], sock_path: str):
|
||||
async def send_fds(fds: list[int], sock_path: str) -> AsyncContextManager[None]:
|
||||
'''
|
||||
Async trio reimplementation of `multiprocessing.reduction.sendfds`
|
||||
|
||||
https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L142
|
||||
|
||||
It's implemented using an async context manager in order to simplyfy usage
|
||||
with `tractor.context`s, we can open a context in a remote actor that uses
|
||||
this acm inside of it, and uses `ctx.started()` to signal the original
|
||||
caller actor to perform the `recv_fds` call.
|
||||
|
||||
See `tractor.ipc._ringbuf._ringd._attach_to_ring` for an example.
|
||||
'''
|
||||
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
await sock.bind(sock_path)
|
||||
sock.listen(1)
|
||||
yield
|
||||
fds = array.array('i', fds)
|
||||
# first byte of msg will be len of fds to send % 256
|
||||
msg = bytes([len(fds) % 256])
|
||||
|
||||
yield # socket is setup, ready for receiver connect
|
||||
|
||||
# wait until receiver connects
|
||||
conn, _ = await sock.accept()
|
||||
|
||||
# setup int array for fds
|
||||
fds = array.array('i', fds)
|
||||
|
||||
# first byte of msg will be len of fds to send % 256, acting as a fd amount
|
||||
# verification on `recv_fds` we refer to it as `check_byte`
|
||||
msg = bytes([len(fds) % 256])
|
||||
|
||||
# send msg with custom SCM_RIGHTS type
|
||||
await conn.sendmsg(
|
||||
[msg],
|
||||
[(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds)]
|
||||
)
|
||||
# wait ack
|
||||
|
||||
# finally wait receiver ack
|
||||
if await conn.recv(1) != b'A':
|
||||
raise RuntimeError('did not receive acknowledgement of fd')
|
||||
raise FDSharingError('did not receive acknowledgement of fd')
|
||||
|
||||
conn.close()
|
||||
sock.close()
|
||||
|
||||
|
||||
async def recv_fds(sock_path: str, amount: int) -> tuple:
|
||||
'''
|
||||
Async trio reimplementation of `multiprocessing.reduction.recvfds`
|
||||
|
||||
https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L150
|
||||
|
||||
It's equivalent to std just using `trio.open_unix_socket` for connecting and
|
||||
changes on error handling.
|
||||
|
||||
See `tractor.ipc._ringbuf._ringd._attach_to_ring` for an example.
|
||||
'''
|
||||
stream = await trio.open_unix_socket(sock_path)
|
||||
sock = stream.socket
|
||||
|
||||
# prepare int array for fds
|
||||
a = array.array('i')
|
||||
bytes_size = a.itemsize * amount
|
||||
|
||||
# receive 1 byte + space necesary for SCM_RIGHTS msg for {amount} fds
|
||||
msg, ancdata, flags, addr = await sock.recvmsg(
|
||||
1, socket.CMSG_SPACE(bytes_size)
|
||||
)
|
||||
|
||||
# maybe failed to receive msg?
|
||||
if not msg and not ancdata:
|
||||
raise EOFError
|
||||
try:
|
||||
raise FDSharingError(f'Expected to receive {amount} fds from {sock_path}, but got EOF')
|
||||
|
||||
# send ack, std comment mentions this ack pattern was to get around an
|
||||
# old macosx bug, but they are not sure if its necesary any more, in
|
||||
# any case its not a bad pattern to keep
|
||||
await sock.send(b'A') # Ack
|
||||
|
||||
# expect to receive only one `ancdata` item
|
||||
if len(ancdata) != 1:
|
||||
raise RuntimeError(
|
||||
f'received {len(ancdata)} items of ancdata'
|
||||
raise FDSharingError(
|
||||
f'Expected to receive exactly one \"ancdata\" but got {len(ancdata)}: {ancdata}'
|
||||
)
|
||||
|
||||
# unpack SCM_RIGHTS msg
|
||||
cmsg_level, cmsg_type, cmsg_data = ancdata[0]
|
||||
|
||||
# check proper msg type
|
||||
if (
|
||||
cmsg_level == socket.SOL_SOCKET
|
||||
and
|
||||
cmsg_type == socket.SCM_RIGHTS
|
||||
):
|
||||
if cmsg_level != socket.SOL_SOCKET:
|
||||
raise FDSharingError(
|
||||
f'Expected CMSG level to be SOL_SOCKET({socket.SOL_SOCKET}) but got {cmsg_level}'
|
||||
)
|
||||
|
||||
if cmsg_type != socket.SCM_RIGHTS:
|
||||
raise FDSharingError(
|
||||
f'Expected CMSG type to be SCM_RIGHTS({socket.SCM_RIGHTS}) but got {cmsg_type}'
|
||||
)
|
||||
|
||||
# check proper data alignment
|
||||
if len(cmsg_data) % a.itemsize != 0:
|
||||
raise ValueError
|
||||
length = len(cmsg_data)
|
||||
if length % a.itemsize != 0:
|
||||
raise FDSharingError(
|
||||
f'CMSG data alignment error: len of {length} is not divisible by int size {a.itemsize}'
|
||||
)
|
||||
|
||||
# attempt to cast as int array
|
||||
a.frombytes(cmsg_data)
|
||||
|
||||
# check first byte of message is amount % 256
|
||||
if len(a) % 256 != msg[0]:
|
||||
raise AssertionError(
|
||||
'Len is {0:n} but msg[0] is {1!r}'.format(
|
||||
len(a), msg[0]
|
||||
# validate length check byte
|
||||
valid_check_byte = amount % 256 # check byte acording to `recv_fds` caller
|
||||
recvd_check_byte = msg[0] # actual received check byte
|
||||
payload_check_byte = len(a) % 256 # check byte acording to received fd int array
|
||||
|
||||
if recvd_check_byte != payload_check_byte:
|
||||
raise FDSharingError(
|
||||
'Validation failed: received check byte '
|
||||
f'({recvd_check_byte}) does not match fd int array len % 256 ({payload_check_byte})'
|
||||
)
|
||||
|
||||
if valid_check_byte != recvd_check_byte:
|
||||
raise FDSharingError(
|
||||
'Validation failed: received check byte '
|
||||
f'({recvd_check_byte}) does not match expected fd amount % 256 ({valid_check_byte})'
|
||||
)
|
||||
|
||||
return tuple(a)
|
||||
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
raise RuntimeError('Invalid data received')
|
||||
|
|
Loading…
Reference in New Issue