Compare commits

..

4 Commits

8 changed files with 142 additions and 884 deletions

View File

@ -14,6 +14,6 @@ pkgs.mkShell {
shellHook = '' shellHook = ''
set -e set -e
uv venv .venv --python=3.11 uv venv .venv --python=3.12
''; '';
} }

View File

@ -1,32 +0,0 @@
import trio
import pytest
from tractor.ipc import (
open_eventfd,
EFDReadCancelled,
EventFD
)
def test_eventfd_read_cancellation():
'''
Ensure EventFD.read raises EFDReadCancelled if EventFD.close()
is called.
'''
fd = open_eventfd()
async def _read(event: EventFD):
with pytest.raises(EFDReadCancelled):
await event.read()
async def main():
async with trio.open_nursery() as n:
with (
EventFD(fd, 'w') as event,
trio.fail_after(3)
):
n.start_soon(_read, event)
await trio.sleep(0.2)
event.close()
trio.run(main)

View File

@ -1,21 +1,15 @@
import time import time
import hashlib
import trio import trio
import pytest import pytest
import tractor import tractor
from tractor.ipc import ( from tractor.ipc import (
open_ringbuf, open_ringbuf,
attach_to_ringbuf_receiver,
attach_to_ringbuf_sender,
attach_to_ringbuf_stream,
attach_to_ringbuf_channel,
RBToken, RBToken,
RingBuffSender,
RingBuffReceiver
) )
from tractor._testing.samples import ( from tractor._testing.samples import generate_sample_messages
generate_single_byte_msgs,
generate_sample_messages
)
@tractor.context @tractor.context
@ -23,28 +17,20 @@ async def child_read_shm(
ctx: tractor.Context, ctx: tractor.Context,
msg_amount: int, msg_amount: int,
token: RBToken, token: RBToken,
) -> str: total_bytes: int,
''' ) -> None:
Sub-actor used in `test_ringbuf`.
Attach to a ringbuf and receive all messages until end of stream.
Keep track of how many bytes received and also calculate
sha256 of the whole byte stream.
Calculate and print performance stats, finally return calculated
hash.
'''
await ctx.started()
print('reader started')
recvd_bytes = 0 recvd_bytes = 0
recvd_hash = hashlib.sha256() await ctx.started()
start_ts = time.time() start_ts = time.time()
async with attach_to_ringbuf_receiver(token) as receiver: async with RingBuffReceiver(token) as receiver:
async for msg in receiver: while recvd_bytes < total_bytes:
recvd_hash.update(msg) msg = await receiver.receive_some()
recvd_bytes += len(msg) recvd_bytes += len(msg)
# make sure we dont hold any memoryviews
# before the ctx manager aclose()
msg = None
end_ts = time.time() end_ts = time.time()
elapsed = end_ts - start_ts elapsed = end_ts - start_ts
elapsed_ms = int(elapsed * 1000) elapsed_ms = int(elapsed * 1000)
@ -52,9 +38,6 @@ async def child_read_shm(
print(f'\n\telapsed ms: {elapsed_ms}') print(f'\n\telapsed ms: {elapsed_ms}')
print(f'\tmsg/sec: {int(msg_amount / elapsed):,}') print(f'\tmsg/sec: {int(msg_amount / elapsed):,}')
print(f'\tbytes/sec: {int(recvd_bytes / elapsed):,}') print(f'\tbytes/sec: {int(recvd_bytes / elapsed):,}')
print(f'\treceived bytes: {recvd_bytes:,}')
return recvd_hash.hexdigest()
@tractor.context @tractor.context
@ -65,32 +48,16 @@ async def child_write_shm(
rand_max: int, rand_max: int,
token: RBToken, token: RBToken,
) -> None: ) -> None:
''' msgs, total_bytes = generate_sample_messages(
Sub-actor used in `test_ringbuf`
Generate `msg_amount` payloads with
`random.randint(rand_min, rand_max)` random bytes at the end,
Calculate sha256 hash and send it to parent on `ctx.started`.
Attach to ringbuf and send all generated messages.
'''
msgs, _total_bytes = generate_sample_messages(
msg_amount, msg_amount,
rand_min=rand_min, rand_min=rand_min,
rand_max=rand_max, rand_max=rand_max,
) )
print('writer hashing payload...') await ctx.started(total_bytes)
sent_hash = hashlib.sha256(b''.join(msgs)).hexdigest() async with RingBuffSender(token) as sender:
print('writer done hashing.')
await ctx.started(sent_hash)
print('writer started')
async with attach_to_ringbuf_sender(token, cleanup=False) as sender:
for msg in msgs: for msg in msgs:
await sender.send_all(msg) await sender.send_all(msg)
print('writer exit')
@pytest.mark.parametrize( @pytest.mark.parametrize(
'msg_amount,rand_min,rand_max,buf_size', 'msg_amount,rand_min,rand_max,buf_size',
@ -116,23 +83,19 @@ def test_ringbuf(
rand_max: int, rand_max: int,
buf_size: int buf_size: int
): ):
'''
- Open a new ring buf on root actor
- Open `child_write_shm` ctx in sub-actor which will generate a
random payload and send its hash on `ctx.started`, finally sending
the payload through the stream.
- Open `child_read_shm` ctx in sub-actor which will receive the
payload, calculate perf stats and return the hash.
- Compare both hashes
'''
async def main(): async def main():
with open_ringbuf( with open_ringbuf(
'test_ringbuf', 'test_ringbuf',
buf_size=buf_size buf_size=buf_size
) as token: ) as token:
proc_kwargs = {'pass_fds': token.fds} proc_kwargs = {
'pass_fds': (token.write_eventfd, token.wrap_eventfd)
}
common_kwargs = {
'msg_amount': msg_amount,
'token': token,
}
async with tractor.open_nursery() as an: async with tractor.open_nursery() as an:
send_p = await an.start_actor( send_p = await an.start_actor(
'ring_sender', 'ring_sender',
@ -147,20 +110,17 @@ def test_ringbuf(
async with ( async with (
send_p.open_context( send_p.open_context(
child_write_shm, child_write_shm,
token=token,
msg_amount=msg_amount,
rand_min=rand_min, rand_min=rand_min,
rand_max=rand_max, rand_max=rand_max,
) as (_sctx, sent_hash), **common_kwargs
) as (sctx, total_bytes),
recv_p.open_context( recv_p.open_context(
child_read_shm, child_read_shm,
token=token, **common_kwargs,
msg_amount=msg_amount total_bytes=total_bytes,
) as (rctx, _sent), ) as (sctx, _sent),
): ):
recvd_hash = await rctx.result() await recv_p.result()
assert sent_hash == recvd_hash
await send_p.cancel_actor() await send_p.cancel_actor()
await recv_p.cancel_actor() await recv_p.cancel_actor()
@ -174,28 +134,23 @@ async def child_blocked_receiver(
ctx: tractor.Context, ctx: tractor.Context,
token: RBToken token: RBToken
): ):
async with attach_to_ringbuf_receiver(token) as receiver: async with RingBuffReceiver(token) as receiver:
await ctx.started() await ctx.started()
await receiver.receive_some() await receiver.receive_some()
def test_reader_cancel(): def test_ring_reader_cancel():
'''
Test that a receiver blocked on eventfd(2) read responds to
cancellation.
'''
async def main(): async def main():
with open_ringbuf('test_ring_cancel_reader') as token: with open_ringbuf('test_ring_cancel_reader') as token:
async with ( async with (
tractor.open_nursery() as an, tractor.open_nursery() as an,
attach_to_ringbuf_sender(token) as _sender, RingBuffSender(token) as _sender,
): ):
recv_p = await an.start_actor( recv_p = await an.start_actor(
'ring_blocked_receiver', 'ring_blocked_receiver',
enable_modules=[__name__], enable_modules=[__name__],
proc_kwargs={ proc_kwargs={
'pass_fds': token.fds 'pass_fds': (token.write_eventfd, token.wrap_eventfd)
} }
) )
async with ( async with (
@ -217,17 +172,12 @@ async def child_blocked_sender(
ctx: tractor.Context, ctx: tractor.Context,
token: RBToken token: RBToken
): ):
async with attach_to_ringbuf_sender(token) as sender: async with RingBuffSender(token) as sender:
await ctx.started() await ctx.started()
await sender.send_all(b'this will wrap') await sender.send_all(b'this will wrap')
def test_sender_cancel(): def test_ring_sender_cancel():
'''
Test that a sender blocked on eventfd(2) read responds to
cancellation.
'''
async def main(): async def main():
with open_ringbuf( with open_ringbuf(
'test_ring_cancel_sender', 'test_ring_cancel_sender',
@ -238,7 +188,7 @@ def test_sender_cancel():
'ring_blocked_sender', 'ring_blocked_sender',
enable_modules=[__name__], enable_modules=[__name__],
proc_kwargs={ proc_kwargs={
'pass_fds': token.fds 'pass_fds': (token.write_eventfd, token.wrap_eventfd)
} }
) )
async with ( async with (
@ -253,171 +203,3 @@ def test_sender_cancel():
with pytest.raises(tractor._exceptions.ContextCancelled): with pytest.raises(tractor._exceptions.ContextCancelled):
trio.run(main) trio.run(main)
def test_receiver_max_bytes():
'''
Test that RingBuffReceiver.receive_some's max_bytes optional
argument works correctly, send a msg of size 100, then
force receive of messages with max_bytes == 1, wait until
100 of these messages are received, then compare join of
msgs with original message
'''
msg = generate_single_byte_msgs(100)
msgs = []
async def main():
with open_ringbuf(
'test_ringbuf_max_bytes',
buf_size=10
) as token:
async with (
trio.open_nursery() as n,
attach_to_ringbuf_sender(token, cleanup=False) as sender,
attach_to_ringbuf_receiver(token, cleanup=False) as receiver
):
async def _send_and_close():
await sender.send_all(msg)
await sender.aclose()
n.start_soon(_send_and_close)
while len(msgs) < len(msg):
msg_part = await receiver.receive_some(max_bytes=1)
assert len(msg_part) == 1
msgs.append(msg_part)
trio.run(main)
assert msg == b''.join(msgs)
def test_stapled_ringbuf():
'''
Open two ringbufs and give tokens to tasks (swap them such that in/out tokens
are inversed on each task) which will open the streams and use trio.StapledStream
to have a single bidirectional stream.
Then take turns to send and receive messages.
'''
msg = generate_single_byte_msgs(100)
pair_0_msgs = []
pair_1_msgs = []
pair_0_done = trio.Event()
pair_1_done = trio.Event()
async def pair_0(token_in: RBToken, token_out: RBToken):
async with attach_to_ringbuf_stream(
token_in,
token_out,
cleanup_in=False,
cleanup_out=False
) as stream:
# first turn to send
await stream.send_all(msg)
# second turn to receive
while len(pair_0_msgs) != len(msg):
_msg = await stream.receive_some(max_bytes=1)
pair_0_msgs.append(_msg)
pair_0_done.set()
await pair_1_done.wait()
async def pair_1(token_in: RBToken, token_out: RBToken):
async with attach_to_ringbuf_stream(
token_in,
token_out,
cleanup_in=False,
cleanup_out=False
) as stream:
# first turn to receive
while len(pair_1_msgs) != len(msg):
_msg = await stream.receive_some(max_bytes=1)
pair_1_msgs.append(_msg)
# second turn to send
await stream.send_all(msg)
pair_1_done.set()
await pair_0_done.wait()
async def main():
with tractor.ipc.open_ringbuf_pair(
'test_stapled_ringbuf'
) as (token_0, token_1):
async with trio.open_nursery() as n:
n.start_soon(pair_0, token_0, token_1)
n.start_soon(pair_1, token_1, token_0)
trio.run(main)
assert msg == b''.join(pair_0_msgs)
assert msg == b''.join(pair_1_msgs)
@tractor.context
async def child_channel_sender(
ctx: tractor.Context,
msg_amount_min: int,
msg_amount_max: int,
token_in: RBToken,
token_out: RBToken
):
import random
msgs, _total_bytes = generate_sample_messages(
random.randint(msg_amount_min, msg_amount_max),
rand_min=256,
rand_max=1024,
)
async with attach_to_ringbuf_channel(
token_in,
token_out
) as chan:
await ctx.started(msgs)
for msg in msgs:
await chan.send(msg)
def test_channel():
msg_amount_min = 100
msg_amount_max = 1000
async def main():
with tractor.ipc.open_ringbuf_pair(
'test_ringbuf_transport'
) as (token_0, token_1):
async with (
attach_to_ringbuf_channel(token_0, token_1) as chan,
tractor.open_nursery() as an
):
recv_p = await an.start_actor(
'test_ringbuf_transport_sender',
enable_modules=[__name__],
proc_kwargs={
'pass_fds': token_0.fds + token_1.fds
}
)
async with (
recv_p.open_context(
child_channel_sender,
msg_amount_min=msg_amount_min,
msg_amount_max=msg_amount_max,
token_in=token_1,
token_out=token_0
) as (ctx, msgs),
):
recv_msgs = []
async for msg in chan:
recv_msgs.append(msg)
await recv_p.cancel_actor()
assert recv_msgs == msgs
trio.run(main)

View File

@ -2,59 +2,19 @@ import os
import random import random
def generate_single_byte_msgs(amount: int) -> bytes:
'''
Generate a byte instance of len `amount` with:
```
byte_at_index(i) = (i % 10).encode()
```
this results in constantly repeating sequences of:
b'0123456789'
'''
return b''.join(str(i % 10).encode() for i in range(amount))
def generate_sample_messages( def generate_sample_messages(
amount: int, amount: int,
rand_min: int = 0, rand_min: int = 0,
rand_max: int = 0, rand_max: int = 0,
silent: bool = False, silent: bool = False
) -> tuple[list[bytes], int]: ) -> tuple[list[bytes], int]:
'''
Generate bytes msgs for tests.
Messages will have the following format:
```
b'[{i:08}]' + os.urandom(random.randint(rand_min, rand_max))
```
so for message index 25:
b'[00000025]' + random_bytes
'''
msgs = [] msgs = []
size = 0 size = 0
log_interval = None
if not silent: if not silent:
print(f'\ngenerating {amount} messages...') print(f'\ngenerating {amount} messages...')
# calculate an apropiate log interval based on
# max message size
max_msg_size = 10 + rand_max
if max_msg_size <= 32 * 1024:
log_interval = 10_000
else:
log_interval = 1000
for i in range(amount): for i in range(amount):
msg = f'[{i:08}]'.encode('utf-8') msg = f'[{i:08}]'.encode('utf-8')
@ -66,13 +26,7 @@ def generate_sample_messages(
msgs.append(msg) msgs.append(msg)
if ( if not silent and i and i % 10_000 == 0:
not silent
and
i > 0
and
i % log_interval == 0
):
print(f'{i} generated') print(f'{i} generated')
if not silent: if not silent:

View File

@ -44,23 +44,12 @@ if platform.system() == 'Linux':
write_eventfd as write_eventfd, write_eventfd as write_eventfd,
read_eventfd as read_eventfd, read_eventfd as read_eventfd,
close_eventfd as close_eventfd, close_eventfd as close_eventfd,
EFDReadCancelled as EFDReadCancelled,
EventFD as EventFD, EventFD as EventFD,
) )
from ._ringbuf import ( from ._ringbuf import (
RBToken as RBToken, RBToken as RBToken,
open_ringbuf as open_ringbuf,
RingBuffSender as RingBuffSender, RingBuffSender as RingBuffSender,
RingBuffReceiver as RingBuffReceiver, RingBuffReceiver as RingBuffReceiver,
open_ringbuf_pair as open_ringbuf_pair, open_ringbuf as open_ringbuf
attach_to_ringbuf_receiver as attach_to_ringbuf_receiver,
attach_to_ringbuf_sender as attach_to_ringbuf_sender,
attach_to_ringbuf_stream as attach_to_ringbuf_stream,
RingBuffBytesSender as RingBuffBytesSender,
RingBuffBytesReceiver as RingBuffBytesReceiver,
RingBuffChannel as RingBuffChannel,
attach_to_ringbuf_schannel as attach_to_ringbuf_schannel,
attach_to_ringbuf_rchannel as attach_to_ringbuf_rchannel,
attach_to_ringbuf_channel as attach_to_ringbuf_channel,
) )

View File

@ -108,10 +108,6 @@ def close_eventfd(fd: int) -> int:
raise OSError(errno.errorcode[ffi.errno], 'close failed') raise OSError(errno.errorcode[ffi.errno], 'close failed')
class EFDReadCancelled(Exception):
...
class EventFD: class EventFD:
''' '''
Use a previously opened eventfd(2), meant to be used in Use a previously opened eventfd(2), meant to be used in
@ -128,7 +124,6 @@ class EventFD:
self._fd: int = fd self._fd: int = fd
self._omode: str = omode self._omode: str = omode
self._fobj = None self._fobj = None
self._cscope: trio.CancelScope | None = None
@property @property
def fd(self) -> int | None: def fd(self) -> int | None:
@ -138,47 +133,18 @@ class EventFD:
return write_eventfd(self._fd, value) return write_eventfd(self._fd, value)
async def read(self) -> int: async def read(self) -> int:
'''
Async wrapper for `read_eventfd(self.fd)`
`trio.to_thread.run_sync` is used, need to use a `trio.CancelScope`
in order to make it cancellable when `self.close()` is called.
'''
self._cscope = trio.CancelScope()
with self._cscope:
return await trio.to_thread.run_sync( return await trio.to_thread.run_sync(
read_eventfd, self._fd, read_eventfd, self._fd,
abandon_on_cancel=True abandon_on_cancel=True
) )
if self._cscope.cancelled_caught:
raise EFDReadCancelled
self._cscope = None
def read_direct(self) -> int:
'''
Direct call to `read_eventfd(self.fd)`, unless `eventfd` was
opened with `EFD_NONBLOCK` its gonna block the thread.
'''
return read_eventfd(self._fd)
def open(self): def open(self):
self._fobj = os.fdopen(self._fd, self._omode) self._fobj = os.fdopen(self._fd, self._omode)
def close(self): def close(self):
if self._fobj: if self._fobj:
try:
self._fobj.close() self._fobj.close()
except OSError:
...
if self._cscope:
self._cscope.cancel()
def __enter__(self): def __enter__(self):
self.open() self.open()
return self return self

View File

@ -18,15 +18,7 @@ IPC Reliable RingBuffer implementation
''' '''
from __future__ import annotations from __future__ import annotations
import struct from contextlib import contextmanager as cm
from typing import (
ContextManager,
AsyncContextManager
)
from contextlib import (
contextmanager as cm,
asynccontextmanager as acm
)
from multiprocessing.shared_memory import SharedMemory from multiprocessing.shared_memory import SharedMemory
import trio import trio
@ -36,37 +28,25 @@ from msgspec import (
) )
from ._linux import ( from ._linux import (
EFD_NONBLOCK,
open_eventfd, open_eventfd,
EFDReadCancelled,
EventFD EventFD
) )
from ._mp_bs import disable_mantracker from ._mp_bs import disable_mantracker
from tractor.log import get_logger
from tractor._exceptions import (
InternalError
)
log = get_logger(__name__)
disable_mantracker() disable_mantracker()
_DEFAULT_RB_SIZE = 10 * 1024
class RBToken(Struct, frozen=True): class RBToken(Struct, frozen=True):
''' '''
RingBuffer token contains necesary info to open the three RingBuffer token contains necesary info to open the two
eventfds and the shared memory eventfds and the shared memory
''' '''
shm_name: str shm_name: str
write_eventfd: int
write_eventfd: int # used to signal writer ptr advance wrap_eventfd: int
wrap_eventfd: int # used to signal reader ready after wrap around
eof_eventfd: int # used to signal writer closed
buf_size: int buf_size: int
def as_msg(self): def as_msg(self):
@ -79,45 +59,24 @@ class RBToken(Struct, frozen=True):
return RBToken(**msg) return RBToken(**msg)
@property
def fds(self) -> tuple[int, int, int]:
'''
Useful for `pass_fds` params
'''
return (
self.write_eventfd,
self.wrap_eventfd,
self.eof_eventfd
)
@cm @cm
def open_ringbuf( def open_ringbuf(
shm_name: str, shm_name: str,
buf_size: int = _DEFAULT_RB_SIZE, buf_size: int = 10 * 1024,
) -> ContextManager[RBToken]: write_efd_flags: int = 0,
''' wrap_efd_flags: int = 0
Handle resources for a ringbuf (shm, eventfd), yield `RBToken` to ) -> RBToken:
be used with `attach_to_ringbuf_sender` and `attach_to_ringbuf_receiver`
'''
shm = SharedMemory( shm = SharedMemory(
name=shm_name, name=shm_name,
size=buf_size, size=buf_size,
create=True create=True
) )
try: try:
with (
EventFD(open_eventfd(), 'r') as write_event,
EventFD(open_eventfd(), 'r') as wrap_event,
EventFD(open_eventfd(), 'r') as eof_event,
):
token = RBToken( token = RBToken(
shm_name=shm_name, shm_name=shm_name,
write_eventfd=write_event.fd, write_eventfd=open_eventfd(flags=write_efd_flags),
wrap_eventfd=wrap_event.fd, wrap_eventfd=open_eventfd(flags=wrap_efd_flags),
eof_eventfd=eof_event.fd,
buf_size=buf_size buf_size=buf_size
) )
yield token yield token
@ -126,50 +85,36 @@ def open_ringbuf(
shm.unlink() shm.unlink()
Buffer = bytes | bytearray | memoryview
'''
IPC Reliable Ring Buffer
`eventfd(2)` is used for wrap around sync, to signal writes to
the reader and end of stream.
'''
class RingBuffSender(trio.abc.SendStream): class RingBuffSender(trio.abc.SendStream):
''' '''
Ring Buffer sender side implementation IPC Reliable Ring Buffer sender side implementation
Do not use directly! manage with `attach_to_ringbuf_sender` `eventfd(2)` is used for wrap around sync, and also to signal
after having opened a ringbuf context with `open_ringbuf`. writes to the reader.
''' '''
def __init__( def __init__(
self, self,
token: RBToken, token: RBToken,
cleanup: bool = False start_ptr: int = 0,
): ):
self._token = RBToken.from_msg(token) token = RBToken.from_msg(token)
self._shm: SharedMemory | None = None self._shm = SharedMemory(
self._write_event = EventFD(self._token.write_eventfd, 'w') name=token.shm_name,
self._wrap_event = EventFD(self._token.wrap_eventfd, 'r') size=token.buf_size,
self._eof_event = EventFD(self._token.eof_eventfd, 'w') create=False
self._ptr = 0 )
self._write_event = EventFD(token.write_eventfd, 'w')
self._cleanup = cleanup self._wrap_event = EventFD(token.wrap_eventfd, 'r')
self._send_lock = trio.StrictFIFOLock() self._ptr = start_ptr
@property @property
def name(self) -> str: def key(self) -> str:
if not self._shm:
raise ValueError('shared memory not initialized yet!')
return self._shm.name return self._shm.name
@property @property
def size(self) -> int: def size(self) -> int:
return self._token.buf_size return self._shm.size
@property @property
def ptr(self) -> int: def ptr(self) -> int:
@ -183,11 +128,7 @@ class RingBuffSender(trio.abc.SendStream):
def wrap_fd(self) -> int: def wrap_fd(self) -> int:
return self._wrap_event.fd return self._wrap_event.fd
async def _wait_wrap(self): async def send_all(self, data: bytes | bytearray | memoryview):
await self._wrap_event.read()
async def send_all(self, data: Buffer):
async with self._send_lock:
# while data is larger than the remaining buf # while data is larger than the remaining buf
target_ptr = self.ptr + len(data) target_ptr = self.ptr + len(data)
while target_ptr > self.size: while target_ptr > self.size:
@ -196,7 +137,7 @@ class RingBuffSender(trio.abc.SendStream):
self._shm.buf[self.ptr:] = data[:remaining] self._shm.buf[self.ptr:] = data[:remaining]
# signal write and wait for reader wrap around # signal write and wait for reader wrap around
self._write_event.write(remaining) self._write_event.write(remaining)
await self._wait_wrap() await self._wrap_event.read()
# wrap around and trim already written bytes # wrap around and trim already written bytes
self._ptr = 0 self._ptr = 0
@ -211,69 +152,49 @@ class RingBuffSender(trio.abc.SendStream):
async def wait_send_all_might_not_block(self): async def wait_send_all_might_not_block(self):
raise NotImplementedError raise NotImplementedError
def open(self): async def aclose(self):
self._shm = SharedMemory(
name=self._token.shm_name,
size=self._token.buf_size,
create=False
)
self._write_event.open()
self._wrap_event.open()
self._eof_event.open()
def close(self):
self._eof_event.write(
self._ptr if self._ptr > 0 else self.size
)
if self._cleanup:
self._write_event.close() self._write_event.close()
self._wrap_event.close() self._wrap_event.close()
self._eof_event.close()
self._shm.close() self._shm.close()
async def aclose(self):
async with self._send_lock:
self.close()
async def __aenter__(self): async def __aenter__(self):
self.open() self._write_event.open()
self._wrap_event.open()
return self return self
class RingBuffReceiver(trio.abc.ReceiveStream): class RingBuffReceiver(trio.abc.ReceiveStream):
''' '''
Ring Buffer receiver side implementation IPC Reliable Ring Buffer receiver side implementation
Do not use directly! manage with `attach_to_ringbuf_receiver` `eventfd(2)` is used for wrap around sync, and also to signal
after having opened a ringbuf context with `open_ringbuf`. writes to the reader.
''' '''
def __init__( def __init__(
self, self,
token: RBToken, token: RBToken,
cleanup: bool = True, start_ptr: int = 0,
flags: int = 0
): ):
self._token = RBToken.from_msg(token) token = RBToken.from_msg(token)
self._shm: SharedMemory | None = None self._shm = SharedMemory(
self._write_event = EventFD(self._token.write_eventfd, 'w') name=token.shm_name,
self._wrap_event = EventFD(self._token.wrap_eventfd, 'r') size=token.buf_size,
self._eof_event = EventFD(self._token.eof_eventfd, 'r') create=False
self._ptr: int = 0 )
self._write_ptr: int = 0 self._write_event = EventFD(token.write_eventfd, 'w')
self._end_ptr: int = -1 self._wrap_event = EventFD(token.wrap_eventfd, 'r')
self._ptr = start_ptr
self._cleanup: bool = cleanup self._flags = flags
@property @property
def name(self) -> str: def key(self) -> str:
if not self._shm:
raise ValueError('shared memory not initialized yet!')
return self._shm.name return self._shm.name
@property @property
def size(self) -> int: def size(self) -> int:
return self._token.buf_size return self._shm.size
@property @property
def ptr(self) -> int: def ptr(self) -> int:
@ -287,368 +208,46 @@ class RingBuffReceiver(trio.abc.ReceiveStream):
def wrap_fd(self) -> int: def wrap_fd(self) -> int:
return self._wrap_event.fd return self._wrap_event.fd
async def _eof_monitor_task(self): async def receive_some(
''' self,
Long running EOF event monitor, automatically run in bg by max_bytes: int | None = None,
`attach_to_ringbuf_receiver` context manager, if EOF event nb_timeout: float = 0.1
is set its value will be the end pointer (highest valid ) -> memoryview:
index to be read from buf, after setting the `self._end_ptr` # if non blocking eventfd enabled, do polling
we close the write event which should cancel any blocked # until next write, this allows signal handling
`self._write_event.read()`s on it. if self._flags | EFD_NONBLOCK:
delta = None
''' while delta is None:
try:
self._end_ptr = await self._eof_event.read()
self._write_event.close()
except EFDReadCancelled:
...
except trio.Cancelled:
...
async def receive_some(self, max_bytes: int | None = None) -> bytes:
'''
Receive up to `max_bytes`, if no `max_bytes` is provided
a reasonable default is used.
'''
if max_bytes is None:
max_bytes: int = _DEFAULT_RB_SIZE
if max_bytes < 1:
raise ValueError("max_bytes must be >= 1")
# delta is remaining bytes we havent read
delta = self._write_ptr - self._ptr
if delta == 0:
# we have read all we can, see if new data is available
if self._end_ptr < 0:
# if we havent been signaled about EOF yet
try: try:
delta = await self._write_event.read() delta = await self._write_event.read()
self._write_ptr += delta
except EFDReadCancelled: except OSError as e:
# while waiting for new data `self._write_event` was closed if e.errno == 'EAGAIN':
# this means writer signaled EOF continue
if self._end_ptr > 0:
# final self._write_ptr modification and recalculate delta raise e
self._write_ptr = self._end_ptr
delta = self._end_ptr - self._ptr
else: else:
# shouldnt happen cause self._eof_monitor_task always sets delta = await self._write_event.read()
# self._end_ptr before closing self._write_event
raise InternalError(
'self._write_event.read cancelled but self._end_ptr is not set'
)
else:
# no more bytes to read and self._end_ptr set, EOF reached
return b''
# dont overflow caller
delta = min(delta, max_bytes)
target_ptr = self._ptr + delta
# fetch next segment and advance ptr # fetch next segment and advance ptr
segment = bytes(self._shm.buf[self._ptr:target_ptr]) next_ptr = self._ptr + delta
self._ptr = target_ptr segment = self._shm.buf[self._ptr:next_ptr]
self._ptr = next_ptr
if self._ptr == self.size: if self.ptr == self.size:
# reached the end, signal wrap around # reached the end, signal wrap around
self._ptr = 0 self._ptr = 0
self._write_ptr = 0
self._wrap_event.write(1) self._wrap_event.write(1)
return segment return segment
def open(self): async def aclose(self):
self._shm = SharedMemory(
name=self._token.shm_name,
size=self._token.buf_size,
create=False
)
self._write_event.open()
self._wrap_event.open()
self._eof_event.open()
def close(self):
if self._cleanup:
self._write_event.close() self._write_event.close()
self._wrap_event.close() self._wrap_event.close()
self._eof_event.close()
self._shm.close() self._shm.close()
async def aclose(self):
self.close()
async def __aenter__(self): async def __aenter__(self):
self.open() self._write_event.open()
self._wrap_event.open()
return self return self
@acm
async def attach_to_ringbuf_receiver(
token: RBToken,
cleanup: bool = True
) -> AsyncContextManager[RingBuffReceiver]:
'''
Attach a RingBuffReceiver from a previously opened
RBToken.
Launches `receiver._eof_monitor_task` in a `trio.Nursery`.
'''
async with (
trio.open_nursery() as n,
RingBuffReceiver(
token,
cleanup=cleanup
) as receiver
):
n.start_soon(receiver._eof_monitor_task)
yield receiver
@acm
async def attach_to_ringbuf_sender(
token: RBToken,
cleanup: bool = True
) -> AsyncContextManager[RingBuffSender]:
'''
Attach a RingBuffSender from a previously opened
RBToken.
'''
async with RingBuffSender(
token,
cleanup=cleanup
) as sender:
yield sender
@cm
def open_ringbuf_pair(
name: str,
buf_size: int = _DEFAULT_RB_SIZE
) -> ContextManager[tuple(RBToken, RBToken)]:
'''
Handle resources for a ringbuf pair to be used for
bidirectional messaging.
'''
with (
open_ringbuf(
name + '.pair0',
buf_size=buf_size
) as token_0,
open_ringbuf(
name + '.pair1',
buf_size=buf_size
) as token_1
):
yield token_0, token_1
@acm
async def attach_to_ringbuf_stream(
token_in: RBToken,
token_out: RBToken,
cleanup_in: bool = True,
cleanup_out: bool = True
) -> AsyncContextManager[trio.StapledStream]:
'''
Attach a trio.StapledStream from a previously opened
ringbuf pair.
'''
async with (
attach_to_ringbuf_receiver(
token_in,
cleanup=cleanup_in
) as receiver,
attach_to_ringbuf_sender(
token_out,
cleanup=cleanup_out
) as sender,
):
yield trio.StapledStream(sender, receiver)
class RingBuffBytesSender(trio.abc.SendChannel[bytes]):
'''
In order to guarantee full messages are received, all bytes
sent by `RingBuffBytesSender` are preceded with a 4 byte header
which decodes into a uint32 indicating the actual size of the
next payload.
Optional batch mode:
If `batch_size` > 1 messages wont get sent immediately but will be
stored until `batch_size` messages are pending, then it will send
them all at once.
`batch_size` can be changed dynamically but always call, `flush()`
right before.
'''
def __init__(
self,
sender: RingBuffSender,
batch_size: int = 1
):
self._sender = sender
self.batch_size = batch_size
self._batch_msg_len = 0
self._batch: bytes = b''
async def flush(self) -> None:
await self._sender.send_all(self._batch)
self._batch = b''
self._batch_msg_len = 0
async def send(self, value: bytes) -> None:
msg: bytes = struct.pack("<I", len(value)) + value
if self.batch_size == 1:
await self._sender.send_all(msg)
return
self._batch += msg
self._batch_msg_len += 1
if self._batch_msg_len == self.batch_size:
await self.flush()
async def aclose(self) -> None:
await self._sender.aclose()
class RingBuffBytesReceiver(trio.abc.ReceiveChannel[bytes]):
'''
See `RingBuffBytesSender` docstring.
A `tricycle.BufferedReceiveStream` is used for the
`receive_exactly` API.
'''
def __init__(
self,
receiver: RingBuffReceiver
):
self._receiver = receiver
async def _receive_exactly(self, num_bytes: int) -> bytes:
'''
Fetch bytes from receiver until we read exactly `num_bytes`
or end of stream is signaled.
'''
payload = b''
while len(payload) < num_bytes:
remaining = num_bytes - len(payload)
new_bytes = await self._receiver.receive_some(
max_bytes=remaining
)
if new_bytes == b'':
raise trio.EndOfChannel
payload += new_bytes
return payload
async def receive(self) -> bytes:
header: bytes = await self._receive_exactly(4)
size: int
size, = struct.unpack("<I", header)
if size == 0:
raise trio.EndOfChannel
return await self._receive_exactly(size)
async def aclose(self) -> None:
await self._receiver.aclose()
@acm
async def attach_to_ringbuf_rchannel(
token: RBToken,
cleanup: bool = True
) -> AsyncContextManager[RingBuffBytesReceiver]:
'''
Attach a RingBuffBytesReceiver from a previously opened
RBToken.
'''
async with attach_to_ringbuf_receiver(
token, cleanup=cleanup
) as receiver:
yield RingBuffBytesReceiver(receiver)
@acm
async def attach_to_ringbuf_schannel(
token: RBToken,
cleanup: bool = True,
batch_size: int = 1,
) -> AsyncContextManager[RingBuffBytesSender]:
'''
Attach a RingBuffBytesSender from a previously opened
RBToken.
'''
async with attach_to_ringbuf_sender(
token, cleanup=cleanup
) as sender:
yield RingBuffBytesSender(sender, batch_size=batch_size)
class RingBuffChannel(trio.abc.Channel[bytes]):
'''
Combine `RingBuffBytesSender` and `RingBuffBytesReceiver`
in order to expose the bidirectional `trio.abc.Channel` API.
'''
def __init__(
self,
sender: RingBuffBytesSender,
receiver: RingBuffBytesReceiver
):
self._sender = sender
self._receiver = receiver
async def send(self, value: bytes):
await self._sender.send(value)
async def receive(self) -> bytes:
return await self._receiver.receive()
async def aclose(self):
await self._receiver.aclose()
await self._sender.aclose()
@acm
async def attach_to_ringbuf_channel(
token_in: RBToken,
token_out: RBToken,
cleanup_in: bool = True,
cleanup_out: bool = True
) -> AsyncContextManager[RingBuffChannel]:
'''
Attach to an already opened ringbuf pair and return
a `RingBuffChannel`.
'''
async with (
attach_to_ringbuf_rchannel(
token_in,
cleanup=cleanup_in
) as receiver,
attach_to_ringbuf_schannel(
token_out,
cleanup=cleanup_out
) as sender,
):
yield RingBuffChannel(sender, receiver)

View File

@ -73,7 +73,7 @@ class MsgTransport(Protocol[MsgType]):
# eventual msg definition/types? # eventual msg definition/types?
# - https://docs.python.org/3/library/typing.html#typing.Protocol # - https://docs.python.org/3/library/typing.html#typing.Protocol
stream: trio.abc.Stream stream: trio.SocketStream
drained: list[MsgType] drained: list[MsgType]
address_type: ClassVar[Type[Address]] address_type: ClassVar[Type[Address]]