diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 172bb7f..1ada466 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -76,7 +76,8 @@ def test_stream_from_single_subactor(arb_addr, start_method): async def stream_data(seed): for i in range(seed): yield i - await trio.sleep(0) # trigger scheduler + # trigger scheduler to simulate practical usage + await trio.sleep(0) # this is the third actor; the aggregator @@ -97,30 +98,32 @@ async def aggregate(seed): send_chan, recv_chan = trio.open_memory_channel(500) - async def push_to_chan(portal): - async for value in await portal.run( - __name__, 'stream_data', seed=seed - ): - # leverage trio's built-in backpressure - await send_chan.send(value) + async def push_to_chan(portal, send_chan): + async with send_chan: + async for value in await portal.run( + __name__, 'stream_data', seed=seed + ): + # leverage trio's built-in backpressure + await send_chan.send(value) - await send_chan.send(None) print(f"FINISHED ITERATING {portal.channel.uid}") # spawn 2 trio tasks to collect streams and push to a local queue async with trio.open_nursery() as n: + for portal in portals: - n.start_soon(push_to_chan, portal) + n.start_soon(push_to_chan, portal, send_chan.clone()) + + # close this local task's reference to send side + await send_chan.aclose() unique_vals = set() - async for value in recv_chan: - if value not in unique_vals: - unique_vals.add(value) - # yield upwards to the spawning parent actor - yield value - - if value is None: - break + async with recv_chan: + async for value in recv_chan: + if value not in unique_vals: + unique_vals.add(value) + # yield upwards to the spawning parent actor + yield value assert value in unique_vals @@ -154,7 +157,7 @@ async def a_quadruple_example(): print(f"STREAM TIME = {time.time() - start}") print(f"STREAM + SPAWN TIME = {time.time() - pre_start}") - assert result_stream == list(range(seed)) + [None] + assert result_stream == list(range(seed)) return result_stream