forked from goodboy/tractor
				
			Add common state delegate type for all consumers
For every set of broadcast receivers which pull from the same producer, we need a singleton state for all of, - subscriptions - the sender ready event - the queue Add a `BroadcastState` dataclass for this and pass it to all subscriptions. This makes the design much more like the built-in memory channels which do something very similar with `MemoryChannelState`. Use a `filter()` on the subs list in the sequence update step, plus some other commented approaches we can try for speed.tokio_backup
							parent
							
								
									9d12cc80dd
								
							
						
					
					
						commit
						b9863fc4ab
					
				|  | @ -4,23 +4,42 @@ https://tokio-rs.github.io/tokio/doc/tokio/sync/broadcast/index.html | ||||||
| 
 | 
 | ||||||
| ''' | ''' | ||||||
| from __future__ import annotations | from __future__ import annotations | ||||||
| from itertools import cycle |  | ||||||
| from collections import deque | from collections import deque | ||||||
| from contextlib import asynccontextmanager | from contextlib import asynccontextmanager | ||||||
|  | from dataclasses import dataclass | ||||||
| from functools import partial | from functools import partial | ||||||
|  | from itertools import cycle | ||||||
|  | from operator import ne | ||||||
| from typing import Optional | from typing import Optional | ||||||
| 
 | 
 | ||||||
| import trio | import trio | ||||||
| import tractor |  | ||||||
| from trio.lowlevel import current_task |  | ||||||
| from trio.abc import ReceiveChannel |  | ||||||
| from trio._core._run import Task | from trio._core._run import Task | ||||||
|  | from trio.abc import ReceiveChannel | ||||||
|  | from trio.lowlevel import current_task | ||||||
|  | import tractor | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class Lagged(trio.TooSlowError): | class Lagged(trio.TooSlowError): | ||||||
|     '''Subscribed consumer task was too slow''' |     '''Subscribed consumer task was too slow''' | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @dataclass | ||||||
|  | class BroadcastState: | ||||||
|  |     '''Common state to all receivers of a broadcast. | ||||||
|  | 
 | ||||||
|  |     ''' | ||||||
|  |     queue: deque | ||||||
|  | 
 | ||||||
|  |     # map of underlying clones to receiver wrappers | ||||||
|  |     # which must be provided as a singleton per broadcaster | ||||||
|  |     # clone-subscription set. | ||||||
|  |     subs: dict[trio.ReceiveChannel, BroadcastReceiver] | ||||||
|  | 
 | ||||||
|  |     # broadcast event to wakeup all sleeping consumer tasks | ||||||
|  |     # on a newly produced value from the sender. | ||||||
|  |     sender_ready: Optional[trio.Event] = None | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class BroadcastReceiver(ReceiveChannel): | class BroadcastReceiver(ReceiveChannel): | ||||||
|     '''A memory receive channel broadcaster which is non-lossy for the |     '''A memory receive channel broadcaster which is non-lossy for the | ||||||
|     fastest consumer. |     fastest consumer. | ||||||
|  | @ -33,28 +52,21 @@ class BroadcastReceiver(ReceiveChannel): | ||||||
|         self, |         self, | ||||||
| 
 | 
 | ||||||
|         rx_chan: ReceiveChannel, |         rx_chan: ReceiveChannel, | ||||||
|         queue: deque, |         state: BroadcastState, | ||||||
|         _subs: dict[trio.ReceiveChannel, BroadcastReceiver], |  | ||||||
| 
 | 
 | ||||||
|     ) -> None: |     ) -> None: | ||||||
| 
 | 
 | ||||||
|         # map of underlying clones to receiver wrappers |         # register the original underlying (clone) | ||||||
|         # which must be provided as a singleton per broadcaster |         self._state = state | ||||||
|         # clone-subscription set. |         state.subs[rx_chan] = -1 | ||||||
|         self._subs = _subs |  | ||||||
| 
 | 
 | ||||||
|         # underlying for this receiver |         # underlying for this receiver | ||||||
|         self._rx = rx_chan |         self._rx = rx_chan | ||||||
| 
 | 
 | ||||||
|         # register the original underlying (clone) |  | ||||||
|         self._subs[rx_chan] = -1 |  | ||||||
| 
 |  | ||||||
|         self._queue = queue |  | ||||||
|         self._value_received: Optional[trio.Event] = None |  | ||||||
| 
 |  | ||||||
|     async def receive(self): |     async def receive(self): | ||||||
| 
 | 
 | ||||||
|         key = self._rx |         key = self._rx | ||||||
|  |         state = self._state | ||||||
| 
 | 
 | ||||||
|         # TODO: ideally we can make some way to "lock out" the |         # TODO: ideally we can make some way to "lock out" the | ||||||
|         # underlying receive channel in some way such that if some task |         # underlying receive channel in some way such that if some task | ||||||
|  | @ -64,7 +76,7 @@ class BroadcastReceiver(ReceiveChannel): | ||||||
|         # only tasks which have entered ``.subscribe()`` can |         # only tasks which have entered ``.subscribe()`` can | ||||||
|         # receive on this broadcaster. |         # receive on this broadcaster. | ||||||
|         try: |         try: | ||||||
|             seq = self._subs[key] |             seq = state.subs[key] | ||||||
|         except KeyError: |         except KeyError: | ||||||
|             raise RuntimeError( |             raise RuntimeError( | ||||||
|                 f'{self} is not registerd as subscriber') |                 f'{self} is not registerd as subscriber') | ||||||
|  | @ -74,7 +86,7 @@ class BroadcastReceiver(ReceiveChannel): | ||||||
|         if seq > -1: |         if seq > -1: | ||||||
|             # get the oldest value we haven't received immediately |             # get the oldest value we haven't received immediately | ||||||
|             try: |             try: | ||||||
|                 value = self._queue[seq] |                 value = state.queue[seq] | ||||||
|             except IndexError: |             except IndexError: | ||||||
| 
 | 
 | ||||||
|                 # adhere to ``tokio`` style "lagging": |                 # adhere to ``tokio`` style "lagging": | ||||||
|  | @ -87,51 +99,61 @@ class BroadcastReceiver(ReceiveChannel): | ||||||
|                 # decrement to the last value and expect |                 # decrement to the last value and expect | ||||||
|                 # consumer to either handle the ``Lagged`` and come back |                 # consumer to either handle the ``Lagged`` and come back | ||||||
|                 # or bail out on its own (thus un-subscribing) |                 # or bail out on its own (thus un-subscribing) | ||||||
|                 self._subs[key] = self._queue.maxlen - 1 |                 state.subs[key] = state.queue.maxlen - 1 | ||||||
| 
 | 
 | ||||||
|                 # this task was overrun by the producer side |                 # this task was overrun by the producer side | ||||||
|                 task: Task = current_task() |                 task: Task = current_task() | ||||||
|                 raise Lagged(f'Task {task.name} was overrun') |                 raise Lagged(f'Task {task.name} was overrun') | ||||||
| 
 | 
 | ||||||
|             self._subs[key] -= 1 |             state.subs[key] -= 1 | ||||||
|             return value |             return value | ||||||
| 
 | 
 | ||||||
|         # current task already has the latest value **and** is the |         # current task already has the latest value **and** is the | ||||||
|         # first task to begin waiting for a new one |         # first task to begin waiting for a new one | ||||||
|         if self._value_received is None: |         if state.sender_ready is None: | ||||||
| 
 | 
 | ||||||
|             event = self._value_received = trio.Event() |             event = state.sender_ready = trio.Event() | ||||||
|             value = await self._rx.receive() |             value = await self._rx.receive() | ||||||
| 
 | 
 | ||||||
|             # items with lower indices are "newer" |             # items with lower indices are "newer" | ||||||
|             self._queue.appendleft(value) |             state.queue.appendleft(value) | ||||||
| 
 | 
 | ||||||
|             # broadcast new value to all subscribers by increasing |             # broadcast new value to all subscribers by increasing | ||||||
|             # all sequence numbers that will point in the queue to |             # all sequence numbers that will point in the queue to | ||||||
|             # their latest available value. |             # their latest available value. | ||||||
| 
 | 
 | ||||||
|             subs = self._subs.copy() |             # don't decrement the sequence for this task since we | ||||||
|             # don't decrement the sequence # for this task since we |  | ||||||
|             # already retreived the last value |             # already retreived the last value | ||||||
|             subs.pop(key) | 
 | ||||||
|             for sub_key, seq in subs.items(): |             # XXX: which of these impls is fastest? | ||||||
|                 self._subs[sub_key] += 1 | 
 | ||||||
|  |             # subs = state.subs.copy() | ||||||
|  |             # subs.pop(key) | ||||||
|  | 
 | ||||||
|  |             for sub_key in filter( | ||||||
|  |                 # lambda k: k != key, state.subs, | ||||||
|  |                 partial(ne, key), state.subs, | ||||||
|  |             ): | ||||||
|  |                 state.subs[sub_key] += 1 | ||||||
| 
 | 
 | ||||||
|             # reset receiver waiter task event for next blocking condition |             # reset receiver waiter task event for next blocking condition | ||||||
|             self._value_received = None |  | ||||||
|             event.set() |             event.set() | ||||||
|  |             state.sender_ready = None | ||||||
|             return value |             return value | ||||||
| 
 | 
 | ||||||
|         # This task is all caught up and ready to receive the latest |         # This task is all caught up and ready to receive the latest | ||||||
|         # value, so queue sched it on the internal event. |         # value, so queue sched it on the internal event. | ||||||
|         else: |         else: | ||||||
|             await self._value_received.wait() |             await state.sender_ready.wait() | ||||||
| 
 | 
 | ||||||
|             seq = self._subs[key] |             # TODO: optimization: if this is always true can't we just | ||||||
|             assert seq > -1, 'Internal error?' |             # skip iterating these sequence numbers on the fastest | ||||||
|  |             # task's wakeup and always read from state.queue[0]? | ||||||
|  |             seq = state.subs[key] | ||||||
|  |             assert seq == 0, 'Internal error?' | ||||||
| 
 | 
 | ||||||
|             self._subs[key] -= 1 |             state.subs[key] -= 1 | ||||||
|             return self._queue[0] |             return state.queue[seq] | ||||||
| 
 | 
 | ||||||
|     @asynccontextmanager |     @asynccontextmanager | ||||||
|     async def subscribe( |     async def subscribe( | ||||||
|  | @ -145,12 +167,12 @@ class BroadcastReceiver(ReceiveChannel): | ||||||
| 
 | 
 | ||||||
|         ''' |         ''' | ||||||
|         clone = self._rx.clone() |         clone = self._rx.clone() | ||||||
|  |         state = self._state | ||||||
|         br = BroadcastReceiver( |         br = BroadcastReceiver( | ||||||
|             clone, |             rx_chan=clone, | ||||||
|             self._queue, |             state=state, | ||||||
|             _subs=self._subs, |  | ||||||
|         ) |         ) | ||||||
|         assert clone in self._subs |         assert clone in state.subs | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|             yield br |             yield br | ||||||
|  | @ -159,7 +181,7 @@ class BroadcastReceiver(ReceiveChannel): | ||||||
|             # ``AsyncResource`` api. |             # ``AsyncResource`` api. | ||||||
|             await clone.aclose() |             await clone.aclose() | ||||||
|             # drop from subscribers and close |             # drop from subscribers and close | ||||||
|             self._subs.pop(clone) |             state.subs.pop(clone) | ||||||
| 
 | 
 | ||||||
|     # TODO: |     # TODO: | ||||||
|     # - should there be some ._closed flag that causes |     # - should there be some ._closed flag that causes | ||||||
|  | @ -186,8 +208,10 @@ def broadcast_receiver( | ||||||
| 
 | 
 | ||||||
|     return BroadcastReceiver( |     return BroadcastReceiver( | ||||||
|         recv_chan, |         recv_chan, | ||||||
|  |         state=BroadcastState( | ||||||
|             queue=deque(maxlen=max_buffer_size), |             queue=deque(maxlen=max_buffer_size), | ||||||
|         _subs={},  # this is singleton over all subscriptions |             subs={}, | ||||||
|  |         ), | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -210,7 +234,7 @@ if __name__ == '__main__': | ||||||
|             ) -> None: |             ) -> None: | ||||||
| 
 | 
 | ||||||
|                 task = current_task() |                 task = current_task() | ||||||
|                 count = 0 |                 lags = 0 | ||||||
| 
 | 
 | ||||||
|                 while True: |                 while True: | ||||||
|                     async with rx.subscribe() as brx: |                     async with rx.subscribe() as brx: | ||||||
|  | @ -218,22 +242,22 @@ if __name__ == '__main__': | ||||||
|                             async for value in brx: |                             async for value in brx: | ||||||
|                                 print(f'{task.name}: {value}') |                                 print(f'{task.name}: {value}') | ||||||
|                                 await trio.sleep(delay) |                                 await trio.sleep(delay) | ||||||
|                                 count += 1 |  | ||||||
| 
 | 
 | ||||||
|                         except Lagged: |                         except Lagged: | ||||||
|                             print( |                             print( | ||||||
|                                 f'restarting slow ass {task.name}' |                                 f'restarting slow ass {task.name}' | ||||||
|                                 f'that bailed out on {count}:{value}') |                                 f'that bailed out on {lags}:{value}') | ||||||
|                             if count <= retries: |                             if lags <= retries: | ||||||
|  |                                 lags += 1 | ||||||
|                                 continue |                                 continue | ||||||
|                             else: |                             else: | ||||||
|                                 print( |                                 print( | ||||||
|                                     f'{task.name} was too slow and terminated ' |                                     f'{task.name} was too slow and terminated ' | ||||||
|                                     f'on {count}:{value}') |                                     f'on {lags}:{value}') | ||||||
|                                 return |                                 return | ||||||
| 
 | 
 | ||||||
|             async with trio.open_nursery() as n: |             async with trio.open_nursery() as n: | ||||||
|                 for i in range(1, size): |                 for i in range(1, 10): | ||||||
|                     n.start_soon( |                     n.start_soon( | ||||||
|                         partial( |                         partial( | ||||||
|                             sub_and_print, |                             sub_and_print, | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue