"""
Broadcast channels for fan-out to local tasks.

"""
from contextlib import (
    asynccontextmanager as acm,
)
from functools import partial
from itertools import cycle
import time
from typing import Optional

import pytest
import trio
from trio.lowlevel import current_task
import tractor
from tractor.trionics import (
    broadcast_receiver,
    Lagged,
    collapse_eg,
)


@tractor.context
async def echo_sequences(

    ctx:  tractor.Context,

) -> None:
    '''Bidir streaming endpoint which will stream
    back any sequence it is sent item-wise.

    '''
    await ctx.started()

    async with ctx.open_stream() as stream:
        async for sequence in stream:
            seq = list(sequence)
            for value in seq:
                await stream.send(value)
                print(f'producer sent {value}')


async def ensure_sequence(

    stream: tractor.MsgStream,
    sequence: list,
    delay: Optional[float] = None,

) -> None:

    name = current_task().name
    async with stream.subscribe() as bcaster:
        assert not isinstance(bcaster, type(stream))
        async for value in bcaster:
            print(f'{name} rx: {value}')
            assert value == sequence[0]
            sequence.remove(value)

            if delay:
                await trio.sleep(delay)

            if not sequence:
                # fully consumed
                break


@acm
async def open_sequence_streamer(

    sequence: list[int],
    reg_addr: tuple[str, int],
    start_method: str,

) -> tractor.MsgStream:

    async with tractor.open_nursery(
        arbiter_addr=reg_addr,
        start_method=start_method,
    ) as an:

        portal = await an.start_actor(
            'sequence_echoer',
            enable_modules=[__name__],
        )

        async with portal.open_context(
            echo_sequences,
        ) as (ctx, first):

            assert first is None
            async with ctx.open_stream(allow_overruns=True) as stream:
                yield stream

        await portal.cancel_actor()


def test_stream_fan_out_to_local_subscriptions(
    reg_addr,
    start_method,
):

    sequence = list(range(1000))

    async def main():

        async with open_sequence_streamer(
            sequence,
            reg_addr,
            start_method,
        ) as stream:

            async with trio.open_nursery() as n:
                for i in range(10):
                    n.start_soon(
                        ensure_sequence,
                        stream,
                        sequence.copy(),
                        name=f'consumer_{i}',
                    )

                await stream.send(tuple(sequence))

                async for value in stream:
                    print(f'source stream rx: {value}')
                    assert value == sequence[0]
                    sequence.remove(value)

                    if not sequence:
                        # fully consumed
                        break

    trio.run(main)


@pytest.mark.parametrize(
    'task_delays',
    [
        (0.01, 0.001),
        (0.001, 0.01),
    ]
)
def test_consumer_and_parent_maybe_lag(
    reg_addr,
    start_method,
    task_delays,
):

    async def main():

        sequence = list(range(300))
        parent_delay, sub_delay = task_delays

        async with open_sequence_streamer(
            sequence,
            reg_addr,
            start_method,
        ) as stream:

            try:
                async with (
                    collapse_eg(),
                    trio.open_nursery() as tn,
                ):

                    tn.start_soon(
                        ensure_sequence,
                        stream,
                        sequence.copy(),
                        sub_delay,
                        name='consumer_task',
                    )

                    await stream.send(tuple(sequence))

                    # async for value in stream:
                    lagged = False
                    lag_count = 0

                    while True:
                        try:
                            value = await stream.receive()
                            print(f'source stream rx: {value}')

                            if lagged:
                                # re set the sequence starting at our last
                                # value
                                sequence = sequence[sequence.index(value) + 1:]
                            else:
                                assert value == sequence[0]
                                sequence.remove(value)

                            lagged = False

                        except Lagged:
                            lagged = True
                            print(f'source stream lagged after {value}')
                            lag_count += 1
                            continue

                        # lag the parent
                        await trio.sleep(parent_delay)

                        if not sequence:
                            # fully consumed
                            break
                    print(f'parent + source stream lagged: {lag_count}')

                    if parent_delay > sub_delay:
                        assert lag_count > 0

            except Lagged:
                # child was lagged
                assert parent_delay < sub_delay

    trio.run(main)


def test_faster_task_to_recv_is_cancelled_by_slower(
    reg_addr,
    start_method,
):
    '''
    Ensure that if a faster task consuming from a stream is cancelled
    the slower task can continue to receive all expected values.

    '''
    async def main():

        sequence = list(range(1000))

        async with open_sequence_streamer(
            sequence,
            reg_addr,
            start_method,

        ) as stream:

            async with trio.open_nursery() as tn:
                tn.start_soon(
                    ensure_sequence,
                    stream,
                    sequence.copy(),
                    0,
                    name='consumer_task',
                )

                await stream.send(tuple(sequence))

                # pull 3 values, cancel the subtask, then
                # expect to be able to pull all values still
                for i in range(20):
                    try:
                        value = await stream.receive()
                        print(f'source stream rx: {value}')
                        await trio.sleep(0.01)
                    except Lagged:
                        print(f'parent overrun after {value}')
                        continue

                print('cancelling faster subtask')
                tn.cancel_scope.cancel()

            try:
                value = await stream.receive()
                print(f'source stream after cancel: {value}')
            except Lagged:
                print(f'parent overrun after {value}')

            # expect to see all remaining values
            with trio.fail_after(0.5):
                async for value in stream:
                    assert stream._broadcaster._state.recv_ready is None
                    print(f'source stream rx: {value}')
                    if value == 999:
                        # fully consumed and we missed no values once
                        # the faster subtask was cancelled
                        break

                # await tractor.pause()
                # await stream.receive()
                print(f'final value: {value}')

    trio.run(main)


def test_subscribe_errors_after_close():

    async def main():

        size = 1
        tx, rx = trio.open_memory_channel(size)
        async with broadcast_receiver(rx, size) as brx:
            pass

        try:
            # open and close
            async with brx.subscribe():
                pass

        except trio.ClosedResourceError:
            assert brx.key not in brx._state.subs

        else:
            assert 0

    trio.run(main)


def test_ensure_slow_consumers_lag_out(
    reg_addr,
    start_method,
):
    '''This is a pure local task test; no tractor
    machinery is really required.

    '''
    async def main():

        # make sure it all works within the runtime
        async with tractor.open_root_actor():

            num_laggers = 4
            laggers: dict[str, int] = {}
            retries = 3
            size = 100
            tx, rx = trio.open_memory_channel(size)
            brx = broadcast_receiver(rx, size)

            async def sub_and_print(
                delay: float,
            ) -> None:

                task = current_task()
                start = time.time()

                async with brx.subscribe() as lbrx:
                    while True:
                        print(f'{task.name}: starting consume loop')
                        try:
                            async for value in lbrx:
                                print(f'{task.name}: {value}')
                                await trio.sleep(delay)

                            if task.name == 'sub_1':
                                # trigger checkpoint to clean out other subs
                                await trio.sleep(0.01)

                                # the non-lagger got
                                # a ``trio.EndOfChannel``
                                # because the ``tx`` below was closed
                                assert len(lbrx._state.subs) == 1

                                await lbrx.aclose()

                                assert len(lbrx._state.subs) == 0

                        except trio.ClosedResourceError:
                            # only the fast sub will try to re-enter
                            # iteration on the now closed bcaster
                            assert task.name == 'sub_1'
                            return

                        except Lagged:
                            lag_time = time.time() - start
                            lags = laggers[task.name]
                            print(
                                f'restarting slow task {task.name} '
                                f'that bailed out on {lags}:{value} '
                                f'after {lag_time:.3f}')
                            if lags <= retries:
                                laggers[task.name] += 1
                                continue
                            else:
                                print(
                                    f'{task.name} was too slow and terminated '
                                    f'on {lags}:{value}')
                                return

            async with trio.open_nursery() as tn:

                for i in range(1, num_laggers):

                    task_name = f'sub_{i}'
                    laggers[task_name] = 0
                    tn.start_soon(
                        partial(
                            sub_and_print,
                            delay=i*0.001,
                        ),
                        name=task_name,
                    )

                # allow subs to sched
                await trio.sleep(0.1)

                async with tx:
                    for i in cycle(range(size)):
                        await tx.send(i)
                        if len(brx._state.subs) == 2:
                            # only one, the non lagger, sub is left
                            break

                # the non-lagger
                assert laggers.pop('sub_1') == 0

                for n, v in laggers.items():
                    assert v == 4

                assert tx._closed
                assert not tx._state.open_send_channels

                # check that "first" bcaster that we created
                # above, never was iterated and is thus overrun
                try:
                    await brx.receive()
                except Lagged:
                    # expect tokio style index truncation
                    seq = brx._state.subs[brx.key]
                    assert seq == len(brx._state.queue) - 1

                # all no_overruns entries in the underlying
                # channel should have been copied into the bcaster
                # queue trailing-window
                async for i in rx:
                    print(f'bped: {i}')
                    assert i in brx._state.queue

                # should be noop
                await brx.aclose()

    trio.run(main)


def test_first_recver_is_cancelled():

    async def main():

        # make sure it all works within the runtime
        async with tractor.open_root_actor():

            tx, rx = trio.open_memory_channel(1)
            brx = broadcast_receiver(rx, 1)
            cs = trio.CancelScope()

            async def sub_and_recv():
                with cs:
                    async with brx.subscribe() as bc:
                        async for value in bc:
                            print(value)

            async def cancel_and_send():
                await trio.sleep(0.2)
                cs.cancel()
                await tx.send(1)

            async with trio.open_nursery() as n:

                n.start_soon(sub_and_recv)
                await trio.sleep(0.1)
                assert brx._state.recv_ready

                n.start_soon(cancel_and_send)

                # ensure that we don't hang because no-task is now
                # waiting on the underlying receive..
                with trio.fail_after(0.5):
                    value = await brx.receive()
                    print(f'parent: {value}')
                    assert value == 1

    trio.run(main)


def test_no_raise_on_lag():
    '''
    Run a simple 2-task broadcast where one task is slow but configured
    so that it does not raise `Lagged` on overruns using
    `raise_on_lasg=False` and verify that the task does not raise.

    '''
    size = 100
    tx, rx = trio.open_memory_channel(size)
    brx = broadcast_receiver(rx, size)

    async def slow():
        async with brx.subscribe(
            raise_on_lag=False,
        ) as br:
            async for msg in br:
                print(f'slow task got: {msg}')
                await trio.sleep(0.1)

    async def fast():
        async with brx.subscribe() as br:
            async for msg in br:
                print(f'fast task got: {msg}')

    async def main():
        async with (
            tractor.open_root_actor(
                # NOTE: so we see the warning msg emitted by the bcaster
                # internals when the no raise flag is set.
                loglevel='warning',
            ),
            collapse_eg(),
            trio.open_nursery() as n,
        ):
            n.start_soon(slow)
            n.start_soon(fast)

            for i in range(1000):
                await tx.send(i)

            # simulate user nailing ctl-c after realizing
            # there's a lag in the slow task.
            await trio.sleep(1)
            raise KeyboardInterrupt

    with pytest.raises(KeyboardInterrupt):
        trio.run(main)