Rename RingBuff -> RingBuffer
Combine RingBuffer stream and channel apis Implement RingBufferReceiveChannel.receive_nowait Make msg generator calculate hashone_ring_to_rule_them_all
parent
95ea4647cc
commit
3568ba5d5d
|
@ -9,7 +9,6 @@ from tractor.ipc._ringbuf import (
|
||||||
open_ringbuf,
|
open_ringbuf,
|
||||||
attach_to_ringbuf_receiver,
|
attach_to_ringbuf_receiver,
|
||||||
attach_to_ringbuf_sender,
|
attach_to_ringbuf_sender,
|
||||||
attach_to_ringbuf_stream,
|
|
||||||
attach_to_ringbuf_channel,
|
attach_to_ringbuf_channel,
|
||||||
RBToken,
|
RBToken,
|
||||||
)
|
)
|
||||||
|
@ -25,7 +24,6 @@ pytestmark = pytest.mark.skip
|
||||||
@tractor.context
|
@tractor.context
|
||||||
async def child_read_shm(
|
async def child_read_shm(
|
||||||
ctx: tractor.Context,
|
ctx: tractor.Context,
|
||||||
msg_amount: int,
|
|
||||||
token: RBToken,
|
token: RBToken,
|
||||||
) -> str:
|
) -> str:
|
||||||
'''
|
'''
|
||||||
|
@ -41,11 +39,13 @@ async def child_read_shm(
|
||||||
'''
|
'''
|
||||||
await ctx.started()
|
await ctx.started()
|
||||||
print('reader started')
|
print('reader started')
|
||||||
|
msg_amount = 0
|
||||||
recvd_bytes = 0
|
recvd_bytes = 0
|
||||||
recvd_hash = hashlib.sha256()
|
recvd_hash = hashlib.sha256()
|
||||||
start_ts = time.time()
|
start_ts = time.time()
|
||||||
async with attach_to_ringbuf_receiver(token) as receiver:
|
async with attach_to_ringbuf_receiver(token) as receiver:
|
||||||
async for msg in receiver:
|
async for msg in receiver:
|
||||||
|
msg_amount += 1
|
||||||
recvd_hash.update(msg)
|
recvd_hash.update(msg)
|
||||||
recvd_bytes += len(msg)
|
recvd_bytes += len(msg)
|
||||||
|
|
||||||
|
@ -79,19 +79,16 @@ async def child_write_shm(
|
||||||
Attach to ringbuf and send all generated messages.
|
Attach to ringbuf and send all generated messages.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
msgs, _total_bytes = generate_sample_messages(
|
sent_hash, msgs, _total_bytes = generate_sample_messages(
|
||||||
msg_amount,
|
msg_amount,
|
||||||
rand_min=rand_min,
|
rand_min=rand_min,
|
||||||
rand_max=rand_max,
|
rand_max=rand_max,
|
||||||
)
|
)
|
||||||
print('writer hashing payload...')
|
|
||||||
sent_hash = hashlib.sha256(b''.join(msgs)).hexdigest()
|
|
||||||
print('writer done hashing.')
|
|
||||||
await ctx.started(sent_hash)
|
await ctx.started(sent_hash)
|
||||||
print('writer started')
|
print('writer started')
|
||||||
async with attach_to_ringbuf_sender(token, cleanup=False) as sender:
|
async with attach_to_ringbuf_sender(token, cleanup=False) as sender:
|
||||||
for msg in msgs:
|
for msg in msgs:
|
||||||
await sender.send_all(msg)
|
await sender.send(msg)
|
||||||
|
|
||||||
print('writer exit')
|
print('writer exit')
|
||||||
|
|
||||||
|
@ -159,7 +156,6 @@ def test_ringbuf(
|
||||||
recv_p.open_context(
|
recv_p.open_context(
|
||||||
child_read_shm,
|
child_read_shm,
|
||||||
token=token,
|
token=token,
|
||||||
msg_amount=msg_amount
|
|
||||||
) as (rctx, _sent),
|
) as (rctx, _sent),
|
||||||
):
|
):
|
||||||
recvd_hash = await rctx.result()
|
recvd_hash = await rctx.result()
|
||||||
|
@ -295,75 +291,6 @@ def test_receiver_max_bytes():
|
||||||
assert msg == b''.join(msgs)
|
assert msg == b''.join(msgs)
|
||||||
|
|
||||||
|
|
||||||
def test_stapled_ringbuf():
|
|
||||||
'''
|
|
||||||
Open two ringbufs and give tokens to tasks (swap them such that in/out tokens
|
|
||||||
are inversed on each task) which will open the streams and use trio.StapledStream
|
|
||||||
to have a single bidirectional stream.
|
|
||||||
|
|
||||||
Then take turns to send and receive messages.
|
|
||||||
|
|
||||||
'''
|
|
||||||
msg = generate_single_byte_msgs(100)
|
|
||||||
pair_0_msgs = []
|
|
||||||
pair_1_msgs = []
|
|
||||||
|
|
||||||
pair_0_done = trio.Event()
|
|
||||||
pair_1_done = trio.Event()
|
|
||||||
|
|
||||||
async def pair_0(token_in: RBToken, token_out: RBToken):
|
|
||||||
async with attach_to_ringbuf_stream(
|
|
||||||
token_in,
|
|
||||||
token_out,
|
|
||||||
cleanup_in=False,
|
|
||||||
cleanup_out=False
|
|
||||||
) as stream:
|
|
||||||
# first turn to send
|
|
||||||
await stream.send_all(msg)
|
|
||||||
|
|
||||||
# second turn to receive
|
|
||||||
while len(pair_0_msgs) != len(msg):
|
|
||||||
_msg = await stream.receive_some(max_bytes=1)
|
|
||||||
pair_0_msgs.append(_msg)
|
|
||||||
|
|
||||||
pair_0_done.set()
|
|
||||||
await pair_1_done.wait()
|
|
||||||
|
|
||||||
|
|
||||||
async def pair_1(token_in: RBToken, token_out: RBToken):
|
|
||||||
async with attach_to_ringbuf_stream(
|
|
||||||
token_in,
|
|
||||||
token_out,
|
|
||||||
cleanup_in=False,
|
|
||||||
cleanup_out=False
|
|
||||||
) as stream:
|
|
||||||
# first turn to receive
|
|
||||||
while len(pair_1_msgs) != len(msg):
|
|
||||||
_msg = await stream.receive_some(max_bytes=1)
|
|
||||||
pair_1_msgs.append(_msg)
|
|
||||||
|
|
||||||
# second turn to send
|
|
||||||
await stream.send_all(msg)
|
|
||||||
|
|
||||||
pair_1_done.set()
|
|
||||||
await pair_0_done.wait()
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
with tractor.ipc.open_ringbuf_pair(
|
|
||||||
'test_stapled_ringbuf'
|
|
||||||
) as (token_0, token_1):
|
|
||||||
async with trio.open_nursery() as n:
|
|
||||||
n.start_soon(pair_0, token_0, token_1)
|
|
||||||
n.start_soon(pair_1, token_1, token_0)
|
|
||||||
|
|
||||||
|
|
||||||
trio.run(main)
|
|
||||||
|
|
||||||
assert msg == b''.join(pair_0_msgs)
|
|
||||||
assert msg == b''.join(pair_1_msgs)
|
|
||||||
|
|
||||||
|
|
||||||
@tractor.context
|
@tractor.context
|
||||||
async def child_channel_sender(
|
async def child_channel_sender(
|
||||||
ctx: tractor.Context,
|
ctx: tractor.Context,
|
||||||
|
@ -373,7 +300,7 @@ async def child_channel_sender(
|
||||||
token_out: RBToken
|
token_out: RBToken
|
||||||
):
|
):
|
||||||
import random
|
import random
|
||||||
msgs, _total_bytes = generate_sample_messages(
|
_hash, msgs, _total_bytes = generate_sample_messages(
|
||||||
random.randint(msg_amount_min, msg_amount_max),
|
random.randint(msg_amount_min, msg_amount_max),
|
||||||
rand_min=256,
|
rand_min=256,
|
||||||
rand_max=1024,
|
rand_max=1024,
|
||||||
|
@ -383,7 +310,6 @@ async def child_channel_sender(
|
||||||
token_out
|
token_out
|
||||||
) as chan:
|
) as chan:
|
||||||
await ctx.started(msgs)
|
await ctx.started(msgs)
|
||||||
|
|
||||||
for msg in msgs:
|
for msg in msgs:
|
||||||
await chan.send(msg)
|
await chan.send(msg)
|
||||||
|
|
||||||
|
@ -396,16 +322,16 @@ def test_channel():
|
||||||
async def main():
|
async def main():
|
||||||
with tractor.ipc.open_ringbuf_pair(
|
with tractor.ipc.open_ringbuf_pair(
|
||||||
'test_ringbuf_transport'
|
'test_ringbuf_transport'
|
||||||
) as (token_0, token_1):
|
) as (send_token, recv_token):
|
||||||
async with (
|
async with (
|
||||||
attach_to_ringbuf_channel(token_0, token_1) as chan,
|
attach_to_ringbuf_channel(send_token, recv_token) as chan,
|
||||||
tractor.open_nursery() as an
|
tractor.open_nursery() as an
|
||||||
):
|
):
|
||||||
recv_p = await an.start_actor(
|
recv_p = await an.start_actor(
|
||||||
'test_ringbuf_transport_sender',
|
'test_ringbuf_transport_sender',
|
||||||
enable_modules=[__name__],
|
enable_modules=[__name__],
|
||||||
proc_kwargs={
|
proc_kwargs={
|
||||||
'pass_fds': token_0.fds + token_1.fds
|
'pass_fds': send_token.fds + recv_token.fds
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
async with (
|
async with (
|
||||||
|
@ -413,8 +339,8 @@ def test_channel():
|
||||||
child_channel_sender,
|
child_channel_sender,
|
||||||
msg_amount_min=msg_amount_min,
|
msg_amount_min=msg_amount_min,
|
||||||
msg_amount_max=msg_amount_max,
|
msg_amount_max=msg_amount_max,
|
||||||
token_in=token_1,
|
token_in=recv_token,
|
||||||
token_out=token_0
|
token_out=send_token
|
||||||
) as (ctx, msgs),
|
) as (ctx, msgs),
|
||||||
):
|
):
|
||||||
recv_msgs = []
|
recv_msgs = []
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
|
||||||
def generate_single_byte_msgs(amount: int) -> bytes:
|
def generate_single_byte_msgs(amount: int) -> bytes:
|
||||||
|
@ -23,7 +24,7 @@ def generate_sample_messages(
|
||||||
rand_min: int = 0,
|
rand_min: int = 0,
|
||||||
rand_max: int = 0,
|
rand_max: int = 0,
|
||||||
silent: bool = False,
|
silent: bool = False,
|
||||||
) -> tuple[list[bytes], int]:
|
) -> tuple[str, list[bytes], int]:
|
||||||
'''
|
'''
|
||||||
Generate bytes msgs for tests.
|
Generate bytes msgs for tests.
|
||||||
|
|
||||||
|
@ -55,6 +56,7 @@ def generate_sample_messages(
|
||||||
else:
|
else:
|
||||||
log_interval = 1000
|
log_interval = 1000
|
||||||
|
|
||||||
|
payload_hash = hashlib.sha256()
|
||||||
for i in range(amount):
|
for i in range(amount):
|
||||||
msg = f'[{i:08}]'.encode('utf-8')
|
msg = f'[{i:08}]'.encode('utf-8')
|
||||||
|
|
||||||
|
@ -64,6 +66,7 @@ def generate_sample_messages(
|
||||||
|
|
||||||
size += len(msg)
|
size += len(msg)
|
||||||
|
|
||||||
|
payload_hash.update(msg)
|
||||||
msgs.append(msg)
|
msgs.append(msg)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -78,4 +81,4 @@ def generate_sample_messages(
|
||||||
if not silent:
|
if not silent:
|
||||||
print(f'done, {size:,} bytes in total')
|
print(f'done, {size:,} bytes in total')
|
||||||
|
|
||||||
return msgs, size
|
return payload_hash.hexdigest(), msgs, size
|
||||||
|
|
|
@ -27,17 +27,16 @@ from ._chan import (
|
||||||
if platform.system() == 'Linux':
|
if platform.system() == 'Linux':
|
||||||
from ._ringbuf import (
|
from ._ringbuf import (
|
||||||
RBToken as RBToken,
|
RBToken as RBToken,
|
||||||
|
|
||||||
open_ringbuf as open_ringbuf,
|
open_ringbuf as open_ringbuf,
|
||||||
RingBuffSender as RingBuffSender,
|
|
||||||
RingBuffReceiver as RingBuffReceiver,
|
|
||||||
open_ringbuf_pair as open_ringbuf_pair,
|
open_ringbuf_pair as open_ringbuf_pair,
|
||||||
attach_to_ringbuf_receiver as attach_to_ringbuf_receiver,
|
|
||||||
|
RingBufferSendChannel as RingBufferSendChannel,
|
||||||
attach_to_ringbuf_sender as attach_to_ringbuf_sender,
|
attach_to_ringbuf_sender as attach_to_ringbuf_sender,
|
||||||
attach_to_ringbuf_stream as attach_to_ringbuf_stream,
|
|
||||||
RingBuffBytesSender as RingBuffBytesSender,
|
RingBufferReceiveChannel as RingBufferReceiveChannel,
|
||||||
RingBuffBytesReceiver as RingBuffBytesReceiver,
|
attach_to_ringbuf_receiver as attach_to_ringbuf_receiver,
|
||||||
RingBuffChannel as RingBuffChannel,
|
|
||||||
attach_to_ringbuf_schannel as attach_to_ringbuf_schannel,
|
RingBufferChannel as RingBufferChannel,
|
||||||
attach_to_ringbuf_rchannel as attach_to_ringbuf_rchannel,
|
|
||||||
attach_to_ringbuf_channel as attach_to_ringbuf_channel,
|
attach_to_ringbuf_channel as attach_to_ringbuf_channel,
|
||||||
)
|
)
|
||||||
|
|
|
@ -126,6 +126,30 @@ def open_ringbuf(
|
||||||
shm.unlink()
|
shm.unlink()
|
||||||
|
|
||||||
|
|
||||||
|
@cm
|
||||||
|
def open_ringbuf_pair(
|
||||||
|
name: str,
|
||||||
|
buf_size: int = _DEFAULT_RB_SIZE
|
||||||
|
) -> ContextManager[tuple(RBToken, RBToken)]:
|
||||||
|
'''
|
||||||
|
Handle resources for a ringbuf pair to be used for
|
||||||
|
bidirectional messaging.
|
||||||
|
|
||||||
|
'''
|
||||||
|
with (
|
||||||
|
open_ringbuf(
|
||||||
|
name + '.send',
|
||||||
|
buf_size=buf_size
|
||||||
|
) as send_token,
|
||||||
|
|
||||||
|
open_ringbuf(
|
||||||
|
name + '.recv',
|
||||||
|
buf_size=buf_size
|
||||||
|
) as recv_token
|
||||||
|
):
|
||||||
|
yield send_token, recv_token
|
||||||
|
|
||||||
|
|
||||||
Buffer = bytes | bytearray | memoryview
|
Buffer = bytes | bytearray | memoryview
|
||||||
|
|
||||||
|
|
||||||
|
@ -135,32 +159,65 @@ IPC Reliable Ring Buffer
|
||||||
`eventfd(2)` is used for wrap around sync, to signal writes to
|
`eventfd(2)` is used for wrap around sync, to signal writes to
|
||||||
the reader and end of stream.
|
the reader and end of stream.
|
||||||
|
|
||||||
|
In order to guarantee full messages are received, all bytes
|
||||||
|
sent by `RingBufferSendChannel` are preceded with a 4 byte header
|
||||||
|
which decodes into a uint32 indicating the actual size of the
|
||||||
|
next full payload.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
|
||||||
class RingBuffSender(trio.abc.SendStream):
|
class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
|
||||||
'''
|
'''
|
||||||
Ring Buffer sender side implementation
|
Ring Buffer sender side implementation
|
||||||
|
|
||||||
Do not use directly! manage with `attach_to_ringbuf_sender`
|
Do not use directly! manage with `attach_to_ringbuf_sender`
|
||||||
after having opened a ringbuf context with `open_ringbuf`.
|
after having opened a ringbuf context with `open_ringbuf`.
|
||||||
|
|
||||||
|
Optional batch mode:
|
||||||
|
|
||||||
|
If `batch_size` > 1 messages wont get sent immediately but will be
|
||||||
|
stored until `batch_size` messages are pending, then it will send
|
||||||
|
them all at once.
|
||||||
|
|
||||||
|
`batch_size` can be changed dynamically but always call, `flush()`
|
||||||
|
right before.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
token: RBToken,
|
token: RBToken,
|
||||||
|
batch_size: int = 1,
|
||||||
cleanup: bool = False
|
cleanup: bool = False
|
||||||
):
|
):
|
||||||
self._token = RBToken.from_msg(token)
|
self._token = RBToken.from_msg(token)
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
# ringbuf os resources
|
||||||
self._shm: SharedMemory | None = None
|
self._shm: SharedMemory | None = None
|
||||||
self._write_event = EventFD(self._token.write_eventfd, 'w')
|
self._write_event = EventFD(self._token.write_eventfd, 'w')
|
||||||
self._wrap_event = EventFD(self._token.wrap_eventfd, 'r')
|
self._wrap_event = EventFD(self._token.wrap_eventfd, 'r')
|
||||||
self._eof_event = EventFD(self._token.eof_eventfd, 'w')
|
self._eof_event = EventFD(self._token.eof_eventfd, 'w')
|
||||||
|
|
||||||
|
# current write pointer
|
||||||
self._ptr = 0
|
self._ptr = 0
|
||||||
|
|
||||||
|
# when `batch_size` > 1 store messages on `self._batch` and write them
|
||||||
|
# all, once `len(self._batch) == `batch_size`
|
||||||
|
self._batch: list[bytes] = []
|
||||||
|
|
||||||
self._cleanup = cleanup
|
self._cleanup = cleanup
|
||||||
self._send_lock = trio.StrictFIFOLock()
|
self._send_lock = trio.StrictFIFOLock()
|
||||||
|
|
||||||
|
@acm
|
||||||
|
async def _maybe_lock(self) -> AsyncContextManager[None]:
|
||||||
|
if self._send_lock.locked():
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
async with self._send_lock:
|
||||||
|
yield
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
if not self._shm:
|
if not self._shm:
|
||||||
|
@ -183,11 +240,19 @@ class RingBuffSender(trio.abc.SendStream):
|
||||||
def wrap_fd(self) -> int:
|
def wrap_fd(self) -> int:
|
||||||
return self._wrap_event.fd
|
return self._wrap_event.fd
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pending_msgs(self) -> int:
|
||||||
|
return len(self._batch)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def must_flush(self) -> bool:
|
||||||
|
return self.pending_msgs >= self.batch_size
|
||||||
|
|
||||||
async def _wait_wrap(self):
|
async def _wait_wrap(self):
|
||||||
await self._wrap_event.read()
|
await self._wrap_event.read()
|
||||||
|
|
||||||
async def send_all(self, data: Buffer):
|
async def send_all(self, data: Buffer):
|
||||||
async with self._send_lock:
|
async with self._maybe_lock():
|
||||||
# while data is larger than the remaining buf
|
# while data is larger than the remaining buf
|
||||||
target_ptr = self.ptr + len(data)
|
target_ptr = self.ptr + len(data)
|
||||||
while target_ptr > self.size:
|
while target_ptr > self.size:
|
||||||
|
@ -211,6 +276,34 @@ class RingBuffSender(trio.abc.SendStream):
|
||||||
async def wait_send_all_might_not_block(self):
|
async def wait_send_all_might_not_block(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def flush(
|
||||||
|
self,
|
||||||
|
new_batch_size: int | None = None
|
||||||
|
) -> None:
|
||||||
|
async with self._maybe_lock():
|
||||||
|
for msg in self._batch:
|
||||||
|
await self.send_all(msg)
|
||||||
|
|
||||||
|
self._batch = []
|
||||||
|
if new_batch_size:
|
||||||
|
self.batch_size = new_batch_size
|
||||||
|
|
||||||
|
async def send(self, value: bytes) -> None:
|
||||||
|
async with self._maybe_lock():
|
||||||
|
msg: bytes = struct.pack("<I", len(value)) + value
|
||||||
|
if self.batch_size == 1:
|
||||||
|
await self.send_all(msg)
|
||||||
|
return
|
||||||
|
|
||||||
|
self._batch.append(msg)
|
||||||
|
if self.must_flush:
|
||||||
|
await self.flush()
|
||||||
|
|
||||||
|
async def send_eof(self) -> None:
|
||||||
|
async with self._send_lock:
|
||||||
|
await self.flush(new_batch_size=1)
|
||||||
|
await self.send(b'')
|
||||||
|
|
||||||
def open(self):
|
def open(self):
|
||||||
try:
|
try:
|
||||||
self._shm = SharedMemory(
|
self._shm = SharedMemory(
|
||||||
|
@ -238,15 +331,14 @@ class RingBuffSender(trio.abc.SendStream):
|
||||||
self._shm.close()
|
self._shm.close()
|
||||||
|
|
||||||
async def aclose(self):
|
async def aclose(self):
|
||||||
async with self._send_lock:
|
self.close()
|
||||||
self.close()
|
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
self.open()
|
self.open()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
class RingBuffReceiver(trio.abc.ReceiveStream):
|
class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
||||||
'''
|
'''
|
||||||
Ring Buffer receiver side implementation
|
Ring Buffer receiver side implementation
|
||||||
|
|
||||||
|
@ -312,21 +404,48 @@ class RingBuffReceiver(trio.abc.ReceiveStream):
|
||||||
except trio.Cancelled:
|
except trio.Cancelled:
|
||||||
...
|
...
|
||||||
|
|
||||||
async def receive_some(self, max_bytes: int | None = None) -> bytes:
|
def receive_nowait(self, max_bytes: int = _DEFAULT_RB_SIZE) -> bytes:
|
||||||
|
'''
|
||||||
|
Try to receive any bytes we can without blocking or raise
|
||||||
|
`trio.WouldBlock`.
|
||||||
|
|
||||||
|
'''
|
||||||
|
if max_bytes < 1:
|
||||||
|
raise ValueError("max_bytes must be >= 1")
|
||||||
|
|
||||||
|
delta = self._write_ptr - self._ptr
|
||||||
|
if delta == 0:
|
||||||
|
raise trio.WouldBlock
|
||||||
|
|
||||||
|
# dont overflow caller
|
||||||
|
delta = min(delta, max_bytes)
|
||||||
|
|
||||||
|
target_ptr = self._ptr + delta
|
||||||
|
|
||||||
|
# fetch next segment and advance ptr
|
||||||
|
segment = bytes(self._shm.buf[self._ptr:target_ptr])
|
||||||
|
self._ptr = target_ptr
|
||||||
|
|
||||||
|
if self._ptr == self.size:
|
||||||
|
# reached the end, signal wrap around
|
||||||
|
self._ptr = 0
|
||||||
|
self._write_ptr = 0
|
||||||
|
self._wrap_event.write(1)
|
||||||
|
|
||||||
|
return segment
|
||||||
|
|
||||||
|
async def receive_some(self, max_bytes: int = _DEFAULT_RB_SIZE) -> bytes:
|
||||||
'''
|
'''
|
||||||
Receive up to `max_bytes`, if no `max_bytes` is provided
|
Receive up to `max_bytes`, if no `max_bytes` is provided
|
||||||
a reasonable default is used.
|
a reasonable default is used.
|
||||||
|
|
||||||
|
Can return < max_bytes.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
if max_bytes is None:
|
try:
|
||||||
max_bytes: int = _DEFAULT_RB_SIZE
|
return self.receive_nowait(max_bytes=max_bytes)
|
||||||
|
|
||||||
if max_bytes < 1:
|
except trio.WouldBlock:
|
||||||
raise ValueError("max_bytes must be >= 1")
|
|
||||||
|
|
||||||
# delta is remaining bytes we havent read
|
|
||||||
delta = self._write_ptr - self._ptr
|
|
||||||
if delta == 0:
|
|
||||||
# we have read all we can, see if new data is available
|
# we have read all we can, see if new data is available
|
||||||
if self._end_ptr < 0:
|
if self._end_ptr < 0:
|
||||||
# if we havent been signaled about EOF yet
|
# if we havent been signaled about EOF yet
|
||||||
|
@ -353,22 +472,39 @@ class RingBuffReceiver(trio.abc.ReceiveStream):
|
||||||
# no more bytes to read and self._end_ptr set, EOF reached
|
# no more bytes to read and self._end_ptr set, EOF reached
|
||||||
return b''
|
return b''
|
||||||
|
|
||||||
# dont overflow caller
|
return await self.receive_some(max_bytes=max_bytes)
|
||||||
delta = min(delta, max_bytes)
|
|
||||||
|
|
||||||
target_ptr = self._ptr + delta
|
async def receive_exactly(self, num_bytes: int) -> bytes:
|
||||||
|
'''
|
||||||
|
Fetch bytes until we read exactly `num_bytes` or EOF.
|
||||||
|
|
||||||
# fetch next segment and advance ptr
|
'''
|
||||||
segment = bytes(self._shm.buf[self._ptr:target_ptr])
|
payload = b''
|
||||||
self._ptr = target_ptr
|
while len(payload) < num_bytes:
|
||||||
|
remaining = num_bytes - len(payload)
|
||||||
|
|
||||||
if self._ptr == self.size:
|
new_bytes = await self.receive_some(
|
||||||
# reached the end, signal wrap around
|
max_bytes=remaining
|
||||||
self._ptr = 0
|
)
|
||||||
self._write_ptr = 0
|
|
||||||
self._wrap_event.write(1)
|
|
||||||
|
|
||||||
return segment
|
if new_bytes == b'':
|
||||||
|
raise trio.EndOfChannel
|
||||||
|
|
||||||
|
payload += new_bytes
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
|
async def receive(self) -> bytes:
|
||||||
|
'''
|
||||||
|
Receive a complete payload
|
||||||
|
|
||||||
|
'''
|
||||||
|
header: bytes = await self.receive_exactly(4)
|
||||||
|
size: int
|
||||||
|
size, = struct.unpack("<I", header)
|
||||||
|
if size == 0:
|
||||||
|
raise trio.EndOfChannel
|
||||||
|
return await self.receive_exactly(size)
|
||||||
|
|
||||||
def open(self):
|
def open(self):
|
||||||
try:
|
try:
|
||||||
|
@ -402,18 +538,20 @@ class RingBuffReceiver(trio.abc.ReceiveStream):
|
||||||
|
|
||||||
@acm
|
@acm
|
||||||
async def attach_to_ringbuf_receiver(
|
async def attach_to_ringbuf_receiver(
|
||||||
|
|
||||||
token: RBToken,
|
token: RBToken,
|
||||||
cleanup: bool = True
|
cleanup: bool = True
|
||||||
) -> AsyncContextManager[RingBuffReceiver]:
|
|
||||||
|
) -> AsyncContextManager[RingBufferReceiveChannel]:
|
||||||
'''
|
'''
|
||||||
Attach a RingBuffReceiver from a previously opened
|
Attach a RingBufferReceiveChannel from a previously opened
|
||||||
RBToken.
|
RBToken.
|
||||||
|
|
||||||
Launches `receiver._eof_monitor_task` in a `trio.Nursery`.
|
Launches `receiver._eof_monitor_task` in a `trio.Nursery`.
|
||||||
'''
|
'''
|
||||||
async with (
|
async with (
|
||||||
trio.open_nursery() as n,
|
trio.open_nursery() as n,
|
||||||
RingBuffReceiver(
|
RingBufferReceiveChannel(
|
||||||
token,
|
token,
|
||||||
cleanup=cleanup
|
cleanup=cleanup
|
||||||
) as receiver
|
) as receiver
|
||||||
|
@ -424,232 +562,33 @@ async def attach_to_ringbuf_receiver(
|
||||||
|
|
||||||
@acm
|
@acm
|
||||||
async def attach_to_ringbuf_sender(
|
async def attach_to_ringbuf_sender(
|
||||||
|
|
||||||
token: RBToken,
|
token: RBToken,
|
||||||
cleanup: bool = True
|
cleanup: bool = True
|
||||||
) -> AsyncContextManager[RingBuffSender]:
|
|
||||||
|
) -> AsyncContextManager[RingBufferSendChannel]:
|
||||||
'''
|
'''
|
||||||
Attach a RingBuffSender from a previously opened
|
Attach a RingBufferSendChannel from a previously opened
|
||||||
RBToken.
|
RBToken.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
async with RingBuffSender(
|
async with RingBufferSendChannel(
|
||||||
token,
|
token,
|
||||||
cleanup=cleanup
|
cleanup=cleanup
|
||||||
) as sender:
|
) as sender:
|
||||||
yield sender
|
yield sender
|
||||||
|
|
||||||
|
|
||||||
@cm
|
class RingBufferChannel(trio.abc.Channel[bytes]):
|
||||||
def open_ringbuf_pair(
|
|
||||||
name: str,
|
|
||||||
buf_size: int = _DEFAULT_RB_SIZE
|
|
||||||
) -> ContextManager[tuple(RBToken, RBToken)]:
|
|
||||||
'''
|
'''
|
||||||
Handle resources for a ringbuf pair to be used for
|
Combine `RingBufferSendChannel` and `RingBufferReceiveChannel`
|
||||||
bidirectional messaging.
|
|
||||||
|
|
||||||
'''
|
|
||||||
with (
|
|
||||||
open_ringbuf(
|
|
||||||
name + '.pair0',
|
|
||||||
buf_size=buf_size
|
|
||||||
) as token_0,
|
|
||||||
|
|
||||||
open_ringbuf(
|
|
||||||
name + '.pair1',
|
|
||||||
buf_size=buf_size
|
|
||||||
) as token_1
|
|
||||||
):
|
|
||||||
yield token_0, token_1
|
|
||||||
|
|
||||||
|
|
||||||
@acm
|
|
||||||
async def attach_to_ringbuf_stream(
|
|
||||||
token_in: RBToken,
|
|
||||||
token_out: RBToken,
|
|
||||||
cleanup_in: bool = True,
|
|
||||||
cleanup_out: bool = True
|
|
||||||
) -> AsyncContextManager[trio.StapledStream]:
|
|
||||||
'''
|
|
||||||
Attach a trio.StapledStream from a previously opened
|
|
||||||
ringbuf pair.
|
|
||||||
|
|
||||||
'''
|
|
||||||
async with (
|
|
||||||
attach_to_ringbuf_receiver(
|
|
||||||
token_in,
|
|
||||||
cleanup=cleanup_in
|
|
||||||
) as receiver,
|
|
||||||
attach_to_ringbuf_sender(
|
|
||||||
token_out,
|
|
||||||
cleanup=cleanup_out
|
|
||||||
) as sender,
|
|
||||||
):
|
|
||||||
yield trio.StapledStream(sender, receiver)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class RingBuffBytesSender(trio.abc.SendChannel[bytes]):
|
|
||||||
'''
|
|
||||||
In order to guarantee full messages are received, all bytes
|
|
||||||
sent by `RingBuffBytesSender` are preceded with a 4 byte header
|
|
||||||
which decodes into a uint32 indicating the actual size of the
|
|
||||||
next payload.
|
|
||||||
|
|
||||||
Optional batch mode:
|
|
||||||
|
|
||||||
If `batch_size` > 1 messages wont get sent immediately but will be
|
|
||||||
stored until `batch_size` messages are pending, then it will send
|
|
||||||
them all at once.
|
|
||||||
|
|
||||||
`batch_size` can be changed dynamically but always call, `flush()`
|
|
||||||
right before.
|
|
||||||
|
|
||||||
'''
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
sender: RingBuffSender,
|
|
||||||
batch_size: int = 1
|
|
||||||
):
|
|
||||||
self._sender = sender
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self._batch_msg_len = 0
|
|
||||||
self._batch: bytes = b''
|
|
||||||
self._send_lock = trio.StrictFIFOLock()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def pending_msgs(self) -> int:
|
|
||||||
return self._batch_msg_len
|
|
||||||
|
|
||||||
@property
|
|
||||||
def must_flush(self) -> bool:
|
|
||||||
return self._batch_msg_len >= self.batch_size
|
|
||||||
|
|
||||||
async def _flush(
|
|
||||||
self,
|
|
||||||
new_batch_size: int | None = None
|
|
||||||
) -> None:
|
|
||||||
await self._sender.send_all(self._batch)
|
|
||||||
self._batch = b''
|
|
||||||
self._batch_msg_len = 0
|
|
||||||
if new_batch_size:
|
|
||||||
self.batch_size = new_batch_size
|
|
||||||
|
|
||||||
async def flush(
|
|
||||||
self,
|
|
||||||
new_batch_size: int | None = None
|
|
||||||
) -> None:
|
|
||||||
async with self._send_lock:
|
|
||||||
await self._flush(new_batch_size=new_batch_size)
|
|
||||||
|
|
||||||
async def send(self, value: bytes) -> None:
|
|
||||||
async with self._send_lock:
|
|
||||||
msg: bytes = struct.pack("<I", len(value)) + value
|
|
||||||
if self.batch_size == 1:
|
|
||||||
await self._sender.send_all(msg)
|
|
||||||
return
|
|
||||||
|
|
||||||
self._batch += msg
|
|
||||||
self._batch_msg_len += 1
|
|
||||||
if self.must_flush:
|
|
||||||
await self._flush()
|
|
||||||
|
|
||||||
async def send_eof(self) -> None:
|
|
||||||
await self.flush(new_batch_size=1)
|
|
||||||
await self.send(b'')
|
|
||||||
|
|
||||||
async def aclose(self) -> None:
|
|
||||||
async with self._send_lock:
|
|
||||||
await self._sender.aclose()
|
|
||||||
|
|
||||||
|
|
||||||
class RingBuffBytesReceiver(trio.abc.ReceiveChannel[bytes]):
|
|
||||||
'''
|
|
||||||
See `RingBuffBytesSender` docstring.
|
|
||||||
|
|
||||||
A `tricycle.BufferedReceiveStream` is used for the
|
|
||||||
`receive_exactly` API.
|
|
||||||
'''
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
receiver: RingBuffReceiver
|
|
||||||
):
|
|
||||||
self._receiver = receiver
|
|
||||||
|
|
||||||
async def _receive_exactly(self, num_bytes: int) -> bytes:
|
|
||||||
'''
|
|
||||||
Fetch bytes from receiver until we read exactly `num_bytes`
|
|
||||||
or end of stream is signaled.
|
|
||||||
|
|
||||||
'''
|
|
||||||
payload = b''
|
|
||||||
while len(payload) < num_bytes:
|
|
||||||
remaining = num_bytes - len(payload)
|
|
||||||
|
|
||||||
new_bytes = await self._receiver.receive_some(
|
|
||||||
max_bytes=remaining
|
|
||||||
)
|
|
||||||
|
|
||||||
if new_bytes == b'':
|
|
||||||
raise trio.EndOfChannel
|
|
||||||
|
|
||||||
payload += new_bytes
|
|
||||||
|
|
||||||
return payload
|
|
||||||
|
|
||||||
async def receive(self) -> bytes:
|
|
||||||
header: bytes = await self._receive_exactly(4)
|
|
||||||
size: int
|
|
||||||
size, = struct.unpack("<I", header)
|
|
||||||
if size == 0:
|
|
||||||
raise trio.EndOfChannel
|
|
||||||
return await self._receive_exactly(size)
|
|
||||||
|
|
||||||
async def aclose(self) -> None:
|
|
||||||
await self._receiver.aclose()
|
|
||||||
|
|
||||||
|
|
||||||
@acm
|
|
||||||
async def attach_to_ringbuf_rchannel(
|
|
||||||
token: RBToken,
|
|
||||||
cleanup: bool = True
|
|
||||||
) -> AsyncContextManager[RingBuffBytesReceiver]:
|
|
||||||
'''
|
|
||||||
Attach a RingBuffBytesReceiver from a previously opened
|
|
||||||
RBToken.
|
|
||||||
'''
|
|
||||||
async with attach_to_ringbuf_receiver(
|
|
||||||
token, cleanup=cleanup
|
|
||||||
) as receiver:
|
|
||||||
yield RingBuffBytesReceiver(receiver)
|
|
||||||
|
|
||||||
|
|
||||||
@acm
|
|
||||||
async def attach_to_ringbuf_schannel(
|
|
||||||
token: RBToken,
|
|
||||||
cleanup: bool = True,
|
|
||||||
batch_size: int = 1,
|
|
||||||
) -> AsyncContextManager[RingBuffBytesSender]:
|
|
||||||
'''
|
|
||||||
Attach a RingBuffBytesSender from a previously opened
|
|
||||||
RBToken.
|
|
||||||
'''
|
|
||||||
async with attach_to_ringbuf_sender(
|
|
||||||
token, cleanup=cleanup
|
|
||||||
) as sender:
|
|
||||||
yield RingBuffBytesSender(sender, batch_size=batch_size)
|
|
||||||
|
|
||||||
|
|
||||||
class RingBuffChannel(trio.abc.Channel[bytes]):
|
|
||||||
'''
|
|
||||||
Combine `RingBuffBytesSender` and `RingBuffBytesReceiver`
|
|
||||||
in order to expose the bidirectional `trio.abc.Channel` API.
|
in order to expose the bidirectional `trio.abc.Channel` API.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
sender: RingBuffBytesSender,
|
sender: RingBufferSendChannel,
|
||||||
receiver: RingBuffBytesReceiver
|
receiver: RingBufferReceiveChannel
|
||||||
):
|
):
|
||||||
self._sender = sender
|
self._sender = sender
|
||||||
self._receiver = receiver
|
self._receiver = receiver
|
||||||
|
@ -666,6 +605,12 @@ class RingBuffChannel(trio.abc.Channel[bytes]):
|
||||||
def pending_msgs(self) -> int:
|
def pending_msgs(self) -> int:
|
||||||
return self._sender.pending_msgs
|
return self._sender.pending_msgs
|
||||||
|
|
||||||
|
async def send_all(self, value: bytes) -> None:
|
||||||
|
await self._sender.send_all(value)
|
||||||
|
|
||||||
|
async def wait_send_all_might_not_block(self):
|
||||||
|
await self._sender.wait_send_all_might_not_block()
|
||||||
|
|
||||||
async def flush(
|
async def flush(
|
||||||
self,
|
self,
|
||||||
new_batch_size: int | None = None
|
new_batch_size: int | None = None
|
||||||
|
@ -678,6 +623,15 @@ class RingBuffChannel(trio.abc.Channel[bytes]):
|
||||||
async def send_eof(self) -> None:
|
async def send_eof(self) -> None:
|
||||||
await self._sender.send_eof()
|
await self._sender.send_eof()
|
||||||
|
|
||||||
|
def receive_nowait(self, max_bytes: int = _DEFAULT_RB_SIZE) -> bytes:
|
||||||
|
return self._receiver.receive_nowait(max_bytes=max_bytes)
|
||||||
|
|
||||||
|
async def receive_some(self, max_bytes: int = _DEFAULT_RB_SIZE) -> bytes:
|
||||||
|
return await self._receiver.receive_some(max_bytes=max_bytes)
|
||||||
|
|
||||||
|
async def receive_exactly(self, num_bytes: int) -> bytes:
|
||||||
|
return await self._receiver.receive_exactly(num_bytes)
|
||||||
|
|
||||||
async def receive(self) -> bytes:
|
async def receive(self) -> bytes:
|
||||||
return await self._receiver.receive()
|
return await self._receiver.receive()
|
||||||
|
|
||||||
|
@ -691,23 +645,20 @@ async def attach_to_ringbuf_channel(
|
||||||
token_in: RBToken,
|
token_in: RBToken,
|
||||||
token_out: RBToken,
|
token_out: RBToken,
|
||||||
cleanup_in: bool = True,
|
cleanup_in: bool = True,
|
||||||
cleanup_out: bool = True,
|
cleanup_out: bool = True
|
||||||
batch_size: int = 1
|
) -> AsyncContextManager[trio.StapledStream]:
|
||||||
) -> AsyncContextManager[RingBuffChannel]:
|
|
||||||
'''
|
'''
|
||||||
Attach to an already opened ringbuf pair and return
|
Attach to two previously opened `RBToken`s and return a `RingBufferChannel`
|
||||||
a `RingBuffChannel`.
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
async with (
|
async with (
|
||||||
attach_to_ringbuf_rchannel(
|
attach_to_ringbuf_receiver(
|
||||||
token_in,
|
token_in,
|
||||||
cleanup=cleanup_in
|
cleanup=cleanup_in
|
||||||
) as receiver,
|
) as receiver,
|
||||||
attach_to_ringbuf_schannel(
|
attach_to_ringbuf_sender(
|
||||||
token_out,
|
token_out,
|
||||||
cleanup=cleanup_out,
|
cleanup=cleanup_out
|
||||||
batch_size=batch_size
|
|
||||||
) as sender,
|
) as sender,
|
||||||
):
|
):
|
||||||
yield RingBuffChannel(sender, receiver)
|
yield RingBufferChannel(sender, receiver)
|
||||||
|
|
Loading…
Reference in New Issue