Refactor generate_sample_messages to be a generator and use numpy

one_ring_to_rule_them_all
Guillermo Rodriguez 2025-04-06 21:59:14 -03:00
parent 171545e4fb
commit b2f6c298f5
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
2 changed files with 108 additions and 76 deletions

View File

@ -14,7 +14,7 @@ from tractor.ipc._ringbuf import (
) )
from tractor._testing.samples import ( from tractor._testing.samples import (
generate_single_byte_msgs, generate_single_byte_msgs,
generate_sample_messages RandomBytesGenerator
) )
# in case you don't want to melt your cores, uncomment dis! # in case you don't want to melt your cores, uncomment dis!
@ -80,18 +80,22 @@ async def child_write_shm(
Attach to ringbuf and send all generated messages. Attach to ringbuf and send all generated messages.
''' '''
sent_hash, msgs, _total_bytes = generate_sample_messages( rng = RandomBytesGenerator(
msg_amount, msg_amount,
rand_min=rand_min, rand_min=rand_min,
rand_max=rand_max, rand_max=rand_max,
) )
await ctx.started(sent_hash) await ctx.started()
print('writer started') print('writer started')
async with attach_to_ringbuf_sender(token, cleanup=False) as sender: async with attach_to_ringbuf_sender(token, cleanup=False) as sender:
for msg in msgs: for msg in rng:
await sender.send(msg) await sender.send(msg)
if rng.msgs_generated % rng.recommended_log_interval == 0:
print(f'wrote {rng.total_msgs} msgs')
print('writer exit') print('writer exit')
return rng.hexdigest
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -153,12 +157,14 @@ def test_ringbuf(
msg_amount=msg_amount, msg_amount=msg_amount,
rand_min=rand_min, rand_min=rand_min,
rand_max=rand_max, rand_max=rand_max,
) as (_sctx, sent_hash), ) as (sctx, _),
recv_p.open_context( recv_p.open_context(
child_read_shm, child_read_shm,
token=token, token=token,
) as (rctx, _sent), ) as (rctx, _),
): ):
sent_hash = await sctx.result()
recvd_hash = await rctx.result() recvd_hash = await rctx.result()
assert sent_hash == recvd_hash assert sent_hash == recvd_hash
@ -300,7 +306,7 @@ async def child_channel_sender(
token_out: RBToken token_out: RBToken
): ):
import random import random
_hash, msgs, _total_bytes = generate_sample_messages( rng = RandomBytesGenerator(
random.randint(msg_amount_min, msg_amount_max), random.randint(msg_amount_min, msg_amount_max),
rand_min=256, rand_min=256,
rand_max=1024, rand_max=1024,
@ -309,10 +315,14 @@ async def child_channel_sender(
token_in, token_in,
token_out token_out
) as chan: ) as chan:
await ctx.started(msgs) await ctx.started()
for msg in msgs: for msg in rng:
await chan.send(msg) await chan.send(msg)
await chan.send(b'bye')
await chan.receive()
return rng.hexdigest
def test_channel(): def test_channel():
@ -327,7 +337,7 @@ def test_channel():
attach_to_ringbuf_channel(send_token, recv_token) as chan, attach_to_ringbuf_channel(send_token, recv_token) as chan,
tractor.open_nursery() as an tractor.open_nursery() as an
): ):
recv_p = await an.start_actor( sender = await an.start_actor(
'test_ringbuf_transport_sender', 'test_ringbuf_transport_sender',
enable_modules=[__name__], enable_modules=[__name__],
proc_kwargs={ proc_kwargs={
@ -335,19 +345,26 @@ def test_channel():
} }
) )
async with ( async with (
recv_p.open_context( sender.open_context(
child_channel_sender, child_channel_sender,
msg_amount_min=msg_amount_min, msg_amount_min=msg_amount_min,
msg_amount_max=msg_amount_max, msg_amount_max=msg_amount_max,
token_in=recv_token, token_in=recv_token,
token_out=send_token token_out=send_token
) as (ctx, msgs), ) as (ctx, _),
): ):
recv_msgs = [] recvd_hash = hashlib.sha256()
async for msg in chan: async for msg in chan:
recv_msgs.append(msg) if msg == b'bye':
await chan.send(b'bye')
break
await recv_p.cancel_actor() recvd_hash.update(msg)
assert recv_msgs == msgs
sent_hash = await ctx.result()
assert recvd_hash.hexdigest() == sent_hash
await an.cancel()
trio.run(main) trio.run(main)

View File

@ -1,84 +1,99 @@
import os
import random
import hashlib import hashlib
import numpy as np
def generate_single_byte_msgs(amount: int) -> bytes: def generate_single_byte_msgs(amount: int) -> bytes:
''' '''
Generate a byte instance of len `amount` with: Generate a byte instance of length `amount` with repeating ASCII digits 0..9.
```
byte_at_index(i) = (i % 10).encode()
```
this results in constantly repeating sequences of:
b'0123456789'
''' '''
return b''.join(str(i % 10).encode() for i in range(amount)) # array [0, 1, 2, ..., amount-1], take mod 10 => [0..9], and map 0->'0'(48)
# up to 9->'9'(57).
arr = np.arange(amount, dtype=np.uint8) % 10
# move into ascii space
arr += 48
return arr.tobytes()
def generate_sample_messages( class RandomBytesGenerator:
amount: int,
rand_min: int = 0,
rand_max: int = 0,
silent: bool = False,
) -> tuple[str, list[bytes], int]:
''' '''
Generate bytes msgs for tests. Generate bytes msgs for tests.
Messages will have the following format: messages will have the following format:
``` b'[{i:08}]' + random_bytes
b'[{i:08}]' + os.urandom(random.randint(rand_min, rand_max))
```
so for message index 25: so for message index 25:
b'[00000025]' + random_bytes b'[00000025]' + random_bytes
also generates sha256 hash of msgs.
''' '''
msgs = []
size = 0
log_interval = None def __init__(
if not silent: self,
print(f'\ngenerating {amount} messages...') amount: int,
rand_min: int = 0,
rand_max: int = 0
):
if rand_max < rand_min:
raise ValueError('rand_max must be >= rand_min')
# calculate an apropiate log interval based on self._amount = amount
# max message size self._rand_min = rand_min
max_msg_size = 10 + rand_max self._rand_max = rand_max
self._index = 0
self._hasher = hashlib.sha256()
self._total_bytes = 0
self._lengths = np.random.randint(
rand_min,
rand_max + 1,
size=amount,
dtype=np.int32
)
def __iter__(self):
return self
def __next__(self) -> bytes:
if self._index == self._amount:
raise StopIteration
header = f'[{self._index:08}]'.encode('utf-8')
length = int(self._lengths[self._index])
msg = header + np.random.bytes(length)
self._hasher.update(msg)
self._total_bytes += length
self._index += 1
return msg
@property
def hexdigest(self) -> str:
return self._hasher.hexdigest()
@property
def total_bytes(self) -> int:
return self._total_bytes
@property
def total_msgs(self) -> int:
return self._amount
@property
def msgs_generated(self) -> int:
return self._index
@property
def recommended_log_interval(self) -> int:
max_msg_size = 10 + self._rand_max
if max_msg_size <= 32 * 1024: if max_msg_size <= 32 * 1024:
log_interval = 10_000 return 10_000
else: else:
log_interval = 1000 return 1000
payload_hash = hashlib.sha256()
for i in range(amount):
msg = f'[{i:08}]'.encode('utf-8')
if rand_max > 0:
msg += os.urandom(
random.randint(rand_min, rand_max))
size += len(msg)
payload_hash.update(msg)
msgs.append(msg)
if (
not silent
and
i > 0
and
i % log_interval == 0
):
print(f'{i} generated')
if not silent:
print(f'done, {size:,} bytes in total')
return payload_hash.hexdigest(), msgs, size