From 5afe0a02643a803e24a5e5ecaeeae4d87f62957c Mon Sep 17 00:00:00 2001
From: Guillermo Rodriguez <guillermo@telos.net>
Date: Thu, 13 Mar 2025 20:17:04 -0300
Subject: [PATCH] General improvements

EventFD class now expects the fd to already be init with open_eventfd
RingBuff Sender and Receiver fully manage SharedMemory and EventFD lifecycles, no aditional ctx mngrs needed
Separate ring buf tests into its own test bed
Add parametrization to test and cancellation
Add docstrings
Add simple testing data gen module .samples
---
 tests/test_ringbuf.py       | 212 ++++++++++++++++++++++++++++++++++++
 tests/test_shm.py           |  80 --------------
 tractor/_shm.py             | 195 +++++++++++++++++++--------------
 tractor/_testing/samples.py |  35 ++++++
 4 files changed, 360 insertions(+), 162 deletions(-)
 create mode 100644 tests/test_ringbuf.py
 create mode 100644 tractor/_testing/samples.py

diff --git a/tests/test_ringbuf.py b/tests/test_ringbuf.py
new file mode 100644
index 00000000..b81ea5f9
--- /dev/null
+++ b/tests/test_ringbuf.py
@@ -0,0 +1,212 @@
+import time
+
+import trio
+import pytest
+import tractor
+from tractor._shm import (
+    EFD_NONBLOCK,
+    open_eventfd,
+    RingBuffSender,
+    RingBuffReceiver
+)
+from tractor._testing.samples import generate_sample_messages
+
+
+@tractor.context
+async def child_read_shm(
+    ctx: tractor.Context,
+    msg_amount: int,
+    shm_key: str,
+    write_eventfd: int,
+    wrap_eventfd: int,
+    buf_size: int,
+    total_bytes: int,
+    flags: int = 0,
+) -> None:
+    recvd_bytes = 0
+    await ctx.started()
+    start_ts = time.time()
+    async with RingBuffReceiver(
+        shm_key,
+        write_eventfd,
+        wrap_eventfd,
+        buf_size=buf_size,
+        flags=flags
+    ) as receiver:
+        while recvd_bytes < total_bytes:
+            msg = await receiver.receive_some()
+            recvd_bytes += len(msg)
+
+        # make sure we dont hold any memoryviews
+        # before the ctx manager aclose()
+        msg = None
+
+    end_ts = time.time()
+    elapsed = end_ts - start_ts
+    elapsed_ms = int(elapsed * 1000)
+
+    print(f'\n\telapsed ms: {elapsed_ms}')
+    print(f'\tmsg/sec: {int(msg_amount / elapsed):,}')
+    print(f'\tbytes/sec: {int(recvd_bytes / elapsed):,}')
+
+
+@tractor.context
+async def child_write_shm(
+    ctx: tractor.Context,
+    msg_amount: int,
+    rand_min: int,
+    rand_max: int,
+    shm_key: str,
+    write_eventfd: int,
+    wrap_eventfd: int,
+    buf_size: int,
+) -> None:
+    msgs, total_bytes = generate_sample_messages(
+        msg_amount,
+        rand_min=rand_min,
+        rand_max=rand_max,
+    )
+    await ctx.started(total_bytes)
+    async with RingBuffSender(
+        shm_key,
+        write_eventfd,
+        wrap_eventfd,
+        buf_size=buf_size
+    ) as sender:
+        for msg in msgs:
+            await sender.send_all(msg)
+
+
+@pytest.mark.parametrize(
+    'msg_amount,rand_min,rand_max,buf_size',
+    [
+        # simple case, fixed payloads, large buffer
+        (100_000, 0, 0, 10 * 1024),
+
+        # guaranteed wrap around on every write
+        (100, 10 * 1024, 20 * 1024, 10 * 1024),
+
+        # large payload size, but large buffer
+        (10_000, 256 * 1024, 512 * 1024, 10 * 1024 * 1024)
+    ],
+    ids=[
+        'fixed_payloads_large_buffer',
+        'wrap_around_every_write',
+        'large_payloads_large_buffer',
+    ]
+)
+def test_ring_buff(
+    msg_amount: int,
+    rand_min: int,
+    rand_max: int,
+    buf_size: int
+):
+    write_eventfd = open_eventfd()
+    wrap_eventfd = open_eventfd()
+
+    proc_kwargs = {
+        'pass_fds': (write_eventfd, wrap_eventfd)
+    }
+
+    shm_key = 'test_ring_buff'
+
+    common_kwargs = {
+        'msg_amount': msg_amount,
+        'shm_key': shm_key,
+        'write_eventfd': write_eventfd,
+        'wrap_eventfd': wrap_eventfd,
+        'buf_size': buf_size
+    }
+
+    async def main():
+        async with tractor.open_nursery() as an:
+            send_p = await an.start_actor(
+                'ring_sender',
+                enable_modules=[__name__],
+                proc_kwargs=proc_kwargs
+            )
+            recv_p = await an.start_actor(
+                'ring_receiver',
+                enable_modules=[__name__],
+                proc_kwargs=proc_kwargs
+            )
+            async with (
+                send_p.open_context(
+                    child_write_shm,
+                    rand_min=rand_min,
+                    rand_max=rand_max,
+                    **common_kwargs
+                ) as (sctx, total_bytes),
+                recv_p.open_context(
+                    child_read_shm,
+                    **common_kwargs,
+                    total_bytes=total_bytes,
+                ) as (sctx, _sent),
+            ):
+                await recv_p.result()
+
+            await send_p.cancel_actor()
+            await recv_p.cancel_actor()
+
+
+    trio.run(main)
+
+
+@tractor.context
+async def child_blocked_receiver(
+    ctx: tractor.Context,
+    shm_key: str,
+    write_eventfd: int,
+    wrap_eventfd: int,
+    flags: int = 0
+):
+    async with RingBuffReceiver(
+        shm_key,
+        write_eventfd,
+        wrap_eventfd,
+        flags=flags
+    ) as receiver:
+        await ctx.started()
+        await receiver.receive_some()
+
+
+def test_ring_reader_cancel():
+    flags = EFD_NONBLOCK
+    write_eventfd = open_eventfd(flags=flags)
+    wrap_eventfd = open_eventfd()
+
+    proc_kwargs = {
+        'pass_fds': (write_eventfd, wrap_eventfd)
+    }
+
+    shm_key = 'test_ring_cancel'
+
+    async def main():
+        async with (
+            tractor.open_nursery() as an,
+            RingBuffSender(
+                shm_key,
+                write_eventfd,
+                wrap_eventfd,
+            ) as _sender,
+        ):
+            recv_p = await an.start_actor(
+                'ring_blocked_receiver',
+                enable_modules=[__name__],
+                proc_kwargs=proc_kwargs
+            )
+            async with (
+                recv_p.open_context(
+                    child_blocked_receiver,
+                    write_eventfd=write_eventfd,
+                    wrap_eventfd=wrap_eventfd,
+                    shm_key=shm_key,
+                    flags=flags
+                ) as (sctx, _sent),
+            ):
+                await trio.sleep(1)
+                await an.cancel()
+
+
+    with pytest.raises(tractor._exceptions.ContextCancelled):
+        trio.run(main)
diff --git a/tests/test_shm.py b/tests/test_shm.py
index db0b1818..2b7a382f 100644
--- a/tests/test_shm.py
+++ b/tests/test_shm.py
@@ -2,10 +2,7 @@
 Shared mem primitives and APIs.
 
 """
-import time
 import uuid
-import string
-import random
 
 # import numpy
 import pytest
@@ -14,7 +11,6 @@ import tractor
 from tractor._shm import (
     open_shm_list,
     attach_shm_list,
-    EventFD, open_ringbuffer_sender, open_ringbuffer_receiver,
 )
 
 
@@ -169,79 +165,3 @@ def test_parent_writer_child_reader(
             await portal.cancel_actor()
 
     trio.run(main)
-
-
-def random_string(size=256):
-    return ''.join(random.choice(string.ascii_lowercase) for i in range(size))
-
-
-async def child_read_shm(
-    msg_amount: int,
-    key: str,
-    write_event_fd: int,
-    wrap_event_fd: int,
-    max_bytes: int,
-) -> None:
-    log = tractor.log.get_console_log(level='info')
-    recvd_msgs = 0
-    start_ts = time.time()
-    async with open_ringbuffer_receiver(
-        write_event_fd,
-        wrap_event_fd,
-        key,
-        max_bytes=max_bytes
-    ) as receiver:
-        while recvd_msgs < msg_amount:
-            msg = await receiver.receive_some()
-            msgs = bytes(msg).split(b'\n')
-            first = msgs[0]
-            last = msgs[-2]
-            log.info((receiver.ptr - len(msg), receiver.ptr, first[:10], last[:10]))
-            recvd_msgs += len(msgs)
-
-    end_ts = time.time()
-    elapsed = end_ts - start_ts
-    elapsed_ms = int(elapsed * 1000)
-    log.info(f'elapsed ms: {elapsed_ms}')
-    log.info(f'msg/sec: {int(msg_amount / elapsed):,}')
-    log.info(f'bytes/sec: {int(max_bytes / elapsed):,}')
-
-def test_ring_buff():
-    log = tractor.log.get_console_log(level='info')
-    msg_amount = 100_000
-    log.info(f'generating {msg_amount} messages...')
-    msgs = [
-        f'[{i:08}]: {random_string()}\n'.encode('utf-8')
-        for i in range(msg_amount)
-    ]
-    buf_size = sum((len(m) for m in msgs))
-    log.info(f'done! buffer size: {buf_size}')
-    async def main():
-        with (
-            EventFD(initval=0) as write_event,
-            EventFD(initval=0) as wrap_event,
-        ):
-            async with (
-                tractor.open_nursery() as an,
-                open_ringbuffer_sender(
-                    write_event.fd,
-                    wrap_event.fd,
-                    max_bytes=buf_size
-                ) as sender
-            ):
-                await an.run_in_actor(
-                    child_read_shm,
-                    msg_amount=msg_amount,
-                    key=sender.key,
-                    write_event_fd=write_event.fd,
-                    wrap_event_fd=wrap_event.fd,
-                    max_bytes=buf_size,
-                    proc_kwargs={
-                        'pass_fds': (write_event.fd, wrap_event.fd)
-                    }
-                )
-                for msg in msgs:
-                    await sender.send_all(msg)
-
-
-    trio.run(main)
diff --git a/tractor/_shm.py b/tractor/_shm.py
index 7c177bc5..5038e77a 100644
--- a/tractor/_shm.py
+++ b/tractor/_shm.py
@@ -837,8 +837,6 @@ def attach_shm_list(
 if platform.system() == 'Linux':
     import os
     import errno
-    import string
-    import random
     from contextlib import asynccontextmanager as acm
 
     import cffi
@@ -862,19 +860,21 @@ if platform.system() == 'Linux':
         '''
     )
 
+
     # Open the default dynamic library (essentially 'libc' in most cases)
     C = ffi.dlopen(None)
 
     # Constants from <sys/eventfd.h>, if needed.
-    EFD_SEMAPHORE = 1 << 0  # 0x1
-    EFD_CLOEXEC   = 1 << 1  # 0x2
-    EFD_NONBLOCK  = 1 << 2  # 0x4
+    EFD_SEMAPHORE = 1
+    EFD_CLOEXEC = 0o2000000
+    EFD_NONBLOCK = 0o4000
 
 
     def open_eventfd(initval: int = 0, flags: int = 0) -> int:
         '''
         Open an eventfd with the given initial value and flags.
         Returns the file descriptor on success, otherwise raises OSError.
+
         '''
         fd = C.eventfd(initval, flags)
         if fd < 0:
@@ -884,6 +884,7 @@ if platform.system() == 'Linux':
     def write_eventfd(fd: int, value: int) -> int:
         '''
         Write a 64-bit integer (uint64_t) to the eventfd's counter.
+
         '''
         # Create a uint64_t* in C, store `value`
         data_ptr = ffi.new('uint64_t *', value)
@@ -899,6 +900,7 @@ if platform.system() == 'Linux':
         '''
         Read a 64-bit integer (uint64_t) from the eventfd, returning the value.
         Reading resets the counter to 0 (unless using EFD_SEMAPHORE).
+
         '''
         # Allocate an 8-byte buffer in C for reading
         buf = ffi.new('char[]', 8)
@@ -914,6 +916,7 @@ if platform.system() == 'Linux':
     def close_eventfd(fd: int) -> int:
         '''
         Close the eventfd.
+
         '''
         ret = C.close(fd)
         if ret < 0:
@@ -921,17 +924,19 @@ if platform.system() == 'Linux':
 
 
     class EventFD:
+        '''
+        Use a previously opened eventfd(2), meant to be used in
+        sub-actors after root actor opens the eventfds then passes
+        them through pass_fds
+
+        '''
 
         def __init__(
             self,
-            initval: int = 0,
-            flags: int = 0,
-            fd: int | None = None,
-            omode: str = 'r'
+            fd: int,
+            omode: str
         ):
-            self._initval: int = initval
-            self._flags: int = flags
-            self._fd: int | None = fd
+            self._fd: int = fd
             self._omode: str = omode
             self._fobj = None
 
@@ -943,23 +948,15 @@ if platform.system() == 'Linux':
             return write_eventfd(self._fd, value)
 
         async def read(self) -> int:
+            #TODO: how to handle signals?
             return await trio.to_thread.run_sync(read_eventfd, self._fd)
 
         def open(self):
-            if not self._fd:
-                self._fd = open_eventfd(
-                    initval=self._initval, flags=self._flags)
-
-            else:
-                self._fobj = os.fdopen(self._fd, self._omode)
+            self._fobj = os.fdopen(self._fd, self._omode)
 
         def close(self):
             if self._fobj:
                 self._fobj.close()
-                return
-
-            if self._fd:
-                close_eventfd(self._fd)
 
         def __enter__(self):
             self.open()
@@ -970,18 +967,34 @@ if platform.system() == 'Linux':
 
 
     class RingBuffSender(trio.abc.SendStream):
+        '''
+        IPC Reliable Ring Buffer sender side implementation
+
+        `eventfd(2)` is used for wrap around sync, and also to signal
+        writes to the reader.
+
+        TODO: if blocked on wrap around event wait it will not respond
+        to signals, fix soon TM
+        '''
 
         def __init__(
             self,
-            shm: SharedMemory,
-            write_event: EventFD,
-            wrap_event: EventFD,
-            start_ptr: int = 0
+            shm_key: str,
+            write_eventfd: int,
+            wrap_eventfd: int,
+            start_ptr: int = 0,
+            buf_size: int = 10 * 1024,
+            clean_shm_on_exit: bool = True
         ):
-            self._shm: SharedMemory = shm
-            self._write_event = write_event
-            self._wrap_event = wrap_event
+            self._shm = SharedMemory(
+                name=shm_key,
+                size=buf_size,
+                create=True
+            )
+            self._write_event = EventFD(write_eventfd, 'w')
+            self._wrap_event = EventFD(wrap_eventfd, 'r')
             self._ptr = start_ptr
+            self.clean_shm_on_exit = clean_shm_on_exit
 
         @property
         def key(self) -> str:
@@ -1004,25 +1017,37 @@ if platform.system() == 'Linux':
             return self._wrap_event.fd
 
         async def send_all(self, data: bytes | bytearray | memoryview):
+            # while data is larger than the remaining buf
             target_ptr = self.ptr + len(data)
-            if target_ptr > self.size:
+            while target_ptr > self.size:
+                # write all bytes that fit
                 remaining = self.size - self.ptr
                 self._shm.buf[self.ptr:] = data[:remaining]
+                # signal write and wait for reader wrap around
                 self._write_event.write(remaining)
                 await self._wrap_event.read()
+
+                # wrap around and trim already written bytes
                 self._ptr = 0
                 data = data[remaining:]
                 target_ptr = self._ptr + len(data)
 
+            # remaining data fits on buffer
             self._shm.buf[self.ptr:target_ptr] = data
             self._write_event.write(len(data))
             self._ptr = target_ptr
 
         async def wait_send_all_might_not_block(self):
-            ...
+            raise NotImplementedError
 
         async def aclose(self):
-            ...
+            self._write_event.close()
+            self._wrap_event.close()
+            if self.clean_shm_on_exit:
+                self._shm.unlink()
+
+            else:
+                self._shm.close()
 
         async def __aenter__(self):
             self._write_event.open()
@@ -1034,18 +1059,37 @@ if platform.system() == 'Linux':
 
 
     class RingBuffReceiver(trio.abc.ReceiveStream):
+        '''
+        IPC Reliable Ring Buffer receiver side implementation
+
+        `eventfd(2)` is used for wrap around sync, and also to signal
+        writes to the reader.
+
+        Unless eventfd(2) object is opened with EFD_NONBLOCK flag,
+        calls to `receive_some` will block the signal handling,
+        on the main thread, for now solution is using polling,
+        working on a way to unblock GIL during read(2) to allow
+        signal processing on the main thread.
+        '''
 
         def __init__(
             self,
-            shm: SharedMemory,
-            write_event: EventFD,
-            wrap_event: EventFD,
-            start_ptr: int = 0
+            shm_key: str,
+            write_eventfd: int,
+            wrap_eventfd: int,
+            start_ptr: int = 0,
+            buf_size: int = 10 * 1024,
+            flags: int = 0
         ):
-            self._shm: SharedMemory = shm
-            self._write_event = write_event
-            self._wrap_event = wrap_event
+            self._shm = SharedMemory(
+                name=shm_key,
+                size=buf_size,
+                create=False
+            )
+            self._write_event = EventFD(write_eventfd, 'w')
+            self._wrap_event = EventFD(wrap_eventfd, 'r')
             self._ptr = start_ptr
+            self._flags = flags
 
         @property
         def key(self) -> str:
@@ -1067,18 +1111,44 @@ if platform.system() == 'Linux':
         def wrap_fd(self) -> int:
             return self._wrap_event.fd
 
-        async def receive_some(self, max_bytes: int | None = None) -> bytes:
-            delta = await self._write_event.read()
+        async def receive_some(
+            self,
+            max_bytes: int | None = None,
+            nb_timeout: float = 0.1
+        ) -> memoryview:
+            # if non blocking eventfd enabled, do polling
+            # until next write, this allows signal handling
+            if self._flags | EFD_NONBLOCK:
+                delta = None
+                while delta is None:
+                    try:
+                        delta = await self._write_event.read()
+
+                    except OSError as e:
+                        if e.errno == 'EAGAIN':
+                            continue
+
+                        raise e
+
+            else:
+                delta = await self._write_event.read()
+
+            # fetch next segment and advance ptr
             next_ptr = self._ptr + delta
-            segment = bytes(self._shm.buf[self._ptr:next_ptr])
+            segment = self._shm.buf[self._ptr:next_ptr]
             self._ptr = next_ptr
+
             if self.ptr == self.size:
+                # reached the end, signal wrap around
                 self._ptr = 0
                 self._wrap_event.write(1)
+
             return segment
 
         async def aclose(self):
-            ...
+            self._write_event.close()
+            self._wrap_event.close()
+            self._shm.close()
 
         async def __aenter__(self):
             self._write_event.open()
@@ -1087,42 +1157,3 @@ if platform.system() == 'Linux':
 
         async def __aexit__(self, exc_type, exc_value, traceback):
             await self.aclose()
-
-    @acm
-    async def open_ringbuffer_sender(
-        write_event_fd: int,
-        wrap_event_fd: int,
-        key: str | None = None,
-        max_bytes: int = 10 * 1024,
-        start_ptr: int = 0,
-    ) -> RingBuffSender:
-        if not key:
-            key: str = ''.join(random.choice(string.ascii_lowercase) for i in range(32))
-
-        shm = SharedMemory(
-            name=key,
-            size=max_bytes,
-            create=True
-        )
-        async with RingBuffSender(
-            shm, EventFD(fd=write_event_fd, omode='w'), EventFD(fd=wrap_event_fd), start_ptr=start_ptr
-        ) as s:
-            yield s
-
-    @acm
-    async def open_ringbuffer_receiver(
-        write_event_fd: int,
-        wrap_event_fd: int,
-        key: str,
-        max_bytes: int = 10 * 1024,
-        start_ptr: int = 0,
-    ) -> RingBuffSender:
-        shm = SharedMemory(
-            name=key,
-            size=max_bytes,
-            create=False
-        )
-        async with RingBuffReceiver(
-            shm, EventFD(fd=write_event_fd), EventFD(fd=wrap_event_fd, omode='w'), start_ptr=start_ptr
-        ) as r:
-            yield r
diff --git a/tractor/_testing/samples.py b/tractor/_testing/samples.py
new file mode 100644
index 00000000..a87a22c4
--- /dev/null
+++ b/tractor/_testing/samples.py
@@ -0,0 +1,35 @@
+import os
+import random
+
+
+def generate_sample_messages(
+    amount: int,
+    rand_min: int = 0,
+    rand_max: int = 0,
+    silent: bool = False
+) -> tuple[list[bytes], int]:
+
+    msgs = []
+    size = 0
+
+    if not silent:
+        print(f'\ngenerating {amount} messages...')
+
+    for i in range(amount):
+        msg = f'[{i:08}]'.encode('utf-8')
+
+        if rand_max > 0:
+            msg += os.urandom(
+                random.randint(rand_min, rand_max))
+
+        size += len(msg)
+
+        msgs.append(msg)
+
+        if not silent and i and i % 10_000 == 0:
+            print(f'{i} generated')
+
+    if not silent:
+        print(f'done, {size:,} bytes in total')
+
+    return msgs, size