Add common state delegate type for all consumers

For every set of broadcast receivers which pull from the same producer,
we need a singleton state for all of,
- subscriptions
- the sender ready event
- the queue

Add a `BroadcastState` dataclass for this and pass it to all
subscriptions. This makes the design much more like the built-in memory
channels which do something very similar with `MemoryChannelState`.

Use a `filter()` on the subs list in the sequence update step, plus some
other commented approaches we can try for speed.
tokio_backup
Tyler Goodlet 2021-08-10 15:32:53 -04:00
parent 9d12cc80dd
commit b9863fc4ab
1 changed files with 71 additions and 47 deletions

View File

@ -4,23 +4,42 @@ https://tokio-rs.github.io/tokio/doc/tokio/sync/broadcast/index.html
'''
from __future__ import annotations
from itertools import cycle
from collections import deque
from contextlib import asynccontextmanager
from dataclasses import dataclass
from functools import partial
from itertools import cycle
from operator import ne
from typing import Optional
import trio
import tractor
from trio.lowlevel import current_task
from trio.abc import ReceiveChannel
from trio._core._run import Task
from trio.abc import ReceiveChannel
from trio.lowlevel import current_task
import tractor
class Lagged(trio.TooSlowError):
'''Subscribed consumer task was too slow'''
@dataclass
class BroadcastState:
'''Common state to all receivers of a broadcast.
'''
queue: deque
# map of underlying clones to receiver wrappers
# which must be provided as a singleton per broadcaster
# clone-subscription set.
subs: dict[trio.ReceiveChannel, BroadcastReceiver]
# broadcast event to wakeup all sleeping consumer tasks
# on a newly produced value from the sender.
sender_ready: Optional[trio.Event] = None
class BroadcastReceiver(ReceiveChannel):
'''A memory receive channel broadcaster which is non-lossy for the
fastest consumer.
@ -33,28 +52,21 @@ class BroadcastReceiver(ReceiveChannel):
self,
rx_chan: ReceiveChannel,
queue: deque,
_subs: dict[trio.ReceiveChannel, BroadcastReceiver],
state: BroadcastState,
) -> None:
# map of underlying clones to receiver wrappers
# which must be provided as a singleton per broadcaster
# clone-subscription set.
self._subs = _subs
# register the original underlying (clone)
self._state = state
state.subs[rx_chan] = -1
# underlying for this receiver
self._rx = rx_chan
# register the original underlying (clone)
self._subs[rx_chan] = -1
self._queue = queue
self._value_received: Optional[trio.Event] = None
async def receive(self):
key = self._rx
state = self._state
# TODO: ideally we can make some way to "lock out" the
# underlying receive channel in some way such that if some task
@ -64,7 +76,7 @@ class BroadcastReceiver(ReceiveChannel):
# only tasks which have entered ``.subscribe()`` can
# receive on this broadcaster.
try:
seq = self._subs[key]
seq = state.subs[key]
except KeyError:
raise RuntimeError(
f'{self} is not registerd as subscriber')
@ -74,7 +86,7 @@ class BroadcastReceiver(ReceiveChannel):
if seq > -1:
# get the oldest value we haven't received immediately
try:
value = self._queue[seq]
value = state.queue[seq]
except IndexError:
# adhere to ``tokio`` style "lagging":
@ -87,51 +99,61 @@ class BroadcastReceiver(ReceiveChannel):
# decrement to the last value and expect
# consumer to either handle the ``Lagged`` and come back
# or bail out on its own (thus un-subscribing)
self._subs[key] = self._queue.maxlen - 1
state.subs[key] = state.queue.maxlen - 1
# this task was overrun by the producer side
task: Task = current_task()
raise Lagged(f'Task {task.name} was overrun')
self._subs[key] -= 1
state.subs[key] -= 1
return value
# current task already has the latest value **and** is the
# first task to begin waiting for a new one
if self._value_received is None:
if state.sender_ready is None:
event = self._value_received = trio.Event()
event = state.sender_ready = trio.Event()
value = await self._rx.receive()
# items with lower indices are "newer"
self._queue.appendleft(value)
state.queue.appendleft(value)
# broadcast new value to all subscribers by increasing
# all sequence numbers that will point in the queue to
# their latest available value.
subs = self._subs.copy()
# don't decrement the sequence # for this task since we
# don't decrement the sequence for this task since we
# already retreived the last value
subs.pop(key)
for sub_key, seq in subs.items():
self._subs[sub_key] += 1
# XXX: which of these impls is fastest?
# subs = state.subs.copy()
# subs.pop(key)
for sub_key in filter(
# lambda k: k != key, state.subs,
partial(ne, key), state.subs,
):
state.subs[sub_key] += 1
# reset receiver waiter task event for next blocking condition
self._value_received = None
event.set()
state.sender_ready = None
return value
# This task is all caught up and ready to receive the latest
# value, so queue sched it on the internal event.
else:
await self._value_received.wait()
await state.sender_ready.wait()
seq = self._subs[key]
assert seq > -1, 'Internal error?'
# TODO: optimization: if this is always true can't we just
# skip iterating these sequence numbers on the fastest
# task's wakeup and always read from state.queue[0]?
seq = state.subs[key]
assert seq == 0, 'Internal error?'
self._subs[key] -= 1
return self._queue[0]
state.subs[key] -= 1
return state.queue[seq]
@asynccontextmanager
async def subscribe(
@ -145,12 +167,12 @@ class BroadcastReceiver(ReceiveChannel):
'''
clone = self._rx.clone()
state = self._state
br = BroadcastReceiver(
clone,
self._queue,
_subs=self._subs,
rx_chan=clone,
state=state,
)
assert clone in self._subs
assert clone in state.subs
try:
yield br
@ -159,7 +181,7 @@ class BroadcastReceiver(ReceiveChannel):
# ``AsyncResource`` api.
await clone.aclose()
# drop from subscribers and close
self._subs.pop(clone)
state.subs.pop(clone)
# TODO:
# - should there be some ._closed flag that causes
@ -186,8 +208,10 @@ def broadcast_receiver(
return BroadcastReceiver(
recv_chan,
queue=deque(maxlen=max_buffer_size),
_subs={}, # this is singleton over all subscriptions
state=BroadcastState(
queue=deque(maxlen=max_buffer_size),
subs={},
),
)
@ -210,7 +234,7 @@ if __name__ == '__main__':
) -> None:
task = current_task()
count = 0
lags = 0
while True:
async with rx.subscribe() as brx:
@ -218,22 +242,22 @@ if __name__ == '__main__':
async for value in brx:
print(f'{task.name}: {value}')
await trio.sleep(delay)
count += 1
except Lagged:
print(
f'restarting slow ass {task.name}'
f'that bailed out on {count}:{value}')
if count <= retries:
f'that bailed out on {lags}:{value}')
if lags <= retries:
lags += 1
continue
else:
print(
f'{task.name} was too slow and terminated '
f'on {count}:{value}')
f'on {lags}:{value}')
return
async with trio.open_nursery() as n:
for i in range(1, size):
for i in range(1, 10):
n.start_soon(
partial(
sub_and_print,