Better `trio`-ize `BroadcastReceiver` internals

Driven by a bug found in `piker` where we'd get an inf recursion error
due to `BroadcastReceiver.receive()` being called when consumer tasks
are awoken but no value is ready to `.nowait_receive()`.

This new rework takes an approach closer to the interface and internals
of `trio.MemoryReceiveChannel` particularly in terms of,

- implementing a `BroadcastReceiver.receive_nowait()` and using it
  within the async `.receive()`.
- failing over to an internal `._receive_from_underlying()` when the
  `_nowait()` call raises `trio.WouldBlock`.
- adding `BroadcastState.statistics()` for debugging and testing
  dropping recursion from `.receive()`.
breceiver_internals
Tyler Goodlet 2022-11-14 16:10:43 -05:00
parent a777217674
commit c2367c1c5e
1 changed files with 248 additions and 85 deletions

View File

@ -23,7 +23,6 @@ from __future__ import annotations
from abc import abstractmethod from abc import abstractmethod
from collections import deque from collections import deque
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import dataclass
from functools import partial from functools import partial
from operator import ne from operator import ne
from typing import Optional, Callable, Awaitable, Any, AsyncIterator, Protocol from typing import Optional, Callable, Awaitable, Any, AsyncIterator, Protocol
@ -33,6 +32,7 @@ import trio
from trio._core._run import Task from trio._core._run import Task
from trio.abc import ReceiveChannel from trio.abc import ReceiveChannel
from trio.lowlevel import current_task from trio.lowlevel import current_task
from msgspec import Struct
# A regular invariant generic type # A regular invariant generic type
@ -86,8 +86,7 @@ class Lagged(trio.TooSlowError):
''' '''
@dataclass class BroadcastState(Struct):
class BroadcastState:
''' '''
Common state to all receivers of a broadcast. Common state to all receivers of a broadcast.
@ -110,7 +109,32 @@ class BroadcastState:
eoc: bool = False eoc: bool = False
# If the broadcaster was cancelled, we might as well track it # If the broadcaster was cancelled, we might as well track it
cancelled: bool = False cancelled: dict[int, Task] = {}
def statistics(self) -> dict[str, str | int | float]:
'''
Return broadcast receiver group "statistics" like many of
``trio``'s internal task-sync primitives.
'''
subs = self.subs
if self.recv_ready is not None:
key, ev = self.recv_ready
else:
key = ev = None
qlens = {}
for tid, sz in subs.items():
qlens[tid] = sz if sz != -1 else 0
return {
'open_consumers': len(subs),
'queued_len_by_task': qlens,
'max_buffer_size': self.maxlen,
'tasks_waiting': ev.statistics().tasks_waiting if ev else 0,
'tasks_cancelled': self.cancelled,
'next_value_receiver_id': key,
}
class BroadcastReceiver(ReceiveChannel): class BroadcastReceiver(ReceiveChannel):
@ -134,6 +158,12 @@ class BroadcastReceiver(ReceiveChannel):
# register the original underlying (clone) # register the original underlying (clone)
self.key = id(self) self.key = id(self)
self._state = state self._state = state
# each consumer has an int count which indicates
# which index contains the next value that the task has not yet
# consumed and thus should read. In the "up-to-date" case the
# consumer task must wait for a new value from the underlying
# receiver and we use ``-1`` as the sentinel for this state.
state.subs[self.key] = -1 state.subs[self.key] = -1
# underlying for this receiver # underlying for this receiver
@ -141,10 +171,14 @@ class BroadcastReceiver(ReceiveChannel):
self._recv = receive_afunc or rx_chan.receive self._recv = receive_afunc or rx_chan.receive
self._closed: bool = False self._closed: bool = False
async def receive(self) -> ReceiveType: def receive_nowait(
self,
_key: int | None = None,
_state: BroadcastState | None = None,
key = self.key ) -> ReceiveType:
state = self._state key = _key or self.key
state = _state or self._state
# TODO: ideally we can make some way to "lock out" the # TODO: ideally we can make some way to "lock out" the
# underlying receive channel in some way such that if some task # underlying receive channel in some way such that if some task
@ -189,112 +223,241 @@ class BroadcastReceiver(ReceiveChannel):
state.subs[key] -= 1 state.subs[key] -= 1
return value return value
# current task already has the latest value **and** is the raise trio.WouldBlock
# first task to begin waiting for a new one
if state.recv_ready is None:
if self._closed: async def _receive_from_underlying(
raise trio.ClosedResourceError self,
key: int,
state: BroadcastState,
event = trio.Event() ) -> ReceiveType:
state.recv_ready = key, event
if self._closed:
raise trio.ClosedResourceError
event = trio.Event()
assert state.recv_ready is None
state.recv_ready = key, event
try:
# if we're cancelled here it should be # if we're cancelled here it should be
# fine to bail without affecting any other consumers # fine to bail without affecting any other consumers
# right? # right?
try: value = await self._recv()
value = await self._recv()
# items with lower indices are "newer" # items with lower indices are "newer"
# NOTE: ``collections.deque`` implicitly takes care of # NOTE: ``collections.deque`` implicitly takes care of
# trucating values outside our ``state.maxlen``. In the # trucating values outside our ``state.maxlen``. In the
# alt-backend-array-case we'll need to make sure this is # alt-backend-array-case we'll need to make sure this is
# implemented in similar ringer-buffer-ish style. # implemented in similar ringer-buffer-ish style.
state.queue.appendleft(value) state.queue.appendleft(value)
# broadcast new value to all subscribers by increasing # broadcast new value to all subscribers by increasing
# all sequence numbers that will point in the queue to # all sequence numbers that will point in the queue to
# their latest available value. # their latest available value.
# don't decrement the sequence for this task since we # don't decrement the sequence for this task since we
# already retreived the last value # already retreived the last value
# XXX: which of these impls is fastest? # XXX: which of these impls is fastest?
# subs = state.subs.copy() # subs = state.subs.copy()
# subs.pop(key) # subs.pop(key)
for sub_key in filter( for sub_key in filter(
# lambda k: k != key, state.subs, # lambda k: k != key, state.subs,
partial(ne, key), state.subs, partial(ne, key), state.subs,
):
state.subs[sub_key] += 1
# NOTE: this should ONLY be set if the above task was *NOT*
# cancelled on the `._recv()` call.
event.set()
return value
except trio.EndOfChannel:
# if any one consumer gets an EOC from the underlying
# receiver we need to unblock and send that signal to
# all other consumers.
self._state.eoc = True
if event.statistics().tasks_waiting:
event.set()
raise
except (
trio.Cancelled,
): ):
# handle cancelled specially otherwise sibling state.subs[sub_key] += 1
# consumers will be awoken with a sequence of -1
# and will potentially try to rewait the underlying
# receiver instead of just cancelling immediately.
self._state.cancelled = True
if event.statistics().tasks_waiting:
event.set()
raise
finally: # NOTE: this should ONLY be set if the above task was *NOT*
# cancelled on the `._recv()` call.
event.set()
return value
# Reset receiver waiter task event for next blocking condition. except trio.EndOfChannel:
# this MUST be reset even if the above ``.recv()`` call # if any one consumer gets an EOC from the underlying
# was cancelled to avoid the next consumer from blocking on # receiver we need to unblock and send that signal to
# an event that won't be set! # all other consumers.
state.recv_ready = None self._state.eoc = True
if event.statistics().tasks_waiting:
event.set()
raise
except (
trio.Cancelled,
):
# handle cancelled specially otherwise sibling
# consumers will be awoken with a sequence of -1
# and will potentially try to rewait the underlying
# receiver instead of just cancelling immediately.
self._state.cancelled[key] = current_task()
if event.statistics().tasks_waiting:
event.set()
raise
finally:
# Reset receiver waiter task event for next blocking condition.
# this MUST be reset even if the above ``.recv()`` call
# was cancelled to avoid the next consumer from blocking on
# an event that won't be set!
state.recv_ready = None
async def receive(self) -> ReceiveType:
key = self.key
state = self._state
try:
return self.receive_nowait(
_key=key,
_state=state,
)
except trio.WouldBlock:
pass
# current task already has the latest value **and** is the
# first task to begin waiting for a new one
if state.recv_ready is None:
return await self._receive_from_underlying(key, state)
# if self._closed:
# raise trio.ClosedResourceError
# event = trio.Event()
# state.recv_ready = key, event
# try:
# # if we're cancelled here it should be
# # fine to bail without affecting any other consumers
# # right?
# value = await self._recv()
# # items with lower indices are "newer"
# # NOTE: ``collections.deque`` implicitly takes care of
# # trucating values outside our ``state.maxlen``. In the
# # alt-backend-array-case we'll need to make sure this is
# # implemented in similar ringer-buffer-ish style.
# state.queue.appendleft(value)
# # broadcast new value to all subscribers by increasing
# # all sequence numbers that will point in the queue to
# # their latest available value.
# # don't decrement the sequence for this task since we
# # already retreived the last value
# # XXX: which of these impls is fastest?
# # subs = state.subs.copy()
# # subs.pop(key)
# for sub_key in filter(
# # lambda k: k != key, state.subs,
# partial(ne, key), state.subs,
# ):
# state.subs[sub_key] += 1
# # NOTE: this should ONLY be set if the above task was *NOT*
# # cancelled on the `._recv()` call.
# event.set()
# return value
# except trio.EndOfChannel:
# # if any one consumer gets an EOC from the underlying
# # receiver we need to unblock and send that signal to
# # all other consumers.
# self._state.eoc = True
# if event.statistics().tasks_waiting:
# event.set()
# raise
# except (
# trio.Cancelled,
# ):
# # handle cancelled specially otherwise sibling
# # consumers will be awoken with a sequence of -1
# # and will potentially try to rewait the underlying
# # receiver instead of just cancelling immediately.
# self._state.cancelled[key] = current_task()
# if event.statistics().tasks_waiting:
# event.set()
# raise
# finally:
# # Reset receiver waiter task event for next blocking condition.
# # this MUST be reset even if the above ``.recv()`` call
# # was cancelled to avoid the next consumer from blocking on
# # an event that won't be set!
# state.recv_ready = None
# This task is all caught up and ready to receive the latest # This task is all caught up and ready to receive the latest
# value, so queue sched it on the internal event. # value, so queue sched it on the internal event.
else: else:
seq = state.subs[key] while state.recv_ready is not None:
assert seq == -1 # sanity # seq = state.subs[key]
_, ev = state.recv_ready # assert seq == -1 # sanity
await ev.wait() _, ev = state.recv_ready
await ev.wait()
try:
return self.receive_nowait(
_key=key,
_state=state,
)
except trio.WouldBlock:
if (
self._closed
):
raise trio.ClosedResourceError
subs = state.subs
if (
len(subs) == 1
and key in subs
# or cancelled
):
# XXX: we are the last and only user of this BR so
# likely it makes sense to unwind back to the
# underlying?
import tractor
await tractor.breakpoint()
# XXX: In the case where the first task to allocate the
# ``.recv_ready`` event is cancelled we will be woken
# with a non-incremented sequence number (the ``-1``
# sentinel) and thus will read the oldest value if we
# use that. Instead we need to detect if we have not
# been incremented and then receive again.
# return await self.receive()
# if state.recv_ready is None:
print(f'{key}: {state.statistics()}')
return await self._receive_from_underlying(key, state)
# seq = state.subs[key]
# NOTE: if we ever would like the behaviour where if the # NOTE: if we ever would like the behaviour where if the
# first task to recv on the underlying is cancelled but it # first task to recv on the underlying is cancelled but it
# still DOES trigger the ``.recv_ready``, event we'll likely need # still DOES trigger the ``.recv_ready``, event we'll likely need
# this logic: # this logic:
if seq > -1: # if seq > -1:
# stuff from above.. # # stuff from above..
seq = state.subs[key] # seq = state.subs[key]
value = state.queue[seq] # value = state.queue[seq]
state.subs[key] -= 1 # state.subs[key] -= 1
return value # return value
elif seq == -1: # elif (
# XXX: In the case where the first task to allocate the # seq == -1
# ``.recv_ready`` event is cancelled we will be woken with # ):
# a non-incremented sequence number and thus will read the
# oldest value if we use that. Instead we need to detect if
# we have not been incremented and then receive again.
return await self.receive()
else: # else:
raise ValueError(f'Invalid sequence {seq}!?') raise RuntimeError(f'Unable to receive {key}:\n{state.statistics()}')
@asynccontextmanager @asynccontextmanager
async def subscribe( async def subscribe(