From b2f6c298f5ca23452cb688fca847ec7f6ca046c1 Mon Sep 17 00:00:00 2001 From: Guillermo Rodriguez Date: Sun, 6 Apr 2025 21:59:14 -0300 Subject: [PATCH] Refactor generate_sample_messages to be a generator and use numpy --- tests/test_ringbuf.py | 49 ++++++++----- tractor/_testing/samples.py | 135 ++++++++++++++++++++---------------- 2 files changed, 108 insertions(+), 76 deletions(-) diff --git a/tests/test_ringbuf.py b/tests/test_ringbuf.py index 6dfeae72..8f549b7c 100644 --- a/tests/test_ringbuf.py +++ b/tests/test_ringbuf.py @@ -14,7 +14,7 @@ from tractor.ipc._ringbuf import ( ) from tractor._testing.samples import ( generate_single_byte_msgs, - generate_sample_messages + RandomBytesGenerator ) # in case you don't want to melt your cores, uncomment dis! @@ -80,18 +80,22 @@ async def child_write_shm( Attach to ringbuf and send all generated messages. ''' - sent_hash, msgs, _total_bytes = generate_sample_messages( + rng = RandomBytesGenerator( msg_amount, rand_min=rand_min, rand_max=rand_max, ) - await ctx.started(sent_hash) + await ctx.started() print('writer started') async with attach_to_ringbuf_sender(token, cleanup=False) as sender: - for msg in msgs: + for msg in rng: await sender.send(msg) + if rng.msgs_generated % rng.recommended_log_interval == 0: + print(f'wrote {rng.total_msgs} msgs') + print('writer exit') + return rng.hexdigest @pytest.mark.parametrize( @@ -153,12 +157,14 @@ def test_ringbuf( msg_amount=msg_amount, rand_min=rand_min, rand_max=rand_max, - ) as (_sctx, sent_hash), + ) as (sctx, _), + recv_p.open_context( child_read_shm, token=token, - ) as (rctx, _sent), + ) as (rctx, _), ): + sent_hash = await sctx.result() recvd_hash = await rctx.result() assert sent_hash == recvd_hash @@ -300,7 +306,7 @@ async def child_channel_sender( token_out: RBToken ): import random - _hash, msgs, _total_bytes = generate_sample_messages( + rng = RandomBytesGenerator( random.randint(msg_amount_min, msg_amount_max), rand_min=256, rand_max=1024, @@ -309,10 +315,14 @@ async def child_channel_sender( token_in, token_out ) as chan: - await ctx.started(msgs) - for msg in msgs: + await ctx.started() + for msg in rng: await chan.send(msg) + await chan.send(b'bye') + await chan.receive() + return rng.hexdigest + def test_channel(): @@ -327,7 +337,7 @@ def test_channel(): attach_to_ringbuf_channel(send_token, recv_token) as chan, tractor.open_nursery() as an ): - recv_p = await an.start_actor( + sender = await an.start_actor( 'test_ringbuf_transport_sender', enable_modules=[__name__], proc_kwargs={ @@ -335,19 +345,26 @@ def test_channel(): } ) async with ( - recv_p.open_context( + sender.open_context( child_channel_sender, msg_amount_min=msg_amount_min, msg_amount_max=msg_amount_max, token_in=recv_token, token_out=send_token - ) as (ctx, msgs), + ) as (ctx, _), ): - recv_msgs = [] + recvd_hash = hashlib.sha256() async for msg in chan: - recv_msgs.append(msg) + if msg == b'bye': + await chan.send(b'bye') + break - await recv_p.cancel_actor() - assert recv_msgs == msgs + recvd_hash.update(msg) + + sent_hash = await ctx.result() + + assert recvd_hash.hexdigest() == sent_hash + + await an.cancel() trio.run(main) diff --git a/tractor/_testing/samples.py b/tractor/_testing/samples.py index f8671332..fcf41dfa 100644 --- a/tractor/_testing/samples.py +++ b/tractor/_testing/samples.py @@ -1,84 +1,99 @@ -import os -import random import hashlib +import numpy as np def generate_single_byte_msgs(amount: int) -> bytes: ''' - Generate a byte instance of len `amount` with: - - ``` - byte_at_index(i) = (i % 10).encode() - ``` - - this results in constantly repeating sequences of: - - b'0123456789' + Generate a byte instance of length `amount` with repeating ASCII digits 0..9. ''' - return b''.join(str(i % 10).encode() for i in range(amount)) + # array [0, 1, 2, ..., amount-1], take mod 10 => [0..9], and map 0->'0'(48) + # up to 9->'9'(57). + arr = np.arange(amount, dtype=np.uint8) % 10 + # move into ascii space + arr += 48 + return arr.tobytes() -def generate_sample_messages( - amount: int, - rand_min: int = 0, - rand_max: int = 0, - silent: bool = False, -) -> tuple[str, list[bytes], int]: +class RandomBytesGenerator: ''' Generate bytes msgs for tests. - Messages will have the following format: + messages will have the following format: - ``` - b'[{i:08}]' + os.urandom(random.randint(rand_min, rand_max)) - ``` + b'[{i:08}]' + random_bytes so for message index 25: - b'[00000025]' + random_bytes + b'[00000025]' + random_bytes + + also generates sha256 hash of msgs. ''' - msgs = [] - size = 0 - log_interval = None - if not silent: - print(f'\ngenerating {amount} messages...') + def __init__( + self, + amount: int, + rand_min: int = 0, + rand_max: int = 0 + ): + if rand_max < rand_min: + raise ValueError('rand_max must be >= rand_min') - # calculate an apropiate log interval based on - # max message size - max_msg_size = 10 + rand_max + self._amount = amount + self._rand_min = rand_min + self._rand_max = rand_max + self._index = 0 + self._hasher = hashlib.sha256() + self._total_bytes = 0 + + self._lengths = np.random.randint( + rand_min, + rand_max + 1, + size=amount, + dtype=np.int32 + ) + + def __iter__(self): + return self + + def __next__(self) -> bytes: + if self._index == self._amount: + raise StopIteration + + header = f'[{self._index:08}]'.encode('utf-8') + + length = int(self._lengths[self._index]) + msg = header + np.random.bytes(length) + + self._hasher.update(msg) + self._total_bytes += length + self._index += 1 + + return msg + + @property + def hexdigest(self) -> str: + return self._hasher.hexdigest() + + @property + def total_bytes(self) -> int: + return self._total_bytes + + @property + def total_msgs(self) -> int: + return self._amount + + @property + def msgs_generated(self) -> int: + return self._index + + @property + def recommended_log_interval(self) -> int: + max_msg_size = 10 + self._rand_max if max_msg_size <= 32 * 1024: - log_interval = 10_000 + return 10_000 else: - log_interval = 1000 - - payload_hash = hashlib.sha256() - for i in range(amount): - msg = f'[{i:08}]'.encode('utf-8') - - if rand_max > 0: - msg += os.urandom( - random.randint(rand_min, rand_max)) - - size += len(msg) - - payload_hash.update(msg) - msgs.append(msg) - - if ( - not silent - and - i > 0 - and - i % log_interval == 0 - ): - print(f'{i} generated') - - if not silent: - print(f'done, {size:,} bytes in total') - - return payload_hash.hexdigest(), msgs, size + return 1000