From d6721f06df2cb9aca8095c2ca59e2406141d92e3 Mon Sep 17 00:00:00 2001 From: Guillermo Rodriguez Date: Sun, 16 Mar 2025 17:50:13 -0300 Subject: [PATCH] Better encapsulate RingBuff ctx managment methods and support non ipc usage Add trio.StrictFIFOLock on sender.send_all Support max_bytes argument on receive_some, keep track of write_ptr on receiver Add max_bytes receive test test_ringbuf_max_bytes Add docstrings to all ringbuf tests Remove EFD_NONBLOCK support, not necesary anymore since we can use abandon_on_cancel=True on trio.to_thread.run_sync Close eventfd's after usage on open_ringbuf --- tests/test_ringbuf.py | 54 ++++++++++++ tractor/ipc/_ringbuf.py | 180 ++++++++++++++++++++++------------------ 2 files changed, 153 insertions(+), 81 deletions(-) diff --git a/tests/test_ringbuf.py b/tests/test_ringbuf.py index 28af7b83..52cf0836 100644 --- a/tests/test_ringbuf.py +++ b/tests/test_ringbuf.py @@ -58,6 +58,8 @@ async def child_write_shm( for msg in msgs: await sender.send_all(msg) + print('writer exit') + @pytest.mark.parametrize( 'msg_amount,rand_min,rand_max,buf_size', @@ -83,6 +85,15 @@ def test_ringbuf( rand_max: int, buf_size: int ): + ''' + - Open a new ring buf on root actor + - Create a sender subactor and generate {msg_amount} messages + optionally with a random amount of bytes at the end of each, + return total_bytes on `ctx.started`, then send all messages + - Create a receiver subactor and receive until total_bytes are + read, print simple perf stats. + + ''' async def main(): with open_ringbuf( 'test_ringbuf', @@ -140,6 +151,11 @@ async def child_blocked_receiver( def test_ring_reader_cancel(): + ''' + Test that a receiver blocked on eventfd(2) read responds to + cancellation. + + ''' async def main(): with open_ringbuf('test_ring_cancel_reader') as token: async with ( @@ -178,6 +194,11 @@ async def child_blocked_sender( def test_ring_sender_cancel(): + ''' + Test that a sender blocked on eventfd(2) read responds to + cancellation. + + ''' async def main(): with open_ringbuf( 'test_ring_cancel_sender', @@ -203,3 +224,36 @@ def test_ring_sender_cancel(): with pytest.raises(tractor._exceptions.ContextCancelled): trio.run(main) + + +def test_ringbuf_max_bytes(): + ''' + Test that RingBuffReceiver.receive_some's max_bytes optional + argument works correctly, send a msg of size 100, then + force receive of messages with max_bytes == 1, wait until + 100 of these messages are received, then compare join of + msgs with original message + + ''' + msg = b''.join(str(i % 10).encode() for i in range(100)) + msgs = [] + + async def main(): + with open_ringbuf( + 'test_ringbuf_max_bytes', + buf_size=10 + ) as token: + async with ( + trio.open_nursery() as n, + RingBuffSender(token, is_ipc=False) as sender, + RingBuffReceiver(token, is_ipc=False) as receiver + ): + n.start_soon(sender.send_all, msg) + while len(msgs) < len(msg): + msg_part = await receiver.receive_some(max_bytes=1) + msg_part = bytes(msg_part) + assert len(msg_part) == 1 + msgs.append(msg_part) + + trio.run(main) + assert msg == b''.join(msgs) diff --git a/tractor/ipc/_ringbuf.py b/tractor/ipc/_ringbuf.py index 6337eea1..304454ed 100644 --- a/tractor/ipc/_ringbuf.py +++ b/tractor/ipc/_ringbuf.py @@ -28,11 +28,15 @@ from msgspec import ( ) from ._linux import ( - EFD_NONBLOCK, open_eventfd, + close_eventfd, EventFD ) from ._mp_bs import disable_mantracker +from tractor.log import get_logger + + +log = get_logger(__name__) disable_mantracker() @@ -64,8 +68,6 @@ class RBToken(Struct, frozen=True): def open_ringbuf( shm_name: str, buf_size: int = 10 * 1024, - write_efd_flags: int = 0, - wrap_efd_flags: int = 0 ) -> RBToken: shm = SharedMemory( name=shm_name, @@ -75,16 +77,21 @@ def open_ringbuf( try: token = RBToken( shm_name=shm_name, - write_eventfd=open_eventfd(flags=write_efd_flags), - wrap_eventfd=open_eventfd(flags=wrap_efd_flags), + write_eventfd=open_eventfd(), + wrap_eventfd=open_eventfd(), buf_size=buf_size ) yield token + close_eventfd(token.write_eventfd) + close_eventfd(token.wrap_eventfd) finally: shm.unlink() +Buffer = bytes | bytearray | memoryview + + class RingBuffSender(trio.abc.SendStream): ''' IPC Reliable Ring Buffer sender side implementation @@ -97,24 +104,26 @@ class RingBuffSender(trio.abc.SendStream): self, token: RBToken, start_ptr: int = 0, + is_ipc: bool = True ): - token = RBToken.from_msg(token) - self._shm = SharedMemory( - name=token.shm_name, - size=token.buf_size, - create=False - ) - self._write_event = EventFD(token.write_eventfd, 'w') - self._wrap_event = EventFD(token.wrap_eventfd, 'r') + self._token = RBToken.from_msg(token) + self._shm: SharedMemory | None = None + self._write_event = EventFD(self._token.write_eventfd, 'w') + self._wrap_event = EventFD(self._token.wrap_eventfd, 'r') self._ptr = start_ptr + self._is_ipc = is_ipc + self._send_lock = trio.StrictFIFOLock() + @property - def key(self) -> str: + def name(self) -> str: + if not self._shm: + raise ValueError('shared memory not initialized yet!') return self._shm.name @property def size(self) -> int: - return self._shm.size + return self._token.buf_size @property def ptr(self) -> int: @@ -128,38 +137,48 @@ class RingBuffSender(trio.abc.SendStream): def wrap_fd(self) -> int: return self._wrap_event.fd - async def send_all(self, data: bytes | bytearray | memoryview): - # while data is larger than the remaining buf - target_ptr = self.ptr + len(data) - while target_ptr > self.size: - # write all bytes that fit - remaining = self.size - self.ptr - self._shm.buf[self.ptr:] = data[:remaining] - # signal write and wait for reader wrap around - self._write_event.write(remaining) - await self._wrap_event.read() + async def send_all(self, data: Buffer): + async with self._send_lock: + # while data is larger than the remaining buf + target_ptr = self.ptr + len(data) + while target_ptr > self.size: + # write all bytes that fit + remaining = self.size - self.ptr + self._shm.buf[self.ptr:] = data[:remaining] + # signal write and wait for reader wrap around + self._write_event.write(remaining) + await self._wrap_event.read() - # wrap around and trim already written bytes - self._ptr = 0 - data = data[remaining:] - target_ptr = self._ptr + len(data) + # wrap around and trim already written bytes + self._ptr = 0 + data = data[remaining:] + target_ptr = self._ptr + len(data) - # remaining data fits on buffer - self._shm.buf[self.ptr:target_ptr] = data - self._write_event.write(len(data)) - self._ptr = target_ptr + # remaining data fits on buffer + self._shm.buf[self.ptr:target_ptr] = data + self._write_event.write(len(data)) + self._ptr = target_ptr async def wait_send_all_might_not_block(self): raise NotImplementedError - async def aclose(self): - self._write_event.close() - self._wrap_event.close() - self._shm.close() - - async def __aenter__(self): + def open(self): + self._shm = SharedMemory( + name=self._token.shm_name, + size=self._token.buf_size, + create=False + ) self._write_event.open() self._wrap_event.open() + + async def aclose(self): + if self._is_ipc: + self._write_event.close() + self._wrap_event.close() + self._shm.close() + + async def __aenter__(self): + self.open() return self @@ -175,26 +194,25 @@ class RingBuffReceiver(trio.abc.ReceiveStream): self, token: RBToken, start_ptr: int = 0, - flags: int = 0 + is_ipc: bool = True ): - token = RBToken.from_msg(token) - self._shm = SharedMemory( - name=token.shm_name, - size=token.buf_size, - create=False - ) - self._write_event = EventFD(token.write_eventfd, 'w') - self._wrap_event = EventFD(token.wrap_eventfd, 'r') + self._token = RBToken.from_msg(token) + self._shm: SharedMemory | None = None + self._write_event = EventFD(self._token.write_eventfd, 'w') + self._wrap_event = EventFD(self._token.wrap_eventfd, 'r') self._ptr = start_ptr - self._flags = flags + self._write_ptr = start_ptr + self._is_ipc = is_ipc @property - def key(self) -> str: + def name(self) -> str: + if not self._shm: + raise ValueError('shared memory not initialized yet!') return self._shm.name @property def size(self) -> int: - return self._shm.size + return self._token.buf_size @property def ptr(self) -> int: @@ -208,46 +226,46 @@ class RingBuffReceiver(trio.abc.ReceiveStream): def wrap_fd(self) -> int: return self._wrap_event.fd - async def receive_some( - self, - max_bytes: int | None = None, - nb_timeout: float = 0.1 - ) -> memoryview: - # if non blocking eventfd enabled, do polling - # until next write, this allows signal handling - if self._flags | EFD_NONBLOCK: - delta = None - while delta is None: - try: - delta = await self._write_event.read() - - except OSError as e: - if e.errno == 'EAGAIN': - continue - - raise e - - else: + async def receive_some(self, max_bytes: int | None = None) -> memoryview: + delta = self._write_ptr - self._ptr + if delta == 0: delta = await self._write_event.read() + self._write_ptr += delta + + if isinstance(max_bytes, int): + if max_bytes == 0: + raise ValueError('if set, max_bytes must be > 0') + delta = min(delta, max_bytes) + + target_ptr = self._ptr + delta # fetch next segment and advance ptr - next_ptr = self._ptr + delta - segment = self._shm.buf[self._ptr:next_ptr] - self._ptr = next_ptr + segment = self._shm.buf[self._ptr:target_ptr] + self._ptr = target_ptr - if self.ptr == self.size: + if self._ptr == self.size: # reached the end, signal wrap around self._ptr = 0 + self._write_ptr = 0 self._wrap_event.write(1) return segment - async def aclose(self): - self._write_event.close() - self._wrap_event.close() - self._shm.close() - - async def __aenter__(self): + def open(self): + self._shm = SharedMemory( + name=self._token.shm_name, + size=self._token.buf_size, + create=False + ) self._write_event.open() self._wrap_event.open() + + async def aclose(self): + if self._is_ipc: + self._write_event.close() + self._wrap_event.close() + self._shm.close() + + async def __aenter__(self): + self.open() return self