Refactor generate_sample_messages to be a generator and use numpy
parent
171545e4fb
commit
b2f6c298f5
|
@ -14,7 +14,7 @@ from tractor.ipc._ringbuf import (
|
|||
)
|
||||
from tractor._testing.samples import (
|
||||
generate_single_byte_msgs,
|
||||
generate_sample_messages
|
||||
RandomBytesGenerator
|
||||
)
|
||||
|
||||
# in case you don't want to melt your cores, uncomment dis!
|
||||
|
@ -80,18 +80,22 @@ async def child_write_shm(
|
|||
Attach to ringbuf and send all generated messages.
|
||||
|
||||
'''
|
||||
sent_hash, msgs, _total_bytes = generate_sample_messages(
|
||||
rng = RandomBytesGenerator(
|
||||
msg_amount,
|
||||
rand_min=rand_min,
|
||||
rand_max=rand_max,
|
||||
)
|
||||
await ctx.started(sent_hash)
|
||||
await ctx.started()
|
||||
print('writer started')
|
||||
async with attach_to_ringbuf_sender(token, cleanup=False) as sender:
|
||||
for msg in msgs:
|
||||
for msg in rng:
|
||||
await sender.send(msg)
|
||||
|
||||
if rng.msgs_generated % rng.recommended_log_interval == 0:
|
||||
print(f'wrote {rng.total_msgs} msgs')
|
||||
|
||||
print('writer exit')
|
||||
return rng.hexdigest
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -153,12 +157,14 @@ def test_ringbuf(
|
|||
msg_amount=msg_amount,
|
||||
rand_min=rand_min,
|
||||
rand_max=rand_max,
|
||||
) as (_sctx, sent_hash),
|
||||
) as (sctx, _),
|
||||
|
||||
recv_p.open_context(
|
||||
child_read_shm,
|
||||
token=token,
|
||||
) as (rctx, _sent),
|
||||
) as (rctx, _),
|
||||
):
|
||||
sent_hash = await sctx.result()
|
||||
recvd_hash = await rctx.result()
|
||||
|
||||
assert sent_hash == recvd_hash
|
||||
|
@ -300,7 +306,7 @@ async def child_channel_sender(
|
|||
token_out: RBToken
|
||||
):
|
||||
import random
|
||||
_hash, msgs, _total_bytes = generate_sample_messages(
|
||||
rng = RandomBytesGenerator(
|
||||
random.randint(msg_amount_min, msg_amount_max),
|
||||
rand_min=256,
|
||||
rand_max=1024,
|
||||
|
@ -309,10 +315,14 @@ async def child_channel_sender(
|
|||
token_in,
|
||||
token_out
|
||||
) as chan:
|
||||
await ctx.started(msgs)
|
||||
for msg in msgs:
|
||||
await ctx.started()
|
||||
for msg in rng:
|
||||
await chan.send(msg)
|
||||
|
||||
await chan.send(b'bye')
|
||||
await chan.receive()
|
||||
return rng.hexdigest
|
||||
|
||||
|
||||
def test_channel():
|
||||
|
||||
|
@ -327,7 +337,7 @@ def test_channel():
|
|||
attach_to_ringbuf_channel(send_token, recv_token) as chan,
|
||||
tractor.open_nursery() as an
|
||||
):
|
||||
recv_p = await an.start_actor(
|
||||
sender = await an.start_actor(
|
||||
'test_ringbuf_transport_sender',
|
||||
enable_modules=[__name__],
|
||||
proc_kwargs={
|
||||
|
@ -335,19 +345,26 @@ def test_channel():
|
|||
}
|
||||
)
|
||||
async with (
|
||||
recv_p.open_context(
|
||||
sender.open_context(
|
||||
child_channel_sender,
|
||||
msg_amount_min=msg_amount_min,
|
||||
msg_amount_max=msg_amount_max,
|
||||
token_in=recv_token,
|
||||
token_out=send_token
|
||||
) as (ctx, msgs),
|
||||
) as (ctx, _),
|
||||
):
|
||||
recv_msgs = []
|
||||
recvd_hash = hashlib.sha256()
|
||||
async for msg in chan:
|
||||
recv_msgs.append(msg)
|
||||
if msg == b'bye':
|
||||
await chan.send(b'bye')
|
||||
break
|
||||
|
||||
await recv_p.cancel_actor()
|
||||
assert recv_msgs == msgs
|
||||
recvd_hash.update(msg)
|
||||
|
||||
sent_hash = await ctx.result()
|
||||
|
||||
assert recvd_hash.hexdigest() == sent_hash
|
||||
|
||||
await an.cancel()
|
||||
|
||||
trio.run(main)
|
||||
|
|
|
@ -1,84 +1,99 @@
|
|||
import os
|
||||
import random
|
||||
import hashlib
|
||||
import numpy as np
|
||||
|
||||
|
||||
def generate_single_byte_msgs(amount: int) -> bytes:
|
||||
'''
|
||||
Generate a byte instance of len `amount` with:
|
||||
|
||||
```
|
||||
byte_at_index(i) = (i % 10).encode()
|
||||
```
|
||||
|
||||
this results in constantly repeating sequences of:
|
||||
|
||||
b'0123456789'
|
||||
Generate a byte instance of length `amount` with repeating ASCII digits 0..9.
|
||||
|
||||
'''
|
||||
return b''.join(str(i % 10).encode() for i in range(amount))
|
||||
# array [0, 1, 2, ..., amount-1], take mod 10 => [0..9], and map 0->'0'(48)
|
||||
# up to 9->'9'(57).
|
||||
arr = np.arange(amount, dtype=np.uint8) % 10
|
||||
# move into ascii space
|
||||
arr += 48
|
||||
return arr.tobytes()
|
||||
|
||||
|
||||
def generate_sample_messages(
|
||||
amount: int,
|
||||
rand_min: int = 0,
|
||||
rand_max: int = 0,
|
||||
silent: bool = False,
|
||||
) -> tuple[str, list[bytes], int]:
|
||||
class RandomBytesGenerator:
|
||||
'''
|
||||
Generate bytes msgs for tests.
|
||||
|
||||
Messages will have the following format:
|
||||
messages will have the following format:
|
||||
|
||||
```
|
||||
b'[{i:08}]' + os.urandom(random.randint(rand_min, rand_max))
|
||||
```
|
||||
b'[{i:08}]' + random_bytes
|
||||
|
||||
so for message index 25:
|
||||
|
||||
b'[00000025]' + random_bytes
|
||||
b'[00000025]' + random_bytes
|
||||
|
||||
also generates sha256 hash of msgs.
|
||||
|
||||
'''
|
||||
msgs = []
|
||||
size = 0
|
||||
|
||||
log_interval = None
|
||||
if not silent:
|
||||
print(f'\ngenerating {amount} messages...')
|
||||
def __init__(
|
||||
self,
|
||||
amount: int,
|
||||
rand_min: int = 0,
|
||||
rand_max: int = 0
|
||||
):
|
||||
if rand_max < rand_min:
|
||||
raise ValueError('rand_max must be >= rand_min')
|
||||
|
||||
# calculate an apropiate log interval based on
|
||||
# max message size
|
||||
max_msg_size = 10 + rand_max
|
||||
self._amount = amount
|
||||
self._rand_min = rand_min
|
||||
self._rand_max = rand_max
|
||||
self._index = 0
|
||||
self._hasher = hashlib.sha256()
|
||||
self._total_bytes = 0
|
||||
|
||||
self._lengths = np.random.randint(
|
||||
rand_min,
|
||||
rand_max + 1,
|
||||
size=amount,
|
||||
dtype=np.int32
|
||||
)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self) -> bytes:
|
||||
if self._index == self._amount:
|
||||
raise StopIteration
|
||||
|
||||
header = f'[{self._index:08}]'.encode('utf-8')
|
||||
|
||||
length = int(self._lengths[self._index])
|
||||
msg = header + np.random.bytes(length)
|
||||
|
||||
self._hasher.update(msg)
|
||||
self._total_bytes += length
|
||||
self._index += 1
|
||||
|
||||
return msg
|
||||
|
||||
@property
|
||||
def hexdigest(self) -> str:
|
||||
return self._hasher.hexdigest()
|
||||
|
||||
@property
|
||||
def total_bytes(self) -> int:
|
||||
return self._total_bytes
|
||||
|
||||
@property
|
||||
def total_msgs(self) -> int:
|
||||
return self._amount
|
||||
|
||||
@property
|
||||
def msgs_generated(self) -> int:
|
||||
return self._index
|
||||
|
||||
@property
|
||||
def recommended_log_interval(self) -> int:
|
||||
max_msg_size = 10 + self._rand_max
|
||||
|
||||
if max_msg_size <= 32 * 1024:
|
||||
log_interval = 10_000
|
||||
return 10_000
|
||||
|
||||
else:
|
||||
log_interval = 1000
|
||||
|
||||
payload_hash = hashlib.sha256()
|
||||
for i in range(amount):
|
||||
msg = f'[{i:08}]'.encode('utf-8')
|
||||
|
||||
if rand_max > 0:
|
||||
msg += os.urandom(
|
||||
random.randint(rand_min, rand_max))
|
||||
|
||||
size += len(msg)
|
||||
|
||||
payload_hash.update(msg)
|
||||
msgs.append(msg)
|
||||
|
||||
if (
|
||||
not silent
|
||||
and
|
||||
i > 0
|
||||
and
|
||||
i % log_interval == 0
|
||||
):
|
||||
print(f'{i} generated')
|
||||
|
||||
if not silent:
|
||||
print(f'done, {size:,} bytes in total')
|
||||
|
||||
return payload_hash.hexdigest(), msgs, size
|
||||
return 1000
|
||||
|
|
Loading…
Reference in New Issue