Add common state delegate type for all consumers
For every set of broadcast receivers which pull from the same producer, we need a singleton state for all of, - subscriptions - the sender ready event - the queue Add a `BroadcastState` dataclass for this and pass it to all subscriptions. This makes the design much more like the built-in memory channels which do something very similar with `MemoryChannelState`. Use a `filter()` on the subs list in the sequence update step, plus some other commented approaches we can try for speed.tokio_backup
							parent
							
								
									9d12cc80dd
								
							
						
					
					
						commit
						b9863fc4ab
					
				| 
						 | 
				
			
			@ -4,23 +4,42 @@ https://tokio-rs.github.io/tokio/doc/tokio/sync/broadcast/index.html
 | 
			
		|||
 | 
			
		||||
'''
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
from itertools import cycle
 | 
			
		||||
from collections import deque
 | 
			
		||||
from contextlib import asynccontextmanager
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from functools import partial
 | 
			
		||||
from itertools import cycle
 | 
			
		||||
from operator import ne
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
import trio
 | 
			
		||||
import tractor
 | 
			
		||||
from trio.lowlevel import current_task
 | 
			
		||||
from trio.abc import ReceiveChannel
 | 
			
		||||
from trio._core._run import Task
 | 
			
		||||
from trio.abc import ReceiveChannel
 | 
			
		||||
from trio.lowlevel import current_task
 | 
			
		||||
import tractor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Lagged(trio.TooSlowError):
 | 
			
		||||
    '''Subscribed consumer task was too slow'''
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class BroadcastState:
 | 
			
		||||
    '''Common state to all receivers of a broadcast.
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    queue: deque
 | 
			
		||||
 | 
			
		||||
    # map of underlying clones to receiver wrappers
 | 
			
		||||
    # which must be provided as a singleton per broadcaster
 | 
			
		||||
    # clone-subscription set.
 | 
			
		||||
    subs: dict[trio.ReceiveChannel, BroadcastReceiver]
 | 
			
		||||
 | 
			
		||||
    # broadcast event to wakeup all sleeping consumer tasks
 | 
			
		||||
    # on a newly produced value from the sender.
 | 
			
		||||
    sender_ready: Optional[trio.Event] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BroadcastReceiver(ReceiveChannel):
 | 
			
		||||
    '''A memory receive channel broadcaster which is non-lossy for the
 | 
			
		||||
    fastest consumer.
 | 
			
		||||
| 
						 | 
				
			
			@ -33,28 +52,21 @@ class BroadcastReceiver(ReceiveChannel):
 | 
			
		|||
        self,
 | 
			
		||||
 | 
			
		||||
        rx_chan: ReceiveChannel,
 | 
			
		||||
        queue: deque,
 | 
			
		||||
        _subs: dict[trio.ReceiveChannel, BroadcastReceiver],
 | 
			
		||||
        state: BroadcastState,
 | 
			
		||||
 | 
			
		||||
    ) -> None:
 | 
			
		||||
 | 
			
		||||
        # map of underlying clones to receiver wrappers
 | 
			
		||||
        # which must be provided as a singleton per broadcaster
 | 
			
		||||
        # clone-subscription set.
 | 
			
		||||
        self._subs = _subs
 | 
			
		||||
        # register the original underlying (clone)
 | 
			
		||||
        self._state = state
 | 
			
		||||
        state.subs[rx_chan] = -1
 | 
			
		||||
 | 
			
		||||
        # underlying for this receiver
 | 
			
		||||
        self._rx = rx_chan
 | 
			
		||||
 | 
			
		||||
        # register the original underlying (clone)
 | 
			
		||||
        self._subs[rx_chan] = -1
 | 
			
		||||
 | 
			
		||||
        self._queue = queue
 | 
			
		||||
        self._value_received: Optional[trio.Event] = None
 | 
			
		||||
 | 
			
		||||
    async def receive(self):
 | 
			
		||||
 | 
			
		||||
        key = self._rx
 | 
			
		||||
        state = self._state
 | 
			
		||||
 | 
			
		||||
        # TODO: ideally we can make some way to "lock out" the
 | 
			
		||||
        # underlying receive channel in some way such that if some task
 | 
			
		||||
| 
						 | 
				
			
			@ -64,7 +76,7 @@ class BroadcastReceiver(ReceiveChannel):
 | 
			
		|||
        # only tasks which have entered ``.subscribe()`` can
 | 
			
		||||
        # receive on this broadcaster.
 | 
			
		||||
        try:
 | 
			
		||||
            seq = self._subs[key]
 | 
			
		||||
            seq = state.subs[key]
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            raise RuntimeError(
 | 
			
		||||
                f'{self} is not registerd as subscriber')
 | 
			
		||||
| 
						 | 
				
			
			@ -74,7 +86,7 @@ class BroadcastReceiver(ReceiveChannel):
 | 
			
		|||
        if seq > -1:
 | 
			
		||||
            # get the oldest value we haven't received immediately
 | 
			
		||||
            try:
 | 
			
		||||
                value = self._queue[seq]
 | 
			
		||||
                value = state.queue[seq]
 | 
			
		||||
            except IndexError:
 | 
			
		||||
 | 
			
		||||
                # adhere to ``tokio`` style "lagging":
 | 
			
		||||
| 
						 | 
				
			
			@ -87,51 +99,61 @@ class BroadcastReceiver(ReceiveChannel):
 | 
			
		|||
                # decrement to the last value and expect
 | 
			
		||||
                # consumer to either handle the ``Lagged`` and come back
 | 
			
		||||
                # or bail out on its own (thus un-subscribing)
 | 
			
		||||
                self._subs[key] = self._queue.maxlen - 1
 | 
			
		||||
                state.subs[key] = state.queue.maxlen - 1
 | 
			
		||||
 | 
			
		||||
                # this task was overrun by the producer side
 | 
			
		||||
                task: Task = current_task()
 | 
			
		||||
                raise Lagged(f'Task {task.name} was overrun')
 | 
			
		||||
 | 
			
		||||
            self._subs[key] -= 1
 | 
			
		||||
            state.subs[key] -= 1
 | 
			
		||||
            return value
 | 
			
		||||
 | 
			
		||||
        # current task already has the latest value **and** is the
 | 
			
		||||
        # first task to begin waiting for a new one
 | 
			
		||||
        if self._value_received is None:
 | 
			
		||||
        if state.sender_ready is None:
 | 
			
		||||
 | 
			
		||||
            event = self._value_received = trio.Event()
 | 
			
		||||
            event = state.sender_ready = trio.Event()
 | 
			
		||||
            value = await self._rx.receive()
 | 
			
		||||
 | 
			
		||||
            # items with lower indices are "newer"
 | 
			
		||||
            self._queue.appendleft(value)
 | 
			
		||||
            state.queue.appendleft(value)
 | 
			
		||||
 | 
			
		||||
            # broadcast new value to all subscribers by increasing
 | 
			
		||||
            # all sequence numbers that will point in the queue to
 | 
			
		||||
            # their latest available value.
 | 
			
		||||
 | 
			
		||||
            subs = self._subs.copy()
 | 
			
		||||
            # don't decrement the sequence # for this task since we
 | 
			
		||||
            # don't decrement the sequence for this task since we
 | 
			
		||||
            # already retreived the last value
 | 
			
		||||
            subs.pop(key)
 | 
			
		||||
            for sub_key, seq in subs.items():
 | 
			
		||||
                self._subs[sub_key] += 1
 | 
			
		||||
 | 
			
		||||
            # XXX: which of these impls is fastest?
 | 
			
		||||
 | 
			
		||||
            # subs = state.subs.copy()
 | 
			
		||||
            # subs.pop(key)
 | 
			
		||||
 | 
			
		||||
            for sub_key in filter(
 | 
			
		||||
                # lambda k: k != key, state.subs,
 | 
			
		||||
                partial(ne, key), state.subs,
 | 
			
		||||
            ):
 | 
			
		||||
                state.subs[sub_key] += 1
 | 
			
		||||
 | 
			
		||||
            # reset receiver waiter task event for next blocking condition
 | 
			
		||||
            self._value_received = None
 | 
			
		||||
            event.set()
 | 
			
		||||
            state.sender_ready = None
 | 
			
		||||
            return value
 | 
			
		||||
 | 
			
		||||
        # This task is all caught up and ready to receive the latest
 | 
			
		||||
        # value, so queue sched it on the internal event.
 | 
			
		||||
        else:
 | 
			
		||||
            await self._value_received.wait()
 | 
			
		||||
            await state.sender_ready.wait()
 | 
			
		||||
 | 
			
		||||
            seq = self._subs[key]
 | 
			
		||||
            assert seq > -1, 'Internal error?'
 | 
			
		||||
            # TODO: optimization: if this is always true can't we just
 | 
			
		||||
            # skip iterating these sequence numbers on the fastest
 | 
			
		||||
            # task's wakeup and always read from state.queue[0]?
 | 
			
		||||
            seq = state.subs[key]
 | 
			
		||||
            assert seq == 0, 'Internal error?'
 | 
			
		||||
 | 
			
		||||
            self._subs[key] -= 1
 | 
			
		||||
            return self._queue[0]
 | 
			
		||||
            state.subs[key] -= 1
 | 
			
		||||
            return state.queue[seq]
 | 
			
		||||
 | 
			
		||||
    @asynccontextmanager
 | 
			
		||||
    async def subscribe(
 | 
			
		||||
| 
						 | 
				
			
			@ -145,12 +167,12 @@ class BroadcastReceiver(ReceiveChannel):
 | 
			
		|||
 | 
			
		||||
        '''
 | 
			
		||||
        clone = self._rx.clone()
 | 
			
		||||
        state = self._state
 | 
			
		||||
        br = BroadcastReceiver(
 | 
			
		||||
            clone,
 | 
			
		||||
            self._queue,
 | 
			
		||||
            _subs=self._subs,
 | 
			
		||||
            rx_chan=clone,
 | 
			
		||||
            state=state,
 | 
			
		||||
        )
 | 
			
		||||
        assert clone in self._subs
 | 
			
		||||
        assert clone in state.subs
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            yield br
 | 
			
		||||
| 
						 | 
				
			
			@ -159,7 +181,7 @@ class BroadcastReceiver(ReceiveChannel):
 | 
			
		|||
            # ``AsyncResource`` api.
 | 
			
		||||
            await clone.aclose()
 | 
			
		||||
            # drop from subscribers and close
 | 
			
		||||
            self._subs.pop(clone)
 | 
			
		||||
            state.subs.pop(clone)
 | 
			
		||||
 | 
			
		||||
    # TODO:
 | 
			
		||||
    # - should there be some ._closed flag that causes
 | 
			
		||||
| 
						 | 
				
			
			@ -186,8 +208,10 @@ def broadcast_receiver(
 | 
			
		|||
 | 
			
		||||
    return BroadcastReceiver(
 | 
			
		||||
        recv_chan,
 | 
			
		||||
        queue=deque(maxlen=max_buffer_size),
 | 
			
		||||
        _subs={},  # this is singleton over all subscriptions
 | 
			
		||||
        state=BroadcastState(
 | 
			
		||||
            queue=deque(maxlen=max_buffer_size),
 | 
			
		||||
            subs={},
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -210,7 +234,7 @@ if __name__ == '__main__':
 | 
			
		|||
            ) -> None:
 | 
			
		||||
 | 
			
		||||
                task = current_task()
 | 
			
		||||
                count = 0
 | 
			
		||||
                lags = 0
 | 
			
		||||
 | 
			
		||||
                while True:
 | 
			
		||||
                    async with rx.subscribe() as brx:
 | 
			
		||||
| 
						 | 
				
			
			@ -218,22 +242,22 @@ if __name__ == '__main__':
 | 
			
		|||
                            async for value in brx:
 | 
			
		||||
                                print(f'{task.name}: {value}')
 | 
			
		||||
                                await trio.sleep(delay)
 | 
			
		||||
                                count += 1
 | 
			
		||||
 | 
			
		||||
                        except Lagged:
 | 
			
		||||
                            print(
 | 
			
		||||
                                f'restarting slow ass {task.name}'
 | 
			
		||||
                                f'that bailed out on {count}:{value}')
 | 
			
		||||
                            if count <= retries:
 | 
			
		||||
                                f'that bailed out on {lags}:{value}')
 | 
			
		||||
                            if lags <= retries:
 | 
			
		||||
                                lags += 1
 | 
			
		||||
                                continue
 | 
			
		||||
                            else:
 | 
			
		||||
                                print(
 | 
			
		||||
                                    f'{task.name} was too slow and terminated '
 | 
			
		||||
                                    f'on {count}:{value}')
 | 
			
		||||
                                    f'on {lags}:{value}')
 | 
			
		||||
                                return
 | 
			
		||||
 | 
			
		||||
            async with trio.open_nursery() as n:
 | 
			
		||||
                for i in range(1, size):
 | 
			
		||||
                for i in range(1, 10):
 | 
			
		||||
                    n.start_soon(
 | 
			
		||||
                        partial(
 | 
			
		||||
                            sub_and_print,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue