Compare commits
	
		
			1 Commits 
		
	
	
		
			main
			...
			multicast_
		
	
	| Author | SHA1 | Date | 
|---|---|---|
|  | eb4bd203f0 | 
|  | @ -67,7 +67,6 @@ async def ensure_sequence( | |||
| 
 | ||||
| @acm | ||||
| async def open_sequence_streamer( | ||||
| 
 | ||||
|     sequence: list[int], | ||||
|     reg_addr: tuple[str, int], | ||||
|     start_method: str, | ||||
|  | @ -96,39 +95,43 @@ async def open_sequence_streamer( | |||
| 
 | ||||
| 
 | ||||
| def test_stream_fan_out_to_local_subscriptions( | ||||
|     reg_addr, | ||||
|     debug_mode: bool, | ||||
|     reg_addr: tuple, | ||||
|     start_method, | ||||
| ): | ||||
| 
 | ||||
|     sequence = list(range(1000)) | ||||
| 
 | ||||
|     async def main(): | ||||
|         with trio.fail_after(9): | ||||
|             async with open_sequence_streamer( | ||||
|                 sequence, | ||||
|                 reg_addr, | ||||
|                 start_method, | ||||
|             ) as stream: | ||||
| 
 | ||||
|         async with open_sequence_streamer( | ||||
|             sequence, | ||||
|             reg_addr, | ||||
|             start_method, | ||||
|         ) as stream: | ||||
|                 async with ( | ||||
|                     collapse_eg(), | ||||
|                     trio.open_nursery() as tn, | ||||
|                 ): | ||||
|                     for i in range(10): | ||||
|                         tn.start_soon( | ||||
|                             ensure_sequence, | ||||
|                             stream, | ||||
|                             sequence.copy(), | ||||
|                             name=f'consumer_{i}', | ||||
|                         ) | ||||
| 
 | ||||
|             async with trio.open_nursery() as n: | ||||
|                 for i in range(10): | ||||
|                     n.start_soon( | ||||
|                         ensure_sequence, | ||||
|                         stream, | ||||
|                         sequence.copy(), | ||||
|                         name=f'consumer_{i}', | ||||
|                     ) | ||||
|                     await stream.send(tuple(sequence)) | ||||
| 
 | ||||
|                 await stream.send(tuple(sequence)) | ||||
|                     async for value in stream: | ||||
|                         print(f'source stream rx: {value}') | ||||
|                         assert value == sequence[0] | ||||
|                         sequence.remove(value) | ||||
| 
 | ||||
|                 async for value in stream: | ||||
|                     print(f'source stream rx: {value}') | ||||
|                     assert value == sequence[0] | ||||
|                     sequence.remove(value) | ||||
| 
 | ||||
|                     if not sequence: | ||||
|                         # fully consumed | ||||
|                         break | ||||
|                         if not sequence: | ||||
|                             # fully consumed | ||||
|                             break | ||||
| 
 | ||||
|     trio.run(main) | ||||
| 
 | ||||
|  | @ -151,67 +154,69 @@ def test_consumer_and_parent_maybe_lag( | |||
|         sequence = list(range(300)) | ||||
|         parent_delay, sub_delay = task_delays | ||||
| 
 | ||||
|         async with open_sequence_streamer( | ||||
|             sequence, | ||||
|             reg_addr, | ||||
|             start_method, | ||||
|         ) as stream: | ||||
|         # TODO, maybe mak a cm-deco for main()s? | ||||
|         with trio.fail_after(3): | ||||
|             async with open_sequence_streamer( | ||||
|                 sequence, | ||||
|                 reg_addr, | ||||
|                 start_method, | ||||
|             ) as stream: | ||||
| 
 | ||||
|             try: | ||||
|                 async with ( | ||||
|                     collapse_eg(), | ||||
|                     trio.open_nursery() as tn, | ||||
|                 ): | ||||
|                 try: | ||||
|                     async with ( | ||||
|                         collapse_eg(), | ||||
|                         trio.open_nursery() as tn, | ||||
|                     ): | ||||
| 
 | ||||
|                     tn.start_soon( | ||||
|                         ensure_sequence, | ||||
|                         stream, | ||||
|                         sequence.copy(), | ||||
|                         sub_delay, | ||||
|                         name='consumer_task', | ||||
|                     ) | ||||
|                         tn.start_soon( | ||||
|                             ensure_sequence, | ||||
|                             stream, | ||||
|                             sequence.copy(), | ||||
|                             sub_delay, | ||||
|                             name='consumer_task', | ||||
|                         ) | ||||
| 
 | ||||
|                     await stream.send(tuple(sequence)) | ||||
|                         await stream.send(tuple(sequence)) | ||||
| 
 | ||||
|                     # async for value in stream: | ||||
|                     lagged = False | ||||
|                     lag_count = 0 | ||||
|                         # async for value in stream: | ||||
|                         lagged = False | ||||
|                         lag_count = 0 | ||||
| 
 | ||||
|                     while True: | ||||
|                         try: | ||||
|                             value = await stream.receive() | ||||
|                             print(f'source stream rx: {value}') | ||||
|                         while True: | ||||
|                             try: | ||||
|                                 value = await stream.receive() | ||||
|                                 print(f'source stream rx: {value}') | ||||
| 
 | ||||
|                             if lagged: | ||||
|                                 # re set the sequence starting at our last | ||||
|                                 # value | ||||
|                                 sequence = sequence[sequence.index(value) + 1:] | ||||
|                             else: | ||||
|                                 assert value == sequence[0] | ||||
|                                 sequence.remove(value) | ||||
|                                 if lagged: | ||||
|                                     # re set the sequence starting at our last | ||||
|                                     # value | ||||
|                                     sequence = sequence[sequence.index(value) + 1:] | ||||
|                                 else: | ||||
|                                     assert value == sequence[0] | ||||
|                                     sequence.remove(value) | ||||
| 
 | ||||
|                             lagged = False | ||||
|                                 lagged = False | ||||
| 
 | ||||
|                         except Lagged: | ||||
|                             lagged = True | ||||
|                             print(f'source stream lagged after {value}') | ||||
|                             lag_count += 1 | ||||
|                             continue | ||||
|                             except Lagged: | ||||
|                                 lagged = True | ||||
|                                 print(f'source stream lagged after {value}') | ||||
|                                 lag_count += 1 | ||||
|                                 continue | ||||
| 
 | ||||
|                         # lag the parent | ||||
|                         await trio.sleep(parent_delay) | ||||
|                             # lag the parent | ||||
|                             await trio.sleep(parent_delay) | ||||
| 
 | ||||
|                         if not sequence: | ||||
|                             # fully consumed | ||||
|                             break | ||||
|                     print(f'parent + source stream lagged: {lag_count}') | ||||
|                             if not sequence: | ||||
|                                 # fully consumed | ||||
|                                 break | ||||
|                         print(f'parent + source stream lagged: {lag_count}') | ||||
| 
 | ||||
|                     if parent_delay > sub_delay: | ||||
|                         assert lag_count > 0 | ||||
|                         if parent_delay > sub_delay: | ||||
|                             assert lag_count > 0 | ||||
| 
 | ||||
|             except Lagged: | ||||
|                 # child was lagged | ||||
|                 assert parent_delay < sub_delay | ||||
|                 except Lagged: | ||||
|                     # child was lagged | ||||
|                     assert parent_delay < sub_delay | ||||
| 
 | ||||
|     trio.run(main) | ||||
| 
 | ||||
|  | @ -285,7 +290,11 @@ def test_faster_task_to_recv_is_cancelled_by_slower( | |||
| 
 | ||||
| 
 | ||||
| def test_subscribe_errors_after_close(): | ||||
|     ''' | ||||
|     Verify after calling `BroadcastReceiver.aclose()` you can't | ||||
|     "re-open" it via `.subscribe()`. | ||||
| 
 | ||||
|     ''' | ||||
|     async def main(): | ||||
| 
 | ||||
|         size = 1 | ||||
|  | @ -293,6 +302,8 @@ def test_subscribe_errors_after_close(): | |||
|         async with broadcast_receiver(rx, size) as brx: | ||||
|             pass | ||||
| 
 | ||||
|         assert brx.key not in brx._state.subs | ||||
| 
 | ||||
|         try: | ||||
|             # open and close | ||||
|             async with brx.subscribe(): | ||||
|  | @ -302,7 +313,7 @@ def test_subscribe_errors_after_close(): | |||
|             assert brx.key not in brx._state.subs | ||||
| 
 | ||||
|         else: | ||||
|             assert 0 | ||||
|             pytest.fail('brx.subscribe() never raised!?') | ||||
| 
 | ||||
|     trio.run(main) | ||||
| 
 | ||||
|  |  | |||
|  | @ -102,6 +102,9 @@ class MsgStream(trio.abc.Channel): | |||
|         self._eoc: bool|trio.EndOfChannel = False | ||||
|         self._closed: bool|trio.ClosedResourceError = False | ||||
| 
 | ||||
|     def is_eoc(self) -> bool|trio.EndOfChannel: | ||||
|         return self._eoc | ||||
| 
 | ||||
|     @property | ||||
|     def ctx(self) -> Context: | ||||
|         ''' | ||||
|  | @ -188,7 +191,14 @@ class MsgStream(trio.abc.Channel): | |||
| 
 | ||||
|         return pld | ||||
| 
 | ||||
|     async def receive( | ||||
|     # XXX NOTE, this is left private because in `.subscribe()` usage | ||||
|     # we rebind the public `.recieve()` to a `BroadcastReceiver` but | ||||
|     # on `.subscribe().__aexit__()`, for the first task which enters, | ||||
|     # we want to revert to this msg-stream-instance's method since | ||||
|     # mult-task-tracking provided by the b-caster is then no longer | ||||
|     # necessary. | ||||
|     # | ||||
|     async def _receive( | ||||
|         self, | ||||
|         hide_tb: bool = False, | ||||
|     ): | ||||
|  | @ -313,6 +323,8 @@ class MsgStream(trio.abc.Channel): | |||
| 
 | ||||
|             raise src_err | ||||
| 
 | ||||
|     receive = _receive | ||||
| 
 | ||||
|     async def aclose(self) -> list[Exception|dict]: | ||||
|         ''' | ||||
|         Cancel associated remote actor task and local memory channel on | ||||
|  | @ -528,10 +540,15 @@ class MsgStream(trio.abc.Channel): | |||
|         receiver wrapper. | ||||
| 
 | ||||
|         ''' | ||||
|         # NOTE: This operation is indempotent and non-reversible, so be | ||||
|         # sure you can deal with any (theoretical) overhead of the the | ||||
|         # allocated ``BroadcastReceiver`` before calling this method for | ||||
|         # the first time. | ||||
|         # XXX NOTE, This operation was originally implemented as | ||||
|         # indempotent and non-reversible, so you had to be **VERY** | ||||
|         # aware of any (theoretical) overhead from the allocated | ||||
|         # `BroadcastReceiver.receive()`. | ||||
|         # | ||||
|         # HOWEVER, NOw we do revert and de-alloc the ._broadcaster | ||||
|         # when the final caller (task) exits. | ||||
|         # | ||||
|         bcast: BroadcastReceiver|None = None | ||||
|         if self._broadcaster is None: | ||||
| 
 | ||||
|             bcast = self._broadcaster = broadcast_receiver( | ||||
|  | @ -541,29 +558,60 @@ class MsgStream(trio.abc.Channel): | |||
| 
 | ||||
|                 # TODO: can remove this kwarg right since | ||||
|                 # by default behaviour is to do this anyway? | ||||
|                 receive_afunc=self.receive, | ||||
|                 receive_afunc=self._receive, | ||||
|             ) | ||||
| 
 | ||||
|             # NOTE: we override the original stream instance's receive | ||||
|             # method to now delegate to the broadcaster's ``.receive()`` | ||||
|             # such that new subscribers will be copied received values | ||||
|             # and this stream doesn't have to expect it's original | ||||
|             # consumer(s) to get a new broadcast rx handle. | ||||
|             # XXX NOTE, we override the original stream instance's | ||||
|             # receive method to instead delegate to the broadcaster's | ||||
|             # `.receive()` such that new subscribers (multiple | ||||
|             # `trio.Task`s) will be copied received values and the | ||||
|             # *first* task to enter here doesn't have to expect its original consumer(s) | ||||
|             # to get a new broadcast rx handle; everything happens | ||||
|             # underneath this iface seemlessly. | ||||
|             # | ||||
|             self.receive = bcast.receive  # type: ignore | ||||
|             # seems there's no graceful way to type this with ``mypy``? | ||||
|             # seems there's no graceful way to type this with `mypy`? | ||||
|             # https://github.com/python/mypy/issues/708 | ||||
| 
 | ||||
|         async with self._broadcaster.subscribe() as bstream: | ||||
|             assert bstream.key != self._broadcaster.key | ||||
|             assert bstream._recv == self._broadcaster._recv | ||||
|         # TODO, prevent re-entrant sub scope? | ||||
|         # if self._broadcaster._closed: | ||||
|         #     raise RuntimeError( | ||||
|         #         'This stream | ||||
| 
 | ||||
|             # NOTE: we patch on a `.send()` to the bcaster so that the | ||||
|             # caller can still conduct 2-way streaming using this | ||||
|             # ``bstream`` handle transparently as though it was the msg | ||||
|             # stream instance. | ||||
|             bstream.send = self.send  # type: ignore | ||||
|         try: | ||||
|             aenter = self._broadcaster.subscribe() | ||||
|             async with aenter as bstream: | ||||
|                 # ?TODO, move into test suite? | ||||
|                 assert bstream.key != self._broadcaster.key | ||||
|                 assert bstream._recv == self._broadcaster._recv | ||||
| 
 | ||||
|             yield bstream | ||||
|                 # NOTE: we patch on a `.send()` to the bcaster so that the | ||||
|                 # caller can still conduct 2-way streaming using this | ||||
|                 # ``bstream`` handle transparently as though it was the msg | ||||
|                 # stream instance. | ||||
|                 bstream.send = self.send  # type: ignore | ||||
| 
 | ||||
|                 # newly-allocated instance | ||||
|                 yield bstream | ||||
| 
 | ||||
|         finally: | ||||
|             # XXX, the first-enterer task should, like all other | ||||
|             # subs, close the first allocated bcrx, which adjusts the | ||||
|             # common `bcrx.state` | ||||
|             with trio.CancelScope(shield=True): | ||||
|                 if bcast is not None: | ||||
|                     await bcast.aclose() | ||||
| 
 | ||||
|                 # XXX, when the bcrx.state reports there are no more subs | ||||
|                 # we can revert to this obj's method, removing any | ||||
|                 # delegation overhead! | ||||
|                 if ( | ||||
|                     (orig_bcast := self._broadcaster) | ||||
|                     and | ||||
|                     not orig_bcast.state.subs | ||||
|                 ): | ||||
|                     self.receive = self._receive | ||||
|                     # self._broadcaster = None | ||||
| 
 | ||||
|     async def send( | ||||
|         self, | ||||
|  |  | |||
|  | @ -100,6 +100,32 @@ class Lagged(trio.TooSlowError): | |||
|     ''' | ||||
| 
 | ||||
| 
 | ||||
| def wrap_rx_for_eoc( | ||||
|     rx: AsyncReceiver, | ||||
| ) -> AsyncReceiver: | ||||
| 
 | ||||
|     match rx: | ||||
|         case trio.MemoryReceiveChannel(): | ||||
| 
 | ||||
|             # XXX, taken verbatim from .receive_nowait() | ||||
|             def is_eoc() -> bool: | ||||
|                 if not rx._state.open_send_channels: | ||||
|                     return trio.EndOfChannel | ||||
| 
 | ||||
|                 return False | ||||
| 
 | ||||
|             rx.is_eoc = is_eoc | ||||
| 
 | ||||
|         case _: | ||||
|             # XXX, ensure we define a private field! | ||||
|             # case tractor.MsgStream: | ||||
|             assert ( | ||||
|                 getattr(rx, '_eoc', False) is not None | ||||
|             ) | ||||
| 
 | ||||
|     return rx | ||||
| 
 | ||||
| 
 | ||||
| class BroadcastState(Struct): | ||||
|     ''' | ||||
|     Common state to all receivers of a broadcast. | ||||
|  | @ -186,11 +212,23 @@ class BroadcastReceiver(ReceiveChannel): | |||
|         state.subs[self.key] = -1 | ||||
| 
 | ||||
|         # underlying for this receiver | ||||
|         self._rx = rx_chan | ||||
|         self._rx = wrap_rx_for_eoc(rx_chan) | ||||
|         self._recv = receive_afunc or rx_chan.receive | ||||
|         self._closed: bool = False | ||||
|         self._raise_on_lag = raise_on_lag | ||||
| 
 | ||||
|     @property | ||||
|     def state(self) -> BroadcastState: | ||||
|         ''' | ||||
|         Read-only access to this receivers internal `._state` | ||||
|         instance ref. | ||||
| 
 | ||||
|         If you just want to read the high-level state metrics, | ||||
|         use `.state.statistics()`. | ||||
| 
 | ||||
|         ''' | ||||
|         return self._state | ||||
| 
 | ||||
|     def receive_nowait( | ||||
|         self, | ||||
|         _key: int | None = None, | ||||
|  | @ -215,7 +253,23 @@ class BroadcastReceiver(ReceiveChannel): | |||
|         try: | ||||
|             seq = state.subs[key] | ||||
|         except KeyError: | ||||
|             # from tractor import pause_from_sync | ||||
|             # pause_from_sync(shield=True) | ||||
|             if ( | ||||
|                 (rx_eoc := self._rx.is_eoc()) | ||||
|                 or | ||||
|                 self.state.eoc | ||||
|             ): | ||||
|                 raise trio.EndOfChannel( | ||||
|                     'Broadcast-Rx underlying already ended!' | ||||
|                 ) from rx_eoc | ||||
| 
 | ||||
|             if self._closed: | ||||
|                 # if (rx_eoc := self._rx._eoc): | ||||
|                 #     raise trio.EndOfChannel( | ||||
|                 #         'Broadcast-Rx underlying already ended!' | ||||
|                 #     ) from rx_eoc | ||||
| 
 | ||||
|                 raise trio.ClosedResourceError | ||||
| 
 | ||||
|             raise RuntimeError( | ||||
|  | @ -453,8 +507,9 @@ class BroadcastReceiver(ReceiveChannel): | |||
|         self._closed = True | ||||
| 
 | ||||
| 
 | ||||
| # NOTE, this can we use as an `@acm` since `BroadcastReceiver` | ||||
| # derives from `ReceiveChannel`. | ||||
| def broadcast_receiver( | ||||
| 
 | ||||
|     recv_chan: AsyncReceiver, | ||||
|     max_buffer_size: int, | ||||
|     receive_afunc: Callable[[], Awaitable[Any]]|None = None, | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue