2021-08-08 21:23:48 +00:00
|
|
|
'''
|
2021-08-08 23:48:02 +00:00
|
|
|
``tokio`` style broadcast channel.
|
|
|
|
https://tokio-rs.github.io/tokio/doc/tokio/sync/broadcast/index.html
|
2021-08-08 21:23:48 +00:00
|
|
|
|
|
|
|
'''
|
|
|
|
from __future__ import annotations
|
2021-08-16 16:47:49 +00:00
|
|
|
from abc import abstractmethod
|
2021-08-08 21:23:48 +00:00
|
|
|
from collections import deque
|
2021-08-09 20:40:02 +00:00
|
|
|
from contextlib import asynccontextmanager
|
2021-08-10 19:32:53 +00:00
|
|
|
from dataclasses import dataclass
|
2021-08-08 21:23:48 +00:00
|
|
|
from functools import partial
|
2021-08-10 19:32:53 +00:00
|
|
|
from itertools import cycle
|
|
|
|
from operator import ne
|
2021-08-16 16:47:49 +00:00
|
|
|
from typing import Optional, Callable, Awaitable, Any, AsyncIterator, Protocol
|
|
|
|
from typing import Generic, TypeVar
|
2021-08-08 21:23:48 +00:00
|
|
|
|
|
|
|
import trio
|
|
|
|
from trio._core._run import Task
|
2021-08-10 19:32:53 +00:00
|
|
|
from trio.abc import ReceiveChannel
|
|
|
|
from trio.lowlevel import current_task
|
|
|
|
import tractor
|
2021-08-08 21:23:48 +00:00
|
|
|
|
|
|
|
|
2021-08-16 16:47:49 +00:00
|
|
|
# A regular invariant generic type
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
|
|
# The type of object produced by a ReceiveChannel (covariant because
|
|
|
|
# ReceiveChannel[Derived] can be passed to someone expecting
|
|
|
|
# ReceiveChannel[Base])
|
|
|
|
ReceiveType = TypeVar("ReceiveType", covariant=True)
|
|
|
|
|
|
|
|
|
|
|
|
class CloneableReceiveChannel(
|
|
|
|
Protocol,
|
|
|
|
Generic[ReceiveType],
|
|
|
|
):
|
|
|
|
@abstractmethod
|
|
|
|
def clone(self) -> CloneableReceiveChannel[ReceiveType]:
|
|
|
|
'''Clone this receiver usually by making a copy.'''
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
async def receive(self) -> ReceiveType:
|
|
|
|
'''Same as in ``trio``.'''
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def __aiter__(self) -> AsyncIterator[ReceiveType]:
|
|
|
|
...
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
async def __anext__(self) -> ReceiveType:
|
|
|
|
...
|
|
|
|
|
|
|
|
# ``trio.abc.AsyncResource`` methods
|
|
|
|
@abstractmethod
|
|
|
|
async def aclose(self):
|
|
|
|
...
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
async def __aenter__(self) -> CloneableReceiveChannel[ReceiveType]:
|
|
|
|
...
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
async def __aexit__(self, *args) -> None:
|
|
|
|
...
|
|
|
|
|
|
|
|
|
2021-08-08 21:23:48 +00:00
|
|
|
class Lagged(trio.TooSlowError):
|
|
|
|
'''Subscribed consumer task was too slow'''
|
|
|
|
|
|
|
|
|
2021-08-10 19:32:53 +00:00
|
|
|
@dataclass
|
|
|
|
class BroadcastState:
|
|
|
|
'''Common state to all receivers of a broadcast.
|
|
|
|
|
|
|
|
'''
|
|
|
|
queue: deque
|
|
|
|
|
|
|
|
# map of underlying clones to receiver wrappers
|
|
|
|
# which must be provided as a singleton per broadcaster
|
|
|
|
# clone-subscription set.
|
2021-08-16 16:47:49 +00:00
|
|
|
subs: dict[CloneableReceiveChannel, int]
|
2021-08-10 19:32:53 +00:00
|
|
|
|
|
|
|
# broadcast event to wakeup all sleeping consumer tasks
|
|
|
|
# on a newly produced value from the sender.
|
|
|
|
sender_ready: Optional[trio.Event] = None
|
|
|
|
|
|
|
|
|
2021-08-08 21:23:48 +00:00
|
|
|
class BroadcastReceiver(ReceiveChannel):
|
2021-08-08 23:48:02 +00:00
|
|
|
'''A memory receive channel broadcaster which is non-lossy for the
|
|
|
|
fastest consumer.
|
|
|
|
|
|
|
|
Additional consumer tasks can receive all produced values by registering
|
2021-08-09 20:40:02 +00:00
|
|
|
with ``.subscribe()`` and receiving from thew new instance it delivers.
|
2021-08-08 21:23:48 +00:00
|
|
|
|
|
|
|
'''
|
|
|
|
def __init__(
|
|
|
|
self,
|
2021-08-08 23:48:02 +00:00
|
|
|
|
2021-08-16 16:47:49 +00:00
|
|
|
rx_chan: CloneableReceiveChannel,
|
2021-08-10 19:32:53 +00:00
|
|
|
state: BroadcastState,
|
2021-08-16 16:47:49 +00:00
|
|
|
receive_afunc: Optional[Callable[[], Awaitable[Any]]] = None,
|
2021-08-08 21:23:48 +00:00
|
|
|
|
|
|
|
) -> None:
|
|
|
|
|
2021-08-10 19:32:53 +00:00
|
|
|
# register the original underlying (clone)
|
|
|
|
self._state = state
|
|
|
|
state.subs[rx_chan] = -1
|
2021-08-10 16:38:26 +00:00
|
|
|
|
|
|
|
# underlying for this receiver
|
2021-08-08 21:23:48 +00:00
|
|
|
self._rx = rx_chan
|
2021-08-16 16:47:49 +00:00
|
|
|
self._recv = receive_afunc or rx_chan.receive
|
2021-08-10 16:38:26 +00:00
|
|
|
|
2021-08-08 21:23:48 +00:00
|
|
|
async def receive(self):
|
|
|
|
|
2021-08-09 20:40:02 +00:00
|
|
|
key = self._rx
|
2021-08-10 19:32:53 +00:00
|
|
|
state = self._state
|
2021-08-08 21:23:48 +00:00
|
|
|
|
2021-08-09 20:40:02 +00:00
|
|
|
# TODO: ideally we can make some way to "lock out" the
|
|
|
|
# underlying receive channel in some way such that if some task
|
|
|
|
# tries to pull from it directly (i.e. one we're unaware of)
|
|
|
|
# then it errors out.
|
|
|
|
|
|
|
|
# only tasks which have entered ``.subscribe()`` can
|
|
|
|
# receive on this broadcaster.
|
2021-08-08 21:23:48 +00:00
|
|
|
try:
|
2021-08-10 19:32:53 +00:00
|
|
|
seq = state.subs[key]
|
2021-08-08 21:23:48 +00:00
|
|
|
except KeyError:
|
|
|
|
raise RuntimeError(
|
2021-08-09 20:40:02 +00:00
|
|
|
f'{self} is not registerd as subscriber')
|
2021-08-08 21:23:48 +00:00
|
|
|
|
2021-08-09 20:40:02 +00:00
|
|
|
# check that task does not already have a value it can receive
|
|
|
|
# immediately and/or that it has lagged.
|
2021-08-08 21:23:48 +00:00
|
|
|
if seq > -1:
|
|
|
|
# get the oldest value we haven't received immediately
|
|
|
|
try:
|
2021-08-10 19:32:53 +00:00
|
|
|
value = state.queue[seq]
|
2021-08-08 21:23:48 +00:00
|
|
|
except IndexError:
|
2021-08-09 20:40:02 +00:00
|
|
|
|
|
|
|
# adhere to ``tokio`` style "lagging":
|
|
|
|
# "Once RecvError::Lagged is returned, the lagging
|
|
|
|
# receiver's position is updated to the oldest value
|
|
|
|
# contained by the channel. The next call to recv will
|
|
|
|
# return this value."
|
|
|
|
# https://tokio-rs.github.io/tokio/doc/tokio/sync/broadcast/index.html#lagging
|
|
|
|
|
2021-08-08 23:48:02 +00:00
|
|
|
# decrement to the last value and expect
|
|
|
|
# consumer to either handle the ``Lagged`` and come back
|
2021-08-09 20:40:02 +00:00
|
|
|
# or bail out on its own (thus un-subscribing)
|
2021-08-10 19:32:53 +00:00
|
|
|
state.subs[key] = state.queue.maxlen - 1
|
2021-08-08 23:48:02 +00:00
|
|
|
|
|
|
|
# this task was overrun by the producer side
|
2021-08-09 20:40:02 +00:00
|
|
|
task: Task = current_task()
|
2021-08-08 21:23:48 +00:00
|
|
|
raise Lagged(f'Task {task.name} was overrun')
|
|
|
|
|
2021-08-10 19:32:53 +00:00
|
|
|
state.subs[key] -= 1
|
2021-08-08 21:23:48 +00:00
|
|
|
return value
|
|
|
|
|
2021-08-09 20:40:02 +00:00
|
|
|
# current task already has the latest value **and** is the
|
|
|
|
# first task to begin waiting for a new one
|
2021-08-10 19:32:53 +00:00
|
|
|
if state.sender_ready is None:
|
2021-08-08 21:23:48 +00:00
|
|
|
|
2021-08-10 19:32:53 +00:00
|
|
|
event = state.sender_ready = trio.Event()
|
2021-08-16 16:47:49 +00:00
|
|
|
value = await self._recv()
|
2021-08-08 21:23:48 +00:00
|
|
|
|
|
|
|
# items with lower indices are "newer"
|
2021-08-10 19:32:53 +00:00
|
|
|
state.queue.appendleft(value)
|
2021-08-08 21:23:48 +00:00
|
|
|
|
|
|
|
# broadcast new value to all subscribers by increasing
|
|
|
|
# all sequence numbers that will point in the queue to
|
|
|
|
# their latest available value.
|
|
|
|
|
2021-08-10 19:32:53 +00:00
|
|
|
# don't decrement the sequence for this task since we
|
2021-08-08 23:48:02 +00:00
|
|
|
# already retreived the last value
|
2021-08-10 19:32:53 +00:00
|
|
|
|
|
|
|
# XXX: which of these impls is fastest?
|
|
|
|
|
|
|
|
# subs = state.subs.copy()
|
|
|
|
# subs.pop(key)
|
|
|
|
|
|
|
|
for sub_key in filter(
|
|
|
|
# lambda k: k != key, state.subs,
|
|
|
|
partial(ne, key), state.subs,
|
|
|
|
):
|
|
|
|
state.subs[sub_key] += 1
|
2021-08-08 21:23:48 +00:00
|
|
|
|
2021-08-08 23:48:02 +00:00
|
|
|
# reset receiver waiter task event for next blocking condition
|
2021-08-08 21:23:48 +00:00
|
|
|
event.set()
|
2021-08-10 19:32:53 +00:00
|
|
|
state.sender_ready = None
|
2021-08-08 21:23:48 +00:00
|
|
|
return value
|
|
|
|
|
2021-08-09 20:40:02 +00:00
|
|
|
# This task is all caught up and ready to receive the latest
|
|
|
|
# value, so queue sched it on the internal event.
|
2021-08-08 21:23:48 +00:00
|
|
|
else:
|
2021-08-10 19:32:53 +00:00
|
|
|
await state.sender_ready.wait()
|
|
|
|
seq = state.subs[key]
|
|
|
|
state.subs[key] -= 1
|
|
|
|
return state.queue[seq]
|
2021-08-08 21:23:48 +00:00
|
|
|
|
2021-08-09 20:40:02 +00:00
|
|
|
@asynccontextmanager
|
|
|
|
async def subscribe(
|
2021-08-08 21:23:48 +00:00
|
|
|
self,
|
2021-08-16 16:47:49 +00:00
|
|
|
) -> AsyncIterator[BroadcastReceiver]:
|
2021-08-09 20:40:02 +00:00
|
|
|
'''Subscribe for values from this broadcast receiver.
|
|
|
|
|
|
|
|
Returns a new ``BroadCastReceiver`` which is registered for and
|
|
|
|
pulls data from a clone of the original ``trio.abc.ReceiveChannel``
|
|
|
|
provided at creation.
|
|
|
|
|
|
|
|
'''
|
2021-08-16 16:47:49 +00:00
|
|
|
# if we didn't want to enforce "clone-ability" how would
|
|
|
|
# we key arbitrary subscriptions? Use a token system?
|
2021-08-09 20:40:02 +00:00
|
|
|
clone = self._rx.clone()
|
2021-08-10 19:32:53 +00:00
|
|
|
state = self._state
|
2021-08-10 16:38:26 +00:00
|
|
|
br = BroadcastReceiver(
|
2021-08-10 19:32:53 +00:00
|
|
|
rx_chan=clone,
|
|
|
|
state=state,
|
2021-08-10 16:38:26 +00:00
|
|
|
)
|
2021-08-10 19:32:53 +00:00
|
|
|
assert clone in state.subs
|
2021-08-10 16:38:26 +00:00
|
|
|
|
2021-08-08 21:23:48 +00:00
|
|
|
try:
|
2021-08-10 16:38:26 +00:00
|
|
|
yield br
|
2021-08-08 21:23:48 +00:00
|
|
|
finally:
|
2021-08-09 20:40:02 +00:00
|
|
|
# XXX: this is the reason this function is async: the
|
|
|
|
# ``AsyncResource`` api.
|
|
|
|
await clone.aclose()
|
2021-08-10 16:38:26 +00:00
|
|
|
# drop from subscribers and close
|
2021-08-10 19:32:53 +00:00
|
|
|
state.subs.pop(clone)
|
2021-08-09 20:40:02 +00:00
|
|
|
|
|
|
|
# TODO:
|
|
|
|
# - should there be some ._closed flag that causes
|
|
|
|
# consumers to die **before** they read all queued values?
|
|
|
|
# - if subs only open and close clones then the underlying
|
|
|
|
# will never be killed until the last instance closes?
|
|
|
|
# This is correct right?
|
|
|
|
async def aclose(
|
|
|
|
self,
|
|
|
|
) -> None:
|
|
|
|
# XXX: leaving it like this consumers can still get values
|
|
|
|
# up to the last received that still reside in the queue.
|
|
|
|
# Is this what we want?
|
|
|
|
await self._rx.aclose()
|
2021-08-16 16:47:49 +00:00
|
|
|
self._state.subs.pop(self._rx)
|
2021-08-08 21:23:48 +00:00
|
|
|
|
|
|
|
|
2021-08-08 23:48:02 +00:00
|
|
|
def broadcast_receiver(
|
2021-08-08 21:23:48 +00:00
|
|
|
|
2021-08-16 16:47:49 +00:00
|
|
|
recv_chan: CloneableReceiveChannel,
|
2021-08-08 21:23:48 +00:00
|
|
|
max_buffer_size: int,
|
2021-08-16 16:47:49 +00:00
|
|
|
**kwargs,
|
2021-08-08 21:23:48 +00:00
|
|
|
|
2021-08-08 23:48:02 +00:00
|
|
|
) -> BroadcastReceiver:
|
2021-08-08 21:23:48 +00:00
|
|
|
|
2021-08-08 23:48:02 +00:00
|
|
|
return BroadcastReceiver(
|
|
|
|
recv_chan,
|
2021-08-10 19:32:53 +00:00
|
|
|
state=BroadcastState(
|
|
|
|
queue=deque(maxlen=max_buffer_size),
|
|
|
|
subs={},
|
|
|
|
),
|
2021-08-16 16:47:49 +00:00
|
|
|
**kwargs,
|
2021-08-08 23:48:02 +00:00
|
|
|
)
|
2021-08-08 21:23:48 +00:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
|
|
async def main():
|
|
|
|
|
|
|
|
async with tractor.open_root_actor(
|
|
|
|
debug_mode=True,
|
|
|
|
# loglevel='info',
|
|
|
|
):
|
|
|
|
|
2021-08-09 20:40:02 +00:00
|
|
|
retries = 3
|
2021-08-08 23:48:02 +00:00
|
|
|
size = 100
|
|
|
|
tx, rx = trio.open_memory_channel(size)
|
|
|
|
rx = broadcast_receiver(rx, size)
|
2021-08-08 21:23:48 +00:00
|
|
|
|
|
|
|
async def sub_and_print(
|
|
|
|
delay: float,
|
|
|
|
) -> None:
|
|
|
|
|
|
|
|
task = current_task()
|
2021-08-10 19:32:53 +00:00
|
|
|
lags = 0
|
2021-08-08 21:23:48 +00:00
|
|
|
|
|
|
|
while True:
|
2021-08-09 20:40:02 +00:00
|
|
|
async with rx.subscribe() as brx:
|
2021-08-08 21:23:48 +00:00
|
|
|
try:
|
2021-08-09 20:40:02 +00:00
|
|
|
async for value in brx:
|
2021-08-08 21:23:48 +00:00
|
|
|
print(f'{task.name}: {value}')
|
|
|
|
await trio.sleep(delay)
|
|
|
|
|
|
|
|
except Lagged:
|
|
|
|
print(
|
|
|
|
f'restarting slow ass {task.name}'
|
2021-08-10 19:32:53 +00:00
|
|
|
f'that bailed out on {lags}:{value}')
|
|
|
|
if lags <= retries:
|
|
|
|
lags += 1
|
2021-08-09 20:40:02 +00:00
|
|
|
continue
|
|
|
|
else:
|
|
|
|
print(
|
|
|
|
f'{task.name} was too slow and terminated '
|
2021-08-10 19:32:53 +00:00
|
|
|
f'on {lags}:{value}')
|
2021-08-09 20:40:02 +00:00
|
|
|
return
|
2021-08-08 21:23:48 +00:00
|
|
|
|
|
|
|
async with trio.open_nursery() as n:
|
2021-08-10 19:32:53 +00:00
|
|
|
for i in range(1, 10):
|
2021-08-08 21:23:48 +00:00
|
|
|
n.start_soon(
|
|
|
|
partial(
|
|
|
|
sub_and_print,
|
|
|
|
delay=i*0.01,
|
|
|
|
),
|
|
|
|
name=f'sub_{i}',
|
|
|
|
)
|
|
|
|
|
|
|
|
async with tx:
|
2021-08-09 20:40:02 +00:00
|
|
|
for i in cycle(range(size)):
|
2021-08-08 21:23:48 +00:00
|
|
|
print(f'sending: {i}')
|
|
|
|
await tx.send(i)
|
|
|
|
|
|
|
|
trio.run(main)
|