Enable ordering assertion & simplify some parts of test
							parent
							
								
									d8d01e8b3c
								
							
						
					
					
						commit
						d942f073e0
					
				| 
						 | 
				
			
			@ -1,3 +1,5 @@
 | 
			
		|||
from contextlib import asynccontextmanager as acm
 | 
			
		||||
 | 
			
		||||
import trio
 | 
			
		||||
import tractor
 | 
			
		||||
import msgspec
 | 
			
		||||
| 
						 | 
				
			
			@ -107,8 +109,7 @@ class RemoveChannelMsg(Struct, frozen=True, tag=True):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class RangeMsg(Struct, frozen=True, tag=True):
 | 
			
		||||
    start: int
 | 
			
		||||
    end: int
 | 
			
		||||
    size: int
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
ControlMessages = AddChannelMsg | RemoveChannelMsg | RangeMsg
 | 
			
		||||
| 
						 | 
				
			
			@ -151,17 +152,17 @@ async def subscriber_child(ctx: tractor.Context):
 | 
			
		|||
            while True:
 | 
			
		||||
                await range_event.wait()
 | 
			
		||||
                range_event = trio.Event()
 | 
			
		||||
                for i in range(range_msg.start, range_msg.end):
 | 
			
		||||
                for i in range(range_msg.size):
 | 
			
		||||
                    recv = int.from_bytes(await subs.receive())
 | 
			
		||||
                    # if recv != i:
 | 
			
		||||
                    #     raise AssertionError(
 | 
			
		||||
                    #         f'received: {recv} expected: {i}'
 | 
			
		||||
                    #     )
 | 
			
		||||
                    if recv != i:
 | 
			
		||||
                        raise AssertionError(
 | 
			
		||||
                            f'received: {recv} expected: {i}'
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
                    log.info(f'received: {recv} expected: {i}')
 | 
			
		||||
                    log.info(f'received: {recv}')
 | 
			
		||||
 | 
			
		||||
                await stream.send(b'valid range')
 | 
			
		||||
                log.info('FINISHED RANGE')
 | 
			
		||||
                log.info('finished range')
 | 
			
		||||
 | 
			
		||||
    log.info('subscriber exit')
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -170,10 +171,9 @@ async def subscriber_child(ctx: tractor.Context):
 | 
			
		|||
async def publisher_child(ctx: tractor.Context):
 | 
			
		||||
    await ctx.started()
 | 
			
		||||
    async with (
 | 
			
		||||
        open_ringbuf_publisher(batch_size=1, guarantee_order=True) as pub,
 | 
			
		||||
        open_ringbuf_publisher(batch_size=100, guarantee_order=True) as pub,
 | 
			
		||||
        ctx.open_stream() as stream
 | 
			
		||||
    ):
 | 
			
		||||
        abs_index = 0
 | 
			
		||||
        async for msg in stream:
 | 
			
		||||
            msg = msgspec.msgpack.decode(msg, type=ControlMessages)
 | 
			
		||||
            match msg:
 | 
			
		||||
| 
						 | 
				
			
			@ -184,10 +184,9 @@ async def publisher_child(ctx: tractor.Context):
 | 
			
		|||
                    await pub.remove_channel(msg.name)
 | 
			
		||||
 | 
			
		||||
                case RangeMsg():
 | 
			
		||||
                    for i in range(msg.start, msg.end):
 | 
			
		||||
                    for i in range(msg.size):
 | 
			
		||||
                        await pub.send(i.to_bytes(4))
 | 
			
		||||
                        log.info(f'sent {i}, index: {abs_index}')
 | 
			
		||||
                        abs_index += 1
 | 
			
		||||
                        log.info(f'sent {i}')
 | 
			
		||||
 | 
			
		||||
            await stream.send(b'ack')
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -243,21 +242,26 @@ def test_pubsub():
 | 
			
		|||
                async def remove_channel(name: str):
 | 
			
		||||
                    await send_wait_ack(RemoveChannelMsg(name=name).encode())
 | 
			
		||||
 | 
			
		||||
                async def send_range(start: int, end: int):
 | 
			
		||||
                    await send_wait_ack(RangeMsg(start=start, end=end).encode())
 | 
			
		||||
                async def send_range(size: int):
 | 
			
		||||
                    await send_wait_ack(RangeMsg(size=size).encode())
 | 
			
		||||
                    range_ack = await recv_stream.receive()
 | 
			
		||||
                    assert range_ack == b'valid range'
 | 
			
		||||
 | 
			
		||||
                # simple test, open one channel and send 0..100 range
 | 
			
		||||
                ring_name = 'ring-first'
 | 
			
		||||
                await add_channel(ring_name)
 | 
			
		||||
                await send_range(0, 100)
 | 
			
		||||
                await send_range(100)
 | 
			
		||||
                await remove_channel(ring_name)
 | 
			
		||||
 | 
			
		||||
                # redo
 | 
			
		||||
                ring_name = 'ring-redo'
 | 
			
		||||
                await add_channel(ring_name)
 | 
			
		||||
                await send_range(0, 100)
 | 
			
		||||
                await send_range(100)
 | 
			
		||||
                await remove_channel(ring_name)
 | 
			
		||||
 | 
			
		||||
                # try using same ring name
 | 
			
		||||
                await add_channel(ring_name)
 | 
			
		||||
                await send_range(100)
 | 
			
		||||
                await remove_channel(ring_name)
 | 
			
		||||
 | 
			
		||||
                # multi chan test
 | 
			
		||||
| 
						 | 
				
			
			@ -268,7 +272,7 @@ def test_pubsub():
 | 
			
		|||
                for name in ring_names:
 | 
			
		||||
                    await add_channel(name)
 | 
			
		||||
 | 
			
		||||
                await send_range(0, 300)
 | 
			
		||||
                await send_range(1000)
 | 
			
		||||
 | 
			
		||||
                for name in ring_names:
 | 
			
		||||
                    await remove_channel(name)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue