diff --git a/newsfragments/278.bug.rst b/newsfragments/278.bug.rst new file mode 100644 index 0000000..6fef955 --- /dev/null +++ b/newsfragments/278.bug.rst @@ -0,0 +1,12 @@ +Repair inter-actor stream closure semantics to work correctly with +``tractor.trionics.BroadcastReceiver`` task fan out usage. + +A set of previously unknown bugs discovered in `257 +`_ let graceful stream +closure result in hanging consumer tasks that use the broadcast APIs. +This adds better internal closure state tracking to the broadcast +receiver and message stream APIs and in particular ensures that when an +underlying stream/receive-channel (a broadcast receiver is receiving +from) is closed, all consumer tasks waiting on that underlying channel +are woken so they can receive the ``trio.EndOfChannel`` signal and +promptly terminate. diff --git a/requirements-test.txt b/requirements-test.txt index a46c4f3..5ad6c45 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,7 +1,7 @@ pytest pytest-trio pdbpp -mypy +mypy<0.920 trio_typing pexpect -towncrier \ No newline at end of file +towncrier diff --git a/tests/test_advanced_streaming.py b/tests/test_advanced_streaming.py index 74c06ca..b3e2bc1 100644 --- a/tests/test_advanced_streaming.py +++ b/tests/test_advanced_streaming.py @@ -1,7 +1,8 @@ -""" +''' Advanced streaming patterns using bidirectional streams and contexts. -""" +''' +from collections import Counter import itertools from typing import Set, Dict, List @@ -269,3 +270,98 @@ def test_sigint_both_stream_types(): assert 0, "Didn't receive KBI!?" except KeyboardInterrupt: pass + + +@tractor.context +async def inf_streamer( + ctx: tractor.Context, + +) -> None: + ''' + Stream increasing ints until terminated with a 'done' msg. + + ''' + await ctx.started() + + async with ( + ctx.open_stream() as stream, + trio.open_nursery() as n, + ): + async def bail_on_sentinel(): + async for msg in stream: + if msg == 'done': + await stream.aclose() + else: + print(f'streamer received {msg}') + + # start termination detector + n.start_soon(bail_on_sentinel) + + for val in itertools.count(): + try: + await stream.send(val) + except trio.ClosedResourceError: + # close out the stream gracefully + break + + print('terminating streamer') + + +def test_local_task_fanout_from_stream(): + ''' + Single stream with multiple local consumer tasks using the + ``MsgStream.subscribe()` api. + + Ensure all tasks receive all values after stream completes sending. + + ''' + consumers = 22 + + async def main(): + + counts = Counter() + + async with tractor.open_nursery() as tn: + p = await tn.start_actor( + 'inf_streamer', + enable_modules=[__name__], + ) + async with ( + p.open_context(inf_streamer) as (ctx, _), + ctx.open_stream() as stream, + ): + + async def pull_and_count(name: str): + # name = trio.lowlevel.current_task().name + async with stream.subscribe() as recver: + assert isinstance( + recver, + tractor.trionics.BroadcastReceiver + ) + async for val in recver: + # print(f'{name}: {val}') + counts[name] += 1 + + print(f'{name} bcaster ended') + + print(f'{name} completed') + + with trio.fail_after(3): + async with trio.open_nursery() as nurse: + for i in range(consumers): + nurse.start_soon(pull_and_count, i) + + await trio.sleep(0.5) + print('\nterminating') + await stream.send('done') + + print('closed stream connection') + + assert len(counts) == consumers + mx = max(counts.values()) + # make sure each task received all stream values + assert all(val == mx for val in counts.values()) + + await p.cancel_actor() + + trio.run(main) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index baee54e..6f9d18c 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -79,33 +79,36 @@ async def stream_from_single_subactor( seq = range(10) - async with portal.open_stream_from( - stream_func, - sequence=list(seq), # has to be msgpack serializable - ) as stream: + with trio.fail_after(5): + async with portal.open_stream_from( + stream_func, + sequence=list(seq), # has to be msgpack serializable + ) as stream: - # it'd sure be nice to have an asyncitertools here... - iseq = iter(seq) - ival = next(iseq) + # it'd sure be nice to have an asyncitertools here... + iseq = iter(seq) + ival = next(iseq) - async for val in stream: - assert val == ival + async for val in stream: + assert val == ival + try: + ival = next(iseq) + except StopIteration: + # should cancel far end task which will be + # caught and no error is raised + await stream.aclose() + + await trio.sleep(0.3) + + # ensure EOC signalled-state translates + # XXX: not really sure this is correct, + # shouldn't it be a `ClosedResourceError`? try: - ival = next(iseq) - except StopIteration: - # should cancel far end task which will be - # caught and no error is raised - await stream.aclose() - - await trio.sleep(0.3) - - try: - await stream.__anext__() - except StopAsyncIteration: - # stop all spawned subactors - await portal.cancel_actor() - # await nursery.cancel() + await stream.__anext__() + except StopAsyncIteration: + # stop all spawned subactors + await portal.cancel_actor() @pytest.mark.parametrize( diff --git a/tractor/_streaming.py b/tractor/_streaming.py index 171ca4b..05932b8 100644 --- a/tractor/_streaming.py +++ b/tractor/_streaming.py @@ -78,6 +78,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): # flag to denote end of stream self._eoc: bool = False + self._closed: bool = False # delegate directly to underlying mem channel def receive_nowait(self): @@ -98,7 +99,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): msg = await self._rx_chan.receive() return msg['yield'] - except KeyError: + except KeyError as err: # internal error should never get here assert msg.get('cid'), ("Received internal error at portal?") @@ -107,9 +108,15 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): # - 'error' # possibly just handle msg['stop'] here! - if msg.get('stop'): + if msg.get('stop') or self._eoc: log.debug(f"{self} was stopped at remote end") + # XXX: important to set so that a new ``.receive()`` + # call (likely by another task using a broadcast receiver) + # doesn't accidentally pull the ``return`` message + # value out of the underlying feed mem chan! + self._eoc = True + # # when the send is closed we assume the stream has # # terminated and signal this local iterator to stop # await self.aclose() @@ -117,7 +124,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): # XXX: this causes ``ReceiveChannel.__anext__()`` to # raise a ``StopAsyncIteration`` **and** in our catch # block below it will trigger ``.aclose()``. - raise trio.EndOfChannel + raise trio.EndOfChannel from err # TODO: test that shows stream raising an expected error!!! elif msg.get('error'): @@ -162,10 +169,11 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): raise # propagate async def aclose(self): - """Cancel associated remote actor task and local memory channel - on close. + ''' + Cancel associated remote actor task and local memory channel on + close. - """ + ''' # XXX: keep proper adherance to trio's `.aclose()` semantics: # https://trio.readthedocs.io/en/stable/reference-io.html#trio.abc.AsyncResource.aclose rx_chan = self._rx_chan @@ -179,6 +187,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): return self._eoc = True + self._closed = True # NOTE: this is super subtle IPC messaging stuff: # Relay stop iteration to far end **iff** we're @@ -310,15 +319,16 @@ class MsgStream(ReceiveMsgStream, trio.abc.Channel): self, data: Any ) -> None: - '''Send a message over this stream to the far end. + ''' + Send a message over this stream to the far end. ''' - # if self._eoc: - # raise trio.ClosedResourceError('This stream is already ded') - if self._ctx._error: raise self._ctx._error # from None + if self._closed: + raise trio.ClosedResourceError('This stream was already closed') + await self._ctx.chan.send({'yield': data, 'cid': self._ctx.cid}) diff --git a/tractor/trionics/_broadcast.py b/tractor/trionics/_broadcast.py index 3a2e1e4..35711b2 100644 --- a/tractor/trionics/_broadcast.py +++ b/tractor/trionics/_broadcast.py @@ -100,6 +100,15 @@ class BroadcastState: # on a newly produced value from the sender. recv_ready: Optional[tuple[int, trio.Event]] = None + # if a ``trio.EndOfChannel`` is received on any + # consumer all consumers should be placed in this state + # such that the group is notified of the end-of-broadcast. + # For now, this is solely for testing/debugging purposes. + eoc: bool = False + + # If the broadcaster was cancelled, we might as well track it + cancelled: bool = False + class BroadcastReceiver(ReceiveChannel): '''A memory receive channel broadcaster which is non-lossy for the @@ -222,10 +231,23 @@ class BroadcastReceiver(ReceiveChannel): event.set() return value - except trio.Cancelled: + except trio.EndOfChannel: + # if any one consumer gets an EOC from the underlying + # receiver we need to unblock and send that signal to + # all other consumers. + self._state.eoc = True + if event.statistics().tasks_waiting: + event.set() + raise + + except ( + trio.Cancelled, + ): # handle cancelled specially otherwise sibling # consumers will be awoken with a sequence of -1 - # state.recv_ready = trio.Cancelled + # and will potentially try to rewait the underlying + # receiver instead of just cancelling immediately. + self._state.cancelled = True if event.statistics().tasks_waiting: event.set() raise @@ -274,11 +296,12 @@ class BroadcastReceiver(ReceiveChannel): async def subscribe( self, ) -> AsyncIterator[BroadcastReceiver]: - '''Subscribe for values from this broadcast receiver. + ''' + Subscribe for values from this broadcast receiver. Returns a new ``BroadCastReceiver`` which is registered for and - pulls data from a clone of the original ``trio.abc.ReceiveChannel`` - provided at creation. + pulls data from a clone of the original + ``trio.abc.ReceiveChannel`` provided at creation. ''' if self._closed: @@ -301,7 +324,10 @@ class BroadcastReceiver(ReceiveChannel): async def aclose( self, ) -> None: + ''' + Close this receiver without affecting other consumers. + ''' if self._closed: return