''' ``tokio`` style broadcast channel. https://docs.rs/tokio/1.11.0/tokio/sync/broadcast/index.html ''' from __future__ import annotations from abc import abstractmethod from collections import deque from contextlib import asynccontextmanager from dataclasses import dataclass from functools import partial from operator import ne from typing import Optional, Callable, Awaitable, Any, AsyncIterator, Protocol from typing import Generic, TypeVar import trio from trio._core._run import Task from trio.abc import ReceiveChannel from trio.lowlevel import current_task # A regular invariant generic type T = TypeVar("T") # covariant because AsyncReceiver[Derived] can be passed to someone # expecting AsyncReceiver[Base]) ReceiveType = TypeVar("ReceiveType", covariant=True) class AsyncReceiver( Protocol, Generic[ReceiveType], ): '''An async receivable duck-type that quacks much like trio's ``trio.abc.ReceieveChannel``. ''' @abstractmethod async def receive(self) -> ReceiveType: ... @abstractmethod def __aiter__(self) -> AsyncIterator[ReceiveType]: ... @abstractmethod async def __anext__(self) -> ReceiveType: ... # ``trio.abc.AsyncResource`` methods @abstractmethod async def aclose(self): ... @abstractmethod async def __aenter__(self) -> AsyncReceiver[ReceiveType]: ... @abstractmethod async def __aexit__(self, *args) -> None: ... class Lagged(trio.TooSlowError): '''Subscribed consumer task was too slow and was overrun by the fastest consumer-producer pair. ''' @dataclass class BroadcastState: '''Common state to all receivers of a broadcast. ''' queue: deque maxlen: int # map of underlying instance id keys to receiver instances which # must be provided as a singleton per broadcaster set. subs: dict[int, int] # broadcast event to wake up all sleeping consumer tasks # on a newly produced value from the sender. recv_ready: Optional[tuple[int, trio.Event]] = None class BroadcastReceiver(ReceiveChannel): '''A memory receive channel broadcaster which is non-lossy for the fastest consumer. Additional consumer tasks can receive all produced values by registering with ``.subscribe()`` and receiving from the new instance it delivers. ''' def __init__( self, rx_chan: AsyncReceiver, state: BroadcastState, receive_afunc: Optional[Callable[[], Awaitable[Any]]] = None, ) -> None: # register the original underlying (clone) self.key = id(self) self._state = state state.subs[self.key] = -1 # underlying for this receiver self._rx = rx_chan self._recv = receive_afunc or rx_chan.receive self._closed: bool = False async def receive(self) -> ReceiveType: key = self.key state = self._state # TODO: ideally we can make some way to "lock out" the # underlying receive channel in some way such that if some task # tries to pull from it directly (i.e. one we're unaware of) # then it errors out. # only tasks which have entered ``.subscribe()`` can # receive on this broadcaster. try: seq = state.subs[key] except KeyError: if self._closed: raise trio.ClosedResourceError raise RuntimeError( f'{self} is not registerd as subscriber') # check that task does not already have a value it can receive # immediately and/or that it has lagged. if seq > -1: # get the oldest value we haven't received immediately try: value = state.queue[seq] except IndexError: # adhere to ``tokio`` style "lagging": # "Once RecvError::Lagged is returned, the lagging # receiver's position is updated to the oldest value # contained by the channel. The next call to recv will # return this value." # https://docs.rs/tokio/1.11.0/tokio/sync/broadcast/index.html#lagging # decrement to the last value and expect # consumer to either handle the ``Lagged`` and come back # or bail out on its own (thus un-subscribing) state.subs[key] = state.maxlen - 1 # this task was overrun by the producer side task: Task = current_task() raise Lagged(f'Task {task.name} was overrun') state.subs[key] -= 1 return value # current task already has the latest value **and** is the # first task to begin waiting for a new one if state.recv_ready is None: if self._closed: raise trio.ClosedResourceError event = trio.Event() state.recv_ready = key, event # if we're cancelled here it should be # fine to bail without affecting any other consumers # right? try: value = await self._recv() # items with lower indices are "newer" # NOTE: ``collections.deque`` implicitly takes care of # trucating values outside our ``state.maxlen``. In the # alt-backend-array-case we'll need to make sure this is # implemented in similar ringer-buffer-ish style. state.queue.appendleft(value) # broadcast new value to all subscribers by increasing # all sequence numbers that will point in the queue to # their latest available value. # don't decrement the sequence for this task since we # already retreived the last value # XXX: which of these impls is fastest? # subs = state.subs.copy() # subs.pop(key) for sub_key in filter( # lambda k: k != key, state.subs, partial(ne, key), state.subs, ): state.subs[sub_key] += 1 # NOTE: this should ONLY be set if the above task was *NOT* # cancelled on the `._recv()` call. event.set() return value except trio.Cancelled: # handle cancelled specially otherwise sibling # consumers will be awoken with a sequence of -1 # state.recv_ready = trio.Cancelled if event.statistics().tasks_waiting: event.set() raise finally: # Reset receiver waiter task event for next blocking condition. # this MUST be reset even if the above ``.recv()`` call # was cancelled to avoid the next consumer from blocking on # an event that won't be set! state.recv_ready = None # This task is all caught up and ready to receive the latest # value, so queue sched it on the internal event. else: seq = state.subs[key] assert seq == -1 # sanity _, ev = state.recv_ready await ev.wait() # NOTE: if we ever would like the behaviour where if the # first task to recv on the underlying is cancelled but it # still DOES trigger the ``.recv_ready``, event we'll likely need # this logic: if seq > -1: # stuff from above.. seq = state.subs[key] value = state.queue[seq] state.subs[key] -= 1 return value elif seq == -1: # XXX: In the case where the first task to allocate the # ``.recv_ready`` event is cancelled we will be woken with # a non-incremented sequence number and thus will read the # oldest value if we use that. Instead we need to detect if # we have not been incremented and then receive again. return await self.receive() else: raise ValueError(f'Invalid sequence {seq}!?') @asynccontextmanager async def subscribe( self, ) -> AsyncIterator[BroadcastReceiver]: '''Subscribe for values from this broadcast receiver. Returns a new ``BroadCastReceiver`` which is registered for and pulls data from a clone of the original ``trio.abc.ReceiveChannel`` provided at creation. ''' if self._closed: raise trio.ClosedResourceError state = self._state br = BroadcastReceiver( rx_chan=self._rx, state=state, receive_afunc=self._recv, ) # assert clone in state.subs assert br.key in state.subs try: yield br finally: await br.aclose() async def aclose( self, ) -> None: if self._closed: return # XXX: leaving it like this consumers can still get values # up to the last received that still reside in the queue. self._state.subs.pop(self.key) self._closed = True def broadcast_receiver( recv_chan: AsyncReceiver, max_buffer_size: int, **kwargs, ) -> BroadcastReceiver: return BroadcastReceiver( recv_chan, state=BroadcastState( queue=deque(maxlen=max_buffer_size), maxlen=max_buffer_size, subs={}, ), **kwargs, )