Refactor generate_sample_messages to be a generator and use numpy
parent
171545e4fb
commit
b2f6c298f5
|
@ -14,7 +14,7 @@ from tractor.ipc._ringbuf import (
|
||||||
)
|
)
|
||||||
from tractor._testing.samples import (
|
from tractor._testing.samples import (
|
||||||
generate_single_byte_msgs,
|
generate_single_byte_msgs,
|
||||||
generate_sample_messages
|
RandomBytesGenerator
|
||||||
)
|
)
|
||||||
|
|
||||||
# in case you don't want to melt your cores, uncomment dis!
|
# in case you don't want to melt your cores, uncomment dis!
|
||||||
|
@ -80,18 +80,22 @@ async def child_write_shm(
|
||||||
Attach to ringbuf and send all generated messages.
|
Attach to ringbuf and send all generated messages.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
sent_hash, msgs, _total_bytes = generate_sample_messages(
|
rng = RandomBytesGenerator(
|
||||||
msg_amount,
|
msg_amount,
|
||||||
rand_min=rand_min,
|
rand_min=rand_min,
|
||||||
rand_max=rand_max,
|
rand_max=rand_max,
|
||||||
)
|
)
|
||||||
await ctx.started(sent_hash)
|
await ctx.started()
|
||||||
print('writer started')
|
print('writer started')
|
||||||
async with attach_to_ringbuf_sender(token, cleanup=False) as sender:
|
async with attach_to_ringbuf_sender(token, cleanup=False) as sender:
|
||||||
for msg in msgs:
|
for msg in rng:
|
||||||
await sender.send(msg)
|
await sender.send(msg)
|
||||||
|
|
||||||
|
if rng.msgs_generated % rng.recommended_log_interval == 0:
|
||||||
|
print(f'wrote {rng.total_msgs} msgs')
|
||||||
|
|
||||||
print('writer exit')
|
print('writer exit')
|
||||||
|
return rng.hexdigest
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -153,12 +157,14 @@ def test_ringbuf(
|
||||||
msg_amount=msg_amount,
|
msg_amount=msg_amount,
|
||||||
rand_min=rand_min,
|
rand_min=rand_min,
|
||||||
rand_max=rand_max,
|
rand_max=rand_max,
|
||||||
) as (_sctx, sent_hash),
|
) as (sctx, _),
|
||||||
|
|
||||||
recv_p.open_context(
|
recv_p.open_context(
|
||||||
child_read_shm,
|
child_read_shm,
|
||||||
token=token,
|
token=token,
|
||||||
) as (rctx, _sent),
|
) as (rctx, _),
|
||||||
):
|
):
|
||||||
|
sent_hash = await sctx.result()
|
||||||
recvd_hash = await rctx.result()
|
recvd_hash = await rctx.result()
|
||||||
|
|
||||||
assert sent_hash == recvd_hash
|
assert sent_hash == recvd_hash
|
||||||
|
@ -300,7 +306,7 @@ async def child_channel_sender(
|
||||||
token_out: RBToken
|
token_out: RBToken
|
||||||
):
|
):
|
||||||
import random
|
import random
|
||||||
_hash, msgs, _total_bytes = generate_sample_messages(
|
rng = RandomBytesGenerator(
|
||||||
random.randint(msg_amount_min, msg_amount_max),
|
random.randint(msg_amount_min, msg_amount_max),
|
||||||
rand_min=256,
|
rand_min=256,
|
||||||
rand_max=1024,
|
rand_max=1024,
|
||||||
|
@ -309,10 +315,14 @@ async def child_channel_sender(
|
||||||
token_in,
|
token_in,
|
||||||
token_out
|
token_out
|
||||||
) as chan:
|
) as chan:
|
||||||
await ctx.started(msgs)
|
await ctx.started()
|
||||||
for msg in msgs:
|
for msg in rng:
|
||||||
await chan.send(msg)
|
await chan.send(msg)
|
||||||
|
|
||||||
|
await chan.send(b'bye')
|
||||||
|
await chan.receive()
|
||||||
|
return rng.hexdigest
|
||||||
|
|
||||||
|
|
||||||
def test_channel():
|
def test_channel():
|
||||||
|
|
||||||
|
@ -327,7 +337,7 @@ def test_channel():
|
||||||
attach_to_ringbuf_channel(send_token, recv_token) as chan,
|
attach_to_ringbuf_channel(send_token, recv_token) as chan,
|
||||||
tractor.open_nursery() as an
|
tractor.open_nursery() as an
|
||||||
):
|
):
|
||||||
recv_p = await an.start_actor(
|
sender = await an.start_actor(
|
||||||
'test_ringbuf_transport_sender',
|
'test_ringbuf_transport_sender',
|
||||||
enable_modules=[__name__],
|
enable_modules=[__name__],
|
||||||
proc_kwargs={
|
proc_kwargs={
|
||||||
|
@ -335,19 +345,26 @@ def test_channel():
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
async with (
|
async with (
|
||||||
recv_p.open_context(
|
sender.open_context(
|
||||||
child_channel_sender,
|
child_channel_sender,
|
||||||
msg_amount_min=msg_amount_min,
|
msg_amount_min=msg_amount_min,
|
||||||
msg_amount_max=msg_amount_max,
|
msg_amount_max=msg_amount_max,
|
||||||
token_in=recv_token,
|
token_in=recv_token,
|
||||||
token_out=send_token
|
token_out=send_token
|
||||||
) as (ctx, msgs),
|
) as (ctx, _),
|
||||||
):
|
):
|
||||||
recv_msgs = []
|
recvd_hash = hashlib.sha256()
|
||||||
async for msg in chan:
|
async for msg in chan:
|
||||||
recv_msgs.append(msg)
|
if msg == b'bye':
|
||||||
|
await chan.send(b'bye')
|
||||||
|
break
|
||||||
|
|
||||||
await recv_p.cancel_actor()
|
recvd_hash.update(msg)
|
||||||
assert recv_msgs == msgs
|
|
||||||
|
sent_hash = await ctx.result()
|
||||||
|
|
||||||
|
assert recvd_hash.hexdigest() == sent_hash
|
||||||
|
|
||||||
|
await an.cancel()
|
||||||
|
|
||||||
trio.run(main)
|
trio.run(main)
|
||||||
|
|
|
@ -1,84 +1,99 @@
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def generate_single_byte_msgs(amount: int) -> bytes:
|
def generate_single_byte_msgs(amount: int) -> bytes:
|
||||||
'''
|
'''
|
||||||
Generate a byte instance of len `amount` with:
|
Generate a byte instance of length `amount` with repeating ASCII digits 0..9.
|
||||||
|
|
||||||
```
|
|
||||||
byte_at_index(i) = (i % 10).encode()
|
|
||||||
```
|
|
||||||
|
|
||||||
this results in constantly repeating sequences of:
|
|
||||||
|
|
||||||
b'0123456789'
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
return b''.join(str(i % 10).encode() for i in range(amount))
|
# array [0, 1, 2, ..., amount-1], take mod 10 => [0..9], and map 0->'0'(48)
|
||||||
|
# up to 9->'9'(57).
|
||||||
|
arr = np.arange(amount, dtype=np.uint8) % 10
|
||||||
|
# move into ascii space
|
||||||
|
arr += 48
|
||||||
|
return arr.tobytes()
|
||||||
|
|
||||||
|
|
||||||
def generate_sample_messages(
|
class RandomBytesGenerator:
|
||||||
amount: int,
|
|
||||||
rand_min: int = 0,
|
|
||||||
rand_max: int = 0,
|
|
||||||
silent: bool = False,
|
|
||||||
) -> tuple[str, list[bytes], int]:
|
|
||||||
'''
|
'''
|
||||||
Generate bytes msgs for tests.
|
Generate bytes msgs for tests.
|
||||||
|
|
||||||
Messages will have the following format:
|
messages will have the following format:
|
||||||
|
|
||||||
```
|
b'[{i:08}]' + random_bytes
|
||||||
b'[{i:08}]' + os.urandom(random.randint(rand_min, rand_max))
|
|
||||||
```
|
|
||||||
|
|
||||||
so for message index 25:
|
so for message index 25:
|
||||||
|
|
||||||
b'[00000025]' + random_bytes
|
b'[00000025]' + random_bytes
|
||||||
|
|
||||||
|
also generates sha256 hash of msgs.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
msgs = []
|
|
||||||
size = 0
|
|
||||||
|
|
||||||
log_interval = None
|
def __init__(
|
||||||
if not silent:
|
self,
|
||||||
print(f'\ngenerating {amount} messages...')
|
amount: int,
|
||||||
|
rand_min: int = 0,
|
||||||
|
rand_max: int = 0
|
||||||
|
):
|
||||||
|
if rand_max < rand_min:
|
||||||
|
raise ValueError('rand_max must be >= rand_min')
|
||||||
|
|
||||||
# calculate an apropiate log interval based on
|
self._amount = amount
|
||||||
# max message size
|
self._rand_min = rand_min
|
||||||
max_msg_size = 10 + rand_max
|
self._rand_max = rand_max
|
||||||
|
self._index = 0
|
||||||
|
self._hasher = hashlib.sha256()
|
||||||
|
self._total_bytes = 0
|
||||||
|
|
||||||
|
self._lengths = np.random.randint(
|
||||||
|
rand_min,
|
||||||
|
rand_max + 1,
|
||||||
|
size=amount,
|
||||||
|
dtype=np.int32
|
||||||
|
)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self) -> bytes:
|
||||||
|
if self._index == self._amount:
|
||||||
|
raise StopIteration
|
||||||
|
|
||||||
|
header = f'[{self._index:08}]'.encode('utf-8')
|
||||||
|
|
||||||
|
length = int(self._lengths[self._index])
|
||||||
|
msg = header + np.random.bytes(length)
|
||||||
|
|
||||||
|
self._hasher.update(msg)
|
||||||
|
self._total_bytes += length
|
||||||
|
self._index += 1
|
||||||
|
|
||||||
|
return msg
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hexdigest(self) -> str:
|
||||||
|
return self._hasher.hexdigest()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_bytes(self) -> int:
|
||||||
|
return self._total_bytes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_msgs(self) -> int:
|
||||||
|
return self._amount
|
||||||
|
|
||||||
|
@property
|
||||||
|
def msgs_generated(self) -> int:
|
||||||
|
return self._index
|
||||||
|
|
||||||
|
@property
|
||||||
|
def recommended_log_interval(self) -> int:
|
||||||
|
max_msg_size = 10 + self._rand_max
|
||||||
|
|
||||||
if max_msg_size <= 32 * 1024:
|
if max_msg_size <= 32 * 1024:
|
||||||
log_interval = 10_000
|
return 10_000
|
||||||
|
|
||||||
else:
|
else:
|
||||||
log_interval = 1000
|
return 1000
|
||||||
|
|
||||||
payload_hash = hashlib.sha256()
|
|
||||||
for i in range(amount):
|
|
||||||
msg = f'[{i:08}]'.encode('utf-8')
|
|
||||||
|
|
||||||
if rand_max > 0:
|
|
||||||
msg += os.urandom(
|
|
||||||
random.randint(rand_min, rand_max))
|
|
||||||
|
|
||||||
size += len(msg)
|
|
||||||
|
|
||||||
payload_hash.update(msg)
|
|
||||||
msgs.append(msg)
|
|
||||||
|
|
||||||
if (
|
|
||||||
not silent
|
|
||||||
and
|
|
||||||
i > 0
|
|
||||||
and
|
|
||||||
i % log_interval == 0
|
|
||||||
):
|
|
||||||
print(f'{i} generated')
|
|
||||||
|
|
||||||
if not silent:
|
|
||||||
print(f'done, {size:,} bytes in total')
|
|
||||||
|
|
||||||
return payload_hash.hexdigest(), msgs, size
|
|
||||||
|
|
Loading…
Reference in New Issue