Simplify api around receive channel

Buncha improvements:
- pass in the queue via constructor
- tracking over all underlying memory channel closure using cloning
- do it like `tokio` and set lagged consumers to the last sequence
  before raising
- copy the subs on first receiver wakeup for iteration instead of
  iterating the table directly (and being forced to skip the current
  tasks sequence increment)
- implement `.aclose()` to close the underlying clone for this task
- make `broadcast_receiver()` just take the recv chan since it doesn't
  need anything on the send side.
live_on_air_from_tokio
Tyler Goodlet 2021-08-08 19:48:02 -04:00
parent 3817b4fb5e
commit 6a2c3da1bb
1 changed files with 49 additions and 39 deletions

View File

@ -1,25 +1,22 @@
'''
``tokio`` style broadcast channels.
``tokio`` style broadcast channel.
https://tokio-rs.github.io/tokio/doc/tokio/sync/broadcast/index.html
'''
from __future__ import annotations
# from math import inf
from itertools import cycle
from collections import deque
from contextlib import contextmanager # , asynccontextmanager
from contextlib import contextmanager
from functools import partial
from typing import Optional
import trio
import tractor
from trio.lowlevel import current_task
from trio.abc import ReceiveChannel # , SendChannel
# from trio._core import enable_ki_protection
from trio.abc import ReceiveChannel
from trio._core._run import Task
from trio._channel import (
MemorySendChannel,
MemoryReceiveChannel,
# MemoryChannelState,
)
@ -28,20 +25,25 @@ class Lagged(trio.TooSlowError):
class BroadcastReceiver(ReceiveChannel):
'''This isn't Paris, not Berlin, nor Honk Kong..
'''A memory receive channel broadcaster which is non-lossy for the
fastest consumer.
Additional consumer tasks can receive all produced values by registering
with ``.subscribe()``.
'''
def __init__(
self,
rx_chan: MemoryReceiveChannel,
buffer_size: int = 100,
queue: deque,
) -> None:
self._rx = rx_chan
self._len = buffer_size
self._queue = deque(maxlen=buffer_size)
self._subs = {id(current_task()): -1}
self._queue = queue
self._subs: dict[Task, int] = {} # {id(current_task()): -1}
self._clones: dict[Task, MemoryReceiveChannel] = {}
self._value_received: Optional[trio.Event] = None
async def receive(self):
@ -56,26 +58,30 @@ class BroadcastReceiver(ReceiveChannel):
try:
seq = self._subs[key]
except KeyError:
self._subs.pop(key)
raise RuntimeError(
f'Task {task.name} is not registerd as subscriber')
if seq > -1:
# get the oldest value we haven't received immediately
try:
value = self._queue[seq]
except IndexError:
# decrement to the last value and expect
# consumer to either handle the ``Lagged`` and come back
# or bail out on it's own (thus un-subscribing)
self._subs[key] = self._queue.maxlen - 1
# this task was overrun by the producer side
raise Lagged(f'Task {task.name} was overrun')
self._subs[key] -= 1
return value
if self._value_received is None:
# we already have the latest value **and** are the first
# task to begin waiting for a new one
# current task already has the latest value **and** is the
# first task to begin waiting for a new one
# sanity checks with underlying chan ?
# what sanity checks might we use for the underlying chan ?
# assert not self._rx._state.data
event = self._value_received = trio.Event()
@ -87,20 +93,15 @@ class BroadcastReceiver(ReceiveChannel):
# broadcast new value to all subscribers by increasing
# all sequence numbers that will point in the queue to
# their latest available value.
for sub_key, seq in self._subs.items():
if key == sub_key:
# we don't need to increase **this** task's
# sequence number since we just consumed the latest
# value
continue
# # except TypeError:
# # # already lagged
# # seq = Lagged
subs = self._subs.copy()
# don't decerement the sequence # for this task since we
# already retreived the last value
subs.pop(key)
for sub_key, seq in subs.items():
self._subs[sub_key] += 1
# reset receiver waiter task event for next blocking condition
self._value_received = None
event.set()
return value
@ -109,7 +110,7 @@ class BroadcastReceiver(ReceiveChannel):
await self._value_received.wait()
seq = self._subs[key]
assert seq > -1, 'Uhhhh'
assert seq > -1, 'Internal error?'
self._subs[key] -= 1
return self._queue[0]
@ -118,30 +119,37 @@ class BroadcastReceiver(ReceiveChannel):
@contextmanager
def subscribe(
self,
) -> BroadcastReceiver:
key = id(current_task())
self._subs[key] = -1
# XXX: we only use this clone for closure tracking
clone = self._clones[key] = self._rx.clone()
try:
yield self
finally:
self._subs.pop(key)
clone.close()
async def aclose(self) -> None:
# TODO: wtf should we do here?
# TODO: do we need anything here?
# if we're the last sub to close then close
# the underlying rx channel
pass
# the underlying rx channel, but couldn't we just
# use ``.clone()``s trackign then?
async def aclose(self) -> None:
key = id(current_task())
await self._clones[key].aclose()
def broadcast_channel(
def broadcast_receiver(
recv_chan: MemoryReceiveChannel,
max_buffer_size: int,
) -> (MemorySendChannel, BroadcastReceiver):
) -> BroadcastReceiver:
tx, rx = trio.open_memory_channel(max_buffer_size)
return tx, BroadcastReceiver(rx)
return BroadcastReceiver(
recv_chan,
queue=deque(maxlen=max_buffer_size),
)
if __name__ == '__main__':
@ -153,7 +161,9 @@ if __name__ == '__main__':
# loglevel='info',
):
tx, rx = broadcast_channel(100)
size = 100
tx, rx = trio.open_memory_channel(size)
rx = broadcast_receiver(rx, size)
async def sub_and_print(
delay: float,