forked from goodboy/tractor
				
			Store handle to underlying channel's `.receive()`
This allows for wrapping an existing stream by re-assigning its receive method to the allocated broadcaster's `.receive()` so as to avoid expecting any original consumer(s) of the stream to now know about the broadcaster; this instead mutates the stream to delegate to the new receive call behind the scenes any time `.subscribe()` is called. Add a `typing.Protocol` for so called "cloneable channels" until we decide/figure out a better keying system for each subscription and mask all undesired typing failures.tokio_backup
							parent
							
								
									eaa761b0c7
								
							
						
					
					
						commit
						43820e194e
					
				|  | @ -4,13 +4,15 @@ https://tokio-rs.github.io/tokio/doc/tokio/sync/broadcast/index.html | ||||||
| 
 | 
 | ||||||
| ''' | ''' | ||||||
| from __future__ import annotations | from __future__ import annotations | ||||||
|  | from abc import abstractmethod | ||||||
| from collections import deque | from collections import deque | ||||||
| from contextlib import asynccontextmanager | from contextlib import asynccontextmanager | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
| from functools import partial | from functools import partial | ||||||
| from itertools import cycle | from itertools import cycle | ||||||
| from operator import ne | from operator import ne | ||||||
| from typing import Optional | from typing import Optional, Callable, Awaitable, Any, AsyncIterator, Protocol | ||||||
|  | from typing import Generic, TypeVar | ||||||
| 
 | 
 | ||||||
| import trio | import trio | ||||||
| from trio._core._run import Task | from trio._core._run import Task | ||||||
|  | @ -19,6 +21,49 @@ from trio.lowlevel import current_task | ||||||
| import tractor | import tractor | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | # A regular invariant generic type | ||||||
|  | T = TypeVar("T") | ||||||
|  | 
 | ||||||
|  | # The type of object produced by a ReceiveChannel (covariant because | ||||||
|  | # ReceiveChannel[Derived] can be passed to someone expecting | ||||||
|  | # ReceiveChannel[Base]) | ||||||
|  | ReceiveType = TypeVar("ReceiveType", covariant=True) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class CloneableReceiveChannel( | ||||||
|  |     Protocol, | ||||||
|  |     Generic[ReceiveType], | ||||||
|  | ): | ||||||
|  |     @abstractmethod | ||||||
|  |     def clone(self) -> CloneableReceiveChannel[ReceiveType]: | ||||||
|  |         '''Clone this receiver usually by making a copy.''' | ||||||
|  | 
 | ||||||
|  |     @abstractmethod | ||||||
|  |     async def receive(self) -> ReceiveType: | ||||||
|  |         '''Same as in ``trio``.''' | ||||||
|  | 
 | ||||||
|  |     @abstractmethod | ||||||
|  |     def __aiter__(self) -> AsyncIterator[ReceiveType]: | ||||||
|  |         ... | ||||||
|  | 
 | ||||||
|  |     @abstractmethod | ||||||
|  |     async def __anext__(self) -> ReceiveType: | ||||||
|  |         ... | ||||||
|  | 
 | ||||||
|  |     # ``trio.abc.AsyncResource`` methods | ||||||
|  |     @abstractmethod | ||||||
|  |     async def aclose(self): | ||||||
|  |         ... | ||||||
|  | 
 | ||||||
|  |     @abstractmethod | ||||||
|  |     async def __aenter__(self) -> CloneableReceiveChannel[ReceiveType]: | ||||||
|  |         ... | ||||||
|  | 
 | ||||||
|  |     @abstractmethod | ||||||
|  |     async def __aexit__(self, *args) -> None: | ||||||
|  |         ... | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class Lagged(trio.TooSlowError): | class Lagged(trio.TooSlowError): | ||||||
|     '''Subscribed consumer task was too slow''' |     '''Subscribed consumer task was too slow''' | ||||||
| 
 | 
 | ||||||
|  | @ -33,7 +78,7 @@ class BroadcastState: | ||||||
|     # map of underlying clones to receiver wrappers |     # map of underlying clones to receiver wrappers | ||||||
|     # which must be provided as a singleton per broadcaster |     # which must be provided as a singleton per broadcaster | ||||||
|     # clone-subscription set. |     # clone-subscription set. | ||||||
|     subs: dict[trio.ReceiveChannel, BroadcastReceiver] |     subs: dict[CloneableReceiveChannel, int] | ||||||
| 
 | 
 | ||||||
|     # broadcast event to wakeup all sleeping consumer tasks |     # broadcast event to wakeup all sleeping consumer tasks | ||||||
|     # on a newly produced value from the sender. |     # on a newly produced value from the sender. | ||||||
|  | @ -51,8 +96,9 @@ class BroadcastReceiver(ReceiveChannel): | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
| 
 | 
 | ||||||
|         rx_chan: ReceiveChannel, |         rx_chan: CloneableReceiveChannel, | ||||||
|         state: BroadcastState, |         state: BroadcastState, | ||||||
|  |         receive_afunc: Optional[Callable[[], Awaitable[Any]]] = None, | ||||||
| 
 | 
 | ||||||
|     ) -> None: |     ) -> None: | ||||||
| 
 | 
 | ||||||
|  | @ -62,6 +108,7 @@ class BroadcastReceiver(ReceiveChannel): | ||||||
| 
 | 
 | ||||||
|         # underlying for this receiver |         # underlying for this receiver | ||||||
|         self._rx = rx_chan |         self._rx = rx_chan | ||||||
|  |         self._recv = receive_afunc or rx_chan.receive | ||||||
| 
 | 
 | ||||||
|     async def receive(self): |     async def receive(self): | ||||||
| 
 | 
 | ||||||
|  | @ -113,7 +160,7 @@ class BroadcastReceiver(ReceiveChannel): | ||||||
|         if state.sender_ready is None: |         if state.sender_ready is None: | ||||||
| 
 | 
 | ||||||
|             event = state.sender_ready = trio.Event() |             event = state.sender_ready = trio.Event() | ||||||
|             value = await self._rx.receive() |             value = await self._recv() | ||||||
| 
 | 
 | ||||||
|             # items with lower indices are "newer" |             # items with lower indices are "newer" | ||||||
|             state.queue.appendleft(value) |             state.queue.appendleft(value) | ||||||
|  | @ -152,7 +199,7 @@ class BroadcastReceiver(ReceiveChannel): | ||||||
|     @asynccontextmanager |     @asynccontextmanager | ||||||
|     async def subscribe( |     async def subscribe( | ||||||
|         self, |         self, | ||||||
|     ) -> BroadcastReceiver: |     ) -> AsyncIterator[BroadcastReceiver]: | ||||||
|         '''Subscribe for values from this broadcast receiver. |         '''Subscribe for values from this broadcast receiver. | ||||||
| 
 | 
 | ||||||
|         Returns a new ``BroadCastReceiver`` which is registered for and |         Returns a new ``BroadCastReceiver`` which is registered for and | ||||||
|  | @ -160,6 +207,8 @@ class BroadcastReceiver(ReceiveChannel): | ||||||
|         provided at creation. |         provided at creation. | ||||||
| 
 | 
 | ||||||
|         ''' |         ''' | ||||||
|  |         # if we didn't want to enforce "clone-ability" how would | ||||||
|  |         # we key arbitrary subscriptions? Use a token system? | ||||||
|         clone = self._rx.clone() |         clone = self._rx.clone() | ||||||
|         state = self._state |         state = self._state | ||||||
|         br = BroadcastReceiver( |         br = BroadcastReceiver( | ||||||
|  | @ -190,13 +239,14 @@ class BroadcastReceiver(ReceiveChannel): | ||||||
|         # up to the last received that still reside in the queue. |         # up to the last received that still reside in the queue. | ||||||
|         # Is this what we want? |         # Is this what we want? | ||||||
|         await self._rx.aclose() |         await self._rx.aclose() | ||||||
|         self._subs.pop(self._rx) |         self._state.subs.pop(self._rx) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def broadcast_receiver( | def broadcast_receiver( | ||||||
| 
 | 
 | ||||||
|     recv_chan: ReceiveChannel, |     recv_chan: CloneableReceiveChannel, | ||||||
|     max_buffer_size: int, |     max_buffer_size: int, | ||||||
|  |     **kwargs, | ||||||
| 
 | 
 | ||||||
| ) -> BroadcastReceiver: | ) -> BroadcastReceiver: | ||||||
| 
 | 
 | ||||||
|  | @ -206,6 +256,7 @@ def broadcast_receiver( | ||||||
|             queue=deque(maxlen=max_buffer_size), |             queue=deque(maxlen=max_buffer_size), | ||||||
|             subs={}, |             subs={}, | ||||||
|         ), |         ), | ||||||
|  |         **kwargs, | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -9,6 +9,7 @@ from dataclasses import dataclass | ||||||
| from typing import ( | from typing import ( | ||||||
|     Any, Iterator, Optional, Callable, |     Any, Iterator, Optional, Callable, | ||||||
|     AsyncGenerator, Dict, |     AsyncGenerator, Dict, | ||||||
|  |     AsyncIterator, Awaitable | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| import warnings | import warnings | ||||||
|  | @ -47,7 +48,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         ctx: 'Context',  # typing: ignore # noqa |         ctx: 'Context',  # typing: ignore # noqa | ||||||
|         rx_chan: trio.abc.ReceiveChannel, |         rx_chan: trio.MemoryReceiveChannel, | ||||||
|     ) -> None: |     ) -> None: | ||||||
|         self._ctx = ctx |         self._ctx = ctx | ||||||
|         self._rx_chan = rx_chan |         self._rx_chan = rx_chan | ||||||
|  | @ -246,7 +247,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): | ||||||
|     async def subscribe( |     async def subscribe( | ||||||
|         self, |         self, | ||||||
| 
 | 
 | ||||||
|     ) -> BroadcastReceiver: |     ) -> AsyncIterator[BroadcastReceiver]: | ||||||
|         '''Allocate and return a ``BroadcastReceiver`` which delegates |         '''Allocate and return a ``BroadcastReceiver`` which delegates | ||||||
|         to this message stream. |         to this message stream. | ||||||
| 
 | 
 | ||||||
|  | @ -259,21 +260,24 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): | ||||||
|         receiver wrapper. |         receiver wrapper. | ||||||
| 
 | 
 | ||||||
|         ''' |         ''' | ||||||
|  |         # NOTE: This operation is indempotent and non-reversible, so be | ||||||
|  |         # sure you can deal with any (theoretical) overhead of the the | ||||||
|  |         # allocated ``BroadcastReceiver`` before calling this method for | ||||||
|  |         # the first time. | ||||||
|         if self._broadcaster is None: |         if self._broadcaster is None: | ||||||
|             self._broadcaster = broadcast_receiver( |             self._broadcaster = broadcast_receiver( | ||||||
|                 self, |                 self, | ||||||
|                 self._rx_chan._state.max_buffer_size, |                 self._rx_chan._state.max_buffer_size,  # type: ignore | ||||||
|             ) |             ) | ||||||
|             # override the original stream instance's receive to |  | ||||||
|             # delegate to the broadcaster receive such that |  | ||||||
|             # new subscribers will be copied received values |  | ||||||
|             # XXX: this operation is indempotent and non-reversible, |  | ||||||
|             # so be sure you can deal with any (theoretical) overhead |  | ||||||
|             # of the the ``BroadcastReceiver`` before calling |  | ||||||
|             # this method for the first time. |  | ||||||
| 
 | 
 | ||||||
|             # XXX: why does this work without a recursion issue?! |             # NOTE: we override the original stream instance's receive | ||||||
|             self.receive = self._broadcaster.receive |             # method to now delegate to the broadcaster's ``.receive()`` | ||||||
|  |             # such that new subscribers will be copied received values | ||||||
|  |             # and this stream doesn't have to expect it's original | ||||||
|  |             # consumer(s) to get a new broadcast rx handle. | ||||||
|  |             self.receive = self._broadcaster.receive  # type: ignore | ||||||
|  |             # seems there's no graceful way to type this with ``mypy``? | ||||||
|  |             # https://github.com/python/mypy/issues/708 | ||||||
| 
 | 
 | ||||||
|         async with self._broadcaster.subscribe() as bstream: |         async with self._broadcaster.subscribe() as bstream: | ||||||
|             # a ``MsgStream`` clone is allocated for the |             # a ``MsgStream`` clone is allocated for the | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue