diff --git a/tests/test_task_broadcasting.py b/tests/test_task_broadcasting.py index c2200b3..e9e5500 100644 --- a/tests/test_task_broadcasting.py +++ b/tests/test_task_broadcasting.py @@ -30,8 +30,8 @@ async def echo_sequences( async for sequence in stream: seq = list(sequence) for value in seq: - print(f'sending {value}') await stream.send(value) + print(f'producer sent {value}') async def ensure_sequence( @@ -64,6 +64,7 @@ async def open_sequence_streamer( sequence: list[int], arb_addr: tuple[str, int], start_method: str, + shield: bool = False, ) -> tractor.MsgStream: @@ -82,7 +83,7 @@ async def open_sequence_streamer( ) as (ctx, first): assert first is None - async with ctx.open_stream() as stream: + async with ctx.open_stream(shield=shield) as stream: yield stream await portal.cancel_actor() @@ -206,9 +207,79 @@ def test_consumer_and_parent_maybe_lag( trio.run(main) -# TODO: -# def test_first_task_to_recv_is_cancelled(): -# ... +def test_faster_task_to_recv_is_cancelled_by_slower( + arb_addr, + start_method, +): + '''Ensure that if a faster task consuming from a stream is cancelled + the slower task can continue to receive all expected values. + + ''' + async def main(): + + sequence = list(range(1000)) + + async with open_sequence_streamer( + sequence, + arb_addr, + start_method, + + # NOTE: this MUST be set to avoid the stream terminating + # early when the faster subtask is cancelled by the slower + # parent task. + shield=True, + + ) as stream: + + # alt to passing kwarg above. + # with stream.shield(): + + async with trio.open_nursery() as n: + n.start_soon( + ensure_sequence, + stream, + sequence.copy(), + 0, + name='consumer_task', + ) + + await stream.send(tuple(sequence)) + + # pull 3 values, cancel the subtask, then + # expect to be able to pull all values still + for i in range(20): + try: + value = await stream.receive() + print(f'source stream rx: {value}') + await trio.sleep(0.01) + except Lagged: + print(f'parent overrun after {value}') + continue + + print('cancelling faster subtask') + n.cancel_scope.cancel() + + try: + value = await stream.receive() + print(f'source stream after cancel: {value}') + except Lagged: + print(f'parent overrun after {value}') + + # expect to see all remaining values + with trio.fail_after(0.5): + async for value in stream: + assert stream._broadcaster._state.recv_ready is None + print(f'source stream rx: {value}') + if value == 999: + # fully consumed and we missed no values once + # the faster subtask was cancelled + break + + # await tractor.breakpoint() + # await stream.receive() + print(f'final value: {value}') + + trio.run(main) def test_subscribe_errors_after_close():