diff --git a/tests/test_ringbuf.py b/tests/test_ringbuf.py new file mode 100644 index 00000000..48035898 --- /dev/null +++ b/tests/test_ringbuf.py @@ -0,0 +1,208 @@ +import time + +import trio +import pytest +import tractor +from tractor._shm import ( + EFD_NONBLOCK, + open_eventfd, + RingBuffSender, + RingBuffReceiver +) +from tractor._testing.samples import generate_sample_messages + + +@tractor.context +async def child_read_shm( + ctx: tractor.Context, + msg_amount: int, + shm_key: str, + write_eventfd: int, + wrap_eventfd: int, + buf_size: int, + total_bytes: int, + flags: int = 0, +) -> None: + recvd_bytes = 0 + await ctx.started() + start_ts = time.time() + async with RingBuffReceiver( + shm_key, + write_eventfd, + wrap_eventfd, + buf_size=buf_size, + flags=flags + ) as receiver: + while recvd_bytes < total_bytes: + msg = await receiver.receive_some() + recvd_bytes += len(msg) + + end_ts = time.time() + elapsed = end_ts - start_ts + elapsed_ms = int(elapsed * 1000) + + print(f'\n\telapsed ms: {elapsed_ms}') + print(f'\tmsg/sec: {int(msg_amount / elapsed):,}') + print(f'\tbytes/sec: {int(recvd_bytes / elapsed):,}') + + +@tractor.context +async def child_write_shm( + ctx: tractor.Context, + msg_amount: int, + rand_min: int, + rand_max: int, + shm_key: str, + write_eventfd: int, + wrap_eventfd: int, + buf_size: int, +) -> None: + msgs, total_bytes = generate_sample_messages( + msg_amount, + rand_min=rand_min, + rand_max=rand_max, + ) + await ctx.started(total_bytes) + async with RingBuffSender( + shm_key, + write_eventfd, + wrap_eventfd, + buf_size=buf_size + ) as sender: + for msg in msgs: + await sender.send_all(msg) + + +@pytest.mark.parametrize( + 'msg_amount,rand_min,rand_max,buf_size', + [ + # simple case, fixed payloads, large buffer + (100_000, 0, 0, 10 * 1024), + + # guaranteed wrap around on every write + (100, 10 * 1024, 20 * 1024, 10 * 1024), + + # large payload size, but large buffer + (10_000, 256 * 1024, 512 * 1024, 10 * 1024 * 1024) + ], + ids=[ + 'fixed_payloads_large_buffer', + 'wrap_around_every_write', + 'large_payloads_large_buffer', + ] +) +def test_ring_buff( + msg_amount: int, + rand_min: int, + rand_max: int, + buf_size: int +): + write_eventfd = open_eventfd() + wrap_eventfd = open_eventfd() + + proc_kwargs = { + 'pass_fds': (write_eventfd, wrap_eventfd) + } + + shm_key = 'test_ring_buff' + + common_kwargs = { + 'msg_amount': msg_amount, + 'shm_key': shm_key, + 'write_eventfd': write_eventfd, + 'wrap_eventfd': wrap_eventfd, + 'buf_size': buf_size + } + + async def main(): + async with tractor.open_nursery() as an: + send_p = await an.start_actor( + 'ring_sender', + enable_modules=[__name__], + proc_kwargs=proc_kwargs + ) + recv_p = await an.start_actor( + 'ring_receiver', + enable_modules=[__name__], + proc_kwargs=proc_kwargs + ) + async with ( + send_p.open_context( + child_write_shm, + rand_min=rand_min, + rand_max=rand_max, + **common_kwargs + ) as (sctx, total_bytes), + recv_p.open_context( + child_read_shm, + **common_kwargs, + total_bytes=total_bytes, + ) as (sctx, _sent), + ): + await recv_p.result() + + await send_p.cancel_actor() + await recv_p.cancel_actor() + + + trio.run(main) + + +@tractor.context +async def child_blocked_receiver( + ctx: tractor.Context, + shm_key: str, + write_eventfd: int, + wrap_eventfd: int, + flags: int = 0 +): + async with RingBuffReceiver( + shm_key, + write_eventfd, + wrap_eventfd, + flags=flags + ) as receiver: + await ctx.started() + await receiver.receive_some() + + +def test_ring_reader_cancel(): + flags = EFD_NONBLOCK + write_eventfd = open_eventfd(flags=flags) + wrap_eventfd = open_eventfd() + + proc_kwargs = { + 'pass_fds': (write_eventfd, wrap_eventfd) + } + + shm_key = 'test_ring_cancel' + + async def main(): + async with ( + tractor.open_nursery() as an, + RingBuffSender( + shm_key, + write_eventfd, + wrap_eventfd, + ) as _sender, + ): + recv_p = await an.start_actor( + 'ring_blocked_receiver', + enable_modules=[__name__], + proc_kwargs=proc_kwargs + ) + async with ( + recv_p.open_context( + child_blocked_receiver, + write_eventfd=write_eventfd, + wrap_eventfd=wrap_eventfd, + shm_key=shm_key, + flags=flags + ) as (sctx, _sent), + ): + await trio.sleep(1) + await an.cancel() + + + with pytest.raises(tractor._exceptions.ContextCancelled): + trio.run(main) diff --git a/tests/test_shm.py b/tests/test_shm.py index db0b1818..2b7a382f 100644 --- a/tests/test_shm.py +++ b/tests/test_shm.py @@ -2,10 +2,7 @@ Shared mem primitives and APIs. """ -import time import uuid -import string -import random # import numpy import pytest @@ -14,7 +11,6 @@ import tractor from tractor._shm import ( open_shm_list, attach_shm_list, - EventFD, open_ringbuffer_sender, open_ringbuffer_receiver, ) @@ -169,79 +165,3 @@ def test_parent_writer_child_reader( await portal.cancel_actor() trio.run(main) - - -def random_string(size=256): - return ''.join(random.choice(string.ascii_lowercase) for i in range(size)) - - -async def child_read_shm( - msg_amount: int, - key: str, - write_event_fd: int, - wrap_event_fd: int, - max_bytes: int, -) -> None: - log = tractor.log.get_console_log(level='info') - recvd_msgs = 0 - start_ts = time.time() - async with open_ringbuffer_receiver( - write_event_fd, - wrap_event_fd, - key, - max_bytes=max_bytes - ) as receiver: - while recvd_msgs < msg_amount: - msg = await receiver.receive_some() - msgs = bytes(msg).split(b'\n') - first = msgs[0] - last = msgs[-2] - log.info((receiver.ptr - len(msg), receiver.ptr, first[:10], last[:10])) - recvd_msgs += len(msgs) - - end_ts = time.time() - elapsed = end_ts - start_ts - elapsed_ms = int(elapsed * 1000) - log.info(f'elapsed ms: {elapsed_ms}') - log.info(f'msg/sec: {int(msg_amount / elapsed):,}') - log.info(f'bytes/sec: {int(max_bytes / elapsed):,}') - -def test_ring_buff(): - log = tractor.log.get_console_log(level='info') - msg_amount = 100_000 - log.info(f'generating {msg_amount} messages...') - msgs = [ - f'[{i:08}]: {random_string()}\n'.encode('utf-8') - for i in range(msg_amount) - ] - buf_size = sum((len(m) for m in msgs)) - log.info(f'done! buffer size: {buf_size}') - async def main(): - with ( - EventFD(initval=0) as write_event, - EventFD(initval=0) as wrap_event, - ): - async with ( - tractor.open_nursery() as an, - open_ringbuffer_sender( - write_event.fd, - wrap_event.fd, - max_bytes=buf_size - ) as sender - ): - await an.run_in_actor( - child_read_shm, - msg_amount=msg_amount, - key=sender.key, - write_event_fd=write_event.fd, - wrap_event_fd=wrap_event.fd, - max_bytes=buf_size, - proc_kwargs={ - 'pass_fds': (write_event.fd, wrap_event.fd) - } - ) - for msg in msgs: - await sender.send_all(msg) - - - trio.run(main) diff --git a/tractor/_shm.py b/tractor/_shm.py index 9c12e934..547cb2dd 100644 --- a/tractor/_shm.py +++ b/tractor/_shm.py @@ -837,8 +837,6 @@ def attach_shm_list( if platform.system() == 'Linux': import os import errno - import string - import random from contextlib import asynccontextmanager as acm import cffi @@ -862,19 +860,21 @@ if platform.system() == 'Linux': ''' ) + # Open the default dynamic library (essentially 'libc' in most cases) C = ffi.dlopen(None) # Constants from , if needed. - EFD_SEMAPHORE = 1 << 0 # 0x1 - EFD_CLOEXEC = 1 << 1 # 0x2 - EFD_NONBLOCK = 1 << 2 # 0x4 + EFD_SEMAPHORE = 1 + EFD_CLOEXEC = 0o2000000 + EFD_NONBLOCK = 0o4000 def open_eventfd(initval: int = 0, flags: int = 0) -> int: ''' Open an eventfd with the given initial value and flags. Returns the file descriptor on success, otherwise raises OSError. + ''' fd = C.eventfd(initval, flags) if fd < 0: @@ -884,6 +884,7 @@ if platform.system() == 'Linux': def write_eventfd(fd: int, value: int) -> int: ''' Write a 64-bit integer (uint64_t) to the eventfd's counter. + ''' # Create a uint64_t* in C, store `value` data_ptr = ffi.new('uint64_t *', value) @@ -899,6 +900,7 @@ if platform.system() == 'Linux': ''' Read a 64-bit integer (uint64_t) from the eventfd, returning the value. Reading resets the counter to 0 (unless using EFD_SEMAPHORE). + ''' # Allocate an 8-byte buffer in C for reading buf = ffi.new('char[]', 8) @@ -914,6 +916,7 @@ if platform.system() == 'Linux': def close_eventfd(fd: int) -> int: ''' Close the eventfd. + ''' ret = C.close(fd) if ret < 0: @@ -921,17 +924,19 @@ if platform.system() == 'Linux': class EventFD: + ''' + Use a previously opened eventfd(2), meant to be used in + sub-actors after root actor opens the eventfds then passes + them through pass_fds + + ''' def __init__( self, - initval: int = 0, - flags: int = 0, - fd: int | None = None, - omode: str = 'r' + fd: int, + omode: str ): - self._initval: int = initval - self._flags: int = flags - self._fd: int | None = fd + self._fd: int = fd self._omode: str = omode self._fobj = None @@ -943,23 +948,15 @@ if platform.system() == 'Linux': return write_eventfd(self._fd, value) async def read(self) -> int: + #TODO: how to handle signals? return await trio.to_thread.run_sync(read_eventfd, self._fd) def open(self): - if not self._fd: - self._fd = open_eventfd( - initval=self._initval, flags=self._flags) - - else: - self._fobj = os.fdopen(self._fd, self._omode) + self._fobj = os.fdopen(self._fd, self._omode) def close(self): if self._fobj: self._fobj.close() - return - - if self._fd: - close_eventfd(self._fd) def __enter__(self): self.open() @@ -970,18 +967,34 @@ if platform.system() == 'Linux': class RingBuffSender(trio.abc.SendStream): + ''' + IPC Reliable Ring Buffer sender side implementation + + `eventfd(2)` is used for wrap around sync, and also to signal + writes to the reader. + + TODO: if blocked on wrap around event wait it will not respond + to signals, fix soon TM + ''' def __init__( self, - shm: SharedMemory, - write_event: EventFD, - wrap_event: EventFD, - start_ptr: int = 0 + shm_key: str, + write_eventfd: int, + wrap_eventfd: int, + start_ptr: int = 0, + buf_size: int = 10 * 1024, + clean_shm_on_exit: bool = True ): - self._shm: SharedMemory = shm - self._write_event = write_event - self._wrap_event = wrap_event + self._shm = SharedMemory( + name=shm_key, + size=buf_size, + create=True + ) + self._write_event = EventFD(write_eventfd, 'w') + self._wrap_event = EventFD(wrap_eventfd, 'r') self._ptr = start_ptr + self.clean_shm_on_exit = clean_shm_on_exit @property def key(self) -> str: @@ -1004,25 +1017,37 @@ if platform.system() == 'Linux': return self._wrap_event.fd async def send_all(self, data: bytes | bytearray | memoryview): + # while data is larger than the remaining buf target_ptr = self.ptr + len(data) - if target_ptr > self.size: + while target_ptr > self.size: + # write all bytes that fit remaining = self.size - self.ptr self._shm.buf[self.ptr:] = data[:remaining] + # signal write and wait for reader wrap around self._write_event.write(remaining) await self._wrap_event.read() + + # wrap around and trim already written bytes self._ptr = 0 data = data[remaining:] target_ptr = self._ptr + len(data) + # remaining data fits on buffer self._shm.buf[self.ptr:target_ptr] = data self._write_event.write(len(data)) self._ptr = target_ptr async def wait_send_all_might_not_block(self): - ... + raise NotImplementedError async def aclose(self): - ... + self._write_event.close() + self._wrap_event.close() + if self.clean_shm_on_exit: + self._shm.unlink() + + else: + self._shm.close() async def __aenter__(self): self._write_event.open() @@ -1034,18 +1059,37 @@ if platform.system() == 'Linux': class RingBuffReceiver(trio.abc.ReceiveStream): + ''' + IPC Reliable Ring Buffer receiver side implementation + + `eventfd(2)` is used for wrap around sync, and also to signal + writes to the reader. + + Unless eventfd(2) object is opened with EFD_NONBLOCK flag, + calls to `receive_some` will block the signal handling, + on the main thread, for now solution is using polling, + working on a way to unblock GIL during read(2) to allow + signal processing on the main thread. + ''' def __init__( self, - shm: SharedMemory, - write_event: EventFD, - wrap_event: EventFD, - start_ptr: int = 0 + shm_key: str, + write_eventfd: int, + wrap_eventfd: int, + start_ptr: int = 0, + buf_size: int = 10 * 1024, + flags: int = 0 ): - self._shm: SharedMemory = shm - self._write_event = write_event - self._wrap_event = wrap_event + self._shm = SharedMemory( + name=shm_key, + size=buf_size, + create=False + ) + self._write_event = EventFD(write_eventfd, 'w') + self._wrap_event = EventFD(wrap_eventfd, 'r') self._ptr = start_ptr + self._flags = flags @property def key(self) -> str: @@ -1067,18 +1111,44 @@ if platform.system() == 'Linux': def wrap_fd(self) -> int: return self._wrap_event.fd - async def receive_some(self, max_bytes: int | None = None) -> bytes: - delta = await self._write_event.read() + async def receive_some( + self, + max_bytes: int | None = None, + nb_timeout: float = 0.1 + ) -> memoryview: + # if non blocking eventfd enabled, do polling + # until next write, this allows signal handling + if self._flags | EFD_NONBLOCK: + delta = None + while delta is None: + try: + delta = await self._write_event.read() + + except OSError as e: + if e.errno == 'EAGAIN': + continue + + raise e + + else: + delta = await self._write_event.read() + + # fetch next segment and advance ptr next_ptr = self._ptr + delta - segment = bytes(self._shm.buf[self._ptr:next_ptr]) + segment = self._shm.buf[self._ptr:next_ptr] self._ptr = next_ptr + if self.ptr == self.size: + # reached the end, signal wrap around self._ptr = 0 self._wrap_event.write(1) + return segment async def aclose(self): - ... + self._write_event.close() + self._wrap_event.close() + self._shm.close() async def __aenter__(self): self._write_event.open() @@ -1087,42 +1157,3 @@ if platform.system() == 'Linux': async def __aexit__(self, exc_type, exc_value, traceback): await self.aclose() - - @acm - async def open_ringbuffer_sender( - write_event_fd: int, - wrap_event_fd: int, - key: str | None = None, - max_bytes: int = 10 * 1024, - start_ptr: int = 0, - ) -> RingBuffSender: - if not key: - key: str = ''.join(random.choice(string.ascii_lowercase) for i in range(32)) - - shm = SharedMemory( - name=key, - size=max_bytes, - create=True - ) - async with RingBuffSender( - shm, EventFD(fd=write_event_fd, omode='w'), EventFD(fd=wrap_event_fd), start_ptr=start_ptr - ) as s: - yield s - - @acm - async def open_ringbuffer_receiver( - write_event_fd: int, - wrap_event_fd: int, - key: str, - max_bytes: int = 10 * 1024, - start_ptr: int = 0, - ) -> RingBuffSender: - shm = SharedMemory( - name=key, - size=max_bytes, - create=False - ) - async with RingBuffReceiver( - shm, EventFD(fd=write_event_fd), EventFD(fd=wrap_event_fd, omode='w'), start_ptr=start_ptr - ) as r: - yield r diff --git a/tractor/_testing/samples.py b/tractor/_testing/samples.py new file mode 100644 index 00000000..a87a22c4 --- /dev/null +++ b/tractor/_testing/samples.py @@ -0,0 +1,35 @@ +import os +import random + + +def generate_sample_messages( + amount: int, + rand_min: int = 0, + rand_max: int = 0, + silent: bool = False +) -> tuple[list[bytes], int]: + + msgs = [] + size = 0 + + if not silent: + print(f'\ngenerating {amount} messages...') + + for i in range(amount): + msg = f'[{i:08}]'.encode('utf-8') + + if rand_max > 0: + msg += os.urandom( + random.randint(rand_min, rand_max)) + + size += len(msg) + + msgs.append(msg) + + if not silent and i and i % 10_000 == 0: + print(f'{i} generated') + + if not silent: + print(f'done, {size:,} bytes in total') + + return msgs, size