forked from goodboy/tractor
commit
3f1bc37143
|
@ -0,0 +1,12 @@
|
||||||
|
Add `tokio-style broadcast channels
|
||||||
|
<https://docs.rs/tokio/1.11.0/tokio/sync/broadcast/index.html>`_ as
|
||||||
|
a solution for `#204 <https://github.com/goodboy/tractor/pull/204>`_ and
|
||||||
|
discussed thoroughly in `trio/#987
|
||||||
|
<https://github.com/python-trio/trio/issues/987>`_.
|
||||||
|
|
||||||
|
This gives us local task broadcast functionality using a new
|
||||||
|
``BroadcastReceiver`` type which can wrap ``trio.ReceiveChannel`` and
|
||||||
|
provide fan-out copies of a stream of data to every subscribed consumer.
|
||||||
|
We use this new machinery to provide a ``ReceiveMsgStream.subscribe()``
|
||||||
|
async context manager which can be used by actor-local concumers tasks
|
||||||
|
to easily pull from a shared and dynamic IPC stream.
|
|
@ -0,0 +1,459 @@
|
||||||
|
"""
|
||||||
|
Broadcast channels for fan-out to local tasks.
|
||||||
|
"""
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from functools import partial
|
||||||
|
from itertools import cycle
|
||||||
|
import time
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import trio
|
||||||
|
from trio.lowlevel import current_task
|
||||||
|
import tractor
|
||||||
|
from tractor._broadcast import broadcast_receiver, Lagged
|
||||||
|
|
||||||
|
|
||||||
|
@tractor.context
|
||||||
|
async def echo_sequences(
|
||||||
|
|
||||||
|
ctx: tractor.Context,
|
||||||
|
|
||||||
|
) -> None:
|
||||||
|
'''Bidir streaming endpoint which will stream
|
||||||
|
back any sequence it is sent item-wise.
|
||||||
|
|
||||||
|
'''
|
||||||
|
await ctx.started()
|
||||||
|
|
||||||
|
async with ctx.open_stream() as stream:
|
||||||
|
async for sequence in stream:
|
||||||
|
seq = list(sequence)
|
||||||
|
for value in seq:
|
||||||
|
await stream.send(value)
|
||||||
|
print(f'producer sent {value}')
|
||||||
|
|
||||||
|
|
||||||
|
async def ensure_sequence(
|
||||||
|
|
||||||
|
stream: tractor.ReceiveMsgStream,
|
||||||
|
sequence: list,
|
||||||
|
delay: Optional[float] = None,
|
||||||
|
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
name = current_task().name
|
||||||
|
async with stream.subscribe() as bcaster:
|
||||||
|
assert not isinstance(bcaster, type(stream))
|
||||||
|
async for value in bcaster:
|
||||||
|
print(f'{name} rx: {value}')
|
||||||
|
assert value == sequence[0]
|
||||||
|
sequence.remove(value)
|
||||||
|
|
||||||
|
if delay:
|
||||||
|
await trio.sleep(delay)
|
||||||
|
|
||||||
|
if not sequence:
|
||||||
|
# fully consumed
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def open_sequence_streamer(
|
||||||
|
|
||||||
|
sequence: List[int],
|
||||||
|
arb_addr: Tuple[str, int],
|
||||||
|
start_method: str,
|
||||||
|
|
||||||
|
) -> tractor.MsgStream:
|
||||||
|
|
||||||
|
async with tractor.open_nursery(
|
||||||
|
arbiter_addr=arb_addr,
|
||||||
|
start_method=start_method,
|
||||||
|
) as tn:
|
||||||
|
|
||||||
|
portal = await tn.start_actor(
|
||||||
|
'sequence_echoer',
|
||||||
|
enable_modules=[__name__],
|
||||||
|
)
|
||||||
|
|
||||||
|
async with portal.open_context(
|
||||||
|
echo_sequences,
|
||||||
|
) as (ctx, first):
|
||||||
|
|
||||||
|
assert first is None
|
||||||
|
async with ctx.open_stream() as stream:
|
||||||
|
yield stream
|
||||||
|
|
||||||
|
await portal.cancel_actor()
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_fan_out_to_local_subscriptions(
|
||||||
|
arb_addr,
|
||||||
|
start_method,
|
||||||
|
):
|
||||||
|
|
||||||
|
sequence = list(range(1000))
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
|
||||||
|
async with open_sequence_streamer(
|
||||||
|
sequence,
|
||||||
|
arb_addr,
|
||||||
|
start_method,
|
||||||
|
) as stream:
|
||||||
|
|
||||||
|
async with trio.open_nursery() as n:
|
||||||
|
for i in range(10):
|
||||||
|
n.start_soon(
|
||||||
|
ensure_sequence,
|
||||||
|
stream,
|
||||||
|
sequence.copy(),
|
||||||
|
name=f'consumer_{i}',
|
||||||
|
)
|
||||||
|
|
||||||
|
await stream.send(tuple(sequence))
|
||||||
|
|
||||||
|
async for value in stream:
|
||||||
|
print(f'source stream rx: {value}')
|
||||||
|
assert value == sequence[0]
|
||||||
|
sequence.remove(value)
|
||||||
|
|
||||||
|
if not sequence:
|
||||||
|
# fully consumed
|
||||||
|
break
|
||||||
|
|
||||||
|
trio.run(main)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'task_delays',
|
||||||
|
[
|
||||||
|
(0.01, 0.001),
|
||||||
|
(0.001, 0.01),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_consumer_and_parent_maybe_lag(
|
||||||
|
arb_addr,
|
||||||
|
start_method,
|
||||||
|
task_delays,
|
||||||
|
):
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
|
||||||
|
sequence = list(range(300))
|
||||||
|
parent_delay, sub_delay = task_delays
|
||||||
|
|
||||||
|
async with open_sequence_streamer(
|
||||||
|
sequence,
|
||||||
|
arb_addr,
|
||||||
|
start_method,
|
||||||
|
) as stream:
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with trio.open_nursery() as n:
|
||||||
|
|
||||||
|
n.start_soon(
|
||||||
|
ensure_sequence,
|
||||||
|
stream,
|
||||||
|
sequence.copy(),
|
||||||
|
sub_delay,
|
||||||
|
name='consumer_task',
|
||||||
|
)
|
||||||
|
|
||||||
|
await stream.send(tuple(sequence))
|
||||||
|
|
||||||
|
# async for value in stream:
|
||||||
|
lagged = False
|
||||||
|
lag_count = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
value = await stream.receive()
|
||||||
|
print(f'source stream rx: {value}')
|
||||||
|
|
||||||
|
if lagged:
|
||||||
|
# re set the sequence starting at our last
|
||||||
|
# value
|
||||||
|
sequence = sequence[sequence.index(value) + 1:]
|
||||||
|
else:
|
||||||
|
assert value == sequence[0]
|
||||||
|
sequence.remove(value)
|
||||||
|
|
||||||
|
lagged = False
|
||||||
|
|
||||||
|
except Lagged:
|
||||||
|
lagged = True
|
||||||
|
print(f'source stream lagged after {value}')
|
||||||
|
lag_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# lag the parent
|
||||||
|
await trio.sleep(parent_delay)
|
||||||
|
|
||||||
|
if not sequence:
|
||||||
|
# fully consumed
|
||||||
|
break
|
||||||
|
print(f'parent + source stream lagged: {lag_count}')
|
||||||
|
|
||||||
|
if parent_delay > sub_delay:
|
||||||
|
assert lag_count > 0
|
||||||
|
|
||||||
|
except Lagged:
|
||||||
|
# child was lagged
|
||||||
|
assert parent_delay < sub_delay
|
||||||
|
|
||||||
|
trio.run(main)
|
||||||
|
|
||||||
|
|
||||||
|
def test_faster_task_to_recv_is_cancelled_by_slower(
|
||||||
|
arb_addr,
|
||||||
|
start_method,
|
||||||
|
):
|
||||||
|
'''Ensure that if a faster task consuming from a stream is cancelled
|
||||||
|
the slower task can continue to receive all expected values.
|
||||||
|
|
||||||
|
'''
|
||||||
|
async def main():
|
||||||
|
|
||||||
|
sequence = list(range(1000))
|
||||||
|
|
||||||
|
async with open_sequence_streamer(
|
||||||
|
sequence,
|
||||||
|
arb_addr,
|
||||||
|
start_method,
|
||||||
|
|
||||||
|
) as stream:
|
||||||
|
|
||||||
|
async with trio.open_nursery() as n:
|
||||||
|
n.start_soon(
|
||||||
|
ensure_sequence,
|
||||||
|
stream,
|
||||||
|
sequence.copy(),
|
||||||
|
0,
|
||||||
|
name='consumer_task',
|
||||||
|
)
|
||||||
|
|
||||||
|
await stream.send(tuple(sequence))
|
||||||
|
|
||||||
|
# pull 3 values, cancel the subtask, then
|
||||||
|
# expect to be able to pull all values still
|
||||||
|
for i in range(20):
|
||||||
|
try:
|
||||||
|
value = await stream.receive()
|
||||||
|
print(f'source stream rx: {value}')
|
||||||
|
await trio.sleep(0.01)
|
||||||
|
except Lagged:
|
||||||
|
print(f'parent overrun after {value}')
|
||||||
|
continue
|
||||||
|
|
||||||
|
print('cancelling faster subtask')
|
||||||
|
n.cancel_scope.cancel()
|
||||||
|
|
||||||
|
try:
|
||||||
|
value = await stream.receive()
|
||||||
|
print(f'source stream after cancel: {value}')
|
||||||
|
except Lagged:
|
||||||
|
print(f'parent overrun after {value}')
|
||||||
|
|
||||||
|
# expect to see all remaining values
|
||||||
|
with trio.fail_after(0.5):
|
||||||
|
async for value in stream:
|
||||||
|
assert stream._broadcaster._state.recv_ready is None
|
||||||
|
print(f'source stream rx: {value}')
|
||||||
|
if value == 999:
|
||||||
|
# fully consumed and we missed no values once
|
||||||
|
# the faster subtask was cancelled
|
||||||
|
break
|
||||||
|
|
||||||
|
# await tractor.breakpoint()
|
||||||
|
# await stream.receive()
|
||||||
|
print(f'final value: {value}')
|
||||||
|
|
||||||
|
trio.run(main)
|
||||||
|
|
||||||
|
|
||||||
|
def test_subscribe_errors_after_close():
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
|
||||||
|
size = 1
|
||||||
|
tx, rx = trio.open_memory_channel(size)
|
||||||
|
async with broadcast_receiver(rx, size) as brx:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
# open and close
|
||||||
|
async with brx.subscribe():
|
||||||
|
pass
|
||||||
|
|
||||||
|
except trio.ClosedResourceError:
|
||||||
|
assert brx.key not in brx._state.subs
|
||||||
|
|
||||||
|
else:
|
||||||
|
assert 0
|
||||||
|
|
||||||
|
trio.run(main)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_slow_consumers_lag_out(
|
||||||
|
arb_addr,
|
||||||
|
start_method,
|
||||||
|
):
|
||||||
|
'''This is a pure local task test; no tractor
|
||||||
|
machinery is really required.
|
||||||
|
|
||||||
|
'''
|
||||||
|
async def main():
|
||||||
|
|
||||||
|
# make sure it all works within the runtime
|
||||||
|
async with tractor.open_root_actor():
|
||||||
|
|
||||||
|
num_laggers = 4
|
||||||
|
laggers: dict[str, int] = {}
|
||||||
|
retries = 3
|
||||||
|
size = 100
|
||||||
|
tx, rx = trio.open_memory_channel(size)
|
||||||
|
brx = broadcast_receiver(rx, size)
|
||||||
|
|
||||||
|
async def sub_and_print(
|
||||||
|
delay: float,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
task = current_task()
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
async with brx.subscribe() as lbrx:
|
||||||
|
while True:
|
||||||
|
print(f'{task.name}: starting consume loop')
|
||||||
|
try:
|
||||||
|
async for value in lbrx:
|
||||||
|
print(f'{task.name}: {value}')
|
||||||
|
await trio.sleep(delay)
|
||||||
|
|
||||||
|
if task.name == 'sub_1':
|
||||||
|
# the non-lagger got
|
||||||
|
# a ``trio.EndOfChannel``
|
||||||
|
# because the ``tx`` below was closed
|
||||||
|
assert len(lbrx._state.subs) == 1
|
||||||
|
|
||||||
|
await lbrx.aclose()
|
||||||
|
|
||||||
|
assert len(lbrx._state.subs) == 0
|
||||||
|
|
||||||
|
except trio.ClosedResourceError:
|
||||||
|
# only the fast sub will try to re-enter
|
||||||
|
# iteration on the now closed bcaster
|
||||||
|
assert task.name == 'sub_1'
|
||||||
|
return
|
||||||
|
|
||||||
|
except Lagged:
|
||||||
|
lag_time = time.time() - start
|
||||||
|
lags = laggers[task.name]
|
||||||
|
print(
|
||||||
|
f'restarting slow task {task.name} '
|
||||||
|
f'that bailed out on {lags}:{value} '
|
||||||
|
f'after {lag_time:.3f}')
|
||||||
|
if lags <= retries:
|
||||||
|
laggers[task.name] += 1
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f'{task.name} was too slow and terminated '
|
||||||
|
f'on {lags}:{value}')
|
||||||
|
return
|
||||||
|
|
||||||
|
async with trio.open_nursery() as nursery:
|
||||||
|
|
||||||
|
for i in range(1, num_laggers):
|
||||||
|
|
||||||
|
task_name = f'sub_{i}'
|
||||||
|
laggers[task_name] = 0
|
||||||
|
nursery.start_soon(
|
||||||
|
partial(
|
||||||
|
sub_and_print,
|
||||||
|
delay=i*0.001,
|
||||||
|
),
|
||||||
|
name=task_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# allow subs to sched
|
||||||
|
await trio.sleep(0.1)
|
||||||
|
|
||||||
|
async with tx:
|
||||||
|
for i in cycle(range(size)):
|
||||||
|
await tx.send(i)
|
||||||
|
if len(brx._state.subs) == 2:
|
||||||
|
# only one, the non lagger, sub is left
|
||||||
|
break
|
||||||
|
|
||||||
|
# the non-lagger
|
||||||
|
assert laggers.pop('sub_1') == 0
|
||||||
|
|
||||||
|
for n, v in laggers.items():
|
||||||
|
assert v == 4
|
||||||
|
|
||||||
|
assert tx._closed
|
||||||
|
assert not tx._state.open_send_channels
|
||||||
|
|
||||||
|
# check that "first" bcaster that we created
|
||||||
|
# above, never wass iterated and is thus overrun
|
||||||
|
try:
|
||||||
|
await brx.receive()
|
||||||
|
except Lagged:
|
||||||
|
# expect tokio style index truncation
|
||||||
|
seq = brx._state.subs[brx.key]
|
||||||
|
assert seq == len(brx._state.queue) - 1
|
||||||
|
|
||||||
|
# all backpressured entries in the underlying
|
||||||
|
# channel should have been copied into the caster
|
||||||
|
# queue trailing-window
|
||||||
|
async for i in rx:
|
||||||
|
print(f'bped: {i}')
|
||||||
|
assert i in brx._state.queue
|
||||||
|
|
||||||
|
# should be noop
|
||||||
|
await brx.aclose()
|
||||||
|
|
||||||
|
trio.run(main)
|
||||||
|
|
||||||
|
|
||||||
|
def test_first_recver_is_cancelled():
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
|
||||||
|
# make sure it all works within the runtime
|
||||||
|
async with tractor.open_root_actor():
|
||||||
|
|
||||||
|
tx, rx = trio.open_memory_channel(1)
|
||||||
|
brx = broadcast_receiver(rx, 1)
|
||||||
|
cs = trio.CancelScope()
|
||||||
|
sequence = list(range(3))
|
||||||
|
|
||||||
|
async def sub_and_recv():
|
||||||
|
with cs:
|
||||||
|
async with brx.subscribe() as bc:
|
||||||
|
async for value in bc:
|
||||||
|
print(value)
|
||||||
|
|
||||||
|
async def cancel_and_send():
|
||||||
|
await trio.sleep(0.2)
|
||||||
|
cs.cancel()
|
||||||
|
await tx.send(1)
|
||||||
|
|
||||||
|
async with trio.open_nursery() as n:
|
||||||
|
|
||||||
|
n.start_soon(sub_and_recv)
|
||||||
|
await trio.sleep(0.1)
|
||||||
|
assert brx._state.recv_ready
|
||||||
|
|
||||||
|
n.start_soon(cancel_and_send)
|
||||||
|
|
||||||
|
# ensure that we don't hang because no-task is now
|
||||||
|
# waiting on the underlying receive..
|
||||||
|
with trio.fail_after(0.5):
|
||||||
|
value = await brx.receive()
|
||||||
|
print(f'parent: {value}')
|
||||||
|
assert value == 1
|
||||||
|
|
||||||
|
trio.run(main)
|
|
@ -567,7 +567,7 @@ class Actor:
|
||||||
try:
|
try:
|
||||||
send_chan, recv_chan = self._cids2qs[(actorid, cid)]
|
send_chan, recv_chan = self._cids2qs[(actorid, cid)]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
send_chan, recv_chan = trio.open_memory_channel(1000)
|
send_chan, recv_chan = trio.open_memory_channel(2**6)
|
||||||
send_chan.cid = cid # type: ignore
|
send_chan.cid = cid # type: ignore
|
||||||
recv_chan.cid = cid # type: ignore
|
recv_chan.cid = cid # type: ignore
|
||||||
self._cids2qs[(actorid, cid)] = send_chan, recv_chan
|
self._cids2qs[(actorid, cid)] = send_chan, recv_chan
|
||||||
|
|
|
@ -0,0 +1,315 @@
|
||||||
|
'''
|
||||||
|
``tokio`` style broadcast channel.
|
||||||
|
https://docs.rs/tokio/1.11.0/tokio/sync/broadcast/index.html
|
||||||
|
|
||||||
|
'''
|
||||||
|
from __future__ import annotations
|
||||||
|
from abc import abstractmethod
|
||||||
|
from collections import deque
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
|
from operator import ne
|
||||||
|
from typing import Optional, Callable, Awaitable, Any, AsyncIterator, Protocol
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
|
import trio
|
||||||
|
from trio._core._run import Task
|
||||||
|
from trio.abc import ReceiveChannel
|
||||||
|
from trio.lowlevel import current_task
|
||||||
|
|
||||||
|
|
||||||
|
# A regular invariant generic type
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
# covariant because AsyncReceiver[Derived] can be passed to someone
|
||||||
|
# expecting AsyncReceiver[Base])
|
||||||
|
ReceiveType = TypeVar("ReceiveType", covariant=True)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncReceiver(
|
||||||
|
Protocol,
|
||||||
|
Generic[ReceiveType],
|
||||||
|
):
|
||||||
|
'''An async receivable duck-type that quacks much like trio's
|
||||||
|
``trio.abc.ReceieveChannel``.
|
||||||
|
|
||||||
|
'''
|
||||||
|
@abstractmethod
|
||||||
|
async def receive(self) -> ReceiveType:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __aiter__(self) -> AsyncIterator[ReceiveType]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def __anext__(self) -> ReceiveType:
|
||||||
|
...
|
||||||
|
|
||||||
|
# ``trio.abc.AsyncResource`` methods
|
||||||
|
@abstractmethod
|
||||||
|
async def aclose(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def __aenter__(self) -> AsyncReceiver[ReceiveType]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def __aexit__(self, *args) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class Lagged(trio.TooSlowError):
|
||||||
|
'''Subscribed consumer task was too slow and was overrun
|
||||||
|
by the fastest consumer-producer pair.
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BroadcastState:
|
||||||
|
'''Common state to all receivers of a broadcast.
|
||||||
|
|
||||||
|
'''
|
||||||
|
queue: deque
|
||||||
|
maxlen: int
|
||||||
|
|
||||||
|
# map of underlying instance id keys to receiver instances which
|
||||||
|
# must be provided as a singleton per broadcaster set.
|
||||||
|
subs: dict[int, int]
|
||||||
|
|
||||||
|
# broadcast event to wake up all sleeping consumer tasks
|
||||||
|
# on a newly produced value from the sender.
|
||||||
|
recv_ready: Optional[tuple[int, trio.Event]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class BroadcastReceiver(ReceiveChannel):
|
||||||
|
'''A memory receive channel broadcaster which is non-lossy for the
|
||||||
|
fastest consumer.
|
||||||
|
|
||||||
|
Additional consumer tasks can receive all produced values by registering
|
||||||
|
with ``.subscribe()`` and receiving from the new instance it delivers.
|
||||||
|
|
||||||
|
'''
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
|
||||||
|
rx_chan: AsyncReceiver,
|
||||||
|
state: BroadcastState,
|
||||||
|
receive_afunc: Optional[Callable[[], Awaitable[Any]]] = None,
|
||||||
|
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
# register the original underlying (clone)
|
||||||
|
self.key = id(self)
|
||||||
|
self._state = state
|
||||||
|
state.subs[self.key] = -1
|
||||||
|
|
||||||
|
# underlying for this receiver
|
||||||
|
self._rx = rx_chan
|
||||||
|
self._recv = receive_afunc or rx_chan.receive
|
||||||
|
self._closed: bool = False
|
||||||
|
|
||||||
|
async def receive(self) -> ReceiveType:
|
||||||
|
|
||||||
|
key = self.key
|
||||||
|
state = self._state
|
||||||
|
|
||||||
|
# TODO: ideally we can make some way to "lock out" the
|
||||||
|
# underlying receive channel in some way such that if some task
|
||||||
|
# tries to pull from it directly (i.e. one we're unaware of)
|
||||||
|
# then it errors out.
|
||||||
|
|
||||||
|
# only tasks which have entered ``.subscribe()`` can
|
||||||
|
# receive on this broadcaster.
|
||||||
|
try:
|
||||||
|
seq = state.subs[key]
|
||||||
|
except KeyError:
|
||||||
|
if self._closed:
|
||||||
|
raise trio.ClosedResourceError
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
f'{self} is not registerd as subscriber')
|
||||||
|
|
||||||
|
# check that task does not already have a value it can receive
|
||||||
|
# immediately and/or that it has lagged.
|
||||||
|
if seq > -1:
|
||||||
|
# get the oldest value we haven't received immediately
|
||||||
|
try:
|
||||||
|
value = state.queue[seq]
|
||||||
|
except IndexError:
|
||||||
|
|
||||||
|
# adhere to ``tokio`` style "lagging":
|
||||||
|
# "Once RecvError::Lagged is returned, the lagging
|
||||||
|
# receiver's position is updated to the oldest value
|
||||||
|
# contained by the channel. The next call to recv will
|
||||||
|
# return this value."
|
||||||
|
# https://docs.rs/tokio/1.11.0/tokio/sync/broadcast/index.html#lagging
|
||||||
|
|
||||||
|
# decrement to the last value and expect
|
||||||
|
# consumer to either handle the ``Lagged`` and come back
|
||||||
|
# or bail out on its own (thus un-subscribing)
|
||||||
|
state.subs[key] = state.maxlen - 1
|
||||||
|
|
||||||
|
# this task was overrun by the producer side
|
||||||
|
task: Task = current_task()
|
||||||
|
raise Lagged(f'Task {task.name} was overrun')
|
||||||
|
|
||||||
|
state.subs[key] -= 1
|
||||||
|
return value
|
||||||
|
|
||||||
|
# current task already has the latest value **and** is the
|
||||||
|
# first task to begin waiting for a new one
|
||||||
|
if state.recv_ready is None:
|
||||||
|
|
||||||
|
if self._closed:
|
||||||
|
raise trio.ClosedResourceError
|
||||||
|
|
||||||
|
event = trio.Event()
|
||||||
|
state.recv_ready = key, event
|
||||||
|
|
||||||
|
# if we're cancelled here it should be
|
||||||
|
# fine to bail without affecting any other consumers
|
||||||
|
# right?
|
||||||
|
try:
|
||||||
|
value = await self._recv()
|
||||||
|
|
||||||
|
# items with lower indices are "newer"
|
||||||
|
# NOTE: ``collections.deque`` implicitly takes care of
|
||||||
|
# trucating values outside our ``state.maxlen``. In the
|
||||||
|
# alt-backend-array-case we'll need to make sure this is
|
||||||
|
# implemented in similar ringer-buffer-ish style.
|
||||||
|
state.queue.appendleft(value)
|
||||||
|
|
||||||
|
# broadcast new value to all subscribers by increasing
|
||||||
|
# all sequence numbers that will point in the queue to
|
||||||
|
# their latest available value.
|
||||||
|
|
||||||
|
# don't decrement the sequence for this task since we
|
||||||
|
# already retreived the last value
|
||||||
|
|
||||||
|
# XXX: which of these impls is fastest?
|
||||||
|
|
||||||
|
# subs = state.subs.copy()
|
||||||
|
# subs.pop(key)
|
||||||
|
|
||||||
|
for sub_key in filter(
|
||||||
|
# lambda k: k != key, state.subs,
|
||||||
|
partial(ne, key), state.subs,
|
||||||
|
):
|
||||||
|
state.subs[sub_key] += 1
|
||||||
|
|
||||||
|
# NOTE: this should ONLY be set if the above task was *NOT*
|
||||||
|
# cancelled on the `._recv()` call.
|
||||||
|
event.set()
|
||||||
|
return value
|
||||||
|
|
||||||
|
except trio.Cancelled:
|
||||||
|
# handle cancelled specially otherwise sibling
|
||||||
|
# consumers will be awoken with a sequence of -1
|
||||||
|
# state.recv_ready = trio.Cancelled
|
||||||
|
if event.statistics().tasks_waiting:
|
||||||
|
event.set()
|
||||||
|
raise
|
||||||
|
|
||||||
|
finally:
|
||||||
|
|
||||||
|
# Reset receiver waiter task event for next blocking condition.
|
||||||
|
# this MUST be reset even if the above ``.recv()`` call
|
||||||
|
# was cancelled to avoid the next consumer from blocking on
|
||||||
|
# an event that won't be set!
|
||||||
|
state.recv_ready = None
|
||||||
|
|
||||||
|
# This task is all caught up and ready to receive the latest
|
||||||
|
# value, so queue sched it on the internal event.
|
||||||
|
else:
|
||||||
|
seq = state.subs[key]
|
||||||
|
assert seq == -1 # sanity
|
||||||
|
_, ev = state.recv_ready
|
||||||
|
await ev.wait()
|
||||||
|
|
||||||
|
# NOTE: if we ever would like the behaviour where if the
|
||||||
|
# first task to recv on the underlying is cancelled but it
|
||||||
|
# still DOES trigger the ``.recv_ready``, event we'll likely need
|
||||||
|
# this logic:
|
||||||
|
|
||||||
|
if seq > -1:
|
||||||
|
# stuff from above..
|
||||||
|
seq = state.subs[key]
|
||||||
|
|
||||||
|
value = state.queue[seq]
|
||||||
|
state.subs[key] -= 1
|
||||||
|
return value
|
||||||
|
|
||||||
|
elif seq == -1:
|
||||||
|
# XXX: In the case where the first task to allocate the
|
||||||
|
# ``.recv_ready`` event is cancelled we will be woken with
|
||||||
|
# a non-incremented sequence number and thus will read the
|
||||||
|
# oldest value if we use that. Instead we need to detect if
|
||||||
|
# we have not been incremented and then receive again.
|
||||||
|
return await self.receive()
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Invalid sequence {seq}!?')
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def subscribe(
|
||||||
|
self,
|
||||||
|
) -> AsyncIterator[BroadcastReceiver]:
|
||||||
|
'''Subscribe for values from this broadcast receiver.
|
||||||
|
|
||||||
|
Returns a new ``BroadCastReceiver`` which is registered for and
|
||||||
|
pulls data from a clone of the original ``trio.abc.ReceiveChannel``
|
||||||
|
provided at creation.
|
||||||
|
|
||||||
|
'''
|
||||||
|
if self._closed:
|
||||||
|
raise trio.ClosedResourceError
|
||||||
|
|
||||||
|
state = self._state
|
||||||
|
br = BroadcastReceiver(
|
||||||
|
rx_chan=self._rx,
|
||||||
|
state=state,
|
||||||
|
receive_afunc=self._recv,
|
||||||
|
)
|
||||||
|
# assert clone in state.subs
|
||||||
|
assert br.key in state.subs
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield br
|
||||||
|
finally:
|
||||||
|
await br.aclose()
|
||||||
|
|
||||||
|
async def aclose(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
if self._closed:
|
||||||
|
return
|
||||||
|
|
||||||
|
# XXX: leaving it like this consumers can still get values
|
||||||
|
# up to the last received that still reside in the queue.
|
||||||
|
self._state.subs.pop(self.key)
|
||||||
|
|
||||||
|
self._closed = True
|
||||||
|
|
||||||
|
|
||||||
|
def broadcast_receiver(
|
||||||
|
|
||||||
|
recv_chan: AsyncReceiver,
|
||||||
|
max_buffer_size: int,
|
||||||
|
**kwargs,
|
||||||
|
|
||||||
|
) -> BroadcastReceiver:
|
||||||
|
|
||||||
|
return BroadcastReceiver(
|
||||||
|
recv_chan,
|
||||||
|
state=BroadcastState(
|
||||||
|
queue=deque(maxlen=max_buffer_size),
|
||||||
|
maxlen=max_buffer_size,
|
||||||
|
subs={},
|
||||||
|
),
|
||||||
|
**kwargs,
|
||||||
|
)
|
|
@ -294,6 +294,7 @@ class Portal:
|
||||||
async def open_stream_from(
|
async def open_stream_from(
|
||||||
self,
|
self,
|
||||||
async_gen_func: Callable, # typing: ignore
|
async_gen_func: Callable, # typing: ignore
|
||||||
|
shield: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
|
||||||
) -> AsyncGenerator[ReceiveMsgStream, None]:
|
) -> AsyncGenerator[ReceiveMsgStream, None]:
|
||||||
|
@ -320,7 +321,9 @@ class Portal:
|
||||||
ctx = Context(self.channel, cid, _portal=self)
|
ctx = Context(self.channel, cid, _portal=self)
|
||||||
try:
|
try:
|
||||||
# deliver receive only stream
|
# deliver receive only stream
|
||||||
async with ReceiveMsgStream(ctx, recv_chan) as rchan:
|
async with ReceiveMsgStream(
|
||||||
|
ctx, recv_chan, shield=shield
|
||||||
|
) as rchan:
|
||||||
self._streams.add(rchan)
|
self._streams.add(rchan)
|
||||||
yield rchan
|
yield rchan
|
||||||
|
|
||||||
|
|
|
@ -2,12 +2,14 @@
|
||||||
Message stream types and APIs.
|
Message stream types and APIs.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
import inspect
|
import inspect
|
||||||
from contextlib import contextmanager, asynccontextmanager
|
from contextlib import contextmanager, asynccontextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import (
|
||||||
Any, Iterator, Optional, Callable,
|
Any, Iterator, Optional, Callable,
|
||||||
AsyncGenerator, Dict,
|
AsyncGenerator, Dict,
|
||||||
|
AsyncIterator
|
||||||
)
|
)
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
@ -17,6 +19,7 @@ import trio
|
||||||
from ._ipc import Channel
|
from ._ipc import Channel
|
||||||
from ._exceptions import unpack_error, ContextCancelled
|
from ._exceptions import unpack_error, ContextCancelled
|
||||||
from ._state import current_actor
|
from ._state import current_actor
|
||||||
|
from ._broadcast import broadcast_receiver, BroadcastReceiver
|
||||||
from .log import get_logger
|
from .log import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
@ -47,11 +50,14 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
ctx: 'Context', # typing: ignore # noqa
|
ctx: 'Context', # typing: ignore # noqa
|
||||||
rx_chan: trio.abc.ReceiveChannel,
|
rx_chan: trio.MemoryReceiveChannel,
|
||||||
|
shield: bool = False,
|
||||||
|
_broadcaster: Optional[BroadcastReceiver] = None,
|
||||||
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self._ctx = ctx
|
self._ctx = ctx
|
||||||
self._rx_chan = rx_chan
|
self._rx_chan = rx_chan
|
||||||
|
self._broadcaster = _broadcaster
|
||||||
|
|
||||||
# flag to denote end of stream
|
# flag to denote end of stream
|
||||||
self._eoc: bool = False
|
self._eoc: bool = False
|
||||||
|
@ -231,6 +237,50 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
|
||||||
# still need to consume msgs that are "in transit" from the far
|
# still need to consume msgs that are "in transit" from the far
|
||||||
# end (eg. for ``Context.result()``).
|
# end (eg. for ``Context.result()``).
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def subscribe(
|
||||||
|
self,
|
||||||
|
|
||||||
|
) -> AsyncIterator[BroadcastReceiver]:
|
||||||
|
'''Allocate and return a ``BroadcastReceiver`` which delegates
|
||||||
|
to this message stream.
|
||||||
|
|
||||||
|
This allows multiple local tasks to receive each their own copy
|
||||||
|
of this message stream.
|
||||||
|
|
||||||
|
This operation is indempotent and and mutates this stream's
|
||||||
|
receive machinery to copy and window-length-store each received
|
||||||
|
value from the far end via the internally created broudcast
|
||||||
|
receiver wrapper.
|
||||||
|
|
||||||
|
'''
|
||||||
|
# NOTE: This operation is indempotent and non-reversible, so be
|
||||||
|
# sure you can deal with any (theoretical) overhead of the the
|
||||||
|
# allocated ``BroadcastReceiver`` before calling this method for
|
||||||
|
# the first time.
|
||||||
|
if self._broadcaster is None:
|
||||||
|
|
||||||
|
bcast = self._broadcaster = broadcast_receiver(
|
||||||
|
self,
|
||||||
|
# use memory channel size by default
|
||||||
|
self._rx_chan._state.max_buffer_size, # type: ignore
|
||||||
|
receive_afunc=self.receive,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE: we override the original stream instance's receive
|
||||||
|
# method to now delegate to the broadcaster's ``.receive()``
|
||||||
|
# such that new subscribers will be copied received values
|
||||||
|
# and this stream doesn't have to expect it's original
|
||||||
|
# consumer(s) to get a new broadcast rx handle.
|
||||||
|
self.receive = bcast.receive # type: ignore
|
||||||
|
# seems there's no graceful way to type this with ``mypy``?
|
||||||
|
# https://github.com/python/mypy/issues/708
|
||||||
|
|
||||||
|
async with self._broadcaster.subscribe() as bstream:
|
||||||
|
assert bstream.key != self._broadcaster.key
|
||||||
|
assert bstream._recv == self._broadcaster._recv
|
||||||
|
yield bstream
|
||||||
|
|
||||||
|
|
||||||
class MsgStream(ReceiveMsgStream, trio.abc.Channel):
|
class MsgStream(ReceiveMsgStream, trio.abc.Channel):
|
||||||
"""
|
"""
|
||||||
|
@ -247,17 +297,6 @@ class MsgStream(ReceiveMsgStream, trio.abc.Channel):
|
||||||
'''
|
'''
|
||||||
await self._ctx.chan.send({'yield': data, 'cid': self._ctx.cid})
|
await self._ctx.chan.send({'yield': data, 'cid': self._ctx.cid})
|
||||||
|
|
||||||
# TODO: but make it broadcasting to consumers
|
|
||||||
def clone(self):
|
|
||||||
"""Clone this receive channel allowing for multi-task
|
|
||||||
consumption from the same channel.
|
|
||||||
|
|
||||||
"""
|
|
||||||
return MsgStream(
|
|
||||||
self._ctx,
|
|
||||||
self._rx_chan.clone(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Context:
|
class Context:
|
||||||
|
|
Loading…
Reference in New Issue