forked from goodboy/tractor
				
			
						commit
						3f1bc37143
					
				| 
						 | 
				
			
			@ -0,0 +1,12 @@
 | 
			
		|||
Add `tokio-style broadcast channels
 | 
			
		||||
<https://docs.rs/tokio/1.11.0/tokio/sync/broadcast/index.html>`_ as
 | 
			
		||||
a solution for `#204 <https://github.com/goodboy/tractor/pull/204>`_ and
 | 
			
		||||
discussed thoroughly in `trio/#987
 | 
			
		||||
<https://github.com/python-trio/trio/issues/987>`_.
 | 
			
		||||
 | 
			
		||||
This gives us local task broadcast functionality using a new
 | 
			
		||||
``BroadcastReceiver`` type which can wrap ``trio.ReceiveChannel``  and
 | 
			
		||||
provide fan-out copies of a stream of data to every subscribed consumer.
 | 
			
		||||
We use this new machinery to provide a ``ReceiveMsgStream.subscribe()``
 | 
			
		||||
async context manager which can be used by actor-local concumers tasks
 | 
			
		||||
to easily pull from a shared and dynamic IPC stream.
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,459 @@
 | 
			
		|||
"""
 | 
			
		||||
Broadcast channels for fan-out to local tasks.
 | 
			
		||||
"""
 | 
			
		||||
from contextlib import asynccontextmanager
 | 
			
		||||
from functools import partial
 | 
			
		||||
from itertools import cycle
 | 
			
		||||
import time
 | 
			
		||||
from typing import Optional, List, Tuple
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
import trio
 | 
			
		||||
from trio.lowlevel import current_task
 | 
			
		||||
import tractor
 | 
			
		||||
from tractor._broadcast import broadcast_receiver, Lagged
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@tractor.context
 | 
			
		||||
async def echo_sequences(
 | 
			
		||||
 | 
			
		||||
    ctx:  tractor.Context,
 | 
			
		||||
 | 
			
		||||
) -> None:
 | 
			
		||||
    '''Bidir streaming endpoint which will stream
 | 
			
		||||
    back any sequence it is sent item-wise.
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    await ctx.started()
 | 
			
		||||
 | 
			
		||||
    async with ctx.open_stream() as stream:
 | 
			
		||||
        async for sequence in stream:
 | 
			
		||||
            seq = list(sequence)
 | 
			
		||||
            for value in seq:
 | 
			
		||||
                await stream.send(value)
 | 
			
		||||
                print(f'producer sent {value}')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def ensure_sequence(
 | 
			
		||||
 | 
			
		||||
    stream: tractor.ReceiveMsgStream,
 | 
			
		||||
    sequence: list,
 | 
			
		||||
    delay: Optional[float] = None,
 | 
			
		||||
 | 
			
		||||
) -> None:
 | 
			
		||||
 | 
			
		||||
    name = current_task().name
 | 
			
		||||
    async with stream.subscribe() as bcaster:
 | 
			
		||||
        assert not isinstance(bcaster, type(stream))
 | 
			
		||||
        async for value in bcaster:
 | 
			
		||||
            print(f'{name} rx: {value}')
 | 
			
		||||
            assert value == sequence[0]
 | 
			
		||||
            sequence.remove(value)
 | 
			
		||||
 | 
			
		||||
            if delay:
 | 
			
		||||
                await trio.sleep(delay)
 | 
			
		||||
 | 
			
		||||
            if not sequence:
 | 
			
		||||
                # fully consumed
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@asynccontextmanager
 | 
			
		||||
async def open_sequence_streamer(
 | 
			
		||||
 | 
			
		||||
    sequence: List[int],
 | 
			
		||||
    arb_addr: Tuple[str, int],
 | 
			
		||||
    start_method: str,
 | 
			
		||||
 | 
			
		||||
) -> tractor.MsgStream:
 | 
			
		||||
 | 
			
		||||
    async with tractor.open_nursery(
 | 
			
		||||
        arbiter_addr=arb_addr,
 | 
			
		||||
        start_method=start_method,
 | 
			
		||||
    ) as tn:
 | 
			
		||||
 | 
			
		||||
        portal = await tn.start_actor(
 | 
			
		||||
            'sequence_echoer',
 | 
			
		||||
            enable_modules=[__name__],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        async with portal.open_context(
 | 
			
		||||
            echo_sequences,
 | 
			
		||||
        ) as (ctx, first):
 | 
			
		||||
 | 
			
		||||
            assert first is None
 | 
			
		||||
            async with ctx.open_stream() as stream:
 | 
			
		||||
                yield stream
 | 
			
		||||
 | 
			
		||||
        await portal.cancel_actor()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_stream_fan_out_to_local_subscriptions(
 | 
			
		||||
    arb_addr,
 | 
			
		||||
    start_method,
 | 
			
		||||
):
 | 
			
		||||
 | 
			
		||||
    sequence = list(range(1000))
 | 
			
		||||
 | 
			
		||||
    async def main():
 | 
			
		||||
 | 
			
		||||
        async with open_sequence_streamer(
 | 
			
		||||
            sequence,
 | 
			
		||||
            arb_addr,
 | 
			
		||||
            start_method,
 | 
			
		||||
        ) as stream:
 | 
			
		||||
 | 
			
		||||
            async with trio.open_nursery() as n:
 | 
			
		||||
                for i in range(10):
 | 
			
		||||
                    n.start_soon(
 | 
			
		||||
                        ensure_sequence,
 | 
			
		||||
                        stream,
 | 
			
		||||
                        sequence.copy(),
 | 
			
		||||
                        name=f'consumer_{i}',
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                await stream.send(tuple(sequence))
 | 
			
		||||
 | 
			
		||||
                async for value in stream:
 | 
			
		||||
                    print(f'source stream rx: {value}')
 | 
			
		||||
                    assert value == sequence[0]
 | 
			
		||||
                    sequence.remove(value)
 | 
			
		||||
 | 
			
		||||
                    if not sequence:
 | 
			
		||||
                        # fully consumed
 | 
			
		||||
                        break
 | 
			
		||||
 | 
			
		||||
    trio.run(main)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize(
 | 
			
		||||
    'task_delays',
 | 
			
		||||
    [
 | 
			
		||||
        (0.01, 0.001),
 | 
			
		||||
        (0.001, 0.01),
 | 
			
		||||
    ]
 | 
			
		||||
)
 | 
			
		||||
def test_consumer_and_parent_maybe_lag(
 | 
			
		||||
    arb_addr,
 | 
			
		||||
    start_method,
 | 
			
		||||
    task_delays,
 | 
			
		||||
):
 | 
			
		||||
 | 
			
		||||
    async def main():
 | 
			
		||||
 | 
			
		||||
        sequence = list(range(300))
 | 
			
		||||
        parent_delay, sub_delay = task_delays
 | 
			
		||||
 | 
			
		||||
        async with open_sequence_streamer(
 | 
			
		||||
            sequence,
 | 
			
		||||
            arb_addr,
 | 
			
		||||
            start_method,
 | 
			
		||||
        ) as stream:
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                async with trio.open_nursery() as n:
 | 
			
		||||
 | 
			
		||||
                    n.start_soon(
 | 
			
		||||
                        ensure_sequence,
 | 
			
		||||
                        stream,
 | 
			
		||||
                        sequence.copy(),
 | 
			
		||||
                        sub_delay,
 | 
			
		||||
                        name='consumer_task',
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                    await stream.send(tuple(sequence))
 | 
			
		||||
 | 
			
		||||
                    # async for value in stream:
 | 
			
		||||
                    lagged = False
 | 
			
		||||
                    lag_count = 0
 | 
			
		||||
 | 
			
		||||
                    while True:
 | 
			
		||||
                        try:
 | 
			
		||||
                            value = await stream.receive()
 | 
			
		||||
                            print(f'source stream rx: {value}')
 | 
			
		||||
 | 
			
		||||
                            if lagged:
 | 
			
		||||
                                # re set the sequence starting at our last
 | 
			
		||||
                                # value
 | 
			
		||||
                                sequence = sequence[sequence.index(value) + 1:]
 | 
			
		||||
                            else:
 | 
			
		||||
                                assert value == sequence[0]
 | 
			
		||||
                                sequence.remove(value)
 | 
			
		||||
 | 
			
		||||
                            lagged = False
 | 
			
		||||
 | 
			
		||||
                        except Lagged:
 | 
			
		||||
                            lagged = True
 | 
			
		||||
                            print(f'source stream lagged after {value}')
 | 
			
		||||
                            lag_count += 1
 | 
			
		||||
                            continue
 | 
			
		||||
 | 
			
		||||
                        # lag the parent
 | 
			
		||||
                        await trio.sleep(parent_delay)
 | 
			
		||||
 | 
			
		||||
                        if not sequence:
 | 
			
		||||
                            # fully consumed
 | 
			
		||||
                            break
 | 
			
		||||
                    print(f'parent + source stream lagged: {lag_count}')
 | 
			
		||||
 | 
			
		||||
                    if parent_delay > sub_delay:
 | 
			
		||||
                        assert lag_count > 0
 | 
			
		||||
 | 
			
		||||
            except Lagged:
 | 
			
		||||
                # child was lagged
 | 
			
		||||
                assert parent_delay < sub_delay
 | 
			
		||||
 | 
			
		||||
    trio.run(main)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_faster_task_to_recv_is_cancelled_by_slower(
 | 
			
		||||
    arb_addr,
 | 
			
		||||
    start_method,
 | 
			
		||||
):
 | 
			
		||||
    '''Ensure that if a faster task consuming from a stream is cancelled
 | 
			
		||||
    the slower task can continue to receive all expected values.
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    async def main():
 | 
			
		||||
 | 
			
		||||
        sequence = list(range(1000))
 | 
			
		||||
 | 
			
		||||
        async with open_sequence_streamer(
 | 
			
		||||
            sequence,
 | 
			
		||||
            arb_addr,
 | 
			
		||||
            start_method,
 | 
			
		||||
 | 
			
		||||
        ) as stream:
 | 
			
		||||
 | 
			
		||||
            async with trio.open_nursery() as n:
 | 
			
		||||
                n.start_soon(
 | 
			
		||||
                    ensure_sequence,
 | 
			
		||||
                    stream,
 | 
			
		||||
                    sequence.copy(),
 | 
			
		||||
                    0,
 | 
			
		||||
                    name='consumer_task',
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                await stream.send(tuple(sequence))
 | 
			
		||||
 | 
			
		||||
                # pull 3 values, cancel the subtask, then
 | 
			
		||||
                # expect to be able to pull all values still
 | 
			
		||||
                for i in range(20):
 | 
			
		||||
                    try:
 | 
			
		||||
                        value = await stream.receive()
 | 
			
		||||
                        print(f'source stream rx: {value}')
 | 
			
		||||
                        await trio.sleep(0.01)
 | 
			
		||||
                    except Lagged:
 | 
			
		||||
                        print(f'parent overrun after {value}')
 | 
			
		||||
                        continue
 | 
			
		||||
 | 
			
		||||
                print('cancelling faster subtask')
 | 
			
		||||
                n.cancel_scope.cancel()
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                value = await stream.receive()
 | 
			
		||||
                print(f'source stream after cancel: {value}')
 | 
			
		||||
            except Lagged:
 | 
			
		||||
                print(f'parent overrun after {value}')
 | 
			
		||||
 | 
			
		||||
            # expect to see all remaining values
 | 
			
		||||
            with trio.fail_after(0.5):
 | 
			
		||||
                async for value in stream:
 | 
			
		||||
                    assert stream._broadcaster._state.recv_ready is None
 | 
			
		||||
                    print(f'source stream rx: {value}')
 | 
			
		||||
                    if value == 999:
 | 
			
		||||
                        # fully consumed and we missed no values once
 | 
			
		||||
                        # the faster subtask was cancelled
 | 
			
		||||
                        break
 | 
			
		||||
 | 
			
		||||
                # await tractor.breakpoint()
 | 
			
		||||
                # await stream.receive()
 | 
			
		||||
                print(f'final value: {value}')
 | 
			
		||||
 | 
			
		||||
    trio.run(main)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_subscribe_errors_after_close():
 | 
			
		||||
 | 
			
		||||
    async def main():
 | 
			
		||||
 | 
			
		||||
        size = 1
 | 
			
		||||
        tx, rx = trio.open_memory_channel(size)
 | 
			
		||||
        async with broadcast_receiver(rx, size) as brx:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            # open and close
 | 
			
		||||
            async with brx.subscribe():
 | 
			
		||||
                pass
 | 
			
		||||
 | 
			
		||||
        except trio.ClosedResourceError:
 | 
			
		||||
            assert brx.key not in brx._state.subs
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            assert 0
 | 
			
		||||
 | 
			
		||||
    trio.run(main)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_ensure_slow_consumers_lag_out(
 | 
			
		||||
    arb_addr,
 | 
			
		||||
    start_method,
 | 
			
		||||
):
 | 
			
		||||
    '''This is a pure local task test; no tractor
 | 
			
		||||
    machinery is really required.
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    async def main():
 | 
			
		||||
 | 
			
		||||
        # make sure it all works within the runtime
 | 
			
		||||
        async with tractor.open_root_actor():
 | 
			
		||||
 | 
			
		||||
            num_laggers = 4
 | 
			
		||||
            laggers: dict[str, int] = {}
 | 
			
		||||
            retries = 3
 | 
			
		||||
            size = 100
 | 
			
		||||
            tx, rx = trio.open_memory_channel(size)
 | 
			
		||||
            brx = broadcast_receiver(rx, size)
 | 
			
		||||
 | 
			
		||||
            async def sub_and_print(
 | 
			
		||||
                delay: float,
 | 
			
		||||
            ) -> None:
 | 
			
		||||
 | 
			
		||||
                task = current_task()
 | 
			
		||||
                start = time.time()
 | 
			
		||||
 | 
			
		||||
                async with brx.subscribe() as lbrx:
 | 
			
		||||
                    while True:
 | 
			
		||||
                        print(f'{task.name}: starting consume loop')
 | 
			
		||||
                        try:
 | 
			
		||||
                            async for value in lbrx:
 | 
			
		||||
                                print(f'{task.name}: {value}')
 | 
			
		||||
                                await trio.sleep(delay)
 | 
			
		||||
 | 
			
		||||
                            if task.name == 'sub_1':
 | 
			
		||||
                                # the non-lagger got
 | 
			
		||||
                                # a ``trio.EndOfChannel``
 | 
			
		||||
                                # because the ``tx`` below was closed
 | 
			
		||||
                                assert len(lbrx._state.subs) == 1
 | 
			
		||||
 | 
			
		||||
                                await lbrx.aclose()
 | 
			
		||||
 | 
			
		||||
                                assert len(lbrx._state.subs) == 0
 | 
			
		||||
 | 
			
		||||
                        except trio.ClosedResourceError:
 | 
			
		||||
                            # only the fast sub will try to re-enter
 | 
			
		||||
                            # iteration on the now closed bcaster
 | 
			
		||||
                            assert task.name == 'sub_1'
 | 
			
		||||
                            return
 | 
			
		||||
 | 
			
		||||
                        except Lagged:
 | 
			
		||||
                            lag_time = time.time() - start
 | 
			
		||||
                            lags = laggers[task.name]
 | 
			
		||||
                            print(
 | 
			
		||||
                                f'restarting slow task {task.name} '
 | 
			
		||||
                                f'that bailed out on {lags}:{value} '
 | 
			
		||||
                                f'after {lag_time:.3f}')
 | 
			
		||||
                            if lags <= retries:
 | 
			
		||||
                                laggers[task.name] += 1
 | 
			
		||||
                                continue
 | 
			
		||||
                            else:
 | 
			
		||||
                                print(
 | 
			
		||||
                                    f'{task.name} was too slow and terminated '
 | 
			
		||||
                                    f'on {lags}:{value}')
 | 
			
		||||
                                return
 | 
			
		||||
 | 
			
		||||
            async with trio.open_nursery() as nursery:
 | 
			
		||||
 | 
			
		||||
                for i in range(1, num_laggers):
 | 
			
		||||
 | 
			
		||||
                    task_name = f'sub_{i}'
 | 
			
		||||
                    laggers[task_name] = 0
 | 
			
		||||
                    nursery.start_soon(
 | 
			
		||||
                        partial(
 | 
			
		||||
                            sub_and_print,
 | 
			
		||||
                            delay=i*0.001,
 | 
			
		||||
                        ),
 | 
			
		||||
                        name=task_name,
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                # allow subs to sched
 | 
			
		||||
                await trio.sleep(0.1)
 | 
			
		||||
 | 
			
		||||
                async with tx:
 | 
			
		||||
                    for i in cycle(range(size)):
 | 
			
		||||
                        await tx.send(i)
 | 
			
		||||
                        if len(brx._state.subs) == 2:
 | 
			
		||||
                            # only one, the non lagger, sub is left
 | 
			
		||||
                            break
 | 
			
		||||
 | 
			
		||||
                # the non-lagger
 | 
			
		||||
                assert laggers.pop('sub_1') == 0
 | 
			
		||||
 | 
			
		||||
                for n, v in laggers.items():
 | 
			
		||||
                    assert v == 4
 | 
			
		||||
 | 
			
		||||
                assert tx._closed
 | 
			
		||||
                assert not tx._state.open_send_channels
 | 
			
		||||
 | 
			
		||||
                # check that "first" bcaster that we created
 | 
			
		||||
                # above, never wass iterated and is thus overrun
 | 
			
		||||
                try:
 | 
			
		||||
                    await brx.receive()
 | 
			
		||||
                except Lagged:
 | 
			
		||||
                    # expect tokio style index truncation
 | 
			
		||||
                    seq = brx._state.subs[brx.key]
 | 
			
		||||
                    assert seq == len(brx._state.queue) - 1
 | 
			
		||||
 | 
			
		||||
                # all backpressured entries in the underlying
 | 
			
		||||
                # channel should have been copied into the caster
 | 
			
		||||
                # queue trailing-window
 | 
			
		||||
                async for i in rx:
 | 
			
		||||
                    print(f'bped: {i}')
 | 
			
		||||
                    assert i in brx._state.queue
 | 
			
		||||
 | 
			
		||||
                # should be noop
 | 
			
		||||
                await brx.aclose()
 | 
			
		||||
 | 
			
		||||
    trio.run(main)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_first_recver_is_cancelled():
 | 
			
		||||
 | 
			
		||||
    async def main():
 | 
			
		||||
 | 
			
		||||
        # make sure it all works within the runtime
 | 
			
		||||
        async with tractor.open_root_actor():
 | 
			
		||||
 | 
			
		||||
            tx, rx = trio.open_memory_channel(1)
 | 
			
		||||
            brx = broadcast_receiver(rx, 1)
 | 
			
		||||
            cs = trio.CancelScope()
 | 
			
		||||
            sequence = list(range(3))
 | 
			
		||||
 | 
			
		||||
            async def sub_and_recv():
 | 
			
		||||
                with cs:
 | 
			
		||||
                    async with brx.subscribe() as bc:
 | 
			
		||||
                        async for value in bc:
 | 
			
		||||
                            print(value)
 | 
			
		||||
 | 
			
		||||
            async def cancel_and_send():
 | 
			
		||||
                await trio.sleep(0.2)
 | 
			
		||||
                cs.cancel()
 | 
			
		||||
                await tx.send(1)
 | 
			
		||||
 | 
			
		||||
            async with trio.open_nursery() as n:
 | 
			
		||||
 | 
			
		||||
                n.start_soon(sub_and_recv)
 | 
			
		||||
                await trio.sleep(0.1)
 | 
			
		||||
                assert brx._state.recv_ready
 | 
			
		||||
 | 
			
		||||
                n.start_soon(cancel_and_send)
 | 
			
		||||
 | 
			
		||||
                # ensure that we don't hang because no-task is now
 | 
			
		||||
                # waiting on the underlying receive..
 | 
			
		||||
                with trio.fail_after(0.5):
 | 
			
		||||
                    value = await brx.receive()
 | 
			
		||||
                    print(f'parent: {value}')
 | 
			
		||||
                    assert value == 1
 | 
			
		||||
 | 
			
		||||
    trio.run(main)
 | 
			
		||||
| 
						 | 
				
			
			@ -567,7 +567,7 @@ class Actor:
 | 
			
		|||
        try:
 | 
			
		||||
            send_chan, recv_chan = self._cids2qs[(actorid, cid)]
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            send_chan, recv_chan = trio.open_memory_channel(1000)
 | 
			
		||||
            send_chan, recv_chan = trio.open_memory_channel(2**6)
 | 
			
		||||
            send_chan.cid = cid  # type: ignore
 | 
			
		||||
            recv_chan.cid = cid  # type: ignore
 | 
			
		||||
            self._cids2qs[(actorid, cid)] = send_chan, recv_chan
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,315 @@
 | 
			
		|||
'''
 | 
			
		||||
``tokio`` style broadcast channel.
 | 
			
		||||
https://docs.rs/tokio/1.11.0/tokio/sync/broadcast/index.html
 | 
			
		||||
 | 
			
		||||
'''
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
from abc import abstractmethod
 | 
			
		||||
from collections import deque
 | 
			
		||||
from contextlib import asynccontextmanager
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from functools import partial
 | 
			
		||||
from operator import ne
 | 
			
		||||
from typing import Optional, Callable, Awaitable, Any, AsyncIterator, Protocol
 | 
			
		||||
from typing import Generic, TypeVar
 | 
			
		||||
 | 
			
		||||
import trio
 | 
			
		||||
from trio._core._run import Task
 | 
			
		||||
from trio.abc import ReceiveChannel
 | 
			
		||||
from trio.lowlevel import current_task
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# A regular invariant generic type
 | 
			
		||||
T = TypeVar("T")
 | 
			
		||||
 | 
			
		||||
# covariant because AsyncReceiver[Derived] can be passed to someone
 | 
			
		||||
# expecting AsyncReceiver[Base])
 | 
			
		||||
ReceiveType = TypeVar("ReceiveType", covariant=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AsyncReceiver(
 | 
			
		||||
    Protocol,
 | 
			
		||||
    Generic[ReceiveType],
 | 
			
		||||
):
 | 
			
		||||
    '''An async receivable duck-type that quacks much like trio's
 | 
			
		||||
    ``trio.abc.ReceieveChannel``.
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    async def receive(self) -> ReceiveType:
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def __aiter__(self) -> AsyncIterator[ReceiveType]:
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    async def __anext__(self) -> ReceiveType:
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    # ``trio.abc.AsyncResource`` methods
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    async def aclose(self):
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    async def __aenter__(self) -> AsyncReceiver[ReceiveType]:
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    async def __aexit__(self, *args) -> None:
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Lagged(trio.TooSlowError):
 | 
			
		||||
    '''Subscribed consumer task was too slow and was overrun
 | 
			
		||||
    by the fastest consumer-producer pair.
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class BroadcastState:
 | 
			
		||||
    '''Common state to all receivers of a broadcast.
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    queue: deque
 | 
			
		||||
    maxlen: int
 | 
			
		||||
 | 
			
		||||
    # map of underlying instance id keys to receiver instances which
 | 
			
		||||
    # must be provided as a singleton per broadcaster set.
 | 
			
		||||
    subs: dict[int, int]
 | 
			
		||||
 | 
			
		||||
    # broadcast event to wake up all sleeping consumer tasks
 | 
			
		||||
    # on a newly produced value from the sender.
 | 
			
		||||
    recv_ready: Optional[tuple[int, trio.Event]] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BroadcastReceiver(ReceiveChannel):
 | 
			
		||||
    '''A memory receive channel broadcaster which is non-lossy for the
 | 
			
		||||
    fastest consumer.
 | 
			
		||||
 | 
			
		||||
    Additional consumer tasks can receive all produced values by registering
 | 
			
		||||
    with ``.subscribe()`` and receiving from the new instance it delivers.
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
 | 
			
		||||
        rx_chan: AsyncReceiver,
 | 
			
		||||
        state: BroadcastState,
 | 
			
		||||
        receive_afunc: Optional[Callable[[], Awaitable[Any]]] = None,
 | 
			
		||||
 | 
			
		||||
    ) -> None:
 | 
			
		||||
 | 
			
		||||
        # register the original underlying (clone)
 | 
			
		||||
        self.key = id(self)
 | 
			
		||||
        self._state = state
 | 
			
		||||
        state.subs[self.key] = -1
 | 
			
		||||
 | 
			
		||||
        # underlying for this receiver
 | 
			
		||||
        self._rx = rx_chan
 | 
			
		||||
        self._recv = receive_afunc or rx_chan.receive
 | 
			
		||||
        self._closed: bool = False
 | 
			
		||||
 | 
			
		||||
    async def receive(self) -> ReceiveType:
 | 
			
		||||
 | 
			
		||||
        key = self.key
 | 
			
		||||
        state = self._state
 | 
			
		||||
 | 
			
		||||
        # TODO: ideally we can make some way to "lock out" the
 | 
			
		||||
        # underlying receive channel in some way such that if some task
 | 
			
		||||
        # tries to pull from it directly (i.e. one we're unaware of)
 | 
			
		||||
        # then it errors out.
 | 
			
		||||
 | 
			
		||||
        # only tasks which have entered ``.subscribe()`` can
 | 
			
		||||
        # receive on this broadcaster.
 | 
			
		||||
        try:
 | 
			
		||||
            seq = state.subs[key]
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            if self._closed:
 | 
			
		||||
                raise trio.ClosedResourceError
 | 
			
		||||
 | 
			
		||||
            raise RuntimeError(
 | 
			
		||||
                f'{self} is not registerd as subscriber')
 | 
			
		||||
 | 
			
		||||
        # check that task does not already have a value it can receive
 | 
			
		||||
        # immediately and/or that it has lagged.
 | 
			
		||||
        if seq > -1:
 | 
			
		||||
            # get the oldest value we haven't received immediately
 | 
			
		||||
            try:
 | 
			
		||||
                value = state.queue[seq]
 | 
			
		||||
            except IndexError:
 | 
			
		||||
 | 
			
		||||
                # adhere to ``tokio`` style "lagging":
 | 
			
		||||
                # "Once RecvError::Lagged is returned, the lagging
 | 
			
		||||
                # receiver's position is updated to the oldest value
 | 
			
		||||
                # contained by the channel. The next call to recv will
 | 
			
		||||
                # return this value."
 | 
			
		||||
                # https://docs.rs/tokio/1.11.0/tokio/sync/broadcast/index.html#lagging
 | 
			
		||||
 | 
			
		||||
                # decrement to the last value and expect
 | 
			
		||||
                # consumer to either handle the ``Lagged`` and come back
 | 
			
		||||
                # or bail out on its own (thus un-subscribing)
 | 
			
		||||
                state.subs[key] = state.maxlen - 1
 | 
			
		||||
 | 
			
		||||
                # this task was overrun by the producer side
 | 
			
		||||
                task: Task = current_task()
 | 
			
		||||
                raise Lagged(f'Task {task.name} was overrun')
 | 
			
		||||
 | 
			
		||||
            state.subs[key] -= 1
 | 
			
		||||
            return value
 | 
			
		||||
 | 
			
		||||
        # current task already has the latest value **and** is the
 | 
			
		||||
        # first task to begin waiting for a new one
 | 
			
		||||
        if state.recv_ready is None:
 | 
			
		||||
 | 
			
		||||
            if self._closed:
 | 
			
		||||
                raise trio.ClosedResourceError
 | 
			
		||||
 | 
			
		||||
            event = trio.Event()
 | 
			
		||||
            state.recv_ready = key, event
 | 
			
		||||
 | 
			
		||||
            # if we're cancelled here it should be
 | 
			
		||||
            # fine to bail without affecting any other consumers
 | 
			
		||||
            # right?
 | 
			
		||||
            try:
 | 
			
		||||
                value = await self._recv()
 | 
			
		||||
 | 
			
		||||
                # items with lower indices are "newer"
 | 
			
		||||
                # NOTE: ``collections.deque`` implicitly takes care of
 | 
			
		||||
                # trucating values outside our ``state.maxlen``. In the
 | 
			
		||||
                # alt-backend-array-case we'll need to make sure this is
 | 
			
		||||
                # implemented in similar ringer-buffer-ish style.
 | 
			
		||||
                state.queue.appendleft(value)
 | 
			
		||||
 | 
			
		||||
                # broadcast new value to all subscribers by increasing
 | 
			
		||||
                # all sequence numbers that will point in the queue to
 | 
			
		||||
                # their latest available value.
 | 
			
		||||
 | 
			
		||||
                # don't decrement the sequence for this task since we
 | 
			
		||||
                # already retreived the last value
 | 
			
		||||
 | 
			
		||||
                # XXX: which of these impls is fastest?
 | 
			
		||||
 | 
			
		||||
                # subs = state.subs.copy()
 | 
			
		||||
                # subs.pop(key)
 | 
			
		||||
 | 
			
		||||
                for sub_key in filter(
 | 
			
		||||
                    # lambda k: k != key, state.subs,
 | 
			
		||||
                    partial(ne, key), state.subs,
 | 
			
		||||
                ):
 | 
			
		||||
                    state.subs[sub_key] += 1
 | 
			
		||||
 | 
			
		||||
                # NOTE: this should ONLY be set if the above task was *NOT*
 | 
			
		||||
                # cancelled on the `._recv()` call.
 | 
			
		||||
                event.set()
 | 
			
		||||
                return value
 | 
			
		||||
 | 
			
		||||
            except trio.Cancelled:
 | 
			
		||||
                # handle cancelled specially otherwise sibling
 | 
			
		||||
                # consumers will be awoken with a sequence of -1
 | 
			
		||||
                # state.recv_ready = trio.Cancelled
 | 
			
		||||
                if event.statistics().tasks_waiting:
 | 
			
		||||
                    event.set()
 | 
			
		||||
                raise
 | 
			
		||||
 | 
			
		||||
            finally:
 | 
			
		||||
 | 
			
		||||
                # Reset receiver waiter task event for next blocking condition.
 | 
			
		||||
                # this MUST be reset even if the above ``.recv()`` call
 | 
			
		||||
                # was cancelled to avoid the next consumer from blocking on
 | 
			
		||||
                # an event that won't be set!
 | 
			
		||||
                state.recv_ready = None
 | 
			
		||||
 | 
			
		||||
        # This task is all caught up and ready to receive the latest
 | 
			
		||||
        # value, so queue sched it on the internal event.
 | 
			
		||||
        else:
 | 
			
		||||
            seq = state.subs[key]
 | 
			
		||||
            assert seq == -1  # sanity
 | 
			
		||||
            _, ev = state.recv_ready
 | 
			
		||||
            await ev.wait()
 | 
			
		||||
 | 
			
		||||
            # NOTE: if we ever would like the behaviour where if the
 | 
			
		||||
            # first task to recv on the underlying is cancelled but it
 | 
			
		||||
            # still DOES trigger the ``.recv_ready``, event we'll likely need
 | 
			
		||||
            # this logic:
 | 
			
		||||
 | 
			
		||||
            if seq > -1:
 | 
			
		||||
                # stuff from above..
 | 
			
		||||
                seq = state.subs[key]
 | 
			
		||||
 | 
			
		||||
                value = state.queue[seq]
 | 
			
		||||
                state.subs[key] -= 1
 | 
			
		||||
                return value
 | 
			
		||||
 | 
			
		||||
            elif seq == -1:
 | 
			
		||||
                # XXX: In the case where the first task to allocate the
 | 
			
		||||
                # ``.recv_ready`` event is cancelled we will be woken with
 | 
			
		||||
                # a non-incremented sequence number and thus will read the
 | 
			
		||||
                # oldest value if we use that. Instead we need to detect if
 | 
			
		||||
                # we have not been incremented and then receive again.
 | 
			
		||||
                return await self.receive()
 | 
			
		||||
 | 
			
		||||
            else:
 | 
			
		||||
                raise ValueError(f'Invalid sequence {seq}!?')
 | 
			
		||||
 | 
			
		||||
    @asynccontextmanager
 | 
			
		||||
    async def subscribe(
 | 
			
		||||
        self,
 | 
			
		||||
    ) -> AsyncIterator[BroadcastReceiver]:
 | 
			
		||||
        '''Subscribe for values from this broadcast receiver.
 | 
			
		||||
 | 
			
		||||
        Returns a new ``BroadCastReceiver`` which is registered for and
 | 
			
		||||
        pulls data from a clone of the original ``trio.abc.ReceiveChannel``
 | 
			
		||||
        provided at creation.
 | 
			
		||||
 | 
			
		||||
        '''
 | 
			
		||||
        if self._closed:
 | 
			
		||||
            raise trio.ClosedResourceError
 | 
			
		||||
 | 
			
		||||
        state = self._state
 | 
			
		||||
        br = BroadcastReceiver(
 | 
			
		||||
            rx_chan=self._rx,
 | 
			
		||||
            state=state,
 | 
			
		||||
            receive_afunc=self._recv,
 | 
			
		||||
        )
 | 
			
		||||
        # assert clone in state.subs
 | 
			
		||||
        assert br.key in state.subs
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            yield br
 | 
			
		||||
        finally:
 | 
			
		||||
            await br.aclose()
 | 
			
		||||
 | 
			
		||||
    async def aclose(
 | 
			
		||||
        self,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
 | 
			
		||||
        if self._closed:
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        # XXX: leaving it like this consumers can still get values
 | 
			
		||||
        # up to the last received that still reside in the queue.
 | 
			
		||||
        self._state.subs.pop(self.key)
 | 
			
		||||
 | 
			
		||||
        self._closed = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def broadcast_receiver(
 | 
			
		||||
 | 
			
		||||
    recv_chan: AsyncReceiver,
 | 
			
		||||
    max_buffer_size: int,
 | 
			
		||||
    **kwargs,
 | 
			
		||||
 | 
			
		||||
) -> BroadcastReceiver:
 | 
			
		||||
 | 
			
		||||
    return BroadcastReceiver(
 | 
			
		||||
        recv_chan,
 | 
			
		||||
        state=BroadcastState(
 | 
			
		||||
            queue=deque(maxlen=max_buffer_size),
 | 
			
		||||
            maxlen=max_buffer_size,
 | 
			
		||||
            subs={},
 | 
			
		||||
        ),
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			@ -294,6 +294,7 @@ class Portal:
 | 
			
		|||
    async def open_stream_from(
 | 
			
		||||
        self,
 | 
			
		||||
        async_gen_func: Callable,  # typing: ignore
 | 
			
		||||
        shield: bool = False,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
 | 
			
		||||
    ) -> AsyncGenerator[ReceiveMsgStream, None]:
 | 
			
		||||
| 
						 | 
				
			
			@ -320,7 +321,9 @@ class Portal:
 | 
			
		|||
        ctx = Context(self.channel, cid, _portal=self)
 | 
			
		||||
        try:
 | 
			
		||||
            # deliver receive only stream
 | 
			
		||||
            async with ReceiveMsgStream(ctx, recv_chan) as rchan:
 | 
			
		||||
            async with ReceiveMsgStream(
 | 
			
		||||
                ctx, recv_chan, shield=shield
 | 
			
		||||
            ) as rchan:
 | 
			
		||||
                self._streams.add(rchan)
 | 
			
		||||
                yield rchan
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2,12 +2,14 @@
 | 
			
		|||
Message stream types and APIs.
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
import inspect
 | 
			
		||||
from contextlib import contextmanager, asynccontextmanager
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import (
 | 
			
		||||
    Any, Iterator, Optional, Callable,
 | 
			
		||||
    AsyncGenerator, Dict,
 | 
			
		||||
    AsyncIterator
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
import warnings
 | 
			
		||||
| 
						 | 
				
			
			@ -17,6 +19,7 @@ import trio
 | 
			
		|||
from ._ipc import Channel
 | 
			
		||||
from ._exceptions import unpack_error, ContextCancelled
 | 
			
		||||
from ._state import current_actor
 | 
			
		||||
from ._broadcast import broadcast_receiver, BroadcastReceiver
 | 
			
		||||
from .log import get_logger
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -47,11 +50,14 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
 | 
			
		|||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        ctx: 'Context',  # typing: ignore # noqa
 | 
			
		||||
        rx_chan: trio.abc.ReceiveChannel,
 | 
			
		||||
        rx_chan: trio.MemoryReceiveChannel,
 | 
			
		||||
        shield: bool = False,
 | 
			
		||||
        _broadcaster: Optional[BroadcastReceiver] = None,
 | 
			
		||||
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        self._ctx = ctx
 | 
			
		||||
        self._rx_chan = rx_chan
 | 
			
		||||
        self._broadcaster = _broadcaster
 | 
			
		||||
 | 
			
		||||
        # flag to denote end of stream
 | 
			
		||||
        self._eoc: bool = False
 | 
			
		||||
| 
						 | 
				
			
			@ -231,6 +237,50 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
 | 
			
		|||
        # still need to consume msgs that are "in transit" from the far
 | 
			
		||||
        # end (eg. for ``Context.result()``).
 | 
			
		||||
 | 
			
		||||
    @asynccontextmanager
 | 
			
		||||
    async def subscribe(
 | 
			
		||||
        self,
 | 
			
		||||
 | 
			
		||||
    ) -> AsyncIterator[BroadcastReceiver]:
 | 
			
		||||
        '''Allocate and return a ``BroadcastReceiver`` which delegates
 | 
			
		||||
        to this message stream.
 | 
			
		||||
 | 
			
		||||
        This allows multiple local tasks to receive each their own copy
 | 
			
		||||
        of this message stream.
 | 
			
		||||
 | 
			
		||||
        This operation is indempotent and and mutates this stream's
 | 
			
		||||
        receive machinery to copy and window-length-store each received
 | 
			
		||||
        value from the far end via the internally created broudcast
 | 
			
		||||
        receiver wrapper.
 | 
			
		||||
 | 
			
		||||
        '''
 | 
			
		||||
        # NOTE: This operation is indempotent and non-reversible, so be
 | 
			
		||||
        # sure you can deal with any (theoretical) overhead of the the
 | 
			
		||||
        # allocated ``BroadcastReceiver`` before calling this method for
 | 
			
		||||
        # the first time.
 | 
			
		||||
        if self._broadcaster is None:
 | 
			
		||||
 | 
			
		||||
            bcast = self._broadcaster = broadcast_receiver(
 | 
			
		||||
                self,
 | 
			
		||||
                # use memory channel size by default
 | 
			
		||||
                self._rx_chan._state.max_buffer_size,  # type: ignore
 | 
			
		||||
                receive_afunc=self.receive,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # NOTE: we override the original stream instance's receive
 | 
			
		||||
            # method to now delegate to the broadcaster's ``.receive()``
 | 
			
		||||
            # such that new subscribers will be copied received values
 | 
			
		||||
            # and this stream doesn't have to expect it's original
 | 
			
		||||
            # consumer(s) to get a new broadcast rx handle.
 | 
			
		||||
            self.receive = bcast.receive  # type: ignore
 | 
			
		||||
            # seems there's no graceful way to type this with ``mypy``?
 | 
			
		||||
            # https://github.com/python/mypy/issues/708
 | 
			
		||||
 | 
			
		||||
        async with self._broadcaster.subscribe() as bstream:
 | 
			
		||||
            assert bstream.key != self._broadcaster.key
 | 
			
		||||
            assert bstream._recv == self._broadcaster._recv
 | 
			
		||||
            yield bstream
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MsgStream(ReceiveMsgStream, trio.abc.Channel):
 | 
			
		||||
    """
 | 
			
		||||
| 
						 | 
				
			
			@ -247,17 +297,6 @@ class MsgStream(ReceiveMsgStream, trio.abc.Channel):
 | 
			
		|||
        '''
 | 
			
		||||
        await self._ctx.chan.send({'yield': data, 'cid': self._ctx.cid})
 | 
			
		||||
 | 
			
		||||
    # TODO: but make it broadcasting to consumers
 | 
			
		||||
    def clone(self):
 | 
			
		||||
        """Clone this receive channel allowing for multi-task
 | 
			
		||||
        consumption from the same channel.
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        return MsgStream(
 | 
			
		||||
            self._ctx,
 | 
			
		||||
            self._rx_chan.clone(),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class Context:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue