Add test for `LinkedTaskChannel.subscribe()` fanout feature
parent
9c43bb28f1
commit
90593611bb
|
@ -2,9 +2,10 @@
|
|||
The hipster way to force SC onto the stdlib's "async": 'infection mode'.
|
||||
|
||||
'''
|
||||
from typing import Optional, Iterable
|
||||
from typing import Optional, Iterable, Union
|
||||
import asyncio
|
||||
import builtins
|
||||
import itertools
|
||||
import importlib
|
||||
|
||||
import pytest
|
||||
|
@ -12,6 +13,7 @@ import trio
|
|||
import tractor
|
||||
from tractor import to_asyncio
|
||||
from tractor import RemoteActorError
|
||||
from tractor.trionics import BroadcastReceiver
|
||||
|
||||
|
||||
async def sleep_and_err(sleep_for: float = 0.1):
|
||||
|
@ -217,6 +219,7 @@ async def stream_from_aio(
|
|||
exit_early: bool = False,
|
||||
raise_err: bool = False,
|
||||
aio_raise_err: bool = False,
|
||||
fan_out: bool = False,
|
||||
|
||||
) -> None:
|
||||
seq = range(100)
|
||||
|
@ -234,15 +237,33 @@ async def stream_from_aio(
|
|||
|
||||
assert first is True
|
||||
|
||||
async for value in chan:
|
||||
print(f'trio received {value}')
|
||||
pulled.append(value)
|
||||
async def consume(
|
||||
chan: Union[
|
||||
to_asyncio.LinkedTaskChannel,
|
||||
BroadcastReceiver,
|
||||
],
|
||||
):
|
||||
async for value in chan:
|
||||
print(f'trio received {value}')
|
||||
pulled.append(value)
|
||||
|
||||
if value == 50:
|
||||
if raise_err:
|
||||
raise Exception
|
||||
elif exit_early:
|
||||
break
|
||||
if value == 50:
|
||||
if raise_err:
|
||||
raise Exception
|
||||
elif exit_early:
|
||||
break
|
||||
|
||||
if fan_out:
|
||||
# start second task that get's the same stream value set.
|
||||
async with (
|
||||
trio.open_nursery() as n,
|
||||
chan.subscribe() as br,
|
||||
):
|
||||
n.start_soon(consume, br)
|
||||
await consume(chan)
|
||||
|
||||
else:
|
||||
await consume(chan)
|
||||
finally:
|
||||
|
||||
if (
|
||||
|
@ -250,19 +271,32 @@ async def stream_from_aio(
|
|||
not exit_early and
|
||||
not aio_raise_err
|
||||
):
|
||||
assert pulled == expect
|
||||
if fan_out:
|
||||
# we get double the pulled values in the
|
||||
# ``.subscribe()`` fan out case.
|
||||
doubled = list(itertools.chain(*zip(expect, expect)))
|
||||
assert pulled == doubled
|
||||
|
||||
else:
|
||||
assert pulled == expect
|
||||
else:
|
||||
# if fan_out:
|
||||
assert pulled == expect[:51]
|
||||
|
||||
print('trio guest mode task completed!')
|
||||
|
||||
|
||||
def test_basic_interloop_channel_stream(arb_addr):
|
||||
@pytest.mark.parametrize(
|
||||
'fan_out', [False, True],
|
||||
ids='fan_out_w_chan_subscribe={}'.format
|
||||
)
|
||||
def test_basic_interloop_channel_stream(arb_addr, fan_out):
|
||||
async def main():
|
||||
async with tractor.open_nursery() as n:
|
||||
portal = await n.run_in_actor(
|
||||
stream_from_aio,
|
||||
infect_asyncio=True,
|
||||
fan_out=fan_out,
|
||||
)
|
||||
await portal.result()
|
||||
|
||||
|
|
Loading…
Reference in New Issue