Make ring buf api use pickle-able RBToken
parent
9d25cce945
commit
3127db8502
|
@ -4,7 +4,8 @@ import trio
|
|||
import pytest
|
||||
import tractor
|
||||
from tractor.ipc import (
|
||||
open_eventfd,
|
||||
open_ringbuf,
|
||||
RBToken,
|
||||
RingBuffSender,
|
||||
RingBuffReceiver
|
||||
)
|
||||
|
@ -15,22 +16,16 @@ from tractor._testing.samples import generate_sample_messages
|
|||
async def child_read_shm(
|
||||
ctx: tractor.Context,
|
||||
msg_amount: int,
|
||||
shm_key: str,
|
||||
write_eventfd: int,
|
||||
wrap_eventfd: int,
|
||||
token: RBToken,
|
||||
buf_size: int,
|
||||
total_bytes: int,
|
||||
flags: int = 0,
|
||||
) -> None:
|
||||
recvd_bytes = 0
|
||||
await ctx.started()
|
||||
start_ts = time.time()
|
||||
async with RingBuffReceiver(
|
||||
shm_key,
|
||||
write_eventfd,
|
||||
wrap_eventfd,
|
||||
token,
|
||||
buf_size=buf_size,
|
||||
flags=flags
|
||||
) as receiver:
|
||||
while recvd_bytes < total_bytes:
|
||||
msg = await receiver.receive_some()
|
||||
|
@ -55,9 +50,7 @@ async def child_write_shm(
|
|||
msg_amount: int,
|
||||
rand_min: int,
|
||||
rand_max: int,
|
||||
shm_key: str,
|
||||
write_eventfd: int,
|
||||
wrap_eventfd: int,
|
||||
token: RBToken,
|
||||
buf_size: int,
|
||||
) -> None:
|
||||
msgs, total_bytes = generate_sample_messages(
|
||||
|
@ -67,9 +60,7 @@ async def child_write_shm(
|
|||
)
|
||||
await ctx.started(total_bytes)
|
||||
async with RingBuffSender(
|
||||
shm_key,
|
||||
write_eventfd,
|
||||
wrap_eventfd,
|
||||
token,
|
||||
buf_size=buf_size
|
||||
) as sender:
|
||||
for msg in msgs:
|
||||
|
@ -100,52 +91,46 @@ def test_ringbuf(
|
|||
rand_max: int,
|
||||
buf_size: int
|
||||
):
|
||||
write_eventfd = open_eventfd()
|
||||
wrap_eventfd = open_eventfd()
|
||||
|
||||
proc_kwargs = {
|
||||
'pass_fds': (write_eventfd, wrap_eventfd)
|
||||
}
|
||||
|
||||
shm_key = 'test_ring_buff'
|
||||
|
||||
common_kwargs = {
|
||||
'msg_amount': msg_amount,
|
||||
'shm_key': shm_key,
|
||||
'write_eventfd': write_eventfd,
|
||||
'wrap_eventfd': wrap_eventfd,
|
||||
'buf_size': buf_size
|
||||
}
|
||||
|
||||
async def main():
|
||||
async with tractor.open_nursery() as an:
|
||||
send_p = await an.start_actor(
|
||||
'ring_sender',
|
||||
enable_modules=[__name__],
|
||||
proc_kwargs=proc_kwargs
|
||||
)
|
||||
recv_p = await an.start_actor(
|
||||
'ring_receiver',
|
||||
enable_modules=[__name__],
|
||||
proc_kwargs=proc_kwargs
|
||||
)
|
||||
async with (
|
||||
send_p.open_context(
|
||||
child_write_shm,
|
||||
rand_min=rand_min,
|
||||
rand_max=rand_max,
|
||||
**common_kwargs
|
||||
) as (sctx, total_bytes),
|
||||
recv_p.open_context(
|
||||
child_read_shm,
|
||||
**common_kwargs,
|
||||
total_bytes=total_bytes,
|
||||
) as (sctx, _sent),
|
||||
):
|
||||
await recv_p.result()
|
||||
|
||||
await send_p.cancel_actor()
|
||||
await recv_p.cancel_actor()
|
||||
with open_ringbuf('test_ringbuf') as token:
|
||||
proc_kwargs = {
|
||||
'pass_fds': (token.write_eventfd, token.wrap_eventfd)
|
||||
}
|
||||
|
||||
common_kwargs = {
|
||||
'msg_amount': msg_amount,
|
||||
'token': token,
|
||||
'buf_size': buf_size
|
||||
}
|
||||
async with tractor.open_nursery() as an:
|
||||
send_p = await an.start_actor(
|
||||
'ring_sender',
|
||||
enable_modules=[__name__],
|
||||
proc_kwargs=proc_kwargs
|
||||
)
|
||||
recv_p = await an.start_actor(
|
||||
'ring_receiver',
|
||||
enable_modules=[__name__],
|
||||
proc_kwargs=proc_kwargs
|
||||
)
|
||||
async with (
|
||||
send_p.open_context(
|
||||
child_write_shm,
|
||||
rand_min=rand_min,
|
||||
rand_max=rand_max,
|
||||
**common_kwargs
|
||||
) as (sctx, total_bytes),
|
||||
recv_p.open_context(
|
||||
child_read_shm,
|
||||
**common_kwargs,
|
||||
total_bytes=total_bytes,
|
||||
) as (sctx, _sent),
|
||||
):
|
||||
await recv_p.result()
|
||||
|
||||
await send_p.cancel_actor()
|
||||
await recv_p.cancel_actor()
|
||||
|
||||
|
||||
trio.run(main)
|
||||
|
@ -154,55 +139,35 @@ def test_ringbuf(
|
|||
@tractor.context
|
||||
async def child_blocked_receiver(
|
||||
ctx: tractor.Context,
|
||||
shm_key: str,
|
||||
write_eventfd: int,
|
||||
wrap_eventfd: int,
|
||||
flags: int = 0
|
||||
token: RBToken
|
||||
):
|
||||
async with RingBuffReceiver(
|
||||
shm_key,
|
||||
write_eventfd,
|
||||
wrap_eventfd,
|
||||
flags=flags
|
||||
) as receiver:
|
||||
async with RingBuffReceiver(token) as receiver:
|
||||
await ctx.started()
|
||||
await receiver.receive_some()
|
||||
|
||||
|
||||
def test_ring_reader_cancel():
|
||||
write_eventfd = open_eventfd()
|
||||
wrap_eventfd = open_eventfd()
|
||||
|
||||
proc_kwargs = {
|
||||
'pass_fds': (write_eventfd, wrap_eventfd)
|
||||
}
|
||||
|
||||
shm_key = 'test_ring_cancel'
|
||||
|
||||
async def main():
|
||||
async with (
|
||||
tractor.open_nursery() as an,
|
||||
RingBuffSender(
|
||||
shm_key,
|
||||
write_eventfd,
|
||||
wrap_eventfd,
|
||||
) as _sender,
|
||||
):
|
||||
recv_p = await an.start_actor(
|
||||
'ring_blocked_receiver',
|
||||
enable_modules=[__name__],
|
||||
proc_kwargs=proc_kwargs
|
||||
)
|
||||
with open_ringbuf('test_ring_cancel') as token:
|
||||
async with (
|
||||
recv_p.open_context(
|
||||
child_blocked_receiver,
|
||||
write_eventfd=write_eventfd,
|
||||
wrap_eventfd=wrap_eventfd,
|
||||
shm_key=shm_key,
|
||||
) as (sctx, _sent),
|
||||
tractor.open_nursery() as an,
|
||||
RingBuffSender(token) as _sender,
|
||||
):
|
||||
await trio.sleep(1)
|
||||
await an.cancel()
|
||||
recv_p = await an.start_actor(
|
||||
'ring_blocked_receiver',
|
||||
enable_modules=[__name__],
|
||||
proc_kwargs={
|
||||
'pass_fds': (token.write_eventfd, token.wrap_eventfd)
|
||||
}
|
||||
)
|
||||
async with (
|
||||
recv_p.open_context(
|
||||
child_blocked_receiver,
|
||||
token=token
|
||||
) as (sctx, _sent),
|
||||
):
|
||||
await trio.sleep(1)
|
||||
await an.cancel()
|
||||
|
||||
|
||||
with pytest.raises(tractor._exceptions.ContextCancelled):
|
||||
|
|
|
@ -36,6 +36,8 @@ if platform.system() == 'Linux':
|
|||
)
|
||||
|
||||
from ._ringbuf import (
|
||||
RBToken as RBToken,
|
||||
RingBuffSender as RingBuffSender,
|
||||
RingBuffReceiver as RingBuffReceiver
|
||||
RingBuffReceiver as RingBuffReceiver,
|
||||
open_ringbuf
|
||||
)
|
||||
|
|
|
@ -17,16 +17,58 @@
|
|||
IPC Reliable RingBuffer implementation
|
||||
|
||||
'''
|
||||
from __future__ import annotations
|
||||
from contextlib import contextmanager as cm
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
|
||||
import trio
|
||||
from msgspec import (
|
||||
Struct,
|
||||
to_builtins
|
||||
)
|
||||
|
||||
from ._linux import (
|
||||
EFD_NONBLOCK,
|
||||
open_eventfd,
|
||||
EventFD
|
||||
)
|
||||
|
||||
|
||||
class RBToken(Struct, frozen=True):
|
||||
'''
|
||||
RingBuffer token contains necesary info to open the two
|
||||
eventfds and the shared memory
|
||||
|
||||
'''
|
||||
shm_name: str
|
||||
write_eventfd: int
|
||||
wrap_eventfd: int
|
||||
|
||||
def as_msg(self):
|
||||
return to_builtins(self)
|
||||
|
||||
@classmethod
|
||||
def from_msg(cls, msg: dict) -> RBToken:
|
||||
if isinstance(msg, RBToken):
|
||||
return msg
|
||||
|
||||
return RBToken(**msg)
|
||||
|
||||
|
||||
@cm
|
||||
def open_ringbuf(
|
||||
shm_name: str,
|
||||
write_efd_flags: int = 0,
|
||||
wrap_efd_flags: int = 0
|
||||
) -> RBToken:
|
||||
token = RBToken(
|
||||
shm_name=shm_name,
|
||||
write_eventfd=open_eventfd(flags=write_efd_flags),
|
||||
wrap_eventfd=open_eventfd(flags=wrap_efd_flags)
|
||||
)
|
||||
yield token
|
||||
|
||||
|
||||
class RingBuffSender(trio.abc.SendStream):
|
||||
'''
|
||||
IPC Reliable Ring Buffer sender side implementation
|
||||
|
@ -34,26 +76,22 @@ class RingBuffSender(trio.abc.SendStream):
|
|||
`eventfd(2)` is used for wrap around sync, and also to signal
|
||||
writes to the reader.
|
||||
|
||||
TODO: if blocked on wrap around event wait it will not respond
|
||||
to signals, fix soon TM
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shm_key: str,
|
||||
write_eventfd: int,
|
||||
wrap_eventfd: int,
|
||||
token: RBToken,
|
||||
start_ptr: int = 0,
|
||||
buf_size: int = 10 * 1024,
|
||||
unlink_on_exit: bool = True
|
||||
unlink_on_exit: bool = False
|
||||
):
|
||||
token = RBToken.from_msg(token)
|
||||
self._shm = SharedMemory(
|
||||
name=shm_key,
|
||||
name=token.shm_name,
|
||||
size=buf_size,
|
||||
create=True
|
||||
)
|
||||
self._write_event = EventFD(write_eventfd, 'w')
|
||||
self._wrap_event = EventFD(wrap_eventfd, 'r')
|
||||
self._write_event = EventFD(token.write_eventfd, 'w')
|
||||
self._wrap_event = EventFD(token.wrap_eventfd, 'r')
|
||||
self._ptr = start_ptr
|
||||
self.unlink_on_exit = unlink_on_exit
|
||||
|
||||
|
@ -123,29 +161,22 @@ class RingBuffReceiver(trio.abc.ReceiveStream):
|
|||
`eventfd(2)` is used for wrap around sync, and also to signal
|
||||
writes to the reader.
|
||||
|
||||
Unless eventfd(2) object is opened with EFD_NONBLOCK flag,
|
||||
calls to `receive_some` will block the signal handling,
|
||||
on the main thread, for now solution is using polling,
|
||||
working on a way to unblock GIL during read(2) to allow
|
||||
signal processing on the main thread.
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shm_key: str,
|
||||
write_eventfd: int,
|
||||
wrap_eventfd: int,
|
||||
token: RBToken,
|
||||
start_ptr: int = 0,
|
||||
buf_size: int = 10 * 1024,
|
||||
flags: int = 0
|
||||
):
|
||||
token = RBToken.from_msg(token)
|
||||
self._shm = SharedMemory(
|
||||
name=shm_key,
|
||||
name=token.shm_name,
|
||||
size=buf_size,
|
||||
create=False
|
||||
)
|
||||
self._write_event = EventFD(write_eventfd, 'w')
|
||||
self._wrap_event = EventFD(wrap_eventfd, 'r')
|
||||
self._write_event = EventFD(token.write_eventfd, 'w')
|
||||
self._wrap_event = EventFD(token.wrap_eventfd, 'r')
|
||||
self._ptr = start_ptr
|
||||
self._flags = flags
|
||||
|
||||
|
|
Loading…
Reference in New Issue