Better `trio`-ize `BroadcastReceiver` internals
Driven by a bug found in `piker` where we'd get an inf recursion error due to `BroadcastReceiver.receive()` being called when consumer tasks are awoken but no value is ready to `.nowait_receive()`. This new rework takes an approach closer to the interface and internals of `trio.MemoryReceiveChannel` particularly in terms of, - implementing a `BroadcastReceiver.receive_nowait()` and using it within the async `.receive()`. - failing over to an internal `._receive_from_underlying()` when the `_nowait()` call raises `trio.WouldBlock`. - adding `BroadcastState.statistics()` for debugging and testing dropping recursion from `.receive()`.breceiver_internals
parent
a777217674
commit
c2367c1c5e
|
@ -23,7 +23,6 @@ from __future__ import annotations
|
|||
from abc import abstractmethod
|
||||
from collections import deque
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from operator import ne
|
||||
from typing import Optional, Callable, Awaitable, Any, AsyncIterator, Protocol
|
||||
|
@ -33,6 +32,7 @@ import trio
|
|||
from trio._core._run import Task
|
||||
from trio.abc import ReceiveChannel
|
||||
from trio.lowlevel import current_task
|
||||
from msgspec import Struct
|
||||
|
||||
|
||||
# A regular invariant generic type
|
||||
|
@ -86,8 +86,7 @@ class Lagged(trio.TooSlowError):
|
|||
'''
|
||||
|
||||
|
||||
@dataclass
|
||||
class BroadcastState:
|
||||
class BroadcastState(Struct):
|
||||
'''
|
||||
Common state to all receivers of a broadcast.
|
||||
|
||||
|
@ -110,7 +109,32 @@ class BroadcastState:
|
|||
eoc: bool = False
|
||||
|
||||
# If the broadcaster was cancelled, we might as well track it
|
||||
cancelled: bool = False
|
||||
cancelled: dict[int, Task] = {}
|
||||
|
||||
def statistics(self) -> dict[str, str | int | float]:
|
||||
'''
|
||||
Return broadcast receiver group "statistics" like many of
|
||||
``trio``'s internal task-sync primitives.
|
||||
|
||||
'''
|
||||
subs = self.subs
|
||||
if self.recv_ready is not None:
|
||||
key, ev = self.recv_ready
|
||||
else:
|
||||
key = ev = None
|
||||
|
||||
qlens = {}
|
||||
for tid, sz in subs.items():
|
||||
qlens[tid] = sz if sz != -1 else 0
|
||||
|
||||
return {
|
||||
'open_consumers': len(subs),
|
||||
'queued_len_by_task': qlens,
|
||||
'max_buffer_size': self.maxlen,
|
||||
'tasks_waiting': ev.statistics().tasks_waiting if ev else 0,
|
||||
'tasks_cancelled': self.cancelled,
|
||||
'next_value_receiver_id': key,
|
||||
}
|
||||
|
||||
|
||||
class BroadcastReceiver(ReceiveChannel):
|
||||
|
@ -134,6 +158,12 @@ class BroadcastReceiver(ReceiveChannel):
|
|||
# register the original underlying (clone)
|
||||
self.key = id(self)
|
||||
self._state = state
|
||||
|
||||
# each consumer has an int count which indicates
|
||||
# which index contains the next value that the task has not yet
|
||||
# consumed and thus should read. In the "up-to-date" case the
|
||||
# consumer task must wait for a new value from the underlying
|
||||
# receiver and we use ``-1`` as the sentinel for this state.
|
||||
state.subs[self.key] = -1
|
||||
|
||||
# underlying for this receiver
|
||||
|
@ -141,10 +171,14 @@ class BroadcastReceiver(ReceiveChannel):
|
|||
self._recv = receive_afunc or rx_chan.receive
|
||||
self._closed: bool = False
|
||||
|
||||
async def receive(self) -> ReceiveType:
|
||||
def receive_nowait(
|
||||
self,
|
||||
_key: int | None = None,
|
||||
_state: BroadcastState | None = None,
|
||||
|
||||
key = self.key
|
||||
state = self._state
|
||||
) -> ReceiveType:
|
||||
key = _key or self.key
|
||||
state = _state or self._state
|
||||
|
||||
# TODO: ideally we can make some way to "lock out" the
|
||||
# underlying receive channel in some way such that if some task
|
||||
|
@ -189,20 +223,26 @@ class BroadcastReceiver(ReceiveChannel):
|
|||
state.subs[key] -= 1
|
||||
return value
|
||||
|
||||
# current task already has the latest value **and** is the
|
||||
# first task to begin waiting for a new one
|
||||
if state.recv_ready is None:
|
||||
raise trio.WouldBlock
|
||||
|
||||
async def _receive_from_underlying(
|
||||
self,
|
||||
key: int,
|
||||
state: BroadcastState,
|
||||
|
||||
) -> ReceiveType:
|
||||
|
||||
if self._closed:
|
||||
raise trio.ClosedResourceError
|
||||
|
||||
event = trio.Event()
|
||||
assert state.recv_ready is None
|
||||
state.recv_ready = key, event
|
||||
|
||||
try:
|
||||
# if we're cancelled here it should be
|
||||
# fine to bail without affecting any other consumers
|
||||
# right?
|
||||
try:
|
||||
value = await self._recv()
|
||||
|
||||
# items with lower indices are "newer"
|
||||
|
@ -251,7 +291,7 @@ class BroadcastReceiver(ReceiveChannel):
|
|||
# consumers will be awoken with a sequence of -1
|
||||
# and will potentially try to rewait the underlying
|
||||
# receiver instead of just cancelling immediately.
|
||||
self._state.cancelled = True
|
||||
self._state.cancelled[key] = current_task()
|
||||
if event.statistics().tasks_waiting:
|
||||
event.set()
|
||||
raise
|
||||
|
@ -264,37 +304,160 @@ class BroadcastReceiver(ReceiveChannel):
|
|||
# an event that won't be set!
|
||||
state.recv_ready = None
|
||||
|
||||
async def receive(self) -> ReceiveType:
|
||||
key = self.key
|
||||
state = self._state
|
||||
|
||||
try:
|
||||
return self.receive_nowait(
|
||||
_key=key,
|
||||
_state=state,
|
||||
)
|
||||
except trio.WouldBlock:
|
||||
pass
|
||||
|
||||
# current task already has the latest value **and** is the
|
||||
# first task to begin waiting for a new one
|
||||
if state.recv_ready is None:
|
||||
return await self._receive_from_underlying(key, state)
|
||||
|
||||
# if self._closed:
|
||||
# raise trio.ClosedResourceError
|
||||
|
||||
# event = trio.Event()
|
||||
# state.recv_ready = key, event
|
||||
|
||||
# try:
|
||||
# # if we're cancelled here it should be
|
||||
# # fine to bail without affecting any other consumers
|
||||
# # right?
|
||||
# value = await self._recv()
|
||||
|
||||
# # items with lower indices are "newer"
|
||||
# # NOTE: ``collections.deque`` implicitly takes care of
|
||||
# # trucating values outside our ``state.maxlen``. In the
|
||||
# # alt-backend-array-case we'll need to make sure this is
|
||||
# # implemented in similar ringer-buffer-ish style.
|
||||
# state.queue.appendleft(value)
|
||||
|
||||
# # broadcast new value to all subscribers by increasing
|
||||
# # all sequence numbers that will point in the queue to
|
||||
# # their latest available value.
|
||||
|
||||
# # don't decrement the sequence for this task since we
|
||||
# # already retreived the last value
|
||||
|
||||
# # XXX: which of these impls is fastest?
|
||||
|
||||
# # subs = state.subs.copy()
|
||||
# # subs.pop(key)
|
||||
|
||||
# for sub_key in filter(
|
||||
# # lambda k: k != key, state.subs,
|
||||
# partial(ne, key), state.subs,
|
||||
# ):
|
||||
# state.subs[sub_key] += 1
|
||||
|
||||
# # NOTE: this should ONLY be set if the above task was *NOT*
|
||||
# # cancelled on the `._recv()` call.
|
||||
# event.set()
|
||||
# return value
|
||||
|
||||
# except trio.EndOfChannel:
|
||||
# # if any one consumer gets an EOC from the underlying
|
||||
# # receiver we need to unblock and send that signal to
|
||||
# # all other consumers.
|
||||
# self._state.eoc = True
|
||||
# if event.statistics().tasks_waiting:
|
||||
# event.set()
|
||||
# raise
|
||||
|
||||
# except (
|
||||
# trio.Cancelled,
|
||||
# ):
|
||||
# # handle cancelled specially otherwise sibling
|
||||
# # consumers will be awoken with a sequence of -1
|
||||
# # and will potentially try to rewait the underlying
|
||||
# # receiver instead of just cancelling immediately.
|
||||
# self._state.cancelled[key] = current_task()
|
||||
# if event.statistics().tasks_waiting:
|
||||
# event.set()
|
||||
# raise
|
||||
|
||||
# finally:
|
||||
|
||||
# # Reset receiver waiter task event for next blocking condition.
|
||||
# # this MUST be reset even if the above ``.recv()`` call
|
||||
# # was cancelled to avoid the next consumer from blocking on
|
||||
# # an event that won't be set!
|
||||
# state.recv_ready = None
|
||||
|
||||
# This task is all caught up and ready to receive the latest
|
||||
# value, so queue sched it on the internal event.
|
||||
else:
|
||||
seq = state.subs[key]
|
||||
assert seq == -1 # sanity
|
||||
while state.recv_ready is not None:
|
||||
# seq = state.subs[key]
|
||||
# assert seq == -1 # sanity
|
||||
_, ev = state.recv_ready
|
||||
await ev.wait()
|
||||
try:
|
||||
return self.receive_nowait(
|
||||
_key=key,
|
||||
_state=state,
|
||||
)
|
||||
except trio.WouldBlock:
|
||||
if (
|
||||
self._closed
|
||||
):
|
||||
raise trio.ClosedResourceError
|
||||
|
||||
subs = state.subs
|
||||
if (
|
||||
len(subs) == 1
|
||||
and key in subs
|
||||
# or cancelled
|
||||
):
|
||||
# XXX: we are the last and only user of this BR so
|
||||
# likely it makes sense to unwind back to the
|
||||
# underlying?
|
||||
import tractor
|
||||
await tractor.breakpoint()
|
||||
|
||||
|
||||
# XXX: In the case where the first task to allocate the
|
||||
# ``.recv_ready`` event is cancelled we will be woken
|
||||
# with a non-incremented sequence number (the ``-1``
|
||||
# sentinel) and thus will read the oldest value if we
|
||||
# use that. Instead we need to detect if we have not
|
||||
# been incremented and then receive again.
|
||||
# return await self.receive()
|
||||
|
||||
# if state.recv_ready is None:
|
||||
|
||||
print(f'{key}: {state.statistics()}')
|
||||
return await self._receive_from_underlying(key, state)
|
||||
|
||||
# seq = state.subs[key]
|
||||
|
||||
# NOTE: if we ever would like the behaviour where if the
|
||||
# first task to recv on the underlying is cancelled but it
|
||||
# still DOES trigger the ``.recv_ready``, event we'll likely need
|
||||
# this logic:
|
||||
|
||||
if seq > -1:
|
||||
# stuff from above..
|
||||
seq = state.subs[key]
|
||||
# if seq > -1:
|
||||
# # stuff from above..
|
||||
# seq = state.subs[key]
|
||||
|
||||
value = state.queue[seq]
|
||||
state.subs[key] -= 1
|
||||
return value
|
||||
# value = state.queue[seq]
|
||||
# state.subs[key] -= 1
|
||||
# return value
|
||||
|
||||
elif seq == -1:
|
||||
# XXX: In the case where the first task to allocate the
|
||||
# ``.recv_ready`` event is cancelled we will be woken with
|
||||
# a non-incremented sequence number and thus will read the
|
||||
# oldest value if we use that. Instead we need to detect if
|
||||
# we have not been incremented and then receive again.
|
||||
return await self.receive()
|
||||
# elif (
|
||||
# seq == -1
|
||||
# ):
|
||||
|
||||
else:
|
||||
raise ValueError(f'Invalid sequence {seq}!?')
|
||||
# else:
|
||||
raise RuntimeError(f'Unable to receive {key}:\n{state.statistics()}')
|
||||
|
||||
@asynccontextmanager
|
||||
async def subscribe(
|
||||
|
|
Loading…
Reference in New Issue