diff --git a/tractor/_live_from_tokio.py b/tractor/_live_from_tokio.py index 88e3652..5aab368 100644 --- a/tractor/_live_from_tokio.py +++ b/tractor/_live_from_tokio.py @@ -1,25 +1,22 @@ ''' -``tokio`` style broadcast channels. +``tokio`` style broadcast channel. +https://tokio-rs.github.io/tokio/doc/tokio/sync/broadcast/index.html ''' from __future__ import annotations -# from math import inf from itertools import cycle from collections import deque -from contextlib import contextmanager # , asynccontextmanager +from contextlib import contextmanager from functools import partial from typing import Optional import trio import tractor from trio.lowlevel import current_task -from trio.abc import ReceiveChannel # , SendChannel -# from trio._core import enable_ki_protection +from trio.abc import ReceiveChannel from trio._core._run import Task from trio._channel import ( - MemorySendChannel, MemoryReceiveChannel, - # MemoryChannelState, ) @@ -28,20 +25,25 @@ class Lagged(trio.TooSlowError): class BroadcastReceiver(ReceiveChannel): - '''This isn't Paris, not Berlin, nor Honk Kong.. + '''A memory receive channel broadcaster which is non-lossy for the + fastest consumer. + + Additional consumer tasks can receive all produced values by registering + with ``.subscribe()``. ''' def __init__( self, + rx_chan: MemoryReceiveChannel, - buffer_size: int = 100, + queue: deque, ) -> None: self._rx = rx_chan - self._len = buffer_size - self._queue = deque(maxlen=buffer_size) - self._subs = {id(current_task()): -1} + self._queue = queue + self._subs: dict[Task, int] = {} # {id(current_task()): -1} + self._clones: dict[Task, MemoryReceiveChannel] = {} self._value_received: Optional[trio.Event] = None async def receive(self): @@ -56,26 +58,30 @@ class BroadcastReceiver(ReceiveChannel): try: seq = self._subs[key] except KeyError: - self._subs.pop(key) raise RuntimeError( f'Task {task.name} is not registerd as subscriber') if seq > -1: # get the oldest value we haven't received immediately - try: value = self._queue[seq] except IndexError: + # decrement to the last value and expect + # consumer to either handle the ``Lagged`` and come back + # or bail out on it's own (thus un-subscribing) + self._subs[key] = self._queue.maxlen - 1 + + # this task was overrun by the producer side raise Lagged(f'Task {task.name} was overrun') self._subs[key] -= 1 return value if self._value_received is None: - # we already have the latest value **and** are the first - # task to begin waiting for a new one + # current task already has the latest value **and** is the + # first task to begin waiting for a new one - # sanity checks with underlying chan ? + # what sanity checks might we use for the underlying chan ? # assert not self._rx._state.data event = self._value_received = trio.Event() @@ -87,20 +93,15 @@ class BroadcastReceiver(ReceiveChannel): # broadcast new value to all subscribers by increasing # all sequence numbers that will point in the queue to # their latest available value. - for sub_key, seq in self._subs.items(): - - if key == sub_key: - # we don't need to increase **this** task's - # sequence number since we just consumed the latest - # value - continue - - # # except TypeError: - # # # already lagged - # # seq = Lagged + subs = self._subs.copy() + # don't decerement the sequence # for this task since we + # already retreived the last value + subs.pop(key) + for sub_key, seq in subs.items(): self._subs[sub_key] += 1 + # reset receiver waiter task event for next blocking condition self._value_received = None event.set() return value @@ -109,7 +110,7 @@ class BroadcastReceiver(ReceiveChannel): await self._value_received.wait() seq = self._subs[key] - assert seq > -1, 'Uhhhh' + assert seq > -1, 'Internal error?' self._subs[key] -= 1 return self._queue[0] @@ -118,30 +119,37 @@ class BroadcastReceiver(ReceiveChannel): @contextmanager def subscribe( self, - ) -> BroadcastReceiver: key = id(current_task()) self._subs[key] = -1 + # XXX: we only use this clone for closure tracking + clone = self._clones[key] = self._rx.clone() try: yield self finally: self._subs.pop(key) + clone.close() + # TODO: do we need anything here? + # if we're the last sub to close then close + # the underlying rx channel, but couldn't we just + # use ``.clone()``s trackign then? async def aclose(self) -> None: - # TODO: wtf should we do here? - # if we're the last sub to close then close - # the underlying rx channel - pass + key = id(current_task()) + await self._clones[key].aclose() -def broadcast_channel( +def broadcast_receiver( + recv_chan: MemoryReceiveChannel, max_buffer_size: int, -) -> (MemorySendChannel, BroadcastReceiver): +) -> BroadcastReceiver: - tx, rx = trio.open_memory_channel(max_buffer_size) - return tx, BroadcastReceiver(rx) + return BroadcastReceiver( + recv_chan, + queue=deque(maxlen=max_buffer_size), + ) if __name__ == '__main__': @@ -153,7 +161,9 @@ if __name__ == '__main__': # loglevel='info', ): - tx, rx = broadcast_channel(100) + size = 100 + tx, rx = trio.open_memory_channel(size) + rx = broadcast_receiver(rx, size) async def sub_and_print( delay: float,