forked from goodboy/tractor
				
			Simplify api around receive channel
Buncha improvements: - pass in the queue via constructor - tracking over all underlying memory channel closure using cloning - do it like `tokio` and set lagged consumers to the last sequence before raising - copy the subs on first receiver wakeup for iteration instead of iterating the table directly (and being forced to skip the current tasks sequence increment) - implement `.aclose()` to close the underlying clone for this task - make `broadcast_receiver()` just take the recv chan since it doesn't need anything on the send side.live_on_air_from_tokio
							parent
							
								
									3817b4fb5e
								
							
						
					
					
						commit
						6a2c3da1bb
					
				| 
						 | 
				
			
			@ -1,25 +1,22 @@
 | 
			
		|||
'''
 | 
			
		||||
``tokio`` style broadcast channels.
 | 
			
		||||
``tokio`` style broadcast channel.
 | 
			
		||||
https://tokio-rs.github.io/tokio/doc/tokio/sync/broadcast/index.html
 | 
			
		||||
 | 
			
		||||
'''
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
# from math import inf
 | 
			
		||||
from itertools import cycle
 | 
			
		||||
from collections import deque
 | 
			
		||||
from contextlib import contextmanager  # , asynccontextmanager
 | 
			
		||||
from contextlib import contextmanager
 | 
			
		||||
from functools import partial
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
import trio
 | 
			
		||||
import tractor
 | 
			
		||||
from trio.lowlevel import current_task
 | 
			
		||||
from trio.abc import ReceiveChannel  # , SendChannel
 | 
			
		||||
# from trio._core import enable_ki_protection
 | 
			
		||||
from trio.abc import ReceiveChannel
 | 
			
		||||
from trio._core._run import Task
 | 
			
		||||
from trio._channel import (
 | 
			
		||||
    MemorySendChannel,
 | 
			
		||||
    MemoryReceiveChannel,
 | 
			
		||||
    # MemoryChannelState,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -28,20 +25,25 @@ class Lagged(trio.TooSlowError):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class BroadcastReceiver(ReceiveChannel):
 | 
			
		||||
    '''This isn't Paris, not Berlin, nor Honk Kong..
 | 
			
		||||
    '''A memory receive channel broadcaster which is non-lossy for the
 | 
			
		||||
    fastest consumer.
 | 
			
		||||
 | 
			
		||||
    Additional consumer tasks can receive all produced values by registering
 | 
			
		||||
    with ``.subscribe()``.
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
 | 
			
		||||
        rx_chan: MemoryReceiveChannel,
 | 
			
		||||
        buffer_size: int = 100,
 | 
			
		||||
        queue: deque,
 | 
			
		||||
 | 
			
		||||
    ) -> None:
 | 
			
		||||
 | 
			
		||||
        self._rx = rx_chan
 | 
			
		||||
        self._len = buffer_size
 | 
			
		||||
        self._queue = deque(maxlen=buffer_size)
 | 
			
		||||
        self._subs = {id(current_task()): -1}
 | 
			
		||||
        self._queue = queue
 | 
			
		||||
        self._subs: dict[Task, int] = {}  # {id(current_task()): -1}
 | 
			
		||||
        self._clones: dict[Task, MemoryReceiveChannel] = {}
 | 
			
		||||
        self._value_received: Optional[trio.Event] = None
 | 
			
		||||
 | 
			
		||||
    async def receive(self):
 | 
			
		||||
| 
						 | 
				
			
			@ -56,26 +58,30 @@ class BroadcastReceiver(ReceiveChannel):
 | 
			
		|||
        try:
 | 
			
		||||
            seq = self._subs[key]
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            self._subs.pop(key)
 | 
			
		||||
            raise RuntimeError(
 | 
			
		||||
                f'Task {task.name} is not registerd as subscriber')
 | 
			
		||||
 | 
			
		||||
        if seq > -1:
 | 
			
		||||
            # get the oldest value we haven't received immediately
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                value = self._queue[seq]
 | 
			
		||||
            except IndexError:
 | 
			
		||||
                # decrement to the last value and expect
 | 
			
		||||
                # consumer to either handle the ``Lagged`` and come back
 | 
			
		||||
                # or bail out on it's own (thus un-subscribing)
 | 
			
		||||
                self._subs[key] = self._queue.maxlen - 1
 | 
			
		||||
 | 
			
		||||
                # this task was overrun by the producer side
 | 
			
		||||
                raise Lagged(f'Task {task.name} was overrun')
 | 
			
		||||
 | 
			
		||||
            self._subs[key] -= 1
 | 
			
		||||
            return value
 | 
			
		||||
 | 
			
		||||
        if self._value_received is None:
 | 
			
		||||
            # we already have the latest value **and** are the first
 | 
			
		||||
            # task to begin waiting for a new one
 | 
			
		||||
            # current task already has the latest value **and** is the
 | 
			
		||||
            # first task to begin waiting for a new one
 | 
			
		||||
 | 
			
		||||
            # sanity checks with underlying chan ?
 | 
			
		||||
            # what sanity checks might we use for the underlying chan ?
 | 
			
		||||
            # assert not self._rx._state.data
 | 
			
		||||
 | 
			
		||||
            event = self._value_received = trio.Event()
 | 
			
		||||
| 
						 | 
				
			
			@ -87,20 +93,15 @@ class BroadcastReceiver(ReceiveChannel):
 | 
			
		|||
            # broadcast new value to all subscribers by increasing
 | 
			
		||||
            # all sequence numbers that will point in the queue to
 | 
			
		||||
            # their latest available value.
 | 
			
		||||
            for sub_key, seq in self._subs.items():
 | 
			
		||||
 | 
			
		||||
                if key == sub_key:
 | 
			
		||||
                    # we don't need to increase **this** task's
 | 
			
		||||
                    # sequence number since we just consumed the latest
 | 
			
		||||
                    # value
 | 
			
		||||
                    continue
 | 
			
		||||
 | 
			
		||||
                # # except TypeError:
 | 
			
		||||
                # #     # already lagged
 | 
			
		||||
                # #     seq = Lagged
 | 
			
		||||
 | 
			
		||||
            subs = self._subs.copy()
 | 
			
		||||
            # don't decerement the sequence # for this task since we
 | 
			
		||||
            # already retreived the last value
 | 
			
		||||
            subs.pop(key)
 | 
			
		||||
            for sub_key, seq in subs.items():
 | 
			
		||||
                self._subs[sub_key] += 1
 | 
			
		||||
 | 
			
		||||
            # reset receiver waiter task event for next blocking condition
 | 
			
		||||
            self._value_received = None
 | 
			
		||||
            event.set()
 | 
			
		||||
            return value
 | 
			
		||||
| 
						 | 
				
			
			@ -109,7 +110,7 @@ class BroadcastReceiver(ReceiveChannel):
 | 
			
		|||
            await self._value_received.wait()
 | 
			
		||||
 | 
			
		||||
            seq = self._subs[key]
 | 
			
		||||
            assert seq > -1, 'Uhhhh'
 | 
			
		||||
            assert seq > -1, 'Internal error?'
 | 
			
		||||
 | 
			
		||||
            self._subs[key] -= 1
 | 
			
		||||
            return self._queue[0]
 | 
			
		||||
| 
						 | 
				
			
			@ -118,30 +119,37 @@ class BroadcastReceiver(ReceiveChannel):
 | 
			
		|||
    @contextmanager
 | 
			
		||||
    def subscribe(
 | 
			
		||||
        self,
 | 
			
		||||
 | 
			
		||||
    ) -> BroadcastReceiver:
 | 
			
		||||
        key = id(current_task())
 | 
			
		||||
        self._subs[key] = -1
 | 
			
		||||
        # XXX: we only use this clone for closure tracking
 | 
			
		||||
        clone = self._clones[key] = self._rx.clone()
 | 
			
		||||
        try:
 | 
			
		||||
            yield self
 | 
			
		||||
        finally:
 | 
			
		||||
            self._subs.pop(key)
 | 
			
		||||
            clone.close()
 | 
			
		||||
 | 
			
		||||
    # TODO: do we need anything here?
 | 
			
		||||
    # if we're the last sub to close then close
 | 
			
		||||
    # the underlying rx channel, but couldn't we just
 | 
			
		||||
    # use ``.clone()``s trackign then?
 | 
			
		||||
    async def aclose(self) -> None:
 | 
			
		||||
        # TODO: wtf should we do here?
 | 
			
		||||
        # if we're the last sub to close then close
 | 
			
		||||
        # the underlying rx channel
 | 
			
		||||
        pass
 | 
			
		||||
        key = id(current_task())
 | 
			
		||||
        await self._clones[key].aclose()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def broadcast_channel(
 | 
			
		||||
def broadcast_receiver(
 | 
			
		||||
 | 
			
		||||
    recv_chan: MemoryReceiveChannel,
 | 
			
		||||
    max_buffer_size: int,
 | 
			
		||||
 | 
			
		||||
) -> (MemorySendChannel, BroadcastReceiver):
 | 
			
		||||
) -> BroadcastReceiver:
 | 
			
		||||
 | 
			
		||||
    tx, rx = trio.open_memory_channel(max_buffer_size)
 | 
			
		||||
    return tx, BroadcastReceiver(rx)
 | 
			
		||||
    return BroadcastReceiver(
 | 
			
		||||
        recv_chan,
 | 
			
		||||
        queue=deque(maxlen=max_buffer_size),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
| 
						 | 
				
			
			@ -153,7 +161,9 @@ if __name__ == '__main__':
 | 
			
		|||
            # loglevel='info',
 | 
			
		||||
        ):
 | 
			
		||||
 | 
			
		||||
            tx, rx = broadcast_channel(100)
 | 
			
		||||
            size = 100
 | 
			
		||||
            tx, rx = trio.open_memory_channel(size)
 | 
			
		||||
            rx = broadcast_receiver(rx, size)
 | 
			
		||||
 | 
			
		||||
            async def sub_and_print(
 | 
			
		||||
                delay: float,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue