diff --git a/newsfragments/229.feature.rst b/newsfragments/229.feature.rst new file mode 100644 index 0000000..bda005c --- /dev/null +++ b/newsfragments/229.feature.rst @@ -0,0 +1,12 @@ +Add `tokio-style broadcast channels +`_ as +a solution for `#204 `_ and +discussed thoroughly in `trio/#987 +`_. + +This gives us local task broadcast functionality using a new +``BroadcastReceiver`` type which can wrap ``trio.ReceiveChannel`` and +provide fan-out copies of a stream of data to every subscribed consumer. +We use this new machinery to provide a ``ReceiveMsgStream.subscribe()`` +async context manager which can be used by actor-local concumers tasks +to easily pull from a shared and dynamic IPC stream. diff --git a/tests/test_task_broadcasting.py b/tests/test_task_broadcasting.py new file mode 100644 index 0000000..8265197 --- /dev/null +++ b/tests/test_task_broadcasting.py @@ -0,0 +1,459 @@ +""" +Broadcast channels for fan-out to local tasks. +""" +from contextlib import asynccontextmanager +from functools import partial +from itertools import cycle +import time +from typing import Optional, List, Tuple + +import pytest +import trio +from trio.lowlevel import current_task +import tractor +from tractor._broadcast import broadcast_receiver, Lagged + + +@tractor.context +async def echo_sequences( + + ctx: tractor.Context, + +) -> None: + '''Bidir streaming endpoint which will stream + back any sequence it is sent item-wise. + + ''' + await ctx.started() + + async with ctx.open_stream() as stream: + async for sequence in stream: + seq = list(sequence) + for value in seq: + await stream.send(value) + print(f'producer sent {value}') + + +async def ensure_sequence( + + stream: tractor.ReceiveMsgStream, + sequence: list, + delay: Optional[float] = None, + +) -> None: + + name = current_task().name + async with stream.subscribe() as bcaster: + assert not isinstance(bcaster, type(stream)) + async for value in bcaster: + print(f'{name} rx: {value}') + assert value == sequence[0] + sequence.remove(value) + + if delay: + await trio.sleep(delay) + + if not sequence: + # fully consumed + break + + +@asynccontextmanager +async def open_sequence_streamer( + + sequence: List[int], + arb_addr: Tuple[str, int], + start_method: str, + +) -> tractor.MsgStream: + + async with tractor.open_nursery( + arbiter_addr=arb_addr, + start_method=start_method, + ) as tn: + + portal = await tn.start_actor( + 'sequence_echoer', + enable_modules=[__name__], + ) + + async with portal.open_context( + echo_sequences, + ) as (ctx, first): + + assert first is None + async with ctx.open_stream() as stream: + yield stream + + await portal.cancel_actor() + + +def test_stream_fan_out_to_local_subscriptions( + arb_addr, + start_method, +): + + sequence = list(range(1000)) + + async def main(): + + async with open_sequence_streamer( + sequence, + arb_addr, + start_method, + ) as stream: + + async with trio.open_nursery() as n: + for i in range(10): + n.start_soon( + ensure_sequence, + stream, + sequence.copy(), + name=f'consumer_{i}', + ) + + await stream.send(tuple(sequence)) + + async for value in stream: + print(f'source stream rx: {value}') + assert value == sequence[0] + sequence.remove(value) + + if not sequence: + # fully consumed + break + + trio.run(main) + + +@pytest.mark.parametrize( + 'task_delays', + [ + (0.01, 0.001), + (0.001, 0.01), + ] +) +def test_consumer_and_parent_maybe_lag( + arb_addr, + start_method, + task_delays, +): + + async def main(): + + sequence = list(range(300)) + parent_delay, sub_delay = task_delays + + async with open_sequence_streamer( + sequence, + arb_addr, + start_method, + ) as stream: + + try: + async with trio.open_nursery() as n: + + n.start_soon( + ensure_sequence, + stream, + sequence.copy(), + sub_delay, + name='consumer_task', + ) + + await stream.send(tuple(sequence)) + + # async for value in stream: + lagged = False + lag_count = 0 + + while True: + try: + value = await stream.receive() + print(f'source stream rx: {value}') + + if lagged: + # re set the sequence starting at our last + # value + sequence = sequence[sequence.index(value) + 1:] + else: + assert value == sequence[0] + sequence.remove(value) + + lagged = False + + except Lagged: + lagged = True + print(f'source stream lagged after {value}') + lag_count += 1 + continue + + # lag the parent + await trio.sleep(parent_delay) + + if not sequence: + # fully consumed + break + print(f'parent + source stream lagged: {lag_count}') + + if parent_delay > sub_delay: + assert lag_count > 0 + + except Lagged: + # child was lagged + assert parent_delay < sub_delay + + trio.run(main) + + +def test_faster_task_to_recv_is_cancelled_by_slower( + arb_addr, + start_method, +): + '''Ensure that if a faster task consuming from a stream is cancelled + the slower task can continue to receive all expected values. + + ''' + async def main(): + + sequence = list(range(1000)) + + async with open_sequence_streamer( + sequence, + arb_addr, + start_method, + + ) as stream: + + async with trio.open_nursery() as n: + n.start_soon( + ensure_sequence, + stream, + sequence.copy(), + 0, + name='consumer_task', + ) + + await stream.send(tuple(sequence)) + + # pull 3 values, cancel the subtask, then + # expect to be able to pull all values still + for i in range(20): + try: + value = await stream.receive() + print(f'source stream rx: {value}') + await trio.sleep(0.01) + except Lagged: + print(f'parent overrun after {value}') + continue + + print('cancelling faster subtask') + n.cancel_scope.cancel() + + try: + value = await stream.receive() + print(f'source stream after cancel: {value}') + except Lagged: + print(f'parent overrun after {value}') + + # expect to see all remaining values + with trio.fail_after(0.5): + async for value in stream: + assert stream._broadcaster._state.recv_ready is None + print(f'source stream rx: {value}') + if value == 999: + # fully consumed and we missed no values once + # the faster subtask was cancelled + break + + # await tractor.breakpoint() + # await stream.receive() + print(f'final value: {value}') + + trio.run(main) + + +def test_subscribe_errors_after_close(): + + async def main(): + + size = 1 + tx, rx = trio.open_memory_channel(size) + async with broadcast_receiver(rx, size) as brx: + pass + + try: + # open and close + async with brx.subscribe(): + pass + + except trio.ClosedResourceError: + assert brx.key not in brx._state.subs + + else: + assert 0 + + trio.run(main) + + +def test_ensure_slow_consumers_lag_out( + arb_addr, + start_method, +): + '''This is a pure local task test; no tractor + machinery is really required. + + ''' + async def main(): + + # make sure it all works within the runtime + async with tractor.open_root_actor(): + + num_laggers = 4 + laggers: dict[str, int] = {} + retries = 3 + size = 100 + tx, rx = trio.open_memory_channel(size) + brx = broadcast_receiver(rx, size) + + async def sub_and_print( + delay: float, + ) -> None: + + task = current_task() + start = time.time() + + async with brx.subscribe() as lbrx: + while True: + print(f'{task.name}: starting consume loop') + try: + async for value in lbrx: + print(f'{task.name}: {value}') + await trio.sleep(delay) + + if task.name == 'sub_1': + # the non-lagger got + # a ``trio.EndOfChannel`` + # because the ``tx`` below was closed + assert len(lbrx._state.subs) == 1 + + await lbrx.aclose() + + assert len(lbrx._state.subs) == 0 + + except trio.ClosedResourceError: + # only the fast sub will try to re-enter + # iteration on the now closed bcaster + assert task.name == 'sub_1' + return + + except Lagged: + lag_time = time.time() - start + lags = laggers[task.name] + print( + f'restarting slow task {task.name} ' + f'that bailed out on {lags}:{value} ' + f'after {lag_time:.3f}') + if lags <= retries: + laggers[task.name] += 1 + continue + else: + print( + f'{task.name} was too slow and terminated ' + f'on {lags}:{value}') + return + + async with trio.open_nursery() as nursery: + + for i in range(1, num_laggers): + + task_name = f'sub_{i}' + laggers[task_name] = 0 + nursery.start_soon( + partial( + sub_and_print, + delay=i*0.001, + ), + name=task_name, + ) + + # allow subs to sched + await trio.sleep(0.1) + + async with tx: + for i in cycle(range(size)): + await tx.send(i) + if len(brx._state.subs) == 2: + # only one, the non lagger, sub is left + break + + # the non-lagger + assert laggers.pop('sub_1') == 0 + + for n, v in laggers.items(): + assert v == 4 + + assert tx._closed + assert not tx._state.open_send_channels + + # check that "first" bcaster that we created + # above, never wass iterated and is thus overrun + try: + await brx.receive() + except Lagged: + # expect tokio style index truncation + seq = brx._state.subs[brx.key] + assert seq == len(brx._state.queue) - 1 + + # all backpressured entries in the underlying + # channel should have been copied into the caster + # queue trailing-window + async for i in rx: + print(f'bped: {i}') + assert i in brx._state.queue + + # should be noop + await brx.aclose() + + trio.run(main) + + +def test_first_recver_is_cancelled(): + + async def main(): + + # make sure it all works within the runtime + async with tractor.open_root_actor(): + + tx, rx = trio.open_memory_channel(1) + brx = broadcast_receiver(rx, 1) + cs = trio.CancelScope() + sequence = list(range(3)) + + async def sub_and_recv(): + with cs: + async with brx.subscribe() as bc: + async for value in bc: + print(value) + + async def cancel_and_send(): + await trio.sleep(0.2) + cs.cancel() + await tx.send(1) + + async with trio.open_nursery() as n: + + n.start_soon(sub_and_recv) + await trio.sleep(0.1) + assert brx._state.recv_ready + + n.start_soon(cancel_and_send) + + # ensure that we don't hang because no-task is now + # waiting on the underlying receive.. + with trio.fail_after(0.5): + value = await brx.receive() + print(f'parent: {value}') + assert value == 1 + + trio.run(main) diff --git a/tractor/_actor.py b/tractor/_actor.py index 0dbaede..f84a597 100644 --- a/tractor/_actor.py +++ b/tractor/_actor.py @@ -567,7 +567,7 @@ class Actor: try: send_chan, recv_chan = self._cids2qs[(actorid, cid)] except KeyError: - send_chan, recv_chan = trio.open_memory_channel(1000) + send_chan, recv_chan = trio.open_memory_channel(2**6) send_chan.cid = cid # type: ignore recv_chan.cid = cid # type: ignore self._cids2qs[(actorid, cid)] = send_chan, recv_chan diff --git a/tractor/_broadcast.py b/tractor/_broadcast.py new file mode 100644 index 0000000..51a9be8 --- /dev/null +++ b/tractor/_broadcast.py @@ -0,0 +1,315 @@ +''' +``tokio`` style broadcast channel. +https://docs.rs/tokio/1.11.0/tokio/sync/broadcast/index.html + +''' +from __future__ import annotations +from abc import abstractmethod +from collections import deque +from contextlib import asynccontextmanager +from dataclasses import dataclass +from functools import partial +from operator import ne +from typing import Optional, Callable, Awaitable, Any, AsyncIterator, Protocol +from typing import Generic, TypeVar + +import trio +from trio._core._run import Task +from trio.abc import ReceiveChannel +from trio.lowlevel import current_task + + +# A regular invariant generic type +T = TypeVar("T") + +# covariant because AsyncReceiver[Derived] can be passed to someone +# expecting AsyncReceiver[Base]) +ReceiveType = TypeVar("ReceiveType", covariant=True) + + +class AsyncReceiver( + Protocol, + Generic[ReceiveType], +): + '''An async receivable duck-type that quacks much like trio's + ``trio.abc.ReceieveChannel``. + + ''' + @abstractmethod + async def receive(self) -> ReceiveType: + ... + + @abstractmethod + def __aiter__(self) -> AsyncIterator[ReceiveType]: + ... + + @abstractmethod + async def __anext__(self) -> ReceiveType: + ... + + # ``trio.abc.AsyncResource`` methods + @abstractmethod + async def aclose(self): + ... + + @abstractmethod + async def __aenter__(self) -> AsyncReceiver[ReceiveType]: + ... + + @abstractmethod + async def __aexit__(self, *args) -> None: + ... + + +class Lagged(trio.TooSlowError): + '''Subscribed consumer task was too slow and was overrun + by the fastest consumer-producer pair. + + ''' + + +@dataclass +class BroadcastState: + '''Common state to all receivers of a broadcast. + + ''' + queue: deque + maxlen: int + + # map of underlying instance id keys to receiver instances which + # must be provided as a singleton per broadcaster set. + subs: dict[int, int] + + # broadcast event to wake up all sleeping consumer tasks + # on a newly produced value from the sender. + recv_ready: Optional[tuple[int, trio.Event]] = None + + +class BroadcastReceiver(ReceiveChannel): + '''A memory receive channel broadcaster which is non-lossy for the + fastest consumer. + + Additional consumer tasks can receive all produced values by registering + with ``.subscribe()`` and receiving from the new instance it delivers. + + ''' + def __init__( + self, + + rx_chan: AsyncReceiver, + state: BroadcastState, + receive_afunc: Optional[Callable[[], Awaitable[Any]]] = None, + + ) -> None: + + # register the original underlying (clone) + self.key = id(self) + self._state = state + state.subs[self.key] = -1 + + # underlying for this receiver + self._rx = rx_chan + self._recv = receive_afunc or rx_chan.receive + self._closed: bool = False + + async def receive(self) -> ReceiveType: + + key = self.key + state = self._state + + # TODO: ideally we can make some way to "lock out" the + # underlying receive channel in some way such that if some task + # tries to pull from it directly (i.e. one we're unaware of) + # then it errors out. + + # only tasks which have entered ``.subscribe()`` can + # receive on this broadcaster. + try: + seq = state.subs[key] + except KeyError: + if self._closed: + raise trio.ClosedResourceError + + raise RuntimeError( + f'{self} is not registerd as subscriber') + + # check that task does not already have a value it can receive + # immediately and/or that it has lagged. + if seq > -1: + # get the oldest value we haven't received immediately + try: + value = state.queue[seq] + except IndexError: + + # adhere to ``tokio`` style "lagging": + # "Once RecvError::Lagged is returned, the lagging + # receiver's position is updated to the oldest value + # contained by the channel. The next call to recv will + # return this value." + # https://docs.rs/tokio/1.11.0/tokio/sync/broadcast/index.html#lagging + + # decrement to the last value and expect + # consumer to either handle the ``Lagged`` and come back + # or bail out on its own (thus un-subscribing) + state.subs[key] = state.maxlen - 1 + + # this task was overrun by the producer side + task: Task = current_task() + raise Lagged(f'Task {task.name} was overrun') + + state.subs[key] -= 1 + return value + + # current task already has the latest value **and** is the + # first task to begin waiting for a new one + if state.recv_ready is None: + + if self._closed: + raise trio.ClosedResourceError + + event = trio.Event() + state.recv_ready = key, event + + # if we're cancelled here it should be + # fine to bail without affecting any other consumers + # right? + try: + value = await self._recv() + + # items with lower indices are "newer" + # NOTE: ``collections.deque`` implicitly takes care of + # trucating values outside our ``state.maxlen``. In the + # alt-backend-array-case we'll need to make sure this is + # implemented in similar ringer-buffer-ish style. + state.queue.appendleft(value) + + # broadcast new value to all subscribers by increasing + # all sequence numbers that will point in the queue to + # their latest available value. + + # don't decrement the sequence for this task since we + # already retreived the last value + + # XXX: which of these impls is fastest? + + # subs = state.subs.copy() + # subs.pop(key) + + for sub_key in filter( + # lambda k: k != key, state.subs, + partial(ne, key), state.subs, + ): + state.subs[sub_key] += 1 + + # NOTE: this should ONLY be set if the above task was *NOT* + # cancelled on the `._recv()` call. + event.set() + return value + + except trio.Cancelled: + # handle cancelled specially otherwise sibling + # consumers will be awoken with a sequence of -1 + # state.recv_ready = trio.Cancelled + if event.statistics().tasks_waiting: + event.set() + raise + + finally: + + # Reset receiver waiter task event for next blocking condition. + # this MUST be reset even if the above ``.recv()`` call + # was cancelled to avoid the next consumer from blocking on + # an event that won't be set! + state.recv_ready = None + + # This task is all caught up and ready to receive the latest + # value, so queue sched it on the internal event. + else: + seq = state.subs[key] + assert seq == -1 # sanity + _, ev = state.recv_ready + await ev.wait() + + # NOTE: if we ever would like the behaviour where if the + # first task to recv on the underlying is cancelled but it + # still DOES trigger the ``.recv_ready``, event we'll likely need + # this logic: + + if seq > -1: + # stuff from above.. + seq = state.subs[key] + + value = state.queue[seq] + state.subs[key] -= 1 + return value + + elif seq == -1: + # XXX: In the case where the first task to allocate the + # ``.recv_ready`` event is cancelled we will be woken with + # a non-incremented sequence number and thus will read the + # oldest value if we use that. Instead we need to detect if + # we have not been incremented and then receive again. + return await self.receive() + + else: + raise ValueError(f'Invalid sequence {seq}!?') + + @asynccontextmanager + async def subscribe( + self, + ) -> AsyncIterator[BroadcastReceiver]: + '''Subscribe for values from this broadcast receiver. + + Returns a new ``BroadCastReceiver`` which is registered for and + pulls data from a clone of the original ``trio.abc.ReceiveChannel`` + provided at creation. + + ''' + if self._closed: + raise trio.ClosedResourceError + + state = self._state + br = BroadcastReceiver( + rx_chan=self._rx, + state=state, + receive_afunc=self._recv, + ) + # assert clone in state.subs + assert br.key in state.subs + + try: + yield br + finally: + await br.aclose() + + async def aclose( + self, + ) -> None: + + if self._closed: + return + + # XXX: leaving it like this consumers can still get values + # up to the last received that still reside in the queue. + self._state.subs.pop(self.key) + + self._closed = True + + +def broadcast_receiver( + + recv_chan: AsyncReceiver, + max_buffer_size: int, + **kwargs, + +) -> BroadcastReceiver: + + return BroadcastReceiver( + recv_chan, + state=BroadcastState( + queue=deque(maxlen=max_buffer_size), + maxlen=max_buffer_size, + subs={}, + ), + **kwargs, + ) diff --git a/tractor/_portal.py b/tractor/_portal.py index 44e8630..63c59ed 100644 --- a/tractor/_portal.py +++ b/tractor/_portal.py @@ -294,6 +294,7 @@ class Portal: async def open_stream_from( self, async_gen_func: Callable, # typing: ignore + shield: bool = False, **kwargs, ) -> AsyncGenerator[ReceiveMsgStream, None]: @@ -320,7 +321,9 @@ class Portal: ctx = Context(self.channel, cid, _portal=self) try: # deliver receive only stream - async with ReceiveMsgStream(ctx, recv_chan) as rchan: + async with ReceiveMsgStream( + ctx, recv_chan, shield=shield + ) as rchan: self._streams.add(rchan) yield rchan diff --git a/tractor/_streaming.py b/tractor/_streaming.py index 5f04554..9d832b2 100644 --- a/tractor/_streaming.py +++ b/tractor/_streaming.py @@ -2,12 +2,14 @@ Message stream types and APIs. """ +from __future__ import annotations import inspect from contextlib import contextmanager, asynccontextmanager from dataclasses import dataclass from typing import ( Any, Iterator, Optional, Callable, AsyncGenerator, Dict, + AsyncIterator ) import warnings @@ -17,6 +19,7 @@ import trio from ._ipc import Channel from ._exceptions import unpack_error, ContextCancelled from ._state import current_actor +from ._broadcast import broadcast_receiver, BroadcastReceiver from .log import get_logger @@ -47,11 +50,14 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): def __init__( self, ctx: 'Context', # typing: ignore # noqa - rx_chan: trio.abc.ReceiveChannel, + rx_chan: trio.MemoryReceiveChannel, + shield: bool = False, + _broadcaster: Optional[BroadcastReceiver] = None, ) -> None: self._ctx = ctx self._rx_chan = rx_chan + self._broadcaster = _broadcaster # flag to denote end of stream self._eoc: bool = False @@ -231,6 +237,50 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): # still need to consume msgs that are "in transit" from the far # end (eg. for ``Context.result()``). + @asynccontextmanager + async def subscribe( + self, + + ) -> AsyncIterator[BroadcastReceiver]: + '''Allocate and return a ``BroadcastReceiver`` which delegates + to this message stream. + + This allows multiple local tasks to receive each their own copy + of this message stream. + + This operation is indempotent and and mutates this stream's + receive machinery to copy and window-length-store each received + value from the far end via the internally created broudcast + receiver wrapper. + + ''' + # NOTE: This operation is indempotent and non-reversible, so be + # sure you can deal with any (theoretical) overhead of the the + # allocated ``BroadcastReceiver`` before calling this method for + # the first time. + if self._broadcaster is None: + + bcast = self._broadcaster = broadcast_receiver( + self, + # use memory channel size by default + self._rx_chan._state.max_buffer_size, # type: ignore + receive_afunc=self.receive, + ) + + # NOTE: we override the original stream instance's receive + # method to now delegate to the broadcaster's ``.receive()`` + # such that new subscribers will be copied received values + # and this stream doesn't have to expect it's original + # consumer(s) to get a new broadcast rx handle. + self.receive = bcast.receive # type: ignore + # seems there's no graceful way to type this with ``mypy``? + # https://github.com/python/mypy/issues/708 + + async with self._broadcaster.subscribe() as bstream: + assert bstream.key != self._broadcaster.key + assert bstream._recv == self._broadcaster._recv + yield bstream + class MsgStream(ReceiveMsgStream, trio.abc.Channel): """ @@ -247,17 +297,6 @@ class MsgStream(ReceiveMsgStream, trio.abc.Channel): ''' await self._ctx.chan.send({'yield': data, 'cid': self._ctx.cid}) - # TODO: but make it broadcasting to consumers - def clone(self): - """Clone this receive channel allowing for multi-task - consumption from the same channel. - - """ - return MsgStream( - self._ctx, - self._rx_chan.clone(), - ) - @dataclass class Context: