1
0
Fork 0
tractor/tractor/_broadcast.py

316 lines
9.5 KiB
Python

'''
``tokio`` style broadcast channel.
https://docs.rs/tokio/1.11.0/tokio/sync/broadcast/index.html
'''
from __future__ import annotations
from abc import abstractmethod
from collections import deque
from contextlib import asynccontextmanager
from dataclasses import dataclass
from functools import partial
from operator import ne
from typing import Optional, Callable, Awaitable, Any, AsyncIterator, Protocol
from typing import Generic, TypeVar
import trio
from trio._core._run import Task
from trio.abc import ReceiveChannel
from trio.lowlevel import current_task
# A regular invariant generic type
T = TypeVar("T")
# covariant because AsyncReceiver[Derived] can be passed to someone
# expecting AsyncReceiver[Base])
ReceiveType = TypeVar("ReceiveType", covariant=True)
class AsyncReceiver(
Protocol,
Generic[ReceiveType],
):
'''An async receivable duck-type that quacks much like trio's
``trio.abc.ReceieveChannel``.
'''
@abstractmethod
async def receive(self) -> ReceiveType:
...
@abstractmethod
def __aiter__(self) -> AsyncIterator[ReceiveType]:
...
@abstractmethod
async def __anext__(self) -> ReceiveType:
...
# ``trio.abc.AsyncResource`` methods
@abstractmethod
async def aclose(self):
...
@abstractmethod
async def __aenter__(self) -> AsyncReceiver[ReceiveType]:
...
@abstractmethod
async def __aexit__(self, *args) -> None:
...
class Lagged(trio.TooSlowError):
'''Subscribed consumer task was too slow and was overrun
by the fastest consumer-producer pair.
'''
@dataclass
class BroadcastState:
'''Common state to all receivers of a broadcast.
'''
queue: deque
maxlen: int
# map of underlying instance id keys to receiver instances which
# must be provided as a singleton per broadcaster set.
subs: dict[int, int]
# broadcast event to wake up all sleeping consumer tasks
# on a newly produced value from the sender.
recv_ready: Optional[tuple[int, trio.Event]] = None
class BroadcastReceiver(ReceiveChannel):
'''A memory receive channel broadcaster which is non-lossy for the
fastest consumer.
Additional consumer tasks can receive all produced values by registering
with ``.subscribe()`` and receiving from the new instance it delivers.
'''
def __init__(
self,
rx_chan: AsyncReceiver,
state: BroadcastState,
receive_afunc: Optional[Callable[[], Awaitable[Any]]] = None,
) -> None:
# register the original underlying (clone)
self.key = id(self)
self._state = state
state.subs[self.key] = -1
# underlying for this receiver
self._rx = rx_chan
self._recv = receive_afunc or rx_chan.receive
self._closed: bool = False
async def receive(self) -> ReceiveType:
key = self.key
state = self._state
# TODO: ideally we can make some way to "lock out" the
# underlying receive channel in some way such that if some task
# tries to pull from it directly (i.e. one we're unaware of)
# then it errors out.
# only tasks which have entered ``.subscribe()`` can
# receive on this broadcaster.
try:
seq = state.subs[key]
except KeyError:
if self._closed:
raise trio.ClosedResourceError
raise RuntimeError(
f'{self} is not registerd as subscriber')
# check that task does not already have a value it can receive
# immediately and/or that it has lagged.
if seq > -1:
# get the oldest value we haven't received immediately
try:
value = state.queue[seq]
except IndexError:
# adhere to ``tokio`` style "lagging":
# "Once RecvError::Lagged is returned, the lagging
# receiver's position is updated to the oldest value
# contained by the channel. The next call to recv will
# return this value."
# https://docs.rs/tokio/1.11.0/tokio/sync/broadcast/index.html#lagging
# decrement to the last value and expect
# consumer to either handle the ``Lagged`` and come back
# or bail out on its own (thus un-subscribing)
state.subs[key] = state.maxlen - 1
# this task was overrun by the producer side
task: Task = current_task()
raise Lagged(f'Task {task.name} was overrun')
state.subs[key] -= 1
return value
# current task already has the latest value **and** is the
# first task to begin waiting for a new one
if state.recv_ready is None:
if self._closed:
raise trio.ClosedResourceError
event = trio.Event()
state.recv_ready = key, event
# if we're cancelled here it should be
# fine to bail without affecting any other consumers
# right?
try:
value = await self._recv()
# items with lower indices are "newer"
# NOTE: ``collections.deque`` implicitly takes care of
# trucating values outside our ``state.maxlen``. In the
# alt-backend-array-case we'll need to make sure this is
# implemented in similar ringer-buffer-ish style.
state.queue.appendleft(value)
# broadcast new value to all subscribers by increasing
# all sequence numbers that will point in the queue to
# their latest available value.
# don't decrement the sequence for this task since we
# already retreived the last value
# XXX: which of these impls is fastest?
# subs = state.subs.copy()
# subs.pop(key)
for sub_key in filter(
# lambda k: k != key, state.subs,
partial(ne, key), state.subs,
):
state.subs[sub_key] += 1
# NOTE: this should ONLY be set if the above task was *NOT*
# cancelled on the `._recv()` call.
event.set()
return value
except trio.Cancelled:
# handle cancelled specially otherwise sibling
# consumers will be awoken with a sequence of -1
# state.recv_ready = trio.Cancelled
if event.statistics().tasks_waiting:
event.set()
raise
finally:
# Reset receiver waiter task event for next blocking condition.
# this MUST be reset even if the above ``.recv()`` call
# was cancelled to avoid the next consumer from blocking on
# an event that won't be set!
state.recv_ready = None
# This task is all caught up and ready to receive the latest
# value, so queue sched it on the internal event.
else:
seq = state.subs[key]
assert seq == -1 # sanity
_, ev = state.recv_ready
await ev.wait()
# NOTE: if we ever would like the behaviour where if the
# first task to recv on the underlying is cancelled but it
# still DOES trigger the ``.recv_ready``, event we'll likely need
# this logic:
if seq > -1:
# stuff from above..
seq = state.subs[key]
value = state.queue[seq]
state.subs[key] -= 1
return value
elif seq == -1:
# XXX: In the case where the first task to allocate the
# ``.recv_ready`` event is cancelled we will be woken with
# a non-incremented sequence number and thus will read the
# oldest value if we use that. Instead we need to detect if
# we have not been incremented and then receive again.
return await self.receive()
else:
raise ValueError(f'Invalid sequence {seq}!?')
@asynccontextmanager
async def subscribe(
self,
) -> AsyncIterator[BroadcastReceiver]:
'''Subscribe for values from this broadcast receiver.
Returns a new ``BroadCastReceiver`` which is registered for and
pulls data from a clone of the original ``trio.abc.ReceiveChannel``
provided at creation.
'''
if self._closed:
raise trio.ClosedResourceError
state = self._state
br = BroadcastReceiver(
rx_chan=self._rx,
state=state,
receive_afunc=self._recv,
)
# assert clone in state.subs
assert br.key in state.subs
try:
yield br
finally:
await br.aclose()
async def aclose(
self,
) -> None:
if self._closed:
return
# XXX: leaving it like this consumers can still get values
# up to the last received that still reside in the queue.
self._state.subs.pop(self.key)
self._closed = True
def broadcast_receiver(
recv_chan: AsyncReceiver,
max_buffer_size: int,
**kwargs,
) -> BroadcastReceiver:
return BroadcastReceiver(
recv_chan,
state=BroadcastState(
queue=deque(maxlen=max_buffer_size),
maxlen=max_buffer_size,
subs={},
),
**kwargs,
)