Compare commits
10 Commits
45f334b9c2
...
a7e7c9d1c0
Author | SHA1 | Date |
---|---|---|
|
a7e7c9d1c0 | |
|
c3665801a5 | |
|
71a4f8aaa9 | |
|
7296d171be | |
|
a053a18f53 | |
|
db86409369 | |
|
2c96e85981 | |
|
a0b69fd64b | |
|
727d666cb4 | |
|
c82ca67263 |
|
@ -1,245 +0,0 @@
|
|||
"""
|
||||
Broadcast channels for fan-out to local tasks.
|
||||
"""
|
||||
from functools import partial
|
||||
from itertools import cycle
|
||||
import time
|
||||
|
||||
import trio
|
||||
from trio.lowlevel import current_task
|
||||
import tractor
|
||||
from tractor._broadcast import broadcast_receiver, Lagged
|
||||
|
||||
|
||||
@tractor.context
|
||||
async def echo_sequences(
|
||||
|
||||
ctx: tractor.Context,
|
||||
|
||||
) -> None:
|
||||
'''Bidir streaming endpoint which will stream
|
||||
back any sequence it is sent item-wise.
|
||||
|
||||
'''
|
||||
await ctx.started()
|
||||
|
||||
async with ctx.open_stream() as stream:
|
||||
async for sequence in stream:
|
||||
seq = list(sequence)
|
||||
for value in seq:
|
||||
print(f'sending {value}')
|
||||
await stream.send(value)
|
||||
|
||||
|
||||
async def ensure_sequence(
|
||||
stream: tractor.ReceiveMsgStream,
|
||||
sequence: list,
|
||||
) -> None:
|
||||
|
||||
name = current_task().name
|
||||
async with stream.subscribe() as bcaster:
|
||||
assert not isinstance(bcaster, type(stream))
|
||||
async for value in bcaster:
|
||||
print(f'{name} rx: {value}')
|
||||
assert value == sequence[0]
|
||||
sequence.remove(value)
|
||||
|
||||
if not sequence:
|
||||
# fully consumed
|
||||
break
|
||||
|
||||
|
||||
def test_stream_fan_out_to_local_subscriptions(
|
||||
arb_addr,
|
||||
start_method,
|
||||
):
|
||||
|
||||
sequence = list(range(1000))
|
||||
|
||||
async def main():
|
||||
|
||||
async with tractor.open_nursery(
|
||||
arbiter_addr=arb_addr,
|
||||
start_method=start_method,
|
||||
) as tn:
|
||||
|
||||
portal = await tn.start_actor(
|
||||
'sequence_echoer',
|
||||
enable_modules=[__name__],
|
||||
)
|
||||
|
||||
async with portal.open_context(
|
||||
echo_sequences,
|
||||
) as (ctx, first):
|
||||
|
||||
assert first is None
|
||||
async with ctx.open_stream() as stream:
|
||||
|
||||
async with trio.open_nursery() as n:
|
||||
for i in range(10):
|
||||
n.start_soon(
|
||||
ensure_sequence,
|
||||
stream,
|
||||
sequence.copy(),
|
||||
name=f'consumer_{i}',
|
||||
)
|
||||
|
||||
await stream.send(tuple(sequence))
|
||||
|
||||
async for value in stream:
|
||||
print(f'source stream rx: {value}')
|
||||
assert value == sequence[0]
|
||||
sequence.remove(value)
|
||||
|
||||
if not sequence:
|
||||
# fully consumed
|
||||
break
|
||||
|
||||
await portal.cancel_actor()
|
||||
|
||||
trio.run(main)
|
||||
|
||||
|
||||
def test_subscribe_errors_after_close():
|
||||
|
||||
async def main():
|
||||
|
||||
size = 1
|
||||
tx, rx = trio.open_memory_channel(size)
|
||||
async with broadcast_receiver(rx, size) as brx:
|
||||
pass
|
||||
|
||||
try:
|
||||
# open and close
|
||||
async with brx.subscribe():
|
||||
pass
|
||||
|
||||
except trio.ClosedResourceError:
|
||||
assert brx.key not in brx._state.subs
|
||||
|
||||
else:
|
||||
assert 0
|
||||
|
||||
trio.run(main)
|
||||
|
||||
|
||||
def test_ensure_slow_consumers_lag_out(
|
||||
arb_addr,
|
||||
start_method,
|
||||
):
|
||||
'''This is a pure local task test; no tractor
|
||||
machinery is really required.
|
||||
|
||||
'''
|
||||
async def main():
|
||||
|
||||
# make sure it all works within the runtime
|
||||
async with tractor.open_root_actor():
|
||||
|
||||
num_laggers = 4
|
||||
laggers: dict[str, int] = {}
|
||||
retries = 3
|
||||
size = 100
|
||||
tx, rx = trio.open_memory_channel(size)
|
||||
brx = broadcast_receiver(rx, size)
|
||||
|
||||
async def sub_and_print(
|
||||
delay: float,
|
||||
) -> None:
|
||||
|
||||
task = current_task()
|
||||
start = time.time()
|
||||
|
||||
async with brx.subscribe() as lbrx:
|
||||
while True:
|
||||
print(f'{task.name}: starting consume loop')
|
||||
try:
|
||||
async for value in lbrx:
|
||||
print(f'{task.name}: {value}')
|
||||
await trio.sleep(delay)
|
||||
|
||||
if task.name == 'sub_1':
|
||||
# the non-lagger got
|
||||
# a ``trio.EndOfChannel``
|
||||
# because the ``tx`` below was closed
|
||||
assert len(lbrx._state.subs) == 1
|
||||
|
||||
await lbrx.aclose()
|
||||
|
||||
assert len(lbrx._state.subs) == 0
|
||||
|
||||
except trio.ClosedResourceError:
|
||||
# only the fast sub will try to re-enter
|
||||
# iteration on the now closed bcaster
|
||||
assert task.name == 'sub_1'
|
||||
return
|
||||
|
||||
except Lagged:
|
||||
lag_time = time.time() - start
|
||||
lags = laggers[task.name]
|
||||
print(
|
||||
f'restarting slow task {task.name} '
|
||||
f'that bailed out on {lags}:{value} '
|
||||
f'after {lag_time:.3f}')
|
||||
if lags <= retries:
|
||||
laggers[task.name] += 1
|
||||
continue
|
||||
else:
|
||||
print(
|
||||
f'{task.name} was too slow and terminated '
|
||||
f'on {lags}:{value}')
|
||||
return
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
|
||||
for i in range(1, num_laggers):
|
||||
|
||||
task_name = f'sub_{i}'
|
||||
laggers[task_name] = 0
|
||||
nursery.start_soon(
|
||||
partial(
|
||||
sub_and_print,
|
||||
delay=i*0.001,
|
||||
),
|
||||
name=task_name,
|
||||
)
|
||||
|
||||
# allow subs to sched
|
||||
await trio.sleep(0.1)
|
||||
|
||||
async with tx:
|
||||
for i in cycle(range(size)):
|
||||
await tx.send(i)
|
||||
if len(brx._state.subs) == 2:
|
||||
# only one, the non lagger, sub is left
|
||||
break
|
||||
|
||||
# the non-lagger
|
||||
assert laggers.pop('sub_1') == 0
|
||||
|
||||
for n, v in laggers.items():
|
||||
assert v == 4
|
||||
|
||||
assert tx._closed
|
||||
assert not tx._state.open_send_channels
|
||||
|
||||
# check that "first" bcaster that we created
|
||||
# above, never wass iterated and is thus overrun
|
||||
try:
|
||||
await brx.receive()
|
||||
except Lagged:
|
||||
# expect tokio style index truncation
|
||||
seq = brx._state.subs[brx.key]
|
||||
assert seq == len(brx._state.queue) - 1
|
||||
|
||||
# all backpressured entries in the underlying
|
||||
# channel should have been copied into the caster
|
||||
# queue trailing-window
|
||||
async for i in rx:
|
||||
print(f'bped: {i}')
|
||||
assert i in brx._state.queue
|
||||
|
||||
# should be noop
|
||||
await brx.aclose()
|
||||
|
||||
trio.run(main)
|
|
@ -0,0 +1,427 @@
|
|||
"""
|
||||
Broadcast channels for fan-out to local tasks.
|
||||
"""
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from itertools import cycle
|
||||
import time
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
from trio.lowlevel import current_task
|
||||
import tractor
|
||||
from tractor._broadcast import broadcast_receiver, Lagged
|
||||
|
||||
|
||||
@tractor.context
|
||||
async def echo_sequences(
|
||||
|
||||
ctx: tractor.Context,
|
||||
|
||||
) -> None:
|
||||
'''Bidir streaming endpoint which will stream
|
||||
back any sequence it is sent item-wise.
|
||||
|
||||
'''
|
||||
await ctx.started()
|
||||
|
||||
async with ctx.open_stream() as stream:
|
||||
async for sequence in stream:
|
||||
seq = list(sequence)
|
||||
for value in seq:
|
||||
await stream.send(value)
|
||||
print(f'producer sent {value}')
|
||||
|
||||
|
||||
async def ensure_sequence(
|
||||
|
||||
stream: tractor.ReceiveMsgStream,
|
||||
sequence: list,
|
||||
delay: Optional[float] = None,
|
||||
|
||||
) -> None:
|
||||
|
||||
name = current_task().name
|
||||
async with stream.subscribe() as bcaster:
|
||||
assert not isinstance(bcaster, type(stream))
|
||||
async for value in bcaster:
|
||||
print(f'{name} rx: {value}')
|
||||
assert value == sequence[0]
|
||||
sequence.remove(value)
|
||||
|
||||
if delay:
|
||||
await trio.sleep(delay)
|
||||
|
||||
if not sequence:
|
||||
# fully consumed
|
||||
break
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def open_sequence_streamer(
|
||||
|
||||
sequence: List[int],
|
||||
arb_addr: Tuple[str, int],
|
||||
start_method: str,
|
||||
shield: bool = False,
|
||||
|
||||
) -> tractor.MsgStream:
|
||||
|
||||
async with tractor.open_nursery(
|
||||
arbiter_addr=arb_addr,
|
||||
start_method=start_method,
|
||||
) as tn:
|
||||
|
||||
portal = await tn.start_actor(
|
||||
'sequence_echoer',
|
||||
enable_modules=[__name__],
|
||||
)
|
||||
|
||||
async with portal.open_context(
|
||||
echo_sequences,
|
||||
) as (ctx, first):
|
||||
|
||||
assert first is None
|
||||
async with ctx.open_stream(shield=shield) as stream:
|
||||
yield stream
|
||||
|
||||
await portal.cancel_actor()
|
||||
|
||||
|
||||
def test_stream_fan_out_to_local_subscriptions(
|
||||
arb_addr,
|
||||
start_method,
|
||||
):
|
||||
|
||||
sequence = list(range(1000))
|
||||
|
||||
async def main():
|
||||
|
||||
async with open_sequence_streamer(
|
||||
sequence,
|
||||
arb_addr,
|
||||
start_method,
|
||||
) as stream:
|
||||
|
||||
async with trio.open_nursery() as n:
|
||||
for i in range(10):
|
||||
n.start_soon(
|
||||
ensure_sequence,
|
||||
stream,
|
||||
sequence.copy(),
|
||||
name=f'consumer_{i}',
|
||||
)
|
||||
|
||||
await stream.send(tuple(sequence))
|
||||
|
||||
async for value in stream:
|
||||
print(f'source stream rx: {value}')
|
||||
assert value == sequence[0]
|
||||
sequence.remove(value)
|
||||
|
||||
if not sequence:
|
||||
# fully consumed
|
||||
break
|
||||
|
||||
trio.run(main)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'task_delays',
|
||||
[
|
||||
(0.01, 0.001),
|
||||
(0.001, 0.01),
|
||||
]
|
||||
)
|
||||
def test_consumer_and_parent_maybe_lag(
|
||||
arb_addr,
|
||||
start_method,
|
||||
task_delays,
|
||||
):
|
||||
|
||||
async def main():
|
||||
|
||||
sequence = list(range(300))
|
||||
parent_delay, sub_delay = task_delays
|
||||
|
||||
async with open_sequence_streamer(
|
||||
sequence,
|
||||
arb_addr,
|
||||
start_method,
|
||||
) as stream:
|
||||
|
||||
try:
|
||||
async with trio.open_nursery() as n:
|
||||
|
||||
n.start_soon(
|
||||
ensure_sequence,
|
||||
stream,
|
||||
sequence.copy(),
|
||||
sub_delay,
|
||||
name='consumer_task',
|
||||
)
|
||||
|
||||
await stream.send(tuple(sequence))
|
||||
|
||||
# async for value in stream:
|
||||
lagged = False
|
||||
lag_count = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
value = await stream.receive()
|
||||
print(f'source stream rx: {value}')
|
||||
|
||||
if lagged:
|
||||
# re set the sequence starting at our last
|
||||
# value
|
||||
sequence = sequence[sequence.index(value) + 1:]
|
||||
else:
|
||||
assert value == sequence[0]
|
||||
sequence.remove(value)
|
||||
|
||||
lagged = False
|
||||
|
||||
except Lagged:
|
||||
lagged = True
|
||||
print(f'source stream lagged after {value}')
|
||||
lag_count += 1
|
||||
continue
|
||||
|
||||
# lag the parent
|
||||
await trio.sleep(parent_delay)
|
||||
|
||||
if not sequence:
|
||||
# fully consumed
|
||||
break
|
||||
print(f'parent + source stream lagged: {lag_count}')
|
||||
|
||||
if parent_delay > sub_delay:
|
||||
assert lag_count > 0
|
||||
|
||||
except Lagged:
|
||||
# child was lagged
|
||||
assert parent_delay < sub_delay
|
||||
|
||||
trio.run(main)
|
||||
|
||||
|
||||
def test_faster_task_to_recv_is_cancelled_by_slower(
|
||||
arb_addr,
|
||||
start_method,
|
||||
):
|
||||
'''Ensure that if a faster task consuming from a stream is cancelled
|
||||
the slower task can continue to receive all expected values.
|
||||
|
||||
'''
|
||||
async def main():
|
||||
|
||||
sequence = list(range(1000))
|
||||
|
||||
async with open_sequence_streamer(
|
||||
sequence,
|
||||
arb_addr,
|
||||
start_method,
|
||||
|
||||
# NOTE: this MUST be set to avoid the stream terminating
|
||||
# early when the faster subtask is cancelled by the slower
|
||||
# parent task.
|
||||
shield=True,
|
||||
|
||||
) as stream:
|
||||
|
||||
# alt to passing kwarg above.
|
||||
# with stream.shield():
|
||||
|
||||
async with trio.open_nursery() as n:
|
||||
n.start_soon(
|
||||
ensure_sequence,
|
||||
stream,
|
||||
sequence.copy(),
|
||||
0,
|
||||
name='consumer_task',
|
||||
)
|
||||
|
||||
await stream.send(tuple(sequence))
|
||||
|
||||
# pull 3 values, cancel the subtask, then
|
||||
# expect to be able to pull all values still
|
||||
for i in range(20):
|
||||
try:
|
||||
value = await stream.receive()
|
||||
print(f'source stream rx: {value}')
|
||||
await trio.sleep(0.01)
|
||||
except Lagged:
|
||||
print(f'parent overrun after {value}')
|
||||
continue
|
||||
|
||||
print('cancelling faster subtask')
|
||||
n.cancel_scope.cancel()
|
||||
|
||||
try:
|
||||
value = await stream.receive()
|
||||
print(f'source stream after cancel: {value}')
|
||||
except Lagged:
|
||||
print(f'parent overrun after {value}')
|
||||
|
||||
# expect to see all remaining values
|
||||
with trio.fail_after(0.5):
|
||||
async for value in stream:
|
||||
assert stream._broadcaster._state.recv_ready is None
|
||||
print(f'source stream rx: {value}')
|
||||
if value == 999:
|
||||
# fully consumed and we missed no values once
|
||||
# the faster subtask was cancelled
|
||||
break
|
||||
|
||||
# await tractor.breakpoint()
|
||||
# await stream.receive()
|
||||
print(f'final value: {value}')
|
||||
|
||||
trio.run(main)
|
||||
|
||||
|
||||
def test_subscribe_errors_after_close():
|
||||
|
||||
async def main():
|
||||
|
||||
size = 1
|
||||
tx, rx = trio.open_memory_channel(size)
|
||||
async with broadcast_receiver(rx, size) as brx:
|
||||
pass
|
||||
|
||||
try:
|
||||
# open and close
|
||||
async with brx.subscribe():
|
||||
pass
|
||||
|
||||
except trio.ClosedResourceError:
|
||||
assert brx.key not in brx._state.subs
|
||||
|
||||
else:
|
||||
assert 0
|
||||
|
||||
trio.run(main)
|
||||
|
||||
|
||||
def test_ensure_slow_consumers_lag_out(
|
||||
arb_addr,
|
||||
start_method,
|
||||
):
|
||||
'''This is a pure local task test; no tractor
|
||||
machinery is really required.
|
||||
|
||||
'''
|
||||
async def main():
|
||||
|
||||
# make sure it all works within the runtime
|
||||
async with tractor.open_root_actor():
|
||||
|
||||
num_laggers = 4
|
||||
laggers: dict[str, int] = {}
|
||||
retries = 3
|
||||
size = 100
|
||||
tx, rx = trio.open_memory_channel(size)
|
||||
brx = broadcast_receiver(rx, size)
|
||||
|
||||
async def sub_and_print(
|
||||
delay: float,
|
||||
) -> None:
|
||||
|
||||
task = current_task()
|
||||
start = time.time()
|
||||
|
||||
async with brx.subscribe() as lbrx:
|
||||
while True:
|
||||
print(f'{task.name}: starting consume loop')
|
||||
try:
|
||||
async for value in lbrx:
|
||||
print(f'{task.name}: {value}')
|
||||
await trio.sleep(delay)
|
||||
|
||||
if task.name == 'sub_1':
|
||||
# the non-lagger got
|
||||
# a ``trio.EndOfChannel``
|
||||
# because the ``tx`` below was closed
|
||||
assert len(lbrx._state.subs) == 1
|
||||
|
||||
await lbrx.aclose()
|
||||
|
||||
assert len(lbrx._state.subs) == 0
|
||||
|
||||
except trio.ClosedResourceError:
|
||||
# only the fast sub will try to re-enter
|
||||
# iteration on the now closed bcaster
|
||||
assert task.name == 'sub_1'
|
||||
return
|
||||
|
||||
except Lagged:
|
||||
lag_time = time.time() - start
|
||||
lags = laggers[task.name]
|
||||
print(
|
||||
f'restarting slow task {task.name} '
|
||||
f'that bailed out on {lags}:{value} '
|
||||
f'after {lag_time:.3f}')
|
||||
if lags <= retries:
|
||||
laggers[task.name] += 1
|
||||
continue
|
||||
else:
|
||||
print(
|
||||
f'{task.name} was too slow and terminated '
|
||||
f'on {lags}:{value}')
|
||||
return
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
|
||||
for i in range(1, num_laggers):
|
||||
|
||||
task_name = f'sub_{i}'
|
||||
laggers[task_name] = 0
|
||||
nursery.start_soon(
|
||||
partial(
|
||||
sub_and_print,
|
||||
delay=i*0.001,
|
||||
),
|
||||
name=task_name,
|
||||
)
|
||||
|
||||
# allow subs to sched
|
||||
await trio.sleep(0.1)
|
||||
|
||||
async with tx:
|
||||
for i in cycle(range(size)):
|
||||
await tx.send(i)
|
||||
if len(brx._state.subs) == 2:
|
||||
# only one, the non lagger, sub is left
|
||||
break
|
||||
|
||||
# the non-lagger
|
||||
assert laggers.pop('sub_1') == 0
|
||||
|
||||
for n, v in laggers.items():
|
||||
assert v == 4
|
||||
|
||||
assert tx._closed
|
||||
assert not tx._state.open_send_channels
|
||||
|
||||
# check that "first" bcaster that we created
|
||||
# above, never wass iterated and is thus overrun
|
||||
try:
|
||||
await brx.receive()
|
||||
except Lagged:
|
||||
# expect tokio style index truncation
|
||||
seq = brx._state.subs[brx.key]
|
||||
assert seq == len(brx._state.queue) - 1
|
||||
|
||||
# all backpressured entries in the underlying
|
||||
# channel should have been copied into the caster
|
||||
# queue trailing-window
|
||||
async for i in rx:
|
||||
print(f'bped: {i}')
|
||||
assert i in brx._state.queue
|
||||
|
||||
# should be noop
|
||||
await brx.aclose()
|
||||
|
||||
trio.run(main)
|
|
@ -567,7 +567,7 @@ class Actor:
|
|||
try:
|
||||
send_chan, recv_chan = self._cids2qs[(actorid, cid)]
|
||||
except KeyError:
|
||||
send_chan, recv_chan = trio.open_memory_channel(1000)
|
||||
send_chan, recv_chan = trio.open_memory_channel(2**6)
|
||||
send_chan.cid = cid # type: ignore
|
||||
recv_chan.cid = cid # type: ignore
|
||||
self._cids2qs[(actorid, cid)] = send_chan, recv_chan
|
||||
|
|
|
@ -74,6 +74,7 @@ class BroadcastState:
|
|||
|
||||
'''
|
||||
queue: deque
|
||||
maxlen: int
|
||||
|
||||
# map of underlying instance id keys to receiver instances which
|
||||
# must be provided as a singleton per broadcaster set.
|
||||
|
@ -81,7 +82,7 @@ class BroadcastState:
|
|||
|
||||
# broadcast event to wake up all sleeping consumer tasks
|
||||
# on a newly produced value from the sender.
|
||||
recv_ready: Optional[tuple[str, trio.Event]] = None
|
||||
recv_ready: Optional[tuple[int, trio.Event]] = None
|
||||
|
||||
|
||||
class BroadcastReceiver(ReceiveChannel):
|
||||
|
@ -111,7 +112,7 @@ class BroadcastReceiver(ReceiveChannel):
|
|||
self._recv = receive_afunc or rx_chan.receive
|
||||
self._closed: bool = False
|
||||
|
||||
async def receive(self):
|
||||
async def receive(self) -> ReceiveType:
|
||||
|
||||
key = self.key
|
||||
state = self._state
|
||||
|
@ -150,7 +151,7 @@ class BroadcastReceiver(ReceiveChannel):
|
|||
# decrement to the last value and expect
|
||||
# consumer to either handle the ``Lagged`` and come back
|
||||
# or bail out on its own (thus un-subscribing)
|
||||
state.subs[key] = state.queue.maxlen - 1
|
||||
state.subs[key] = state.maxlen - 1
|
||||
|
||||
# this task was overrun by the producer side
|
||||
task: Task = current_task()
|
||||
|
@ -169,10 +170,17 @@ class BroadcastReceiver(ReceiveChannel):
|
|||
event = trio.Event()
|
||||
state.recv_ready = key, event
|
||||
|
||||
# if we're cancelled here it should be
|
||||
# fine to bail without affecting any other consumers
|
||||
# right?
|
||||
try:
|
||||
value = await self._recv()
|
||||
|
||||
# items with lower indices are "newer"
|
||||
# NOTE: ``collections.deque`` implicitly takes care of
|
||||
# trucating values outside our ``state.maxlen``. In the
|
||||
# alt-backend-array-case we'll need to make sure this is
|
||||
# implemented in similar ringer-buffer-ish style.
|
||||
state.queue.appendleft(value)
|
||||
|
||||
# broadcast new value to all subscribers by increasing
|
||||
|
@ -193,21 +201,51 @@ class BroadcastReceiver(ReceiveChannel):
|
|||
):
|
||||
state.subs[sub_key] += 1
|
||||
|
||||
# NOTE: this should ONLY be set if the above task was *NOT*
|
||||
# cancelled on the `._recv()` call otherwise sibling
|
||||
# consumers will be awoken with a sequence of -1
|
||||
event.set()
|
||||
|
||||
return value
|
||||
|
||||
finally:
|
||||
# reset receiver waiter task event for next blocking condition
|
||||
event.set()
|
||||
# Reset receiver waiter task event for next blocking condition.
|
||||
# this MUST be reset even if the above ``.recv()`` call
|
||||
# was cancelled to avoid the next consumer from blocking on
|
||||
# an event that won't be set!
|
||||
state.recv_ready = None
|
||||
|
||||
# This task is all caught up and ready to receive the latest
|
||||
# value, so queue sched it on the internal event.
|
||||
else:
|
||||
seq = state.subs[key]
|
||||
assert seq == -1 # sanity
|
||||
_, ev = state.recv_ready
|
||||
await ev.wait()
|
||||
|
||||
seq = state.subs[key]
|
||||
assert seq > -1, f'Invalid sequence {seq}!?'
|
||||
|
||||
value = state.queue[seq]
|
||||
state.subs[key] -= 1
|
||||
return state.queue[seq]
|
||||
return value
|
||||
|
||||
# NOTE: if we ever would like the behaviour where if the
|
||||
# first task to recv on the underlying is cancelled but it
|
||||
# still DOES trigger the ``.recv_ready``, event we'll likely need
|
||||
# this logic:
|
||||
|
||||
# if seq > -1:
|
||||
# # stuff from above..
|
||||
# elif seq == -1:
|
||||
# # XXX: In the case where the first task to allocate the
|
||||
# # ``.recv_ready`` event is cancelled we will be woken with
|
||||
# # a non-incremented sequence number and thus will read the
|
||||
# # oldest value if we use that. Instead we need to detect if
|
||||
# # we have not been incremented and then receive again.
|
||||
# return await self.receive()
|
||||
# else:
|
||||
# raise ValueError(f'Invalid sequence {seq}!?')
|
||||
|
||||
@asynccontextmanager
|
||||
async def subscribe(
|
||||
|
@ -263,6 +301,7 @@ def broadcast_receiver(
|
|||
recv_chan,
|
||||
state=BroadcastState(
|
||||
queue=deque(maxlen=max_buffer_size),
|
||||
maxlen=max_buffer_size,
|
||||
subs={},
|
||||
),
|
||||
**kwargs,
|
||||
|
|
|
@ -294,6 +294,7 @@ class Portal:
|
|||
async def open_stream_from(
|
||||
self,
|
||||
async_gen_func: Callable, # typing: ignore
|
||||
shield: bool = False,
|
||||
**kwargs,
|
||||
|
||||
) -> AsyncGenerator[ReceiveMsgStream, None]:
|
||||
|
@ -320,7 +321,9 @@ class Portal:
|
|||
ctx = Context(self.channel, cid, _portal=self)
|
||||
try:
|
||||
# deliver receive only stream
|
||||
async with ReceiveMsgStream(ctx, recv_chan) as rchan:
|
||||
async with ReceiveMsgStream(
|
||||
ctx, recv_chan, shield=shield
|
||||
) as rchan:
|
||||
self._streams.add(rchan)
|
||||
yield rchan
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ from dataclasses import dataclass
|
|||
from typing import (
|
||||
Any, Iterator, Optional, Callable,
|
||||
AsyncGenerator, Dict,
|
||||
AsyncIterator, Awaitable
|
||||
AsyncIterator
|
||||
)
|
||||
|
||||
import warnings
|
||||
|
@ -276,6 +276,8 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
|
|||
# https://github.com/python/mypy/issues/708
|
||||
|
||||
async with self._broadcaster.subscribe() as bstream:
|
||||
assert bstream.key != self._broadcaster.key
|
||||
assert bstream._recv == self._broadcaster._recv
|
||||
yield bstream
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue