Compare commits
30 Commits
master
...
tokio_back
Author | SHA1 | Date |
---|---|---|
Tyler Goodlet | a7e7c9d1c0 | |
Tyler Goodlet | c3665801a5 | |
Tyler Goodlet | 71a4f8aaa9 | |
Tyler Goodlet | 7296d171be | |
Tyler Goodlet | a053a18f53 | |
Tyler Goodlet | db86409369 | |
Tyler Goodlet | 2c96e85981 | |
Tyler Goodlet | a0b69fd64b | |
Tyler Goodlet | 727d666cb4 | |
Tyler Goodlet | c82ca67263 | |
Tyler Goodlet | 45f334b9c2 | |
Tyler Goodlet | 29e0b8f67d | |
Tyler Goodlet | aad6cf9070 | |
Tyler Goodlet | ac14f611b2 | |
Tyler Goodlet | 4461e3e34f | |
Tyler Goodlet | a27aca070e | |
Tyler Goodlet | 3ba01e7e40 | |
Tyler Goodlet | 843a713f5a | |
Tyler Goodlet | e9b038e87d | |
Tyler Goodlet | 43820e194e | |
Tyler Goodlet | eaa761b0c7 | |
Tyler Goodlet | db2f3f787a | |
Tyler Goodlet | b9863fc4ab | |
Tyler Goodlet | 9d12cc80dd | |
Tyler Goodlet | 3f9b860210 | |
Tyler Goodlet | eeca3d0d50 | |
Tyler Goodlet | e1e3e6918c | |
Tyler Goodlet | dfc4082ad2 | |
Tyler Goodlet | af6e8a64ad | |
Tyler Goodlet | 0c6e7ca351 |
|
@ -24,13 +24,13 @@ jobs:
|
|||
|
||||
testing:
|
||||
name: '${{ matrix.os }} Python ${{ matrix.python }} - ${{ matrix.spawn_backend }}'
|
||||
timeout-minutes: 10
|
||||
timeout-minutes: 9
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest]
|
||||
python: ['3.7', '3.8', '3.9']
|
||||
python: ['3.8', '3.9']
|
||||
spawn_backend: ['trio', 'mp']
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
|
|
@ -313,12 +313,12 @@ async def test_respawn_consumer_task(
|
|||
task_status.started(cs)
|
||||
|
||||
# shield stream's underlying channel from cancellation
|
||||
with stream.shield():
|
||||
# with stream.shield():
|
||||
|
||||
async for v in stream:
|
||||
print(f'from stream: {v}')
|
||||
expect.remove(v)
|
||||
received.append(v)
|
||||
async for v in stream:
|
||||
print(f'from stream: {v}')
|
||||
expect.remove(v)
|
||||
received.append(v)
|
||||
|
||||
print('exited consume')
|
||||
|
||||
|
|
|
@ -0,0 +1,427 @@
|
|||
"""
|
||||
Broadcast channels for fan-out to local tasks.
|
||||
"""
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from itertools import cycle
|
||||
import time
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
from trio.lowlevel import current_task
|
||||
import tractor
|
||||
from tractor._broadcast import broadcast_receiver, Lagged
|
||||
|
||||
|
||||
@tractor.context
|
||||
async def echo_sequences(
|
||||
|
||||
ctx: tractor.Context,
|
||||
|
||||
) -> None:
|
||||
'''Bidir streaming endpoint which will stream
|
||||
back any sequence it is sent item-wise.
|
||||
|
||||
'''
|
||||
await ctx.started()
|
||||
|
||||
async with ctx.open_stream() as stream:
|
||||
async for sequence in stream:
|
||||
seq = list(sequence)
|
||||
for value in seq:
|
||||
await stream.send(value)
|
||||
print(f'producer sent {value}')
|
||||
|
||||
|
||||
async def ensure_sequence(
|
||||
|
||||
stream: tractor.ReceiveMsgStream,
|
||||
sequence: list,
|
||||
delay: Optional[float] = None,
|
||||
|
||||
) -> None:
|
||||
|
||||
name = current_task().name
|
||||
async with stream.subscribe() as bcaster:
|
||||
assert not isinstance(bcaster, type(stream))
|
||||
async for value in bcaster:
|
||||
print(f'{name} rx: {value}')
|
||||
assert value == sequence[0]
|
||||
sequence.remove(value)
|
||||
|
||||
if delay:
|
||||
await trio.sleep(delay)
|
||||
|
||||
if not sequence:
|
||||
# fully consumed
|
||||
break
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def open_sequence_streamer(
|
||||
|
||||
sequence: List[int],
|
||||
arb_addr: Tuple[str, int],
|
||||
start_method: str,
|
||||
shield: bool = False,
|
||||
|
||||
) -> tractor.MsgStream:
|
||||
|
||||
async with tractor.open_nursery(
|
||||
arbiter_addr=arb_addr,
|
||||
start_method=start_method,
|
||||
) as tn:
|
||||
|
||||
portal = await tn.start_actor(
|
||||
'sequence_echoer',
|
||||
enable_modules=[__name__],
|
||||
)
|
||||
|
||||
async with portal.open_context(
|
||||
echo_sequences,
|
||||
) as (ctx, first):
|
||||
|
||||
assert first is None
|
||||
async with ctx.open_stream(shield=shield) as stream:
|
||||
yield stream
|
||||
|
||||
await portal.cancel_actor()
|
||||
|
||||
|
||||
def test_stream_fan_out_to_local_subscriptions(
|
||||
arb_addr,
|
||||
start_method,
|
||||
):
|
||||
|
||||
sequence = list(range(1000))
|
||||
|
||||
async def main():
|
||||
|
||||
async with open_sequence_streamer(
|
||||
sequence,
|
||||
arb_addr,
|
||||
start_method,
|
||||
) as stream:
|
||||
|
||||
async with trio.open_nursery() as n:
|
||||
for i in range(10):
|
||||
n.start_soon(
|
||||
ensure_sequence,
|
||||
stream,
|
||||
sequence.copy(),
|
||||
name=f'consumer_{i}',
|
||||
)
|
||||
|
||||
await stream.send(tuple(sequence))
|
||||
|
||||
async for value in stream:
|
||||
print(f'source stream rx: {value}')
|
||||
assert value == sequence[0]
|
||||
sequence.remove(value)
|
||||
|
||||
if not sequence:
|
||||
# fully consumed
|
||||
break
|
||||
|
||||
trio.run(main)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'task_delays',
|
||||
[
|
||||
(0.01, 0.001),
|
||||
(0.001, 0.01),
|
||||
]
|
||||
)
|
||||
def test_consumer_and_parent_maybe_lag(
|
||||
arb_addr,
|
||||
start_method,
|
||||
task_delays,
|
||||
):
|
||||
|
||||
async def main():
|
||||
|
||||
sequence = list(range(300))
|
||||
parent_delay, sub_delay = task_delays
|
||||
|
||||
async with open_sequence_streamer(
|
||||
sequence,
|
||||
arb_addr,
|
||||
start_method,
|
||||
) as stream:
|
||||
|
||||
try:
|
||||
async with trio.open_nursery() as n:
|
||||
|
||||
n.start_soon(
|
||||
ensure_sequence,
|
||||
stream,
|
||||
sequence.copy(),
|
||||
sub_delay,
|
||||
name='consumer_task',
|
||||
)
|
||||
|
||||
await stream.send(tuple(sequence))
|
||||
|
||||
# async for value in stream:
|
||||
lagged = False
|
||||
lag_count = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
value = await stream.receive()
|
||||
print(f'source stream rx: {value}')
|
||||
|
||||
if lagged:
|
||||
# re set the sequence starting at our last
|
||||
# value
|
||||
sequence = sequence[sequence.index(value) + 1:]
|
||||
else:
|
||||
assert value == sequence[0]
|
||||
sequence.remove(value)
|
||||
|
||||
lagged = False
|
||||
|
||||
except Lagged:
|
||||
lagged = True
|
||||
print(f'source stream lagged after {value}')
|
||||
lag_count += 1
|
||||
continue
|
||||
|
||||
# lag the parent
|
||||
await trio.sleep(parent_delay)
|
||||
|
||||
if not sequence:
|
||||
# fully consumed
|
||||
break
|
||||
print(f'parent + source stream lagged: {lag_count}')
|
||||
|
||||
if parent_delay > sub_delay:
|
||||
assert lag_count > 0
|
||||
|
||||
except Lagged:
|
||||
# child was lagged
|
||||
assert parent_delay < sub_delay
|
||||
|
||||
trio.run(main)
|
||||
|
||||
|
||||
def test_faster_task_to_recv_is_cancelled_by_slower(
|
||||
arb_addr,
|
||||
start_method,
|
||||
):
|
||||
'''Ensure that if a faster task consuming from a stream is cancelled
|
||||
the slower task can continue to receive all expected values.
|
||||
|
||||
'''
|
||||
async def main():
|
||||
|
||||
sequence = list(range(1000))
|
||||
|
||||
async with open_sequence_streamer(
|
||||
sequence,
|
||||
arb_addr,
|
||||
start_method,
|
||||
|
||||
# NOTE: this MUST be set to avoid the stream terminating
|
||||
# early when the faster subtask is cancelled by the slower
|
||||
# parent task.
|
||||
shield=True,
|
||||
|
||||
) as stream:
|
||||
|
||||
# alt to passing kwarg above.
|
||||
# with stream.shield():
|
||||
|
||||
async with trio.open_nursery() as n:
|
||||
n.start_soon(
|
||||
ensure_sequence,
|
||||
stream,
|
||||
sequence.copy(),
|
||||
0,
|
||||
name='consumer_task',
|
||||
)
|
||||
|
||||
await stream.send(tuple(sequence))
|
||||
|
||||
# pull 3 values, cancel the subtask, then
|
||||
# expect to be able to pull all values still
|
||||
for i in range(20):
|
||||
try:
|
||||
value = await stream.receive()
|
||||
print(f'source stream rx: {value}')
|
||||
await trio.sleep(0.01)
|
||||
except Lagged:
|
||||
print(f'parent overrun after {value}')
|
||||
continue
|
||||
|
||||
print('cancelling faster subtask')
|
||||
n.cancel_scope.cancel()
|
||||
|
||||
try:
|
||||
value = await stream.receive()
|
||||
print(f'source stream after cancel: {value}')
|
||||
except Lagged:
|
||||
print(f'parent overrun after {value}')
|
||||
|
||||
# expect to see all remaining values
|
||||
with trio.fail_after(0.5):
|
||||
async for value in stream:
|
||||
assert stream._broadcaster._state.recv_ready is None
|
||||
print(f'source stream rx: {value}')
|
||||
if value == 999:
|
||||
# fully consumed and we missed no values once
|
||||
# the faster subtask was cancelled
|
||||
break
|
||||
|
||||
# await tractor.breakpoint()
|
||||
# await stream.receive()
|
||||
print(f'final value: {value}')
|
||||
|
||||
trio.run(main)
|
||||
|
||||
|
||||
def test_subscribe_errors_after_close():
|
||||
|
||||
async def main():
|
||||
|
||||
size = 1
|
||||
tx, rx = trio.open_memory_channel(size)
|
||||
async with broadcast_receiver(rx, size) as brx:
|
||||
pass
|
||||
|
||||
try:
|
||||
# open and close
|
||||
async with brx.subscribe():
|
||||
pass
|
||||
|
||||
except trio.ClosedResourceError:
|
||||
assert brx.key not in brx._state.subs
|
||||
|
||||
else:
|
||||
assert 0
|
||||
|
||||
trio.run(main)
|
||||
|
||||
|
||||
def test_ensure_slow_consumers_lag_out(
|
||||
arb_addr,
|
||||
start_method,
|
||||
):
|
||||
'''This is a pure local task test; no tractor
|
||||
machinery is really required.
|
||||
|
||||
'''
|
||||
async def main():
|
||||
|
||||
# make sure it all works within the runtime
|
||||
async with tractor.open_root_actor():
|
||||
|
||||
num_laggers = 4
|
||||
laggers: dict[str, int] = {}
|
||||
retries = 3
|
||||
size = 100
|
||||
tx, rx = trio.open_memory_channel(size)
|
||||
brx = broadcast_receiver(rx, size)
|
||||
|
||||
async def sub_and_print(
|
||||
delay: float,
|
||||
) -> None:
|
||||
|
||||
task = current_task()
|
||||
start = time.time()
|
||||
|
||||
async with brx.subscribe() as lbrx:
|
||||
while True:
|
||||
print(f'{task.name}: starting consume loop')
|
||||
try:
|
||||
async for value in lbrx:
|
||||
print(f'{task.name}: {value}')
|
||||
await trio.sleep(delay)
|
||||
|
||||
if task.name == 'sub_1':
|
||||
# the non-lagger got
|
||||
# a ``trio.EndOfChannel``
|
||||
# because the ``tx`` below was closed
|
||||
assert len(lbrx._state.subs) == 1
|
||||
|
||||
await lbrx.aclose()
|
||||
|
||||
assert len(lbrx._state.subs) == 0
|
||||
|
||||
except trio.ClosedResourceError:
|
||||
# only the fast sub will try to re-enter
|
||||
# iteration on the now closed bcaster
|
||||
assert task.name == 'sub_1'
|
||||
return
|
||||
|
||||
except Lagged:
|
||||
lag_time = time.time() - start
|
||||
lags = laggers[task.name]
|
||||
print(
|
||||
f'restarting slow task {task.name} '
|
||||
f'that bailed out on {lags}:{value} '
|
||||
f'after {lag_time:.3f}')
|
||||
if lags <= retries:
|
||||
laggers[task.name] += 1
|
||||
continue
|
||||
else:
|
||||
print(
|
||||
f'{task.name} was too slow and terminated '
|
||||
f'on {lags}:{value}')
|
||||
return
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
|
||||
for i in range(1, num_laggers):
|
||||
|
||||
task_name = f'sub_{i}'
|
||||
laggers[task_name] = 0
|
||||
nursery.start_soon(
|
||||
partial(
|
||||
sub_and_print,
|
||||
delay=i*0.001,
|
||||
),
|
||||
name=task_name,
|
||||
)
|
||||
|
||||
# allow subs to sched
|
||||
await trio.sleep(0.1)
|
||||
|
||||
async with tx:
|
||||
for i in cycle(range(size)):
|
||||
await tx.send(i)
|
||||
if len(brx._state.subs) == 2:
|
||||
# only one, the non lagger, sub is left
|
||||
break
|
||||
|
||||
# the non-lagger
|
||||
assert laggers.pop('sub_1') == 0
|
||||
|
||||
for n, v in laggers.items():
|
||||
assert v == 4
|
||||
|
||||
assert tx._closed
|
||||
assert not tx._state.open_send_channels
|
||||
|
||||
# check that "first" bcaster that we created
|
||||
# above, never wass iterated and is thus overrun
|
||||
try:
|
||||
await brx.receive()
|
||||
except Lagged:
|
||||
# expect tokio style index truncation
|
||||
seq = brx._state.subs[brx.key]
|
||||
assert seq == len(brx._state.queue) - 1
|
||||
|
||||
# all backpressured entries in the underlying
|
||||
# channel should have been copied into the caster
|
||||
# queue trailing-window
|
||||
async for i in rx:
|
||||
print(f'bped: {i}')
|
||||
assert i in brx._state.queue
|
||||
|
||||
# should be noop
|
||||
await brx.aclose()
|
||||
|
||||
trio.run(main)
|
|
@ -567,7 +567,7 @@ class Actor:
|
|||
try:
|
||||
send_chan, recv_chan = self._cids2qs[(actorid, cid)]
|
||||
except KeyError:
|
||||
send_chan, recv_chan = trio.open_memory_channel(1000)
|
||||
send_chan, recv_chan = trio.open_memory_channel(2**6)
|
||||
send_chan.cid = cid # type: ignore
|
||||
recv_chan.cid = cid # type: ignore
|
||||
self._cids2qs[(actorid, cid)] = send_chan, recv_chan
|
||||
|
|
|
@ -0,0 +1,308 @@
|
|||
'''
|
||||
``tokio`` style broadcast channel.
|
||||
https://tokio-rs.github.io/tokio/doc/tokio/sync/broadcast/index.html
|
||||
|
||||
'''
|
||||
from __future__ import annotations
|
||||
from abc import abstractmethod
|
||||
from collections import deque
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from operator import ne
|
||||
from typing import Optional, Callable, Awaitable, Any, AsyncIterator, Protocol
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
import trio
|
||||
from trio._core._run import Task
|
||||
from trio.abc import ReceiveChannel
|
||||
from trio.lowlevel import current_task
|
||||
|
||||
|
||||
# A regular invariant generic type
|
||||
T = TypeVar("T")
|
||||
|
||||
# covariant because AsyncReceiver[Derived] can be passed to someone
|
||||
# expecting AsyncReceiver[Base])
|
||||
ReceiveType = TypeVar("ReceiveType", covariant=True)
|
||||
|
||||
|
||||
class AsyncReceiver(
|
||||
Protocol,
|
||||
Generic[ReceiveType],
|
||||
):
|
||||
'''An async receivable duck-type that quacks much like trio's
|
||||
``trio.abc.ReceieveChannel``.
|
||||
|
||||
'''
|
||||
@abstractmethod
|
||||
async def receive(self) -> ReceiveType:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def __aiter__(self) -> AsyncIterator[ReceiveType]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def __anext__(self) -> ReceiveType:
|
||||
...
|
||||
|
||||
# ``trio.abc.AsyncResource`` methods
|
||||
@abstractmethod
|
||||
async def aclose(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def __aenter__(self) -> AsyncReceiver[ReceiveType]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def __aexit__(self, *args) -> None:
|
||||
...
|
||||
|
||||
|
||||
class Lagged(trio.TooSlowError):
|
||||
'''Subscribed consumer task was too slow and was overrun
|
||||
by the fastest consumer-producer pair.
|
||||
|
||||
'''
|
||||
|
||||
|
||||
@dataclass
|
||||
class BroadcastState:
|
||||
'''Common state to all receivers of a broadcast.
|
||||
|
||||
'''
|
||||
queue: deque
|
||||
maxlen: int
|
||||
|
||||
# map of underlying instance id keys to receiver instances which
|
||||
# must be provided as a singleton per broadcaster set.
|
||||
subs: dict[int, int]
|
||||
|
||||
# broadcast event to wake up all sleeping consumer tasks
|
||||
# on a newly produced value from the sender.
|
||||
recv_ready: Optional[tuple[int, trio.Event]] = None
|
||||
|
||||
|
||||
class BroadcastReceiver(ReceiveChannel):
|
||||
'''A memory receive channel broadcaster which is non-lossy for the
|
||||
fastest consumer.
|
||||
|
||||
Additional consumer tasks can receive all produced values by registering
|
||||
with ``.subscribe()`` and receiving from the new instance it delivers.
|
||||
|
||||
'''
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
rx_chan: AsyncReceiver,
|
||||
state: BroadcastState,
|
||||
receive_afunc: Optional[Callable[[], Awaitable[Any]]] = None,
|
||||
|
||||
) -> None:
|
||||
|
||||
# register the original underlying (clone)
|
||||
self.key = id(self)
|
||||
self._state = state
|
||||
state.subs[self.key] = -1
|
||||
|
||||
# underlying for this receiver
|
||||
self._rx = rx_chan
|
||||
self._recv = receive_afunc or rx_chan.receive
|
||||
self._closed: bool = False
|
||||
|
||||
async def receive(self) -> ReceiveType:
|
||||
|
||||
key = self.key
|
||||
state = self._state
|
||||
|
||||
# TODO: ideally we can make some way to "lock out" the
|
||||
# underlying receive channel in some way such that if some task
|
||||
# tries to pull from it directly (i.e. one we're unaware of)
|
||||
# then it errors out.
|
||||
|
||||
# only tasks which have entered ``.subscribe()`` can
|
||||
# receive on this broadcaster.
|
||||
try:
|
||||
seq = state.subs[key]
|
||||
except KeyError:
|
||||
if self._closed:
|
||||
raise trio.ClosedResourceError
|
||||
|
||||
raise RuntimeError(
|
||||
f'{self} is not registerd as subscriber')
|
||||
|
||||
# check that task does not already have a value it can receive
|
||||
# immediately and/or that it has lagged.
|
||||
if seq > -1:
|
||||
# get the oldest value we haven't received immediately
|
||||
try:
|
||||
value = state.queue[seq]
|
||||
except IndexError:
|
||||
|
||||
# adhere to ``tokio`` style "lagging":
|
||||
# "Once RecvError::Lagged is returned, the lagging
|
||||
# receiver's position is updated to the oldest value
|
||||
# contained by the channel. The next call to recv will
|
||||
# return this value."
|
||||
# https://tokio-rs.github.io/tokio/doc/tokio/sync/broadcast/index.html#lagging
|
||||
|
||||
# decrement to the last value and expect
|
||||
# consumer to either handle the ``Lagged`` and come back
|
||||
# or bail out on its own (thus un-subscribing)
|
||||
state.subs[key] = state.maxlen - 1
|
||||
|
||||
# this task was overrun by the producer side
|
||||
task: Task = current_task()
|
||||
raise Lagged(f'Task {task.name} was overrun')
|
||||
|
||||
state.subs[key] -= 1
|
||||
return value
|
||||
|
||||
# current task already has the latest value **and** is the
|
||||
# first task to begin waiting for a new one
|
||||
if state.recv_ready is None:
|
||||
|
||||
if self._closed:
|
||||
raise trio.ClosedResourceError
|
||||
|
||||
event = trio.Event()
|
||||
state.recv_ready = key, event
|
||||
|
||||
# if we're cancelled here it should be
|
||||
# fine to bail without affecting any other consumers
|
||||
# right?
|
||||
try:
|
||||
value = await self._recv()
|
||||
|
||||
# items with lower indices are "newer"
|
||||
# NOTE: ``collections.deque`` implicitly takes care of
|
||||
# trucating values outside our ``state.maxlen``. In the
|
||||
# alt-backend-array-case we'll need to make sure this is
|
||||
# implemented in similar ringer-buffer-ish style.
|
||||
state.queue.appendleft(value)
|
||||
|
||||
# broadcast new value to all subscribers by increasing
|
||||
# all sequence numbers that will point in the queue to
|
||||
# their latest available value.
|
||||
|
||||
# don't decrement the sequence for this task since we
|
||||
# already retreived the last value
|
||||
|
||||
# XXX: which of these impls is fastest?
|
||||
|
||||
# subs = state.subs.copy()
|
||||
# subs.pop(key)
|
||||
|
||||
for sub_key in filter(
|
||||
# lambda k: k != key, state.subs,
|
||||
partial(ne, key), state.subs,
|
||||
):
|
||||
state.subs[sub_key] += 1
|
||||
|
||||
# NOTE: this should ONLY be set if the above task was *NOT*
|
||||
# cancelled on the `._recv()` call otherwise sibling
|
||||
# consumers will be awoken with a sequence of -1
|
||||
event.set()
|
||||
|
||||
return value
|
||||
|
||||
finally:
|
||||
# Reset receiver waiter task event for next blocking condition.
|
||||
# this MUST be reset even if the above ``.recv()`` call
|
||||
# was cancelled to avoid the next consumer from blocking on
|
||||
# an event that won't be set!
|
||||
state.recv_ready = None
|
||||
|
||||
# This task is all caught up and ready to receive the latest
|
||||
# value, so queue sched it on the internal event.
|
||||
else:
|
||||
seq = state.subs[key]
|
||||
assert seq == -1 # sanity
|
||||
_, ev = state.recv_ready
|
||||
await ev.wait()
|
||||
|
||||
seq = state.subs[key]
|
||||
assert seq > -1, f'Invalid sequence {seq}!?'
|
||||
|
||||
value = state.queue[seq]
|
||||
state.subs[key] -= 1
|
||||
return value
|
||||
|
||||
# NOTE: if we ever would like the behaviour where if the
|
||||
# first task to recv on the underlying is cancelled but it
|
||||
# still DOES trigger the ``.recv_ready``, event we'll likely need
|
||||
# this logic:
|
||||
|
||||
# if seq > -1:
|
||||
# # stuff from above..
|
||||
# elif seq == -1:
|
||||
# # XXX: In the case where the first task to allocate the
|
||||
# # ``.recv_ready`` event is cancelled we will be woken with
|
||||
# # a non-incremented sequence number and thus will read the
|
||||
# # oldest value if we use that. Instead we need to detect if
|
||||
# # we have not been incremented and then receive again.
|
||||
# return await self.receive()
|
||||
# else:
|
||||
# raise ValueError(f'Invalid sequence {seq}!?')
|
||||
|
||||
@asynccontextmanager
|
||||
async def subscribe(
|
||||
self,
|
||||
) -> AsyncIterator[BroadcastReceiver]:
|
||||
'''Subscribe for values from this broadcast receiver.
|
||||
|
||||
Returns a new ``BroadCastReceiver`` which is registered for and
|
||||
pulls data from a clone of the original ``trio.abc.ReceiveChannel``
|
||||
provided at creation.
|
||||
|
||||
'''
|
||||
if self._closed:
|
||||
raise trio.ClosedResourceError
|
||||
|
||||
state = self._state
|
||||
br = BroadcastReceiver(
|
||||
rx_chan=self._rx,
|
||||
state=state,
|
||||
receive_afunc=self._recv,
|
||||
)
|
||||
# assert clone in state.subs
|
||||
assert br.key in state.subs
|
||||
|
||||
try:
|
||||
yield br
|
||||
finally:
|
||||
await br.aclose()
|
||||
|
||||
async def aclose(
|
||||
self,
|
||||
) -> None:
|
||||
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
# XXX: leaving it like this consumers can still get values
|
||||
# up to the last received that still reside in the queue.
|
||||
self._state.subs.pop(self.key)
|
||||
|
||||
self._closed = True
|
||||
|
||||
|
||||
def broadcast_receiver(
|
||||
|
||||
recv_chan: AsyncReceiver,
|
||||
max_buffer_size: int,
|
||||
**kwargs,
|
||||
|
||||
) -> BroadcastReceiver:
|
||||
|
||||
return BroadcastReceiver(
|
||||
recv_chan,
|
||||
state=BroadcastState(
|
||||
queue=deque(maxlen=max_buffer_size),
|
||||
maxlen=max_buffer_size,
|
||||
subs={},
|
||||
),
|
||||
**kwargs,
|
||||
)
|
|
@ -294,6 +294,7 @@ class Portal:
|
|||
async def open_stream_from(
|
||||
self,
|
||||
async_gen_func: Callable, # typing: ignore
|
||||
shield: bool = False,
|
||||
**kwargs,
|
||||
|
||||
) -> AsyncGenerator[ReceiveMsgStream, None]:
|
||||
|
@ -320,7 +321,9 @@ class Portal:
|
|||
ctx = Context(self.channel, cid, _portal=self)
|
||||
try:
|
||||
# deliver receive only stream
|
||||
async with ReceiveMsgStream(ctx, recv_chan) as rchan:
|
||||
async with ReceiveMsgStream(
|
||||
ctx, recv_chan, shield=shield
|
||||
) as rchan:
|
||||
self._streams.add(rchan)
|
||||
yield rchan
|
||||
|
||||
|
|
|
@ -2,12 +2,14 @@
|
|||
Message stream types and APIs.
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import inspect
|
||||
from contextlib import contextmanager, asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any, Iterator, Optional, Callable,
|
||||
AsyncGenerator, Dict,
|
||||
AsyncIterator
|
||||
)
|
||||
|
||||
import warnings
|
||||
|
@ -17,6 +19,7 @@ import trio
|
|||
from ._ipc import Channel
|
||||
from ._exceptions import unpack_error, ContextCancelled
|
||||
from ._state import current_actor
|
||||
from ._broadcast import broadcast_receiver, BroadcastReceiver
|
||||
from .log import get_logger
|
||||
|
||||
|
||||
|
@ -45,12 +48,14 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
|
|||
def __init__(
|
||||
self,
|
||||
ctx: 'Context', # typing: ignore # noqa
|
||||
rx_chan: trio.abc.ReceiveChannel,
|
||||
rx_chan: trio.MemoryReceiveChannel,
|
||||
shield: bool = False,
|
||||
_broadcaster: Optional[BroadcastReceiver] = None,
|
||||
|
||||
) -> None:
|
||||
self._ctx = ctx
|
||||
self._rx_chan = rx_chan
|
||||
self._shielded = shield
|
||||
self._broadcaster = _broadcaster
|
||||
|
||||
# flag to denote end of stream
|
||||
self._eoc: bool = False
|
||||
|
@ -103,7 +108,10 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
|
|||
except (
|
||||
trio.ClosedResourceError, # by self._rx_chan
|
||||
trio.EndOfChannel, # by self._rx_chan or `stop` msg from far end
|
||||
trio.Cancelled, # by local cancellation
|
||||
|
||||
# Wait why would we do an implicit close on cancel? THAT'S
|
||||
# NOT HOW MEM CHANS WORK!!?!?!?!?
|
||||
# trio.Cancelled, # by local cancellation
|
||||
):
|
||||
# XXX: we close the stream on any of these error conditions:
|
||||
|
||||
|
@ -135,23 +143,6 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
|
|||
|
||||
raise # propagate
|
||||
|
||||
@contextmanager
|
||||
def shield(
|
||||
self
|
||||
) -> Iterator['ReceiveMsgStream']: # noqa
|
||||
"""Shield this stream's underlying channel such that a local consumer task
|
||||
can be cancelled (and possibly restarted) using ``trio.Cancelled``.
|
||||
|
||||
Note that here, "shielding" here guards against relaying
|
||||
a ``'stop'`` message to the far end of the stream thus keeping
|
||||
the stream machinery active and ready for further use, it does
|
||||
not have anything to do with an internal ``trio.CancelScope``.
|
||||
|
||||
"""
|
||||
self._shielded = True
|
||||
yield self
|
||||
self._shielded = False
|
||||
|
||||
async def aclose(self):
|
||||
"""Cancel associated remote actor task and local memory channel
|
||||
on close.
|
||||
|
@ -169,18 +160,10 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
|
|||
# https://trio.readthedocs.io/en/stable/reference-io.html#trio.abc.AsyncResource.aclose
|
||||
return
|
||||
|
||||
# TODO: broadcasting to multiple consumers
|
||||
# stats = rx_chan.statistics()
|
||||
# if stats.open_receive_channels > 1:
|
||||
# # if we've been cloned don't kill the stream
|
||||
# log.debug(
|
||||
# "there are still consumers running keeping stream alive")
|
||||
# if self._shielded:
|
||||
# log.warning(f"{self} is shielded, portal channel being kept alive")
|
||||
# return
|
||||
|
||||
if self._shielded:
|
||||
log.warning(f"{self} is shielded, portal channel being kept alive")
|
||||
return
|
||||
|
||||
# XXX: This must be set **AFTER** the shielded test above!
|
||||
self._eoc = True
|
||||
|
||||
|
@ -253,6 +236,50 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
|
|||
# still need to consume msgs that are "in transit" from the far
|
||||
# end (eg. for ``Context.result()``).
|
||||
|
||||
@asynccontextmanager
|
||||
async def subscribe(
|
||||
self,
|
||||
|
||||
) -> AsyncIterator[BroadcastReceiver]:
|
||||
'''Allocate and return a ``BroadcastReceiver`` which delegates
|
||||
to this message stream.
|
||||
|
||||
This allows multiple local tasks to receive each their own copy
|
||||
of this message stream.
|
||||
|
||||
This operation is indempotent and and mutates this stream's
|
||||
receive machinery to copy and window-length-store each received
|
||||
value from the far end via the internally created broudcast
|
||||
receiver wrapper.
|
||||
|
||||
'''
|
||||
# NOTE: This operation is indempotent and non-reversible, so be
|
||||
# sure you can deal with any (theoretical) overhead of the the
|
||||
# allocated ``BroadcastReceiver`` before calling this method for
|
||||
# the first time.
|
||||
if self._broadcaster is None:
|
||||
|
||||
bcast = self._broadcaster = broadcast_receiver(
|
||||
self,
|
||||
# use memory channel size by default
|
||||
self._rx_chan._state.max_buffer_size, # type: ignore
|
||||
receive_afunc=self.receive,
|
||||
)
|
||||
|
||||
# NOTE: we override the original stream instance's receive
|
||||
# method to now delegate to the broadcaster's ``.receive()``
|
||||
# such that new subscribers will be copied received values
|
||||
# and this stream doesn't have to expect it's original
|
||||
# consumer(s) to get a new broadcast rx handle.
|
||||
self.receive = bcast.receive # type: ignore
|
||||
# seems there's no graceful way to type this with ``mypy``?
|
||||
# https://github.com/python/mypy/issues/708
|
||||
|
||||
async with self._broadcaster.subscribe() as bstream:
|
||||
assert bstream.key != self._broadcaster.key
|
||||
assert bstream._recv == self._broadcaster._recv
|
||||
yield bstream
|
||||
|
||||
|
||||
class MsgStream(ReceiveMsgStream, trio.abc.Channel):
|
||||
"""
|
||||
|
@ -269,17 +296,6 @@ class MsgStream(ReceiveMsgStream, trio.abc.Channel):
|
|||
'''
|
||||
await self._ctx.chan.send({'yield': data, 'cid': self._ctx.cid})
|
||||
|
||||
# TODO: but make it broadcasting to consumers
|
||||
def clone(self):
|
||||
"""Clone this receive channel allowing for multi-task
|
||||
consumption from the same channel.
|
||||
|
||||
"""
|
||||
return MsgStream(
|
||||
self._ctx,
|
||||
self._rx_chan.clone(),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Context:
|
||||
|
@ -397,7 +413,6 @@ class Context:
|
|||
async def open_stream(
|
||||
|
||||
self,
|
||||
shield: bool = False,
|
||||
|
||||
) -> AsyncGenerator[MsgStream, None]:
|
||||
'''Open a ``MsgStream``, a bi-directional stream connected to the
|
||||
|
@ -455,7 +470,6 @@ class Context:
|
|||
async with MsgStream(
|
||||
ctx=self,
|
||||
rx_chan=recv_chan,
|
||||
shield=shield,
|
||||
) as rchan:
|
||||
|
||||
if self._portal:
|
||||
|
|
Loading…
Reference in New Issue