diff --git a/tests/test_ringd.py b/tests/test_ringd.py index 3eda428a..9c6bec69 100644 --- a/tests/test_ringd.py +++ b/tests/test_ringd.py @@ -1,3 +1,5 @@ +from contextlib import asynccontextmanager as acm + import trio import tractor import msgspec @@ -107,8 +109,7 @@ class RemoveChannelMsg(Struct, frozen=True, tag=True): class RangeMsg(Struct, frozen=True, tag=True): - start: int - end: int + size: int ControlMessages = AddChannelMsg | RemoveChannelMsg | RangeMsg @@ -151,17 +152,17 @@ async def subscriber_child(ctx: tractor.Context): while True: await range_event.wait() range_event = trio.Event() - for i in range(range_msg.start, range_msg.end): + for i in range(range_msg.size): recv = int.from_bytes(await subs.receive()) - # if recv != i: - # raise AssertionError( - # f'received: {recv} expected: {i}' - # ) + if recv != i: + raise AssertionError( + f'received: {recv} expected: {i}' + ) - log.info(f'received: {recv} expected: {i}') + log.info(f'received: {recv}') await stream.send(b'valid range') - log.info('FINISHED RANGE') + log.info('finished range') log.info('subscriber exit') @@ -170,10 +171,9 @@ async def subscriber_child(ctx: tractor.Context): async def publisher_child(ctx: tractor.Context): await ctx.started() async with ( - open_ringbuf_publisher(batch_size=1, guarantee_order=True) as pub, + open_ringbuf_publisher(batch_size=100, guarantee_order=True) as pub, ctx.open_stream() as stream ): - abs_index = 0 async for msg in stream: msg = msgspec.msgpack.decode(msg, type=ControlMessages) match msg: @@ -184,10 +184,9 @@ async def publisher_child(ctx: tractor.Context): await pub.remove_channel(msg.name) case RangeMsg(): - for i in range(msg.start, msg.end): + for i in range(msg.size): await pub.send(i.to_bytes(4)) - log.info(f'sent {i}, index: {abs_index}') - abs_index += 1 + log.info(f'sent {i}') await stream.send(b'ack') @@ -243,21 +242,26 @@ def test_pubsub(): async def remove_channel(name: str): await send_wait_ack(RemoveChannelMsg(name=name).encode()) - async def send_range(start: int, end: int): - await send_wait_ack(RangeMsg(start=start, end=end).encode()) + async def send_range(size: int): + await send_wait_ack(RangeMsg(size=size).encode()) range_ack = await recv_stream.receive() assert range_ack == b'valid range' # simple test, open one channel and send 0..100 range ring_name = 'ring-first' await add_channel(ring_name) - await send_range(0, 100) + await send_range(100) await remove_channel(ring_name) # redo ring_name = 'ring-redo' await add_channel(ring_name) - await send_range(0, 100) + await send_range(100) + await remove_channel(ring_name) + + # try using same ring name + await add_channel(ring_name) + await send_range(100) await remove_channel(ring_name) # multi chan test @@ -268,7 +272,7 @@ def test_pubsub(): for name in ring_names: await add_channel(name) - await send_range(0, 300) + await send_range(1000) for name in ring_names: await remove_channel(name)