Better encapsulate RingBuff ctx managment methods and support non ipc usage

Add trio.StrictFIFOLock on sender.send_all
Support max_bytes argument on receive_some, keep track of write_ptr on receiver
Add max_bytes receive test test_ringbuf_max_bytes
Add docstrings to all ringbuf tests
Remove EFD_NONBLOCK support, not necesary anymore since we can use abandon_on_cancel=True on trio.to_thread.run_sync
Close eventfd's after usage on open_ringbuf
Guillermo Rodriguez 2025-03-16 17:50:13 -03:00
parent 3c5420f4c9
commit d6721f06df
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
2 changed files with 153 additions and 81 deletions

View File

@ -58,6 +58,8 @@ async def child_write_shm(
for msg in msgs:
await sender.send_all(msg)
print('writer exit')
@pytest.mark.parametrize(
'msg_amount,rand_min,rand_max,buf_size',
@ -83,6 +85,15 @@ def test_ringbuf(
rand_max: int,
buf_size: int
):
'''
- Open a new ring buf on root actor
- Create a sender subactor and generate {msg_amount} messages
optionally with a random amount of bytes at the end of each,
return total_bytes on `ctx.started`, then send all messages
- Create a receiver subactor and receive until total_bytes are
read, print simple perf stats.
'''
async def main():
with open_ringbuf(
'test_ringbuf',
@ -140,6 +151,11 @@ async def child_blocked_receiver(
def test_ring_reader_cancel():
'''
Test that a receiver blocked on eventfd(2) read responds to
cancellation.
'''
async def main():
with open_ringbuf('test_ring_cancel_reader') as token:
async with (
@ -178,6 +194,11 @@ async def child_blocked_sender(
def test_ring_sender_cancel():
'''
Test that a sender blocked on eventfd(2) read responds to
cancellation.
'''
async def main():
with open_ringbuf(
'test_ring_cancel_sender',
@ -203,3 +224,36 @@ def test_ring_sender_cancel():
with pytest.raises(tractor._exceptions.ContextCancelled):
trio.run(main)
def test_ringbuf_max_bytes():
'''
Test that RingBuffReceiver.receive_some's max_bytes optional
argument works correctly, send a msg of size 100, then
force receive of messages with max_bytes == 1, wait until
100 of these messages are received, then compare join of
msgs with original message
'''
msg = b''.join(str(i % 10).encode() for i in range(100))
msgs = []
async def main():
with open_ringbuf(
'test_ringbuf_max_bytes',
buf_size=10
) as token:
async with (
trio.open_nursery() as n,
RingBuffSender(token, is_ipc=False) as sender,
RingBuffReceiver(token, is_ipc=False) as receiver
):
n.start_soon(sender.send_all, msg)
while len(msgs) < len(msg):
msg_part = await receiver.receive_some(max_bytes=1)
msg_part = bytes(msg_part)
assert len(msg_part) == 1
msgs.append(msg_part)
trio.run(main)
assert msg == b''.join(msgs)

View File

@ -28,11 +28,15 @@ from msgspec import (
)
from ._linux import (
EFD_NONBLOCK,
open_eventfd,
close_eventfd,
EventFD
)
from ._mp_bs import disable_mantracker
from tractor.log import get_logger
log = get_logger(__name__)
disable_mantracker()
@ -64,8 +68,6 @@ class RBToken(Struct, frozen=True):
def open_ringbuf(
shm_name: str,
buf_size: int = 10 * 1024,
write_efd_flags: int = 0,
wrap_efd_flags: int = 0
) -> RBToken:
shm = SharedMemory(
name=shm_name,
@ -75,16 +77,21 @@ def open_ringbuf(
try:
token = RBToken(
shm_name=shm_name,
write_eventfd=open_eventfd(flags=write_efd_flags),
wrap_eventfd=open_eventfd(flags=wrap_efd_flags),
write_eventfd=open_eventfd(),
wrap_eventfd=open_eventfd(),
buf_size=buf_size
)
yield token
close_eventfd(token.write_eventfd)
close_eventfd(token.wrap_eventfd)
finally:
shm.unlink()
Buffer = bytes | bytearray | memoryview
class RingBuffSender(trio.abc.SendStream):
'''
IPC Reliable Ring Buffer sender side implementation
@ -97,24 +104,26 @@ class RingBuffSender(trio.abc.SendStream):
self,
token: RBToken,
start_ptr: int = 0,
is_ipc: bool = True
):
token = RBToken.from_msg(token)
self._shm = SharedMemory(
name=token.shm_name,
size=token.buf_size,
create=False
)
self._write_event = EventFD(token.write_eventfd, 'w')
self._wrap_event = EventFD(token.wrap_eventfd, 'r')
self._token = RBToken.from_msg(token)
self._shm: SharedMemory | None = None
self._write_event = EventFD(self._token.write_eventfd, 'w')
self._wrap_event = EventFD(self._token.wrap_eventfd, 'r')
self._ptr = start_ptr
self._is_ipc = is_ipc
self._send_lock = trio.StrictFIFOLock()
@property
def key(self) -> str:
def name(self) -> str:
if not self._shm:
raise ValueError('shared memory not initialized yet!')
return self._shm.name
@property
def size(self) -> int:
return self._shm.size
return self._token.buf_size
@property
def ptr(self) -> int:
@ -128,38 +137,48 @@ class RingBuffSender(trio.abc.SendStream):
def wrap_fd(self) -> int:
return self._wrap_event.fd
async def send_all(self, data: bytes | bytearray | memoryview):
# while data is larger than the remaining buf
target_ptr = self.ptr + len(data)
while target_ptr > self.size:
# write all bytes that fit
remaining = self.size - self.ptr
self._shm.buf[self.ptr:] = data[:remaining]
# signal write and wait for reader wrap around
self._write_event.write(remaining)
await self._wrap_event.read()
async def send_all(self, data: Buffer):
async with self._send_lock:
# while data is larger than the remaining buf
target_ptr = self.ptr + len(data)
while target_ptr > self.size:
# write all bytes that fit
remaining = self.size - self.ptr
self._shm.buf[self.ptr:] = data[:remaining]
# signal write and wait for reader wrap around
self._write_event.write(remaining)
await self._wrap_event.read()
# wrap around and trim already written bytes
self._ptr = 0
data = data[remaining:]
target_ptr = self._ptr + len(data)
# wrap around and trim already written bytes
self._ptr = 0
data = data[remaining:]
target_ptr = self._ptr + len(data)
# remaining data fits on buffer
self._shm.buf[self.ptr:target_ptr] = data
self._write_event.write(len(data))
self._ptr = target_ptr
# remaining data fits on buffer
self._shm.buf[self.ptr:target_ptr] = data
self._write_event.write(len(data))
self._ptr = target_ptr
async def wait_send_all_might_not_block(self):
raise NotImplementedError
async def aclose(self):
self._write_event.close()
self._wrap_event.close()
self._shm.close()
async def __aenter__(self):
def open(self):
self._shm = SharedMemory(
name=self._token.shm_name,
size=self._token.buf_size,
create=False
)
self._write_event.open()
self._wrap_event.open()
async def aclose(self):
if self._is_ipc:
self._write_event.close()
self._wrap_event.close()
self._shm.close()
async def __aenter__(self):
self.open()
return self
@ -175,26 +194,25 @@ class RingBuffReceiver(trio.abc.ReceiveStream):
self,
token: RBToken,
start_ptr: int = 0,
flags: int = 0
is_ipc: bool = True
):
token = RBToken.from_msg(token)
self._shm = SharedMemory(
name=token.shm_name,
size=token.buf_size,
create=False
)
self._write_event = EventFD(token.write_eventfd, 'w')
self._wrap_event = EventFD(token.wrap_eventfd, 'r')
self._token = RBToken.from_msg(token)
self._shm: SharedMemory | None = None
self._write_event = EventFD(self._token.write_eventfd, 'w')
self._wrap_event = EventFD(self._token.wrap_eventfd, 'r')
self._ptr = start_ptr
self._flags = flags
self._write_ptr = start_ptr
self._is_ipc = is_ipc
@property
def key(self) -> str:
def name(self) -> str:
if not self._shm:
raise ValueError('shared memory not initialized yet!')
return self._shm.name
@property
def size(self) -> int:
return self._shm.size
return self._token.buf_size
@property
def ptr(self) -> int:
@ -208,46 +226,46 @@ class RingBuffReceiver(trio.abc.ReceiveStream):
def wrap_fd(self) -> int:
return self._wrap_event.fd
async def receive_some(
self,
max_bytes: int | None = None,
nb_timeout: float = 0.1
) -> memoryview:
# if non blocking eventfd enabled, do polling
# until next write, this allows signal handling
if self._flags | EFD_NONBLOCK:
delta = None
while delta is None:
try:
delta = await self._write_event.read()
except OSError as e:
if e.errno == 'EAGAIN':
continue
raise e
else:
async def receive_some(self, max_bytes: int | None = None) -> memoryview:
delta = self._write_ptr - self._ptr
if delta == 0:
delta = await self._write_event.read()
self._write_ptr += delta
if isinstance(max_bytes, int):
if max_bytes == 0:
raise ValueError('if set, max_bytes must be > 0')
delta = min(delta, max_bytes)
target_ptr = self._ptr + delta
# fetch next segment and advance ptr
next_ptr = self._ptr + delta
segment = self._shm.buf[self._ptr:next_ptr]
self._ptr = next_ptr
segment = self._shm.buf[self._ptr:target_ptr]
self._ptr = target_ptr
if self.ptr == self.size:
if self._ptr == self.size:
# reached the end, signal wrap around
self._ptr = 0
self._write_ptr = 0
self._wrap_event.write(1)
return segment
async def aclose(self):
self._write_event.close()
self._wrap_event.close()
self._shm.close()
async def __aenter__(self):
def open(self):
self._shm = SharedMemory(
name=self._token.shm_name,
size=self._token.buf_size,
create=False
)
self._write_event.open()
self._wrap_event.open()
async def aclose(self):
if self._is_ipc:
self._write_event.close()
self._wrap_event.close()
self._shm.close()
async def __aenter__(self):
self.open()
return self