diff --git a/tractor/ipc/__init__.py b/tractor/ipc/__init__.py index 4f0cd2b4..c6ad040f 100644 --- a/tractor/ipc/__init__.py +++ b/tractor/ipc/__init__.py @@ -31,18 +31,6 @@ from ._chan import ( ) if platform.system() == 'Linux': - from ._linux import ( - EFD_SEMAPHORE as EFD_SEMAPHORE, - EFD_CLOEXEC as EFD_CLOEXEC, - EFD_NONBLOCK as EFD_NONBLOCK, - open_eventfd as open_eventfd, - write_eventfd as write_eventfd, - read_eventfd as read_eventfd, - close_eventfd as close_eventfd, - EFDReadCancelled as EFDReadCancelled, - EventFD as EventFD, - ) - from ._ringbuf import ( RBToken as RBToken, open_ringbuf as open_ringbuf, diff --git a/tractor/ipc/_ringbuf.py b/tractor/ipc/_ringbuf/__init__.py similarity index 84% rename from tractor/ipc/_ringbuf.py rename to tractor/ipc/_ringbuf/__init__.py index 10975b7a..f9b770a1 100644 --- a/tractor/ipc/_ringbuf.py +++ b/tractor/ipc/_ringbuf/__init__.py @@ -35,16 +35,16 @@ from msgspec import ( to_builtins ) -from ._linux import ( +from ...log import get_logger +from ..._exceptions import ( + InternalError +) +from .._mp_bs import disable_mantracker +from ...linux.eventfd import ( open_eventfd, EFDReadCancelled, EventFD ) -from ._mp_bs import disable_mantracker -from tractor.log import get_logger -from tractor._exceptions import ( - InternalError -) log = get_logger(__name__) @@ -183,6 +183,9 @@ class RingBuffSender(trio.abc.SendStream): def wrap_fd(self) -> int: return self._wrap_event.fd + async def _wait_wrap(self): + await self._wrap_event.read() + async def send_all(self, data: Buffer): async with self._send_lock: # while data is larger than the remaining buf @@ -193,7 +196,7 @@ class RingBuffSender(trio.abc.SendStream): self._shm.buf[self.ptr:] = data[:remaining] # signal write and wait for reader wrap around self._write_event.write(remaining) - await self._wrap_event.read() + await self._wait_wrap() # wrap around and trim already written bytes self._ptr = 0 @@ -209,14 +212,19 @@ class RingBuffSender(trio.abc.SendStream): raise NotImplementedError def open(self): - self._shm = SharedMemory( - name=self._token.shm_name, - size=self._token.buf_size, - create=False - ) - self._write_event.open() - self._wrap_event.open() - self._eof_event.open() + try: + self._shm = SharedMemory( + name=self._token.shm_name, + size=self._token.buf_size, + create=False + ) + self._write_event.open() + self._wrap_event.open() + self._eof_event.open() + + except Exception as e: + e.add_note(f'while opening sender for {self._token.as_msg()}') + raise e def close(self): self._eof_event.write( @@ -363,14 +371,19 @@ class RingBuffReceiver(trio.abc.ReceiveStream): return segment def open(self): - self._shm = SharedMemory( - name=self._token.shm_name, - size=self._token.buf_size, - create=False - ) - self._write_event.open() - self._wrap_event.open() - self._eof_event.open() + try: + self._shm = SharedMemory( + name=self._token.shm_name, + size=self._token.buf_size, + create=False + ) + self._write_event.open() + self._wrap_event.open() + self._eof_event.open() + + except Exception as e: + e.add_note(f'while opening receiver for {self._token.as_msg()}') + raise e def close(self): if self._cleanup: @@ -502,26 +515,52 @@ class RingBuffBytesSender(trio.abc.SendChannel[bytes]): self.batch_size = batch_size self._batch_msg_len = 0 self._batch: bytes = b'' + self._send_lock = trio.StrictFIFOLock() - async def flush(self) -> None: + @property + def pending_msgs(self) -> int: + return self._batch_msg_len + + @property + def must_flush(self) -> bool: + return self._batch_msg_len >= self.batch_size + + async def _flush( + self, + new_batch_size: int | None = None + ) -> None: await self._sender.send_all(self._batch) self._batch = b'' self._batch_msg_len = 0 + if new_batch_size: + self.batch_size = new_batch_size + + async def flush( + self, + new_batch_size: int | None = None + ) -> None: + async with self._send_lock: + await self._flush(new_batch_size=new_batch_size) async def send(self, value: bytes) -> None: - msg: bytes = struct.pack(" None: + await self.flush(new_batch_size=1) + await self.send(b'') async def aclose(self) -> None: - await self._sender.aclose() + async with self._send_lock: + await self._sender.aclose() class RingBuffBytesReceiver(trio.abc.ReceiveChannel[bytes]): @@ -615,9 +654,30 @@ class RingBuffChannel(trio.abc.Channel[bytes]): self._sender = sender self._receiver = receiver - async def send(self, value: bytes): + @property + def batch_size(self) -> int: + return self._sender.batch_size + + @batch_size.setter + def batch_size(self, value: int) -> None: + self._sender.batch_size = value + + @property + def pending_msgs(self) -> int: + return self._sender.pending_msgs + + async def flush( + self, + new_batch_size: int | None = None + ) -> None: + await self._sender.flush(new_batch_size=new_batch_size) + + async def send(self, value: bytes) -> None: await self._sender.send(value) + async def send_eof(self) -> None: + await self._sender.send_eof() + async def receive(self) -> bytes: return await self._receiver.receive() @@ -631,7 +691,8 @@ async def attach_to_ringbuf_channel( token_in: RBToken, token_out: RBToken, cleanup_in: bool = True, - cleanup_out: bool = True + cleanup_out: bool = True, + batch_size: int = 1 ) -> AsyncContextManager[RingBuffChannel]: ''' Attach to an already opened ringbuf pair and return @@ -645,7 +706,8 @@ async def attach_to_ringbuf_channel( ) as receiver, attach_to_ringbuf_schannel( token_out, - cleanup=cleanup_out + cleanup=cleanup_out, + batch_size=batch_size ) as sender, ): yield RingBuffChannel(sender, receiver) diff --git a/tractor/ipc/_ringbuf/_pubsub.py b/tractor/ipc/_ringbuf/_pubsub.py new file mode 100644 index 00000000..a41a83dd --- /dev/null +++ b/tractor/ipc/_ringbuf/_pubsub.py @@ -0,0 +1,219 @@ +import time +from abc import ( + ABC, + abstractmethod +) +from contextlib import asynccontextmanager as acm +from dataclasses import dataclass + +import trio +import tractor + +from tractor.ipc import ( + RingBuffBytesSender, + attach_to_ringbuf_schannel, + attach_to_ringbuf_rchannel +) + +import tractor.ipc._ringbuf._ringd as ringd + + +log = tractor.log.get_logger(__name__) + + +@dataclass +class ChannelInfo: + connect_time: float + name: str + channel: RingBuffBytesSender + cancel_scope: trio.CancelScope + + +class ChannelManager(ABC): + + def __init__( + self, + n: trio.Nursery, + ): + self._n = n + self._channels: list[ChannelInfo] = [] + + @abstractmethod + async def _channel_handler_task(self, name: str): + ... + + def find_channel(self, name: str) -> tuple[int, ChannelInfo] | None: + for entry in enumerate(self._channels): + i, info = entry + if info.name == name: + return entry + + return None + + def _maybe_destroy_channel(self, name: str): + maybe_entry = self.find_channel(name) + if maybe_entry: + i, info = maybe_entry + info.cancel_scope.cancel() + del self._channels[i] + + def add_channel(self, name: str): + self._n.start_soon( + self._channel_handler_task, + name + ) + + def remove_channel(self, name: str): + self._maybe_destroy_channel(name) + + def __len__(self) -> int: + return len(self._channels) + + async def aclose(self) -> None: + for chan in self._channels: + self._maybe_destroy_channel(chan.name) + + +class RingBuffPublisher(ChannelManager, trio.abc.SendChannel[bytes]): + + def __init__( + self, + n: trio.Nursery, + buf_size: int = 10 * 1024, + batch_size: int = 1 + ): + super().__init__(n) + self._connect_event = trio.Event() + self._next_turn: int = 0 + + self._batch_size: int = batch_size + + async def _channel_handler_task( + self, + name: str + ): + async with ( + ringd.open_ringbuf( + name=name, + must_exist=True, + ) as token, + attach_to_ringbuf_schannel(token) as schan + ): + with trio.CancelScope() as cancel_scope: + self._channels.append(ChannelInfo( + connect_time=time.time(), + name=name, + channel=schan, + cancel_scope=cancel_scope + )) + self._connect_event.set() + await trio.sleep_forever() + + self._maybe_destroy_channel(name) + + async def send(self, msg: bytes): + # wait at least one decoder connected + if len(self) == 0: + await self._connect_event.wait() + self._connect_event = trio.Event() + + if self._next_turn >= len(self): + self._next_turn = 0 + + turn = self._next_turn + self._next_turn += 1 + + output = self._channels[turn] + await output.channel.send(msg) + + @property + def batch_size(self) -> int: + return self._batch_size + + @batch_size.setter + def set_batch_size(self, value: int) -> None: + for output in self._channels: + output.channel.batch_size = value + + async def flush( + self, + new_batch_size: int | None = None + ): + for output in self._channels: + await output.channel.flush( + new_batch_size=new_batch_size + ) + + async def send_eof(self): + for output in self._channels: + await output.channel.send_eof() + + +@acm +async def open_ringbuf_publisher( + buf_size: int = 10 * 1024, + batch_size: int = 1 +): + async with ( + trio.open_nursery() as n, + RingBuffPublisher( + n, + buf_size=buf_size, + batch_size=batch_size + ) as outputs + ): + yield outputs + await outputs.aclose() + + + +class RingBuffSubscriber(ChannelManager, trio.abc.ReceiveChannel[bytes]): + def __init__( + self, + n: trio.Nursery, + ): + super().__init__(n) + self._send_chan, self._recv_chan = trio.open_memory_channel(0) + + async def _channel_handler_task( + self, + name: str + ): + async with ( + ringd.open_ringbuf( + name=name, + must_exist=True + ) as token, + + attach_to_ringbuf_rchannel(token) as rchan + ): + with trio.CancelScope() as cancel_scope: + self._channels.append(ChannelInfo( + connect_time=time.time(), + name=name, + channel=rchan, + cancel_scope=cancel_scope + )) + send_chan = self._send_chan.clone() + try: + async for msg in rchan: + await send_chan.send(msg) + + except tractor._exceptions.InternalError: + ... + + self._maybe_destroy_channel(name) + + async def receive(self) -> bytes: + return await self._recv_chan.receive() + + +@acm +async def open_ringbuf_subscriber(): + async with ( + trio.open_nursery() as n, + RingBuffSubscriber(n) as inputs + ): + yield inputs + await inputs.aclose() + diff --git a/tractor/ipc/_ringbuf/_ringd.py b/tractor/ipc/_ringbuf/_ringd.py new file mode 100644 index 00000000..fb255979 --- /dev/null +++ b/tractor/ipc/_ringbuf/_ringd.py @@ -0,0 +1,172 @@ +import os +import tempfile +from pathlib import Path +from contextlib import ( + asynccontextmanager as acm +) + +import trio +import tractor +from tractor.linux import send_fds, recv_fds + +from . import ( + RBToken, + open_ringbuf as ipc_open_ringbuf +) + + +log = tractor.log.get_logger(__name__) +# log = tractor.log.get_console_log(level='info') + + +_ringd_actor_name = 'ringd' +_root_key = _ringd_actor_name + f'-{os.getpid()}' +_rings: dict[str, RBToken] = {} + + +async def _attach_to_ring( + ring_name: str +) -> RBToken: + actor = tractor.current_actor() + + fd_amount = 3 + sock_path = str( + Path(tempfile.gettempdir()) + / + f'{os.getpid()}-pass-ring-fds-{ring_name}-to-{actor.name}.sock' + ) + + log.info(f'trying to attach to ring {ring_name}...') + + async with ( + tractor.find_actor(_ringd_actor_name) as ringd, + ringd.open_context( + _pass_fds, + name=ring_name, + sock_path=sock_path + ) as (ctx, token), + recv_fds(sock_path, fd_amount) as fds, + ): + log.info( + f'received fds: {fds}' + ) + + token = RBToken.from_msg(token) + + write, wrap, eof = fds + + return RBToken( + shm_name=token.shm_name, + write_eventfd=write, + wrap_eventfd=wrap, + eof_eventfd=eof, + buf_size=token.buf_size + ) + + +@tractor.context +async def _pass_fds( + ctx: tractor.Context, + name: str, + sock_path: str +): + global _rings + + token = _rings[name] + + async with send_fds(token.fds, sock_path): + log.info(f'connected to {sock_path} for fd passing') + await ctx.started(token) + + log.info(f'fds {token.fds} sent') + + return token + + +@tractor.context +async def _open_ringbuf( + ctx: tractor.Context, + name: str, + must_exist: bool = False, + buf_size: int = 10 * 1024 +): + global _root_key, _rings + + teardown = trio.Event() + async def _teardown_listener(task_status=trio.TASK_STATUS_IGNORED): + async with ctx.open_stream() as stream: + task_status.started() + await stream.receive() + teardown.set() + + log.info(f'maybe open ring {name}, must_exist = {must_exist}') + + token = _rings.get(name, None) + + async with trio.open_nursery() as n: + if token: + log.info(f'ring {name} exists') + await ctx.started() + await n.start(_teardown_listener) + await teardown.wait() + return + + if must_exist: + raise FileNotFoundError( + f'Tried to open_ringbuf but it doesn\'t exist: {name}' + ) + + with ipc_open_ringbuf( + _root_key + name, + buf_size=buf_size + ) as token: + _rings[name] = token + log.info(f'ring {name} created') + await ctx.started() + await n.start(_teardown_listener) + await teardown.wait() + del _rings[name] + + log.info(f'ring {name} destroyed') + + +@acm +async def open_ringd(**kwargs) -> tractor.Portal: + async with tractor.open_nursery(**kwargs) as an: + portal = await an.start_actor( + _ringd_actor_name, + enable_modules=[__name__] + ) + yield portal + await an.cancel() + + +@acm +async def wait_for_ringd() -> tractor.Portal: + async with tractor.wait_for_actor( + _ringd_actor_name + ) as portal: + yield portal + + +@acm +async def open_ringbuf( + name: str, + must_exist: bool = False, + buf_size: int = 10 * 1024 +) -> RBToken: + async with ( + wait_for_ringd() as ringd, + ringd.open_context( + _open_ringbuf, + name=name, + must_exist=must_exist, + buf_size=buf_size + ) as (rd_ctx, _), + rd_ctx.open_stream() as stream, + ): + token = await _attach_to_ring(name) + log.info(f'attached to {token}') + yield token + await stream.send(b'bye') + diff --git a/tractor/linux/__init__.py b/tractor/linux/__init__.py new file mode 100644 index 00000000..dce926c8 --- /dev/null +++ b/tractor/linux/__init__.py @@ -0,0 +1,4 @@ +from ._fdshare import ( + send_fds as send_fds, + recv_fds as recv_fds +) diff --git a/tractor/linux/_fdshare.py b/tractor/linux/_fdshare.py new file mode 100644 index 00000000..a1ddceec --- /dev/null +++ b/tractor/linux/_fdshare.py @@ -0,0 +1,81 @@ +''' +Re-Impl of multiprocessing.reduction.sendfds & recvfds, +using acms and trio +''' +import array +from contextlib import asynccontextmanager as acm + +import trio +from trio import socket + + +@acm +async def send_fds(fds: list[int], sock_path: str): + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + await sock.bind(sock_path) + sock.listen(1) + fds = array.array('i', fds) + # first byte of msg will be len of fds to send % 256 + msg = bytes([len(fds) % 256]) + yield + conn, _ = await sock.accept() + await conn.sendmsg( + [msg], + [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds)] + ) + # wait ack + if await conn.recv(1) != b'A': + raise RuntimeError('did not receive acknowledgement of fd') + + conn.close() + sock.close() + + +@acm +async def recv_fds(sock_path: str, amount: int) -> tuple: + stream = await trio.open_unix_socket(sock_path) + sock = stream.socket + a = array.array('i') + bytes_size = a.itemsize * amount + msg, ancdata, flags, addr = await sock.recvmsg( + 1, socket.CMSG_SPACE(bytes_size) + ) + if not msg and not ancdata: + raise EOFError + try: + await sock.send(b'A') # Ack + + if len(ancdata) != 1: + raise RuntimeError( + f'received {len(ancdata)} items of ancdata' + ) + + cmsg_level, cmsg_type, cmsg_data = ancdata[0] + # check proper msg type + if ( + cmsg_level == socket.SOL_SOCKET + and + cmsg_type == socket.SCM_RIGHTS + ): + # check proper data alignment + if len(cmsg_data) % a.itemsize != 0: + raise ValueError + + # attempt to cast as int array + a.frombytes(cmsg_data) + + # check first byte of message is amount % 256 + if len(a) % 256 != msg[0]: + raise AssertionError( + 'Len is {0:n} but msg[0] is {1!r}'.format( + len(a), msg[0] + ) + ) + + yield tuple(a) + return + + except (ValueError, IndexError): + pass + + raise RuntimeError('Invalid data received') diff --git a/tractor/ipc/_linux.py b/tractor/linux/eventfd.py similarity index 100% rename from tractor/ipc/_linux.py rename to tractor/linux/eventfd.py