Improve error handling in fdshare functions, add comments
parent
28b86cb880
commit
4b9d6b9276
|
@ -14,81 +14,144 @@
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
'''
|
'''
|
||||||
Re-Impl of multiprocessing.reduction.sendfds & recvfds,
|
Reimplementation of multiprocessing.reduction.sendfds & recvfds, using acms and trio.
|
||||||
using acms and trio
|
|
||||||
|
cpython impl:
|
||||||
|
https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L138
|
||||||
'''
|
'''
|
||||||
import array
|
import array
|
||||||
|
from typing import AsyncContextManager
|
||||||
from contextlib import asynccontextmanager as acm
|
from contextlib import asynccontextmanager as acm
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
from trio import socket
|
from trio import socket
|
||||||
|
|
||||||
|
|
||||||
|
class FDSharingError(Exception):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
@acm
|
@acm
|
||||||
async def send_fds(fds: list[int], sock_path: str):
|
async def send_fds(fds: list[int], sock_path: str) -> AsyncContextManager[None]:
|
||||||
|
'''
|
||||||
|
Async trio reimplementation of `multiprocessing.reduction.sendfds`
|
||||||
|
|
||||||
|
https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L142
|
||||||
|
|
||||||
|
It's implemented using an async context manager in order to simplyfy usage
|
||||||
|
with `tractor.context`s, we can open a context in a remote actor that uses
|
||||||
|
this acm inside of it, and uses `ctx.started()` to signal the original
|
||||||
|
caller actor to perform the `recv_fds` call.
|
||||||
|
|
||||||
|
See `tractor.ipc._ringbuf._ringd._attach_to_ring` for an example.
|
||||||
|
'''
|
||||||
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||||
await sock.bind(sock_path)
|
await sock.bind(sock_path)
|
||||||
sock.listen(1)
|
sock.listen(1)
|
||||||
yield
|
|
||||||
fds = array.array('i', fds)
|
yield # socket is setup, ready for receiver connect
|
||||||
# first byte of msg will be len of fds to send % 256
|
|
||||||
msg = bytes([len(fds) % 256])
|
# wait until receiver connects
|
||||||
conn, _ = await sock.accept()
|
conn, _ = await sock.accept()
|
||||||
|
|
||||||
|
# setup int array for fds
|
||||||
|
fds = array.array('i', fds)
|
||||||
|
|
||||||
|
# first byte of msg will be len of fds to send % 256, acting as a fd amount
|
||||||
|
# verification on `recv_fds` we refer to it as `check_byte`
|
||||||
|
msg = bytes([len(fds) % 256])
|
||||||
|
|
||||||
|
# send msg with custom SCM_RIGHTS type
|
||||||
await conn.sendmsg(
|
await conn.sendmsg(
|
||||||
[msg],
|
[msg],
|
||||||
[(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds)]
|
[(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds)]
|
||||||
)
|
)
|
||||||
# wait ack
|
|
||||||
|
# finally wait receiver ack
|
||||||
if await conn.recv(1) != b'A':
|
if await conn.recv(1) != b'A':
|
||||||
raise RuntimeError('did not receive acknowledgement of fd')
|
raise FDSharingError('did not receive acknowledgement of fd')
|
||||||
|
|
||||||
conn.close()
|
conn.close()
|
||||||
sock.close()
|
sock.close()
|
||||||
|
|
||||||
|
|
||||||
async def recv_fds(sock_path: str, amount: int) -> tuple:
|
async def recv_fds(sock_path: str, amount: int) -> tuple:
|
||||||
|
'''
|
||||||
|
Async trio reimplementation of `multiprocessing.reduction.recvfds`
|
||||||
|
|
||||||
|
https://github.com/python/cpython/blob/275056a7fdcbe36aaac494b4183ae59943a338eb/Lib/multiprocessing/reduction.py#L150
|
||||||
|
|
||||||
|
It's equivalent to std just using `trio.open_unix_socket` for connecting and
|
||||||
|
changes on error handling.
|
||||||
|
|
||||||
|
See `tractor.ipc._ringbuf._ringd._attach_to_ring` for an example.
|
||||||
|
'''
|
||||||
stream = await trio.open_unix_socket(sock_path)
|
stream = await trio.open_unix_socket(sock_path)
|
||||||
sock = stream.socket
|
sock = stream.socket
|
||||||
|
|
||||||
|
# prepare int array for fds
|
||||||
a = array.array('i')
|
a = array.array('i')
|
||||||
bytes_size = a.itemsize * amount
|
bytes_size = a.itemsize * amount
|
||||||
|
|
||||||
|
# receive 1 byte + space necesary for SCM_RIGHTS msg for {amount} fds
|
||||||
msg, ancdata, flags, addr = await sock.recvmsg(
|
msg, ancdata, flags, addr = await sock.recvmsg(
|
||||||
1, socket.CMSG_SPACE(bytes_size)
|
1, socket.CMSG_SPACE(bytes_size)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# maybe failed to receive msg?
|
||||||
if not msg and not ancdata:
|
if not msg and not ancdata:
|
||||||
raise EOFError
|
raise FDSharingError(f'Expected to receive {amount} fds from {sock_path}, but got EOF')
|
||||||
try:
|
|
||||||
|
# send ack, std comment mentions this ack pattern was to get around an
|
||||||
|
# old macosx bug, but they are not sure if its necesary any more, in
|
||||||
|
# any case its not a bad pattern to keep
|
||||||
await sock.send(b'A') # Ack
|
await sock.send(b'A') # Ack
|
||||||
|
|
||||||
|
# expect to receive only one `ancdata` item
|
||||||
if len(ancdata) != 1:
|
if len(ancdata) != 1:
|
||||||
raise RuntimeError(
|
raise FDSharingError(
|
||||||
f'received {len(ancdata)} items of ancdata'
|
f'Expected to receive exactly one \"ancdata\" but got {len(ancdata)}: {ancdata}'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# unpack SCM_RIGHTS msg
|
||||||
cmsg_level, cmsg_type, cmsg_data = ancdata[0]
|
cmsg_level, cmsg_type, cmsg_data = ancdata[0]
|
||||||
|
|
||||||
# check proper msg type
|
# check proper msg type
|
||||||
if (
|
if cmsg_level != socket.SOL_SOCKET:
|
||||||
cmsg_level == socket.SOL_SOCKET
|
raise FDSharingError(
|
||||||
and
|
f'Expected CMSG level to be SOL_SOCKET({socket.SOL_SOCKET}) but got {cmsg_level}'
|
||||||
cmsg_type == socket.SCM_RIGHTS
|
)
|
||||||
):
|
|
||||||
|
if cmsg_type != socket.SCM_RIGHTS:
|
||||||
|
raise FDSharingError(
|
||||||
|
f'Expected CMSG type to be SCM_RIGHTS({socket.SCM_RIGHTS}) but got {cmsg_type}'
|
||||||
|
)
|
||||||
|
|
||||||
# check proper data alignment
|
# check proper data alignment
|
||||||
if len(cmsg_data) % a.itemsize != 0:
|
length = len(cmsg_data)
|
||||||
raise ValueError
|
if length % a.itemsize != 0:
|
||||||
|
raise FDSharingError(
|
||||||
|
f'CMSG data alignment error: len of {length} is not divisible by int size {a.itemsize}'
|
||||||
|
)
|
||||||
|
|
||||||
# attempt to cast as int array
|
# attempt to cast as int array
|
||||||
a.frombytes(cmsg_data)
|
a.frombytes(cmsg_data)
|
||||||
|
|
||||||
# check first byte of message is amount % 256
|
# validate length check byte
|
||||||
if len(a) % 256 != msg[0]:
|
valid_check_byte = amount % 256 # check byte acording to `recv_fds` caller
|
||||||
raise AssertionError(
|
recvd_check_byte = msg[0] # actual received check byte
|
||||||
'Len is {0:n} but msg[0] is {1!r}'.format(
|
payload_check_byte = len(a) % 256 # check byte acording to received fd int array
|
||||||
len(a), msg[0]
|
|
||||||
|
if recvd_check_byte != payload_check_byte:
|
||||||
|
raise FDSharingError(
|
||||||
|
'Validation failed: received check byte '
|
||||||
|
f'({recvd_check_byte}) does not match fd int array len % 256 ({payload_check_byte})'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if valid_check_byte != recvd_check_byte:
|
||||||
|
raise FDSharingError(
|
||||||
|
'Validation failed: received check byte '
|
||||||
|
f'({recvd_check_byte}) does not match expected fd amount % 256 ({valid_check_byte})'
|
||||||
)
|
)
|
||||||
|
|
||||||
return tuple(a)
|
return tuple(a)
|
||||||
|
|
||||||
except (ValueError, IndexError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
raise RuntimeError('Invalid data received')
|
|
||||||
|
|
Loading…
Reference in New Issue