Improve test_ringbuf test, drop MsgTransport ring buf impl for now in favour of a trio.abc.Channel[bytes] impl, add docstrings

Guillermo Rodriguez 2025-03-18 13:19:40 -03:00
parent 5cec4ee943
commit 2901049b5b
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
4 changed files with 210 additions and 198 deletions

View File

@ -1,4 +1,5 @@
import time import time
import hashlib
import trio import trio
import pytest import pytest
@ -7,8 +8,8 @@ from tractor.ipc import (
open_ringbuf, open_ringbuf,
attach_to_ringbuf_receiver, attach_to_ringbuf_receiver,
attach_to_ringbuf_sender, attach_to_ringbuf_sender,
attach_to_ringbuf_pair,
attach_to_ringbuf_stream, attach_to_ringbuf_stream,
attach_to_ringbuf_channel,
RBToken, RBToken,
) )
from tractor._testing.samples import ( from tractor._testing.samples import (
@ -22,12 +23,26 @@ async def child_read_shm(
ctx: tractor.Context, ctx: tractor.Context,
msg_amount: int, msg_amount: int,
token: RBToken, token: RBToken,
) -> None: ) -> str:
recvd_bytes = 0 '''
Sub-actor used in `test_ringbuf`.
Attach to a ringbuf and receive all messages until end of stream.
Keep track of how many bytes received and also calculate
sha256 of the whole byte stream.
Calculate and print performance stats, finally return calculated
hash.
'''
await ctx.started() await ctx.started()
print('reader started')
recvd_bytes = 0
recvd_hash = hashlib.sha256()
start_ts = time.time() start_ts = time.time()
async with attach_to_ringbuf_receiver(token) as receiver: async with attach_to_ringbuf_receiver(token) as receiver:
async for msg in receiver: async for msg in receiver:
recvd_hash.update(msg)
recvd_bytes += len(msg) recvd_bytes += len(msg)
end_ts = time.time() end_ts = time.time()
@ -37,7 +52,9 @@ async def child_read_shm(
print(f'\n\telapsed ms: {elapsed_ms}') print(f'\n\telapsed ms: {elapsed_ms}')
print(f'\tmsg/sec: {int(msg_amount / elapsed):,}') print(f'\tmsg/sec: {int(msg_amount / elapsed):,}')
print(f'\tbytes/sec: {int(recvd_bytes / elapsed):,}') print(f'\tbytes/sec: {int(recvd_bytes / elapsed):,}')
print(f'\treceived bytes: {recvd_bytes}') print(f'\treceived bytes: {recvd_bytes:,}')
return recvd_hash.hexdigest()
@tractor.context @tractor.context
@ -48,12 +65,26 @@ async def child_write_shm(
rand_max: int, rand_max: int,
token: RBToken, token: RBToken,
) -> None: ) -> None:
msgs, total_bytes = generate_sample_messages( '''
Sub-actor used in `test_ringbuf`
Generate `msg_amount` payloads with
`random.randint(rand_min, rand_max)` random bytes at the end,
Calculate sha256 hash and send it to parent on `ctx.started`.
Attach to ringbuf and send all generated messages.
'''
msgs, _total_bytes = generate_sample_messages(
msg_amount, msg_amount,
rand_min=rand_min, rand_min=rand_min,
rand_max=rand_max, rand_max=rand_max,
) )
await ctx.started(total_bytes) print('writer hashing payload...')
sent_hash = hashlib.sha256(b''.join(msgs)).hexdigest()
print('writer done hashing.')
await ctx.started(sent_hash)
print('writer started')
async with attach_to_ringbuf_sender(token, cleanup=False) as sender: async with attach_to_ringbuf_sender(token, cleanup=False) as sender:
for msg in msgs: for msg in msgs:
await sender.send_all(msg) await sender.send_all(msg)
@ -87,11 +118,12 @@ def test_ringbuf(
): ):
''' '''
- Open a new ring buf on root actor - Open a new ring buf on root actor
- Create a sender subactor and generate {msg_amount} messages - Open `child_write_shm` ctx in sub-actor which will generate a
optionally with a random amount of bytes at the end of each, random payload and send its hash on `ctx.started`, finally sending
return total_bytes on `ctx.started`, then send all messages the payload through the stream.
- Create a receiver subactor and receive until total_bytes are - Open `child_read_shm` ctx in sub-actor which will receive the
read, print simple perf stats. payload, calculate perf stats and return the hash.
- Compare both hashes
''' '''
async def main(): async def main():
@ -119,14 +151,16 @@ def test_ringbuf(
msg_amount=msg_amount, msg_amount=msg_amount,
rand_min=rand_min, rand_min=rand_min,
rand_max=rand_max, rand_max=rand_max,
) as (sctx, total_bytes), ) as (_sctx, sent_hash),
recv_p.open_context( recv_p.open_context(
child_read_shm, child_read_shm,
token=token, token=token,
msg_amount=msg_amount msg_amount=msg_amount
) as (sctx, _sent), ) as (rctx, _sent),
): ):
await recv_p.result() recvd_hash = await rctx.result()
assert sent_hash == recvd_hash
await send_p.cancel_actor() await send_p.cancel_actor()
await recv_p.cancel_actor() await recv_p.cancel_actor()
@ -274,7 +308,7 @@ def test_stapled_ringbuf():
pair_1_done = trio.Event() pair_1_done = trio.Event()
async def pair_0(token_in: RBToken, token_out: RBToken): async def pair_0(token_in: RBToken, token_out: RBToken):
async with attach_to_ringbuf_pair( async with attach_to_ringbuf_stream(
token_in, token_in,
token_out, token_out,
cleanup_in=False, cleanup_in=False,
@ -293,7 +327,7 @@ def test_stapled_ringbuf():
async def pair_1(token_in: RBToken, token_out: RBToken): async def pair_1(token_in: RBToken, token_out: RBToken):
async with attach_to_ringbuf_pair( async with attach_to_ringbuf_stream(
token_in, token_in,
token_out, token_out,
cleanup_in=False, cleanup_in=False,
@ -327,7 +361,7 @@ def test_stapled_ringbuf():
@tractor.context @tractor.context
async def child_transport_sender( async def child_channel_sender(
ctx: tractor.Context, ctx: tractor.Context,
msg_amount_min: int, msg_amount_min: int,
msg_amount_max: int, msg_amount_max: int,
@ -340,19 +374,17 @@ async def child_transport_sender(
rand_min=256, rand_min=256,
rand_max=1024, rand_max=1024,
) )
async with attach_to_ringbuf_stream( async with attach_to_ringbuf_channel(
token_in, token_in,
token_out token_out
) as transport: ) as chan:
await ctx.started(msgs) await ctx.started(msgs)
for msg in msgs: for msg in msgs:
await transport.send(msg) await chan.send(msg)
await transport.recv()
def test_ringbuf_transport(): def test_ringbuf_channel():
msg_amount_min = 100 msg_amount_min = 100
msg_amount_max = 1000 msg_amount_max = 1000
@ -362,7 +394,7 @@ def test_ringbuf_transport():
'test_ringbuf_transport' 'test_ringbuf_transport'
) as (token_0, token_1): ) as (token_0, token_1):
async with ( async with (
attach_to_ringbuf_stream(token_0, token_1) as transport, attach_to_ringbuf_channel(token_0, token_1) as chan,
tractor.open_nursery() as an tractor.open_nursery() as an
): ):
recv_p = await an.start_actor( recv_p = await an.start_actor(
@ -374,7 +406,7 @@ def test_ringbuf_transport():
) )
async with ( async with (
recv_p.open_context( recv_p.open_context(
child_transport_sender, child_channel_sender,
msg_amount_min=msg_amount_min, msg_amount_min=msg_amount_min,
msg_amount_max=msg_amount_max, msg_amount_max=msg_amount_max,
token_in=token_1, token_in=token_1,
@ -382,10 +414,9 @@ def test_ringbuf_transport():
) as (ctx, msgs), ) as (ctx, msgs),
): ):
recv_msgs = [] recv_msgs = []
while len(recv_msgs) < len(msgs): async for msg in chan:
recv_msgs.append(await transport.recv()) recv_msgs.append(msg)
await transport.send(b'end')
await recv_p.cancel_actor() await recv_p.cancel_actor()
assert recv_msgs == msgs assert recv_msgs == msgs

View File

@ -3,6 +3,18 @@ import random
def generate_single_byte_msgs(amount: int) -> bytes: def generate_single_byte_msgs(amount: int) -> bytes:
'''
Generate a byte instance of len `amount` with:
```
byte_at_index(i) = (i % 10).encode()
```
this results in constantly repeating sequences of:
b'0123456789'
'''
return b''.join(str(i % 10).encode() for i in range(amount)) return b''.join(str(i % 10).encode() for i in range(amount))
@ -10,15 +22,39 @@ def generate_sample_messages(
amount: int, amount: int,
rand_min: int = 0, rand_min: int = 0,
rand_max: int = 0, rand_max: int = 0,
silent: bool = False silent: bool = False,
) -> tuple[list[bytes], int]: ) -> tuple[list[bytes], int]:
'''
Generate bytes msgs for tests.
Messages will have the following format:
```
b'[{i:08}]' + os.urandom(random.randint(rand_min, rand_max))
```
so for message index 25:
b'[00000025]' + random_bytes
'''
msgs = [] msgs = []
size = 0 size = 0
log_interval = None
if not silent: if not silent:
print(f'\ngenerating {amount} messages...') print(f'\ngenerating {amount} messages...')
# calculate an apropiate log interval based on
# max message size
max_msg_size = 10 + rand_max
if max_msg_size <= 32 * 1024:
log_interval = 10_000
else:
log_interval = 1000
for i in range(amount): for i in range(amount):
msg = f'[{i:08}]'.encode('utf-8') msg = f'[{i:08}]'.encode('utf-8')
@ -30,7 +66,13 @@ def generate_sample_messages(
msgs.append(msg) msgs.append(msg)
if not silent and i and i % 10_000 == 0: if (
not silent
and
i > 0
and
i % log_interval == 0
):
print(f'{i} generated') print(f'{i} generated')
if not silent: if not silent:

View File

@ -51,7 +51,9 @@ if platform.system() == 'Linux':
open_ringbuf_pair as open_ringbuf_pair, open_ringbuf_pair as open_ringbuf_pair,
attach_to_ringbuf_receiver as attach_to_ringbuf_receiver, attach_to_ringbuf_receiver as attach_to_ringbuf_receiver,
attach_to_ringbuf_sender as attach_to_ringbuf_sender, attach_to_ringbuf_sender as attach_to_ringbuf_sender,
attach_to_ringbuf_pair as attach_to_ringbuf_pair,
attach_to_ringbuf_stream as attach_to_ringbuf_stream, attach_to_ringbuf_stream as attach_to_ringbuf_stream,
MsgpackRBStream as MsgpackRBStream RingBuffBytesSender as RingBuffBytesSender,
RingBuffBytesReceiver as RingBuffBytesReceiver,
RingBuffChannel as RingBuffChannel,
attach_to_ringbuf_channel as attach_to_ringbuf_channel
) )

View File

@ -19,17 +19,10 @@ IPC Reliable RingBuffer implementation
''' '''
from __future__ import annotations from __future__ import annotations
import struct import struct
from collections.abc import (
AsyncGenerator,
AsyncIterator
)
from contextlib import ( from contextlib import (
contextmanager as cm, contextmanager as cm,
asynccontextmanager as acm asynccontextmanager as acm
) )
from typing import (
Any
)
from multiprocessing.shared_memory import SharedMemory from multiprocessing.shared_memory import SharedMemory
import trio import trio
@ -48,10 +41,8 @@ from ._linux import (
from ._mp_bs import disable_mantracker from ._mp_bs import disable_mantracker
from tractor.log import get_logger from tractor.log import get_logger
from tractor._exceptions import ( from tractor._exceptions import (
TransportClosed,
InternalError InternalError
) )
from tractor.ipc import MsgTransport
log = get_logger(__name__) log = get_logger(__name__)
@ -147,6 +138,7 @@ def open_ringbuf(
Buffer = bytes | bytearray | memoryview Buffer = bytes | bytearray | memoryview
''' '''
IPC Reliable Ring Buffer IPC Reliable Ring Buffer
@ -406,7 +398,7 @@ async def attach_to_ringbuf_receiver(
cleanup: bool = True cleanup: bool = True
): ):
''' '''
Instantiate a RingBuffReceiver from a previously opened Attach a RingBuffReceiver from a previously opened
RBToken. RBToken.
Launches `receiver._eof_monitor_task` in a `trio.Nursery`. Launches `receiver._eof_monitor_task` in a `trio.Nursery`.
@ -421,13 +413,14 @@ async def attach_to_ringbuf_receiver(
n.start_soon(receiver._eof_monitor_task) n.start_soon(receiver._eof_monitor_task)
yield receiver yield receiver
@acm @acm
async def attach_to_ringbuf_sender( async def attach_to_ringbuf_sender(
token: RBToken, token: RBToken,
cleanup: bool = True cleanup: bool = True
): ):
''' '''
Instantiate a RingBuffSender from a previously opened Attach a RingBuffSender from a previously opened
RBToken. RBToken.
''' '''
@ -463,14 +456,14 @@ def open_ringbuf_pair(
@acm @acm
async def attach_to_ringbuf_pair( async def attach_to_ringbuf_stream(
token_in: RBToken, token_in: RBToken,
token_out: RBToken, token_out: RBToken,
cleanup_in: bool = True, cleanup_in: bool = True,
cleanup_out: bool = True cleanup_out: bool = True
): ):
''' '''
Instantiate a trio.StapledStream from a previously opened Attach a trio.StapledStream from a previously opened
ringbuf pair. ringbuf pair.
''' '''
@ -487,180 +480,124 @@ async def attach_to_ringbuf_pair(
yield trio.StapledStream(sender, receiver) yield trio.StapledStream(sender, receiver)
class MsgpackRBStream(MsgTransport):
class RingBuffBytesSender(trio.abc.SendChannel[bytes]):
'''
In order to guarantee full messages are received, all bytes
sent by `RingBuffBytesSender` are preceded with a 4 byte header
which decodes into a uint32 indicating the actual size of the
next payload.
'''
def __init__( def __init__(
self, self,
stream: trio.StapledStream sender: RingBuffSender
): ):
self.stream = stream self._sender = sender
# create read loop intance
self._aiter_pkts = self._iter_packets()
self._send_lock = trio.StrictFIFOLock() self._send_lock = trio.StrictFIFOLock()
self.drained: list[dict] = [] async def send(self, value: bytes) -> None:
self.recv_stream = BufferedReceiveStream(
transport_stream=stream
)
async def _iter_packets(self) -> AsyncGenerator[dict, None]:
'''
Yield `bytes`-blob decoded packets from the underlying TCP
stream using the current task's `MsgCodec`.
This is a streaming routine implemented as an async generator
func (which was the original design, but could be changed?)
and is allocated by a `.__call__()` inside `.__init__()` where
it is assigned to the `._aiter_pkts` attr.
'''
while True:
try:
header: bytes = await self.recv_stream.receive_exactly(4)
except (
ValueError,
ConnectionResetError,
# not sure entirely why we need this but without it we
# seem to be getting racy failures here on
# arbiter/registry name subs..
trio.BrokenResourceError,
) as trans_err:
loglevel = 'transport'
match trans_err:
# case (
# ConnectionResetError()
# ):
# loglevel = 'transport'
# peer actor (graceful??) TCP EOF but `tricycle`
# seems to raise a 0-bytes-read?
case ValueError() if (
'unclean EOF' in trans_err.args[0]
):
pass
# peer actor (task) prolly shutdown quickly due
# to cancellation
case trio.BrokenResourceError() if (
'Connection reset by peer' in trans_err.args[0]
):
pass
# unless the disconnect condition falls under "a
# normal operation breakage" we usualy console warn
# about it.
case _:
loglevel: str = 'warning'
raise TransportClosed(
message=(
f'IPC transport already closed by peer\n'
f'x)> {type(trans_err)}\n'
f' |_{self}\n'
),
loglevel=loglevel,
) from trans_err
# XXX definitely can happen if transport is closed
# manually by another `trio.lowlevel.Task` in the
# same actor; we use this in some simulated fault
# testing for ex, but generally should never happen
# under normal operation!
#
# NOTE: as such we always re-raise this error from the
# RPC msg loop!
except trio.ClosedResourceError as closure_err:
raise TransportClosed(
message=(
f'IPC transport already manually closed locally?\n'
f'x)> {type(closure_err)} \n'
f' |_{self}\n'
),
loglevel='error',
raise_on_report=(
closure_err.args[0] == 'another task closed this fd'
or
closure_err.args[0] in ['another task closed this fd']
),
) from closure_err
# graceful EOF disconnect
if header == b'':
raise TransportClosed(
message=(
f'IPC transport already gracefully closed\n'
f')>\n'
f'|_{self}\n'
),
loglevel='transport',
# cause=??? # handy or no?
)
size: int
size, = struct.unpack("<I", header)
log.transport(f'received header {size}') # type: ignore
msg_bytes: bytes = await self.recv_stream.receive_exactly(size)
log.transport(f"received {msg_bytes}") # type: ignore
yield msg_bytes
async def send(
self,
msg: bytes,
) -> None:
'''
Send a msgpack encoded py-object-blob-as-msg.
'''
async with self._send_lock: async with self._send_lock:
size: bytes = struct.pack("<I", len(msg)) size: bytes = struct.pack("<I", len(value))
return await self.stream.send_all(size + msg) return await self._sender.send_all(size + value)
async def recv(self) -> Any: async def aclose(self) -> None:
return await self._aiter_pkts.asend(None) async with self._send_lock:
await self._sender.aclose()
async def drain(self) -> AsyncIterator[dict]:
class RingBuffBytesReceiver(trio.abc.ReceiveChannel[bytes]):
'''
See `RingBuffBytesSender` docstring.
A `tricycle.BufferedReceiveStream` is used for the
`receive_exactly` API.
'''
def __init__(
self,
receiver: RingBuffReceiver
):
self._receiver = receiver
async def _receive_exactly(self, num_bytes: int) -> bytes:
''' '''
Drain the stream's remaining messages sent from Fetch bytes from receiver until we read exactly `num_bytes`
the far end until the connection is closed by or end of stream is signaled.
the peer.
''' '''
try: payload = b''
async for msg in self._iter_packets(): while len(payload) < num_bytes:
self.drained.append(msg) remaining = num_bytes - len(payload)
except TransportClosed:
for msg in self.drained:
yield msg
def __aiter__(self): new_bytes = await self._receiver.receive_some(
return self._aiter_pkts max_bytes=remaining
)
if new_bytes == b'':
raise trio.EndOfChannel
payload += new_bytes
return payload
async def receive(self) -> bytes:
header: bytes = await self._receive_exactly(4)
size: int
size, = struct.unpack("<I", header)
return await self._receive_exactly(size)
async def aclose(self) -> None:
await self._receiver.aclose()
class RingBuffChannel(trio.abc.Channel[bytes]):
'''
Combine `RingBuffBytesSender` and `RingBuffBytesReceiver`
in order to expose the bidirectional `trio.abc.Channel` API.
'''
def __init__(
self,
sender: RingBuffBytesSender,
receiver: RingBuffBytesReceiver
):
self._sender = sender
self._receiver = receiver
async def send(self, value: bytes):
await self._sender.send(value)
async def receive(self) -> bytes:
return await self._receiver.receive()
async def aclose(self):
await self._receiver.aclose()
await self._sender.aclose()
@acm @acm
async def attach_to_ringbuf_stream( async def attach_to_ringbuf_channel(
token_in: RBToken, token_in: RBToken,
token_out: RBToken, token_out: RBToken,
cleanup_in: bool = True, cleanup_in: bool = True,
cleanup_out: bool = True cleanup_out: bool = True
): ):
''' '''
Wrap a ringbuf trio.StapledStream in a MsgpackRBStream Attach to an already opened ringbuf pair and return
a `RingBuffChannel`.
''' '''
async with attach_to_ringbuf_pair( async with (
token_in, attach_to_ringbuf_receiver(
token_out, token_in,
cleanup_in=cleanup_in, cleanup=cleanup_in
cleanup_out=cleanup_out ) as receiver,
) as stream: attach_to_ringbuf_sender(
yield MsgpackRBStream(stream) token_out,
cleanup=cleanup_out
) as sender,
):
yield RingBuffChannel(
RingBuffBytesSender(sender),
RingBuffBytesReceiver(receiver)
)