Merge pull request #278 from goodboy/end_of_channel_fixes

End of channel fixes for streams and broadcasting
new_mypy
goodboy 2021-12-16 18:01:04 -05:00 committed by GitHub
commit 2d6fbd5437
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 189 additions and 42 deletions

View File

@ -0,0 +1,12 @@
Repair inter-actor stream closure semantics to work correctly with
``tractor.trionics.BroadcastReceiver`` task fan out usage.
A set of previously unknown bugs discovered in `257
<https://github.com/goodboy/tractor/pull/257>`_ let graceful stream
closure result in hanging consumer tasks that use the broadcast APIs.
This adds better internal closure state tracking to the broadcast
receiver and message stream APIs and in particular ensures that when an
underlying stream/receive-channel (a broadcast receiver is receiving
from) is closed, all consumer tasks waiting on that underlying channel
are woken so they can receive the ``trio.EndOfChannel`` signal and
promptly terminate.

View File

@ -1,7 +1,7 @@
pytest pytest
pytest-trio pytest-trio
pdbpp pdbpp
mypy mypy<0.920
trio_typing trio_typing
pexpect pexpect
towncrier towncrier

View File

@ -1,7 +1,8 @@
""" '''
Advanced streaming patterns using bidirectional streams and contexts. Advanced streaming patterns using bidirectional streams and contexts.
""" '''
from collections import Counter
import itertools import itertools
from typing import Set, Dict, List from typing import Set, Dict, List
@ -269,3 +270,98 @@ def test_sigint_both_stream_types():
assert 0, "Didn't receive KBI!?" assert 0, "Didn't receive KBI!?"
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
@tractor.context
async def inf_streamer(
ctx: tractor.Context,
) -> None:
'''
Stream increasing ints until terminated with a 'done' msg.
'''
await ctx.started()
async with (
ctx.open_stream() as stream,
trio.open_nursery() as n,
):
async def bail_on_sentinel():
async for msg in stream:
if msg == 'done':
await stream.aclose()
else:
print(f'streamer received {msg}')
# start termination detector
n.start_soon(bail_on_sentinel)
for val in itertools.count():
try:
await stream.send(val)
except trio.ClosedResourceError:
# close out the stream gracefully
break
print('terminating streamer')
def test_local_task_fanout_from_stream():
'''
Single stream with multiple local consumer tasks using the
``MsgStream.subscribe()` api.
Ensure all tasks receive all values after stream completes sending.
'''
consumers = 22
async def main():
counts = Counter()
async with tractor.open_nursery() as tn:
p = await tn.start_actor(
'inf_streamer',
enable_modules=[__name__],
)
async with (
p.open_context(inf_streamer) as (ctx, _),
ctx.open_stream() as stream,
):
async def pull_and_count(name: str):
# name = trio.lowlevel.current_task().name
async with stream.subscribe() as recver:
assert isinstance(
recver,
tractor.trionics.BroadcastReceiver
)
async for val in recver:
# print(f'{name}: {val}')
counts[name] += 1
print(f'{name} bcaster ended')
print(f'{name} completed')
with trio.fail_after(3):
async with trio.open_nursery() as nurse:
for i in range(consumers):
nurse.start_soon(pull_and_count, i)
await trio.sleep(0.5)
print('\nterminating')
await stream.send('done')
print('closed stream connection')
assert len(counts) == consumers
mx = max(counts.values())
# make sure each task received all stream values
assert all(val == mx for val in counts.values())
await p.cancel_actor()
trio.run(main)

View File

@ -79,33 +79,36 @@ async def stream_from_single_subactor(
seq = range(10) seq = range(10)
async with portal.open_stream_from( with trio.fail_after(5):
stream_func, async with portal.open_stream_from(
sequence=list(seq), # has to be msgpack serializable stream_func,
) as stream: sequence=list(seq), # has to be msgpack serializable
) as stream:
# it'd sure be nice to have an asyncitertools here... # it'd sure be nice to have an asyncitertools here...
iseq = iter(seq) iseq = iter(seq)
ival = next(iseq) ival = next(iseq)
async for val in stream: async for val in stream:
assert val == ival assert val == ival
try:
ival = next(iseq)
except StopIteration:
# should cancel far end task which will be
# caught and no error is raised
await stream.aclose()
await trio.sleep(0.3)
# ensure EOC signalled-state translates
# XXX: not really sure this is correct,
# shouldn't it be a `ClosedResourceError`?
try: try:
ival = next(iseq) await stream.__anext__()
except StopIteration: except StopAsyncIteration:
# should cancel far end task which will be # stop all spawned subactors
# caught and no error is raised await portal.cancel_actor()
await stream.aclose()
await trio.sleep(0.3)
try:
await stream.__anext__()
except StopAsyncIteration:
# stop all spawned subactors
await portal.cancel_actor()
# await nursery.cancel()
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@ -78,6 +78,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
# flag to denote end of stream # flag to denote end of stream
self._eoc: bool = False self._eoc: bool = False
self._closed: bool = False
# delegate directly to underlying mem channel # delegate directly to underlying mem channel
def receive_nowait(self): def receive_nowait(self):
@ -98,7 +99,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
msg = await self._rx_chan.receive() msg = await self._rx_chan.receive()
return msg['yield'] return msg['yield']
except KeyError: except KeyError as err:
# internal error should never get here # internal error should never get here
assert msg.get('cid'), ("Received internal error at portal?") assert msg.get('cid'), ("Received internal error at portal?")
@ -107,9 +108,15 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
# - 'error' # - 'error'
# possibly just handle msg['stop'] here! # possibly just handle msg['stop'] here!
if msg.get('stop'): if msg.get('stop') or self._eoc:
log.debug(f"{self} was stopped at remote end") log.debug(f"{self} was stopped at remote end")
# XXX: important to set so that a new ``.receive()``
# call (likely by another task using a broadcast receiver)
# doesn't accidentally pull the ``return`` message
# value out of the underlying feed mem chan!
self._eoc = True
# # when the send is closed we assume the stream has # # when the send is closed we assume the stream has
# # terminated and signal this local iterator to stop # # terminated and signal this local iterator to stop
# await self.aclose() # await self.aclose()
@ -117,7 +124,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
# XXX: this causes ``ReceiveChannel.__anext__()`` to # XXX: this causes ``ReceiveChannel.__anext__()`` to
# raise a ``StopAsyncIteration`` **and** in our catch # raise a ``StopAsyncIteration`` **and** in our catch
# block below it will trigger ``.aclose()``. # block below it will trigger ``.aclose()``.
raise trio.EndOfChannel raise trio.EndOfChannel from err
# TODO: test that shows stream raising an expected error!!! # TODO: test that shows stream raising an expected error!!!
elif msg.get('error'): elif msg.get('error'):
@ -162,10 +169,11 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
raise # propagate raise # propagate
async def aclose(self): async def aclose(self):
"""Cancel associated remote actor task and local memory channel '''
on close. Cancel associated remote actor task and local memory channel on
close.
""" '''
# XXX: keep proper adherance to trio's `.aclose()` semantics: # XXX: keep proper adherance to trio's `.aclose()` semantics:
# https://trio.readthedocs.io/en/stable/reference-io.html#trio.abc.AsyncResource.aclose # https://trio.readthedocs.io/en/stable/reference-io.html#trio.abc.AsyncResource.aclose
rx_chan = self._rx_chan rx_chan = self._rx_chan
@ -179,6 +187,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
return return
self._eoc = True self._eoc = True
self._closed = True
# NOTE: this is super subtle IPC messaging stuff: # NOTE: this is super subtle IPC messaging stuff:
# Relay stop iteration to far end **iff** we're # Relay stop iteration to far end **iff** we're
@ -310,15 +319,16 @@ class MsgStream(ReceiveMsgStream, trio.abc.Channel):
self, self,
data: Any data: Any
) -> None: ) -> None:
'''Send a message over this stream to the far end. '''
Send a message over this stream to the far end.
''' '''
# if self._eoc:
# raise trio.ClosedResourceError('This stream is already ded')
if self._ctx._error: if self._ctx._error:
raise self._ctx._error # from None raise self._ctx._error # from None
if self._closed:
raise trio.ClosedResourceError('This stream was already closed')
await self._ctx.chan.send({'yield': data, 'cid': self._ctx.cid}) await self._ctx.chan.send({'yield': data, 'cid': self._ctx.cid})

View File

@ -100,6 +100,15 @@ class BroadcastState:
# on a newly produced value from the sender. # on a newly produced value from the sender.
recv_ready: Optional[tuple[int, trio.Event]] = None recv_ready: Optional[tuple[int, trio.Event]] = None
# if a ``trio.EndOfChannel`` is received on any
# consumer all consumers should be placed in this state
# such that the group is notified of the end-of-broadcast.
# For now, this is solely for testing/debugging purposes.
eoc: bool = False
# If the broadcaster was cancelled, we might as well track it
cancelled: bool = False
class BroadcastReceiver(ReceiveChannel): class BroadcastReceiver(ReceiveChannel):
'''A memory receive channel broadcaster which is non-lossy for the '''A memory receive channel broadcaster which is non-lossy for the
@ -222,10 +231,23 @@ class BroadcastReceiver(ReceiveChannel):
event.set() event.set()
return value return value
except trio.Cancelled: except trio.EndOfChannel:
# if any one consumer gets an EOC from the underlying
# receiver we need to unblock and send that signal to
# all other consumers.
self._state.eoc = True
if event.statistics().tasks_waiting:
event.set()
raise
except (
trio.Cancelled,
):
# handle cancelled specially otherwise sibling # handle cancelled specially otherwise sibling
# consumers will be awoken with a sequence of -1 # consumers will be awoken with a sequence of -1
# state.recv_ready = trio.Cancelled # and will potentially try to rewait the underlying
# receiver instead of just cancelling immediately.
self._state.cancelled = True
if event.statistics().tasks_waiting: if event.statistics().tasks_waiting:
event.set() event.set()
raise raise
@ -274,11 +296,12 @@ class BroadcastReceiver(ReceiveChannel):
async def subscribe( async def subscribe(
self, self,
) -> AsyncIterator[BroadcastReceiver]: ) -> AsyncIterator[BroadcastReceiver]:
'''Subscribe for values from this broadcast receiver. '''
Subscribe for values from this broadcast receiver.
Returns a new ``BroadCastReceiver`` which is registered for and Returns a new ``BroadCastReceiver`` which is registered for and
pulls data from a clone of the original ``trio.abc.ReceiveChannel`` pulls data from a clone of the original
provided at creation. ``trio.abc.ReceiveChannel`` provided at creation.
''' '''
if self._closed: if self._closed:
@ -301,7 +324,10 @@ class BroadcastReceiver(ReceiveChannel):
async def aclose( async def aclose(
self, self,
) -> None: ) -> None:
'''
Close this receiver without affecting other consumers.
'''
if self._closed: if self._closed:
return return