forked from goodboy/tractor
1
0
Fork 0

Obviously keying on tasks isn't going to work

Using the current task as a subscription key fails horribly as soon as
you hand off new subscription receiver to another task you've spawned..

Instead use the underlying ``trio.abc.ReceiveChannel.clone()`` as a key
(so i guess we're assuming cloning is supported by the underlying?)
which makes this all work just like default mem chans. As a bonus, now
we can just close the underlying rx (which may be a clone) on
`.aclose()` and everything should just work in terms of the underlying
channels lifetime (i think?).

Change `.subscribe()` to be async since the receive channel type
interface only expects `.aclose()` and it actually ends up being
nicer for 3.9+ style `async with` parentheses style anyway.
live_on_air_from_tokio
Tyler Goodlet 2021-08-09 16:40:02 -04:00
parent 64358f6525
commit 4ad75a3287
1 changed files with 88 additions and 51 deletions

View File

@ -6,7 +6,7 @@ https://tokio-rs.github.io/tokio/doc/tokio/sync/broadcast/index.html
from __future__ import annotations
from itertools import cycle
from collections import deque
from contextlib import contextmanager
from contextlib import asynccontextmanager
from functools import partial
from typing import Optional
@ -15,9 +15,6 @@ import tractor
from trio.lowlevel import current_task
from trio.abc import ReceiveChannel
from trio._core._run import Task
# from trio._channel import (
# MemoryReceiveChannel,
# )
class Lagged(trio.TooSlowError):
@ -29,57 +26,71 @@ class BroadcastReceiver(ReceiveChannel):
fastest consumer.
Additional consumer tasks can receive all produced values by registering
with ``.subscribe()``.
with ``.subscribe()`` and receiving from thew new instance it delivers.
'''
# map of underlying clones to receiver wrappers
_subs: dict[trio.ReceiveChannel, BroadcastReceiver] = {}
def __init__(
self,
rx_chan: MemoryReceiveChannel,
rx_chan: ReceiveChannel,
queue: deque,
) -> None:
self._rx = rx_chan
self._queue = queue
self._subs: dict[Task, int] = {} # {id(current_task()): -1}
self._clones: dict[Task, ReceiveChannel] = {}
self._value_received: Optional[trio.Event] = None
async def receive(self):
task: Task = current_task()
key = self._rx
# TODO: ideally we can make some way to "lock out" the
# underlying receive channel in some way such that if some task
# tries to pull from it directly (i.e. one we're unaware of)
# then it errors out.
# only tasks which have entered ``.subscribe()`` can
# receive on this broadcaster.
try:
seq = self._subs[key]
except KeyError:
raise RuntimeError(
f'{self} is not registerd as subscriber')
# check that task does not already have a value it can receive
# immediately and/or that it has lagged.
try:
seq = self._subs[task]
except KeyError:
raise RuntimeError(
f'Task {task.name} is not registerd as subscriber')
if seq > -1:
# get the oldest value we haven't received immediately
try:
value = self._queue[seq]
except IndexError:
# adhere to ``tokio`` style "lagging":
# "Once RecvError::Lagged is returned, the lagging
# receiver's position is updated to the oldest value
# contained by the channel. The next call to recv will
# return this value."
# https://tokio-rs.github.io/tokio/doc/tokio/sync/broadcast/index.html#lagging
# decrement to the last value and expect
# consumer to either handle the ``Lagged`` and come back
# or bail out on it's own (thus un-subscribing)
self._subs[task] = self._queue.maxlen - 1
# or bail out on its own (thus un-subscribing)
self._subs[key] = self._queue.maxlen - 1
# this task was overrun by the producer side
task: Task = current_task()
raise Lagged(f'Task {task.name} was overrun')
self._subs[task] -= 1
self._subs[key] -= 1
return value
# current task already has the latest value **and** is the
# first task to begin waiting for a new one
if self._value_received is None:
# current task already has the latest value **and** is the
# first task to begin waiting for a new one
# what sanity checks might we use for the underlying chan ?
# assert not self._rx._state.data
event = self._value_received = trio.Event()
value = await self._rx.receive()
@ -92,9 +103,9 @@ class BroadcastReceiver(ReceiveChannel):
# their latest available value.
subs = self._subs.copy()
# don't decerement the sequence # for this task since we
# don't decrement the sequence # for this task since we
# already retreived the last value
subs.pop(task)
subs.pop(key)
for sub_key, seq in subs.items():
self._subs[sub_key] += 1
@ -103,37 +114,56 @@ class BroadcastReceiver(ReceiveChannel):
event.set()
return value
# This task is all caught up and ready to receive the latest
# value, so queue sched it on the internal event.
else:
await self._value_received.wait()
seq = self._subs[task]
seq = self._subs[key]
assert seq > -1, 'Internal error?'
self._subs[task] -= 1
self._subs[key] -= 1
return self._queue[0]
# @asynccontextmanager
@contextmanager
def subscribe(
@asynccontextmanager
async def subscribe(
self,
) -> BroadcastReceiver:
task: task = current_task()
self._subs[task] = -1
# XXX: we only use this clone for closure tracking
clone = self._clones[task] = self._rx.clone()
try:
yield self
finally:
self._subs.pop(task)
clone.close()
'''Subscribe for values from this broadcast receiver.
# TODO: do we need anything here?
# if we're the last sub to close then close
# the underlying rx channel, but couldn't we just
# use ``.clone()``s trackign then?
async def aclose(self) -> None:
task: Task = current_task()
await self._clones[task].aclose()
Returns a new ``BroadCastReceiver`` which is registered for and
pulls data from a clone of the original ``trio.abc.ReceiveChannel``
provided at creation.
'''
clone = self._rx.clone()
self._subs[clone] = -1
try:
yield BroadcastReceiver(
clone,
self._queue,
)
finally:
# drop from subscribers and close
self._subs.pop(clone)
# XXX: this is the reason this function is async: the
# ``AsyncResource`` api.
await clone.aclose()
# TODO:
# - should there be some ._closed flag that causes
# consumers to die **before** they read all queued values?
# - if subs only open and close clones then the underlying
# will never be killed until the last instance closes?
# This is correct right?
async def aclose(
self,
) -> None:
# XXX: leaving it like this consumers can still get values
# up to the last received that still reside in the queue.
# Is this what we want?
await self._rx.aclose()
self._subs.pop(self._rx)
def broadcast_receiver(
@ -158,6 +188,7 @@ if __name__ == '__main__':
# loglevel='info',
):
retries = 3
size = 100
tx, rx = trio.open_memory_channel(size)
rx = broadcast_receiver(rx, size)
@ -170,9 +201,9 @@ if __name__ == '__main__':
count = 0
while True:
with rx.subscribe():
async with rx.subscribe() as brx:
try:
async for value in rx:
async for value in brx:
print(f'{task.name}: {value}')
await trio.sleep(delay)
count += 1
@ -181,10 +212,16 @@ if __name__ == '__main__':
print(
f'restarting slow ass {task.name}'
f'that bailed out on {count}:{value}')
continue
if count <= retries:
continue
else:
print(
f'{task.name} was too slow and terminated '
f'on {count}:{value}')
return
async with trio.open_nursery() as n:
for i in range(1, 10):
for i in range(1, size):
n.start_soon(
partial(
sub_and_print,
@ -194,7 +231,7 @@ if __name__ == '__main__':
)
async with tx:
for i in cycle(range(1000)):
for i in cycle(range(size)):
print(f'sending: {i}')
await tx.send(i)