Add common state delegate type for all consumers
For every set of broadcast receivers which pull from the same producer, we need a singleton state for all of, - subscriptions - the sender ready event - the queue Add a `BroadcastState` dataclass for this and pass it to all subscriptions. This makes the design much more like the built-in memory channels which do something very similar with `MemoryChannelState`. Use a `filter()` on the subs list in the sequence update step, plus some other commented approaches we can try for speed.tokio_backup
parent
9d12cc80dd
commit
b9863fc4ab
|
@ -4,23 +4,42 @@ https://tokio-rs.github.io/tokio/doc/tokio/sync/broadcast/index.html
|
||||||
|
|
||||||
'''
|
'''
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from itertools import cycle
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from itertools import cycle
|
||||||
|
from operator import ne
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
import tractor
|
|
||||||
from trio.lowlevel import current_task
|
|
||||||
from trio.abc import ReceiveChannel
|
|
||||||
from trio._core._run import Task
|
from trio._core._run import Task
|
||||||
|
from trio.abc import ReceiveChannel
|
||||||
|
from trio.lowlevel import current_task
|
||||||
|
import tractor
|
||||||
|
|
||||||
|
|
||||||
class Lagged(trio.TooSlowError):
|
class Lagged(trio.TooSlowError):
|
||||||
'''Subscribed consumer task was too slow'''
|
'''Subscribed consumer task was too slow'''
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BroadcastState:
|
||||||
|
'''Common state to all receivers of a broadcast.
|
||||||
|
|
||||||
|
'''
|
||||||
|
queue: deque
|
||||||
|
|
||||||
|
# map of underlying clones to receiver wrappers
|
||||||
|
# which must be provided as a singleton per broadcaster
|
||||||
|
# clone-subscription set.
|
||||||
|
subs: dict[trio.ReceiveChannel, BroadcastReceiver]
|
||||||
|
|
||||||
|
# broadcast event to wakeup all sleeping consumer tasks
|
||||||
|
# on a newly produced value from the sender.
|
||||||
|
sender_ready: Optional[trio.Event] = None
|
||||||
|
|
||||||
|
|
||||||
class BroadcastReceiver(ReceiveChannel):
|
class BroadcastReceiver(ReceiveChannel):
|
||||||
'''A memory receive channel broadcaster which is non-lossy for the
|
'''A memory receive channel broadcaster which is non-lossy for the
|
||||||
fastest consumer.
|
fastest consumer.
|
||||||
|
@ -33,28 +52,21 @@ class BroadcastReceiver(ReceiveChannel):
|
||||||
self,
|
self,
|
||||||
|
|
||||||
rx_chan: ReceiveChannel,
|
rx_chan: ReceiveChannel,
|
||||||
queue: deque,
|
state: BroadcastState,
|
||||||
_subs: dict[trio.ReceiveChannel, BroadcastReceiver],
|
|
||||||
|
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
# map of underlying clones to receiver wrappers
|
# register the original underlying (clone)
|
||||||
# which must be provided as a singleton per broadcaster
|
self._state = state
|
||||||
# clone-subscription set.
|
state.subs[rx_chan] = -1
|
||||||
self._subs = _subs
|
|
||||||
|
|
||||||
# underlying for this receiver
|
# underlying for this receiver
|
||||||
self._rx = rx_chan
|
self._rx = rx_chan
|
||||||
|
|
||||||
# register the original underlying (clone)
|
|
||||||
self._subs[rx_chan] = -1
|
|
||||||
|
|
||||||
self._queue = queue
|
|
||||||
self._value_received: Optional[trio.Event] = None
|
|
||||||
|
|
||||||
async def receive(self):
|
async def receive(self):
|
||||||
|
|
||||||
key = self._rx
|
key = self._rx
|
||||||
|
state = self._state
|
||||||
|
|
||||||
# TODO: ideally we can make some way to "lock out" the
|
# TODO: ideally we can make some way to "lock out" the
|
||||||
# underlying receive channel in some way such that if some task
|
# underlying receive channel in some way such that if some task
|
||||||
|
@ -64,7 +76,7 @@ class BroadcastReceiver(ReceiveChannel):
|
||||||
# only tasks which have entered ``.subscribe()`` can
|
# only tasks which have entered ``.subscribe()`` can
|
||||||
# receive on this broadcaster.
|
# receive on this broadcaster.
|
||||||
try:
|
try:
|
||||||
seq = self._subs[key]
|
seq = state.subs[key]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f'{self} is not registerd as subscriber')
|
f'{self} is not registerd as subscriber')
|
||||||
|
@ -74,7 +86,7 @@ class BroadcastReceiver(ReceiveChannel):
|
||||||
if seq > -1:
|
if seq > -1:
|
||||||
# get the oldest value we haven't received immediately
|
# get the oldest value we haven't received immediately
|
||||||
try:
|
try:
|
||||||
value = self._queue[seq]
|
value = state.queue[seq]
|
||||||
except IndexError:
|
except IndexError:
|
||||||
|
|
||||||
# adhere to ``tokio`` style "lagging":
|
# adhere to ``tokio`` style "lagging":
|
||||||
|
@ -87,51 +99,61 @@ class BroadcastReceiver(ReceiveChannel):
|
||||||
# decrement to the last value and expect
|
# decrement to the last value and expect
|
||||||
# consumer to either handle the ``Lagged`` and come back
|
# consumer to either handle the ``Lagged`` and come back
|
||||||
# or bail out on its own (thus un-subscribing)
|
# or bail out on its own (thus un-subscribing)
|
||||||
self._subs[key] = self._queue.maxlen - 1
|
state.subs[key] = state.queue.maxlen - 1
|
||||||
|
|
||||||
# this task was overrun by the producer side
|
# this task was overrun by the producer side
|
||||||
task: Task = current_task()
|
task: Task = current_task()
|
||||||
raise Lagged(f'Task {task.name} was overrun')
|
raise Lagged(f'Task {task.name} was overrun')
|
||||||
|
|
||||||
self._subs[key] -= 1
|
state.subs[key] -= 1
|
||||||
return value
|
return value
|
||||||
|
|
||||||
# current task already has the latest value **and** is the
|
# current task already has the latest value **and** is the
|
||||||
# first task to begin waiting for a new one
|
# first task to begin waiting for a new one
|
||||||
if self._value_received is None:
|
if state.sender_ready is None:
|
||||||
|
|
||||||
event = self._value_received = trio.Event()
|
event = state.sender_ready = trio.Event()
|
||||||
value = await self._rx.receive()
|
value = await self._rx.receive()
|
||||||
|
|
||||||
# items with lower indices are "newer"
|
# items with lower indices are "newer"
|
||||||
self._queue.appendleft(value)
|
state.queue.appendleft(value)
|
||||||
|
|
||||||
# broadcast new value to all subscribers by increasing
|
# broadcast new value to all subscribers by increasing
|
||||||
# all sequence numbers that will point in the queue to
|
# all sequence numbers that will point in the queue to
|
||||||
# their latest available value.
|
# their latest available value.
|
||||||
|
|
||||||
subs = self._subs.copy()
|
# don't decrement the sequence for this task since we
|
||||||
# don't decrement the sequence # for this task since we
|
|
||||||
# already retreived the last value
|
# already retreived the last value
|
||||||
subs.pop(key)
|
|
||||||
for sub_key, seq in subs.items():
|
# XXX: which of these impls is fastest?
|
||||||
self._subs[sub_key] += 1
|
|
||||||
|
# subs = state.subs.copy()
|
||||||
|
# subs.pop(key)
|
||||||
|
|
||||||
|
for sub_key in filter(
|
||||||
|
# lambda k: k != key, state.subs,
|
||||||
|
partial(ne, key), state.subs,
|
||||||
|
):
|
||||||
|
state.subs[sub_key] += 1
|
||||||
|
|
||||||
# reset receiver waiter task event for next blocking condition
|
# reset receiver waiter task event for next blocking condition
|
||||||
self._value_received = None
|
|
||||||
event.set()
|
event.set()
|
||||||
|
state.sender_ready = None
|
||||||
return value
|
return value
|
||||||
|
|
||||||
# This task is all caught up and ready to receive the latest
|
# This task is all caught up and ready to receive the latest
|
||||||
# value, so queue sched it on the internal event.
|
# value, so queue sched it on the internal event.
|
||||||
else:
|
else:
|
||||||
await self._value_received.wait()
|
await state.sender_ready.wait()
|
||||||
|
|
||||||
seq = self._subs[key]
|
# TODO: optimization: if this is always true can't we just
|
||||||
assert seq > -1, 'Internal error?'
|
# skip iterating these sequence numbers on the fastest
|
||||||
|
# task's wakeup and always read from state.queue[0]?
|
||||||
|
seq = state.subs[key]
|
||||||
|
assert seq == 0, 'Internal error?'
|
||||||
|
|
||||||
self._subs[key] -= 1
|
state.subs[key] -= 1
|
||||||
return self._queue[0]
|
return state.queue[seq]
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def subscribe(
|
async def subscribe(
|
||||||
|
@ -145,12 +167,12 @@ class BroadcastReceiver(ReceiveChannel):
|
||||||
|
|
||||||
'''
|
'''
|
||||||
clone = self._rx.clone()
|
clone = self._rx.clone()
|
||||||
|
state = self._state
|
||||||
br = BroadcastReceiver(
|
br = BroadcastReceiver(
|
||||||
clone,
|
rx_chan=clone,
|
||||||
self._queue,
|
state=state,
|
||||||
_subs=self._subs,
|
|
||||||
)
|
)
|
||||||
assert clone in self._subs
|
assert clone in state.subs
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield br
|
yield br
|
||||||
|
@ -159,7 +181,7 @@ class BroadcastReceiver(ReceiveChannel):
|
||||||
# ``AsyncResource`` api.
|
# ``AsyncResource`` api.
|
||||||
await clone.aclose()
|
await clone.aclose()
|
||||||
# drop from subscribers and close
|
# drop from subscribers and close
|
||||||
self._subs.pop(clone)
|
state.subs.pop(clone)
|
||||||
|
|
||||||
# TODO:
|
# TODO:
|
||||||
# - should there be some ._closed flag that causes
|
# - should there be some ._closed flag that causes
|
||||||
|
@ -186,8 +208,10 @@ def broadcast_receiver(
|
||||||
|
|
||||||
return BroadcastReceiver(
|
return BroadcastReceiver(
|
||||||
recv_chan,
|
recv_chan,
|
||||||
|
state=BroadcastState(
|
||||||
queue=deque(maxlen=max_buffer_size),
|
queue=deque(maxlen=max_buffer_size),
|
||||||
_subs={}, # this is singleton over all subscriptions
|
subs={},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -210,7 +234,7 @@ if __name__ == '__main__':
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
task = current_task()
|
task = current_task()
|
||||||
count = 0
|
lags = 0
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
async with rx.subscribe() as brx:
|
async with rx.subscribe() as brx:
|
||||||
|
@ -218,22 +242,22 @@ if __name__ == '__main__':
|
||||||
async for value in brx:
|
async for value in brx:
|
||||||
print(f'{task.name}: {value}')
|
print(f'{task.name}: {value}')
|
||||||
await trio.sleep(delay)
|
await trio.sleep(delay)
|
||||||
count += 1
|
|
||||||
|
|
||||||
except Lagged:
|
except Lagged:
|
||||||
print(
|
print(
|
||||||
f'restarting slow ass {task.name}'
|
f'restarting slow ass {task.name}'
|
||||||
f'that bailed out on {count}:{value}')
|
f'that bailed out on {lags}:{value}')
|
||||||
if count <= retries:
|
if lags <= retries:
|
||||||
|
lags += 1
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
f'{task.name} was too slow and terminated '
|
f'{task.name} was too slow and terminated '
|
||||||
f'on {count}:{value}')
|
f'on {lags}:{value}')
|
||||||
return
|
return
|
||||||
|
|
||||||
async with trio.open_nursery() as n:
|
async with trio.open_nursery() as n:
|
||||||
for i in range(1, size):
|
for i in range(1, 10):
|
||||||
n.start_soon(
|
n.start_soon(
|
||||||
partial(
|
partial(
|
||||||
sub_and_print,
|
sub_and_print,
|
||||||
|
|
Loading…
Reference in New Issue