forked from goodboy/tractor
				
			Obviously keying on tasks isn't going to work
Using the current task as a subscription key fails horribly as soon as you hand off new subscription receiver to another task you've spawned.. Instead use the underlying ``trio.abc.ReceiveChannel.clone()`` as a key (so i guess we're assuming cloning is supported by the underlying?) which makes this all work just like default mem chans. As a bonus, now we can just close the underlying rx (which may be a clone) on `.aclose()` and everything should just work in terms of the underlying channels lifetime (i think?). Change `.subscribe()` to be async since the receive channel type interface only expects `.aclose()` and it actually ends up being nicer for 3.9+ style `async with` parentheses style anyway.tokio_backup
							parent
							
								
									eeca3d0d50
								
							
						
					
					
						commit
						3f9b860210
					
				|  | @ -6,7 +6,7 @@ https://tokio-rs.github.io/tokio/doc/tokio/sync/broadcast/index.html | ||||||
| from __future__ import annotations | from __future__ import annotations | ||||||
| from itertools import cycle | from itertools import cycle | ||||||
| from collections import deque | from collections import deque | ||||||
| from contextlib import contextmanager | from contextlib import asynccontextmanager | ||||||
| from functools import partial | from functools import partial | ||||||
| from typing import Optional | from typing import Optional | ||||||
| 
 | 
 | ||||||
|  | @ -15,9 +15,6 @@ import tractor | ||||||
| from trio.lowlevel import current_task | from trio.lowlevel import current_task | ||||||
| from trio.abc import ReceiveChannel | from trio.abc import ReceiveChannel | ||||||
| from trio._core._run import Task | from trio._core._run import Task | ||||||
| # from trio._channel import ( |  | ||||||
| #     MemoryReceiveChannel, |  | ||||||
| # ) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class Lagged(trio.TooSlowError): | class Lagged(trio.TooSlowError): | ||||||
|  | @ -29,57 +26,71 @@ class BroadcastReceiver(ReceiveChannel): | ||||||
|     fastest consumer. |     fastest consumer. | ||||||
| 
 | 
 | ||||||
|     Additional consumer tasks can receive all produced values by registering |     Additional consumer tasks can receive all produced values by registering | ||||||
|     with ``.subscribe()``. |     with ``.subscribe()`` and receiving from thew new instance it delivers. | ||||||
| 
 | 
 | ||||||
|     ''' |     ''' | ||||||
|  |     # map of underlying clones to receiver wrappers | ||||||
|  |     _subs: dict[trio.ReceiveChannel, BroadcastReceiver] = {} | ||||||
|  | 
 | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
| 
 | 
 | ||||||
|         rx_chan: MemoryReceiveChannel, |         rx_chan: ReceiveChannel, | ||||||
|         queue: deque, |         queue: deque, | ||||||
| 
 | 
 | ||||||
|     ) -> None: |     ) -> None: | ||||||
| 
 | 
 | ||||||
|         self._rx = rx_chan |         self._rx = rx_chan | ||||||
|         self._queue = queue |         self._queue = queue | ||||||
|         self._subs: dict[Task, int] = {}  # {id(current_task()): -1} |  | ||||||
|         self._clones: dict[Task, ReceiveChannel] = {} |  | ||||||
|         self._value_received: Optional[trio.Event] = None |         self._value_received: Optional[trio.Event] = None | ||||||
| 
 | 
 | ||||||
|     async def receive(self): |     async def receive(self): | ||||||
| 
 | 
 | ||||||
|         task: Task = current_task() |         key = self._rx | ||||||
|  | 
 | ||||||
|  |         # TODO: ideally we can make some way to "lock out" the | ||||||
|  |         # underlying receive channel in some way such that if some task | ||||||
|  |         # tries to pull from it directly (i.e. one we're unaware of) | ||||||
|  |         # then it errors out. | ||||||
|  | 
 | ||||||
|  |         # only tasks which have entered ``.subscribe()`` can | ||||||
|  |         # receive on this broadcaster. | ||||||
|  |         try: | ||||||
|  |             seq = self._subs[key] | ||||||
|  |         except KeyError: | ||||||
|  |             raise RuntimeError( | ||||||
|  |                 f'{self} is not registerd as subscriber') | ||||||
| 
 | 
 | ||||||
|         # check that task does not already have a value it can receive |         # check that task does not already have a value it can receive | ||||||
|         # immediately and/or that it has lagged. |         # immediately and/or that it has lagged. | ||||||
|         try: |  | ||||||
|             seq = self._subs[task] |  | ||||||
|         except KeyError: |  | ||||||
|             raise RuntimeError( |  | ||||||
|                 f'Task {task.name} is not registerd as subscriber') |  | ||||||
| 
 |  | ||||||
|         if seq > -1: |         if seq > -1: | ||||||
|             # get the oldest value we haven't received immediately |             # get the oldest value we haven't received immediately | ||||||
|             try: |             try: | ||||||
|                 value = self._queue[seq] |                 value = self._queue[seq] | ||||||
|             except IndexError: |             except IndexError: | ||||||
|  | 
 | ||||||
|  |                 # adhere to ``tokio`` style "lagging": | ||||||
|  |                 # "Once RecvError::Lagged is returned, the lagging | ||||||
|  |                 # receiver's position is updated to the oldest value | ||||||
|  |                 # contained by the channel. The next call to recv will | ||||||
|  |                 # return this value." | ||||||
|  |                 # https://tokio-rs.github.io/tokio/doc/tokio/sync/broadcast/index.html#lagging | ||||||
|  | 
 | ||||||
|                 # decrement to the last value and expect |                 # decrement to the last value and expect | ||||||
|                 # consumer to either handle the ``Lagged`` and come back |                 # consumer to either handle the ``Lagged`` and come back | ||||||
|                 # or bail out on it's own (thus un-subscribing) |                 # or bail out on its own (thus un-subscribing) | ||||||
|                 self._subs[task] = self._queue.maxlen - 1 |                 self._subs[key] = self._queue.maxlen - 1 | ||||||
| 
 | 
 | ||||||
|                 # this task was overrun by the producer side |                 # this task was overrun by the producer side | ||||||
|  |                 task: Task = current_task() | ||||||
|                 raise Lagged(f'Task {task.name} was overrun') |                 raise Lagged(f'Task {task.name} was overrun') | ||||||
| 
 | 
 | ||||||
|             self._subs[task] -= 1 |             self._subs[key] -= 1 | ||||||
|             return value |             return value | ||||||
| 
 | 
 | ||||||
|         if self._value_received is None: |  | ||||||
|         # current task already has the latest value **and** is the |         # current task already has the latest value **and** is the | ||||||
|         # first task to begin waiting for a new one |         # first task to begin waiting for a new one | ||||||
| 
 |         if self._value_received is None: | ||||||
|             # what sanity checks might we use for the underlying chan ? |  | ||||||
|             # assert not self._rx._state.data |  | ||||||
| 
 | 
 | ||||||
|             event = self._value_received = trio.Event() |             event = self._value_received = trio.Event() | ||||||
|             value = await self._rx.receive() |             value = await self._rx.receive() | ||||||
|  | @ -92,9 +103,9 @@ class BroadcastReceiver(ReceiveChannel): | ||||||
|             # their latest available value. |             # their latest available value. | ||||||
| 
 | 
 | ||||||
|             subs = self._subs.copy() |             subs = self._subs.copy() | ||||||
|             # don't decerement the sequence # for this task since we |             # don't decrement the sequence # for this task since we | ||||||
|             # already retreived the last value |             # already retreived the last value | ||||||
|             subs.pop(task) |             subs.pop(key) | ||||||
|             for sub_key, seq in subs.items(): |             for sub_key, seq in subs.items(): | ||||||
|                 self._subs[sub_key] += 1 |                 self._subs[sub_key] += 1 | ||||||
| 
 | 
 | ||||||
|  | @ -103,37 +114,56 @@ class BroadcastReceiver(ReceiveChannel): | ||||||
|             event.set() |             event.set() | ||||||
|             return value |             return value | ||||||
| 
 | 
 | ||||||
|  |         # This task is all caught up and ready to receive the latest | ||||||
|  |         # value, so queue sched it on the internal event. | ||||||
|         else: |         else: | ||||||
|             await self._value_received.wait() |             await self._value_received.wait() | ||||||
| 
 | 
 | ||||||
|             seq = self._subs[task] |             seq = self._subs[key] | ||||||
|             assert seq > -1, 'Internal error?' |             assert seq > -1, 'Internal error?' | ||||||
| 
 | 
 | ||||||
|             self._subs[task] -= 1 |             self._subs[key] -= 1 | ||||||
|             return self._queue[0] |             return self._queue[0] | ||||||
| 
 | 
 | ||||||
|     # @asynccontextmanager |     @asynccontextmanager | ||||||
|     @contextmanager |     async def subscribe( | ||||||
|     def subscribe( |  | ||||||
|         self, |         self, | ||||||
|     ) -> BroadcastReceiver: |     ) -> BroadcastReceiver: | ||||||
|         task: task = current_task() |         '''Subscribe for values from this broadcast receiver. | ||||||
|         self._subs[task] = -1 |  | ||||||
|         # XXX: we only use this clone for closure tracking |  | ||||||
|         clone = self._clones[task] = self._rx.clone() |  | ||||||
|         try: |  | ||||||
|             yield self |  | ||||||
|         finally: |  | ||||||
|             self._subs.pop(task) |  | ||||||
|             clone.close() |  | ||||||
| 
 | 
 | ||||||
|     # TODO: do we need anything here? |         Returns a new ``BroadCastReceiver`` which is registered for and | ||||||
|     # if we're the last sub to close then close |         pulls data from a clone of the original ``trio.abc.ReceiveChannel`` | ||||||
|     # the underlying rx channel, but couldn't we just |         provided at creation. | ||||||
|     # use ``.clone()``s trackign then? | 
 | ||||||
|     async def aclose(self) -> None: |         ''' | ||||||
|         task: Task = current_task() |         clone = self._rx.clone() | ||||||
|         await self._clones[task].aclose() |         self._subs[clone] = -1 | ||||||
|  |         try: | ||||||
|  |             yield BroadcastReceiver( | ||||||
|  |                 clone, | ||||||
|  |                 self._queue, | ||||||
|  |             ) | ||||||
|  |         finally: | ||||||
|  |             # drop from subscribers and close | ||||||
|  |             self._subs.pop(clone) | ||||||
|  |             # XXX: this is the reason this function is async: the | ||||||
|  |             # ``AsyncResource`` api. | ||||||
|  |             await clone.aclose() | ||||||
|  | 
 | ||||||
|  |     # TODO: | ||||||
|  |     # - should there be some ._closed flag that causes | ||||||
|  |     #   consumers to die **before** they read all queued values? | ||||||
|  |     # - if subs only open and close clones then the underlying | ||||||
|  |     #   will never be killed until the last instance closes? | ||||||
|  |     #   This is correct right? | ||||||
|  |     async def aclose( | ||||||
|  |         self, | ||||||
|  |     ) -> None: | ||||||
|  |         # XXX: leaving it like this consumers can still get values | ||||||
|  |         # up to the last received that still reside in the queue. | ||||||
|  |         # Is this what we want? | ||||||
|  |         await self._rx.aclose() | ||||||
|  |         self._subs.pop(self._rx) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def broadcast_receiver( | def broadcast_receiver( | ||||||
|  | @ -158,6 +188,7 @@ if __name__ == '__main__': | ||||||
|             # loglevel='info', |             # loglevel='info', | ||||||
|         ): |         ): | ||||||
| 
 | 
 | ||||||
|  |             retries = 3 | ||||||
|             size = 100 |             size = 100 | ||||||
|             tx, rx = trio.open_memory_channel(size) |             tx, rx = trio.open_memory_channel(size) | ||||||
|             rx = broadcast_receiver(rx, size) |             rx = broadcast_receiver(rx, size) | ||||||
|  | @ -170,9 +201,9 @@ if __name__ == '__main__': | ||||||
|                 count = 0 |                 count = 0 | ||||||
| 
 | 
 | ||||||
|                 while True: |                 while True: | ||||||
|                     with rx.subscribe(): |                     async with rx.subscribe() as brx: | ||||||
|                         try: |                         try: | ||||||
|                             async for value in rx: |                             async for value in brx: | ||||||
|                                 print(f'{task.name}: {value}') |                                 print(f'{task.name}: {value}') | ||||||
|                                 await trio.sleep(delay) |                                 await trio.sleep(delay) | ||||||
|                                 count += 1 |                                 count += 1 | ||||||
|  | @ -181,10 +212,16 @@ if __name__ == '__main__': | ||||||
|                             print( |                             print( | ||||||
|                                 f'restarting slow ass {task.name}' |                                 f'restarting slow ass {task.name}' | ||||||
|                                 f'that bailed out on {count}:{value}') |                                 f'that bailed out on {count}:{value}') | ||||||
|  |                             if count <= retries: | ||||||
|                                 continue |                                 continue | ||||||
|  |                             else: | ||||||
|  |                                 print( | ||||||
|  |                                     f'{task.name} was too slow and terminated ' | ||||||
|  |                                     f'on {count}:{value}') | ||||||
|  |                                 return | ||||||
| 
 | 
 | ||||||
|             async with trio.open_nursery() as n: |             async with trio.open_nursery() as n: | ||||||
|                 for i in range(1, 10): |                 for i in range(1, size): | ||||||
|                     n.start_soon( |                     n.start_soon( | ||||||
|                         partial( |                         partial( | ||||||
|                             sub_and_print, |                             sub_and_print, | ||||||
|  | @ -194,7 +231,7 @@ if __name__ == '__main__': | ||||||
|                     ) |                     ) | ||||||
| 
 | 
 | ||||||
|                 async with tx: |                 async with tx: | ||||||
|                     for i in cycle(range(1000)): |                     for i in cycle(range(size)): | ||||||
|                         print(f'sending: {i}') |                         print(f'sending: {i}') | ||||||
|                         await tx.send(i) |                         await tx.send(i) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue