diff --git a/tractor/_broadcast.py b/tractor/_broadcast.py index bfd70ce..984aae9 100644 --- a/tractor/_broadcast.py +++ b/tractor/_broadcast.py @@ -4,23 +4,42 @@ https://tokio-rs.github.io/tokio/doc/tokio/sync/broadcast/index.html ''' from __future__ import annotations -from itertools import cycle from collections import deque from contextlib import asynccontextmanager +from dataclasses import dataclass from functools import partial +from itertools import cycle +from operator import ne from typing import Optional import trio -import tractor -from trio.lowlevel import current_task -from trio.abc import ReceiveChannel from trio._core._run import Task +from trio.abc import ReceiveChannel +from trio.lowlevel import current_task +import tractor class Lagged(trio.TooSlowError): '''Subscribed consumer task was too slow''' +@dataclass +class BroadcastState: + '''Common state to all receivers of a broadcast. + + ''' + queue: deque + + # map of underlying clones to receiver wrappers + # which must be provided as a singleton per broadcaster + # clone-subscription set. + subs: dict[trio.ReceiveChannel, BroadcastReceiver] + + # broadcast event to wakeup all sleeping consumer tasks + # on a newly produced value from the sender. + sender_ready: Optional[trio.Event] = None + + class BroadcastReceiver(ReceiveChannel): '''A memory receive channel broadcaster which is non-lossy for the fastest consumer. @@ -33,28 +52,21 @@ class BroadcastReceiver(ReceiveChannel): self, rx_chan: ReceiveChannel, - queue: deque, - _subs: dict[trio.ReceiveChannel, BroadcastReceiver], + state: BroadcastState, ) -> None: - # map of underlying clones to receiver wrappers - # which must be provided as a singleton per broadcaster - # clone-subscription set. - self._subs = _subs + # register the original underlying (clone) + self._state = state + state.subs[rx_chan] = -1 # underlying for this receiver self._rx = rx_chan - # register the original underlying (clone) - self._subs[rx_chan] = -1 - - self._queue = queue - self._value_received: Optional[trio.Event] = None - async def receive(self): key = self._rx + state = self._state # TODO: ideally we can make some way to "lock out" the # underlying receive channel in some way such that if some task @@ -64,7 +76,7 @@ class BroadcastReceiver(ReceiveChannel): # only tasks which have entered ``.subscribe()`` can # receive on this broadcaster. try: - seq = self._subs[key] + seq = state.subs[key] except KeyError: raise RuntimeError( f'{self} is not registerd as subscriber') @@ -74,7 +86,7 @@ class BroadcastReceiver(ReceiveChannel): if seq > -1: # get the oldest value we haven't received immediately try: - value = self._queue[seq] + value = state.queue[seq] except IndexError: # adhere to ``tokio`` style "lagging": @@ -87,51 +99,61 @@ class BroadcastReceiver(ReceiveChannel): # decrement to the last value and expect # consumer to either handle the ``Lagged`` and come back # or bail out on its own (thus un-subscribing) - self._subs[key] = self._queue.maxlen - 1 + state.subs[key] = state.queue.maxlen - 1 # this task was overrun by the producer side task: Task = current_task() raise Lagged(f'Task {task.name} was overrun') - self._subs[key] -= 1 + state.subs[key] -= 1 return value # current task already has the latest value **and** is the # first task to begin waiting for a new one - if self._value_received is None: + if state.sender_ready is None: - event = self._value_received = trio.Event() + event = state.sender_ready = trio.Event() value = await self._rx.receive() # items with lower indices are "newer" - self._queue.appendleft(value) + state.queue.appendleft(value) # broadcast new value to all subscribers by increasing # all sequence numbers that will point in the queue to # their latest available value. - subs = self._subs.copy() - # don't decrement the sequence # for this task since we + # don't decrement the sequence for this task since we # already retreived the last value - subs.pop(key) - for sub_key, seq in subs.items(): - self._subs[sub_key] += 1 + + # XXX: which of these impls is fastest? + + # subs = state.subs.copy() + # subs.pop(key) + + for sub_key in filter( + # lambda k: k != key, state.subs, + partial(ne, key), state.subs, + ): + state.subs[sub_key] += 1 # reset receiver waiter task event for next blocking condition - self._value_received = None event.set() + state.sender_ready = None return value # This task is all caught up and ready to receive the latest # value, so queue sched it on the internal event. else: - await self._value_received.wait() + await state.sender_ready.wait() - seq = self._subs[key] - assert seq > -1, 'Internal error?' + # TODO: optimization: if this is always true can't we just + # skip iterating these sequence numbers on the fastest + # task's wakeup and always read from state.queue[0]? + seq = state.subs[key] + assert seq == 0, 'Internal error?' - self._subs[key] -= 1 - return self._queue[0] + state.subs[key] -= 1 + return state.queue[seq] @asynccontextmanager async def subscribe( @@ -145,12 +167,12 @@ class BroadcastReceiver(ReceiveChannel): ''' clone = self._rx.clone() + state = self._state br = BroadcastReceiver( - clone, - self._queue, - _subs=self._subs, + rx_chan=clone, + state=state, ) - assert clone in self._subs + assert clone in state.subs try: yield br @@ -159,7 +181,7 @@ class BroadcastReceiver(ReceiveChannel): # ``AsyncResource`` api. await clone.aclose() # drop from subscribers and close - self._subs.pop(clone) + state.subs.pop(clone) # TODO: # - should there be some ._closed flag that causes @@ -186,8 +208,10 @@ def broadcast_receiver( return BroadcastReceiver( recv_chan, - queue=deque(maxlen=max_buffer_size), - _subs={}, # this is singleton over all subscriptions + state=BroadcastState( + queue=deque(maxlen=max_buffer_size), + subs={}, + ), ) @@ -210,7 +234,7 @@ if __name__ == '__main__': ) -> None: task = current_task() - count = 0 + lags = 0 while True: async with rx.subscribe() as brx: @@ -218,22 +242,22 @@ if __name__ == '__main__': async for value in brx: print(f'{task.name}: {value}') await trio.sleep(delay) - count += 1 except Lagged: print( f'restarting slow ass {task.name}' - f'that bailed out on {count}:{value}') - if count <= retries: + f'that bailed out on {lags}:{value}') + if lags <= retries: + lags += 1 continue else: print( f'{task.name} was too slow and terminated ' - f'on {count}:{value}') + f'on {lags}:{value}') return async with trio.open_nursery() as n: - for i in range(1, size): + for i in range(1, 10): n.start_soon( partial( sub_and_print,