Add common state delegate type for all consumers

For every set of broadcast receivers which pull from the same producer,
we need a singleton state for all of,
- subscriptions
- the sender ready event
- the queue

Add a `BroadcastState` dataclass for this and pass it to all
subscriptions. This makes the design much more like the built-in memory
channels which do something very similar with `MemoryChannelState`.

Use a `filter()` on the subs list in the sequence update step, plus some
other commented approaches we can try for speed.
tokio_backup
Tyler Goodlet 2021-08-10 15:32:53 -04:00
parent 9d12cc80dd
commit b9863fc4ab
1 changed files with 71 additions and 47 deletions

View File

@ -4,23 +4,42 @@ https://tokio-rs.github.io/tokio/doc/tokio/sync/broadcast/index.html
''' '''
from __future__ import annotations from __future__ import annotations
from itertools import cycle
from collections import deque from collections import deque
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import dataclass
from functools import partial from functools import partial
from itertools import cycle
from operator import ne
from typing import Optional from typing import Optional
import trio import trio
import tractor
from trio.lowlevel import current_task
from trio.abc import ReceiveChannel
from trio._core._run import Task from trio._core._run import Task
from trio.abc import ReceiveChannel
from trio.lowlevel import current_task
import tractor
class Lagged(trio.TooSlowError): class Lagged(trio.TooSlowError):
'''Subscribed consumer task was too slow''' '''Subscribed consumer task was too slow'''
@dataclass
class BroadcastState:
'''Common state to all receivers of a broadcast.
'''
queue: deque
# map of underlying clones to receiver wrappers
# which must be provided as a singleton per broadcaster
# clone-subscription set.
subs: dict[trio.ReceiveChannel, BroadcastReceiver]
# broadcast event to wakeup all sleeping consumer tasks
# on a newly produced value from the sender.
sender_ready: Optional[trio.Event] = None
class BroadcastReceiver(ReceiveChannel): class BroadcastReceiver(ReceiveChannel):
'''A memory receive channel broadcaster which is non-lossy for the '''A memory receive channel broadcaster which is non-lossy for the
fastest consumer. fastest consumer.
@ -33,28 +52,21 @@ class BroadcastReceiver(ReceiveChannel):
self, self,
rx_chan: ReceiveChannel, rx_chan: ReceiveChannel,
queue: deque, state: BroadcastState,
_subs: dict[trio.ReceiveChannel, BroadcastReceiver],
) -> None: ) -> None:
# map of underlying clones to receiver wrappers # register the original underlying (clone)
# which must be provided as a singleton per broadcaster self._state = state
# clone-subscription set. state.subs[rx_chan] = -1
self._subs = _subs
# underlying for this receiver # underlying for this receiver
self._rx = rx_chan self._rx = rx_chan
# register the original underlying (clone)
self._subs[rx_chan] = -1
self._queue = queue
self._value_received: Optional[trio.Event] = None
async def receive(self): async def receive(self):
key = self._rx key = self._rx
state = self._state
# TODO: ideally we can make some way to "lock out" the # TODO: ideally we can make some way to "lock out" the
# underlying receive channel in some way such that if some task # underlying receive channel in some way such that if some task
@ -64,7 +76,7 @@ class BroadcastReceiver(ReceiveChannel):
# only tasks which have entered ``.subscribe()`` can # only tasks which have entered ``.subscribe()`` can
# receive on this broadcaster. # receive on this broadcaster.
try: try:
seq = self._subs[key] seq = state.subs[key]
except KeyError: except KeyError:
raise RuntimeError( raise RuntimeError(
f'{self} is not registerd as subscriber') f'{self} is not registerd as subscriber')
@ -74,7 +86,7 @@ class BroadcastReceiver(ReceiveChannel):
if seq > -1: if seq > -1:
# get the oldest value we haven't received immediately # get the oldest value we haven't received immediately
try: try:
value = self._queue[seq] value = state.queue[seq]
except IndexError: except IndexError:
# adhere to ``tokio`` style "lagging": # adhere to ``tokio`` style "lagging":
@ -87,51 +99,61 @@ class BroadcastReceiver(ReceiveChannel):
# decrement to the last value and expect # decrement to the last value and expect
# consumer to either handle the ``Lagged`` and come back # consumer to either handle the ``Lagged`` and come back
# or bail out on its own (thus un-subscribing) # or bail out on its own (thus un-subscribing)
self._subs[key] = self._queue.maxlen - 1 state.subs[key] = state.queue.maxlen - 1
# this task was overrun by the producer side # this task was overrun by the producer side
task: Task = current_task() task: Task = current_task()
raise Lagged(f'Task {task.name} was overrun') raise Lagged(f'Task {task.name} was overrun')
self._subs[key] -= 1 state.subs[key] -= 1
return value return value
# current task already has the latest value **and** is the # current task already has the latest value **and** is the
# first task to begin waiting for a new one # first task to begin waiting for a new one
if self._value_received is None: if state.sender_ready is None:
event = self._value_received = trio.Event() event = state.sender_ready = trio.Event()
value = await self._rx.receive() value = await self._rx.receive()
# items with lower indices are "newer" # items with lower indices are "newer"
self._queue.appendleft(value) state.queue.appendleft(value)
# broadcast new value to all subscribers by increasing # broadcast new value to all subscribers by increasing
# all sequence numbers that will point in the queue to # all sequence numbers that will point in the queue to
# their latest available value. # their latest available value.
subs = self._subs.copy() # don't decrement the sequence for this task since we
# don't decrement the sequence # for this task since we
# already retreived the last value # already retreived the last value
subs.pop(key)
for sub_key, seq in subs.items(): # XXX: which of these impls is fastest?
self._subs[sub_key] += 1
# subs = state.subs.copy()
# subs.pop(key)
for sub_key in filter(
# lambda k: k != key, state.subs,
partial(ne, key), state.subs,
):
state.subs[sub_key] += 1
# reset receiver waiter task event for next blocking condition # reset receiver waiter task event for next blocking condition
self._value_received = None
event.set() event.set()
state.sender_ready = None
return value return value
# This task is all caught up and ready to receive the latest # This task is all caught up and ready to receive the latest
# value, so queue sched it on the internal event. # value, so queue sched it on the internal event.
else: else:
await self._value_received.wait() await state.sender_ready.wait()
seq = self._subs[key] # TODO: optimization: if this is always true can't we just
assert seq > -1, 'Internal error?' # skip iterating these sequence numbers on the fastest
# task's wakeup and always read from state.queue[0]?
seq = state.subs[key]
assert seq == 0, 'Internal error?'
self._subs[key] -= 1 state.subs[key] -= 1
return self._queue[0] return state.queue[seq]
@asynccontextmanager @asynccontextmanager
async def subscribe( async def subscribe(
@ -145,12 +167,12 @@ class BroadcastReceiver(ReceiveChannel):
''' '''
clone = self._rx.clone() clone = self._rx.clone()
state = self._state
br = BroadcastReceiver( br = BroadcastReceiver(
clone, rx_chan=clone,
self._queue, state=state,
_subs=self._subs,
) )
assert clone in self._subs assert clone in state.subs
try: try:
yield br yield br
@ -159,7 +181,7 @@ class BroadcastReceiver(ReceiveChannel):
# ``AsyncResource`` api. # ``AsyncResource`` api.
await clone.aclose() await clone.aclose()
# drop from subscribers and close # drop from subscribers and close
self._subs.pop(clone) state.subs.pop(clone)
# TODO: # TODO:
# - should there be some ._closed flag that causes # - should there be some ._closed flag that causes
@ -186,8 +208,10 @@ def broadcast_receiver(
return BroadcastReceiver( return BroadcastReceiver(
recv_chan, recv_chan,
state=BroadcastState(
queue=deque(maxlen=max_buffer_size), queue=deque(maxlen=max_buffer_size),
_subs={}, # this is singleton over all subscriptions subs={},
),
) )
@ -210,7 +234,7 @@ if __name__ == '__main__':
) -> None: ) -> None:
task = current_task() task = current_task()
count = 0 lags = 0
while True: while True:
async with rx.subscribe() as brx: async with rx.subscribe() as brx:
@ -218,22 +242,22 @@ if __name__ == '__main__':
async for value in brx: async for value in brx:
print(f'{task.name}: {value}') print(f'{task.name}: {value}')
await trio.sleep(delay) await trio.sleep(delay)
count += 1
except Lagged: except Lagged:
print( print(
f'restarting slow ass {task.name}' f'restarting slow ass {task.name}'
f'that bailed out on {count}:{value}') f'that bailed out on {lags}:{value}')
if count <= retries: if lags <= retries:
lags += 1
continue continue
else: else:
print( print(
f'{task.name} was too slow and terminated ' f'{task.name} was too slow and terminated '
f'on {count}:{value}') f'on {lags}:{value}')
return return
async with trio.open_nursery() as n: async with trio.open_nursery() as n:
for i in range(1, size): for i in range(1, 10):
n.start_soon( n.start_soon(
partial( partial(
sub_and_print, sub_and_print,