forked from goodboy/tractor
				
			Drop uuid4 keys, raise closed error on subscription after close
							parent
							
								
									ac14f611b2
								
							
						
					
					
						commit
						aad6cf9070
					
				| 
						 | 
					@ -12,7 +12,6 @@ from functools import partial
 | 
				
			||||||
from operator import ne
 | 
					from operator import ne
 | 
				
			||||||
from typing import Optional, Callable, Awaitable, Any, AsyncIterator, Protocol
 | 
					from typing import Optional, Callable, Awaitable, Any, AsyncIterator, Protocol
 | 
				
			||||||
from typing import Generic, TypeVar
 | 
					from typing import Generic, TypeVar
 | 
				
			||||||
from uuid import uuid4
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
import trio
 | 
					import trio
 | 
				
			||||||
from trio._core._run import Task
 | 
					from trio._core._run import Task
 | 
				
			||||||
| 
						 | 
					@ -23,9 +22,8 @@ from trio.lowlevel import current_task
 | 
				
			||||||
# A regular invariant generic type
 | 
					# A regular invariant generic type
 | 
				
			||||||
T = TypeVar("T")
 | 
					T = TypeVar("T")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# The type of object produced by a ReceiveChannel (covariant because
 | 
					# covariant because AsyncReceiver[Derived] can be passed to someone
 | 
				
			||||||
# ReceiveChannel[Derived] can be passed to someone expecting
 | 
					# expecting AsyncReceiver[Base])
 | 
				
			||||||
# ReceiveChannel[Base])
 | 
					 | 
				
			||||||
ReceiveType = TypeVar("ReceiveType", covariant=True)
 | 
					ReceiveType = TypeVar("ReceiveType", covariant=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -33,9 +31,13 @@ class AsyncReceiver(
 | 
				
			||||||
    Protocol,
 | 
					    Protocol,
 | 
				
			||||||
    Generic[ReceiveType],
 | 
					    Generic[ReceiveType],
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
 | 
					    '''An async receivable duck-type that quacks much like trio's
 | 
				
			||||||
 | 
					    ``trio.abc.ReceieveChannel``.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    '''
 | 
				
			||||||
    @abstractmethod
 | 
					    @abstractmethod
 | 
				
			||||||
    async def receive(self) -> ReceiveType:
 | 
					    async def receive(self) -> ReceiveType:
 | 
				
			||||||
        '''Same as in ``trio``.'''
 | 
					        ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @abstractmethod
 | 
					    @abstractmethod
 | 
				
			||||||
    def __aiter__(self) -> AsyncIterator[ReceiveType]:
 | 
					    def __aiter__(self) -> AsyncIterator[ReceiveType]:
 | 
				
			||||||
| 
						 | 
					@ -60,7 +62,10 @@ class AsyncReceiver(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Lagged(trio.TooSlowError):
 | 
					class Lagged(trio.TooSlowError):
 | 
				
			||||||
    '''Subscribed consumer task was too slow'''
 | 
					    '''Subscribed consumer task was too slow and was overrun
 | 
				
			||||||
 | 
					    by the fastest consumer-producer pair.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    '''
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@dataclass
 | 
					@dataclass
 | 
				
			||||||
| 
						 | 
					@ -70,11 +75,11 @@ class BroadcastState:
 | 
				
			||||||
    '''
 | 
					    '''
 | 
				
			||||||
    queue: deque
 | 
					    queue: deque
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # map of underlying uuid keys to receiver instances which must be
 | 
					    # map of underlying instance id keys to receiver instances which
 | 
				
			||||||
    # provided as a singleton per broadcaster set.
 | 
					    # must be provided as a singleton per broadcaster set.
 | 
				
			||||||
    subs: dict[str, int]
 | 
					    subs: dict[str, int]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # broadcast event to wakeup all sleeping consumer tasks
 | 
					    # broadcast event to wake up all sleeping consumer tasks
 | 
				
			||||||
    # on a newly produced value from the sender.
 | 
					    # on a newly produced value from the sender.
 | 
				
			||||||
    recv_ready: Optional[tuple[str, trio.Event]] = None
 | 
					    recv_ready: Optional[tuple[str, trio.Event]] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -84,13 +89,12 @@ class BroadcastReceiver(ReceiveChannel):
 | 
				
			||||||
    fastest consumer.
 | 
					    fastest consumer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Additional consumer tasks can receive all produced values by registering
 | 
					    Additional consumer tasks can receive all produced values by registering
 | 
				
			||||||
    with ``.subscribe()`` and receiving from thew new instance it delivers.
 | 
					    with ``.subscribe()`` and receiving from the new instance it delivers.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    '''
 | 
					    '''
 | 
				
			||||||
    def __init__(
 | 
					    def __init__(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        key: str,
 | 
					 | 
				
			||||||
        rx_chan: AsyncReceiver,
 | 
					        rx_chan: AsyncReceiver,
 | 
				
			||||||
        state: BroadcastState,
 | 
					        state: BroadcastState,
 | 
				
			||||||
        receive_afunc: Optional[Callable[[], Awaitable[Any]]] = None,
 | 
					        receive_afunc: Optional[Callable[[], Awaitable[Any]]] = None,
 | 
				
			||||||
| 
						 | 
					@ -98,9 +102,9 @@ class BroadcastReceiver(ReceiveChannel):
 | 
				
			||||||
    ) -> None:
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # register the original underlying (clone)
 | 
					        # register the original underlying (clone)
 | 
				
			||||||
        self.key = key
 | 
					        self.key = id(self)
 | 
				
			||||||
        self._state = state
 | 
					        self._state = state
 | 
				
			||||||
        state.subs[key] = -1
 | 
					        state.subs[self.key] = -1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # underlying for this receiver
 | 
					        # underlying for this receiver
 | 
				
			||||||
        self._rx = rx_chan
 | 
					        self._rx = rx_chan
 | 
				
			||||||
| 
						 | 
					@ -216,29 +220,23 @@ class BroadcastReceiver(ReceiveChannel):
 | 
				
			||||||
        provided at creation.
 | 
					        provided at creation.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        '''
 | 
					        '''
 | 
				
			||||||
        # use a uuid4 for a tee-instance token
 | 
					        if self._closed:
 | 
				
			||||||
        key = str(uuid4())
 | 
					            raise trio.ClosedResourceError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        state = self._state
 | 
					        state = self._state
 | 
				
			||||||
        br = BroadcastReceiver(
 | 
					        br = BroadcastReceiver(
 | 
				
			||||||
            key=key,
 | 
					 | 
				
			||||||
            rx_chan=self._rx,
 | 
					            rx_chan=self._rx,
 | 
				
			||||||
            state=state,
 | 
					            state=state,
 | 
				
			||||||
            receive_afunc=self._recv,
 | 
					            receive_afunc=self._recv,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        # assert clone in state.subs
 | 
					        # assert clone in state.subs
 | 
				
			||||||
        assert key in state.subs
 | 
					        assert br.key in state.subs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            yield br
 | 
					            yield br
 | 
				
			||||||
        finally:
 | 
					        finally:
 | 
				
			||||||
            await br.aclose()
 | 
					            await br.aclose()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # TODO:
 | 
					 | 
				
			||||||
    # - should there be some ._closed flag that causes
 | 
					 | 
				
			||||||
    #   consumers to die **before** they read all queued values?
 | 
					 | 
				
			||||||
    # - if subs only open and close clones then the underlying
 | 
					 | 
				
			||||||
    #   will never be killed until the last instance closes?
 | 
					 | 
				
			||||||
    #   This is correct right?
 | 
					 | 
				
			||||||
    async def aclose(
 | 
					    async def aclose(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
    ) -> None:
 | 
					    ) -> None:
 | 
				
			||||||
| 
						 | 
					@ -248,10 +246,7 @@ class BroadcastReceiver(ReceiveChannel):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # XXX: leaving it like this consumers can still get values
 | 
					        # XXX: leaving it like this consumers can still get values
 | 
				
			||||||
        # up to the last received that still reside in the queue.
 | 
					        # up to the last received that still reside in the queue.
 | 
				
			||||||
        # Is this what we want?
 | 
					 | 
				
			||||||
        self._state.subs.pop(self.key)
 | 
					        self._state.subs.pop(self.key)
 | 
				
			||||||
        # if not self._state.subs:
 | 
					 | 
				
			||||||
        #     await self._rx.aclose()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self._closed = True
 | 
					        self._closed = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -265,7 +260,6 @@ def broadcast_receiver(
 | 
				
			||||||
) -> BroadcastReceiver:
 | 
					) -> BroadcastReceiver:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return BroadcastReceiver(
 | 
					    return BroadcastReceiver(
 | 
				
			||||||
        str(uuid4()),
 | 
					 | 
				
			||||||
        recv_chan,
 | 
					        recv_chan,
 | 
				
			||||||
        state=BroadcastState(
 | 
					        state=BroadcastState(
 | 
				
			||||||
            queue=deque(maxlen=max_buffer_size),
 | 
					            queue=deque(maxlen=max_buffer_size),
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue