tractor/tractor/linux/_fdshare.py

95 lines
2.8 KiB
Python

# tractor: structured concurrent "actors".
# Copyright 2018-eternity Tyler Goodlet.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
'''
Re-Impl of multiprocessing.reduction.sendfds & recvfds,
using acms and trio
'''
import array
from contextlib import asynccontextmanager as acm
import trio
from trio import socket
@acm
async def send_fds(fds: list[int], sock_path: str):
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
await sock.bind(sock_path)
sock.listen(1)
yield
fds = array.array('i', fds)
# first byte of msg will be len of fds to send % 256
msg = bytes([len(fds) % 256])
conn, _ = await sock.accept()
await conn.sendmsg(
[msg],
[(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds)]
)
# wait ack
if await conn.recv(1) != b'A':
raise RuntimeError('did not receive acknowledgement of fd')
conn.close()
sock.close()
async def recv_fds(sock_path: str, amount: int) -> tuple:
stream = await trio.open_unix_socket(sock_path)
sock = stream.socket
a = array.array('i')
bytes_size = a.itemsize * amount
msg, ancdata, flags, addr = await sock.recvmsg(
1, socket.CMSG_SPACE(bytes_size)
)
if not msg and not ancdata:
raise EOFError
try:
await sock.send(b'A') # Ack
if len(ancdata) != 1:
raise RuntimeError(
f'received {len(ancdata)} items of ancdata'
)
cmsg_level, cmsg_type, cmsg_data = ancdata[0]
# check proper msg type
if (
cmsg_level == socket.SOL_SOCKET
and
cmsg_type == socket.SCM_RIGHTS
):
# check proper data alignment
if len(cmsg_data) % a.itemsize != 0:
raise ValueError
# attempt to cast as int array
a.frombytes(cmsg_data)
# check first byte of message is amount % 256
if len(a) % 256 != msg[0]:
raise AssertionError(
'Len is {0:n} but msg[0] is {1!r}'.format(
len(a), msg[0]
)
)
return tuple(a)
except (ValueError, IndexError):
pass
raise RuntimeError('Invalid data received')