General improvements

EventFD class now expects the fd to already be init with open_eventfd
RingBuff Sender and Receiver fully manage SharedMemory and EventFD lifecycles, no aditional ctx mngrs needed
Separate ring buf tests into its own test bed
Add parametrization to test and cancellation
Add docstrings
Add simple testing data gen module .samples
Guillermo Rodriguez 2025-03-13 20:17:04 -03:00
parent bf416ea26f
commit ab1a60bc97
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
4 changed files with 356 additions and 162 deletions

View File

@ -0,0 +1,208 @@
import time
import trio
import pytest
import tractor
from tractor._shm import (
EFD_NONBLOCK,
open_eventfd,
RingBuffSender,
RingBuffReceiver
)
from tractor._testing.samples import generate_sample_messages
@tractor.context
async def child_read_shm(
ctx: tractor.Context,
msg_amount: int,
shm_key: str,
write_eventfd: int,
wrap_eventfd: int,
buf_size: int,
total_bytes: int,
flags: int = 0,
) -> None:
recvd_bytes = 0
await ctx.started()
start_ts = time.time()
async with RingBuffReceiver(
shm_key,
write_eventfd,
wrap_eventfd,
buf_size=buf_size,
flags=flags
) as receiver:
while recvd_bytes < total_bytes:
msg = await receiver.receive_some()
recvd_bytes += len(msg)
end_ts = time.time()
elapsed = end_ts - start_ts
elapsed_ms = int(elapsed * 1000)
print(f'\n\telapsed ms: {elapsed_ms}')
print(f'\tmsg/sec: {int(msg_amount / elapsed):,}')
print(f'\tbytes/sec: {int(recvd_bytes / elapsed):,}')
@tractor.context
async def child_write_shm(
ctx: tractor.Context,
msg_amount: int,
rand_min: int,
rand_max: int,
shm_key: str,
write_eventfd: int,
wrap_eventfd: int,
buf_size: int,
) -> None:
msgs, total_bytes = generate_sample_messages(
msg_amount,
rand_min=rand_min,
rand_max=rand_max,
)
await ctx.started(total_bytes)
async with RingBuffSender(
shm_key,
write_eventfd,
wrap_eventfd,
buf_size=buf_size
) as sender:
for msg in msgs:
await sender.send_all(msg)
@pytest.mark.parametrize(
'msg_amount,rand_min,rand_max,buf_size',
[
# simple case, fixed payloads, large buffer
(100_000, 0, 0, 10 * 1024),
# guaranteed wrap around on every write
(100, 10 * 1024, 20 * 1024, 10 * 1024),
# large payload size, but large buffer
(10_000, 256 * 1024, 512 * 1024, 10 * 1024 * 1024)
],
ids=[
'fixed_payloads_large_buffer',
'wrap_around_every_write',
'large_payloads_large_buffer',
]
)
def test_ring_buff(
msg_amount: int,
rand_min: int,
rand_max: int,
buf_size: int
):
write_eventfd = open_eventfd()
wrap_eventfd = open_eventfd()
proc_kwargs = {
'pass_fds': (write_eventfd, wrap_eventfd)
}
shm_key = 'test_ring_buff'
common_kwargs = {
'msg_amount': msg_amount,
'shm_key': shm_key,
'write_eventfd': write_eventfd,
'wrap_eventfd': wrap_eventfd,
'buf_size': buf_size
}
async def main():
async with tractor.open_nursery() as an:
send_p = await an.start_actor(
'ring_sender',
enable_modules=[__name__],
proc_kwargs=proc_kwargs
)
recv_p = await an.start_actor(
'ring_receiver',
enable_modules=[__name__],
proc_kwargs=proc_kwargs
)
async with (
send_p.open_context(
child_write_shm,
rand_min=rand_min,
rand_max=rand_max,
**common_kwargs
) as (sctx, total_bytes),
recv_p.open_context(
child_read_shm,
**common_kwargs,
total_bytes=total_bytes,
) as (sctx, _sent),
):
await recv_p.result()
await send_p.cancel_actor()
await recv_p.cancel_actor()
trio.run(main)
@tractor.context
async def child_blocked_receiver(
ctx: tractor.Context,
shm_key: str,
write_eventfd: int,
wrap_eventfd: int,
flags: int = 0
):
async with RingBuffReceiver(
shm_key,
write_eventfd,
wrap_eventfd,
flags=flags
) as receiver:
await ctx.started()
await receiver.receive_some()
def test_ring_reader_cancel():
flags = EFD_NONBLOCK
write_eventfd = open_eventfd(flags=flags)
wrap_eventfd = open_eventfd()
proc_kwargs = {
'pass_fds': (write_eventfd, wrap_eventfd)
}
shm_key = 'test_ring_cancel'
async def main():
async with (
tractor.open_nursery() as an,
RingBuffSender(
shm_key,
write_eventfd,
wrap_eventfd,
) as _sender,
):
recv_p = await an.start_actor(
'ring_blocked_receiver',
enable_modules=[__name__],
proc_kwargs=proc_kwargs
)
async with (
recv_p.open_context(
child_blocked_receiver,
write_eventfd=write_eventfd,
wrap_eventfd=wrap_eventfd,
shm_key=shm_key,
flags=flags
) as (sctx, _sent),
):
await trio.sleep(1)
await an.cancel()
with pytest.raises(tractor._exceptions.ContextCancelled):
trio.run(main)

View File

@ -2,10 +2,7 @@
Shared mem primitives and APIs.
"""
import time
import uuid
import string
import random
# import numpy
import pytest
@ -14,7 +11,6 @@ import tractor
from tractor._shm import (
open_shm_list,
attach_shm_list,
EventFD, open_ringbuffer_sender, open_ringbuffer_receiver,
)
@ -169,79 +165,3 @@ def test_parent_writer_child_reader(
await portal.cancel_actor()
trio.run(main)
def random_string(size=256):
return ''.join(random.choice(string.ascii_lowercase) for i in range(size))
async def child_read_shm(
msg_amount: int,
key: str,
write_event_fd: int,
wrap_event_fd: int,
max_bytes: int,
) -> None:
log = tractor.log.get_console_log(level='info')
recvd_msgs = 0
start_ts = time.time()
async with open_ringbuffer_receiver(
write_event_fd,
wrap_event_fd,
key,
max_bytes=max_bytes
) as receiver:
while recvd_msgs < msg_amount:
msg = await receiver.receive_some()
msgs = bytes(msg).split(b'\n')
first = msgs[0]
last = msgs[-2]
log.info((receiver.ptr - len(msg), receiver.ptr, first[:10], last[:10]))
recvd_msgs += len(msgs)
end_ts = time.time()
elapsed = end_ts - start_ts
elapsed_ms = int(elapsed * 1000)
log.info(f'elapsed ms: {elapsed_ms}')
log.info(f'msg/sec: {int(msg_amount / elapsed):,}')
log.info(f'bytes/sec: {int(max_bytes / elapsed):,}')
def test_ring_buff():
log = tractor.log.get_console_log(level='info')
msg_amount = 100_000
log.info(f'generating {msg_amount} messages...')
msgs = [
f'[{i:08}]: {random_string()}\n'.encode('utf-8')
for i in range(msg_amount)
]
buf_size = sum((len(m) for m in msgs))
log.info(f'done! buffer size: {buf_size}')
async def main():
with (
EventFD(initval=0) as write_event,
EventFD(initval=0) as wrap_event,
):
async with (
tractor.open_nursery() as an,
open_ringbuffer_sender(
write_event.fd,
wrap_event.fd,
max_bytes=buf_size
) as sender
):
await an.run_in_actor(
child_read_shm,
msg_amount=msg_amount,
key=sender.key,
write_event_fd=write_event.fd,
wrap_event_fd=wrap_event.fd,
max_bytes=buf_size,
proc_kwargs={
'pass_fds': (write_event.fd, wrap_event.fd)
}
)
for msg in msgs:
await sender.send_all(msg)
trio.run(main)

View File

@ -837,8 +837,6 @@ def attach_shm_list(
if platform.system() == 'Linux':
import os
import errno
import string
import random
from contextlib import asynccontextmanager as acm
import cffi
@ -862,19 +860,21 @@ if platform.system() == 'Linux':
'''
)
# Open the default dynamic library (essentially 'libc' in most cases)
C = ffi.dlopen(None)
# Constants from <sys/eventfd.h>, if needed.
EFD_SEMAPHORE = 1 << 0 # 0x1
EFD_CLOEXEC = 1 << 1 # 0x2
EFD_NONBLOCK = 1 << 2 # 0x4
EFD_SEMAPHORE = 1
EFD_CLOEXEC = 0o2000000
EFD_NONBLOCK = 0o4000
def open_eventfd(initval: int = 0, flags: int = 0) -> int:
'''
Open an eventfd with the given initial value and flags.
Returns the file descriptor on success, otherwise raises OSError.
'''
fd = C.eventfd(initval, flags)
if fd < 0:
@ -884,6 +884,7 @@ if platform.system() == 'Linux':
def write_eventfd(fd: int, value: int) -> int:
'''
Write a 64-bit integer (uint64_t) to the eventfd's counter.
'''
# Create a uint64_t* in C, store `value`
data_ptr = ffi.new('uint64_t *', value)
@ -899,6 +900,7 @@ if platform.system() == 'Linux':
'''
Read a 64-bit integer (uint64_t) from the eventfd, returning the value.
Reading resets the counter to 0 (unless using EFD_SEMAPHORE).
'''
# Allocate an 8-byte buffer in C for reading
buf = ffi.new('char[]', 8)
@ -914,6 +916,7 @@ if platform.system() == 'Linux':
def close_eventfd(fd: int) -> int:
'''
Close the eventfd.
'''
ret = C.close(fd)
if ret < 0:
@ -921,17 +924,19 @@ if platform.system() == 'Linux':
class EventFD:
'''
Use a previously opened eventfd(2), meant to be used in
sub-actors after root actor opens the eventfds then passes
them through pass_fds
'''
def __init__(
self,
initval: int = 0,
flags: int = 0,
fd: int | None = None,
omode: str = 'r'
fd: int,
omode: str
):
self._initval: int = initval
self._flags: int = flags
self._fd: int | None = fd
self._fd: int = fd
self._omode: str = omode
self._fobj = None
@ -943,23 +948,15 @@ if platform.system() == 'Linux':
return write_eventfd(self._fd, value)
async def read(self) -> int:
#TODO: how to handle signals?
return await trio.to_thread.run_sync(read_eventfd, self._fd)
def open(self):
if not self._fd:
self._fd = open_eventfd(
initval=self._initval, flags=self._flags)
else:
self._fobj = os.fdopen(self._fd, self._omode)
self._fobj = os.fdopen(self._fd, self._omode)
def close(self):
if self._fobj:
self._fobj.close()
return
if self._fd:
close_eventfd(self._fd)
def __enter__(self):
self.open()
@ -970,18 +967,34 @@ if platform.system() == 'Linux':
class RingBuffSender(trio.abc.SendStream):
'''
IPC Reliable Ring Buffer sender side implementation
`eventfd(2)` is used for wrap around sync, and also to signal
writes to the reader.
TODO: if blocked on wrap around event wait it will not respond
to signals, fix soon TM
'''
def __init__(
self,
shm: SharedMemory,
write_event: EventFD,
wrap_event: EventFD,
start_ptr: int = 0
shm_key: str,
write_eventfd: int,
wrap_eventfd: int,
start_ptr: int = 0,
buf_size: int = 10 * 1024,
clean_shm_on_exit: bool = True
):
self._shm: SharedMemory = shm
self._write_event = write_event
self._wrap_event = wrap_event
self._shm = SharedMemory(
name=shm_key,
size=buf_size,
create=True
)
self._write_event = EventFD(write_eventfd, 'w')
self._wrap_event = EventFD(wrap_eventfd, 'r')
self._ptr = start_ptr
self.clean_shm_on_exit = clean_shm_on_exit
@property
def key(self) -> str:
@ -1004,25 +1017,37 @@ if platform.system() == 'Linux':
return self._wrap_event.fd
async def send_all(self, data: bytes | bytearray | memoryview):
# while data is larger than the remaining buf
target_ptr = self.ptr + len(data)
if target_ptr > self.size:
while target_ptr > self.size:
# write all bytes that fit
remaining = self.size - self.ptr
self._shm.buf[self.ptr:] = data[:remaining]
# signal write and wait for reader wrap around
self._write_event.write(remaining)
await self._wrap_event.read()
# wrap around and trim already written bytes
self._ptr = 0
data = data[remaining:]
target_ptr = self._ptr + len(data)
# remaining data fits on buffer
self._shm.buf[self.ptr:target_ptr] = data
self._write_event.write(len(data))
self._ptr = target_ptr
async def wait_send_all_might_not_block(self):
...
raise NotImplementedError
async def aclose(self):
...
self._write_event.close()
self._wrap_event.close()
if self.clean_shm_on_exit:
self._shm.unlink()
else:
self._shm.close()
async def __aenter__(self):
self._write_event.open()
@ -1034,18 +1059,37 @@ if platform.system() == 'Linux':
class RingBuffReceiver(trio.abc.ReceiveStream):
'''
IPC Reliable Ring Buffer receiver side implementation
`eventfd(2)` is used for wrap around sync, and also to signal
writes to the reader.
Unless eventfd(2) object is opened with EFD_NONBLOCK flag,
calls to `receive_some` will block the signal handling,
on the main thread, for now solution is using polling,
working on a way to unblock GIL during read(2) to allow
signal processing on the main thread.
'''
def __init__(
self,
shm: SharedMemory,
write_event: EventFD,
wrap_event: EventFD,
start_ptr: int = 0
shm_key: str,
write_eventfd: int,
wrap_eventfd: int,
start_ptr: int = 0,
buf_size: int = 10 * 1024,
flags: int = 0
):
self._shm: SharedMemory = shm
self._write_event = write_event
self._wrap_event = wrap_event
self._shm = SharedMemory(
name=shm_key,
size=buf_size,
create=False
)
self._write_event = EventFD(write_eventfd, 'w')
self._wrap_event = EventFD(wrap_eventfd, 'r')
self._ptr = start_ptr
self._flags = flags
@property
def key(self) -> str:
@ -1067,18 +1111,44 @@ if platform.system() == 'Linux':
def wrap_fd(self) -> int:
return self._wrap_event.fd
async def receive_some(self, max_bytes: int | None = None) -> bytes:
delta = await self._write_event.read()
async def receive_some(
self,
max_bytes: int | None = None,
nb_timeout: float = 0.1
) -> memoryview:
# if non blocking eventfd enabled, do polling
# until next write, this allows signal handling
if self._flags | EFD_NONBLOCK:
delta = None
while delta is None:
try:
delta = await self._write_event.read()
except OSError as e:
if e.errno == 'EAGAIN':
continue
raise e
else:
delta = await self._write_event.read()
# fetch next segment and advance ptr
next_ptr = self._ptr + delta
segment = bytes(self._shm.buf[self._ptr:next_ptr])
segment = self._shm.buf[self._ptr:next_ptr]
self._ptr = next_ptr
if self.ptr == self.size:
# reached the end, signal wrap around
self._ptr = 0
self._wrap_event.write(1)
return segment
async def aclose(self):
...
self._write_event.close()
self._wrap_event.close()
self._shm.close()
async def __aenter__(self):
self._write_event.open()
@ -1087,42 +1157,3 @@ if platform.system() == 'Linux':
async def __aexit__(self, exc_type, exc_value, traceback):
await self.aclose()
@acm
async def open_ringbuffer_sender(
write_event_fd: int,
wrap_event_fd: int,
key: str | None = None,
max_bytes: int = 10 * 1024,
start_ptr: int = 0,
) -> RingBuffSender:
if not key:
key: str = ''.join(random.choice(string.ascii_lowercase) for i in range(32))
shm = SharedMemory(
name=key,
size=max_bytes,
create=True
)
async with RingBuffSender(
shm, EventFD(fd=write_event_fd, omode='w'), EventFD(fd=wrap_event_fd), start_ptr=start_ptr
) as s:
yield s
@acm
async def open_ringbuffer_receiver(
write_event_fd: int,
wrap_event_fd: int,
key: str,
max_bytes: int = 10 * 1024,
start_ptr: int = 0,
) -> RingBuffSender:
shm = SharedMemory(
name=key,
size=max_bytes,
create=False
)
async with RingBuffReceiver(
shm, EventFD(fd=write_event_fd), EventFD(fd=wrap_event_fd, omode='w'), start_ptr=start_ptr
) as r:
yield r

View File

@ -0,0 +1,35 @@
import os
import random
def generate_sample_messages(
amount: int,
rand_min: int = 0,
rand_max: int = 0,
silent: bool = False
) -> tuple[list[bytes], int]:
msgs = []
size = 0
if not silent:
print(f'\ngenerating {amount} messages...')
for i in range(amount):
msg = f'[{i:08}]'.encode('utf-8')
if rand_max > 0:
msg += os.urandom(
random.randint(rand_min, rand_max))
size += len(msg)
msgs.append(msg)
if not silent and i and i % 10_000 == 0:
print(f'{i} generated')
if not silent:
print(f'done, {size:,} bytes in total')
return msgs, size