Simplify api around receive channel
Buncha improvements: - pass in the queue via constructor - tracking over all underlying memory channel closure using cloning - do it like `tokio` and set lagged consumers to the last sequence before raising - copy the subs on first receiver wakeup for iteration instead of iterating the table directly (and being forced to skip the current tasks sequence increment) - implement `.aclose()` to close the underlying clone for this task - make `broadcast_receiver()` just take the recv chan since it doesn't need anything on the send side.tokio_backup
parent
af6e8a64ad
commit
dfc4082ad2
|
@ -1,25 +1,22 @@
|
||||||
'''
|
'''
|
||||||
``tokio`` style broadcast channels.
|
``tokio`` style broadcast channel.
|
||||||
|
https://tokio-rs.github.io/tokio/doc/tokio/sync/broadcast/index.html
|
||||||
|
|
||||||
'''
|
'''
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
# from math import inf
|
|
||||||
from itertools import cycle
|
from itertools import cycle
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from contextlib import contextmanager # , asynccontextmanager
|
from contextlib import contextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
import tractor
|
import tractor
|
||||||
from trio.lowlevel import current_task
|
from trio.lowlevel import current_task
|
||||||
from trio.abc import ReceiveChannel # , SendChannel
|
from trio.abc import ReceiveChannel
|
||||||
# from trio._core import enable_ki_protection
|
|
||||||
from trio._core._run import Task
|
from trio._core._run import Task
|
||||||
from trio._channel import (
|
from trio._channel import (
|
||||||
MemorySendChannel,
|
|
||||||
MemoryReceiveChannel,
|
MemoryReceiveChannel,
|
||||||
# MemoryChannelState,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,20 +25,25 @@ class Lagged(trio.TooSlowError):
|
||||||
|
|
||||||
|
|
||||||
class BroadcastReceiver(ReceiveChannel):
|
class BroadcastReceiver(ReceiveChannel):
|
||||||
'''This isn't Paris, not Berlin, nor Honk Kong..
|
'''A memory receive channel broadcaster which is non-lossy for the
|
||||||
|
fastest consumer.
|
||||||
|
|
||||||
|
Additional consumer tasks can receive all produced values by registering
|
||||||
|
with ``.subscribe()``.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
||||||
rx_chan: MemoryReceiveChannel,
|
rx_chan: MemoryReceiveChannel,
|
||||||
buffer_size: int = 100,
|
queue: deque,
|
||||||
|
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
self._rx = rx_chan
|
self._rx = rx_chan
|
||||||
self._len = buffer_size
|
self._queue = queue
|
||||||
self._queue = deque(maxlen=buffer_size)
|
self._subs: dict[Task, int] = {} # {id(current_task()): -1}
|
||||||
self._subs = {id(current_task()): -1}
|
self._clones: dict[Task, MemoryReceiveChannel] = {}
|
||||||
self._value_received: Optional[trio.Event] = None
|
self._value_received: Optional[trio.Event] = None
|
||||||
|
|
||||||
async def receive(self):
|
async def receive(self):
|
||||||
|
@ -56,26 +58,30 @@ class BroadcastReceiver(ReceiveChannel):
|
||||||
try:
|
try:
|
||||||
seq = self._subs[key]
|
seq = self._subs[key]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
self._subs.pop(key)
|
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f'Task {task.name} is not registerd as subscriber')
|
f'Task {task.name} is not registerd as subscriber')
|
||||||
|
|
||||||
if seq > -1:
|
if seq > -1:
|
||||||
# get the oldest value we haven't received immediately
|
# get the oldest value we haven't received immediately
|
||||||
|
|
||||||
try:
|
try:
|
||||||
value = self._queue[seq]
|
value = self._queue[seq]
|
||||||
except IndexError:
|
except IndexError:
|
||||||
|
# decrement to the last value and expect
|
||||||
|
# consumer to either handle the ``Lagged`` and come back
|
||||||
|
# or bail out on it's own (thus un-subscribing)
|
||||||
|
self._subs[key] = self._queue.maxlen - 1
|
||||||
|
|
||||||
|
# this task was overrun by the producer side
|
||||||
raise Lagged(f'Task {task.name} was overrun')
|
raise Lagged(f'Task {task.name} was overrun')
|
||||||
|
|
||||||
self._subs[key] -= 1
|
self._subs[key] -= 1
|
||||||
return value
|
return value
|
||||||
|
|
||||||
if self._value_received is None:
|
if self._value_received is None:
|
||||||
# we already have the latest value **and** are the first
|
# current task already has the latest value **and** is the
|
||||||
# task to begin waiting for a new one
|
# first task to begin waiting for a new one
|
||||||
|
|
||||||
# sanity checks with underlying chan ?
|
# what sanity checks might we use for the underlying chan ?
|
||||||
# assert not self._rx._state.data
|
# assert not self._rx._state.data
|
||||||
|
|
||||||
event = self._value_received = trio.Event()
|
event = self._value_received = trio.Event()
|
||||||
|
@ -87,20 +93,15 @@ class BroadcastReceiver(ReceiveChannel):
|
||||||
# broadcast new value to all subscribers by increasing
|
# broadcast new value to all subscribers by increasing
|
||||||
# all sequence numbers that will point in the queue to
|
# all sequence numbers that will point in the queue to
|
||||||
# their latest available value.
|
# their latest available value.
|
||||||
for sub_key, seq in self._subs.items():
|
|
||||||
|
|
||||||
if key == sub_key:
|
|
||||||
# we don't need to increase **this** task's
|
|
||||||
# sequence number since we just consumed the latest
|
|
||||||
# value
|
|
||||||
continue
|
|
||||||
|
|
||||||
# # except TypeError:
|
|
||||||
# # # already lagged
|
|
||||||
# # seq = Lagged
|
|
||||||
|
|
||||||
|
subs = self._subs.copy()
|
||||||
|
# don't decerement the sequence # for this task since we
|
||||||
|
# already retreived the last value
|
||||||
|
subs.pop(key)
|
||||||
|
for sub_key, seq in subs.items():
|
||||||
self._subs[sub_key] += 1
|
self._subs[sub_key] += 1
|
||||||
|
|
||||||
|
# reset receiver waiter task event for next blocking condition
|
||||||
self._value_received = None
|
self._value_received = None
|
||||||
event.set()
|
event.set()
|
||||||
return value
|
return value
|
||||||
|
@ -109,7 +110,7 @@ class BroadcastReceiver(ReceiveChannel):
|
||||||
await self._value_received.wait()
|
await self._value_received.wait()
|
||||||
|
|
||||||
seq = self._subs[key]
|
seq = self._subs[key]
|
||||||
assert seq > -1, 'Uhhhh'
|
assert seq > -1, 'Internal error?'
|
||||||
|
|
||||||
self._subs[key] -= 1
|
self._subs[key] -= 1
|
||||||
return self._queue[0]
|
return self._queue[0]
|
||||||
|
@ -118,30 +119,37 @@ class BroadcastReceiver(ReceiveChannel):
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def subscribe(
|
def subscribe(
|
||||||
self,
|
self,
|
||||||
|
|
||||||
) -> BroadcastReceiver:
|
) -> BroadcastReceiver:
|
||||||
key = id(current_task())
|
key = id(current_task())
|
||||||
self._subs[key] = -1
|
self._subs[key] = -1
|
||||||
|
# XXX: we only use this clone for closure tracking
|
||||||
|
clone = self._clones[key] = self._rx.clone()
|
||||||
try:
|
try:
|
||||||
yield self
|
yield self
|
||||||
finally:
|
finally:
|
||||||
self._subs.pop(key)
|
self._subs.pop(key)
|
||||||
|
clone.close()
|
||||||
|
|
||||||
|
# TODO: do we need anything here?
|
||||||
|
# if we're the last sub to close then close
|
||||||
|
# the underlying rx channel, but couldn't we just
|
||||||
|
# use ``.clone()``s trackign then?
|
||||||
async def aclose(self) -> None:
|
async def aclose(self) -> None:
|
||||||
# TODO: wtf should we do here?
|
key = id(current_task())
|
||||||
# if we're the last sub to close then close
|
await self._clones[key].aclose()
|
||||||
# the underlying rx channel
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def broadcast_channel(
|
def broadcast_receiver(
|
||||||
|
|
||||||
|
recv_chan: MemoryReceiveChannel,
|
||||||
max_buffer_size: int,
|
max_buffer_size: int,
|
||||||
|
|
||||||
) -> (MemorySendChannel, BroadcastReceiver):
|
) -> BroadcastReceiver:
|
||||||
|
|
||||||
tx, rx = trio.open_memory_channel(max_buffer_size)
|
return BroadcastReceiver(
|
||||||
return tx, BroadcastReceiver(rx)
|
recv_chan,
|
||||||
|
queue=deque(maxlen=max_buffer_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -153,7 +161,9 @@ if __name__ == '__main__':
|
||||||
# loglevel='info',
|
# loglevel='info',
|
||||||
):
|
):
|
||||||
|
|
||||||
tx, rx = broadcast_channel(100)
|
size = 100
|
||||||
|
tx, rx = trio.open_memory_channel(size)
|
||||||
|
rx = broadcast_receiver(rx, size)
|
||||||
|
|
||||||
async def sub_and_print(
|
async def sub_and_print(
|
||||||
delay: float,
|
delay: float,
|
||||||
|
|
Loading…
Reference in New Issue