Obviously keying on tasks isn't going to work

Using the current task as a subscription key fails horribly as soon as
you hand off new subscription receiver to another task you've spawned..

Instead use the underlying ``trio.abc.ReceiveChannel.clone()`` as a key
(so i guess we're assuming cloning is supported by the underlying?)
which makes this all work just like default mem chans. As a bonus, now
we can just close the underlying rx (which may be a clone) on
`.aclose()` and everything should just work in terms of the underlying
channels lifetime (i think?).

Change `.subscribe()` to be async since the receive channel type
interface only expects `.aclose()` and it actually ends up being
nicer for 3.9+ style `async with` parentheses style anyway.
tokio_backup
Tyler Goodlet 2021-08-09 16:40:02 -04:00
parent eeca3d0d50
commit 3f9b860210
1 changed files with 88 additions and 51 deletions

View File

@ -6,7 +6,7 @@ https://tokio-rs.github.io/tokio/doc/tokio/sync/broadcast/index.html
from __future__ import annotations from __future__ import annotations
from itertools import cycle from itertools import cycle
from collections import deque from collections import deque
from contextlib import contextmanager from contextlib import asynccontextmanager
from functools import partial from functools import partial
from typing import Optional from typing import Optional
@ -15,9 +15,6 @@ import tractor
from trio.lowlevel import current_task from trio.lowlevel import current_task
from trio.abc import ReceiveChannel from trio.abc import ReceiveChannel
from trio._core._run import Task from trio._core._run import Task
# from trio._channel import (
# MemoryReceiveChannel,
# )
class Lagged(trio.TooSlowError): class Lagged(trio.TooSlowError):
@ -29,57 +26,71 @@ class BroadcastReceiver(ReceiveChannel):
fastest consumer. fastest consumer.
Additional consumer tasks can receive all produced values by registering Additional consumer tasks can receive all produced values by registering
with ``.subscribe()``. with ``.subscribe()`` and receiving from thew new instance it delivers.
''' '''
# map of underlying clones to receiver wrappers
_subs: dict[trio.ReceiveChannel, BroadcastReceiver] = {}
def __init__( def __init__(
self, self,
rx_chan: MemoryReceiveChannel, rx_chan: ReceiveChannel,
queue: deque, queue: deque,
) -> None: ) -> None:
self._rx = rx_chan self._rx = rx_chan
self._queue = queue self._queue = queue
self._subs: dict[Task, int] = {} # {id(current_task()): -1}
self._clones: dict[Task, ReceiveChannel] = {}
self._value_received: Optional[trio.Event] = None self._value_received: Optional[trio.Event] = None
async def receive(self): async def receive(self):
task: Task = current_task() key = self._rx
# TODO: ideally we can make some way to "lock out" the
# underlying receive channel in some way such that if some task
# tries to pull from it directly (i.e. one we're unaware of)
# then it errors out.
# only tasks which have entered ``.subscribe()`` can
# receive on this broadcaster.
try:
seq = self._subs[key]
except KeyError:
raise RuntimeError(
f'{self} is not registerd as subscriber')
# check that task does not already have a value it can receive # check that task does not already have a value it can receive
# immediately and/or that it has lagged. # immediately and/or that it has lagged.
try:
seq = self._subs[task]
except KeyError:
raise RuntimeError(
f'Task {task.name} is not registerd as subscriber')
if seq > -1: if seq > -1:
# get the oldest value we haven't received immediately # get the oldest value we haven't received immediately
try: try:
value = self._queue[seq] value = self._queue[seq]
except IndexError: except IndexError:
# adhere to ``tokio`` style "lagging":
# "Once RecvError::Lagged is returned, the lagging
# receiver's position is updated to the oldest value
# contained by the channel. The next call to recv will
# return this value."
# https://tokio-rs.github.io/tokio/doc/tokio/sync/broadcast/index.html#lagging
# decrement to the last value and expect # decrement to the last value and expect
# consumer to either handle the ``Lagged`` and come back # consumer to either handle the ``Lagged`` and come back
# or bail out on it's own (thus un-subscribing) # or bail out on its own (thus un-subscribing)
self._subs[task] = self._queue.maxlen - 1 self._subs[key] = self._queue.maxlen - 1
# this task was overrun by the producer side # this task was overrun by the producer side
task: Task = current_task()
raise Lagged(f'Task {task.name} was overrun') raise Lagged(f'Task {task.name} was overrun')
self._subs[task] -= 1 self._subs[key] -= 1
return value return value
if self._value_received is None:
# current task already has the latest value **and** is the # current task already has the latest value **and** is the
# first task to begin waiting for a new one # first task to begin waiting for a new one
if self._value_received is None:
# what sanity checks might we use for the underlying chan ?
# assert not self._rx._state.data
event = self._value_received = trio.Event() event = self._value_received = trio.Event()
value = await self._rx.receive() value = await self._rx.receive()
@ -92,9 +103,9 @@ class BroadcastReceiver(ReceiveChannel):
# their latest available value. # their latest available value.
subs = self._subs.copy() subs = self._subs.copy()
# don't decerement the sequence # for this task since we # don't decrement the sequence # for this task since we
# already retreived the last value # already retreived the last value
subs.pop(task) subs.pop(key)
for sub_key, seq in subs.items(): for sub_key, seq in subs.items():
self._subs[sub_key] += 1 self._subs[sub_key] += 1
@ -103,37 +114,56 @@ class BroadcastReceiver(ReceiveChannel):
event.set() event.set()
return value return value
# This task is all caught up and ready to receive the latest
# value, so queue sched it on the internal event.
else: else:
await self._value_received.wait() await self._value_received.wait()
seq = self._subs[task] seq = self._subs[key]
assert seq > -1, 'Internal error?' assert seq > -1, 'Internal error?'
self._subs[task] -= 1 self._subs[key] -= 1
return self._queue[0] return self._queue[0]
# @asynccontextmanager @asynccontextmanager
@contextmanager async def subscribe(
def subscribe(
self, self,
) -> BroadcastReceiver: ) -> BroadcastReceiver:
task: task = current_task() '''Subscribe for values from this broadcast receiver.
self._subs[task] = -1
# XXX: we only use this clone for closure tracking
clone = self._clones[task] = self._rx.clone()
try:
yield self
finally:
self._subs.pop(task)
clone.close()
# TODO: do we need anything here? Returns a new ``BroadCastReceiver`` which is registered for and
# if we're the last sub to close then close pulls data from a clone of the original ``trio.abc.ReceiveChannel``
# the underlying rx channel, but couldn't we just provided at creation.
# use ``.clone()``s trackign then?
async def aclose(self) -> None: '''
task: Task = current_task() clone = self._rx.clone()
await self._clones[task].aclose() self._subs[clone] = -1
try:
yield BroadcastReceiver(
clone,
self._queue,
)
finally:
# drop from subscribers and close
self._subs.pop(clone)
# XXX: this is the reason this function is async: the
# ``AsyncResource`` api.
await clone.aclose()
# TODO:
# - should there be some ._closed flag that causes
# consumers to die **before** they read all queued values?
# - if subs only open and close clones then the underlying
# will never be killed until the last instance closes?
# This is correct right?
async def aclose(
self,
) -> None:
# XXX: leaving it like this consumers can still get values
# up to the last received that still reside in the queue.
# Is this what we want?
await self._rx.aclose()
self._subs.pop(self._rx)
def broadcast_receiver( def broadcast_receiver(
@ -158,6 +188,7 @@ if __name__ == '__main__':
# loglevel='info', # loglevel='info',
): ):
retries = 3
size = 100 size = 100
tx, rx = trio.open_memory_channel(size) tx, rx = trio.open_memory_channel(size)
rx = broadcast_receiver(rx, size) rx = broadcast_receiver(rx, size)
@ -170,9 +201,9 @@ if __name__ == '__main__':
count = 0 count = 0
while True: while True:
with rx.subscribe(): async with rx.subscribe() as brx:
try: try:
async for value in rx: async for value in brx:
print(f'{task.name}: {value}') print(f'{task.name}: {value}')
await trio.sleep(delay) await trio.sleep(delay)
count += 1 count += 1
@ -181,10 +212,16 @@ if __name__ == '__main__':
print( print(
f'restarting slow ass {task.name}' f'restarting slow ass {task.name}'
f'that bailed out on {count}:{value}') f'that bailed out on {count}:{value}')
if count <= retries:
continue continue
else:
print(
f'{task.name} was too slow and terminated '
f'on {count}:{value}')
return
async with trio.open_nursery() as n: async with trio.open_nursery() as n:
for i in range(1, 10): for i in range(1, size):
n.start_soon( n.start_soon(
partial( partial(
sub_and_print, sub_and_print,
@ -194,7 +231,7 @@ if __name__ == '__main__':
) )
async with tx: async with tx:
for i in cycle(range(1000)): for i in cycle(range(size)):
print(f'sending: {i}') print(f'sending: {i}')
await tx.send(i) await tx.send(i)