diff --git a/tests/test_ringbuf.py b/tests/test_ringbuf.py index 9e457b2a..64fb37e9 100644 --- a/tests/test_ringbuf.py +++ b/tests/test_ringbuf.py @@ -4,7 +4,8 @@ import trio import pytest import tractor from tractor.ipc import ( - open_eventfd, + open_ringbuf, + RBToken, RingBuffSender, RingBuffReceiver ) @@ -15,22 +16,16 @@ from tractor._testing.samples import generate_sample_messages async def child_read_shm( ctx: tractor.Context, msg_amount: int, - shm_key: str, - write_eventfd: int, - wrap_eventfd: int, + token: RBToken, buf_size: int, total_bytes: int, - flags: int = 0, ) -> None: recvd_bytes = 0 await ctx.started() start_ts = time.time() async with RingBuffReceiver( - shm_key, - write_eventfd, - wrap_eventfd, + token, buf_size=buf_size, - flags=flags ) as receiver: while recvd_bytes < total_bytes: msg = await receiver.receive_some() @@ -55,9 +50,7 @@ async def child_write_shm( msg_amount: int, rand_min: int, rand_max: int, - shm_key: str, - write_eventfd: int, - wrap_eventfd: int, + token: RBToken, buf_size: int, ) -> None: msgs, total_bytes = generate_sample_messages( @@ -67,9 +60,7 @@ async def child_write_shm( ) await ctx.started(total_bytes) async with RingBuffSender( - shm_key, - write_eventfd, - wrap_eventfd, + token, buf_size=buf_size ) as sender: for msg in msgs: @@ -100,52 +91,48 @@ def test_ringbuf( rand_max: int, buf_size: int ): - write_eventfd = open_eventfd() - wrap_eventfd = open_eventfd() - - proc_kwargs = { - 'pass_fds': (write_eventfd, wrap_eventfd) - } - - shm_key = 'test_ring_buff' - - common_kwargs = { - 'msg_amount': msg_amount, - 'shm_key': shm_key, - 'write_eventfd': write_eventfd, - 'wrap_eventfd': wrap_eventfd, - 'buf_size': buf_size - } - async def main(): - async with tractor.open_nursery() as an: - send_p = await an.start_actor( - 'ring_sender', - enable_modules=[__name__], - proc_kwargs=proc_kwargs - ) - recv_p = await an.start_actor( - 'ring_receiver', - enable_modules=[__name__], - proc_kwargs=proc_kwargs - ) - async with ( - send_p.open_context( - child_write_shm, - rand_min=rand_min, - rand_max=rand_max, - **common_kwargs - ) as (sctx, total_bytes), - recv_p.open_context( - child_read_shm, - **common_kwargs, - total_bytes=total_bytes, - ) as (sctx, _sent), - ): - await recv_p.result() + with open_ringbuf( + 'test_ringbuf', + buf_size=buf_size + ) as token: + proc_kwargs = { + 'pass_fds': (token.write_eventfd, token.wrap_eventfd) + } - await send_p.cancel_actor() - await recv_p.cancel_actor() + common_kwargs = { + 'msg_amount': msg_amount, + 'token': token, + 'buf_size': buf_size + } + async with tractor.open_nursery() as an: + send_p = await an.start_actor( + 'ring_sender', + enable_modules=[__name__], + proc_kwargs=proc_kwargs + ) + recv_p = await an.start_actor( + 'ring_receiver', + enable_modules=[__name__], + proc_kwargs=proc_kwargs + ) + async with ( + send_p.open_context( + child_write_shm, + rand_min=rand_min, + rand_max=rand_max, + **common_kwargs + ) as (sctx, total_bytes), + recv_p.open_context( + child_read_shm, + **common_kwargs, + total_bytes=total_bytes, + ) as (sctx, _sent), + ): + await recv_p.result() + + await send_p.cancel_actor() + await recv_p.cancel_actor() trio.run(main) @@ -154,55 +141,35 @@ def test_ringbuf( @tractor.context async def child_blocked_receiver( ctx: tractor.Context, - shm_key: str, - write_eventfd: int, - wrap_eventfd: int, - flags: int = 0 + token: RBToken ): - async with RingBuffReceiver( - shm_key, - write_eventfd, - wrap_eventfd, - flags=flags - ) as receiver: + async with RingBuffReceiver(token) as receiver: await ctx.started() await receiver.receive_some() def test_ring_reader_cancel(): - write_eventfd = open_eventfd() - wrap_eventfd = open_eventfd() - - proc_kwargs = { - 'pass_fds': (write_eventfd, wrap_eventfd) - } - - shm_key = 'test_ring_cancel' - async def main(): - async with ( - tractor.open_nursery() as an, - RingBuffSender( - shm_key, - write_eventfd, - wrap_eventfd, - ) as _sender, - ): - recv_p = await an.start_actor( - 'ring_blocked_receiver', - enable_modules=[__name__], - proc_kwargs=proc_kwargs - ) + with open_ringbuf('test_ring_cancel') as token: async with ( - recv_p.open_context( - child_blocked_receiver, - write_eventfd=write_eventfd, - wrap_eventfd=wrap_eventfd, - shm_key=shm_key, - ) as (sctx, _sent), + tractor.open_nursery() as an, + RingBuffSender(token) as _sender, ): - await trio.sleep(1) - await an.cancel() + recv_p = await an.start_actor( + 'ring_blocked_receiver', + enable_modules=[__name__], + proc_kwargs={ + 'pass_fds': (token.write_eventfd, token.wrap_eventfd) + } + ) + async with ( + recv_p.open_context( + child_blocked_receiver, + token=token + ) as (sctx, _sent), + ): + await trio.sleep(1) + await an.cancel() with pytest.raises(tractor._exceptions.ContextCancelled): diff --git a/tractor/ipc/__init__.py b/tractor/ipc/__init__.py index 59fc1e16..ec6217a1 100644 --- a/tractor/ipc/__init__.py +++ b/tractor/ipc/__init__.py @@ -36,6 +36,8 @@ if platform.system() == 'Linux': ) from ._ringbuf import ( + RBToken as RBToken, RingBuffSender as RingBuffSender, - RingBuffReceiver as RingBuffReceiver + RingBuffReceiver as RingBuffReceiver, + open_ringbuf ) diff --git a/tractor/ipc/_ringbuf.py b/tractor/ipc/_ringbuf.py index 0a4f3819..c590e8e2 100644 --- a/tractor/ipc/_ringbuf.py +++ b/tractor/ipc/_ringbuf.py @@ -17,16 +17,65 @@ IPC Reliable RingBuffer implementation ''' +from __future__ import annotations +from contextlib import contextmanager as cm from multiprocessing.shared_memory import SharedMemory import trio +from msgspec import ( + Struct, + to_builtins +) from ._linux import ( EFD_NONBLOCK, + open_eventfd, EventFD ) +class RBToken(Struct, frozen=True): + ''' + RingBuffer token contains necesary info to open the two + eventfds and the shared memory + + ''' + shm_name: str + write_eventfd: int + wrap_eventfd: int + + def as_msg(self): + return to_builtins(self) + + @classmethod + def from_msg(cls, msg: dict) -> RBToken: + if isinstance(msg, RBToken): + return msg + + return RBToken(**msg) + + +@cm +def open_ringbuf( + shm_name: str, + buf_size: int = 10 * 1024, + write_efd_flags: int = 0, + wrap_efd_flags: int = 0 +) -> RBToken: + shm = SharedMemory( + name=shm_name, + size=buf_size, + create=True + ) + token = RBToken( + shm_name=shm_name, + write_eventfd=open_eventfd(flags=write_efd_flags), + wrap_eventfd=open_eventfd(flags=wrap_efd_flags) + ) + yield token + shm.close() + + class RingBuffSender(trio.abc.SendStream): ''' IPC Reliable Ring Buffer sender side implementation @@ -34,28 +83,22 @@ class RingBuffSender(trio.abc.SendStream): `eventfd(2)` is used for wrap around sync, and also to signal writes to the reader. - TODO: if blocked on wrap around event wait it will not respond - to signals, fix soon TM ''' - def __init__( self, - shm_key: str, - write_eventfd: int, - wrap_eventfd: int, + token: RBToken, start_ptr: int = 0, buf_size: int = 10 * 1024, - unlink_on_exit: bool = True ): + token = RBToken.from_msg(token) self._shm = SharedMemory( - name=shm_key, + name=token.shm_name, size=buf_size, - create=True + create=False ) - self._write_event = EventFD(write_eventfd, 'w') - self._wrap_event = EventFD(wrap_eventfd, 'r') + self._write_event = EventFD(token.write_eventfd, 'w') + self._wrap_event = EventFD(token.wrap_eventfd, 'r') self._ptr = start_ptr - self.unlink_on_exit = unlink_on_exit @property def key(self) -> str: @@ -104,11 +147,7 @@ class RingBuffSender(trio.abc.SendStream): async def aclose(self): self._write_event.close() self._wrap_event.close() - if self.unlink_on_exit: - self._shm.unlink() - - else: - self._shm.close() + self._shm.close() async def __aenter__(self): self._write_event.open() @@ -123,29 +162,22 @@ class RingBuffReceiver(trio.abc.ReceiveStream): `eventfd(2)` is used for wrap around sync, and also to signal writes to the reader. - Unless eventfd(2) object is opened with EFD_NONBLOCK flag, - calls to `receive_some` will block the signal handling, - on the main thread, for now solution is using polling, - working on a way to unblock GIL during read(2) to allow - signal processing on the main thread. ''' - def __init__( self, - shm_key: str, - write_eventfd: int, - wrap_eventfd: int, + token: RBToken, start_ptr: int = 0, buf_size: int = 10 * 1024, flags: int = 0 ): + token = RBToken.from_msg(token) self._shm = SharedMemory( - name=shm_key, + name=token.shm_name, size=buf_size, create=False ) - self._write_event = EventFD(write_eventfd, 'w') - self._wrap_event = EventFD(wrap_eventfd, 'r') + self._write_event = EventFD(token.write_eventfd, 'w') + self._wrap_event = EventFD(token.wrap_eventfd, 'r') self._ptr = start_ptr self._flags = flags