diff --git a/nooz/343.trivial.rst b/nooz/343.trivial.rst new file mode 100644 index 0000000..1193f3c --- /dev/null +++ b/nooz/343.trivial.rst @@ -0,0 +1,19 @@ +Rework our ``.trionics.BroadcastReceiver`` internals to avoid method +recursion and approach a design and interface closer to ``trio``'s +``MemoryReceiveChannel``. + +The details of the internal changes include: + +- implementing a ``BroadcastReceiver.receive_nowait()`` and using it + within the async ``.receive()`` thus avoiding recursion from + ``.receive()``. +- failing over to an internal ``._receive_from_underlying()`` when the + ``_nowait()`` call raises ``trio.WouldBlock`` +- adding ``BroadcastState.statistics()`` for debugging and testing both + internals and by users. +- add an internal ``BroadcastReceiver._raise_on_lag: bool`` which can be + set to avoid ``Lagged`` raising for possible use cases where a user + wants to choose between a [cheap or nasty + pattern](https://zguide.zeromq.org/docs/chapter7/#The-Cheap-or-Nasty-Pattern) + the the particular stream (we use this in ``piker``'s dark clearing + engine to avoid fast feeds breaking during HFT periods). diff --git a/tests/test_advanced_streaming.py b/tests/test_advanced_streaming.py index 99414a5..799a089 100644 --- a/tests/test_advanced_streaming.py +++ b/tests/test_advanced_streaming.py @@ -14,7 +14,7 @@ def is_win(): return platform.system() == 'Windows' -_registry: dict[str, set[tractor.ReceiveMsgStream]] = { +_registry: dict[str, set[tractor.MsgStream]] = { 'even': set(), 'odd': set(), } diff --git a/tests/test_task_broadcasting.py b/tests/test_task_broadcasting.py index 1e2f6b4..9f4a1fe 100644 --- a/tests/test_task_broadcasting.py +++ b/tests/test_task_broadcasting.py @@ -12,7 +12,10 @@ import pytest import trio from trio.lowlevel import current_task import tractor -from tractor.trionics import broadcast_receiver, Lagged +from tractor.trionics import ( + broadcast_receiver, + Lagged, +) @tractor.context @@ -37,7 +40,7 @@ async def echo_sequences( async def ensure_sequence( - stream: tractor.ReceiveMsgStream, + stream: tractor.MsgStream, sequence: list, delay: Optional[float] = None, @@ -211,7 +214,8 @@ def test_faster_task_to_recv_is_cancelled_by_slower( arb_addr, start_method, ): - '''Ensure that if a faster task consuming from a stream is cancelled + ''' + Ensure that if a faster task consuming from a stream is cancelled the slower task can continue to receive all expected values. ''' @@ -460,3 +464,51 @@ def test_first_recver_is_cancelled(): assert value == 1 trio.run(main) + + +def test_no_raise_on_lag(): + ''' + Run a simple 2-task broadcast where one task is slow but configured + so that it does not raise `Lagged` on overruns using + `raise_on_lasg=False` and verify that the task does not raise. + + ''' + size = 100 + tx, rx = trio.open_memory_channel(size) + brx = broadcast_receiver(rx, size) + + async def slow(): + async with brx.subscribe( + raise_on_lag=False, + ) as br: + async for msg in br: + print(f'slow task got: {msg}') + await trio.sleep(0.1) + + async def fast(): + async with brx.subscribe() as br: + async for msg in br: + print(f'fast task got: {msg}') + + async def main(): + async with ( + tractor.open_root_actor( + # NOTE: so we see the warning msg emitted by the bcaster + # internals when the no raise flag is set. + loglevel='warning', + ), + trio.open_nursery() as n, + ): + n.start_soon(slow) + n.start_soon(fast) + + for i in range(1000): + await tx.send(i) + + # simulate user nailing ctl-c after realizing + # there's a lag in the slow task. + await trio.sleep(1) + raise KeyboardInterrupt + + with pytest.raises(KeyboardInterrupt): + trio.run(main) diff --git a/tractor/__init__.py b/tractor/__init__.py index a691df6..731f3e9 100644 --- a/tractor/__init__.py +++ b/tractor/__init__.py @@ -24,7 +24,6 @@ from ._clustering import open_actor_cluster from ._ipc import Channel from ._streaming import ( Context, - ReceiveMsgStream, MsgStream, stream, context, @@ -64,7 +63,6 @@ __all__ = [ 'MsgStream', 'BaseExceptionGroup', 'Portal', - 'ReceiveMsgStream', 'RemoteActorError', 'breakpoint', 'context', diff --git a/tractor/_portal.py b/tractor/_portal.py index 05504bd..17871aa 100644 --- a/tractor/_portal.py +++ b/tractor/_portal.py @@ -45,7 +45,10 @@ from ._exceptions import ( NoResult, ContextCancelled, ) -from ._streaming import Context, ReceiveMsgStream +from ._streaming import ( + Context, + MsgStream, +) log = get_logger(__name__) @@ -101,7 +104,7 @@ class Portal: # it is expected that ``result()`` will be awaited at some # point. self._expect_result: Optional[Context] = None - self._streams: set[ReceiveMsgStream] = set() + self._streams: set[MsgStream] = set() self.actor = current_actor() async def _submit_for_result( @@ -316,7 +319,7 @@ class Portal: async_gen_func: Callable, # typing: ignore **kwargs, - ) -> AsyncGenerator[ReceiveMsgStream, None]: + ) -> AsyncGenerator[MsgStream, None]: if not inspect.isasyncgenfunction(async_gen_func): if not ( @@ -341,7 +344,7 @@ class Portal: try: # deliver receive only stream - async with ReceiveMsgStream( + async with MsgStream( ctx, ctx._recv_chan, ) as rchan: self._streams.add(rchan) diff --git a/tractor/_streaming.py b/tractor/_streaming.py index 699a906..b112956 100644 --- a/tractor/_streaming.py +++ b/tractor/_streaming.py @@ -50,12 +50,13 @@ log = get_logger(__name__) # - use __slots__ on ``Context``? -class ReceiveMsgStream(trio.abc.ReceiveChannel): +class MsgStream(trio.abc.Channel): ''' - A IPC message stream for receiving logically sequenced values over - an inter-actor ``Channel``. This is the type returned to a local - task which entered either ``Portal.open_stream_from()`` or - ``Context.open_stream()``. + A bidirectional message stream for receiving logically sequenced + values over an inter-actor IPC ``Channel``. + + This is the type returned to a local task which entered either + ``Portal.open_stream_from()`` or ``Context.open_stream()``. Termination rules: @@ -317,15 +318,15 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): async with self._broadcaster.subscribe() as bstream: assert bstream.key != self._broadcaster.key assert bstream._recv == self._broadcaster._recv + + # NOTE: we patch on a `.send()` to the bcaster so that the + # caller can still conduct 2-way streaming using this + # ``bstream`` handle transparently as though it was the msg + # stream instance. + bstream.send = self.send # type: ignore + yield bstream - -class MsgStream(ReceiveMsgStream, trio.abc.Channel): - ''' - Bidirectional message stream for use within an inter-actor actor - ``Context```. - - ''' async def send( self, data: Any diff --git a/tractor/trionics/_broadcast.py b/tractor/trionics/_broadcast.py index 6c04895..42b1704 100644 --- a/tractor/trionics/_broadcast.py +++ b/tractor/trionics/_broadcast.py @@ -23,7 +23,6 @@ from __future__ import annotations from abc import abstractmethod from collections import deque from contextlib import asynccontextmanager -from dataclasses import dataclass from functools import partial from operator import ne from typing import Optional, Callable, Awaitable, Any, AsyncIterator, Protocol @@ -33,7 +32,10 @@ import trio from trio._core._run import Task from trio.abc import ReceiveChannel from trio.lowlevel import current_task +from msgspec import Struct +from tractor.log import get_logger +log = get_logger(__name__) # A regular invariant generic type T = TypeVar("T") @@ -86,8 +88,7 @@ class Lagged(trio.TooSlowError): ''' -@dataclass -class BroadcastState: +class BroadcastState(Struct): ''' Common state to all receivers of a broadcast. @@ -110,7 +111,35 @@ class BroadcastState: eoc: bool = False # If the broadcaster was cancelled, we might as well track it - cancelled: bool = False + cancelled: dict[int, Task] = {} + + def statistics(self) -> dict[str, Any]: + ''' + Return broadcast receiver group "statistics" like many of + ``trio``'s internal task-sync primitives. + + ''' + key: int | None + ev: trio.Event | None + + subs = self.subs + if self.recv_ready is not None: + key, ev = self.recv_ready + else: + key = ev = None + + qlens: dict[int, int] = {} + for tid, sz in subs.items(): + qlens[tid] = sz if sz != -1 else 0 + + return { + 'open_consumers': len(subs), + 'queued_len_by_task': qlens, + 'max_buffer_size': self.maxlen, + 'tasks_waiting': ev.statistics().tasks_waiting if ev else 0, + 'tasks_cancelled': self.cancelled, + 'next_value_receiver_id': key, + } class BroadcastReceiver(ReceiveChannel): @@ -128,23 +157,40 @@ class BroadcastReceiver(ReceiveChannel): rx_chan: AsyncReceiver, state: BroadcastState, receive_afunc: Optional[Callable[[], Awaitable[Any]]] = None, + raise_on_lag: bool = True, ) -> None: # register the original underlying (clone) self.key = id(self) self._state = state + + # each consumer has an int count which indicates + # which index contains the next value that the task has not yet + # consumed and thus should read. In the "up-to-date" case the + # consumer task must wait for a new value from the underlying + # receiver and we use ``-1`` as the sentinel for this state. state.subs[self.key] = -1 # underlying for this receiver self._rx = rx_chan self._recv = receive_afunc or rx_chan.receive self._closed: bool = False + self._raise_on_lag = raise_on_lag - async def receive(self) -> ReceiveType: + def receive_nowait( + self, + _key: int | None = None, + _state: BroadcastState | None = None, - key = self.key - state = self._state + ) -> Any: + ''' + Sync version of `.receive()` which does all the low level work + of receiving from the underlying/wrapped receive channel. + + ''' + key = _key or self.key + state = _state or self._state # TODO: ideally we can make some way to "lock out" the # underlying receive channel in some way such that if some task @@ -177,128 +223,173 @@ class BroadcastReceiver(ReceiveChannel): # return this value." # https://docs.rs/tokio/1.11.0/tokio/sync/broadcast/index.html#lagging + mxln = state.maxlen + lost = seq - mxln + # decrement to the last value and expect # consumer to either handle the ``Lagged`` and come back # or bail out on its own (thus un-subscribing) - state.subs[key] = state.maxlen - 1 + state.subs[key] = mxln - 1 # this task was overrun by the producer side task: Task = current_task() - raise Lagged(f'Task {task.name} was overrun') + msg = f'Task `{task.name}` overrun and dropped `{lost}` values' + + if self._raise_on_lag: + raise Lagged(msg) + else: + log.warning(msg) + return self.receive_nowait(_key, _state) state.subs[key] -= 1 return value - # current task already has the latest value **and** is the - # first task to begin waiting for a new one - if state.recv_ready is None: + raise trio.WouldBlock - if self._closed: - raise trio.ClosedResourceError + async def _receive_from_underlying( + self, + key: int, + state: BroadcastState, - event = trio.Event() - state.recv_ready = key, event + ) -> ReceiveType: + if self._closed: + raise trio.ClosedResourceError + + event = trio.Event() + assert state.recv_ready is None + state.recv_ready = key, event + + try: # if we're cancelled here it should be # fine to bail without affecting any other consumers # right? - try: - value = await self._recv() + value = await self._recv() - # items with lower indices are "newer" - # NOTE: ``collections.deque`` implicitly takes care of - # trucating values outside our ``state.maxlen``. In the - # alt-backend-array-case we'll need to make sure this is - # implemented in similar ringer-buffer-ish style. - state.queue.appendleft(value) + # items with lower indices are "newer" + # NOTE: ``collections.deque`` implicitly takes care of + # trucating values outside our ``state.maxlen``. In the + # alt-backend-array-case we'll need to make sure this is + # implemented in similar ringer-buffer-ish style. + state.queue.appendleft(value) - # broadcast new value to all subscribers by increasing - # all sequence numbers that will point in the queue to - # their latest available value. + # broadcast new value to all subscribers by increasing + # all sequence numbers that will point in the queue to + # their latest available value. - # don't decrement the sequence for this task since we - # already retreived the last value + # don't decrement the sequence for this task since we + # already retreived the last value - # XXX: which of these impls is fastest? + # XXX: which of these impls is fastest? + # subs = state.subs.copy() + # subs.pop(key) - # subs = state.subs.copy() - # subs.pop(key) - - for sub_key in filter( - # lambda k: k != key, state.subs, - partial(ne, key), state.subs, - ): - state.subs[sub_key] += 1 - - # NOTE: this should ONLY be set if the above task was *NOT* - # cancelled on the `._recv()` call. - event.set() - return value - - except trio.EndOfChannel: - # if any one consumer gets an EOC from the underlying - # receiver we need to unblock and send that signal to - # all other consumers. - self._state.eoc = True - if event.statistics().tasks_waiting: - event.set() - raise - - except ( - trio.Cancelled, + for sub_key in filter( + # lambda k: k != key, state.subs, + partial(ne, key), state.subs, ): - # handle cancelled specially otherwise sibling - # consumers will be awoken with a sequence of -1 - # and will potentially try to rewait the underlying - # receiver instead of just cancelling immediately. - self._state.cancelled = True - if event.statistics().tasks_waiting: - event.set() - raise + state.subs[sub_key] += 1 - finally: + # NOTE: this should ONLY be set if the above task was *NOT* + # cancelled on the `._recv()` call. + event.set() + return value - # Reset receiver waiter task event for next blocking condition. - # this MUST be reset even if the above ``.recv()`` call - # was cancelled to avoid the next consumer from blocking on - # an event that won't be set! - state.recv_ready = None + except trio.EndOfChannel: + # if any one consumer gets an EOC from the underlying + # receiver we need to unblock and send that signal to + # all other consumers. + self._state.eoc = True + if event.statistics().tasks_waiting: + event.set() + raise + + except ( + trio.Cancelled, + ): + # handle cancelled specially otherwise sibling + # consumers will be awoken with a sequence of -1 + # and will potentially try to rewait the underlying + # receiver instead of just cancelling immediately. + self._state.cancelled[key] = current_task() + if event.statistics().tasks_waiting: + event.set() + raise + + finally: + # Reset receiver waiter task event for next blocking condition. + # this MUST be reset even if the above ``.recv()`` call + # was cancelled to avoid the next consumer from blocking on + # an event that won't be set! + state.recv_ready = None + + async def receive(self) -> ReceiveType: + key = self.key + state = self._state + + try: + return self.receive_nowait( + _key=key, + _state=state, + ) + except trio.WouldBlock: + pass + + # current task already has the latest value **and** is the + # first task to begin waiting for a new one so we begin blocking + # until rescheduled with the a new value from the underlying. + if state.recv_ready is None: + return await self._receive_from_underlying(key, state) # This task is all caught up and ready to receive the latest - # value, so queue sched it on the internal event. + # value, so queue/schedule it to be woken on the next internal + # event. else: - seq = state.subs[key] - assert seq == -1 # sanity - _, ev = state.recv_ready - await ev.wait() + while state.recv_ready is not None: + # seq = state.subs[key] + # assert seq == -1 # sanity + _, ev = state.recv_ready + await ev.wait() + try: + return self.receive_nowait( + _key=key, + _state=state, + ) + except trio.WouldBlock: + if self._closed: + raise trio.ClosedResourceError - # NOTE: if we ever would like the behaviour where if the - # first task to recv on the underlying is cancelled but it - # still DOES trigger the ``.recv_ready``, event we'll likely need - # this logic: + subs = state.subs + if ( + len(subs) == 1 + and key in subs + # or cancelled + ): + # XXX: we are the last and only user of this BR so + # likely it makes sense to unwind back to the + # underlying? + # import tractor + # await tractor.breakpoint() + log.warning( + f'Only one sub left for {self}?\n' + 'We can probably unwind from breceiver?' + ) - if seq > -1: - # stuff from above.. - seq = state.subs[key] + # XXX: In the case where the first task to allocate the + # ``.recv_ready`` event is cancelled we will be woken + # with a non-incremented sequence number (the ``-1`` + # sentinel) and thus will read the oldest value if we + # use that. Instead we need to detect if we have not + # been incremented and then receive again. + # return await self.receive() - value = state.queue[seq] - state.subs[key] -= 1 - return value - - elif seq == -1: - # XXX: In the case where the first task to allocate the - # ``.recv_ready`` event is cancelled we will be woken with - # a non-incremented sequence number and thus will read the - # oldest value if we use that. Instead we need to detect if - # we have not been incremented and then receive again. - return await self.receive() - - else: - raise ValueError(f'Invalid sequence {seq}!?') + return await self._receive_from_underlying(key, state) @asynccontextmanager async def subscribe( self, + raise_on_lag: bool = True, + ) -> AsyncIterator[BroadcastReceiver]: ''' Subscribe for values from this broadcast receiver. @@ -316,6 +407,7 @@ class BroadcastReceiver(ReceiveChannel): rx_chan=self._rx, state=state, receive_afunc=self._recv, + raise_on_lag=raise_on_lag, ) # assert clone in state.subs assert br.key in state.subs @@ -352,7 +444,8 @@ def broadcast_receiver( recv_chan: AsyncReceiver, max_buffer_size: int, - **kwargs, + receive_afunc: Optional[Callable[[], Awaitable[Any]]] = None, + raise_on_lag: bool = True, ) -> BroadcastReceiver: @@ -363,5 +456,6 @@ def broadcast_receiver( maxlen=max_buffer_size, subs={}, ), - **kwargs, + receive_afunc=receive_afunc, + raise_on_lag=raise_on_lag, )