Store handle to underlying channel's `.receive()`
This allows for wrapping an existing stream by re-assigning its receive method to the allocated broadcaster's `.receive()` so as to avoid expecting any original consumer(s) of the stream to now know about the broadcaster; this instead mutates the stream to delegate to the new receive call behind the scenes any time `.subscribe()` is called. Add a `typing.Protocol` for so called "cloneable channels" until we decide/figure out a better keying system for each subscription and mask all undesired typing failures.tokio_backup
parent
eaa761b0c7
commit
43820e194e
|
@ -4,13 +4,15 @@ https://tokio-rs.github.io/tokio/doc/tokio/sync/broadcast/index.html
|
||||||
|
|
||||||
'''
|
'''
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
from abc import abstractmethod
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from itertools import cycle
|
from itertools import cycle
|
||||||
from operator import ne
|
from operator import ne
|
||||||
from typing import Optional
|
from typing import Optional, Callable, Awaitable, Any, AsyncIterator, Protocol
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
from trio._core._run import Task
|
from trio._core._run import Task
|
||||||
|
@ -19,6 +21,49 @@ from trio.lowlevel import current_task
|
||||||
import tractor
|
import tractor
|
||||||
|
|
||||||
|
|
||||||
|
# A regular invariant generic type
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
# The type of object produced by a ReceiveChannel (covariant because
|
||||||
|
# ReceiveChannel[Derived] can be passed to someone expecting
|
||||||
|
# ReceiveChannel[Base])
|
||||||
|
ReceiveType = TypeVar("ReceiveType", covariant=True)
|
||||||
|
|
||||||
|
|
||||||
|
class CloneableReceiveChannel(
|
||||||
|
Protocol,
|
||||||
|
Generic[ReceiveType],
|
||||||
|
):
|
||||||
|
@abstractmethod
|
||||||
|
def clone(self) -> CloneableReceiveChannel[ReceiveType]:
|
||||||
|
'''Clone this receiver usually by making a copy.'''
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def receive(self) -> ReceiveType:
|
||||||
|
'''Same as in ``trio``.'''
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __aiter__(self) -> AsyncIterator[ReceiveType]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def __anext__(self) -> ReceiveType:
|
||||||
|
...
|
||||||
|
|
||||||
|
# ``trio.abc.AsyncResource`` methods
|
||||||
|
@abstractmethod
|
||||||
|
async def aclose(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def __aenter__(self) -> CloneableReceiveChannel[ReceiveType]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def __aexit__(self, *args) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class Lagged(trio.TooSlowError):
|
class Lagged(trio.TooSlowError):
|
||||||
'''Subscribed consumer task was too slow'''
|
'''Subscribed consumer task was too slow'''
|
||||||
|
|
||||||
|
@ -33,7 +78,7 @@ class BroadcastState:
|
||||||
# map of underlying clones to receiver wrappers
|
# map of underlying clones to receiver wrappers
|
||||||
# which must be provided as a singleton per broadcaster
|
# which must be provided as a singleton per broadcaster
|
||||||
# clone-subscription set.
|
# clone-subscription set.
|
||||||
subs: dict[trio.ReceiveChannel, BroadcastReceiver]
|
subs: dict[CloneableReceiveChannel, int]
|
||||||
|
|
||||||
# broadcast event to wakeup all sleeping consumer tasks
|
# broadcast event to wakeup all sleeping consumer tasks
|
||||||
# on a newly produced value from the sender.
|
# on a newly produced value from the sender.
|
||||||
|
@ -51,8 +96,9 @@ class BroadcastReceiver(ReceiveChannel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
||||||
rx_chan: ReceiveChannel,
|
rx_chan: CloneableReceiveChannel,
|
||||||
state: BroadcastState,
|
state: BroadcastState,
|
||||||
|
receive_afunc: Optional[Callable[[], Awaitable[Any]]] = None,
|
||||||
|
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
|
@ -62,6 +108,7 @@ class BroadcastReceiver(ReceiveChannel):
|
||||||
|
|
||||||
# underlying for this receiver
|
# underlying for this receiver
|
||||||
self._rx = rx_chan
|
self._rx = rx_chan
|
||||||
|
self._recv = receive_afunc or rx_chan.receive
|
||||||
|
|
||||||
async def receive(self):
|
async def receive(self):
|
||||||
|
|
||||||
|
@ -113,7 +160,7 @@ class BroadcastReceiver(ReceiveChannel):
|
||||||
if state.sender_ready is None:
|
if state.sender_ready is None:
|
||||||
|
|
||||||
event = state.sender_ready = trio.Event()
|
event = state.sender_ready = trio.Event()
|
||||||
value = await self._rx.receive()
|
value = await self._recv()
|
||||||
|
|
||||||
# items with lower indices are "newer"
|
# items with lower indices are "newer"
|
||||||
state.queue.appendleft(value)
|
state.queue.appendleft(value)
|
||||||
|
@ -152,7 +199,7 @@ class BroadcastReceiver(ReceiveChannel):
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def subscribe(
|
async def subscribe(
|
||||||
self,
|
self,
|
||||||
) -> BroadcastReceiver:
|
) -> AsyncIterator[BroadcastReceiver]:
|
||||||
'''Subscribe for values from this broadcast receiver.
|
'''Subscribe for values from this broadcast receiver.
|
||||||
|
|
||||||
Returns a new ``BroadCastReceiver`` which is registered for and
|
Returns a new ``BroadCastReceiver`` which is registered for and
|
||||||
|
@ -160,6 +207,8 @@ class BroadcastReceiver(ReceiveChannel):
|
||||||
provided at creation.
|
provided at creation.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
# if we didn't want to enforce "clone-ability" how would
|
||||||
|
# we key arbitrary subscriptions? Use a token system?
|
||||||
clone = self._rx.clone()
|
clone = self._rx.clone()
|
||||||
state = self._state
|
state = self._state
|
||||||
br = BroadcastReceiver(
|
br = BroadcastReceiver(
|
||||||
|
@ -190,13 +239,14 @@ class BroadcastReceiver(ReceiveChannel):
|
||||||
# up to the last received that still reside in the queue.
|
# up to the last received that still reside in the queue.
|
||||||
# Is this what we want?
|
# Is this what we want?
|
||||||
await self._rx.aclose()
|
await self._rx.aclose()
|
||||||
self._subs.pop(self._rx)
|
self._state.subs.pop(self._rx)
|
||||||
|
|
||||||
|
|
||||||
def broadcast_receiver(
|
def broadcast_receiver(
|
||||||
|
|
||||||
recv_chan: ReceiveChannel,
|
recv_chan: CloneableReceiveChannel,
|
||||||
max_buffer_size: int,
|
max_buffer_size: int,
|
||||||
|
**kwargs,
|
||||||
|
|
||||||
) -> BroadcastReceiver:
|
) -> BroadcastReceiver:
|
||||||
|
|
||||||
|
@ -206,6 +256,7 @@ def broadcast_receiver(
|
||||||
queue=deque(maxlen=max_buffer_size),
|
queue=deque(maxlen=max_buffer_size),
|
||||||
subs={},
|
subs={},
|
||||||
),
|
),
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import (
|
||||||
Any, Iterator, Optional, Callable,
|
Any, Iterator, Optional, Callable,
|
||||||
AsyncGenerator, Dict,
|
AsyncGenerator, Dict,
|
||||||
|
AsyncIterator, Awaitable
|
||||||
)
|
)
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
@ -47,7 +48,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
ctx: 'Context', # typing: ignore # noqa
|
ctx: 'Context', # typing: ignore # noqa
|
||||||
rx_chan: trio.abc.ReceiveChannel,
|
rx_chan: trio.MemoryReceiveChannel,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._ctx = ctx
|
self._ctx = ctx
|
||||||
self._rx_chan = rx_chan
|
self._rx_chan = rx_chan
|
||||||
|
@ -246,7 +247,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
|
||||||
async def subscribe(
|
async def subscribe(
|
||||||
self,
|
self,
|
||||||
|
|
||||||
) -> BroadcastReceiver:
|
) -> AsyncIterator[BroadcastReceiver]:
|
||||||
'''Allocate and return a ``BroadcastReceiver`` which delegates
|
'''Allocate and return a ``BroadcastReceiver`` which delegates
|
||||||
to this message stream.
|
to this message stream.
|
||||||
|
|
||||||
|
@ -259,21 +260,24 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
|
||||||
receiver wrapper.
|
receiver wrapper.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
# NOTE: This operation is indempotent and non-reversible, so be
|
||||||
|
# sure you can deal with any (theoretical) overhead of the the
|
||||||
|
# allocated ``BroadcastReceiver`` before calling this method for
|
||||||
|
# the first time.
|
||||||
if self._broadcaster is None:
|
if self._broadcaster is None:
|
||||||
self._broadcaster = broadcast_receiver(
|
self._broadcaster = broadcast_receiver(
|
||||||
self,
|
self,
|
||||||
self._rx_chan._state.max_buffer_size,
|
self._rx_chan._state.max_buffer_size, # type: ignore
|
||||||
)
|
)
|
||||||
# override the original stream instance's receive to
|
|
||||||
# delegate to the broadcaster receive such that
|
|
||||||
# new subscribers will be copied received values
|
|
||||||
# XXX: this operation is indempotent and non-reversible,
|
|
||||||
# so be sure you can deal with any (theoretical) overhead
|
|
||||||
# of the the ``BroadcastReceiver`` before calling
|
|
||||||
# this method for the first time.
|
|
||||||
|
|
||||||
# XXX: why does this work without a recursion issue?!
|
# NOTE: we override the original stream instance's receive
|
||||||
self.receive = self._broadcaster.receive
|
# method to now delegate to the broadcaster's ``.receive()``
|
||||||
|
# such that new subscribers will be copied received values
|
||||||
|
# and this stream doesn't have to expect it's original
|
||||||
|
# consumer(s) to get a new broadcast rx handle.
|
||||||
|
self.receive = self._broadcaster.receive # type: ignore
|
||||||
|
# seems there's no graceful way to type this with ``mypy``?
|
||||||
|
# https://github.com/python/mypy/issues/708
|
||||||
|
|
||||||
async with self._broadcaster.subscribe() as bstream:
|
async with self._broadcaster.subscribe() as bstream:
|
||||||
# a ``MsgStream`` clone is allocated for the
|
# a ``MsgStream`` clone is allocated for the
|
||||||
|
|
Loading…
Reference in New Issue