diff --git a/tractor/_broadcast.py b/tractor/_broadcast.py index bba8021..2b326f9 100644 --- a/tractor/_broadcast.py +++ b/tractor/_broadcast.py @@ -9,16 +9,15 @@ from collections import deque from contextlib import asynccontextmanager from dataclasses import dataclass from functools import partial -from itertools import cycle from operator import ne from typing import Optional, Callable, Awaitable, Any, AsyncIterator, Protocol from typing import Generic, TypeVar +from uuid import uuid4 import trio from trio._core._run import Task from trio.abc import ReceiveChannel from trio.lowlevel import current_task -import tractor # A regular invariant generic type @@ -30,14 +29,10 @@ T = TypeVar("T") ReceiveType = TypeVar("ReceiveType", covariant=True) -class CloneableReceiveChannel( +class AsyncReceiver( Protocol, Generic[ReceiveType], ): - @abstractmethod - def clone(self) -> CloneableReceiveChannel[ReceiveType]: - '''Clone this receiver usually by making a copy.''' - @abstractmethod async def receive(self) -> ReceiveType: '''Same as in ``trio``.''' @@ -56,7 +51,7 @@ class CloneableReceiveChannel( ... @abstractmethod - async def __aenter__(self) -> CloneableReceiveChannel[ReceiveType]: + async def __aenter__(self) -> AsyncReceiver[ReceiveType]: ... @abstractmethod @@ -75,14 +70,13 @@ class BroadcastState: ''' queue: deque - # map of underlying clones to receiver wrappers - # which must be provided as a singleton per broadcaster - # clone-subscription set. - subs: dict[CloneableReceiveChannel, int] + # map of underlying uuid keys to receiver instances which must be + # provided as a singleton per broadcaster set. + subs: dict[str, int] # broadcast event to wakeup all sleeping consumer tasks # on a newly produced value from the sender. - sender_ready: Optional[trio.Event] = None + recv_ready: Optional[tuple[str, trio.Event]] = None class BroadcastReceiver(ReceiveChannel): @@ -96,23 +90,26 @@ class BroadcastReceiver(ReceiveChannel): def __init__( self, - rx_chan: CloneableReceiveChannel, + key: str, + rx_chan: AsyncReceiver, state: BroadcastState, receive_afunc: Optional[Callable[[], Awaitable[Any]]] = None, ) -> None: # register the original underlying (clone) + self.key = key self._state = state - state.subs[rx_chan] = -1 + state.subs[key] = -1 # underlying for this receiver self._rx = rx_chan self._recv = receive_afunc or rx_chan.receive + self._closed: bool = False async def receive(self): - key = self._rx + key = self.key state = self._state # TODO: ideally we can make some way to "lock out" the @@ -125,6 +122,9 @@ class BroadcastReceiver(ReceiveChannel): try: seq = state.subs[key] except KeyError: + if self._closed: + raise trio.ClosedResourceError + raise RuntimeError( f'{self} is not registerd as subscriber') @@ -157,41 +157,50 @@ class BroadcastReceiver(ReceiveChannel): # current task already has the latest value **and** is the # first task to begin waiting for a new one - if state.sender_ready is None: + if state.recv_ready is None: - event = state.sender_ready = trio.Event() - value = await self._recv() + if self._closed: + raise trio.ClosedResourceError - # items with lower indices are "newer" - state.queue.appendleft(value) + event = trio.Event() + state.recv_ready = key, event - # broadcast new value to all subscribers by increasing - # all sequence numbers that will point in the queue to - # their latest available value. + try: + value = await self._recv() - # don't decrement the sequence for this task since we - # already retreived the last value + # items with lower indices are "newer" + state.queue.appendleft(value) - # XXX: which of these impls is fastest? + # broadcast new value to all subscribers by increasing + # all sequence numbers that will point in the queue to + # their latest available value. - # subs = state.subs.copy() - # subs.pop(key) + # don't decrement the sequence for this task since we + # already retreived the last value - for sub_key in filter( - # lambda k: k != key, state.subs, - partial(ne, key), state.subs, - ): - state.subs[sub_key] += 1 + # XXX: which of these impls is fastest? - # reset receiver waiter task event for next blocking condition - event.set() - state.sender_ready = None - return value + # subs = state.subs.copy() + # subs.pop(key) + + for sub_key in filter( + # lambda k: k != key, state.subs, + partial(ne, key), state.subs, + ): + state.subs[sub_key] += 1 + + return value + + finally: + # reset receiver waiter task event for next blocking condition + event.set() + state.recv_ready = None # This task is all caught up and ready to receive the latest # value, so queue sched it on the internal event. else: - await state.sender_ready.wait() + _, ev = state.recv_ready + await ev.wait() seq = state.subs[key] state.subs[key] -= 1 return state.queue[seq] @@ -207,24 +216,22 @@ class BroadcastReceiver(ReceiveChannel): provided at creation. ''' - # if we didn't want to enforce "clone-ability" how would - # we key arbitrary subscriptions? Use a token system? - clone = self._rx.clone() + # use a uuid4 for a tee-instance token + key = str(uuid4()) state = self._state br = BroadcastReceiver( - rx_chan=clone, + key=key, + rx_chan=self._rx, state=state, + receive_afunc=self._recv, ) - assert clone in state.subs + # assert clone in state.subs + assert key in state.subs try: yield br finally: - # XXX: this is the reason this function is async: the - # ``AsyncResource`` api. - await clone.aclose() - # drop from subscribers and close - state.subs.pop(clone) + await br.aclose() # TODO: # - should there be some ._closed flag that causes @@ -235,22 +242,30 @@ class BroadcastReceiver(ReceiveChannel): async def aclose( self, ) -> None: + + if self._closed: + return + # XXX: leaving it like this consumers can still get values # up to the last received that still reside in the queue. # Is this what we want? - await self._rx.aclose() - self._state.subs.pop(self._rx) + self._state.subs.pop(self.key) + # if not self._state.subs: + # await self._rx.aclose() + + self._closed = True def broadcast_receiver( - recv_chan: CloneableReceiveChannel, + recv_chan: AsyncReceiver, max_buffer_size: int, **kwargs, ) -> BroadcastReceiver: return BroadcastReceiver( + str(uuid4()), recv_chan, state=BroadcastState( queue=deque(maxlen=max_buffer_size), @@ -258,62 +273,3 @@ def broadcast_receiver( ), **kwargs, ) - - -if __name__ == '__main__': - - async def main(): - - async with tractor.open_root_actor( - debug_mode=True, - # loglevel='info', - ): - - retries = 3 - size = 100 - tx, rx = trio.open_memory_channel(size) - rx = broadcast_receiver(rx, size) - - async def sub_and_print( - delay: float, - ) -> None: - - task = current_task() - lags = 0 - - while True: - async with rx.subscribe() as brx: - try: - async for value in brx: - print(f'{task.name}: {value}') - await trio.sleep(delay) - - except Lagged: - print( - f'restarting slow ass {task.name}' - f'that bailed out on {lags}:{value}') - if lags <= retries: - lags += 1 - continue - else: - print( - f'{task.name} was too slow and terminated ' - f'on {lags}:{value}') - return - - async with trio.open_nursery() as n: - for i in range(1, 10): - n.start_soon( - partial( - sub_and_print, - delay=i*0.01, - ), - name=f'sub_{i}', - ) - - async with tx: - for i in cycle(range(size)): - print(f'sending: {i}') - await tx.send(i) - - trio.run(main)