Make ring buf api use pickle-able RBToken

Guillermo Rodriguez 2025-03-13 23:12:20 -03:00
parent 9d25cce945
commit dd17aa4205
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
3 changed files with 130 additions and 129 deletions

View File

@ -4,7 +4,8 @@ import trio
import pytest import pytest
import tractor import tractor
from tractor.ipc import ( from tractor.ipc import (
open_eventfd, open_ringbuf,
RBToken,
RingBuffSender, RingBuffSender,
RingBuffReceiver RingBuffReceiver
) )
@ -15,22 +16,16 @@ from tractor._testing.samples import generate_sample_messages
async def child_read_shm( async def child_read_shm(
ctx: tractor.Context, ctx: tractor.Context,
msg_amount: int, msg_amount: int,
shm_key: str, token: RBToken,
write_eventfd: int,
wrap_eventfd: int,
buf_size: int, buf_size: int,
total_bytes: int, total_bytes: int,
flags: int = 0,
) -> None: ) -> None:
recvd_bytes = 0 recvd_bytes = 0
await ctx.started() await ctx.started()
start_ts = time.time() start_ts = time.time()
async with RingBuffReceiver( async with RingBuffReceiver(
shm_key, token,
write_eventfd,
wrap_eventfd,
buf_size=buf_size, buf_size=buf_size,
flags=flags
) as receiver: ) as receiver:
while recvd_bytes < total_bytes: while recvd_bytes < total_bytes:
msg = await receiver.receive_some() msg = await receiver.receive_some()
@ -55,9 +50,7 @@ async def child_write_shm(
msg_amount: int, msg_amount: int,
rand_min: int, rand_min: int,
rand_max: int, rand_max: int,
shm_key: str, token: RBToken,
write_eventfd: int,
wrap_eventfd: int,
buf_size: int, buf_size: int,
) -> None: ) -> None:
msgs, total_bytes = generate_sample_messages( msgs, total_bytes = generate_sample_messages(
@ -67,9 +60,7 @@ async def child_write_shm(
) )
await ctx.started(total_bytes) await ctx.started(total_bytes)
async with RingBuffSender( async with RingBuffSender(
shm_key, token,
write_eventfd,
wrap_eventfd,
buf_size=buf_size buf_size=buf_size
) as sender: ) as sender:
for msg in msgs: for msg in msgs:
@ -100,52 +91,48 @@ def test_ringbuf(
rand_max: int, rand_max: int,
buf_size: int buf_size: int
): ):
write_eventfd = open_eventfd()
wrap_eventfd = open_eventfd()
proc_kwargs = {
'pass_fds': (write_eventfd, wrap_eventfd)
}
shm_key = 'test_ring_buff'
common_kwargs = {
'msg_amount': msg_amount,
'shm_key': shm_key,
'write_eventfd': write_eventfd,
'wrap_eventfd': wrap_eventfd,
'buf_size': buf_size
}
async def main(): async def main():
async with tractor.open_nursery() as an: with open_ringbuf(
send_p = await an.start_actor( 'test_ringbuf',
'ring_sender', buf_size=buf_size
enable_modules=[__name__], ) as token:
proc_kwargs=proc_kwargs proc_kwargs = {
) 'pass_fds': (token.write_eventfd, token.wrap_eventfd)
recv_p = await an.start_actor( }
'ring_receiver',
enable_modules=[__name__],
proc_kwargs=proc_kwargs
)
async with (
send_p.open_context(
child_write_shm,
rand_min=rand_min,
rand_max=rand_max,
**common_kwargs
) as (sctx, total_bytes),
recv_p.open_context(
child_read_shm,
**common_kwargs,
total_bytes=total_bytes,
) as (sctx, _sent),
):
await recv_p.result()
await send_p.cancel_actor() common_kwargs = {
await recv_p.cancel_actor() 'msg_amount': msg_amount,
'token': token,
'buf_size': buf_size
}
async with tractor.open_nursery() as an:
send_p = await an.start_actor(
'ring_sender',
enable_modules=[__name__],
proc_kwargs=proc_kwargs
)
recv_p = await an.start_actor(
'ring_receiver',
enable_modules=[__name__],
proc_kwargs=proc_kwargs
)
async with (
send_p.open_context(
child_write_shm,
rand_min=rand_min,
rand_max=rand_max,
**common_kwargs
) as (sctx, total_bytes),
recv_p.open_context(
child_read_shm,
**common_kwargs,
total_bytes=total_bytes,
) as (sctx, _sent),
):
await recv_p.result()
await send_p.cancel_actor()
await recv_p.cancel_actor()
trio.run(main) trio.run(main)
@ -154,55 +141,35 @@ def test_ringbuf(
@tractor.context @tractor.context
async def child_blocked_receiver( async def child_blocked_receiver(
ctx: tractor.Context, ctx: tractor.Context,
shm_key: str, token: RBToken
write_eventfd: int,
wrap_eventfd: int,
flags: int = 0
): ):
async with RingBuffReceiver( async with RingBuffReceiver(token) as receiver:
shm_key,
write_eventfd,
wrap_eventfd,
flags=flags
) as receiver:
await ctx.started() await ctx.started()
await receiver.receive_some() await receiver.receive_some()
def test_ring_reader_cancel(): def test_ring_reader_cancel():
write_eventfd = open_eventfd()
wrap_eventfd = open_eventfd()
proc_kwargs = {
'pass_fds': (write_eventfd, wrap_eventfd)
}
shm_key = 'test_ring_cancel'
async def main(): async def main():
async with ( with open_ringbuf('test_ring_cancel') as token:
tractor.open_nursery() as an,
RingBuffSender(
shm_key,
write_eventfd,
wrap_eventfd,
) as _sender,
):
recv_p = await an.start_actor(
'ring_blocked_receiver',
enable_modules=[__name__],
proc_kwargs=proc_kwargs
)
async with ( async with (
recv_p.open_context( tractor.open_nursery() as an,
child_blocked_receiver, RingBuffSender(token) as _sender,
write_eventfd=write_eventfd,
wrap_eventfd=wrap_eventfd,
shm_key=shm_key,
) as (sctx, _sent),
): ):
await trio.sleep(1) recv_p = await an.start_actor(
await an.cancel() 'ring_blocked_receiver',
enable_modules=[__name__],
proc_kwargs={
'pass_fds': (token.write_eventfd, token.wrap_eventfd)
}
)
async with (
recv_p.open_context(
child_blocked_receiver,
token=token
) as (sctx, _sent),
):
await trio.sleep(1)
await an.cancel()
with pytest.raises(tractor._exceptions.ContextCancelled): with pytest.raises(tractor._exceptions.ContextCancelled):

View File

@ -36,6 +36,8 @@ if platform.system() == 'Linux':
) )
from ._ringbuf import ( from ._ringbuf import (
RBToken as RBToken,
RingBuffSender as RingBuffSender, RingBuffSender as RingBuffSender,
RingBuffReceiver as RingBuffReceiver RingBuffReceiver as RingBuffReceiver,
open_ringbuf
) )

View File

@ -17,16 +17,65 @@
IPC Reliable RingBuffer implementation IPC Reliable RingBuffer implementation
''' '''
from __future__ import annotations
from contextlib import contextmanager as cm
from multiprocessing.shared_memory import SharedMemory from multiprocessing.shared_memory import SharedMemory
import trio import trio
from msgspec import (
Struct,
to_builtins
)
from ._linux import ( from ._linux import (
EFD_NONBLOCK, EFD_NONBLOCK,
open_eventfd,
EventFD EventFD
) )
class RBToken(Struct, frozen=True):
'''
RingBuffer token contains necesary info to open the two
eventfds and the shared memory
'''
shm_name: str
write_eventfd: int
wrap_eventfd: int
def as_msg(self):
return to_builtins(self)
@classmethod
def from_msg(cls, msg: dict) -> RBToken:
if isinstance(msg, RBToken):
return msg
return RBToken(**msg)
@cm
def open_ringbuf(
shm_name: str,
buf_size: int = 10 * 1024,
write_efd_flags: int = 0,
wrap_efd_flags: int = 0
) -> RBToken:
shm = SharedMemory(
name=shm_name,
size=buf_size,
create=True
)
token = RBToken(
shm_name=shm_name,
write_eventfd=open_eventfd(flags=write_efd_flags),
wrap_eventfd=open_eventfd(flags=wrap_efd_flags)
)
yield token
shm.close()
class RingBuffSender(trio.abc.SendStream): class RingBuffSender(trio.abc.SendStream):
''' '''
IPC Reliable Ring Buffer sender side implementation IPC Reliable Ring Buffer sender side implementation
@ -34,28 +83,22 @@ class RingBuffSender(trio.abc.SendStream):
`eventfd(2)` is used for wrap around sync, and also to signal `eventfd(2)` is used for wrap around sync, and also to signal
writes to the reader. writes to the reader.
TODO: if blocked on wrap around event wait it will not respond
to signals, fix soon TM
''' '''
def __init__( def __init__(
self, self,
shm_key: str, token: RBToken,
write_eventfd: int,
wrap_eventfd: int,
start_ptr: int = 0, start_ptr: int = 0,
buf_size: int = 10 * 1024, buf_size: int = 10 * 1024,
unlink_on_exit: bool = True
): ):
token = RBToken.from_msg(token)
self._shm = SharedMemory( self._shm = SharedMemory(
name=shm_key, name=token.shm_name,
size=buf_size, size=buf_size,
create=True create=False
) )
self._write_event = EventFD(write_eventfd, 'w') self._write_event = EventFD(token.write_eventfd, 'w')
self._wrap_event = EventFD(wrap_eventfd, 'r') self._wrap_event = EventFD(token.wrap_eventfd, 'r')
self._ptr = start_ptr self._ptr = start_ptr
self.unlink_on_exit = unlink_on_exit
@property @property
def key(self) -> str: def key(self) -> str:
@ -104,11 +147,7 @@ class RingBuffSender(trio.abc.SendStream):
async def aclose(self): async def aclose(self):
self._write_event.close() self._write_event.close()
self._wrap_event.close() self._wrap_event.close()
if self.unlink_on_exit: self._shm.close()
self._shm.unlink()
else:
self._shm.close()
async def __aenter__(self): async def __aenter__(self):
self._write_event.open() self._write_event.open()
@ -123,29 +162,22 @@ class RingBuffReceiver(trio.abc.ReceiveStream):
`eventfd(2)` is used for wrap around sync, and also to signal `eventfd(2)` is used for wrap around sync, and also to signal
writes to the reader. writes to the reader.
Unless eventfd(2) object is opened with EFD_NONBLOCK flag,
calls to `receive_some` will block the signal handling,
on the main thread, for now solution is using polling,
working on a way to unblock GIL during read(2) to allow
signal processing on the main thread.
''' '''
def __init__( def __init__(
self, self,
shm_key: str, token: RBToken,
write_eventfd: int,
wrap_eventfd: int,
start_ptr: int = 0, start_ptr: int = 0,
buf_size: int = 10 * 1024, buf_size: int = 10 * 1024,
flags: int = 0 flags: int = 0
): ):
token = RBToken.from_msg(token)
self._shm = SharedMemory( self._shm = SharedMemory(
name=shm_key, name=token.shm_name,
size=buf_size, size=buf_size,
create=False create=False
) )
self._write_event = EventFD(write_eventfd, 'w') self._write_event = EventFD(token.write_eventfd, 'w')
self._wrap_event = EventFD(wrap_eventfd, 'r') self._wrap_event = EventFD(token.wrap_eventfd, 'r')
self._ptr = start_ptr self._ptr = start_ptr
self._flags = flags self._flags = flags