Drop uuid4 keys, raise closed error on subscription after close
parent
2bad2bac50
commit
bec3f5999d
|
@ -12,7 +12,6 @@ from functools import partial
|
||||||
from operator import ne
|
from operator import ne
|
||||||
from typing import Optional, Callable, Awaitable, Any, AsyncIterator, Protocol
|
from typing import Optional, Callable, Awaitable, Any, AsyncIterator, Protocol
|
||||||
from typing import Generic, TypeVar
|
from typing import Generic, TypeVar
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
from trio._core._run import Task
|
from trio._core._run import Task
|
||||||
|
@ -23,9 +22,8 @@ from trio.lowlevel import current_task
|
||||||
# A regular invariant generic type
|
# A regular invariant generic type
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
# The type of object produced by a ReceiveChannel (covariant because
|
# covariant because AsyncReceiver[Derived] can be passed to someone
|
||||||
# ReceiveChannel[Derived] can be passed to someone expecting
|
# expecting AsyncReceiver[Base])
|
||||||
# ReceiveChannel[Base])
|
|
||||||
ReceiveType = TypeVar("ReceiveType", covariant=True)
|
ReceiveType = TypeVar("ReceiveType", covariant=True)
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,9 +31,13 @@ class AsyncReceiver(
|
||||||
Protocol,
|
Protocol,
|
||||||
Generic[ReceiveType],
|
Generic[ReceiveType],
|
||||||
):
|
):
|
||||||
|
'''An async receivable duck-type that quacks much like trio's
|
||||||
|
``trio.abc.ReceieveChannel``.
|
||||||
|
|
||||||
|
'''
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def receive(self) -> ReceiveType:
|
async def receive(self) -> ReceiveType:
|
||||||
'''Same as in ``trio``.'''
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __aiter__(self) -> AsyncIterator[ReceiveType]:
|
def __aiter__(self) -> AsyncIterator[ReceiveType]:
|
||||||
|
@ -60,7 +62,10 @@ class AsyncReceiver(
|
||||||
|
|
||||||
|
|
||||||
class Lagged(trio.TooSlowError):
|
class Lagged(trio.TooSlowError):
|
||||||
'''Subscribed consumer task was too slow'''
|
'''Subscribed consumer task was too slow and was overrun
|
||||||
|
by the fastest consumer-producer pair.
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -70,8 +75,8 @@ class BroadcastState:
|
||||||
'''
|
'''
|
||||||
queue: deque
|
queue: deque
|
||||||
|
|
||||||
# map of underlying uuid keys to receiver instances which must be
|
# map of underlying instance id keys to receiver instances which
|
||||||
# provided as a singleton per broadcaster set.
|
# must be provided as a singleton per broadcaster set.
|
||||||
subs: dict[str, int]
|
subs: dict[str, int]
|
||||||
|
|
||||||
# broadcast event to wake up all sleeping consumer tasks
|
# broadcast event to wake up all sleeping consumer tasks
|
||||||
|
@ -84,13 +89,12 @@ class BroadcastReceiver(ReceiveChannel):
|
||||||
fastest consumer.
|
fastest consumer.
|
||||||
|
|
||||||
Additional consumer tasks can receive all produced values by registering
|
Additional consumer tasks can receive all produced values by registering
|
||||||
with ``.subscribe()`` and receiving from thew new instance it delivers.
|
with ``.subscribe()`` and receiving from the new instance it delivers.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
||||||
key: str,
|
|
||||||
rx_chan: AsyncReceiver,
|
rx_chan: AsyncReceiver,
|
||||||
state: BroadcastState,
|
state: BroadcastState,
|
||||||
receive_afunc: Optional[Callable[[], Awaitable[Any]]] = None,
|
receive_afunc: Optional[Callable[[], Awaitable[Any]]] = None,
|
||||||
|
@ -98,9 +102,9 @@ class BroadcastReceiver(ReceiveChannel):
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
# register the original underlying (clone)
|
# register the original underlying (clone)
|
||||||
self.key = key
|
self.key = id(self)
|
||||||
self._state = state
|
self._state = state
|
||||||
state.subs[key] = -1
|
state.subs[self.key] = -1
|
||||||
|
|
||||||
# underlying for this receiver
|
# underlying for this receiver
|
||||||
self._rx = rx_chan
|
self._rx = rx_chan
|
||||||
|
@ -216,29 +220,23 @@ class BroadcastReceiver(ReceiveChannel):
|
||||||
provided at creation.
|
provided at creation.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
# use a uuid4 for a tee-instance token
|
if self._closed:
|
||||||
key = str(uuid4())
|
raise trio.ClosedResourceError
|
||||||
|
|
||||||
state = self._state
|
state = self._state
|
||||||
br = BroadcastReceiver(
|
br = BroadcastReceiver(
|
||||||
key=key,
|
|
||||||
rx_chan=self._rx,
|
rx_chan=self._rx,
|
||||||
state=state,
|
state=state,
|
||||||
receive_afunc=self._recv,
|
receive_afunc=self._recv,
|
||||||
)
|
)
|
||||||
# assert clone in state.subs
|
# assert clone in state.subs
|
||||||
assert key in state.subs
|
assert br.key in state.subs
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield br
|
yield br
|
||||||
finally:
|
finally:
|
||||||
await br.aclose()
|
await br.aclose()
|
||||||
|
|
||||||
# TODO:
|
|
||||||
# - should there be some ._closed flag that causes
|
|
||||||
# consumers to die **before** they read all queued values?
|
|
||||||
# - if subs only open and close clones then the underlying
|
|
||||||
# will never be killed until the last instance closes?
|
|
||||||
# This is correct right?
|
|
||||||
async def aclose(
|
async def aclose(
|
||||||
self,
|
self,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -248,10 +246,7 @@ class BroadcastReceiver(ReceiveChannel):
|
||||||
|
|
||||||
# XXX: leaving it like this consumers can still get values
|
# XXX: leaving it like this consumers can still get values
|
||||||
# up to the last received that still reside in the queue.
|
# up to the last received that still reside in the queue.
|
||||||
# Is this what we want?
|
|
||||||
self._state.subs.pop(self.key)
|
self._state.subs.pop(self.key)
|
||||||
# if not self._state.subs:
|
|
||||||
# await self._rx.aclose()
|
|
||||||
|
|
||||||
self._closed = True
|
self._closed = True
|
||||||
|
|
||||||
|
@ -265,7 +260,6 @@ def broadcast_receiver(
|
||||||
) -> BroadcastReceiver:
|
) -> BroadcastReceiver:
|
||||||
|
|
||||||
return BroadcastReceiver(
|
return BroadcastReceiver(
|
||||||
str(uuid4()),
|
|
||||||
recv_chan,
|
recv_chan,
|
||||||
state=BroadcastState(
|
state=BroadcastState(
|
||||||
queue=deque(maxlen=max_buffer_size),
|
queue=deque(maxlen=max_buffer_size),
|
||||||
|
|
Loading…
Reference in New Issue