From 90593611bbe791977f5dbb21e5e0801789f18339 Mon Sep 17 00:00:00 2001 From: Tyler Goodlet Date: Tue, 12 Apr 2022 14:24:30 -0400 Subject: [PATCH] Add test for `LinkedTaskChannel.subscribe()` fanout feature --- tests/test_infected_asyncio.py | 56 +++++++++++++++++++++++++++------- 1 file changed, 45 insertions(+), 11 deletions(-) diff --git a/tests/test_infected_asyncio.py b/tests/test_infected_asyncio.py index c0c33db..8fa6eb5 100644 --- a/tests/test_infected_asyncio.py +++ b/tests/test_infected_asyncio.py @@ -2,9 +2,10 @@ The hipster way to force SC onto the stdlib's "async": 'infection mode'. ''' -from typing import Optional, Iterable +from typing import Optional, Iterable, Union import asyncio import builtins +import itertools import importlib import pytest @@ -12,6 +13,7 @@ import trio import tractor from tractor import to_asyncio from tractor import RemoteActorError +from tractor.trionics import BroadcastReceiver async def sleep_and_err(sleep_for: float = 0.1): @@ -217,6 +219,7 @@ async def stream_from_aio( exit_early: bool = False, raise_err: bool = False, aio_raise_err: bool = False, + fan_out: bool = False, ) -> None: seq = range(100) @@ -234,15 +237,33 @@ async def stream_from_aio( assert first is True - async for value in chan: - print(f'trio received {value}') - pulled.append(value) + async def consume( + chan: Union[ + to_asyncio.LinkedTaskChannel, + BroadcastReceiver, + ], + ): + async for value in chan: + print(f'trio received {value}') + pulled.append(value) - if value == 50: - if raise_err: - raise Exception - elif exit_early: - break + if value == 50: + if raise_err: + raise Exception + elif exit_early: + break + + if fan_out: + # start second task that get's the same stream value set. + async with ( + trio.open_nursery() as n, + chan.subscribe() as br, + ): + n.start_soon(consume, br) + await consume(chan) + + else: + await consume(chan) finally: if ( @@ -250,19 +271,32 @@ async def stream_from_aio( not exit_early and not aio_raise_err ): - assert pulled == expect + if fan_out: + # we get double the pulled values in the + # ``.subscribe()`` fan out case. + doubled = list(itertools.chain(*zip(expect, expect))) + assert pulled == doubled + + else: + assert pulled == expect else: + # if fan_out: assert pulled == expect[:51] print('trio guest mode task completed!') -def test_basic_interloop_channel_stream(arb_addr): +@pytest.mark.parametrize( + 'fan_out', [False, True], + ids='fan_out_w_chan_subscribe={}'.format +) +def test_basic_interloop_channel_stream(arb_addr, fan_out): async def main(): async with tractor.open_nursery() as n: portal = await n.run_in_actor( stream_from_aio, infect_asyncio=True, + fan_out=fan_out, ) await portal.result()