Enable ordering assertion & simplify some parts of test
parent
d8d01e8b3c
commit
d942f073e0
|
@ -1,3 +1,5 @@
|
|||
from contextlib import asynccontextmanager as acm
|
||||
|
||||
import trio
|
||||
import tractor
|
||||
import msgspec
|
||||
|
@ -107,8 +109,7 @@ class RemoveChannelMsg(Struct, frozen=True, tag=True):
|
|||
|
||||
|
||||
class RangeMsg(Struct, frozen=True, tag=True):
|
||||
start: int
|
||||
end: int
|
||||
size: int
|
||||
|
||||
|
||||
ControlMessages = AddChannelMsg | RemoveChannelMsg | RangeMsg
|
||||
|
@ -151,17 +152,17 @@ async def subscriber_child(ctx: tractor.Context):
|
|||
while True:
|
||||
await range_event.wait()
|
||||
range_event = trio.Event()
|
||||
for i in range(range_msg.start, range_msg.end):
|
||||
for i in range(range_msg.size):
|
||||
recv = int.from_bytes(await subs.receive())
|
||||
# if recv != i:
|
||||
# raise AssertionError(
|
||||
# f'received: {recv} expected: {i}'
|
||||
# )
|
||||
if recv != i:
|
||||
raise AssertionError(
|
||||
f'received: {recv} expected: {i}'
|
||||
)
|
||||
|
||||
log.info(f'received: {recv} expected: {i}')
|
||||
log.info(f'received: {recv}')
|
||||
|
||||
await stream.send(b'valid range')
|
||||
log.info('FINISHED RANGE')
|
||||
log.info('finished range')
|
||||
|
||||
log.info('subscriber exit')
|
||||
|
||||
|
@ -170,10 +171,9 @@ async def subscriber_child(ctx: tractor.Context):
|
|||
async def publisher_child(ctx: tractor.Context):
|
||||
await ctx.started()
|
||||
async with (
|
||||
open_ringbuf_publisher(batch_size=1, guarantee_order=True) as pub,
|
||||
open_ringbuf_publisher(batch_size=100, guarantee_order=True) as pub,
|
||||
ctx.open_stream() as stream
|
||||
):
|
||||
abs_index = 0
|
||||
async for msg in stream:
|
||||
msg = msgspec.msgpack.decode(msg, type=ControlMessages)
|
||||
match msg:
|
||||
|
@ -184,10 +184,9 @@ async def publisher_child(ctx: tractor.Context):
|
|||
await pub.remove_channel(msg.name)
|
||||
|
||||
case RangeMsg():
|
||||
for i in range(msg.start, msg.end):
|
||||
for i in range(msg.size):
|
||||
await pub.send(i.to_bytes(4))
|
||||
log.info(f'sent {i}, index: {abs_index}')
|
||||
abs_index += 1
|
||||
log.info(f'sent {i}')
|
||||
|
||||
await stream.send(b'ack')
|
||||
|
||||
|
@ -243,21 +242,26 @@ def test_pubsub():
|
|||
async def remove_channel(name: str):
|
||||
await send_wait_ack(RemoveChannelMsg(name=name).encode())
|
||||
|
||||
async def send_range(start: int, end: int):
|
||||
await send_wait_ack(RangeMsg(start=start, end=end).encode())
|
||||
async def send_range(size: int):
|
||||
await send_wait_ack(RangeMsg(size=size).encode())
|
||||
range_ack = await recv_stream.receive()
|
||||
assert range_ack == b'valid range'
|
||||
|
||||
# simple test, open one channel and send 0..100 range
|
||||
ring_name = 'ring-first'
|
||||
await add_channel(ring_name)
|
||||
await send_range(0, 100)
|
||||
await send_range(100)
|
||||
await remove_channel(ring_name)
|
||||
|
||||
# redo
|
||||
ring_name = 'ring-redo'
|
||||
await add_channel(ring_name)
|
||||
await send_range(0, 100)
|
||||
await send_range(100)
|
||||
await remove_channel(ring_name)
|
||||
|
||||
# try using same ring name
|
||||
await add_channel(ring_name)
|
||||
await send_range(100)
|
||||
await remove_channel(ring_name)
|
||||
|
||||
# multi chan test
|
||||
|
@ -268,7 +272,7 @@ def test_pubsub():
|
|||
for name in ring_names:
|
||||
await add_channel(name)
|
||||
|
||||
await send_range(0, 300)
|
||||
await send_range(1000)
|
||||
|
||||
for name in ring_names:
|
||||
await remove_channel(name)
|
||||
|
|
Loading…
Reference in New Issue