Obviously keying on tasks isn't going to work
Using the current task as a subscription key fails horribly as soon as you hand off new subscription receiver to another task you've spawned.. Instead use the underlying ``trio.abc.ReceiveChannel.clone()`` as a key (so i guess we're assuming cloning is supported by the underlying?) which makes this all work just like default mem chans. As a bonus, now we can just close the underlying rx (which may be a clone) on `.aclose()` and everything should just work in terms of the underlying channels lifetime (i think?). Change `.subscribe()` to be async since the receive channel type interface only expects `.aclose()` and it actually ends up being nicer for 3.9+ style `async with` parentheses style anyway.live_on_air_from_tokio
parent
64358f6525
commit
4ad75a3287
|
@ -6,7 +6,7 @@ https://tokio-rs.github.io/tokio/doc/tokio/sync/broadcast/index.html
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from itertools import cycle
|
from itertools import cycle
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from contextlib import contextmanager
|
from contextlib import asynccontextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
@ -15,9 +15,6 @@ import tractor
|
||||||
from trio.lowlevel import current_task
|
from trio.lowlevel import current_task
|
||||||
from trio.abc import ReceiveChannel
|
from trio.abc import ReceiveChannel
|
||||||
from trio._core._run import Task
|
from trio._core._run import Task
|
||||||
# from trio._channel import (
|
|
||||||
# MemoryReceiveChannel,
|
|
||||||
# )
|
|
||||||
|
|
||||||
|
|
||||||
class Lagged(trio.TooSlowError):
|
class Lagged(trio.TooSlowError):
|
||||||
|
@ -29,57 +26,71 @@ class BroadcastReceiver(ReceiveChannel):
|
||||||
fastest consumer.
|
fastest consumer.
|
||||||
|
|
||||||
Additional consumer tasks can receive all produced values by registering
|
Additional consumer tasks can receive all produced values by registering
|
||||||
with ``.subscribe()``.
|
with ``.subscribe()`` and receiving from thew new instance it delivers.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
# map of underlying clones to receiver wrappers
|
||||||
|
_subs: dict[trio.ReceiveChannel, BroadcastReceiver] = {}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
||||||
rx_chan: MemoryReceiveChannel,
|
rx_chan: ReceiveChannel,
|
||||||
queue: deque,
|
queue: deque,
|
||||||
|
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
self._rx = rx_chan
|
self._rx = rx_chan
|
||||||
self._queue = queue
|
self._queue = queue
|
||||||
self._subs: dict[Task, int] = {} # {id(current_task()): -1}
|
|
||||||
self._clones: dict[Task, ReceiveChannel] = {}
|
|
||||||
self._value_received: Optional[trio.Event] = None
|
self._value_received: Optional[trio.Event] = None
|
||||||
|
|
||||||
async def receive(self):
|
async def receive(self):
|
||||||
|
|
||||||
task: Task = current_task()
|
key = self._rx
|
||||||
|
|
||||||
|
# TODO: ideally we can make some way to "lock out" the
|
||||||
|
# underlying receive channel in some way such that if some task
|
||||||
|
# tries to pull from it directly (i.e. one we're unaware of)
|
||||||
|
# then it errors out.
|
||||||
|
|
||||||
|
# only tasks which have entered ``.subscribe()`` can
|
||||||
|
# receive on this broadcaster.
|
||||||
|
try:
|
||||||
|
seq = self._subs[key]
|
||||||
|
except KeyError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f'{self} is not registerd as subscriber')
|
||||||
|
|
||||||
# check that task does not already have a value it can receive
|
# check that task does not already have a value it can receive
|
||||||
# immediately and/or that it has lagged.
|
# immediately and/or that it has lagged.
|
||||||
try:
|
|
||||||
seq = self._subs[task]
|
|
||||||
except KeyError:
|
|
||||||
raise RuntimeError(
|
|
||||||
f'Task {task.name} is not registerd as subscriber')
|
|
||||||
|
|
||||||
if seq > -1:
|
if seq > -1:
|
||||||
# get the oldest value we haven't received immediately
|
# get the oldest value we haven't received immediately
|
||||||
try:
|
try:
|
||||||
value = self._queue[seq]
|
value = self._queue[seq]
|
||||||
except IndexError:
|
except IndexError:
|
||||||
|
|
||||||
|
# adhere to ``tokio`` style "lagging":
|
||||||
|
# "Once RecvError::Lagged is returned, the lagging
|
||||||
|
# receiver's position is updated to the oldest value
|
||||||
|
# contained by the channel. The next call to recv will
|
||||||
|
# return this value."
|
||||||
|
# https://tokio-rs.github.io/tokio/doc/tokio/sync/broadcast/index.html#lagging
|
||||||
|
|
||||||
# decrement to the last value and expect
|
# decrement to the last value and expect
|
||||||
# consumer to either handle the ``Lagged`` and come back
|
# consumer to either handle the ``Lagged`` and come back
|
||||||
# or bail out on it's own (thus un-subscribing)
|
# or bail out on its own (thus un-subscribing)
|
||||||
self._subs[task] = self._queue.maxlen - 1
|
self._subs[key] = self._queue.maxlen - 1
|
||||||
|
|
||||||
# this task was overrun by the producer side
|
# this task was overrun by the producer side
|
||||||
|
task: Task = current_task()
|
||||||
raise Lagged(f'Task {task.name} was overrun')
|
raise Lagged(f'Task {task.name} was overrun')
|
||||||
|
|
||||||
self._subs[task] -= 1
|
self._subs[key] -= 1
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
# current task already has the latest value **and** is the
|
||||||
|
# first task to begin waiting for a new one
|
||||||
if self._value_received is None:
|
if self._value_received is None:
|
||||||
# current task already has the latest value **and** is the
|
|
||||||
# first task to begin waiting for a new one
|
|
||||||
|
|
||||||
# what sanity checks might we use for the underlying chan ?
|
|
||||||
# assert not self._rx._state.data
|
|
||||||
|
|
||||||
event = self._value_received = trio.Event()
|
event = self._value_received = trio.Event()
|
||||||
value = await self._rx.receive()
|
value = await self._rx.receive()
|
||||||
|
@ -92,9 +103,9 @@ class BroadcastReceiver(ReceiveChannel):
|
||||||
# their latest available value.
|
# their latest available value.
|
||||||
|
|
||||||
subs = self._subs.copy()
|
subs = self._subs.copy()
|
||||||
# don't decerement the sequence # for this task since we
|
# don't decrement the sequence # for this task since we
|
||||||
# already retreived the last value
|
# already retreived the last value
|
||||||
subs.pop(task)
|
subs.pop(key)
|
||||||
for sub_key, seq in subs.items():
|
for sub_key, seq in subs.items():
|
||||||
self._subs[sub_key] += 1
|
self._subs[sub_key] += 1
|
||||||
|
|
||||||
|
@ -103,37 +114,56 @@ class BroadcastReceiver(ReceiveChannel):
|
||||||
event.set()
|
event.set()
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
# This task is all caught up and ready to receive the latest
|
||||||
|
# value, so queue sched it on the internal event.
|
||||||
else:
|
else:
|
||||||
await self._value_received.wait()
|
await self._value_received.wait()
|
||||||
|
|
||||||
seq = self._subs[task]
|
seq = self._subs[key]
|
||||||
assert seq > -1, 'Internal error?'
|
assert seq > -1, 'Internal error?'
|
||||||
|
|
||||||
self._subs[task] -= 1
|
self._subs[key] -= 1
|
||||||
return self._queue[0]
|
return self._queue[0]
|
||||||
|
|
||||||
# @asynccontextmanager
|
@asynccontextmanager
|
||||||
@contextmanager
|
async def subscribe(
|
||||||
def subscribe(
|
|
||||||
self,
|
self,
|
||||||
) -> BroadcastReceiver:
|
) -> BroadcastReceiver:
|
||||||
task: task = current_task()
|
'''Subscribe for values from this broadcast receiver.
|
||||||
self._subs[task] = -1
|
|
||||||
# XXX: we only use this clone for closure tracking
|
|
||||||
clone = self._clones[task] = self._rx.clone()
|
|
||||||
try:
|
|
||||||
yield self
|
|
||||||
finally:
|
|
||||||
self._subs.pop(task)
|
|
||||||
clone.close()
|
|
||||||
|
|
||||||
# TODO: do we need anything here?
|
Returns a new ``BroadCastReceiver`` which is registered for and
|
||||||
# if we're the last sub to close then close
|
pulls data from a clone of the original ``trio.abc.ReceiveChannel``
|
||||||
# the underlying rx channel, but couldn't we just
|
provided at creation.
|
||||||
# use ``.clone()``s trackign then?
|
|
||||||
async def aclose(self) -> None:
|
'''
|
||||||
task: Task = current_task()
|
clone = self._rx.clone()
|
||||||
await self._clones[task].aclose()
|
self._subs[clone] = -1
|
||||||
|
try:
|
||||||
|
yield BroadcastReceiver(
|
||||||
|
clone,
|
||||||
|
self._queue,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
# drop from subscribers and close
|
||||||
|
self._subs.pop(clone)
|
||||||
|
# XXX: this is the reason this function is async: the
|
||||||
|
# ``AsyncResource`` api.
|
||||||
|
await clone.aclose()
|
||||||
|
|
||||||
|
# TODO:
|
||||||
|
# - should there be some ._closed flag that causes
|
||||||
|
# consumers to die **before** they read all queued values?
|
||||||
|
# - if subs only open and close clones then the underlying
|
||||||
|
# will never be killed until the last instance closes?
|
||||||
|
# This is correct right?
|
||||||
|
async def aclose(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
|
# XXX: leaving it like this consumers can still get values
|
||||||
|
# up to the last received that still reside in the queue.
|
||||||
|
# Is this what we want?
|
||||||
|
await self._rx.aclose()
|
||||||
|
self._subs.pop(self._rx)
|
||||||
|
|
||||||
|
|
||||||
def broadcast_receiver(
|
def broadcast_receiver(
|
||||||
|
@ -158,6 +188,7 @@ if __name__ == '__main__':
|
||||||
# loglevel='info',
|
# loglevel='info',
|
||||||
):
|
):
|
||||||
|
|
||||||
|
retries = 3
|
||||||
size = 100
|
size = 100
|
||||||
tx, rx = trio.open_memory_channel(size)
|
tx, rx = trio.open_memory_channel(size)
|
||||||
rx = broadcast_receiver(rx, size)
|
rx = broadcast_receiver(rx, size)
|
||||||
|
@ -170,9 +201,9 @@ if __name__ == '__main__':
|
||||||
count = 0
|
count = 0
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
with rx.subscribe():
|
async with rx.subscribe() as brx:
|
||||||
try:
|
try:
|
||||||
async for value in rx:
|
async for value in brx:
|
||||||
print(f'{task.name}: {value}')
|
print(f'{task.name}: {value}')
|
||||||
await trio.sleep(delay)
|
await trio.sleep(delay)
|
||||||
count += 1
|
count += 1
|
||||||
|
@ -181,10 +212,16 @@ if __name__ == '__main__':
|
||||||
print(
|
print(
|
||||||
f'restarting slow ass {task.name}'
|
f'restarting slow ass {task.name}'
|
||||||
f'that bailed out on {count}:{value}')
|
f'that bailed out on {count}:{value}')
|
||||||
continue
|
if count <= retries:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f'{task.name} was too slow and terminated '
|
||||||
|
f'on {count}:{value}')
|
||||||
|
return
|
||||||
|
|
||||||
async with trio.open_nursery() as n:
|
async with trio.open_nursery() as n:
|
||||||
for i in range(1, 10):
|
for i in range(1, size):
|
||||||
n.start_soon(
|
n.start_soon(
|
||||||
partial(
|
partial(
|
||||||
sub_and_print,
|
sub_and_print,
|
||||||
|
@ -194,7 +231,7 @@ if __name__ == '__main__':
|
||||||
)
|
)
|
||||||
|
|
||||||
async with tx:
|
async with tx:
|
||||||
for i in cycle(range(1000)):
|
for i in cycle(range(size)):
|
||||||
print(f'sending: {i}')
|
print(f'sending: {i}')
|
||||||
await tx.send(i)
|
await tx.send(i)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue