Add test for `LinkedTaskChannel.subscribe()` fanout feature

aio_explicit_task_cancels
Tyler Goodlet 2022-04-12 14:24:30 -04:00
parent 9c43bb28f1
commit 90593611bb
1 changed files with 45 additions and 11 deletions

View File

@ -2,9 +2,10 @@
The hipster way to force SC onto the stdlib's "async": 'infection mode'. The hipster way to force SC onto the stdlib's "async": 'infection mode'.
''' '''
from typing import Optional, Iterable from typing import Optional, Iterable, Union
import asyncio import asyncio
import builtins import builtins
import itertools
import importlib import importlib
import pytest import pytest
@ -12,6 +13,7 @@ import trio
import tractor import tractor
from tractor import to_asyncio from tractor import to_asyncio
from tractor import RemoteActorError from tractor import RemoteActorError
from tractor.trionics import BroadcastReceiver
async def sleep_and_err(sleep_for: float = 0.1): async def sleep_and_err(sleep_for: float = 0.1):
@ -217,6 +219,7 @@ async def stream_from_aio(
exit_early: bool = False, exit_early: bool = False,
raise_err: bool = False, raise_err: bool = False,
aio_raise_err: bool = False, aio_raise_err: bool = False,
fan_out: bool = False,
) -> None: ) -> None:
seq = range(100) seq = range(100)
@ -234,15 +237,33 @@ async def stream_from_aio(
assert first is True assert first is True
async for value in chan: async def consume(
print(f'trio received {value}') chan: Union[
pulled.append(value) to_asyncio.LinkedTaskChannel,
BroadcastReceiver,
],
):
async for value in chan:
print(f'trio received {value}')
pulled.append(value)
if value == 50: if value == 50:
if raise_err: if raise_err:
raise Exception raise Exception
elif exit_early: elif exit_early:
break break
if fan_out:
# start second task that get's the same stream value set.
async with (
trio.open_nursery() as n,
chan.subscribe() as br,
):
n.start_soon(consume, br)
await consume(chan)
else:
await consume(chan)
finally: finally:
if ( if (
@ -250,19 +271,32 @@ async def stream_from_aio(
not exit_early and not exit_early and
not aio_raise_err not aio_raise_err
): ):
assert pulled == expect if fan_out:
# we get double the pulled values in the
# ``.subscribe()`` fan out case.
doubled = list(itertools.chain(*zip(expect, expect)))
assert pulled == doubled
else:
assert pulled == expect
else: else:
# if fan_out:
assert pulled == expect[:51] assert pulled == expect[:51]
print('trio guest mode task completed!') print('trio guest mode task completed!')
def test_basic_interloop_channel_stream(arb_addr): @pytest.mark.parametrize(
'fan_out', [False, True],
ids='fan_out_w_chan_subscribe={}'.format
)
def test_basic_interloop_channel_stream(arb_addr, fan_out):
async def main(): async def main():
async with tractor.open_nursery() as n: async with tractor.open_nursery() as n:
portal = await n.run_in_actor( portal = await n.run_in_actor(
stream_from_aio, stream_from_aio,
infect_asyncio=True, infect_asyncio=True,
fan_out=fan_out,
) )
await portal.result() await portal.result()