forked from goodboy/tractor
Better `trio`-ize `BroadcastReceiver` internals
Driven by a bug found in `piker` where we'd get an inf recursion error due to `BroadcastReceiver.receive()` being called when consumer tasks are awoken but no value is ready to `.nowait_receive()`. This new rework takes an approach closer to the interface and internals of `trio.MemoryReceiveChannel` particularly in terms of, - implementing a `BroadcastReceiver.receive_nowait()` and using it within the async `.receive()`. - failing over to an internal `._receive_from_underlying()` when the `_nowait()` call raises `trio.WouldBlock`. - adding `BroadcastState.statistics()` for debugging and testing dropping recursion from `.receive()`.breceiver_internals
parent
a777217674
commit
c2367c1c5e
|
@ -23,7 +23,6 @@ from __future__ import annotations
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from dataclasses import dataclass
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from operator import ne
|
from operator import ne
|
||||||
from typing import Optional, Callable, Awaitable, Any, AsyncIterator, Protocol
|
from typing import Optional, Callable, Awaitable, Any, AsyncIterator, Protocol
|
||||||
|
@ -33,6 +32,7 @@ import trio
|
||||||
from trio._core._run import Task
|
from trio._core._run import Task
|
||||||
from trio.abc import ReceiveChannel
|
from trio.abc import ReceiveChannel
|
||||||
from trio.lowlevel import current_task
|
from trio.lowlevel import current_task
|
||||||
|
from msgspec import Struct
|
||||||
|
|
||||||
|
|
||||||
# A regular invariant generic type
|
# A regular invariant generic type
|
||||||
|
@ -86,8 +86,7 @@ class Lagged(trio.TooSlowError):
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class BroadcastState(Struct):
|
||||||
class BroadcastState:
|
|
||||||
'''
|
'''
|
||||||
Common state to all receivers of a broadcast.
|
Common state to all receivers of a broadcast.
|
||||||
|
|
||||||
|
@ -110,7 +109,32 @@ class BroadcastState:
|
||||||
eoc: bool = False
|
eoc: bool = False
|
||||||
|
|
||||||
# If the broadcaster was cancelled, we might as well track it
|
# If the broadcaster was cancelled, we might as well track it
|
||||||
cancelled: bool = False
|
cancelled: dict[int, Task] = {}
|
||||||
|
|
||||||
|
def statistics(self) -> dict[str, str | int | float]:
|
||||||
|
'''
|
||||||
|
Return broadcast receiver group "statistics" like many of
|
||||||
|
``trio``'s internal task-sync primitives.
|
||||||
|
|
||||||
|
'''
|
||||||
|
subs = self.subs
|
||||||
|
if self.recv_ready is not None:
|
||||||
|
key, ev = self.recv_ready
|
||||||
|
else:
|
||||||
|
key = ev = None
|
||||||
|
|
||||||
|
qlens = {}
|
||||||
|
for tid, sz in subs.items():
|
||||||
|
qlens[tid] = sz if sz != -1 else 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
'open_consumers': len(subs),
|
||||||
|
'queued_len_by_task': qlens,
|
||||||
|
'max_buffer_size': self.maxlen,
|
||||||
|
'tasks_waiting': ev.statistics().tasks_waiting if ev else 0,
|
||||||
|
'tasks_cancelled': self.cancelled,
|
||||||
|
'next_value_receiver_id': key,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class BroadcastReceiver(ReceiveChannel):
|
class BroadcastReceiver(ReceiveChannel):
|
||||||
|
@ -134,6 +158,12 @@ class BroadcastReceiver(ReceiveChannel):
|
||||||
# register the original underlying (clone)
|
# register the original underlying (clone)
|
||||||
self.key = id(self)
|
self.key = id(self)
|
||||||
self._state = state
|
self._state = state
|
||||||
|
|
||||||
|
# each consumer has an int count which indicates
|
||||||
|
# which index contains the next value that the task has not yet
|
||||||
|
# consumed and thus should read. In the "up-to-date" case the
|
||||||
|
# consumer task must wait for a new value from the underlying
|
||||||
|
# receiver and we use ``-1`` as the sentinel for this state.
|
||||||
state.subs[self.key] = -1
|
state.subs[self.key] = -1
|
||||||
|
|
||||||
# underlying for this receiver
|
# underlying for this receiver
|
||||||
|
@ -141,10 +171,14 @@ class BroadcastReceiver(ReceiveChannel):
|
||||||
self._recv = receive_afunc or rx_chan.receive
|
self._recv = receive_afunc or rx_chan.receive
|
||||||
self._closed: bool = False
|
self._closed: bool = False
|
||||||
|
|
||||||
async def receive(self) -> ReceiveType:
|
def receive_nowait(
|
||||||
|
self,
|
||||||
|
_key: int | None = None,
|
||||||
|
_state: BroadcastState | None = None,
|
||||||
|
|
||||||
key = self.key
|
) -> ReceiveType:
|
||||||
state = self._state
|
key = _key or self.key
|
||||||
|
state = _state or self._state
|
||||||
|
|
||||||
# TODO: ideally we can make some way to "lock out" the
|
# TODO: ideally we can make some way to "lock out" the
|
||||||
# underlying receive channel in some way such that if some task
|
# underlying receive channel in some way such that if some task
|
||||||
|
@ -189,112 +223,241 @@ class BroadcastReceiver(ReceiveChannel):
|
||||||
state.subs[key] -= 1
|
state.subs[key] -= 1
|
||||||
return value
|
return value
|
||||||
|
|
||||||
# current task already has the latest value **and** is the
|
raise trio.WouldBlock
|
||||||
# first task to begin waiting for a new one
|
|
||||||
if state.recv_ready is None:
|
|
||||||
|
|
||||||
if self._closed:
|
async def _receive_from_underlying(
|
||||||
raise trio.ClosedResourceError
|
self,
|
||||||
|
key: int,
|
||||||
|
state: BroadcastState,
|
||||||
|
|
||||||
event = trio.Event()
|
) -> ReceiveType:
|
||||||
state.recv_ready = key, event
|
|
||||||
|
|
||||||
|
if self._closed:
|
||||||
|
raise trio.ClosedResourceError
|
||||||
|
|
||||||
|
event = trio.Event()
|
||||||
|
assert state.recv_ready is None
|
||||||
|
state.recv_ready = key, event
|
||||||
|
|
||||||
|
try:
|
||||||
# if we're cancelled here it should be
|
# if we're cancelled here it should be
|
||||||
# fine to bail without affecting any other consumers
|
# fine to bail without affecting any other consumers
|
||||||
# right?
|
# right?
|
||||||
try:
|
value = await self._recv()
|
||||||
value = await self._recv()
|
|
||||||
|
|
||||||
# items with lower indices are "newer"
|
# items with lower indices are "newer"
|
||||||
# NOTE: ``collections.deque`` implicitly takes care of
|
# NOTE: ``collections.deque`` implicitly takes care of
|
||||||
# trucating values outside our ``state.maxlen``. In the
|
# trucating values outside our ``state.maxlen``. In the
|
||||||
# alt-backend-array-case we'll need to make sure this is
|
# alt-backend-array-case we'll need to make sure this is
|
||||||
# implemented in similar ringer-buffer-ish style.
|
# implemented in similar ringer-buffer-ish style.
|
||||||
state.queue.appendleft(value)
|
state.queue.appendleft(value)
|
||||||
|
|
||||||
# broadcast new value to all subscribers by increasing
|
# broadcast new value to all subscribers by increasing
|
||||||
# all sequence numbers that will point in the queue to
|
# all sequence numbers that will point in the queue to
|
||||||
# their latest available value.
|
# their latest available value.
|
||||||
|
|
||||||
# don't decrement the sequence for this task since we
|
# don't decrement the sequence for this task since we
|
||||||
# already retreived the last value
|
# already retreived the last value
|
||||||
|
|
||||||
# XXX: which of these impls is fastest?
|
# XXX: which of these impls is fastest?
|
||||||
|
|
||||||
# subs = state.subs.copy()
|
# subs = state.subs.copy()
|
||||||
# subs.pop(key)
|
# subs.pop(key)
|
||||||
|
|
||||||
for sub_key in filter(
|
for sub_key in filter(
|
||||||
# lambda k: k != key, state.subs,
|
# lambda k: k != key, state.subs,
|
||||||
partial(ne, key), state.subs,
|
partial(ne, key), state.subs,
|
||||||
):
|
|
||||||
state.subs[sub_key] += 1
|
|
||||||
|
|
||||||
# NOTE: this should ONLY be set if the above task was *NOT*
|
|
||||||
# cancelled on the `._recv()` call.
|
|
||||||
event.set()
|
|
||||||
return value
|
|
||||||
|
|
||||||
except trio.EndOfChannel:
|
|
||||||
# if any one consumer gets an EOC from the underlying
|
|
||||||
# receiver we need to unblock and send that signal to
|
|
||||||
# all other consumers.
|
|
||||||
self._state.eoc = True
|
|
||||||
if event.statistics().tasks_waiting:
|
|
||||||
event.set()
|
|
||||||
raise
|
|
||||||
|
|
||||||
except (
|
|
||||||
trio.Cancelled,
|
|
||||||
):
|
):
|
||||||
# handle cancelled specially otherwise sibling
|
state.subs[sub_key] += 1
|
||||||
# consumers will be awoken with a sequence of -1
|
|
||||||
# and will potentially try to rewait the underlying
|
|
||||||
# receiver instead of just cancelling immediately.
|
|
||||||
self._state.cancelled = True
|
|
||||||
if event.statistics().tasks_waiting:
|
|
||||||
event.set()
|
|
||||||
raise
|
|
||||||
|
|
||||||
finally:
|
# NOTE: this should ONLY be set if the above task was *NOT*
|
||||||
|
# cancelled on the `._recv()` call.
|
||||||
|
event.set()
|
||||||
|
return value
|
||||||
|
|
||||||
# Reset receiver waiter task event for next blocking condition.
|
except trio.EndOfChannel:
|
||||||
# this MUST be reset even if the above ``.recv()`` call
|
# if any one consumer gets an EOC from the underlying
|
||||||
# was cancelled to avoid the next consumer from blocking on
|
# receiver we need to unblock and send that signal to
|
||||||
# an event that won't be set!
|
# all other consumers.
|
||||||
state.recv_ready = None
|
self._state.eoc = True
|
||||||
|
if event.statistics().tasks_waiting:
|
||||||
|
event.set()
|
||||||
|
raise
|
||||||
|
|
||||||
|
except (
|
||||||
|
trio.Cancelled,
|
||||||
|
):
|
||||||
|
# handle cancelled specially otherwise sibling
|
||||||
|
# consumers will be awoken with a sequence of -1
|
||||||
|
# and will potentially try to rewait the underlying
|
||||||
|
# receiver instead of just cancelling immediately.
|
||||||
|
self._state.cancelled[key] = current_task()
|
||||||
|
if event.statistics().tasks_waiting:
|
||||||
|
event.set()
|
||||||
|
raise
|
||||||
|
|
||||||
|
finally:
|
||||||
|
|
||||||
|
# Reset receiver waiter task event for next blocking condition.
|
||||||
|
# this MUST be reset even if the above ``.recv()`` call
|
||||||
|
# was cancelled to avoid the next consumer from blocking on
|
||||||
|
# an event that won't be set!
|
||||||
|
state.recv_ready = None
|
||||||
|
|
||||||
|
async def receive(self) -> ReceiveType:
|
||||||
|
key = self.key
|
||||||
|
state = self._state
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self.receive_nowait(
|
||||||
|
_key=key,
|
||||||
|
_state=state,
|
||||||
|
)
|
||||||
|
except trio.WouldBlock:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# current task already has the latest value **and** is the
|
||||||
|
# first task to begin waiting for a new one
|
||||||
|
if state.recv_ready is None:
|
||||||
|
return await self._receive_from_underlying(key, state)
|
||||||
|
|
||||||
|
# if self._closed:
|
||||||
|
# raise trio.ClosedResourceError
|
||||||
|
|
||||||
|
# event = trio.Event()
|
||||||
|
# state.recv_ready = key, event
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# # if we're cancelled here it should be
|
||||||
|
# # fine to bail without affecting any other consumers
|
||||||
|
# # right?
|
||||||
|
# value = await self._recv()
|
||||||
|
|
||||||
|
# # items with lower indices are "newer"
|
||||||
|
# # NOTE: ``collections.deque`` implicitly takes care of
|
||||||
|
# # trucating values outside our ``state.maxlen``. In the
|
||||||
|
# # alt-backend-array-case we'll need to make sure this is
|
||||||
|
# # implemented in similar ringer-buffer-ish style.
|
||||||
|
# state.queue.appendleft(value)
|
||||||
|
|
||||||
|
# # broadcast new value to all subscribers by increasing
|
||||||
|
# # all sequence numbers that will point in the queue to
|
||||||
|
# # their latest available value.
|
||||||
|
|
||||||
|
# # don't decrement the sequence for this task since we
|
||||||
|
# # already retreived the last value
|
||||||
|
|
||||||
|
# # XXX: which of these impls is fastest?
|
||||||
|
|
||||||
|
# # subs = state.subs.copy()
|
||||||
|
# # subs.pop(key)
|
||||||
|
|
||||||
|
# for sub_key in filter(
|
||||||
|
# # lambda k: k != key, state.subs,
|
||||||
|
# partial(ne, key), state.subs,
|
||||||
|
# ):
|
||||||
|
# state.subs[sub_key] += 1
|
||||||
|
|
||||||
|
# # NOTE: this should ONLY be set if the above task was *NOT*
|
||||||
|
# # cancelled on the `._recv()` call.
|
||||||
|
# event.set()
|
||||||
|
# return value
|
||||||
|
|
||||||
|
# except trio.EndOfChannel:
|
||||||
|
# # if any one consumer gets an EOC from the underlying
|
||||||
|
# # receiver we need to unblock and send that signal to
|
||||||
|
# # all other consumers.
|
||||||
|
# self._state.eoc = True
|
||||||
|
# if event.statistics().tasks_waiting:
|
||||||
|
# event.set()
|
||||||
|
# raise
|
||||||
|
|
||||||
|
# except (
|
||||||
|
# trio.Cancelled,
|
||||||
|
# ):
|
||||||
|
# # handle cancelled specially otherwise sibling
|
||||||
|
# # consumers will be awoken with a sequence of -1
|
||||||
|
# # and will potentially try to rewait the underlying
|
||||||
|
# # receiver instead of just cancelling immediately.
|
||||||
|
# self._state.cancelled[key] = current_task()
|
||||||
|
# if event.statistics().tasks_waiting:
|
||||||
|
# event.set()
|
||||||
|
# raise
|
||||||
|
|
||||||
|
# finally:
|
||||||
|
|
||||||
|
# # Reset receiver waiter task event for next blocking condition.
|
||||||
|
# # this MUST be reset even if the above ``.recv()`` call
|
||||||
|
# # was cancelled to avoid the next consumer from blocking on
|
||||||
|
# # an event that won't be set!
|
||||||
|
# state.recv_ready = None
|
||||||
|
|
||||||
# This task is all caught up and ready to receive the latest
|
# This task is all caught up and ready to receive the latest
|
||||||
# value, so queue sched it on the internal event.
|
# value, so queue sched it on the internal event.
|
||||||
else:
|
else:
|
||||||
seq = state.subs[key]
|
while state.recv_ready is not None:
|
||||||
assert seq == -1 # sanity
|
# seq = state.subs[key]
|
||||||
_, ev = state.recv_ready
|
# assert seq == -1 # sanity
|
||||||
await ev.wait()
|
_, ev = state.recv_ready
|
||||||
|
await ev.wait()
|
||||||
|
try:
|
||||||
|
return self.receive_nowait(
|
||||||
|
_key=key,
|
||||||
|
_state=state,
|
||||||
|
)
|
||||||
|
except trio.WouldBlock:
|
||||||
|
if (
|
||||||
|
self._closed
|
||||||
|
):
|
||||||
|
raise trio.ClosedResourceError
|
||||||
|
|
||||||
|
subs = state.subs
|
||||||
|
if (
|
||||||
|
len(subs) == 1
|
||||||
|
and key in subs
|
||||||
|
# or cancelled
|
||||||
|
):
|
||||||
|
# XXX: we are the last and only user of this BR so
|
||||||
|
# likely it makes sense to unwind back to the
|
||||||
|
# underlying?
|
||||||
|
import tractor
|
||||||
|
await tractor.breakpoint()
|
||||||
|
|
||||||
|
|
||||||
|
# XXX: In the case where the first task to allocate the
|
||||||
|
# ``.recv_ready`` event is cancelled we will be woken
|
||||||
|
# with a non-incremented sequence number (the ``-1``
|
||||||
|
# sentinel) and thus will read the oldest value if we
|
||||||
|
# use that. Instead we need to detect if we have not
|
||||||
|
# been incremented and then receive again.
|
||||||
|
# return await self.receive()
|
||||||
|
|
||||||
|
# if state.recv_ready is None:
|
||||||
|
|
||||||
|
print(f'{key}: {state.statistics()}')
|
||||||
|
return await self._receive_from_underlying(key, state)
|
||||||
|
|
||||||
|
# seq = state.subs[key]
|
||||||
|
|
||||||
# NOTE: if we ever would like the behaviour where if the
|
# NOTE: if we ever would like the behaviour where if the
|
||||||
# first task to recv on the underlying is cancelled but it
|
# first task to recv on the underlying is cancelled but it
|
||||||
# still DOES trigger the ``.recv_ready``, event we'll likely need
|
# still DOES trigger the ``.recv_ready``, event we'll likely need
|
||||||
# this logic:
|
# this logic:
|
||||||
|
|
||||||
if seq > -1:
|
# if seq > -1:
|
||||||
# stuff from above..
|
# # stuff from above..
|
||||||
seq = state.subs[key]
|
# seq = state.subs[key]
|
||||||
|
|
||||||
value = state.queue[seq]
|
# value = state.queue[seq]
|
||||||
state.subs[key] -= 1
|
# state.subs[key] -= 1
|
||||||
return value
|
# return value
|
||||||
|
|
||||||
elif seq == -1:
|
# elif (
|
||||||
# XXX: In the case where the first task to allocate the
|
# seq == -1
|
||||||
# ``.recv_ready`` event is cancelled we will be woken with
|
# ):
|
||||||
# a non-incremented sequence number and thus will read the
|
|
||||||
# oldest value if we use that. Instead we need to detect if
|
|
||||||
# we have not been incremented and then receive again.
|
|
||||||
return await self.receive()
|
|
||||||
|
|
||||||
else:
|
# else:
|
||||||
raise ValueError(f'Invalid sequence {seq}!?')
|
raise RuntimeError(f'Unable to receive {key}:\n{state.statistics()}')
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def subscribe(
|
async def subscribe(
|
||||||
|
|
Loading…
Reference in New Issue