Store handle to underlying channel's `.receive()`

This allows for wrapping an existing stream by re-assigning its receive
method to the allocated broadcaster's `.receive()` so as to avoid
expecting any original consumer(s) of the stream to now know about the
broadcaster; this instead mutates the stream to delegate to the new
receive call behind the scenes any time `.subscribe()` is called.

Add a `typing.Protocol` for so called "cloneable channels" until we
decide/figure out a better keying system for each subscription and
mask all undesired typing failures.
tokio_backup
Tyler Goodlet 2021-08-16 12:47:49 -04:00
parent eaa761b0c7
commit 43820e194e
2 changed files with 74 additions and 19 deletions

View File

@ -4,13 +4,15 @@ https://tokio-rs.github.io/tokio/doc/tokio/sync/broadcast/index.html
'''
from __future__ import annotations
from abc import abstractmethod
from collections import deque
from contextlib import asynccontextmanager
from dataclasses import dataclass
from functools import partial
from itertools import cycle
from operator import ne
from typing import Optional
from typing import Optional, Callable, Awaitable, Any, AsyncIterator, Protocol
from typing import Generic, TypeVar
import trio
from trio._core._run import Task
@ -19,6 +21,49 @@ from trio.lowlevel import current_task
import tractor
# A regular invariant generic type
T = TypeVar("T")
# The type of object produced by a ReceiveChannel (covariant because
# ReceiveChannel[Derived] can be passed to someone expecting
# ReceiveChannel[Base])
ReceiveType = TypeVar("ReceiveType", covariant=True)
class CloneableReceiveChannel(
Protocol,
Generic[ReceiveType],
):
@abstractmethod
def clone(self) -> CloneableReceiveChannel[ReceiveType]:
'''Clone this receiver usually by making a copy.'''
@abstractmethod
async def receive(self) -> ReceiveType:
'''Same as in ``trio``.'''
@abstractmethod
def __aiter__(self) -> AsyncIterator[ReceiveType]:
...
@abstractmethod
async def __anext__(self) -> ReceiveType:
...
# ``trio.abc.AsyncResource`` methods
@abstractmethod
async def aclose(self):
...
@abstractmethod
async def __aenter__(self) -> CloneableReceiveChannel[ReceiveType]:
...
@abstractmethod
async def __aexit__(self, *args) -> None:
...
class Lagged(trio.TooSlowError):
'''Subscribed consumer task was too slow'''
@ -33,7 +78,7 @@ class BroadcastState:
# map of underlying clones to receiver wrappers
# which must be provided as a singleton per broadcaster
# clone-subscription set.
subs: dict[trio.ReceiveChannel, BroadcastReceiver]
subs: dict[CloneableReceiveChannel, int]
# broadcast event to wakeup all sleeping consumer tasks
# on a newly produced value from the sender.
@ -51,8 +96,9 @@ class BroadcastReceiver(ReceiveChannel):
def __init__(
self,
rx_chan: ReceiveChannel,
rx_chan: CloneableReceiveChannel,
state: BroadcastState,
receive_afunc: Optional[Callable[[], Awaitable[Any]]] = None,
) -> None:
@ -62,6 +108,7 @@ class BroadcastReceiver(ReceiveChannel):
# underlying for this receiver
self._rx = rx_chan
self._recv = receive_afunc or rx_chan.receive
async def receive(self):
@ -113,7 +160,7 @@ class BroadcastReceiver(ReceiveChannel):
if state.sender_ready is None:
event = state.sender_ready = trio.Event()
value = await self._rx.receive()
value = await self._recv()
# items with lower indices are "newer"
state.queue.appendleft(value)
@ -152,7 +199,7 @@ class BroadcastReceiver(ReceiveChannel):
@asynccontextmanager
async def subscribe(
self,
) -> BroadcastReceiver:
) -> AsyncIterator[BroadcastReceiver]:
'''Subscribe for values from this broadcast receiver.
Returns a new ``BroadCastReceiver`` which is registered for and
@ -160,6 +207,8 @@ class BroadcastReceiver(ReceiveChannel):
provided at creation.
'''
# if we didn't want to enforce "clone-ability" how would
# we key arbitrary subscriptions? Use a token system?
clone = self._rx.clone()
state = self._state
br = BroadcastReceiver(
@ -190,13 +239,14 @@ class BroadcastReceiver(ReceiveChannel):
# up to the last received that still reside in the queue.
# Is this what we want?
await self._rx.aclose()
self._subs.pop(self._rx)
self._state.subs.pop(self._rx)
def broadcast_receiver(
recv_chan: ReceiveChannel,
recv_chan: CloneableReceiveChannel,
max_buffer_size: int,
**kwargs,
) -> BroadcastReceiver:
@ -206,6 +256,7 @@ def broadcast_receiver(
queue=deque(maxlen=max_buffer_size),
subs={},
),
**kwargs,
)

View File

@ -9,6 +9,7 @@ from dataclasses import dataclass
from typing import (
Any, Iterator, Optional, Callable,
AsyncGenerator, Dict,
AsyncIterator, Awaitable
)
import warnings
@ -47,7 +48,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
def __init__(
self,
ctx: 'Context', # typing: ignore # noqa
rx_chan: trio.abc.ReceiveChannel,
rx_chan: trio.MemoryReceiveChannel,
) -> None:
self._ctx = ctx
self._rx_chan = rx_chan
@ -246,7 +247,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
async def subscribe(
self,
) -> BroadcastReceiver:
) -> AsyncIterator[BroadcastReceiver]:
'''Allocate and return a ``BroadcastReceiver`` which delegates
to this message stream.
@ -259,21 +260,24 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
receiver wrapper.
'''
# NOTE: This operation is indempotent and non-reversible, so be
# sure you can deal with any (theoretical) overhead of the the
# allocated ``BroadcastReceiver`` before calling this method for
# the first time.
if self._broadcaster is None:
self._broadcaster = broadcast_receiver(
self,
self._rx_chan._state.max_buffer_size,
self._rx_chan._state.max_buffer_size, # type: ignore
)
# override the original stream instance's receive to
# delegate to the broadcaster receive such that
# new subscribers will be copied received values
# XXX: this operation is indempotent and non-reversible,
# so be sure you can deal with any (theoretical) overhead
# of the the ``BroadcastReceiver`` before calling
# this method for the first time.
# XXX: why does this work without a recursion issue?!
self.receive = self._broadcaster.receive
# NOTE: we override the original stream instance's receive
# method to now delegate to the broadcaster's ``.receive()``
# such that new subscribers will be copied received values
# and this stream doesn't have to expect it's original
# consumer(s) to get a new broadcast rx handle.
self.receive = self._broadcaster.receive # type: ignore
# seems there's no graceful way to type this with ``mypy``?
# https://github.com/python/mypy/issues/708
async with self._broadcaster.subscribe() as bstream:
# a ``MsgStream`` clone is allocated for the