forked from goodboy/tractor
				
			Merge pull request #278 from goodboy/end_of_channel_fixes
End of channel fixes for streams and broadcastingnew_mypy
						commit
						2d6fbd5437
					
				|  | @ -0,0 +1,12 @@ | |||
| Repair inter-actor stream closure semantics to work correctly with | ||||
| ``tractor.trionics.BroadcastReceiver`` task fan out usage. | ||||
| 
 | ||||
| A set of previously unknown bugs discovered in `257 | ||||
| <https://github.com/goodboy/tractor/pull/257>`_ let graceful stream | ||||
| closure result in hanging consumer tasks that use the broadcast APIs. | ||||
| This adds better internal closure state tracking to the broadcast | ||||
| receiver and message stream APIs and in particular ensures that when an | ||||
| underlying stream/receive-channel (a broadcast receiver is receiving | ||||
| from) is closed, all consumer tasks waiting on that underlying channel | ||||
| are woken so they can receive the ``trio.EndOfChannel`` signal and | ||||
| promptly terminate. | ||||
|  | @ -1,7 +1,7 @@ | |||
| pytest | ||||
| pytest-trio | ||||
| pdbpp | ||||
| mypy | ||||
| mypy<0.920 | ||||
| trio_typing | ||||
| pexpect | ||||
| towncrier | ||||
|  | @ -1,7 +1,8 @@ | |||
| """ | ||||
| ''' | ||||
| Advanced streaming patterns using bidirectional streams and contexts. | ||||
| 
 | ||||
| """ | ||||
| ''' | ||||
| from collections import Counter | ||||
| import itertools | ||||
| from typing import Set, Dict, List | ||||
| 
 | ||||
|  | @ -269,3 +270,98 @@ def test_sigint_both_stream_types(): | |||
|         assert 0, "Didn't receive KBI!?" | ||||
|     except KeyboardInterrupt: | ||||
|         pass | ||||
| 
 | ||||
| 
 | ||||
| @tractor.context | ||||
| async def inf_streamer( | ||||
|     ctx: tractor.Context, | ||||
| 
 | ||||
| ) -> None: | ||||
|     ''' | ||||
|     Stream increasing ints until terminated with a 'done' msg. | ||||
| 
 | ||||
|     ''' | ||||
|     await ctx.started() | ||||
| 
 | ||||
|     async with ( | ||||
|         ctx.open_stream() as stream, | ||||
|         trio.open_nursery() as n, | ||||
|     ): | ||||
|         async def bail_on_sentinel(): | ||||
|             async for msg in stream: | ||||
|                 if msg == 'done': | ||||
|                     await stream.aclose() | ||||
|                 else: | ||||
|                     print(f'streamer received {msg}') | ||||
| 
 | ||||
|         # start termination detector | ||||
|         n.start_soon(bail_on_sentinel) | ||||
| 
 | ||||
|         for val in itertools.count(): | ||||
|             try: | ||||
|                 await stream.send(val) | ||||
|             except trio.ClosedResourceError: | ||||
|                 # close out the stream gracefully | ||||
|                 break | ||||
| 
 | ||||
|     print('terminating streamer') | ||||
| 
 | ||||
| 
 | ||||
| def test_local_task_fanout_from_stream(): | ||||
|     ''' | ||||
|     Single stream with multiple local consumer tasks using the | ||||
|     ``MsgStream.subscribe()` api. | ||||
| 
 | ||||
|     Ensure all tasks receive all values after stream completes sending. | ||||
| 
 | ||||
|     ''' | ||||
|     consumers = 22 | ||||
| 
 | ||||
|     async def main(): | ||||
| 
 | ||||
|         counts = Counter() | ||||
| 
 | ||||
|         async with tractor.open_nursery() as tn: | ||||
|             p = await tn.start_actor( | ||||
|                 'inf_streamer', | ||||
|                 enable_modules=[__name__], | ||||
|             ) | ||||
|             async with ( | ||||
|                 p.open_context(inf_streamer) as (ctx, _), | ||||
|                 ctx.open_stream() as stream, | ||||
|             ): | ||||
| 
 | ||||
|                 async def pull_and_count(name: str): | ||||
|                     # name = trio.lowlevel.current_task().name | ||||
|                     async with stream.subscribe() as recver: | ||||
|                         assert isinstance( | ||||
|                             recver, | ||||
|                             tractor.trionics.BroadcastReceiver | ||||
|                         ) | ||||
|                         async for val in recver: | ||||
|                             # print(f'{name}: {val}') | ||||
|                             counts[name] += 1 | ||||
| 
 | ||||
|                         print(f'{name} bcaster ended') | ||||
| 
 | ||||
|                     print(f'{name} completed') | ||||
| 
 | ||||
|                 with trio.fail_after(3): | ||||
|                     async with trio.open_nursery() as nurse: | ||||
|                         for i in range(consumers): | ||||
|                             nurse.start_soon(pull_and_count, i) | ||||
| 
 | ||||
|                         await trio.sleep(0.5) | ||||
|                         print('\nterminating') | ||||
|                         await stream.send('done') | ||||
| 
 | ||||
|             print('closed stream connection') | ||||
| 
 | ||||
|             assert len(counts) == consumers | ||||
|             mx = max(counts.values()) | ||||
|             # make sure each task received all stream values | ||||
|             assert all(val == mx for val in counts.values()) | ||||
| 
 | ||||
|             await p.cancel_actor() | ||||
| 
 | ||||
|     trio.run(main) | ||||
|  |  | |||
|  | @ -79,33 +79,36 @@ async def stream_from_single_subactor( | |||
| 
 | ||||
|                 seq = range(10) | ||||
| 
 | ||||
|                 async with portal.open_stream_from( | ||||
|                     stream_func, | ||||
|                     sequence=list(seq),  # has to be msgpack serializable | ||||
|                 ) as stream: | ||||
|                 with trio.fail_after(5): | ||||
|                     async with portal.open_stream_from( | ||||
|                         stream_func, | ||||
|                         sequence=list(seq),  # has to be msgpack serializable | ||||
|                     ) as stream: | ||||
| 
 | ||||
|                     # it'd sure be nice to have an asyncitertools here... | ||||
|                     iseq = iter(seq) | ||||
|                     ival = next(iseq) | ||||
|                         # it'd sure be nice to have an asyncitertools here... | ||||
|                         iseq = iter(seq) | ||||
|                         ival = next(iseq) | ||||
| 
 | ||||
|                     async for val in stream: | ||||
|                         assert val == ival | ||||
|                         async for val in stream: | ||||
|                             assert val == ival | ||||
| 
 | ||||
|                             try: | ||||
|                                 ival = next(iseq) | ||||
|                             except StopIteration: | ||||
|                                 # should cancel far end task which will be | ||||
|                                 # caught and no error is raised | ||||
|                                 await stream.aclose() | ||||
| 
 | ||||
|                         await trio.sleep(0.3) | ||||
| 
 | ||||
|                         # ensure EOC signalled-state translates | ||||
|                         # XXX: not really sure this is correct, | ||||
|                         # shouldn't it be a `ClosedResourceError`? | ||||
|                         try: | ||||
|                             ival = next(iseq) | ||||
|                         except StopIteration: | ||||
|                             # should cancel far end task which will be | ||||
|                             # caught and no error is raised | ||||
|                             await stream.aclose() | ||||
| 
 | ||||
|                     await trio.sleep(0.3) | ||||
| 
 | ||||
|                     try: | ||||
|                         await stream.__anext__() | ||||
|                     except StopAsyncIteration: | ||||
|                         # stop all spawned subactors | ||||
|                         await portal.cancel_actor() | ||||
|                     # await nursery.cancel() | ||||
|                             await stream.__anext__() | ||||
|                         except StopAsyncIteration: | ||||
|                             # stop all spawned subactors | ||||
|                             await portal.cancel_actor() | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|  |  | |||
|  | @ -78,6 +78,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): | |||
| 
 | ||||
|         # flag to denote end of stream | ||||
|         self._eoc: bool = False | ||||
|         self._closed: bool = False | ||||
| 
 | ||||
|     # delegate directly to underlying mem channel | ||||
|     def receive_nowait(self): | ||||
|  | @ -98,7 +99,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): | |||
|             msg = await self._rx_chan.receive() | ||||
|             return msg['yield'] | ||||
| 
 | ||||
|         except KeyError: | ||||
|         except KeyError as err: | ||||
|             # internal error should never get here | ||||
|             assert msg.get('cid'), ("Received internal error at portal?") | ||||
| 
 | ||||
|  | @ -107,9 +108,15 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): | |||
|             # - 'error' | ||||
|             # possibly just handle msg['stop'] here! | ||||
| 
 | ||||
|             if msg.get('stop'): | ||||
|             if msg.get('stop') or self._eoc: | ||||
|                 log.debug(f"{self} was stopped at remote end") | ||||
| 
 | ||||
|                 # XXX: important to set so that a new ``.receive()`` | ||||
|                 # call (likely by another task using a broadcast receiver) | ||||
|                 # doesn't accidentally pull the ``return`` message | ||||
|                 # value out of the underlying feed mem chan! | ||||
|                 self._eoc = True | ||||
| 
 | ||||
|                 # # when the send is closed we assume the stream has | ||||
|                 # # terminated and signal this local iterator to stop | ||||
|                 # await self.aclose() | ||||
|  | @ -117,7 +124,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): | |||
|                 # XXX: this causes ``ReceiveChannel.__anext__()`` to | ||||
|                 # raise a ``StopAsyncIteration`` **and** in our catch | ||||
|                 # block below it will trigger ``.aclose()``. | ||||
|                 raise trio.EndOfChannel | ||||
|                 raise trio.EndOfChannel from err | ||||
| 
 | ||||
|             # TODO: test that shows stream raising an expected error!!! | ||||
|             elif msg.get('error'): | ||||
|  | @ -162,10 +169,11 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): | |||
|             raise  # propagate | ||||
| 
 | ||||
|     async def aclose(self): | ||||
|         """Cancel associated remote actor task and local memory channel | ||||
|         on close. | ||||
|         ''' | ||||
|         Cancel associated remote actor task and local memory channel on | ||||
|         close. | ||||
| 
 | ||||
|         """ | ||||
|         ''' | ||||
|         # XXX: keep proper adherance to trio's `.aclose()` semantics: | ||||
|         # https://trio.readthedocs.io/en/stable/reference-io.html#trio.abc.AsyncResource.aclose | ||||
|         rx_chan = self._rx_chan | ||||
|  | @ -179,6 +187,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): | |||
|             return | ||||
| 
 | ||||
|         self._eoc = True | ||||
|         self._closed = True | ||||
| 
 | ||||
|         # NOTE: this is super subtle IPC messaging stuff: | ||||
|         # Relay stop iteration to far end **iff** we're | ||||
|  | @ -310,15 +319,16 @@ class MsgStream(ReceiveMsgStream, trio.abc.Channel): | |||
|         self, | ||||
|         data: Any | ||||
|     ) -> None: | ||||
|         '''Send a message over this stream to the far end. | ||||
|         ''' | ||||
|         Send a message over this stream to the far end. | ||||
| 
 | ||||
|         ''' | ||||
|         # if self._eoc: | ||||
|         #     raise trio.ClosedResourceError('This stream is already ded') | ||||
| 
 | ||||
|         if self._ctx._error: | ||||
|             raise self._ctx._error  # from None | ||||
| 
 | ||||
|         if self._closed: | ||||
|             raise trio.ClosedResourceError('This stream was already closed') | ||||
| 
 | ||||
|         await self._ctx.chan.send({'yield': data, 'cid': self._ctx.cid}) | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -100,6 +100,15 @@ class BroadcastState: | |||
|     # on a newly produced value from the sender. | ||||
|     recv_ready: Optional[tuple[int, trio.Event]] = None | ||||
| 
 | ||||
|     # if a ``trio.EndOfChannel`` is received on any | ||||
|     # consumer all consumers should be placed in this state | ||||
|     # such that the group is notified of the end-of-broadcast. | ||||
|     # For now, this is solely for testing/debugging purposes. | ||||
|     eoc: bool = False | ||||
| 
 | ||||
|     # If the broadcaster was cancelled, we might as well track it | ||||
|     cancelled: bool = False | ||||
| 
 | ||||
| 
 | ||||
| class BroadcastReceiver(ReceiveChannel): | ||||
|     '''A memory receive channel broadcaster which is non-lossy for the | ||||
|  | @ -222,10 +231,23 @@ class BroadcastReceiver(ReceiveChannel): | |||
|                 event.set() | ||||
|                 return value | ||||
| 
 | ||||
|             except trio.Cancelled: | ||||
|             except trio.EndOfChannel: | ||||
|                 # if any one consumer gets an EOC from the underlying | ||||
|                 # receiver we need to unblock and send that signal to | ||||
|                 # all other consumers. | ||||
|                 self._state.eoc = True | ||||
|                 if event.statistics().tasks_waiting: | ||||
|                     event.set() | ||||
|                 raise | ||||
| 
 | ||||
|             except ( | ||||
|                 trio.Cancelled, | ||||
|             ): | ||||
|                 # handle cancelled specially otherwise sibling | ||||
|                 # consumers will be awoken with a sequence of -1 | ||||
|                 # state.recv_ready = trio.Cancelled | ||||
|                 # and will potentially try to rewait the underlying | ||||
|                 # receiver instead of just cancelling immediately. | ||||
|                 self._state.cancelled = True | ||||
|                 if event.statistics().tasks_waiting: | ||||
|                     event.set() | ||||
|                 raise | ||||
|  | @ -274,11 +296,12 @@ class BroadcastReceiver(ReceiveChannel): | |||
|     async def subscribe( | ||||
|         self, | ||||
|     ) -> AsyncIterator[BroadcastReceiver]: | ||||
|         '''Subscribe for values from this broadcast receiver. | ||||
|         ''' | ||||
|         Subscribe for values from this broadcast receiver. | ||||
| 
 | ||||
|         Returns a new ``BroadCastReceiver`` which is registered for and | ||||
|         pulls data from a clone of the original ``trio.abc.ReceiveChannel`` | ||||
|         provided at creation. | ||||
|         pulls data from a clone of the original | ||||
|         ``trio.abc.ReceiveChannel`` provided at creation. | ||||
| 
 | ||||
|         ''' | ||||
|         if self._closed: | ||||
|  | @ -301,7 +324,10 @@ class BroadcastReceiver(ReceiveChannel): | |||
|     async def aclose( | ||||
|         self, | ||||
|     ) -> None: | ||||
|         ''' | ||||
|         Close this receiver without affecting other consumers. | ||||
| 
 | ||||
|         ''' | ||||
|         if self._closed: | ||||
|             return | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue