Add test for `LinkedTaskChannel.subscribe()` fanout feature
parent
9c43bb28f1
commit
90593611bb
|
@ -2,9 +2,10 @@
|
||||||
The hipster way to force SC onto the stdlib's "async": 'infection mode'.
|
The hipster way to force SC onto the stdlib's "async": 'infection mode'.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
from typing import Optional, Iterable
|
from typing import Optional, Iterable, Union
|
||||||
import asyncio
|
import asyncio
|
||||||
import builtins
|
import builtins
|
||||||
|
import itertools
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -12,6 +13,7 @@ import trio
|
||||||
import tractor
|
import tractor
|
||||||
from tractor import to_asyncio
|
from tractor import to_asyncio
|
||||||
from tractor import RemoteActorError
|
from tractor import RemoteActorError
|
||||||
|
from tractor.trionics import BroadcastReceiver
|
||||||
|
|
||||||
|
|
||||||
async def sleep_and_err(sleep_for: float = 0.1):
|
async def sleep_and_err(sleep_for: float = 0.1):
|
||||||
|
@ -217,6 +219,7 @@ async def stream_from_aio(
|
||||||
exit_early: bool = False,
|
exit_early: bool = False,
|
||||||
raise_err: bool = False,
|
raise_err: bool = False,
|
||||||
aio_raise_err: bool = False,
|
aio_raise_err: bool = False,
|
||||||
|
fan_out: bool = False,
|
||||||
|
|
||||||
) -> None:
|
) -> None:
|
||||||
seq = range(100)
|
seq = range(100)
|
||||||
|
@ -234,15 +237,33 @@ async def stream_from_aio(
|
||||||
|
|
||||||
assert first is True
|
assert first is True
|
||||||
|
|
||||||
async for value in chan:
|
async def consume(
|
||||||
print(f'trio received {value}')
|
chan: Union[
|
||||||
pulled.append(value)
|
to_asyncio.LinkedTaskChannel,
|
||||||
|
BroadcastReceiver,
|
||||||
|
],
|
||||||
|
):
|
||||||
|
async for value in chan:
|
||||||
|
print(f'trio received {value}')
|
||||||
|
pulled.append(value)
|
||||||
|
|
||||||
if value == 50:
|
if value == 50:
|
||||||
if raise_err:
|
if raise_err:
|
||||||
raise Exception
|
raise Exception
|
||||||
elif exit_early:
|
elif exit_early:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if fan_out:
|
||||||
|
# start second task that get's the same stream value set.
|
||||||
|
async with (
|
||||||
|
trio.open_nursery() as n,
|
||||||
|
chan.subscribe() as br,
|
||||||
|
):
|
||||||
|
n.start_soon(consume, br)
|
||||||
|
await consume(chan)
|
||||||
|
|
||||||
|
else:
|
||||||
|
await consume(chan)
|
||||||
finally:
|
finally:
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -250,19 +271,32 @@ async def stream_from_aio(
|
||||||
not exit_early and
|
not exit_early and
|
||||||
not aio_raise_err
|
not aio_raise_err
|
||||||
):
|
):
|
||||||
assert pulled == expect
|
if fan_out:
|
||||||
|
# we get double the pulled values in the
|
||||||
|
# ``.subscribe()`` fan out case.
|
||||||
|
doubled = list(itertools.chain(*zip(expect, expect)))
|
||||||
|
assert pulled == doubled
|
||||||
|
|
||||||
|
else:
|
||||||
|
assert pulled == expect
|
||||||
else:
|
else:
|
||||||
|
# if fan_out:
|
||||||
assert pulled == expect[:51]
|
assert pulled == expect[:51]
|
||||||
|
|
||||||
print('trio guest mode task completed!')
|
print('trio guest mode task completed!')
|
||||||
|
|
||||||
|
|
||||||
def test_basic_interloop_channel_stream(arb_addr):
|
@pytest.mark.parametrize(
|
||||||
|
'fan_out', [False, True],
|
||||||
|
ids='fan_out_w_chan_subscribe={}'.format
|
||||||
|
)
|
||||||
|
def test_basic_interloop_channel_stream(arb_addr, fan_out):
|
||||||
async def main():
|
async def main():
|
||||||
async with tractor.open_nursery() as n:
|
async with tractor.open_nursery() as n:
|
||||||
portal = await n.run_in_actor(
|
portal = await n.run_in_actor(
|
||||||
stream_from_aio,
|
stream_from_aio,
|
||||||
infect_asyncio=True,
|
infect_asyncio=True,
|
||||||
|
fan_out=fan_out,
|
||||||
)
|
)
|
||||||
await portal.result()
|
await portal.result()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue