Compare commits

..

10 Commits

Author SHA1 Message Date
Tyler Goodlet a7e7c9d1c0 Store array `maxlen` in state singleton
The `collections.deque` takes care of array length truncation of values
for us implicitly but in the future we'll likely want this value exposed
to alternate array implementations. This patch is to provide for that as
well as make `mypy` happy since the `dequeu.maxlen` can also be `None`.
2021-09-01 06:49:09 -04:00
Tyler Goodlet c3665801a5 Don't wake sibling bcast consumers on a cancelled call 2021-09-01 06:49:09 -04:00
Tyler Goodlet 71a4f8aaa9 Shorten sequence length for test speedup 2021-09-01 06:49:09 -04:00
Tyler Goodlet 7296d171be Shorten default feeder mem chan size to 64 2021-09-01 06:49:09 -04:00
Tyler Goodlet a053a18f53 Can't use built-in generics till 3.9... 2021-09-01 06:49:09 -04:00
Tyler Goodlet db86409369 Add `shield: bool` kwarg to `Portal.open_stream_from()` 2021-09-01 06:49:09 -04:00
Tyler Goodlet 2c96e85981 Add a "faster task is cancelled" test 2021-09-01 06:49:09 -04:00
Tyler Goodlet a0b69fd64b Rename test module 2021-09-01 06:49:09 -04:00
Tyler Goodlet 727d666cb4 Add some bcaster ref sanity asserts around subscriptions 2021-09-01 06:49:09 -04:00
Tyler Goodlet c82ca67263 Add laggy parent stream tests
Add a couple more tests to check that a parent and sub-task stream can
be lagged and recovered (depending on who's slower). Factor some of the
test machinery into a new ctx mngr to make it all happen.
2021-09-01 06:49:09 -04:00
6 changed files with 481 additions and 255 deletions

View File

@ -1,245 +0,0 @@
"""
Broadcast channels for fan-out to local tasks.
"""
from functools import partial
from itertools import cycle
import time
import trio
from trio.lowlevel import current_task
import tractor
from tractor._broadcast import broadcast_receiver, Lagged
@tractor.context
async def echo_sequences(
ctx: tractor.Context,
) -> None:
'''Bidir streaming endpoint which will stream
back any sequence it is sent item-wise.
'''
await ctx.started()
async with ctx.open_stream() as stream:
async for sequence in stream:
seq = list(sequence)
for value in seq:
print(f'sending {value}')
await stream.send(value)
async def ensure_sequence(
stream: tractor.ReceiveMsgStream,
sequence: list,
) -> None:
name = current_task().name
async with stream.subscribe() as bcaster:
assert not isinstance(bcaster, type(stream))
async for value in bcaster:
print(f'{name} rx: {value}')
assert value == sequence[0]
sequence.remove(value)
if not sequence:
# fully consumed
break
def test_stream_fan_out_to_local_subscriptions(
arb_addr,
start_method,
):
sequence = list(range(1000))
async def main():
async with tractor.open_nursery(
arbiter_addr=arb_addr,
start_method=start_method,
) as tn:
portal = await tn.start_actor(
'sequence_echoer',
enable_modules=[__name__],
)
async with portal.open_context(
echo_sequences,
) as (ctx, first):
assert first is None
async with ctx.open_stream() as stream:
async with trio.open_nursery() as n:
for i in range(10):
n.start_soon(
ensure_sequence,
stream,
sequence.copy(),
name=f'consumer_{i}',
)
await stream.send(tuple(sequence))
async for value in stream:
print(f'source stream rx: {value}')
assert value == sequence[0]
sequence.remove(value)
if not sequence:
# fully consumed
break
await portal.cancel_actor()
trio.run(main)
def test_subscribe_errors_after_close():
async def main():
size = 1
tx, rx = trio.open_memory_channel(size)
async with broadcast_receiver(rx, size) as brx:
pass
try:
# open and close
async with brx.subscribe():
pass
except trio.ClosedResourceError:
assert brx.key not in brx._state.subs
else:
assert 0
trio.run(main)
def test_ensure_slow_consumers_lag_out(
arb_addr,
start_method,
):
'''This is a pure local task test; no tractor
machinery is really required.
'''
async def main():
# make sure it all works within the runtime
async with tractor.open_root_actor():
num_laggers = 4
laggers: dict[str, int] = {}
retries = 3
size = 100
tx, rx = trio.open_memory_channel(size)
brx = broadcast_receiver(rx, size)
async def sub_and_print(
delay: float,
) -> None:
task = current_task()
start = time.time()
async with brx.subscribe() as lbrx:
while True:
print(f'{task.name}: starting consume loop')
try:
async for value in lbrx:
print(f'{task.name}: {value}')
await trio.sleep(delay)
if task.name == 'sub_1':
# the non-lagger got
# a ``trio.EndOfChannel``
# because the ``tx`` below was closed
assert len(lbrx._state.subs) == 1
await lbrx.aclose()
assert len(lbrx._state.subs) == 0
except trio.ClosedResourceError:
# only the fast sub will try to re-enter
# iteration on the now closed bcaster
assert task.name == 'sub_1'
return
except Lagged:
lag_time = time.time() - start
lags = laggers[task.name]
print(
f'restarting slow task {task.name} '
f'that bailed out on {lags}:{value} '
f'after {lag_time:.3f}')
if lags <= retries:
laggers[task.name] += 1
continue
else:
print(
f'{task.name} was too slow and terminated '
f'on {lags}:{value}')
return
async with trio.open_nursery() as nursery:
for i in range(1, num_laggers):
task_name = f'sub_{i}'
laggers[task_name] = 0
nursery.start_soon(
partial(
sub_and_print,
delay=i*0.001,
),
name=task_name,
)
# allow subs to sched
await trio.sleep(0.1)
async with tx:
for i in cycle(range(size)):
await tx.send(i)
if len(brx._state.subs) == 2:
# only one, the non lagger, sub is left
break
# the non-lagger
assert laggers.pop('sub_1') == 0
for n, v in laggers.items():
assert v == 4
assert tx._closed
assert not tx._state.open_send_channels
# check that "first" bcaster that we created
# above, never wass iterated and is thus overrun
try:
await brx.receive()
except Lagged:
# expect tokio style index truncation
seq = brx._state.subs[brx.key]
assert seq == len(brx._state.queue) - 1
# all backpressured entries in the underlying
# channel should have been copied into the caster
# queue trailing-window
async for i in rx:
print(f'bped: {i}')
assert i in brx._state.queue
# should be noop
await brx.aclose()
trio.run(main)

View File

@ -0,0 +1,427 @@
"""
Broadcast channels for fan-out to local tasks.
"""
from contextlib import asynccontextmanager
from functools import partial
from itertools import cycle
import time
from typing import Optional, List, Tuple
import pytest
import trio
from trio.lowlevel import current_task
import tractor
from tractor._broadcast import broadcast_receiver, Lagged
@tractor.context
async def echo_sequences(
ctx: tractor.Context,
) -> None:
'''Bidir streaming endpoint which will stream
back any sequence it is sent item-wise.
'''
await ctx.started()
async with ctx.open_stream() as stream:
async for sequence in stream:
seq = list(sequence)
for value in seq:
await stream.send(value)
print(f'producer sent {value}')
async def ensure_sequence(
stream: tractor.ReceiveMsgStream,
sequence: list,
delay: Optional[float] = None,
) -> None:
name = current_task().name
async with stream.subscribe() as bcaster:
assert not isinstance(bcaster, type(stream))
async for value in bcaster:
print(f'{name} rx: {value}')
assert value == sequence[0]
sequence.remove(value)
if delay:
await trio.sleep(delay)
if not sequence:
# fully consumed
break
@asynccontextmanager
async def open_sequence_streamer(
sequence: List[int],
arb_addr: Tuple[str, int],
start_method: str,
shield: bool = False,
) -> tractor.MsgStream:
async with tractor.open_nursery(
arbiter_addr=arb_addr,
start_method=start_method,
) as tn:
portal = await tn.start_actor(
'sequence_echoer',
enable_modules=[__name__],
)
async with portal.open_context(
echo_sequences,
) as (ctx, first):
assert first is None
async with ctx.open_stream(shield=shield) as stream:
yield stream
await portal.cancel_actor()
def test_stream_fan_out_to_local_subscriptions(
arb_addr,
start_method,
):
sequence = list(range(1000))
async def main():
async with open_sequence_streamer(
sequence,
arb_addr,
start_method,
) as stream:
async with trio.open_nursery() as n:
for i in range(10):
n.start_soon(
ensure_sequence,
stream,
sequence.copy(),
name=f'consumer_{i}',
)
await stream.send(tuple(sequence))
async for value in stream:
print(f'source stream rx: {value}')
assert value == sequence[0]
sequence.remove(value)
if not sequence:
# fully consumed
break
trio.run(main)
@pytest.mark.parametrize(
'task_delays',
[
(0.01, 0.001),
(0.001, 0.01),
]
)
def test_consumer_and_parent_maybe_lag(
arb_addr,
start_method,
task_delays,
):
async def main():
sequence = list(range(300))
parent_delay, sub_delay = task_delays
async with open_sequence_streamer(
sequence,
arb_addr,
start_method,
) as stream:
try:
async with trio.open_nursery() as n:
n.start_soon(
ensure_sequence,
stream,
sequence.copy(),
sub_delay,
name='consumer_task',
)
await stream.send(tuple(sequence))
# async for value in stream:
lagged = False
lag_count = 0
while True:
try:
value = await stream.receive()
print(f'source stream rx: {value}')
if lagged:
# re set the sequence starting at our last
# value
sequence = sequence[sequence.index(value) + 1:]
else:
assert value == sequence[0]
sequence.remove(value)
lagged = False
except Lagged:
lagged = True
print(f'source stream lagged after {value}')
lag_count += 1
continue
# lag the parent
await trio.sleep(parent_delay)
if not sequence:
# fully consumed
break
print(f'parent + source stream lagged: {lag_count}')
if parent_delay > sub_delay:
assert lag_count > 0
except Lagged:
# child was lagged
assert parent_delay < sub_delay
trio.run(main)
def test_faster_task_to_recv_is_cancelled_by_slower(
arb_addr,
start_method,
):
'''Ensure that if a faster task consuming from a stream is cancelled
the slower task can continue to receive all expected values.
'''
async def main():
sequence = list(range(1000))
async with open_sequence_streamer(
sequence,
arb_addr,
start_method,
# NOTE: this MUST be set to avoid the stream terminating
# early when the faster subtask is cancelled by the slower
# parent task.
shield=True,
) as stream:
# alt to passing kwarg above.
# with stream.shield():
async with trio.open_nursery() as n:
n.start_soon(
ensure_sequence,
stream,
sequence.copy(),
0,
name='consumer_task',
)
await stream.send(tuple(sequence))
# pull 3 values, cancel the subtask, then
# expect to be able to pull all values still
for i in range(20):
try:
value = await stream.receive()
print(f'source stream rx: {value}')
await trio.sleep(0.01)
except Lagged:
print(f'parent overrun after {value}')
continue
print('cancelling faster subtask')
n.cancel_scope.cancel()
try:
value = await stream.receive()
print(f'source stream after cancel: {value}')
except Lagged:
print(f'parent overrun after {value}')
# expect to see all remaining values
with trio.fail_after(0.5):
async for value in stream:
assert stream._broadcaster._state.recv_ready is None
print(f'source stream rx: {value}')
if value == 999:
# fully consumed and we missed no values once
# the faster subtask was cancelled
break
# await tractor.breakpoint()
# await stream.receive()
print(f'final value: {value}')
trio.run(main)
def test_subscribe_errors_after_close():
async def main():
size = 1
tx, rx = trio.open_memory_channel(size)
async with broadcast_receiver(rx, size) as brx:
pass
try:
# open and close
async with brx.subscribe():
pass
except trio.ClosedResourceError:
assert brx.key not in brx._state.subs
else:
assert 0
trio.run(main)
def test_ensure_slow_consumers_lag_out(
arb_addr,
start_method,
):
'''This is a pure local task test; no tractor
machinery is really required.
'''
async def main():
# make sure it all works within the runtime
async with tractor.open_root_actor():
num_laggers = 4
laggers: dict[str, int] = {}
retries = 3
size = 100
tx, rx = trio.open_memory_channel(size)
brx = broadcast_receiver(rx, size)
async def sub_and_print(
delay: float,
) -> None:
task = current_task()
start = time.time()
async with brx.subscribe() as lbrx:
while True:
print(f'{task.name}: starting consume loop')
try:
async for value in lbrx:
print(f'{task.name}: {value}')
await trio.sleep(delay)
if task.name == 'sub_1':
# the non-lagger got
# a ``trio.EndOfChannel``
# because the ``tx`` below was closed
assert len(lbrx._state.subs) == 1
await lbrx.aclose()
assert len(lbrx._state.subs) == 0
except trio.ClosedResourceError:
# only the fast sub will try to re-enter
# iteration on the now closed bcaster
assert task.name == 'sub_1'
return
except Lagged:
lag_time = time.time() - start
lags = laggers[task.name]
print(
f'restarting slow task {task.name} '
f'that bailed out on {lags}:{value} '
f'after {lag_time:.3f}')
if lags <= retries:
laggers[task.name] += 1
continue
else:
print(
f'{task.name} was too slow and terminated '
f'on {lags}:{value}')
return
async with trio.open_nursery() as nursery:
for i in range(1, num_laggers):
task_name = f'sub_{i}'
laggers[task_name] = 0
nursery.start_soon(
partial(
sub_and_print,
delay=i*0.001,
),
name=task_name,
)
# allow subs to sched
await trio.sleep(0.1)
async with tx:
for i in cycle(range(size)):
await tx.send(i)
if len(brx._state.subs) == 2:
# only one, the non lagger, sub is left
break
# the non-lagger
assert laggers.pop('sub_1') == 0
for n, v in laggers.items():
assert v == 4
assert tx._closed
assert not tx._state.open_send_channels
# check that "first" bcaster that we created
# above, never wass iterated and is thus overrun
try:
await brx.receive()
except Lagged:
# expect tokio style index truncation
seq = brx._state.subs[brx.key]
assert seq == len(brx._state.queue) - 1
# all backpressured entries in the underlying
# channel should have been copied into the caster
# queue trailing-window
async for i in rx:
print(f'bped: {i}')
assert i in brx._state.queue
# should be noop
await brx.aclose()
trio.run(main)

View File

@ -567,7 +567,7 @@ class Actor:
try: try:
send_chan, recv_chan = self._cids2qs[(actorid, cid)] send_chan, recv_chan = self._cids2qs[(actorid, cid)]
except KeyError: except KeyError:
send_chan, recv_chan = trio.open_memory_channel(1000) send_chan, recv_chan = trio.open_memory_channel(2**6)
send_chan.cid = cid # type: ignore send_chan.cid = cid # type: ignore
recv_chan.cid = cid # type: ignore recv_chan.cid = cid # type: ignore
self._cids2qs[(actorid, cid)] = send_chan, recv_chan self._cids2qs[(actorid, cid)] = send_chan, recv_chan

View File

@ -74,6 +74,7 @@ class BroadcastState:
''' '''
queue: deque queue: deque
maxlen: int
# map of underlying instance id keys to receiver instances which # map of underlying instance id keys to receiver instances which
# must be provided as a singleton per broadcaster set. # must be provided as a singleton per broadcaster set.
@ -81,7 +82,7 @@ class BroadcastState:
# broadcast event to wake up all sleeping consumer tasks # broadcast event to wake up all sleeping consumer tasks
# on a newly produced value from the sender. # on a newly produced value from the sender.
recv_ready: Optional[tuple[str, trio.Event]] = None recv_ready: Optional[tuple[int, trio.Event]] = None
class BroadcastReceiver(ReceiveChannel): class BroadcastReceiver(ReceiveChannel):
@ -111,7 +112,7 @@ class BroadcastReceiver(ReceiveChannel):
self._recv = receive_afunc or rx_chan.receive self._recv = receive_afunc or rx_chan.receive
self._closed: bool = False self._closed: bool = False
async def receive(self): async def receive(self) -> ReceiveType:
key = self.key key = self.key
state = self._state state = self._state
@ -150,7 +151,7 @@ class BroadcastReceiver(ReceiveChannel):
# decrement to the last value and expect # decrement to the last value and expect
# consumer to either handle the ``Lagged`` and come back # consumer to either handle the ``Lagged`` and come back
# or bail out on its own (thus un-subscribing) # or bail out on its own (thus un-subscribing)
state.subs[key] = state.queue.maxlen - 1 state.subs[key] = state.maxlen - 1
# this task was overrun by the producer side # this task was overrun by the producer side
task: Task = current_task() task: Task = current_task()
@ -169,10 +170,17 @@ class BroadcastReceiver(ReceiveChannel):
event = trio.Event() event = trio.Event()
state.recv_ready = key, event state.recv_ready = key, event
# if we're cancelled here it should be
# fine to bail without affecting any other consumers
# right?
try: try:
value = await self._recv() value = await self._recv()
# items with lower indices are "newer" # items with lower indices are "newer"
# NOTE: ``collections.deque`` implicitly takes care of
# trucating values outside our ``state.maxlen``. In the
# alt-backend-array-case we'll need to make sure this is
# implemented in similar ringer-buffer-ish style.
state.queue.appendleft(value) state.queue.appendleft(value)
# broadcast new value to all subscribers by increasing # broadcast new value to all subscribers by increasing
@ -193,21 +201,51 @@ class BroadcastReceiver(ReceiveChannel):
): ):
state.subs[sub_key] += 1 state.subs[sub_key] += 1
# NOTE: this should ONLY be set if the above task was *NOT*
# cancelled on the `._recv()` call otherwise sibling
# consumers will be awoken with a sequence of -1
event.set()
return value return value
finally: finally:
# reset receiver waiter task event for next blocking condition # Reset receiver waiter task event for next blocking condition.
event.set() # this MUST be reset even if the above ``.recv()`` call
# was cancelled to avoid the next consumer from blocking on
# an event that won't be set!
state.recv_ready = None state.recv_ready = None
# This task is all caught up and ready to receive the latest # This task is all caught up and ready to receive the latest
# value, so queue sched it on the internal event. # value, so queue sched it on the internal event.
else: else:
seq = state.subs[key]
assert seq == -1 # sanity
_, ev = state.recv_ready _, ev = state.recv_ready
await ev.wait() await ev.wait()
seq = state.subs[key] seq = state.subs[key]
assert seq > -1, f'Invalid sequence {seq}!?'
value = state.queue[seq]
state.subs[key] -= 1 state.subs[key] -= 1
return state.queue[seq] return value
# NOTE: if we ever would like the behaviour where if the
# first task to recv on the underlying is cancelled but it
# still DOES trigger the ``.recv_ready``, event we'll likely need
# this logic:
# if seq > -1:
# # stuff from above..
# elif seq == -1:
# # XXX: In the case where the first task to allocate the
# # ``.recv_ready`` event is cancelled we will be woken with
# # a non-incremented sequence number and thus will read the
# # oldest value if we use that. Instead we need to detect if
# # we have not been incremented and then receive again.
# return await self.receive()
# else:
# raise ValueError(f'Invalid sequence {seq}!?')
@asynccontextmanager @asynccontextmanager
async def subscribe( async def subscribe(
@ -263,6 +301,7 @@ def broadcast_receiver(
recv_chan, recv_chan,
state=BroadcastState( state=BroadcastState(
queue=deque(maxlen=max_buffer_size), queue=deque(maxlen=max_buffer_size),
maxlen=max_buffer_size,
subs={}, subs={},
), ),
**kwargs, **kwargs,

View File

@ -294,6 +294,7 @@ class Portal:
async def open_stream_from( async def open_stream_from(
self, self,
async_gen_func: Callable, # typing: ignore async_gen_func: Callable, # typing: ignore
shield: bool = False,
**kwargs, **kwargs,
) -> AsyncGenerator[ReceiveMsgStream, None]: ) -> AsyncGenerator[ReceiveMsgStream, None]:
@ -320,7 +321,9 @@ class Portal:
ctx = Context(self.channel, cid, _portal=self) ctx = Context(self.channel, cid, _portal=self)
try: try:
# deliver receive only stream # deliver receive only stream
async with ReceiveMsgStream(ctx, recv_chan) as rchan: async with ReceiveMsgStream(
ctx, recv_chan, shield=shield
) as rchan:
self._streams.add(rchan) self._streams.add(rchan)
yield rchan yield rchan

View File

@ -9,7 +9,7 @@ from dataclasses import dataclass
from typing import ( from typing import (
Any, Iterator, Optional, Callable, Any, Iterator, Optional, Callable,
AsyncGenerator, Dict, AsyncGenerator, Dict,
AsyncIterator, Awaitable AsyncIterator
) )
import warnings import warnings
@ -263,7 +263,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
self, self,
# use memory channel size by default # use memory channel size by default
self._rx_chan._state.max_buffer_size, # type: ignore self._rx_chan._state.max_buffer_size, # type: ignore
receive_afunc=self.receive, receive_afunc=self.receive,
) )
# NOTE: we override the original stream instance's receive # NOTE: we override the original stream instance's receive
@ -276,6 +276,8 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
# https://github.com/python/mypy/issues/708 # https://github.com/python/mypy/issues/708
async with self._broadcaster.subscribe() as bstream: async with self._broadcaster.subscribe() as bstream:
assert bstream.key != self._broadcaster.key
assert bstream._recv == self._broadcaster._recv
yield bstream yield bstream