Add test for `LinkedTaskChannel.subscribe()` fanout feature

aio_explicit_task_cancels
Tyler Goodlet 2022-04-12 14:24:30 -04:00
parent 9c43bb28f1
commit 90593611bb
1 changed files with 45 additions and 11 deletions

View File

@ -2,9 +2,10 @@
The hipster way to force SC onto the stdlib's "async": 'infection mode'.
'''
from typing import Optional, Iterable
from typing import Optional, Iterable, Union
import asyncio
import builtins
import itertools
import importlib
import pytest
@ -12,6 +13,7 @@ import trio
import tractor
from tractor import to_asyncio
from tractor import RemoteActorError
from tractor.trionics import BroadcastReceiver
async def sleep_and_err(sleep_for: float = 0.1):
@ -217,6 +219,7 @@ async def stream_from_aio(
exit_early: bool = False,
raise_err: bool = False,
aio_raise_err: bool = False,
fan_out: bool = False,
) -> None:
seq = range(100)
@ -234,6 +237,12 @@ async def stream_from_aio(
assert first is True
async def consume(
chan: Union[
to_asyncio.LinkedTaskChannel,
BroadcastReceiver,
],
):
async for value in chan:
print(f'trio received {value}')
pulled.append(value)
@ -243,6 +252,18 @@ async def stream_from_aio(
raise Exception
elif exit_early:
break
if fan_out:
# start second task that get's the same stream value set.
async with (
trio.open_nursery() as n,
chan.subscribe() as br,
):
n.start_soon(consume, br)
await consume(chan)
else:
await consume(chan)
finally:
if (
@ -250,19 +271,32 @@ async def stream_from_aio(
not exit_early and
not aio_raise_err
):
if fan_out:
# we get double the pulled values in the
# ``.subscribe()`` fan out case.
doubled = list(itertools.chain(*zip(expect, expect)))
assert pulled == doubled
else:
assert pulled == expect
else:
# if fan_out:
assert pulled == expect[:51]
print('trio guest mode task completed!')
def test_basic_interloop_channel_stream(arb_addr):
@pytest.mark.parametrize(
'fan_out', [False, True],
ids='fan_out_w_chan_subscribe={}'.format
)
def test_basic_interloop_channel_stream(arb_addr, fan_out):
async def main():
async with tractor.open_nursery() as n:
portal = await n.run_in_actor(
stream_from_aio,
infect_asyncio=True,
fan_out=fan_out,
)
await portal.result()