Drop uuid4 keys, raise closed error on subscription after close

tokio_backup
Tyler Goodlet 2021-08-20 13:04:17 -04:00
parent ac14f611b2
commit aad6cf9070
1 changed files with 21 additions and 27 deletions

View File

@ -12,7 +12,6 @@ from functools import partial
from operator import ne from operator import ne
from typing import Optional, Callable, Awaitable, Any, AsyncIterator, Protocol from typing import Optional, Callable, Awaitable, Any, AsyncIterator, Protocol
from typing import Generic, TypeVar from typing import Generic, TypeVar
from uuid import uuid4
import trio import trio
from trio._core._run import Task from trio._core._run import Task
@ -23,9 +22,8 @@ from trio.lowlevel import current_task
# A regular invariant generic type # A regular invariant generic type
T = TypeVar("T") T = TypeVar("T")
# The type of object produced by a ReceiveChannel (covariant because # covariant because AsyncReceiver[Derived] can be passed to someone
# ReceiveChannel[Derived] can be passed to someone expecting # expecting AsyncReceiver[Base])
# ReceiveChannel[Base])
ReceiveType = TypeVar("ReceiveType", covariant=True) ReceiveType = TypeVar("ReceiveType", covariant=True)
@ -33,9 +31,13 @@ class AsyncReceiver(
Protocol, Protocol,
Generic[ReceiveType], Generic[ReceiveType],
): ):
'''An async receivable duck-type that quacks much like trio's
``trio.abc.ReceieveChannel``.
'''
@abstractmethod @abstractmethod
async def receive(self) -> ReceiveType: async def receive(self) -> ReceiveType:
'''Same as in ``trio``.''' ...
@abstractmethod @abstractmethod
def __aiter__(self) -> AsyncIterator[ReceiveType]: def __aiter__(self) -> AsyncIterator[ReceiveType]:
@ -60,7 +62,10 @@ class AsyncReceiver(
class Lagged(trio.TooSlowError): class Lagged(trio.TooSlowError):
'''Subscribed consumer task was too slow''' '''Subscribed consumer task was too slow and was overrun
by the fastest consumer-producer pair.
'''
@dataclass @dataclass
@ -70,11 +75,11 @@ class BroadcastState:
''' '''
queue: deque queue: deque
# map of underlying uuid keys to receiver instances which must be # map of underlying instance id keys to receiver instances which
# provided as a singleton per broadcaster set. # must be provided as a singleton per broadcaster set.
subs: dict[str, int] subs: dict[str, int]
# broadcast event to wakeup all sleeping consumer tasks # broadcast event to wake up all sleeping consumer tasks
# on a newly produced value from the sender. # on a newly produced value from the sender.
recv_ready: Optional[tuple[str, trio.Event]] = None recv_ready: Optional[tuple[str, trio.Event]] = None
@ -84,13 +89,12 @@ class BroadcastReceiver(ReceiveChannel):
fastest consumer. fastest consumer.
Additional consumer tasks can receive all produced values by registering Additional consumer tasks can receive all produced values by registering
with ``.subscribe()`` and receiving from thew new instance it delivers. with ``.subscribe()`` and receiving from the new instance it delivers.
''' '''
def __init__( def __init__(
self, self,
key: str,
rx_chan: AsyncReceiver, rx_chan: AsyncReceiver,
state: BroadcastState, state: BroadcastState,
receive_afunc: Optional[Callable[[], Awaitable[Any]]] = None, receive_afunc: Optional[Callable[[], Awaitable[Any]]] = None,
@ -98,9 +102,9 @@ class BroadcastReceiver(ReceiveChannel):
) -> None: ) -> None:
# register the original underlying (clone) # register the original underlying (clone)
self.key = key self.key = id(self)
self._state = state self._state = state
state.subs[key] = -1 state.subs[self.key] = -1
# underlying for this receiver # underlying for this receiver
self._rx = rx_chan self._rx = rx_chan
@ -216,29 +220,23 @@ class BroadcastReceiver(ReceiveChannel):
provided at creation. provided at creation.
''' '''
# use a uuid4 for a tee-instance token if self._closed:
key = str(uuid4()) raise trio.ClosedResourceError
state = self._state state = self._state
br = BroadcastReceiver( br = BroadcastReceiver(
key=key,
rx_chan=self._rx, rx_chan=self._rx,
state=state, state=state,
receive_afunc=self._recv, receive_afunc=self._recv,
) )
# assert clone in state.subs # assert clone in state.subs
assert key in state.subs assert br.key in state.subs
try: try:
yield br yield br
finally: finally:
await br.aclose() await br.aclose()
# TODO:
# - should there be some ._closed flag that causes
# consumers to die **before** they read all queued values?
# - if subs only open and close clones then the underlying
# will never be killed until the last instance closes?
# This is correct right?
async def aclose( async def aclose(
self, self,
) -> None: ) -> None:
@ -248,10 +246,7 @@ class BroadcastReceiver(ReceiveChannel):
# XXX: leaving it like this consumers can still get values # XXX: leaving it like this consumers can still get values
# up to the last received that still reside in the queue. # up to the last received that still reside in the queue.
# Is this what we want?
self._state.subs.pop(self.key) self._state.subs.pop(self.key)
# if not self._state.subs:
# await self._rx.aclose()
self._closed = True self._closed = True
@ -265,7 +260,6 @@ def broadcast_receiver(
) -> BroadcastReceiver: ) -> BroadcastReceiver:
return BroadcastReceiver( return BroadcastReceiver(
str(uuid4()),
recv_chan, recv_chan,
state=BroadcastState( state=BroadcastState(
queue=deque(maxlen=max_buffer_size), queue=deque(maxlen=max_buffer_size),