Rename RingBuff -> RingBuffer

Combine RingBuffer stream and channel apis
Implement RingBufferReceiveChannel.receive_nowait
Make msg generator calculate hash
one_ring_to_rule_them_all
Guillermo Rodriguez 2025-04-04 02:36:59 -03:00
parent 95ea4647cc
commit 3568ba5d5d
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
4 changed files with 221 additions and 342 deletions

View File

@ -9,7 +9,6 @@ from tractor.ipc._ringbuf import (
open_ringbuf, open_ringbuf,
attach_to_ringbuf_receiver, attach_to_ringbuf_receiver,
attach_to_ringbuf_sender, attach_to_ringbuf_sender,
attach_to_ringbuf_stream,
attach_to_ringbuf_channel, attach_to_ringbuf_channel,
RBToken, RBToken,
) )
@ -25,7 +24,6 @@ pytestmark = pytest.mark.skip
@tractor.context @tractor.context
async def child_read_shm( async def child_read_shm(
ctx: tractor.Context, ctx: tractor.Context,
msg_amount: int,
token: RBToken, token: RBToken,
) -> str: ) -> str:
''' '''
@ -41,11 +39,13 @@ async def child_read_shm(
''' '''
await ctx.started() await ctx.started()
print('reader started') print('reader started')
msg_amount = 0
recvd_bytes = 0 recvd_bytes = 0
recvd_hash = hashlib.sha256() recvd_hash = hashlib.sha256()
start_ts = time.time() start_ts = time.time()
async with attach_to_ringbuf_receiver(token) as receiver: async with attach_to_ringbuf_receiver(token) as receiver:
async for msg in receiver: async for msg in receiver:
msg_amount += 1
recvd_hash.update(msg) recvd_hash.update(msg)
recvd_bytes += len(msg) recvd_bytes += len(msg)
@ -79,19 +79,16 @@ async def child_write_shm(
Attach to ringbuf and send all generated messages. Attach to ringbuf and send all generated messages.
''' '''
msgs, _total_bytes = generate_sample_messages( sent_hash, msgs, _total_bytes = generate_sample_messages(
msg_amount, msg_amount,
rand_min=rand_min, rand_min=rand_min,
rand_max=rand_max, rand_max=rand_max,
) )
print('writer hashing payload...')
sent_hash = hashlib.sha256(b''.join(msgs)).hexdigest()
print('writer done hashing.')
await ctx.started(sent_hash) await ctx.started(sent_hash)
print('writer started') print('writer started')
async with attach_to_ringbuf_sender(token, cleanup=False) as sender: async with attach_to_ringbuf_sender(token, cleanup=False) as sender:
for msg in msgs: for msg in msgs:
await sender.send_all(msg) await sender.send(msg)
print('writer exit') print('writer exit')
@ -159,7 +156,6 @@ def test_ringbuf(
recv_p.open_context( recv_p.open_context(
child_read_shm, child_read_shm,
token=token, token=token,
msg_amount=msg_amount
) as (rctx, _sent), ) as (rctx, _sent),
): ):
recvd_hash = await rctx.result() recvd_hash = await rctx.result()
@ -295,75 +291,6 @@ def test_receiver_max_bytes():
assert msg == b''.join(msgs) assert msg == b''.join(msgs)
def test_stapled_ringbuf():
'''
Open two ringbufs and give tokens to tasks (swap them such that in/out tokens
are inversed on each task) which will open the streams and use trio.StapledStream
to have a single bidirectional stream.
Then take turns to send and receive messages.
'''
msg = generate_single_byte_msgs(100)
pair_0_msgs = []
pair_1_msgs = []
pair_0_done = trio.Event()
pair_1_done = trio.Event()
async def pair_0(token_in: RBToken, token_out: RBToken):
async with attach_to_ringbuf_stream(
token_in,
token_out,
cleanup_in=False,
cleanup_out=False
) as stream:
# first turn to send
await stream.send_all(msg)
# second turn to receive
while len(pair_0_msgs) != len(msg):
_msg = await stream.receive_some(max_bytes=1)
pair_0_msgs.append(_msg)
pair_0_done.set()
await pair_1_done.wait()
async def pair_1(token_in: RBToken, token_out: RBToken):
async with attach_to_ringbuf_stream(
token_in,
token_out,
cleanup_in=False,
cleanup_out=False
) as stream:
# first turn to receive
while len(pair_1_msgs) != len(msg):
_msg = await stream.receive_some(max_bytes=1)
pair_1_msgs.append(_msg)
# second turn to send
await stream.send_all(msg)
pair_1_done.set()
await pair_0_done.wait()
async def main():
with tractor.ipc.open_ringbuf_pair(
'test_stapled_ringbuf'
) as (token_0, token_1):
async with trio.open_nursery() as n:
n.start_soon(pair_0, token_0, token_1)
n.start_soon(pair_1, token_1, token_0)
trio.run(main)
assert msg == b''.join(pair_0_msgs)
assert msg == b''.join(pair_1_msgs)
@tractor.context @tractor.context
async def child_channel_sender( async def child_channel_sender(
ctx: tractor.Context, ctx: tractor.Context,
@ -373,7 +300,7 @@ async def child_channel_sender(
token_out: RBToken token_out: RBToken
): ):
import random import random
msgs, _total_bytes = generate_sample_messages( _hash, msgs, _total_bytes = generate_sample_messages(
random.randint(msg_amount_min, msg_amount_max), random.randint(msg_amount_min, msg_amount_max),
rand_min=256, rand_min=256,
rand_max=1024, rand_max=1024,
@ -383,7 +310,6 @@ async def child_channel_sender(
token_out token_out
) as chan: ) as chan:
await ctx.started(msgs) await ctx.started(msgs)
for msg in msgs: for msg in msgs:
await chan.send(msg) await chan.send(msg)
@ -396,16 +322,16 @@ def test_channel():
async def main(): async def main():
with tractor.ipc.open_ringbuf_pair( with tractor.ipc.open_ringbuf_pair(
'test_ringbuf_transport' 'test_ringbuf_transport'
) as (token_0, token_1): ) as (send_token, recv_token):
async with ( async with (
attach_to_ringbuf_channel(token_0, token_1) as chan, attach_to_ringbuf_channel(send_token, recv_token) as chan,
tractor.open_nursery() as an tractor.open_nursery() as an
): ):
recv_p = await an.start_actor( recv_p = await an.start_actor(
'test_ringbuf_transport_sender', 'test_ringbuf_transport_sender',
enable_modules=[__name__], enable_modules=[__name__],
proc_kwargs={ proc_kwargs={
'pass_fds': token_0.fds + token_1.fds 'pass_fds': send_token.fds + recv_token.fds
} }
) )
async with ( async with (
@ -413,8 +339,8 @@ def test_channel():
child_channel_sender, child_channel_sender,
msg_amount_min=msg_amount_min, msg_amount_min=msg_amount_min,
msg_amount_max=msg_amount_max, msg_amount_max=msg_amount_max,
token_in=token_1, token_in=recv_token,
token_out=token_0 token_out=send_token
) as (ctx, msgs), ) as (ctx, msgs),
): ):
recv_msgs = [] recv_msgs = []

View File

@ -1,5 +1,6 @@
import os import os
import random import random
import hashlib
def generate_single_byte_msgs(amount: int) -> bytes: def generate_single_byte_msgs(amount: int) -> bytes:
@ -23,7 +24,7 @@ def generate_sample_messages(
rand_min: int = 0, rand_min: int = 0,
rand_max: int = 0, rand_max: int = 0,
silent: bool = False, silent: bool = False,
) -> tuple[list[bytes], int]: ) -> tuple[str, list[bytes], int]:
''' '''
Generate bytes msgs for tests. Generate bytes msgs for tests.
@ -55,6 +56,7 @@ def generate_sample_messages(
else: else:
log_interval = 1000 log_interval = 1000
payload_hash = hashlib.sha256()
for i in range(amount): for i in range(amount):
msg = f'[{i:08}]'.encode('utf-8') msg = f'[{i:08}]'.encode('utf-8')
@ -64,6 +66,7 @@ def generate_sample_messages(
size += len(msg) size += len(msg)
payload_hash.update(msg)
msgs.append(msg) msgs.append(msg)
if ( if (
@ -78,4 +81,4 @@ def generate_sample_messages(
if not silent: if not silent:
print(f'done, {size:,} bytes in total') print(f'done, {size:,} bytes in total')
return msgs, size return payload_hash.hexdigest(), msgs, size

View File

@ -27,17 +27,16 @@ from ._chan import (
if platform.system() == 'Linux': if platform.system() == 'Linux':
from ._ringbuf import ( from ._ringbuf import (
RBToken as RBToken, RBToken as RBToken,
open_ringbuf as open_ringbuf, open_ringbuf as open_ringbuf,
RingBuffSender as RingBuffSender,
RingBuffReceiver as RingBuffReceiver,
open_ringbuf_pair as open_ringbuf_pair, open_ringbuf_pair as open_ringbuf_pair,
attach_to_ringbuf_receiver as attach_to_ringbuf_receiver,
RingBufferSendChannel as RingBufferSendChannel,
attach_to_ringbuf_sender as attach_to_ringbuf_sender, attach_to_ringbuf_sender as attach_to_ringbuf_sender,
attach_to_ringbuf_stream as attach_to_ringbuf_stream,
RingBuffBytesSender as RingBuffBytesSender, RingBufferReceiveChannel as RingBufferReceiveChannel,
RingBuffBytesReceiver as RingBuffBytesReceiver, attach_to_ringbuf_receiver as attach_to_ringbuf_receiver,
RingBuffChannel as RingBuffChannel,
attach_to_ringbuf_schannel as attach_to_ringbuf_schannel, RingBufferChannel as RingBufferChannel,
attach_to_ringbuf_rchannel as attach_to_ringbuf_rchannel,
attach_to_ringbuf_channel as attach_to_ringbuf_channel, attach_to_ringbuf_channel as attach_to_ringbuf_channel,
) )

View File

@ -126,6 +126,30 @@ def open_ringbuf(
shm.unlink() shm.unlink()
@cm
def open_ringbuf_pair(
name: str,
buf_size: int = _DEFAULT_RB_SIZE
) -> ContextManager[tuple(RBToken, RBToken)]:
'''
Handle resources for a ringbuf pair to be used for
bidirectional messaging.
'''
with (
open_ringbuf(
name + '.send',
buf_size=buf_size
) as send_token,
open_ringbuf(
name + '.recv',
buf_size=buf_size
) as recv_token
):
yield send_token, recv_token
Buffer = bytes | bytearray | memoryview Buffer = bytes | bytearray | memoryview
@ -135,32 +159,65 @@ IPC Reliable Ring Buffer
`eventfd(2)` is used for wrap around sync, to signal writes to `eventfd(2)` is used for wrap around sync, to signal writes to
the reader and end of stream. the reader and end of stream.
In order to guarantee full messages are received, all bytes
sent by `RingBufferSendChannel` are preceded with a 4 byte header
which decodes into a uint32 indicating the actual size of the
next full payload.
''' '''
class RingBuffSender(trio.abc.SendStream): class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
''' '''
Ring Buffer sender side implementation Ring Buffer sender side implementation
Do not use directly! manage with `attach_to_ringbuf_sender` Do not use directly! manage with `attach_to_ringbuf_sender`
after having opened a ringbuf context with `open_ringbuf`. after having opened a ringbuf context with `open_ringbuf`.
Optional batch mode:
If `batch_size` > 1 messages wont get sent immediately but will be
stored until `batch_size` messages are pending, then it will send
them all at once.
`batch_size` can be changed dynamically but always call, `flush()`
right before.
''' '''
def __init__( def __init__(
self, self,
token: RBToken, token: RBToken,
batch_size: int = 1,
cleanup: bool = False cleanup: bool = False
): ):
self._token = RBToken.from_msg(token) self._token = RBToken.from_msg(token)
self.batch_size = batch_size
# ringbuf os resources
self._shm: SharedMemory | None = None self._shm: SharedMemory | None = None
self._write_event = EventFD(self._token.write_eventfd, 'w') self._write_event = EventFD(self._token.write_eventfd, 'w')
self._wrap_event = EventFD(self._token.wrap_eventfd, 'r') self._wrap_event = EventFD(self._token.wrap_eventfd, 'r')
self._eof_event = EventFD(self._token.eof_eventfd, 'w') self._eof_event = EventFD(self._token.eof_eventfd, 'w')
# current write pointer
self._ptr = 0 self._ptr = 0
# when `batch_size` > 1 store messages on `self._batch` and write them
# all, once `len(self._batch) == `batch_size`
self._batch: list[bytes] = []
self._cleanup = cleanup self._cleanup = cleanup
self._send_lock = trio.StrictFIFOLock() self._send_lock = trio.StrictFIFOLock()
@acm
async def _maybe_lock(self) -> AsyncContextManager[None]:
if self._send_lock.locked():
yield
return
async with self._send_lock:
yield
@property @property
def name(self) -> str: def name(self) -> str:
if not self._shm: if not self._shm:
@ -183,11 +240,19 @@ class RingBuffSender(trio.abc.SendStream):
def wrap_fd(self) -> int: def wrap_fd(self) -> int:
return self._wrap_event.fd return self._wrap_event.fd
@property
def pending_msgs(self) -> int:
return len(self._batch)
@property
def must_flush(self) -> bool:
return self.pending_msgs >= self.batch_size
async def _wait_wrap(self): async def _wait_wrap(self):
await self._wrap_event.read() await self._wrap_event.read()
async def send_all(self, data: Buffer): async def send_all(self, data: Buffer):
async with self._send_lock: async with self._maybe_lock():
# while data is larger than the remaining buf # while data is larger than the remaining buf
target_ptr = self.ptr + len(data) target_ptr = self.ptr + len(data)
while target_ptr > self.size: while target_ptr > self.size:
@ -211,6 +276,34 @@ class RingBuffSender(trio.abc.SendStream):
async def wait_send_all_might_not_block(self): async def wait_send_all_might_not_block(self):
raise NotImplementedError raise NotImplementedError
async def flush(
self,
new_batch_size: int | None = None
) -> None:
async with self._maybe_lock():
for msg in self._batch:
await self.send_all(msg)
self._batch = []
if new_batch_size:
self.batch_size = new_batch_size
async def send(self, value: bytes) -> None:
async with self._maybe_lock():
msg: bytes = struct.pack("<I", len(value)) + value
if self.batch_size == 1:
await self.send_all(msg)
return
self._batch.append(msg)
if self.must_flush:
await self.flush()
async def send_eof(self) -> None:
async with self._send_lock:
await self.flush(new_batch_size=1)
await self.send(b'')
def open(self): def open(self):
try: try:
self._shm = SharedMemory( self._shm = SharedMemory(
@ -238,15 +331,14 @@ class RingBuffSender(trio.abc.SendStream):
self._shm.close() self._shm.close()
async def aclose(self): async def aclose(self):
async with self._send_lock: self.close()
self.close()
async def __aenter__(self): async def __aenter__(self):
self.open() self.open()
return self return self
class RingBuffReceiver(trio.abc.ReceiveStream): class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
''' '''
Ring Buffer receiver side implementation Ring Buffer receiver side implementation
@ -312,21 +404,48 @@ class RingBuffReceiver(trio.abc.ReceiveStream):
except trio.Cancelled: except trio.Cancelled:
... ...
async def receive_some(self, max_bytes: int | None = None) -> bytes: def receive_nowait(self, max_bytes: int = _DEFAULT_RB_SIZE) -> bytes:
'''
Try to receive any bytes we can without blocking or raise
`trio.WouldBlock`.
'''
if max_bytes < 1:
raise ValueError("max_bytes must be >= 1")
delta = self._write_ptr - self._ptr
if delta == 0:
raise trio.WouldBlock
# dont overflow caller
delta = min(delta, max_bytes)
target_ptr = self._ptr + delta
# fetch next segment and advance ptr
segment = bytes(self._shm.buf[self._ptr:target_ptr])
self._ptr = target_ptr
if self._ptr == self.size:
# reached the end, signal wrap around
self._ptr = 0
self._write_ptr = 0
self._wrap_event.write(1)
return segment
async def receive_some(self, max_bytes: int = _DEFAULT_RB_SIZE) -> bytes:
''' '''
Receive up to `max_bytes`, if no `max_bytes` is provided Receive up to `max_bytes`, if no `max_bytes` is provided
a reasonable default is used. a reasonable default is used.
Can return < max_bytes.
''' '''
if max_bytes is None: try:
max_bytes: int = _DEFAULT_RB_SIZE return self.receive_nowait(max_bytes=max_bytes)
if max_bytes < 1: except trio.WouldBlock:
raise ValueError("max_bytes must be >= 1")
# delta is remaining bytes we havent read
delta = self._write_ptr - self._ptr
if delta == 0:
# we have read all we can, see if new data is available # we have read all we can, see if new data is available
if self._end_ptr < 0: if self._end_ptr < 0:
# if we havent been signaled about EOF yet # if we havent been signaled about EOF yet
@ -353,22 +472,39 @@ class RingBuffReceiver(trio.abc.ReceiveStream):
# no more bytes to read and self._end_ptr set, EOF reached # no more bytes to read and self._end_ptr set, EOF reached
return b'' return b''
# dont overflow caller return await self.receive_some(max_bytes=max_bytes)
delta = min(delta, max_bytes)
target_ptr = self._ptr + delta async def receive_exactly(self, num_bytes: int) -> bytes:
'''
Fetch bytes until we read exactly `num_bytes` or EOF.
# fetch next segment and advance ptr '''
segment = bytes(self._shm.buf[self._ptr:target_ptr]) payload = b''
self._ptr = target_ptr while len(payload) < num_bytes:
remaining = num_bytes - len(payload)
if self._ptr == self.size: new_bytes = await self.receive_some(
# reached the end, signal wrap around max_bytes=remaining
self._ptr = 0 )
self._write_ptr = 0
self._wrap_event.write(1)
return segment if new_bytes == b'':
raise trio.EndOfChannel
payload += new_bytes
return payload
async def receive(self) -> bytes:
'''
Receive a complete payload
'''
header: bytes = await self.receive_exactly(4)
size: int
size, = struct.unpack("<I", header)
if size == 0:
raise trio.EndOfChannel
return await self.receive_exactly(size)
def open(self): def open(self):
try: try:
@ -402,18 +538,20 @@ class RingBuffReceiver(trio.abc.ReceiveStream):
@acm @acm
async def attach_to_ringbuf_receiver( async def attach_to_ringbuf_receiver(
token: RBToken, token: RBToken,
cleanup: bool = True cleanup: bool = True
) -> AsyncContextManager[RingBuffReceiver]:
) -> AsyncContextManager[RingBufferReceiveChannel]:
''' '''
Attach a RingBuffReceiver from a previously opened Attach a RingBufferReceiveChannel from a previously opened
RBToken. RBToken.
Launches `receiver._eof_monitor_task` in a `trio.Nursery`. Launches `receiver._eof_monitor_task` in a `trio.Nursery`.
''' '''
async with ( async with (
trio.open_nursery() as n, trio.open_nursery() as n,
RingBuffReceiver( RingBufferReceiveChannel(
token, token,
cleanup=cleanup cleanup=cleanup
) as receiver ) as receiver
@ -424,232 +562,33 @@ async def attach_to_ringbuf_receiver(
@acm @acm
async def attach_to_ringbuf_sender( async def attach_to_ringbuf_sender(
token: RBToken, token: RBToken,
cleanup: bool = True cleanup: bool = True
) -> AsyncContextManager[RingBuffSender]:
) -> AsyncContextManager[RingBufferSendChannel]:
''' '''
Attach a RingBuffSender from a previously opened Attach a RingBufferSendChannel from a previously opened
RBToken. RBToken.
''' '''
async with RingBuffSender( async with RingBufferSendChannel(
token, token,
cleanup=cleanup cleanup=cleanup
) as sender: ) as sender:
yield sender yield sender
@cm class RingBufferChannel(trio.abc.Channel[bytes]):
def open_ringbuf_pair(
name: str,
buf_size: int = _DEFAULT_RB_SIZE
) -> ContextManager[tuple(RBToken, RBToken)]:
''' '''
Handle resources for a ringbuf pair to be used for Combine `RingBufferSendChannel` and `RingBufferReceiveChannel`
bidirectional messaging.
'''
with (
open_ringbuf(
name + '.pair0',
buf_size=buf_size
) as token_0,
open_ringbuf(
name + '.pair1',
buf_size=buf_size
) as token_1
):
yield token_0, token_1
@acm
async def attach_to_ringbuf_stream(
token_in: RBToken,
token_out: RBToken,
cleanup_in: bool = True,
cleanup_out: bool = True
) -> AsyncContextManager[trio.StapledStream]:
'''
Attach a trio.StapledStream from a previously opened
ringbuf pair.
'''
async with (
attach_to_ringbuf_receiver(
token_in,
cleanup=cleanup_in
) as receiver,
attach_to_ringbuf_sender(
token_out,
cleanup=cleanup_out
) as sender,
):
yield trio.StapledStream(sender, receiver)
class RingBuffBytesSender(trio.abc.SendChannel[bytes]):
'''
In order to guarantee full messages are received, all bytes
sent by `RingBuffBytesSender` are preceded with a 4 byte header
which decodes into a uint32 indicating the actual size of the
next payload.
Optional batch mode:
If `batch_size` > 1 messages wont get sent immediately but will be
stored until `batch_size` messages are pending, then it will send
them all at once.
`batch_size` can be changed dynamically but always call, `flush()`
right before.
'''
def __init__(
self,
sender: RingBuffSender,
batch_size: int = 1
):
self._sender = sender
self.batch_size = batch_size
self._batch_msg_len = 0
self._batch: bytes = b''
self._send_lock = trio.StrictFIFOLock()
@property
def pending_msgs(self) -> int:
return self._batch_msg_len
@property
def must_flush(self) -> bool:
return self._batch_msg_len >= self.batch_size
async def _flush(
self,
new_batch_size: int | None = None
) -> None:
await self._sender.send_all(self._batch)
self._batch = b''
self._batch_msg_len = 0
if new_batch_size:
self.batch_size = new_batch_size
async def flush(
self,
new_batch_size: int | None = None
) -> None:
async with self._send_lock:
await self._flush(new_batch_size=new_batch_size)
async def send(self, value: bytes) -> None:
async with self._send_lock:
msg: bytes = struct.pack("<I", len(value)) + value
if self.batch_size == 1:
await self._sender.send_all(msg)
return
self._batch += msg
self._batch_msg_len += 1
if self.must_flush:
await self._flush()
async def send_eof(self) -> None:
await self.flush(new_batch_size=1)
await self.send(b'')
async def aclose(self) -> None:
async with self._send_lock:
await self._sender.aclose()
class RingBuffBytesReceiver(trio.abc.ReceiveChannel[bytes]):
'''
See `RingBuffBytesSender` docstring.
A `tricycle.BufferedReceiveStream` is used for the
`receive_exactly` API.
'''
def __init__(
self,
receiver: RingBuffReceiver
):
self._receiver = receiver
async def _receive_exactly(self, num_bytes: int) -> bytes:
'''
Fetch bytes from receiver until we read exactly `num_bytes`
or end of stream is signaled.
'''
payload = b''
while len(payload) < num_bytes:
remaining = num_bytes - len(payload)
new_bytes = await self._receiver.receive_some(
max_bytes=remaining
)
if new_bytes == b'':
raise trio.EndOfChannel
payload += new_bytes
return payload
async def receive(self) -> bytes:
header: bytes = await self._receive_exactly(4)
size: int
size, = struct.unpack("<I", header)
if size == 0:
raise trio.EndOfChannel
return await self._receive_exactly(size)
async def aclose(self) -> None:
await self._receiver.aclose()
@acm
async def attach_to_ringbuf_rchannel(
token: RBToken,
cleanup: bool = True
) -> AsyncContextManager[RingBuffBytesReceiver]:
'''
Attach a RingBuffBytesReceiver from a previously opened
RBToken.
'''
async with attach_to_ringbuf_receiver(
token, cleanup=cleanup
) as receiver:
yield RingBuffBytesReceiver(receiver)
@acm
async def attach_to_ringbuf_schannel(
token: RBToken,
cleanup: bool = True,
batch_size: int = 1,
) -> AsyncContextManager[RingBuffBytesSender]:
'''
Attach a RingBuffBytesSender from a previously opened
RBToken.
'''
async with attach_to_ringbuf_sender(
token, cleanup=cleanup
) as sender:
yield RingBuffBytesSender(sender, batch_size=batch_size)
class RingBuffChannel(trio.abc.Channel[bytes]):
'''
Combine `RingBuffBytesSender` and `RingBuffBytesReceiver`
in order to expose the bidirectional `trio.abc.Channel` API. in order to expose the bidirectional `trio.abc.Channel` API.
''' '''
def __init__( def __init__(
self, self,
sender: RingBuffBytesSender, sender: RingBufferSendChannel,
receiver: RingBuffBytesReceiver receiver: RingBufferReceiveChannel
): ):
self._sender = sender self._sender = sender
self._receiver = receiver self._receiver = receiver
@ -666,6 +605,12 @@ class RingBuffChannel(trio.abc.Channel[bytes]):
def pending_msgs(self) -> int: def pending_msgs(self) -> int:
return self._sender.pending_msgs return self._sender.pending_msgs
async def send_all(self, value: bytes) -> None:
await self._sender.send_all(value)
async def wait_send_all_might_not_block(self):
await self._sender.wait_send_all_might_not_block()
async def flush( async def flush(
self, self,
new_batch_size: int | None = None new_batch_size: int | None = None
@ -678,6 +623,15 @@ class RingBuffChannel(trio.abc.Channel[bytes]):
async def send_eof(self) -> None: async def send_eof(self) -> None:
await self._sender.send_eof() await self._sender.send_eof()
def receive_nowait(self, max_bytes: int = _DEFAULT_RB_SIZE) -> bytes:
return self._receiver.receive_nowait(max_bytes=max_bytes)
async def receive_some(self, max_bytes: int = _DEFAULT_RB_SIZE) -> bytes:
return await self._receiver.receive_some(max_bytes=max_bytes)
async def receive_exactly(self, num_bytes: int) -> bytes:
return await self._receiver.receive_exactly(num_bytes)
async def receive(self) -> bytes: async def receive(self) -> bytes:
return await self._receiver.receive() return await self._receiver.receive()
@ -691,23 +645,20 @@ async def attach_to_ringbuf_channel(
token_in: RBToken, token_in: RBToken,
token_out: RBToken, token_out: RBToken,
cleanup_in: bool = True, cleanup_in: bool = True,
cleanup_out: bool = True, cleanup_out: bool = True
batch_size: int = 1 ) -> AsyncContextManager[trio.StapledStream]:
) -> AsyncContextManager[RingBuffChannel]:
''' '''
Attach to an already opened ringbuf pair and return Attach to two previously opened `RBToken`s and return a `RingBufferChannel`
a `RingBuffChannel`.
''' '''
async with ( async with (
attach_to_ringbuf_rchannel( attach_to_ringbuf_receiver(
token_in, token_in,
cleanup=cleanup_in cleanup=cleanup_in
) as receiver, ) as receiver,
attach_to_ringbuf_schannel( attach_to_ringbuf_sender(
token_out, token_out,
cleanup=cleanup_out, cleanup=cleanup_out
batch_size=batch_size
) as sender, ) as sender,
): ):
yield RingBuffChannel(sender, receiver) yield RingBufferChannel(sender, receiver)