Simplify api around receive channel

Buncha improvements:
- pass in the queue via constructor
- tracking over all underlying memory channel closure using cloning
- do it like `tokio` and set lagged consumers to the last sequence
  before raising
- copy the subs on first receiver wakeup for iteration instead of
  iterating the table directly (and being forced to skip the current
  tasks sequence increment)
- implement `.aclose()` to close the underlying clone for this task
- make `broadcast_receiver()` just take the recv chan since it doesn't
  need anything on the send side.
tokio_backup
Tyler Goodlet 2021-08-08 19:48:02 -04:00
parent af6e8a64ad
commit dfc4082ad2
1 changed files with 49 additions and 39 deletions

View File

@ -1,25 +1,22 @@
''' '''
``tokio`` style broadcast channels. ``tokio`` style broadcast channel.
https://tokio-rs.github.io/tokio/doc/tokio/sync/broadcast/index.html
''' '''
from __future__ import annotations from __future__ import annotations
# from math import inf
from itertools import cycle from itertools import cycle
from collections import deque from collections import deque
from contextlib import contextmanager # , asynccontextmanager from contextlib import contextmanager
from functools import partial from functools import partial
from typing import Optional from typing import Optional
import trio import trio
import tractor import tractor
from trio.lowlevel import current_task from trio.lowlevel import current_task
from trio.abc import ReceiveChannel # , SendChannel from trio.abc import ReceiveChannel
# from trio._core import enable_ki_protection
from trio._core._run import Task from trio._core._run import Task
from trio._channel import ( from trio._channel import (
MemorySendChannel,
MemoryReceiveChannel, MemoryReceiveChannel,
# MemoryChannelState,
) )
@ -28,20 +25,25 @@ class Lagged(trio.TooSlowError):
class BroadcastReceiver(ReceiveChannel): class BroadcastReceiver(ReceiveChannel):
'''This isn't Paris, not Berlin, nor Honk Kong.. '''A memory receive channel broadcaster which is non-lossy for the
fastest consumer.
Additional consumer tasks can receive all produced values by registering
with ``.subscribe()``.
''' '''
def __init__( def __init__(
self, self,
rx_chan: MemoryReceiveChannel, rx_chan: MemoryReceiveChannel,
buffer_size: int = 100, queue: deque,
) -> None: ) -> None:
self._rx = rx_chan self._rx = rx_chan
self._len = buffer_size self._queue = queue
self._queue = deque(maxlen=buffer_size) self._subs: dict[Task, int] = {} # {id(current_task()): -1}
self._subs = {id(current_task()): -1} self._clones: dict[Task, MemoryReceiveChannel] = {}
self._value_received: Optional[trio.Event] = None self._value_received: Optional[trio.Event] = None
async def receive(self): async def receive(self):
@ -56,26 +58,30 @@ class BroadcastReceiver(ReceiveChannel):
try: try:
seq = self._subs[key] seq = self._subs[key]
except KeyError: except KeyError:
self._subs.pop(key)
raise RuntimeError( raise RuntimeError(
f'Task {task.name} is not registerd as subscriber') f'Task {task.name} is not registerd as subscriber')
if seq > -1: if seq > -1:
# get the oldest value we haven't received immediately # get the oldest value we haven't received immediately
try: try:
value = self._queue[seq] value = self._queue[seq]
except IndexError: except IndexError:
# decrement to the last value and expect
# consumer to either handle the ``Lagged`` and come back
# or bail out on it's own (thus un-subscribing)
self._subs[key] = self._queue.maxlen - 1
# this task was overrun by the producer side
raise Lagged(f'Task {task.name} was overrun') raise Lagged(f'Task {task.name} was overrun')
self._subs[key] -= 1 self._subs[key] -= 1
return value return value
if self._value_received is None: if self._value_received is None:
# we already have the latest value **and** are the first # current task already has the latest value **and** is the
# task to begin waiting for a new one # first task to begin waiting for a new one
# sanity checks with underlying chan ? # what sanity checks might we use for the underlying chan ?
# assert not self._rx._state.data # assert not self._rx._state.data
event = self._value_received = trio.Event() event = self._value_received = trio.Event()
@ -87,20 +93,15 @@ class BroadcastReceiver(ReceiveChannel):
# broadcast new value to all subscribers by increasing # broadcast new value to all subscribers by increasing
# all sequence numbers that will point in the queue to # all sequence numbers that will point in the queue to
# their latest available value. # their latest available value.
for sub_key, seq in self._subs.items():
if key == sub_key:
# we don't need to increase **this** task's
# sequence number since we just consumed the latest
# value
continue
# # except TypeError:
# # # already lagged
# # seq = Lagged
subs = self._subs.copy()
# don't decerement the sequence # for this task since we
# already retreived the last value
subs.pop(key)
for sub_key, seq in subs.items():
self._subs[sub_key] += 1 self._subs[sub_key] += 1
# reset receiver waiter task event for next blocking condition
self._value_received = None self._value_received = None
event.set() event.set()
return value return value
@ -109,7 +110,7 @@ class BroadcastReceiver(ReceiveChannel):
await self._value_received.wait() await self._value_received.wait()
seq = self._subs[key] seq = self._subs[key]
assert seq > -1, 'Uhhhh' assert seq > -1, 'Internal error?'
self._subs[key] -= 1 self._subs[key] -= 1
return self._queue[0] return self._queue[0]
@ -118,30 +119,37 @@ class BroadcastReceiver(ReceiveChannel):
@contextmanager @contextmanager
def subscribe( def subscribe(
self, self,
) -> BroadcastReceiver: ) -> BroadcastReceiver:
key = id(current_task()) key = id(current_task())
self._subs[key] = -1 self._subs[key] = -1
# XXX: we only use this clone for closure tracking
clone = self._clones[key] = self._rx.clone()
try: try:
yield self yield self
finally: finally:
self._subs.pop(key) self._subs.pop(key)
clone.close()
async def aclose(self) -> None: # TODO: do we need anything here?
# TODO: wtf should we do here?
# if we're the last sub to close then close # if we're the last sub to close then close
# the underlying rx channel # the underlying rx channel, but couldn't we just
pass # use ``.clone()``s trackign then?
async def aclose(self) -> None:
key = id(current_task())
await self._clones[key].aclose()
def broadcast_channel( def broadcast_receiver(
recv_chan: MemoryReceiveChannel,
max_buffer_size: int, max_buffer_size: int,
) -> (MemorySendChannel, BroadcastReceiver): ) -> BroadcastReceiver:
tx, rx = trio.open_memory_channel(max_buffer_size) return BroadcastReceiver(
return tx, BroadcastReceiver(rx) recv_chan,
queue=deque(maxlen=max_buffer_size),
)
if __name__ == '__main__': if __name__ == '__main__':
@ -153,7 +161,9 @@ if __name__ == '__main__':
# loglevel='info', # loglevel='info',
): ):
tx, rx = broadcast_channel(100) size = 100
tx, rx = trio.open_memory_channel(size)
rx = broadcast_receiver(rx, size)
async def sub_and_print( async def sub_and_print(
delay: float, delay: float,