commit
						3f1bc37143
					
				|  | @ -0,0 +1,12 @@ | ||||||
|  | Add `tokio-style broadcast channels | ||||||
|  | <https://docs.rs/tokio/1.11.0/tokio/sync/broadcast/index.html>`_ as | ||||||
|  | a solution for `#204 <https://github.com/goodboy/tractor/pull/204>`_ and | ||||||
|  | discussed thoroughly in `trio/#987 | ||||||
|  | <https://github.com/python-trio/trio/issues/987>`_. | ||||||
|  | 
 | ||||||
|  | This gives us local task broadcast functionality using a new | ||||||
|  | ``BroadcastReceiver`` type which can wrap ``trio.ReceiveChannel``  and | ||||||
|  | provide fan-out copies of a stream of data to every subscribed consumer. | ||||||
|  | We use this new machinery to provide a ``ReceiveMsgStream.subscribe()`` | ||||||
|  | async context manager which can be used by actor-local concumers tasks | ||||||
|  | to easily pull from a shared and dynamic IPC stream. | ||||||
|  | @ -0,0 +1,459 @@ | ||||||
|  | """ | ||||||
|  | Broadcast channels for fan-out to local tasks. | ||||||
|  | """ | ||||||
|  | from contextlib import asynccontextmanager | ||||||
|  | from functools import partial | ||||||
|  | from itertools import cycle | ||||||
|  | import time | ||||||
|  | from typing import Optional, List, Tuple | ||||||
|  | 
 | ||||||
|  | import pytest | ||||||
|  | import trio | ||||||
|  | from trio.lowlevel import current_task | ||||||
|  | import tractor | ||||||
|  | from tractor._broadcast import broadcast_receiver, Lagged | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @tractor.context | ||||||
|  | async def echo_sequences( | ||||||
|  | 
 | ||||||
|  |     ctx:  tractor.Context, | ||||||
|  | 
 | ||||||
|  | ) -> None: | ||||||
|  |     '''Bidir streaming endpoint which will stream | ||||||
|  |     back any sequence it is sent item-wise. | ||||||
|  | 
 | ||||||
|  |     ''' | ||||||
|  |     await ctx.started() | ||||||
|  | 
 | ||||||
|  |     async with ctx.open_stream() as stream: | ||||||
|  |         async for sequence in stream: | ||||||
|  |             seq = list(sequence) | ||||||
|  |             for value in seq: | ||||||
|  |                 await stream.send(value) | ||||||
|  |                 print(f'producer sent {value}') | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | async def ensure_sequence( | ||||||
|  | 
 | ||||||
|  |     stream: tractor.ReceiveMsgStream, | ||||||
|  |     sequence: list, | ||||||
|  |     delay: Optional[float] = None, | ||||||
|  | 
 | ||||||
|  | ) -> None: | ||||||
|  | 
 | ||||||
|  |     name = current_task().name | ||||||
|  |     async with stream.subscribe() as bcaster: | ||||||
|  |         assert not isinstance(bcaster, type(stream)) | ||||||
|  |         async for value in bcaster: | ||||||
|  |             print(f'{name} rx: {value}') | ||||||
|  |             assert value == sequence[0] | ||||||
|  |             sequence.remove(value) | ||||||
|  | 
 | ||||||
|  |             if delay: | ||||||
|  |                 await trio.sleep(delay) | ||||||
|  | 
 | ||||||
|  |             if not sequence: | ||||||
|  |                 # fully consumed | ||||||
|  |                 break | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @asynccontextmanager | ||||||
|  | async def open_sequence_streamer( | ||||||
|  | 
 | ||||||
|  |     sequence: List[int], | ||||||
|  |     arb_addr: Tuple[str, int], | ||||||
|  |     start_method: str, | ||||||
|  | 
 | ||||||
|  | ) -> tractor.MsgStream: | ||||||
|  | 
 | ||||||
|  |     async with tractor.open_nursery( | ||||||
|  |         arbiter_addr=arb_addr, | ||||||
|  |         start_method=start_method, | ||||||
|  |     ) as tn: | ||||||
|  | 
 | ||||||
|  |         portal = await tn.start_actor( | ||||||
|  |             'sequence_echoer', | ||||||
|  |             enable_modules=[__name__], | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         async with portal.open_context( | ||||||
|  |             echo_sequences, | ||||||
|  |         ) as (ctx, first): | ||||||
|  | 
 | ||||||
|  |             assert first is None | ||||||
|  |             async with ctx.open_stream() as stream: | ||||||
|  |                 yield stream | ||||||
|  | 
 | ||||||
|  |         await portal.cancel_actor() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_stream_fan_out_to_local_subscriptions( | ||||||
|  |     arb_addr, | ||||||
|  |     start_method, | ||||||
|  | ): | ||||||
|  | 
 | ||||||
|  |     sequence = list(range(1000)) | ||||||
|  | 
 | ||||||
|  |     async def main(): | ||||||
|  | 
 | ||||||
|  |         async with open_sequence_streamer( | ||||||
|  |             sequence, | ||||||
|  |             arb_addr, | ||||||
|  |             start_method, | ||||||
|  |         ) as stream: | ||||||
|  | 
 | ||||||
|  |             async with trio.open_nursery() as n: | ||||||
|  |                 for i in range(10): | ||||||
|  |                     n.start_soon( | ||||||
|  |                         ensure_sequence, | ||||||
|  |                         stream, | ||||||
|  |                         sequence.copy(), | ||||||
|  |                         name=f'consumer_{i}', | ||||||
|  |                     ) | ||||||
|  | 
 | ||||||
|  |                 await stream.send(tuple(sequence)) | ||||||
|  | 
 | ||||||
|  |                 async for value in stream: | ||||||
|  |                     print(f'source stream rx: {value}') | ||||||
|  |                     assert value == sequence[0] | ||||||
|  |                     sequence.remove(value) | ||||||
|  | 
 | ||||||
|  |                     if not sequence: | ||||||
|  |                         # fully consumed | ||||||
|  |                         break | ||||||
|  | 
 | ||||||
|  |     trio.run(main) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.mark.parametrize( | ||||||
|  |     'task_delays', | ||||||
|  |     [ | ||||||
|  |         (0.01, 0.001), | ||||||
|  |         (0.001, 0.01), | ||||||
|  |     ] | ||||||
|  | ) | ||||||
|  | def test_consumer_and_parent_maybe_lag( | ||||||
|  |     arb_addr, | ||||||
|  |     start_method, | ||||||
|  |     task_delays, | ||||||
|  | ): | ||||||
|  | 
 | ||||||
|  |     async def main(): | ||||||
|  | 
 | ||||||
|  |         sequence = list(range(300)) | ||||||
|  |         parent_delay, sub_delay = task_delays | ||||||
|  | 
 | ||||||
|  |         async with open_sequence_streamer( | ||||||
|  |             sequence, | ||||||
|  |             arb_addr, | ||||||
|  |             start_method, | ||||||
|  |         ) as stream: | ||||||
|  | 
 | ||||||
|  |             try: | ||||||
|  |                 async with trio.open_nursery() as n: | ||||||
|  | 
 | ||||||
|  |                     n.start_soon( | ||||||
|  |                         ensure_sequence, | ||||||
|  |                         stream, | ||||||
|  |                         sequence.copy(), | ||||||
|  |                         sub_delay, | ||||||
|  |                         name='consumer_task', | ||||||
|  |                     ) | ||||||
|  | 
 | ||||||
|  |                     await stream.send(tuple(sequence)) | ||||||
|  | 
 | ||||||
|  |                     # async for value in stream: | ||||||
|  |                     lagged = False | ||||||
|  |                     lag_count = 0 | ||||||
|  | 
 | ||||||
|  |                     while True: | ||||||
|  |                         try: | ||||||
|  |                             value = await stream.receive() | ||||||
|  |                             print(f'source stream rx: {value}') | ||||||
|  | 
 | ||||||
|  |                             if lagged: | ||||||
|  |                                 # re set the sequence starting at our last | ||||||
|  |                                 # value | ||||||
|  |                                 sequence = sequence[sequence.index(value) + 1:] | ||||||
|  |                             else: | ||||||
|  |                                 assert value == sequence[0] | ||||||
|  |                                 sequence.remove(value) | ||||||
|  | 
 | ||||||
|  |                             lagged = False | ||||||
|  | 
 | ||||||
|  |                         except Lagged: | ||||||
|  |                             lagged = True | ||||||
|  |                             print(f'source stream lagged after {value}') | ||||||
|  |                             lag_count += 1 | ||||||
|  |                             continue | ||||||
|  | 
 | ||||||
|  |                         # lag the parent | ||||||
|  |                         await trio.sleep(parent_delay) | ||||||
|  | 
 | ||||||
|  |                         if not sequence: | ||||||
|  |                             # fully consumed | ||||||
|  |                             break | ||||||
|  |                     print(f'parent + source stream lagged: {lag_count}') | ||||||
|  | 
 | ||||||
|  |                     if parent_delay > sub_delay: | ||||||
|  |                         assert lag_count > 0 | ||||||
|  | 
 | ||||||
|  |             except Lagged: | ||||||
|  |                 # child was lagged | ||||||
|  |                 assert parent_delay < sub_delay | ||||||
|  | 
 | ||||||
|  |     trio.run(main) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_faster_task_to_recv_is_cancelled_by_slower( | ||||||
|  |     arb_addr, | ||||||
|  |     start_method, | ||||||
|  | ): | ||||||
|  |     '''Ensure that if a faster task consuming from a stream is cancelled | ||||||
|  |     the slower task can continue to receive all expected values. | ||||||
|  | 
 | ||||||
|  |     ''' | ||||||
|  |     async def main(): | ||||||
|  | 
 | ||||||
|  |         sequence = list(range(1000)) | ||||||
|  | 
 | ||||||
|  |         async with open_sequence_streamer( | ||||||
|  |             sequence, | ||||||
|  |             arb_addr, | ||||||
|  |             start_method, | ||||||
|  | 
 | ||||||
|  |         ) as stream: | ||||||
|  | 
 | ||||||
|  |             async with trio.open_nursery() as n: | ||||||
|  |                 n.start_soon( | ||||||
|  |                     ensure_sequence, | ||||||
|  |                     stream, | ||||||
|  |                     sequence.copy(), | ||||||
|  |                     0, | ||||||
|  |                     name='consumer_task', | ||||||
|  |                 ) | ||||||
|  | 
 | ||||||
|  |                 await stream.send(tuple(sequence)) | ||||||
|  | 
 | ||||||
|  |                 # pull 3 values, cancel the subtask, then | ||||||
|  |                 # expect to be able to pull all values still | ||||||
|  |                 for i in range(20): | ||||||
|  |                     try: | ||||||
|  |                         value = await stream.receive() | ||||||
|  |                         print(f'source stream rx: {value}') | ||||||
|  |                         await trio.sleep(0.01) | ||||||
|  |                     except Lagged: | ||||||
|  |                         print(f'parent overrun after {value}') | ||||||
|  |                         continue | ||||||
|  | 
 | ||||||
|  |                 print('cancelling faster subtask') | ||||||
|  |                 n.cancel_scope.cancel() | ||||||
|  | 
 | ||||||
|  |             try: | ||||||
|  |                 value = await stream.receive() | ||||||
|  |                 print(f'source stream after cancel: {value}') | ||||||
|  |             except Lagged: | ||||||
|  |                 print(f'parent overrun after {value}') | ||||||
|  | 
 | ||||||
|  |             # expect to see all remaining values | ||||||
|  |             with trio.fail_after(0.5): | ||||||
|  |                 async for value in stream: | ||||||
|  |                     assert stream._broadcaster._state.recv_ready is None | ||||||
|  |                     print(f'source stream rx: {value}') | ||||||
|  |                     if value == 999: | ||||||
|  |                         # fully consumed and we missed no values once | ||||||
|  |                         # the faster subtask was cancelled | ||||||
|  |                         break | ||||||
|  | 
 | ||||||
|  |                 # await tractor.breakpoint() | ||||||
|  |                 # await stream.receive() | ||||||
|  |                 print(f'final value: {value}') | ||||||
|  | 
 | ||||||
|  |     trio.run(main) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_subscribe_errors_after_close(): | ||||||
|  | 
 | ||||||
|  |     async def main(): | ||||||
|  | 
 | ||||||
|  |         size = 1 | ||||||
|  |         tx, rx = trio.open_memory_channel(size) | ||||||
|  |         async with broadcast_receiver(rx, size) as brx: | ||||||
|  |             pass | ||||||
|  | 
 | ||||||
|  |         try: | ||||||
|  |             # open and close | ||||||
|  |             async with brx.subscribe(): | ||||||
|  |                 pass | ||||||
|  | 
 | ||||||
|  |         except trio.ClosedResourceError: | ||||||
|  |             assert brx.key not in brx._state.subs | ||||||
|  | 
 | ||||||
|  |         else: | ||||||
|  |             assert 0 | ||||||
|  | 
 | ||||||
|  |     trio.run(main) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_ensure_slow_consumers_lag_out( | ||||||
|  |     arb_addr, | ||||||
|  |     start_method, | ||||||
|  | ): | ||||||
|  |     '''This is a pure local task test; no tractor | ||||||
|  |     machinery is really required. | ||||||
|  | 
 | ||||||
|  |     ''' | ||||||
|  |     async def main(): | ||||||
|  | 
 | ||||||
|  |         # make sure it all works within the runtime | ||||||
|  |         async with tractor.open_root_actor(): | ||||||
|  | 
 | ||||||
|  |             num_laggers = 4 | ||||||
|  |             laggers: dict[str, int] = {} | ||||||
|  |             retries = 3 | ||||||
|  |             size = 100 | ||||||
|  |             tx, rx = trio.open_memory_channel(size) | ||||||
|  |             brx = broadcast_receiver(rx, size) | ||||||
|  | 
 | ||||||
|  |             async def sub_and_print( | ||||||
|  |                 delay: float, | ||||||
|  |             ) -> None: | ||||||
|  | 
 | ||||||
|  |                 task = current_task() | ||||||
|  |                 start = time.time() | ||||||
|  | 
 | ||||||
|  |                 async with brx.subscribe() as lbrx: | ||||||
|  |                     while True: | ||||||
|  |                         print(f'{task.name}: starting consume loop') | ||||||
|  |                         try: | ||||||
|  |                             async for value in lbrx: | ||||||
|  |                                 print(f'{task.name}: {value}') | ||||||
|  |                                 await trio.sleep(delay) | ||||||
|  | 
 | ||||||
|  |                             if task.name == 'sub_1': | ||||||
|  |                                 # the non-lagger got | ||||||
|  |                                 # a ``trio.EndOfChannel`` | ||||||
|  |                                 # because the ``tx`` below was closed | ||||||
|  |                                 assert len(lbrx._state.subs) == 1 | ||||||
|  | 
 | ||||||
|  |                                 await lbrx.aclose() | ||||||
|  | 
 | ||||||
|  |                                 assert len(lbrx._state.subs) == 0 | ||||||
|  | 
 | ||||||
|  |                         except trio.ClosedResourceError: | ||||||
|  |                             # only the fast sub will try to re-enter | ||||||
|  |                             # iteration on the now closed bcaster | ||||||
|  |                             assert task.name == 'sub_1' | ||||||
|  |                             return | ||||||
|  | 
 | ||||||
|  |                         except Lagged: | ||||||
|  |                             lag_time = time.time() - start | ||||||
|  |                             lags = laggers[task.name] | ||||||
|  |                             print( | ||||||
|  |                                 f'restarting slow task {task.name} ' | ||||||
|  |                                 f'that bailed out on {lags}:{value} ' | ||||||
|  |                                 f'after {lag_time:.3f}') | ||||||
|  |                             if lags <= retries: | ||||||
|  |                                 laggers[task.name] += 1 | ||||||
|  |                                 continue | ||||||
|  |                             else: | ||||||
|  |                                 print( | ||||||
|  |                                     f'{task.name} was too slow and terminated ' | ||||||
|  |                                     f'on {lags}:{value}') | ||||||
|  |                                 return | ||||||
|  | 
 | ||||||
|  |             async with trio.open_nursery() as nursery: | ||||||
|  | 
 | ||||||
|  |                 for i in range(1, num_laggers): | ||||||
|  | 
 | ||||||
|  |                     task_name = f'sub_{i}' | ||||||
|  |                     laggers[task_name] = 0 | ||||||
|  |                     nursery.start_soon( | ||||||
|  |                         partial( | ||||||
|  |                             sub_and_print, | ||||||
|  |                             delay=i*0.001, | ||||||
|  |                         ), | ||||||
|  |                         name=task_name, | ||||||
|  |                     ) | ||||||
|  | 
 | ||||||
|  |                 # allow subs to sched | ||||||
|  |                 await trio.sleep(0.1) | ||||||
|  | 
 | ||||||
|  |                 async with tx: | ||||||
|  |                     for i in cycle(range(size)): | ||||||
|  |                         await tx.send(i) | ||||||
|  |                         if len(brx._state.subs) == 2: | ||||||
|  |                             # only one, the non lagger, sub is left | ||||||
|  |                             break | ||||||
|  | 
 | ||||||
|  |                 # the non-lagger | ||||||
|  |                 assert laggers.pop('sub_1') == 0 | ||||||
|  | 
 | ||||||
|  |                 for n, v in laggers.items(): | ||||||
|  |                     assert v == 4 | ||||||
|  | 
 | ||||||
|  |                 assert tx._closed | ||||||
|  |                 assert not tx._state.open_send_channels | ||||||
|  | 
 | ||||||
|  |                 # check that "first" bcaster that we created | ||||||
|  |                 # above, never wass iterated and is thus overrun | ||||||
|  |                 try: | ||||||
|  |                     await brx.receive() | ||||||
|  |                 except Lagged: | ||||||
|  |                     # expect tokio style index truncation | ||||||
|  |                     seq = brx._state.subs[brx.key] | ||||||
|  |                     assert seq == len(brx._state.queue) - 1 | ||||||
|  | 
 | ||||||
|  |                 # all backpressured entries in the underlying | ||||||
|  |                 # channel should have been copied into the caster | ||||||
|  |                 # queue trailing-window | ||||||
|  |                 async for i in rx: | ||||||
|  |                     print(f'bped: {i}') | ||||||
|  |                     assert i in brx._state.queue | ||||||
|  | 
 | ||||||
|  |                 # should be noop | ||||||
|  |                 await brx.aclose() | ||||||
|  | 
 | ||||||
|  |     trio.run(main) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_first_recver_is_cancelled(): | ||||||
|  | 
 | ||||||
|  |     async def main(): | ||||||
|  | 
 | ||||||
|  |         # make sure it all works within the runtime | ||||||
|  |         async with tractor.open_root_actor(): | ||||||
|  | 
 | ||||||
|  |             tx, rx = trio.open_memory_channel(1) | ||||||
|  |             brx = broadcast_receiver(rx, 1) | ||||||
|  |             cs = trio.CancelScope() | ||||||
|  |             sequence = list(range(3)) | ||||||
|  | 
 | ||||||
|  |             async def sub_and_recv(): | ||||||
|  |                 with cs: | ||||||
|  |                     async with brx.subscribe() as bc: | ||||||
|  |                         async for value in bc: | ||||||
|  |                             print(value) | ||||||
|  | 
 | ||||||
|  |             async def cancel_and_send(): | ||||||
|  |                 await trio.sleep(0.2) | ||||||
|  |                 cs.cancel() | ||||||
|  |                 await tx.send(1) | ||||||
|  | 
 | ||||||
|  |             async with trio.open_nursery() as n: | ||||||
|  | 
 | ||||||
|  |                 n.start_soon(sub_and_recv) | ||||||
|  |                 await trio.sleep(0.1) | ||||||
|  |                 assert brx._state.recv_ready | ||||||
|  | 
 | ||||||
|  |                 n.start_soon(cancel_and_send) | ||||||
|  | 
 | ||||||
|  |                 # ensure that we don't hang because no-task is now | ||||||
|  |                 # waiting on the underlying receive.. | ||||||
|  |                 with trio.fail_after(0.5): | ||||||
|  |                     value = await brx.receive() | ||||||
|  |                     print(f'parent: {value}') | ||||||
|  |                     assert value == 1 | ||||||
|  | 
 | ||||||
|  |     trio.run(main) | ||||||
|  | @ -567,7 +567,7 @@ class Actor: | ||||||
|         try: |         try: | ||||||
|             send_chan, recv_chan = self._cids2qs[(actorid, cid)] |             send_chan, recv_chan = self._cids2qs[(actorid, cid)] | ||||||
|         except KeyError: |         except KeyError: | ||||||
|             send_chan, recv_chan = trio.open_memory_channel(1000) |             send_chan, recv_chan = trio.open_memory_channel(2**6) | ||||||
|             send_chan.cid = cid  # type: ignore |             send_chan.cid = cid  # type: ignore | ||||||
|             recv_chan.cid = cid  # type: ignore |             recv_chan.cid = cid  # type: ignore | ||||||
|             self._cids2qs[(actorid, cid)] = send_chan, recv_chan |             self._cids2qs[(actorid, cid)] = send_chan, recv_chan | ||||||
|  |  | ||||||
|  | @ -0,0 +1,315 @@ | ||||||
|  | ''' | ||||||
|  | ``tokio`` style broadcast channel. | ||||||
|  | https://docs.rs/tokio/1.11.0/tokio/sync/broadcast/index.html | ||||||
|  | 
 | ||||||
|  | ''' | ||||||
|  | from __future__ import annotations | ||||||
|  | from abc import abstractmethod | ||||||
|  | from collections import deque | ||||||
|  | from contextlib import asynccontextmanager | ||||||
|  | from dataclasses import dataclass | ||||||
|  | from functools import partial | ||||||
|  | from operator import ne | ||||||
|  | from typing import Optional, Callable, Awaitable, Any, AsyncIterator, Protocol | ||||||
|  | from typing import Generic, TypeVar | ||||||
|  | 
 | ||||||
|  | import trio | ||||||
|  | from trio._core._run import Task | ||||||
|  | from trio.abc import ReceiveChannel | ||||||
|  | from trio.lowlevel import current_task | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # A regular invariant generic type | ||||||
|  | T = TypeVar("T") | ||||||
|  | 
 | ||||||
|  | # covariant because AsyncReceiver[Derived] can be passed to someone | ||||||
|  | # expecting AsyncReceiver[Base]) | ||||||
|  | ReceiveType = TypeVar("ReceiveType", covariant=True) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class AsyncReceiver( | ||||||
|  |     Protocol, | ||||||
|  |     Generic[ReceiveType], | ||||||
|  | ): | ||||||
|  |     '''An async receivable duck-type that quacks much like trio's | ||||||
|  |     ``trio.abc.ReceieveChannel``. | ||||||
|  | 
 | ||||||
|  |     ''' | ||||||
|  |     @abstractmethod | ||||||
|  |     async def receive(self) -> ReceiveType: | ||||||
|  |         ... | ||||||
|  | 
 | ||||||
|  |     @abstractmethod | ||||||
|  |     def __aiter__(self) -> AsyncIterator[ReceiveType]: | ||||||
|  |         ... | ||||||
|  | 
 | ||||||
|  |     @abstractmethod | ||||||
|  |     async def __anext__(self) -> ReceiveType: | ||||||
|  |         ... | ||||||
|  | 
 | ||||||
|  |     # ``trio.abc.AsyncResource`` methods | ||||||
|  |     @abstractmethod | ||||||
|  |     async def aclose(self): | ||||||
|  |         ... | ||||||
|  | 
 | ||||||
|  |     @abstractmethod | ||||||
|  |     async def __aenter__(self) -> AsyncReceiver[ReceiveType]: | ||||||
|  |         ... | ||||||
|  | 
 | ||||||
|  |     @abstractmethod | ||||||
|  |     async def __aexit__(self, *args) -> None: | ||||||
|  |         ... | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Lagged(trio.TooSlowError): | ||||||
|  |     '''Subscribed consumer task was too slow and was overrun | ||||||
|  |     by the fastest consumer-producer pair. | ||||||
|  | 
 | ||||||
|  |     ''' | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @dataclass | ||||||
|  | class BroadcastState: | ||||||
|  |     '''Common state to all receivers of a broadcast. | ||||||
|  | 
 | ||||||
|  |     ''' | ||||||
|  |     queue: deque | ||||||
|  |     maxlen: int | ||||||
|  | 
 | ||||||
|  |     # map of underlying instance id keys to receiver instances which | ||||||
|  |     # must be provided as a singleton per broadcaster set. | ||||||
|  |     subs: dict[int, int] | ||||||
|  | 
 | ||||||
|  |     # broadcast event to wake up all sleeping consumer tasks | ||||||
|  |     # on a newly produced value from the sender. | ||||||
|  |     recv_ready: Optional[tuple[int, trio.Event]] = None | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class BroadcastReceiver(ReceiveChannel): | ||||||
|  |     '''A memory receive channel broadcaster which is non-lossy for the | ||||||
|  |     fastest consumer. | ||||||
|  | 
 | ||||||
|  |     Additional consumer tasks can receive all produced values by registering | ||||||
|  |     with ``.subscribe()`` and receiving from the new instance it delivers. | ||||||
|  | 
 | ||||||
|  |     ''' | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  | 
 | ||||||
|  |         rx_chan: AsyncReceiver, | ||||||
|  |         state: BroadcastState, | ||||||
|  |         receive_afunc: Optional[Callable[[], Awaitable[Any]]] = None, | ||||||
|  | 
 | ||||||
|  |     ) -> None: | ||||||
|  | 
 | ||||||
|  |         # register the original underlying (clone) | ||||||
|  |         self.key = id(self) | ||||||
|  |         self._state = state | ||||||
|  |         state.subs[self.key] = -1 | ||||||
|  | 
 | ||||||
|  |         # underlying for this receiver | ||||||
|  |         self._rx = rx_chan | ||||||
|  |         self._recv = receive_afunc or rx_chan.receive | ||||||
|  |         self._closed: bool = False | ||||||
|  | 
 | ||||||
|  |     async def receive(self) -> ReceiveType: | ||||||
|  | 
 | ||||||
|  |         key = self.key | ||||||
|  |         state = self._state | ||||||
|  | 
 | ||||||
|  |         # TODO: ideally we can make some way to "lock out" the | ||||||
|  |         # underlying receive channel in some way such that if some task | ||||||
|  |         # tries to pull from it directly (i.e. one we're unaware of) | ||||||
|  |         # then it errors out. | ||||||
|  | 
 | ||||||
|  |         # only tasks which have entered ``.subscribe()`` can | ||||||
|  |         # receive on this broadcaster. | ||||||
|  |         try: | ||||||
|  |             seq = state.subs[key] | ||||||
|  |         except KeyError: | ||||||
|  |             if self._closed: | ||||||
|  |                 raise trio.ClosedResourceError | ||||||
|  | 
 | ||||||
|  |             raise RuntimeError( | ||||||
|  |                 f'{self} is not registerd as subscriber') | ||||||
|  | 
 | ||||||
|  |         # check that task does not already have a value it can receive | ||||||
|  |         # immediately and/or that it has lagged. | ||||||
|  |         if seq > -1: | ||||||
|  |             # get the oldest value we haven't received immediately | ||||||
|  |             try: | ||||||
|  |                 value = state.queue[seq] | ||||||
|  |             except IndexError: | ||||||
|  | 
 | ||||||
|  |                 # adhere to ``tokio`` style "lagging": | ||||||
|  |                 # "Once RecvError::Lagged is returned, the lagging | ||||||
|  |                 # receiver's position is updated to the oldest value | ||||||
|  |                 # contained by the channel. The next call to recv will | ||||||
|  |                 # return this value." | ||||||
|  |                 # https://docs.rs/tokio/1.11.0/tokio/sync/broadcast/index.html#lagging | ||||||
|  | 
 | ||||||
|  |                 # decrement to the last value and expect | ||||||
|  |                 # consumer to either handle the ``Lagged`` and come back | ||||||
|  |                 # or bail out on its own (thus un-subscribing) | ||||||
|  |                 state.subs[key] = state.maxlen - 1 | ||||||
|  | 
 | ||||||
|  |                 # this task was overrun by the producer side | ||||||
|  |                 task: Task = current_task() | ||||||
|  |                 raise Lagged(f'Task {task.name} was overrun') | ||||||
|  | 
 | ||||||
|  |             state.subs[key] -= 1 | ||||||
|  |             return value | ||||||
|  | 
 | ||||||
|  |         # current task already has the latest value **and** is the | ||||||
|  |         # first task to begin waiting for a new one | ||||||
|  |         if state.recv_ready is None: | ||||||
|  | 
 | ||||||
|  |             if self._closed: | ||||||
|  |                 raise trio.ClosedResourceError | ||||||
|  | 
 | ||||||
|  |             event = trio.Event() | ||||||
|  |             state.recv_ready = key, event | ||||||
|  | 
 | ||||||
|  |             # if we're cancelled here it should be | ||||||
|  |             # fine to bail without affecting any other consumers | ||||||
|  |             # right? | ||||||
|  |             try: | ||||||
|  |                 value = await self._recv() | ||||||
|  | 
 | ||||||
|  |                 # items with lower indices are "newer" | ||||||
|  |                 # NOTE: ``collections.deque`` implicitly takes care of | ||||||
|  |                 # trucating values outside our ``state.maxlen``. In the | ||||||
|  |                 # alt-backend-array-case we'll need to make sure this is | ||||||
|  |                 # implemented in similar ringer-buffer-ish style. | ||||||
|  |                 state.queue.appendleft(value) | ||||||
|  | 
 | ||||||
|  |                 # broadcast new value to all subscribers by increasing | ||||||
|  |                 # all sequence numbers that will point in the queue to | ||||||
|  |                 # their latest available value. | ||||||
|  | 
 | ||||||
|  |                 # don't decrement the sequence for this task since we | ||||||
|  |                 # already retreived the last value | ||||||
|  | 
 | ||||||
|  |                 # XXX: which of these impls is fastest? | ||||||
|  | 
 | ||||||
|  |                 # subs = state.subs.copy() | ||||||
|  |                 # subs.pop(key) | ||||||
|  | 
 | ||||||
|  |                 for sub_key in filter( | ||||||
|  |                     # lambda k: k != key, state.subs, | ||||||
|  |                     partial(ne, key), state.subs, | ||||||
|  |                 ): | ||||||
|  |                     state.subs[sub_key] += 1 | ||||||
|  | 
 | ||||||
|  |                 # NOTE: this should ONLY be set if the above task was *NOT* | ||||||
|  |                 # cancelled on the `._recv()` call. | ||||||
|  |                 event.set() | ||||||
|  |                 return value | ||||||
|  | 
 | ||||||
|  |             except trio.Cancelled: | ||||||
|  |                 # handle cancelled specially otherwise sibling | ||||||
|  |                 # consumers will be awoken with a sequence of -1 | ||||||
|  |                 # state.recv_ready = trio.Cancelled | ||||||
|  |                 if event.statistics().tasks_waiting: | ||||||
|  |                     event.set() | ||||||
|  |                 raise | ||||||
|  | 
 | ||||||
|  |             finally: | ||||||
|  | 
 | ||||||
|  |                 # Reset receiver waiter task event for next blocking condition. | ||||||
|  |                 # this MUST be reset even if the above ``.recv()`` call | ||||||
|  |                 # was cancelled to avoid the next consumer from blocking on | ||||||
|  |                 # an event that won't be set! | ||||||
|  |                 state.recv_ready = None | ||||||
|  | 
 | ||||||
|  |         # This task is all caught up and ready to receive the latest | ||||||
|  |         # value, so queue sched it on the internal event. | ||||||
|  |         else: | ||||||
|  |             seq = state.subs[key] | ||||||
|  |             assert seq == -1  # sanity | ||||||
|  |             _, ev = state.recv_ready | ||||||
|  |             await ev.wait() | ||||||
|  | 
 | ||||||
|  |             # NOTE: if we ever would like the behaviour where if the | ||||||
|  |             # first task to recv on the underlying is cancelled but it | ||||||
|  |             # still DOES trigger the ``.recv_ready``, event we'll likely need | ||||||
|  |             # this logic: | ||||||
|  | 
 | ||||||
|  |             if seq > -1: | ||||||
|  |                 # stuff from above.. | ||||||
|  |                 seq = state.subs[key] | ||||||
|  | 
 | ||||||
|  |                 value = state.queue[seq] | ||||||
|  |                 state.subs[key] -= 1 | ||||||
|  |                 return value | ||||||
|  | 
 | ||||||
|  |             elif seq == -1: | ||||||
|  |                 # XXX: In the case where the first task to allocate the | ||||||
|  |                 # ``.recv_ready`` event is cancelled we will be woken with | ||||||
|  |                 # a non-incremented sequence number and thus will read the | ||||||
|  |                 # oldest value if we use that. Instead we need to detect if | ||||||
|  |                 # we have not been incremented and then receive again. | ||||||
|  |                 return await self.receive() | ||||||
|  | 
 | ||||||
|  |             else: | ||||||
|  |                 raise ValueError(f'Invalid sequence {seq}!?') | ||||||
|  | 
 | ||||||
|  |     @asynccontextmanager | ||||||
|  |     async def subscribe( | ||||||
|  |         self, | ||||||
|  |     ) -> AsyncIterator[BroadcastReceiver]: | ||||||
|  |         '''Subscribe for values from this broadcast receiver. | ||||||
|  | 
 | ||||||
|  |         Returns a new ``BroadCastReceiver`` which is registered for and | ||||||
|  |         pulls data from a clone of the original ``trio.abc.ReceiveChannel`` | ||||||
|  |         provided at creation. | ||||||
|  | 
 | ||||||
|  |         ''' | ||||||
|  |         if self._closed: | ||||||
|  |             raise trio.ClosedResourceError | ||||||
|  | 
 | ||||||
|  |         state = self._state | ||||||
|  |         br = BroadcastReceiver( | ||||||
|  |             rx_chan=self._rx, | ||||||
|  |             state=state, | ||||||
|  |             receive_afunc=self._recv, | ||||||
|  |         ) | ||||||
|  |         # assert clone in state.subs | ||||||
|  |         assert br.key in state.subs | ||||||
|  | 
 | ||||||
|  |         try: | ||||||
|  |             yield br | ||||||
|  |         finally: | ||||||
|  |             await br.aclose() | ||||||
|  | 
 | ||||||
|  |     async def aclose( | ||||||
|  |         self, | ||||||
|  |     ) -> None: | ||||||
|  | 
 | ||||||
|  |         if self._closed: | ||||||
|  |             return | ||||||
|  | 
 | ||||||
|  |         # XXX: leaving it like this consumers can still get values | ||||||
|  |         # up to the last received that still reside in the queue. | ||||||
|  |         self._state.subs.pop(self.key) | ||||||
|  | 
 | ||||||
|  |         self._closed = True | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def broadcast_receiver( | ||||||
|  | 
 | ||||||
|  |     recv_chan: AsyncReceiver, | ||||||
|  |     max_buffer_size: int, | ||||||
|  |     **kwargs, | ||||||
|  | 
 | ||||||
|  | ) -> BroadcastReceiver: | ||||||
|  | 
 | ||||||
|  |     return BroadcastReceiver( | ||||||
|  |         recv_chan, | ||||||
|  |         state=BroadcastState( | ||||||
|  |             queue=deque(maxlen=max_buffer_size), | ||||||
|  |             maxlen=max_buffer_size, | ||||||
|  |             subs={}, | ||||||
|  |         ), | ||||||
|  |         **kwargs, | ||||||
|  |     ) | ||||||
|  | @ -294,6 +294,7 @@ class Portal: | ||||||
|     async def open_stream_from( |     async def open_stream_from( | ||||||
|         self, |         self, | ||||||
|         async_gen_func: Callable,  # typing: ignore |         async_gen_func: Callable,  # typing: ignore | ||||||
|  |         shield: bool = False, | ||||||
|         **kwargs, |         **kwargs, | ||||||
| 
 | 
 | ||||||
|     ) -> AsyncGenerator[ReceiveMsgStream, None]: |     ) -> AsyncGenerator[ReceiveMsgStream, None]: | ||||||
|  | @ -320,7 +321,9 @@ class Portal: | ||||||
|         ctx = Context(self.channel, cid, _portal=self) |         ctx = Context(self.channel, cid, _portal=self) | ||||||
|         try: |         try: | ||||||
|             # deliver receive only stream |             # deliver receive only stream | ||||||
|             async with ReceiveMsgStream(ctx, recv_chan) as rchan: |             async with ReceiveMsgStream( | ||||||
|  |                 ctx, recv_chan, shield=shield | ||||||
|  |             ) as rchan: | ||||||
|                 self._streams.add(rchan) |                 self._streams.add(rchan) | ||||||
|                 yield rchan |                 yield rchan | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -2,12 +2,14 @@ | ||||||
| Message stream types and APIs. | Message stream types and APIs. | ||||||
| 
 | 
 | ||||||
| """ | """ | ||||||
|  | from __future__ import annotations | ||||||
| import inspect | import inspect | ||||||
| from contextlib import contextmanager, asynccontextmanager | from contextlib import contextmanager, asynccontextmanager | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
| from typing import ( | from typing import ( | ||||||
|     Any, Iterator, Optional, Callable, |     Any, Iterator, Optional, Callable, | ||||||
|     AsyncGenerator, Dict, |     AsyncGenerator, Dict, | ||||||
|  |     AsyncIterator | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| import warnings | import warnings | ||||||
|  | @ -17,6 +19,7 @@ import trio | ||||||
| from ._ipc import Channel | from ._ipc import Channel | ||||||
| from ._exceptions import unpack_error, ContextCancelled | from ._exceptions import unpack_error, ContextCancelled | ||||||
| from ._state import current_actor | from ._state import current_actor | ||||||
|  | from ._broadcast import broadcast_receiver, BroadcastReceiver | ||||||
| from .log import get_logger | from .log import get_logger | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -47,11 +50,14 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         ctx: 'Context',  # typing: ignore # noqa |         ctx: 'Context',  # typing: ignore # noqa | ||||||
|         rx_chan: trio.abc.ReceiveChannel, |         rx_chan: trio.MemoryReceiveChannel, | ||||||
|  |         shield: bool = False, | ||||||
|  |         _broadcaster: Optional[BroadcastReceiver] = None, | ||||||
| 
 | 
 | ||||||
|     ) -> None: |     ) -> None: | ||||||
|         self._ctx = ctx |         self._ctx = ctx | ||||||
|         self._rx_chan = rx_chan |         self._rx_chan = rx_chan | ||||||
|  |         self._broadcaster = _broadcaster | ||||||
| 
 | 
 | ||||||
|         # flag to denote end of stream |         # flag to denote end of stream | ||||||
|         self._eoc: bool = False |         self._eoc: bool = False | ||||||
|  | @ -231,6 +237,50 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): | ||||||
|         # still need to consume msgs that are "in transit" from the far |         # still need to consume msgs that are "in transit" from the far | ||||||
|         # end (eg. for ``Context.result()``). |         # end (eg. for ``Context.result()``). | ||||||
| 
 | 
 | ||||||
|  |     @asynccontextmanager | ||||||
|  |     async def subscribe( | ||||||
|  |         self, | ||||||
|  | 
 | ||||||
|  |     ) -> AsyncIterator[BroadcastReceiver]: | ||||||
|  |         '''Allocate and return a ``BroadcastReceiver`` which delegates | ||||||
|  |         to this message stream. | ||||||
|  | 
 | ||||||
|  |         This allows multiple local tasks to receive each their own copy | ||||||
|  |         of this message stream. | ||||||
|  | 
 | ||||||
|  |         This operation is indempotent and and mutates this stream's | ||||||
|  |         receive machinery to copy and window-length-store each received | ||||||
|  |         value from the far end via the internally created broudcast | ||||||
|  |         receiver wrapper. | ||||||
|  | 
 | ||||||
|  |         ''' | ||||||
|  |         # NOTE: This operation is indempotent and non-reversible, so be | ||||||
|  |         # sure you can deal with any (theoretical) overhead of the the | ||||||
|  |         # allocated ``BroadcastReceiver`` before calling this method for | ||||||
|  |         # the first time. | ||||||
|  |         if self._broadcaster is None: | ||||||
|  | 
 | ||||||
|  |             bcast = self._broadcaster = broadcast_receiver( | ||||||
|  |                 self, | ||||||
|  |                 # use memory channel size by default | ||||||
|  |                 self._rx_chan._state.max_buffer_size,  # type: ignore | ||||||
|  |                 receive_afunc=self.receive, | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |             # NOTE: we override the original stream instance's receive | ||||||
|  |             # method to now delegate to the broadcaster's ``.receive()`` | ||||||
|  |             # such that new subscribers will be copied received values | ||||||
|  |             # and this stream doesn't have to expect it's original | ||||||
|  |             # consumer(s) to get a new broadcast rx handle. | ||||||
|  |             self.receive = bcast.receive  # type: ignore | ||||||
|  |             # seems there's no graceful way to type this with ``mypy``? | ||||||
|  |             # https://github.com/python/mypy/issues/708 | ||||||
|  | 
 | ||||||
|  |         async with self._broadcaster.subscribe() as bstream: | ||||||
|  |             assert bstream.key != self._broadcaster.key | ||||||
|  |             assert bstream._recv == self._broadcaster._recv | ||||||
|  |             yield bstream | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| class MsgStream(ReceiveMsgStream, trio.abc.Channel): | class MsgStream(ReceiveMsgStream, trio.abc.Channel): | ||||||
|     """ |     """ | ||||||
|  | @ -247,17 +297,6 @@ class MsgStream(ReceiveMsgStream, trio.abc.Channel): | ||||||
|         ''' |         ''' | ||||||
|         await self._ctx.chan.send({'yield': data, 'cid': self._ctx.cid}) |         await self._ctx.chan.send({'yield': data, 'cid': self._ctx.cid}) | ||||||
| 
 | 
 | ||||||
|     # TODO: but make it broadcasting to consumers |  | ||||||
|     def clone(self): |  | ||||||
|         """Clone this receive channel allowing for multi-task |  | ||||||
|         consumption from the same channel. |  | ||||||
| 
 |  | ||||||
|         """ |  | ||||||
|         return MsgStream( |  | ||||||
|             self._ctx, |  | ||||||
|             self._rx_chan.clone(), |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| @dataclass | @dataclass | ||||||
| class Context: | class Context: | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue