forked from goodboy/tractor
1
0
Fork 0

Compare commits

...

30 Commits

Author SHA1 Message Date
Tyler Goodlet a7e7c9d1c0 Store array `maxlen` in state singleton
The `collections.deque` takes care of array length truncation of values
for us implicitly but in the future we'll likely want this value exposed
to alternate array implementations. This patch is to provide for that as
well as make `mypy` happy since the `dequeu.maxlen` can also be `None`.
2021-09-01 06:49:09 -04:00
Tyler Goodlet c3665801a5 Don't wake sibling bcast consumers on a cancelled call 2021-09-01 06:49:09 -04:00
Tyler Goodlet 71a4f8aaa9 Shorten sequence length for test speedup 2021-09-01 06:49:09 -04:00
Tyler Goodlet 7296d171be Shorten default feeder mem chan size to 64 2021-09-01 06:49:09 -04:00
Tyler Goodlet a053a18f53 Can't use built-in generics till 3.9... 2021-09-01 06:49:09 -04:00
Tyler Goodlet db86409369 Add `shield: bool` kwarg to `Portal.open_stream_from()` 2021-09-01 06:49:09 -04:00
Tyler Goodlet 2c96e85981 Add a "faster task is cancelled" test 2021-09-01 06:49:09 -04:00
Tyler Goodlet a0b69fd64b Rename test module 2021-09-01 06:49:09 -04:00
Tyler Goodlet 727d666cb4 Add some bcaster ref sanity asserts around subscriptions 2021-09-01 06:49:09 -04:00
Tyler Goodlet c82ca67263 Add laggy parent stream tests
Add a couple more tests to check that a parent and sub-task stream can
be lagged and recovered (depending on who's slower). Factor some of the
test machinery into a new ctx mngr to make it all happen.
2021-09-01 06:49:09 -04:00
Tyler Goodlet 45f334b9c2 Instance ids are ints 2021-09-01 06:49:09 -04:00
Tyler Goodlet 29e0b8f67d Add subscribe after close test 2021-09-01 06:49:09 -04:00
Tyler Goodlet aad6cf9070 Drop uuid4 keys, raise closed error on subscription after close 2021-09-01 06:49:09 -04:00
Tyler Goodlet ac14f611b2 Lol, guess windows needs the extra minutes 2021-09-01 06:49:09 -04:00
Tyler Goodlet 4461e3e34f Don't enable debug mode..it borks CI 2021-09-01 06:49:09 -04:00
Tyler Goodlet a27aca070e Drop py3.7 from CI; cut run to 5mins 2021-09-01 06:49:09 -04:00
Tyler Goodlet 3ba01e7e40 Fix `.receive()` re-assignment, drop `.clone()` 2021-09-01 06:49:07 -04:00
Tyler Goodlet 843a713f5a Initial broadcaster tests including one to test our `MsgStream.subscribe()` api 2021-09-01 06:46:43 -04:00
Tyler Goodlet e9b038e87d Blade runner it
Get rid of all the (requirements for) clones of the underlying
receivable. We can just use a uuid generated key for each instance
(thinking now this can probably just be `id(self)`). I'm fully convinced
now that channel cloning is only a source of confusion and anti-patterns
when we already have nurseries to define resource lifetimes. There is no
benefit in particular when you allocate subscriptions using a context
manager (not sure why `trio.open_memory_channel()` doesn't enforce
this).

Further refinements:
- add a `._closed` state that will error the receiver on reuse
- drop module script section;  it's been moved to a real test
- call the "receiver" duck-type stub a new name
2021-09-01 06:46:43 -04:00
Tyler Goodlet 43820e194e Store handle to underlying channel's `.receive()`
This allows for wrapping an existing stream by re-assigning its receive
method to the allocated broadcaster's `.receive()` so as to avoid
expecting any original consumer(s) of the stream to now know about the
broadcaster; this instead mutates the stream to delegate to the new
receive call behind the scenes any time `.subscribe()` is called.

Add a `typing.Protocol` for so called "cloneable channels" until we
decide/figure out a better keying system for each subscription and
mask all undesired typing failures.
2021-09-01 06:46:40 -04:00
Tyler Goodlet eaa761b0c7 Add subscription support to message streams
Add `ReceiveMsgStream.subscribe()` which allows allocating a broadcast
receiver around the stream for use by multiple actor-local consumer
tasks. Entering this context manager idempotently mutates the stream's
receive machinery which for now can not be undone. Move `.clone()` to
the receive stream type.

Resolves #204
2021-09-01 06:40:25 -04:00
Tyler Goodlet db2f3f787a Drop optimization check, binance made its point 2021-09-01 06:39:21 -04:00
Tyler Goodlet b9863fc4ab Add common state delegate type for all consumers
For every set of broadcast receivers which pull from the same producer,
we need a singleton state for all of,
- subscriptions
- the sender ready event
- the queue

Add a `BroadcastState` dataclass for this and pass it to all
subscriptions. This makes the design much more like the built-in memory
channels which do something very similar with `MemoryChannelState`.

Use a `filter()` on the subs list in the sequence update step, plus some
other commented approaches we can try for speed.
2021-09-01 06:39:21 -04:00
Tyler Goodlet 9d12cc80dd Facepalm: use single `_subs` per clone set 2021-09-01 06:39:21 -04:00
Tyler Goodlet 3f9b860210 Obviously keying on tasks isn't going to work
Using the current task as a subscription key fails horribly as soon as
you hand off new subscription receiver to another task you've spawned..

Instead use the underlying ``trio.abc.ReceiveChannel.clone()`` as a key
(so i guess we're assuming cloning is supported by the underlying?)
which makes this all work just like default mem chans. As a bonus, now
we can just close the underlying rx (which may be a clone) on
`.aclose()` and everything should just work in terms of the underlying
channels lifetime (i think?).

Change `.subscribe()` to be async since the receive channel type
interface only expects `.aclose()` and it actually ends up being
nicer for 3.9+ style `async with` parentheses style anyway.
2021-09-01 06:39:21 -04:00
Tyler Goodlet eeca3d0d50 Rename to broadcast mod, don't expect mem chan specifically 2021-09-01 06:39:21 -04:00
Tyler Goodlet e1e3e6918c `Task` is hashable, so key on it 2021-09-01 06:39:21 -04:00
Tyler Goodlet dfc4082ad2 Simplify api around receive channel
Buncha improvements:
- pass in the queue via constructor
- tracking over all underlying memory channel closure using cloning
- do it like `tokio` and set lagged consumers to the last sequence
  before raising
- copy the subs on first receiver wakeup for iteration instead of
  iterating the table directly (and being forced to skip the current
  tasks sequence increment)
- implement `.aclose()` to close the underlying clone for this task
- make `broadcast_receiver()` just take the recv chan since it doesn't
  need anything on the send side.
2021-09-01 06:39:21 -04:00
Tyler Goodlet af6e8a64ad Ultra naive broadcast channel prototype 2021-09-01 06:39:21 -04:00
Tyler Goodlet 0c6e7ca351 Drop stream shielding; it was from a legacy design
The whole origin was not having an explicit open/close semantic for
streams. We have that now so this internal mechanic isn't needed and
further our streams become more correct by having `.aclose()` be
independent of cancellation.
2021-09-01 06:37:53 -04:00
7 changed files with 804 additions and 52 deletions

View File

@ -24,13 +24,13 @@ jobs:
testing: testing:
name: '${{ matrix.os }} Python ${{ matrix.python }} - ${{ matrix.spawn_backend }}' name: '${{ matrix.os }} Python ${{ matrix.python }} - ${{ matrix.spawn_backend }}'
timeout-minutes: 10 timeout-minutes: 9
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
os: [ubuntu-latest, windows-latest] os: [ubuntu-latest, windows-latest]
python: ['3.7', '3.8', '3.9'] python: ['3.8', '3.9']
spawn_backend: ['trio', 'mp'] spawn_backend: ['trio', 'mp']
steps: steps:
- name: Checkout - name: Checkout

View File

@ -313,7 +313,7 @@ async def test_respawn_consumer_task(
task_status.started(cs) task_status.started(cs)
# shield stream's underlying channel from cancellation # shield stream's underlying channel from cancellation
with stream.shield(): # with stream.shield():
async for v in stream: async for v in stream:
print(f'from stream: {v}') print(f'from stream: {v}')

View File

@ -0,0 +1,427 @@
"""
Broadcast channels for fan-out to local tasks.
"""
from contextlib import asynccontextmanager
from functools import partial
from itertools import cycle
import time
from typing import Optional, List, Tuple
import pytest
import trio
from trio.lowlevel import current_task
import tractor
from tractor._broadcast import broadcast_receiver, Lagged
@tractor.context
async def echo_sequences(
ctx: tractor.Context,
) -> None:
'''Bidir streaming endpoint which will stream
back any sequence it is sent item-wise.
'''
await ctx.started()
async with ctx.open_stream() as stream:
async for sequence in stream:
seq = list(sequence)
for value in seq:
await stream.send(value)
print(f'producer sent {value}')
async def ensure_sequence(
stream: tractor.ReceiveMsgStream,
sequence: list,
delay: Optional[float] = None,
) -> None:
name = current_task().name
async with stream.subscribe() as bcaster:
assert not isinstance(bcaster, type(stream))
async for value in bcaster:
print(f'{name} rx: {value}')
assert value == sequence[0]
sequence.remove(value)
if delay:
await trio.sleep(delay)
if not sequence:
# fully consumed
break
@asynccontextmanager
async def open_sequence_streamer(
sequence: List[int],
arb_addr: Tuple[str, int],
start_method: str,
shield: bool = False,
) -> tractor.MsgStream:
async with tractor.open_nursery(
arbiter_addr=arb_addr,
start_method=start_method,
) as tn:
portal = await tn.start_actor(
'sequence_echoer',
enable_modules=[__name__],
)
async with portal.open_context(
echo_sequences,
) as (ctx, first):
assert first is None
async with ctx.open_stream(shield=shield) as stream:
yield stream
await portal.cancel_actor()
def test_stream_fan_out_to_local_subscriptions(
arb_addr,
start_method,
):
sequence = list(range(1000))
async def main():
async with open_sequence_streamer(
sequence,
arb_addr,
start_method,
) as stream:
async with trio.open_nursery() as n:
for i in range(10):
n.start_soon(
ensure_sequence,
stream,
sequence.copy(),
name=f'consumer_{i}',
)
await stream.send(tuple(sequence))
async for value in stream:
print(f'source stream rx: {value}')
assert value == sequence[0]
sequence.remove(value)
if not sequence:
# fully consumed
break
trio.run(main)
@pytest.mark.parametrize(
'task_delays',
[
(0.01, 0.001),
(0.001, 0.01),
]
)
def test_consumer_and_parent_maybe_lag(
arb_addr,
start_method,
task_delays,
):
async def main():
sequence = list(range(300))
parent_delay, sub_delay = task_delays
async with open_sequence_streamer(
sequence,
arb_addr,
start_method,
) as stream:
try:
async with trio.open_nursery() as n:
n.start_soon(
ensure_sequence,
stream,
sequence.copy(),
sub_delay,
name='consumer_task',
)
await stream.send(tuple(sequence))
# async for value in stream:
lagged = False
lag_count = 0
while True:
try:
value = await stream.receive()
print(f'source stream rx: {value}')
if lagged:
# re set the sequence starting at our last
# value
sequence = sequence[sequence.index(value) + 1:]
else:
assert value == sequence[0]
sequence.remove(value)
lagged = False
except Lagged:
lagged = True
print(f'source stream lagged after {value}')
lag_count += 1
continue
# lag the parent
await trio.sleep(parent_delay)
if not sequence:
# fully consumed
break
print(f'parent + source stream lagged: {lag_count}')
if parent_delay > sub_delay:
assert lag_count > 0
except Lagged:
# child was lagged
assert parent_delay < sub_delay
trio.run(main)
def test_faster_task_to_recv_is_cancelled_by_slower(
arb_addr,
start_method,
):
'''Ensure that if a faster task consuming from a stream is cancelled
the slower task can continue to receive all expected values.
'''
async def main():
sequence = list(range(1000))
async with open_sequence_streamer(
sequence,
arb_addr,
start_method,
# NOTE: this MUST be set to avoid the stream terminating
# early when the faster subtask is cancelled by the slower
# parent task.
shield=True,
) as stream:
# alt to passing kwarg above.
# with stream.shield():
async with trio.open_nursery() as n:
n.start_soon(
ensure_sequence,
stream,
sequence.copy(),
0,
name='consumer_task',
)
await stream.send(tuple(sequence))
# pull 3 values, cancel the subtask, then
# expect to be able to pull all values still
for i in range(20):
try:
value = await stream.receive()
print(f'source stream rx: {value}')
await trio.sleep(0.01)
except Lagged:
print(f'parent overrun after {value}')
continue
print('cancelling faster subtask')
n.cancel_scope.cancel()
try:
value = await stream.receive()
print(f'source stream after cancel: {value}')
except Lagged:
print(f'parent overrun after {value}')
# expect to see all remaining values
with trio.fail_after(0.5):
async for value in stream:
assert stream._broadcaster._state.recv_ready is None
print(f'source stream rx: {value}')
if value == 999:
# fully consumed and we missed no values once
# the faster subtask was cancelled
break
# await tractor.breakpoint()
# await stream.receive()
print(f'final value: {value}')
trio.run(main)
def test_subscribe_errors_after_close():
async def main():
size = 1
tx, rx = trio.open_memory_channel(size)
async with broadcast_receiver(rx, size) as brx:
pass
try:
# open and close
async with brx.subscribe():
pass
except trio.ClosedResourceError:
assert brx.key not in brx._state.subs
else:
assert 0
trio.run(main)
def test_ensure_slow_consumers_lag_out(
arb_addr,
start_method,
):
'''This is a pure local task test; no tractor
machinery is really required.
'''
async def main():
# make sure it all works within the runtime
async with tractor.open_root_actor():
num_laggers = 4
laggers: dict[str, int] = {}
retries = 3
size = 100
tx, rx = trio.open_memory_channel(size)
brx = broadcast_receiver(rx, size)
async def sub_and_print(
delay: float,
) -> None:
task = current_task()
start = time.time()
async with brx.subscribe() as lbrx:
while True:
print(f'{task.name}: starting consume loop')
try:
async for value in lbrx:
print(f'{task.name}: {value}')
await trio.sleep(delay)
if task.name == 'sub_1':
# the non-lagger got
# a ``trio.EndOfChannel``
# because the ``tx`` below was closed
assert len(lbrx._state.subs) == 1
await lbrx.aclose()
assert len(lbrx._state.subs) == 0
except trio.ClosedResourceError:
# only the fast sub will try to re-enter
# iteration on the now closed bcaster
assert task.name == 'sub_1'
return
except Lagged:
lag_time = time.time() - start
lags = laggers[task.name]
print(
f'restarting slow task {task.name} '
f'that bailed out on {lags}:{value} '
f'after {lag_time:.3f}')
if lags <= retries:
laggers[task.name] += 1
continue
else:
print(
f'{task.name} was too slow and terminated '
f'on {lags}:{value}')
return
async with trio.open_nursery() as nursery:
for i in range(1, num_laggers):
task_name = f'sub_{i}'
laggers[task_name] = 0
nursery.start_soon(
partial(
sub_and_print,
delay=i*0.001,
),
name=task_name,
)
# allow subs to sched
await trio.sleep(0.1)
async with tx:
for i in cycle(range(size)):
await tx.send(i)
if len(brx._state.subs) == 2:
# only one, the non lagger, sub is left
break
# the non-lagger
assert laggers.pop('sub_1') == 0
for n, v in laggers.items():
assert v == 4
assert tx._closed
assert not tx._state.open_send_channels
# check that "first" bcaster that we created
# above, never wass iterated and is thus overrun
try:
await brx.receive()
except Lagged:
# expect tokio style index truncation
seq = brx._state.subs[brx.key]
assert seq == len(brx._state.queue) - 1
# all backpressured entries in the underlying
# channel should have been copied into the caster
# queue trailing-window
async for i in rx:
print(f'bped: {i}')
assert i in brx._state.queue
# should be noop
await brx.aclose()
trio.run(main)

View File

@ -567,7 +567,7 @@ class Actor:
try: try:
send_chan, recv_chan = self._cids2qs[(actorid, cid)] send_chan, recv_chan = self._cids2qs[(actorid, cid)]
except KeyError: except KeyError:
send_chan, recv_chan = trio.open_memory_channel(1000) send_chan, recv_chan = trio.open_memory_channel(2**6)
send_chan.cid = cid # type: ignore send_chan.cid = cid # type: ignore
recv_chan.cid = cid # type: ignore recv_chan.cid = cid # type: ignore
self._cids2qs[(actorid, cid)] = send_chan, recv_chan self._cids2qs[(actorid, cid)] = send_chan, recv_chan

View File

@ -0,0 +1,308 @@
'''
``tokio`` style broadcast channel.
https://tokio-rs.github.io/tokio/doc/tokio/sync/broadcast/index.html
'''
from __future__ import annotations
from abc import abstractmethod
from collections import deque
from contextlib import asynccontextmanager
from dataclasses import dataclass
from functools import partial
from operator import ne
from typing import Optional, Callable, Awaitable, Any, AsyncIterator, Protocol
from typing import Generic, TypeVar
import trio
from trio._core._run import Task
from trio.abc import ReceiveChannel
from trio.lowlevel import current_task
# A regular invariant generic type
T = TypeVar("T")
# covariant because AsyncReceiver[Derived] can be passed to someone
# expecting AsyncReceiver[Base])
ReceiveType = TypeVar("ReceiveType", covariant=True)
class AsyncReceiver(
Protocol,
Generic[ReceiveType],
):
'''An async receivable duck-type that quacks much like trio's
``trio.abc.ReceieveChannel``.
'''
@abstractmethod
async def receive(self) -> ReceiveType:
...
@abstractmethod
def __aiter__(self) -> AsyncIterator[ReceiveType]:
...
@abstractmethod
async def __anext__(self) -> ReceiveType:
...
# ``trio.abc.AsyncResource`` methods
@abstractmethod
async def aclose(self):
...
@abstractmethod
async def __aenter__(self) -> AsyncReceiver[ReceiveType]:
...
@abstractmethod
async def __aexit__(self, *args) -> None:
...
class Lagged(trio.TooSlowError):
'''Subscribed consumer task was too slow and was overrun
by the fastest consumer-producer pair.
'''
@dataclass
class BroadcastState:
'''Common state to all receivers of a broadcast.
'''
queue: deque
maxlen: int
# map of underlying instance id keys to receiver instances which
# must be provided as a singleton per broadcaster set.
subs: dict[int, int]
# broadcast event to wake up all sleeping consumer tasks
# on a newly produced value from the sender.
recv_ready: Optional[tuple[int, trio.Event]] = None
class BroadcastReceiver(ReceiveChannel):
'''A memory receive channel broadcaster which is non-lossy for the
fastest consumer.
Additional consumer tasks can receive all produced values by registering
with ``.subscribe()`` and receiving from the new instance it delivers.
'''
def __init__(
self,
rx_chan: AsyncReceiver,
state: BroadcastState,
receive_afunc: Optional[Callable[[], Awaitable[Any]]] = None,
) -> None:
# register the original underlying (clone)
self.key = id(self)
self._state = state
state.subs[self.key] = -1
# underlying for this receiver
self._rx = rx_chan
self._recv = receive_afunc or rx_chan.receive
self._closed: bool = False
async def receive(self) -> ReceiveType:
key = self.key
state = self._state
# TODO: ideally we can make some way to "lock out" the
# underlying receive channel in some way such that if some task
# tries to pull from it directly (i.e. one we're unaware of)
# then it errors out.
# only tasks which have entered ``.subscribe()`` can
# receive on this broadcaster.
try:
seq = state.subs[key]
except KeyError:
if self._closed:
raise trio.ClosedResourceError
raise RuntimeError(
f'{self} is not registerd as subscriber')
# check that task does not already have a value it can receive
# immediately and/or that it has lagged.
if seq > -1:
# get the oldest value we haven't received immediately
try:
value = state.queue[seq]
except IndexError:
# adhere to ``tokio`` style "lagging":
# "Once RecvError::Lagged is returned, the lagging
# receiver's position is updated to the oldest value
# contained by the channel. The next call to recv will
# return this value."
# https://tokio-rs.github.io/tokio/doc/tokio/sync/broadcast/index.html#lagging
# decrement to the last value and expect
# consumer to either handle the ``Lagged`` and come back
# or bail out on its own (thus un-subscribing)
state.subs[key] = state.maxlen - 1
# this task was overrun by the producer side
task: Task = current_task()
raise Lagged(f'Task {task.name} was overrun')
state.subs[key] -= 1
return value
# current task already has the latest value **and** is the
# first task to begin waiting for a new one
if state.recv_ready is None:
if self._closed:
raise trio.ClosedResourceError
event = trio.Event()
state.recv_ready = key, event
# if we're cancelled here it should be
# fine to bail without affecting any other consumers
# right?
try:
value = await self._recv()
# items with lower indices are "newer"
# NOTE: ``collections.deque`` implicitly takes care of
# trucating values outside our ``state.maxlen``. In the
# alt-backend-array-case we'll need to make sure this is
# implemented in similar ringer-buffer-ish style.
state.queue.appendleft(value)
# broadcast new value to all subscribers by increasing
# all sequence numbers that will point in the queue to
# their latest available value.
# don't decrement the sequence for this task since we
# already retreived the last value
# XXX: which of these impls is fastest?
# subs = state.subs.copy()
# subs.pop(key)
for sub_key in filter(
# lambda k: k != key, state.subs,
partial(ne, key), state.subs,
):
state.subs[sub_key] += 1
# NOTE: this should ONLY be set if the above task was *NOT*
# cancelled on the `._recv()` call otherwise sibling
# consumers will be awoken with a sequence of -1
event.set()
return value
finally:
# Reset receiver waiter task event for next blocking condition.
# this MUST be reset even if the above ``.recv()`` call
# was cancelled to avoid the next consumer from blocking on
# an event that won't be set!
state.recv_ready = None
# This task is all caught up and ready to receive the latest
# value, so queue sched it on the internal event.
else:
seq = state.subs[key]
assert seq == -1 # sanity
_, ev = state.recv_ready
await ev.wait()
seq = state.subs[key]
assert seq > -1, f'Invalid sequence {seq}!?'
value = state.queue[seq]
state.subs[key] -= 1
return value
# NOTE: if we ever would like the behaviour where if the
# first task to recv on the underlying is cancelled but it
# still DOES trigger the ``.recv_ready``, event we'll likely need
# this logic:
# if seq > -1:
# # stuff from above..
# elif seq == -1:
# # XXX: In the case where the first task to allocate the
# # ``.recv_ready`` event is cancelled we will be woken with
# # a non-incremented sequence number and thus will read the
# # oldest value if we use that. Instead we need to detect if
# # we have not been incremented and then receive again.
# return await self.receive()
# else:
# raise ValueError(f'Invalid sequence {seq}!?')
@asynccontextmanager
async def subscribe(
self,
) -> AsyncIterator[BroadcastReceiver]:
'''Subscribe for values from this broadcast receiver.
Returns a new ``BroadCastReceiver`` which is registered for and
pulls data from a clone of the original ``trio.abc.ReceiveChannel``
provided at creation.
'''
if self._closed:
raise trio.ClosedResourceError
state = self._state
br = BroadcastReceiver(
rx_chan=self._rx,
state=state,
receive_afunc=self._recv,
)
# assert clone in state.subs
assert br.key in state.subs
try:
yield br
finally:
await br.aclose()
async def aclose(
self,
) -> None:
if self._closed:
return
# XXX: leaving it like this consumers can still get values
# up to the last received that still reside in the queue.
self._state.subs.pop(self.key)
self._closed = True
def broadcast_receiver(
recv_chan: AsyncReceiver,
max_buffer_size: int,
**kwargs,
) -> BroadcastReceiver:
return BroadcastReceiver(
recv_chan,
state=BroadcastState(
queue=deque(maxlen=max_buffer_size),
maxlen=max_buffer_size,
subs={},
),
**kwargs,
)

View File

@ -294,6 +294,7 @@ class Portal:
async def open_stream_from( async def open_stream_from(
self, self,
async_gen_func: Callable, # typing: ignore async_gen_func: Callable, # typing: ignore
shield: bool = False,
**kwargs, **kwargs,
) -> AsyncGenerator[ReceiveMsgStream, None]: ) -> AsyncGenerator[ReceiveMsgStream, None]:
@ -320,7 +321,9 @@ class Portal:
ctx = Context(self.channel, cid, _portal=self) ctx = Context(self.channel, cid, _portal=self)
try: try:
# deliver receive only stream # deliver receive only stream
async with ReceiveMsgStream(ctx, recv_chan) as rchan: async with ReceiveMsgStream(
ctx, recv_chan, shield=shield
) as rchan:
self._streams.add(rchan) self._streams.add(rchan)
yield rchan yield rchan

View File

@ -2,12 +2,14 @@
Message stream types and APIs. Message stream types and APIs.
""" """
from __future__ import annotations
import inspect import inspect
from contextlib import contextmanager, asynccontextmanager from contextlib import contextmanager, asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import ( from typing import (
Any, Iterator, Optional, Callable, Any, Iterator, Optional, Callable,
AsyncGenerator, Dict, AsyncGenerator, Dict,
AsyncIterator
) )
import warnings import warnings
@ -17,6 +19,7 @@ import trio
from ._ipc import Channel from ._ipc import Channel
from ._exceptions import unpack_error, ContextCancelled from ._exceptions import unpack_error, ContextCancelled
from ._state import current_actor from ._state import current_actor
from ._broadcast import broadcast_receiver, BroadcastReceiver
from .log import get_logger from .log import get_logger
@ -45,12 +48,14 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
def __init__( def __init__(
self, self,
ctx: 'Context', # typing: ignore # noqa ctx: 'Context', # typing: ignore # noqa
rx_chan: trio.abc.ReceiveChannel, rx_chan: trio.MemoryReceiveChannel,
shield: bool = False, shield: bool = False,
_broadcaster: Optional[BroadcastReceiver] = None,
) -> None: ) -> None:
self._ctx = ctx self._ctx = ctx
self._rx_chan = rx_chan self._rx_chan = rx_chan
self._shielded = shield self._broadcaster = _broadcaster
# flag to denote end of stream # flag to denote end of stream
self._eoc: bool = False self._eoc: bool = False
@ -103,7 +108,10 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
except ( except (
trio.ClosedResourceError, # by self._rx_chan trio.ClosedResourceError, # by self._rx_chan
trio.EndOfChannel, # by self._rx_chan or `stop` msg from far end trio.EndOfChannel, # by self._rx_chan or `stop` msg from far end
trio.Cancelled, # by local cancellation
# Wait why would we do an implicit close on cancel? THAT'S
# NOT HOW MEM CHANS WORK!!?!?!?!?
# trio.Cancelled, # by local cancellation
): ):
# XXX: we close the stream on any of these error conditions: # XXX: we close the stream on any of these error conditions:
@ -135,23 +143,6 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
raise # propagate raise # propagate
@contextmanager
def shield(
self
) -> Iterator['ReceiveMsgStream']: # noqa
"""Shield this stream's underlying channel such that a local consumer task
can be cancelled (and possibly restarted) using ``trio.Cancelled``.
Note that here, "shielding" here guards against relaying
a ``'stop'`` message to the far end of the stream thus keeping
the stream machinery active and ready for further use, it does
not have anything to do with an internal ``trio.CancelScope``.
"""
self._shielded = True
yield self
self._shielded = False
async def aclose(self): async def aclose(self):
"""Cancel associated remote actor task and local memory channel """Cancel associated remote actor task and local memory channel
on close. on close.
@ -169,18 +160,10 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
# https://trio.readthedocs.io/en/stable/reference-io.html#trio.abc.AsyncResource.aclose # https://trio.readthedocs.io/en/stable/reference-io.html#trio.abc.AsyncResource.aclose
return return
# TODO: broadcasting to multiple consumers # if self._shielded:
# stats = rx_chan.statistics() # log.warning(f"{self} is shielded, portal channel being kept alive")
# if stats.open_receive_channels > 1:
# # if we've been cloned don't kill the stream
# log.debug(
# "there are still consumers running keeping stream alive")
# return # return
if self._shielded:
log.warning(f"{self} is shielded, portal channel being kept alive")
return
# XXX: This must be set **AFTER** the shielded test above! # XXX: This must be set **AFTER** the shielded test above!
self._eoc = True self._eoc = True
@ -253,6 +236,50 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
# still need to consume msgs that are "in transit" from the far # still need to consume msgs that are "in transit" from the far
# end (eg. for ``Context.result()``). # end (eg. for ``Context.result()``).
@asynccontextmanager
async def subscribe(
self,
) -> AsyncIterator[BroadcastReceiver]:
'''Allocate and return a ``BroadcastReceiver`` which delegates
to this message stream.
This allows multiple local tasks to receive each their own copy
of this message stream.
This operation is indempotent and and mutates this stream's
receive machinery to copy and window-length-store each received
value from the far end via the internally created broudcast
receiver wrapper.
'''
# NOTE: This operation is indempotent and non-reversible, so be
# sure you can deal with any (theoretical) overhead of the the
# allocated ``BroadcastReceiver`` before calling this method for
# the first time.
if self._broadcaster is None:
bcast = self._broadcaster = broadcast_receiver(
self,
# use memory channel size by default
self._rx_chan._state.max_buffer_size, # type: ignore
receive_afunc=self.receive,
)
# NOTE: we override the original stream instance's receive
# method to now delegate to the broadcaster's ``.receive()``
# such that new subscribers will be copied received values
# and this stream doesn't have to expect it's original
# consumer(s) to get a new broadcast rx handle.
self.receive = bcast.receive # type: ignore
# seems there's no graceful way to type this with ``mypy``?
# https://github.com/python/mypy/issues/708
async with self._broadcaster.subscribe() as bstream:
assert bstream.key != self._broadcaster.key
assert bstream._recv == self._broadcaster._recv
yield bstream
class MsgStream(ReceiveMsgStream, trio.abc.Channel): class MsgStream(ReceiveMsgStream, trio.abc.Channel):
""" """
@ -269,17 +296,6 @@ class MsgStream(ReceiveMsgStream, trio.abc.Channel):
''' '''
await self._ctx.chan.send({'yield': data, 'cid': self._ctx.cid}) await self._ctx.chan.send({'yield': data, 'cid': self._ctx.cid})
# TODO: but make it broadcasting to consumers
def clone(self):
"""Clone this receive channel allowing for multi-task
consumption from the same channel.
"""
return MsgStream(
self._ctx,
self._rx_chan.clone(),
)
@dataclass @dataclass
class Context: class Context:
@ -397,7 +413,6 @@ class Context:
async def open_stream( async def open_stream(
self, self,
shield: bool = False,
) -> AsyncGenerator[MsgStream, None]: ) -> AsyncGenerator[MsgStream, None]:
'''Open a ``MsgStream``, a bi-directional stream connected to the '''Open a ``MsgStream``, a bi-directional stream connected to the
@ -455,7 +470,6 @@ class Context:
async with MsgStream( async with MsgStream(
ctx=self, ctx=self,
rx_chan=recv_chan, rx_chan=recv_chan,
shield=shield,
) as rchan: ) as rchan:
if self._portal: if self._portal: