Store handle to underlying channel's `.receive()`
This allows for wrapping an existing stream by re-assigning its receive method to the allocated broadcaster's `.receive()` so as to avoid expecting any original consumer(s) of the stream to now know about the broadcaster; this instead mutates the stream to delegate to the new receive call behind the scenes any time `.subscribe()` is called. Add a `typing.Protocol` for so called "cloneable channels" until we decide/figure out a better keying system for each subscription and mask all undesired typing failures.live_on_air_from_tokio
parent
2d1c24112b
commit
6c17c7367a
|
@ -4,13 +4,15 @@ https://tokio-rs.github.io/tokio/doc/tokio/sync/broadcast/index.html
|
|||
|
||||
'''
|
||||
from __future__ import annotations
|
||||
from abc import abstractmethod
|
||||
from collections import deque
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from itertools import cycle
|
||||
from operator import ne
|
||||
from typing import Optional
|
||||
from typing import Optional, Callable, Awaitable, Any, AsyncIterator, Protocol
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
import trio
|
||||
from trio._core._run import Task
|
||||
|
@ -19,6 +21,49 @@ from trio.lowlevel import current_task
|
|||
import tractor
|
||||
|
||||
|
||||
# A regular invariant generic type
|
||||
T = TypeVar("T")
|
||||
|
||||
# The type of object produced by a ReceiveChannel (covariant because
|
||||
# ReceiveChannel[Derived] can be passed to someone expecting
|
||||
# ReceiveChannel[Base])
|
||||
ReceiveType = TypeVar("ReceiveType", covariant=True)
|
||||
|
||||
|
||||
class CloneableReceiveChannel(
|
||||
Protocol,
|
||||
Generic[ReceiveType],
|
||||
):
|
||||
@abstractmethod
|
||||
def clone(self) -> CloneableReceiveChannel[ReceiveType]:
|
||||
'''Clone this receiver usually by making a copy.'''
|
||||
|
||||
@abstractmethod
|
||||
async def receive(self) -> ReceiveType:
|
||||
'''Same as in ``trio``.'''
|
||||
|
||||
@abstractmethod
|
||||
def __aiter__(self) -> AsyncIterator[ReceiveType]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def __anext__(self) -> ReceiveType:
|
||||
...
|
||||
|
||||
# ``trio.abc.AsyncResource`` methods
|
||||
@abstractmethod
|
||||
async def aclose(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def __aenter__(self) -> CloneableReceiveChannel[ReceiveType]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def __aexit__(self, *args) -> None:
|
||||
...
|
||||
|
||||
|
||||
class Lagged(trio.TooSlowError):
|
||||
'''Subscribed consumer task was too slow'''
|
||||
|
||||
|
@ -33,7 +78,7 @@ class BroadcastState:
|
|||
# map of underlying clones to receiver wrappers
|
||||
# which must be provided as a singleton per broadcaster
|
||||
# clone-subscription set.
|
||||
subs: dict[trio.ReceiveChannel, BroadcastReceiver]
|
||||
subs: dict[CloneableReceiveChannel, int]
|
||||
|
||||
# broadcast event to wakeup all sleeping consumer tasks
|
||||
# on a newly produced value from the sender.
|
||||
|
@ -51,8 +96,9 @@ class BroadcastReceiver(ReceiveChannel):
|
|||
def __init__(
|
||||
self,
|
||||
|
||||
rx_chan: ReceiveChannel,
|
||||
rx_chan: CloneableReceiveChannel,
|
||||
state: BroadcastState,
|
||||
receive_afunc: Optional[Callable[[], Awaitable[Any]]] = None,
|
||||
|
||||
) -> None:
|
||||
|
||||
|
@ -62,6 +108,7 @@ class BroadcastReceiver(ReceiveChannel):
|
|||
|
||||
# underlying for this receiver
|
||||
self._rx = rx_chan
|
||||
self._recv = receive_afunc or rx_chan.receive
|
||||
|
||||
async def receive(self):
|
||||
|
||||
|
@ -113,7 +160,7 @@ class BroadcastReceiver(ReceiveChannel):
|
|||
if state.sender_ready is None:
|
||||
|
||||
event = state.sender_ready = trio.Event()
|
||||
value = await self._rx.receive()
|
||||
value = await self._recv()
|
||||
|
||||
# items with lower indices are "newer"
|
||||
state.queue.appendleft(value)
|
||||
|
@ -152,7 +199,7 @@ class BroadcastReceiver(ReceiveChannel):
|
|||
@asynccontextmanager
|
||||
async def subscribe(
|
||||
self,
|
||||
) -> BroadcastReceiver:
|
||||
) -> AsyncIterator[BroadcastReceiver]:
|
||||
'''Subscribe for values from this broadcast receiver.
|
||||
|
||||
Returns a new ``BroadCastReceiver`` which is registered for and
|
||||
|
@ -160,6 +207,8 @@ class BroadcastReceiver(ReceiveChannel):
|
|||
provided at creation.
|
||||
|
||||
'''
|
||||
# if we didn't want to enforce "clone-ability" how would
|
||||
# we key arbitrary subscriptions? Use a token system?
|
||||
clone = self._rx.clone()
|
||||
state = self._state
|
||||
br = BroadcastReceiver(
|
||||
|
@ -190,13 +239,14 @@ class BroadcastReceiver(ReceiveChannel):
|
|||
# up to the last received that still reside in the queue.
|
||||
# Is this what we want?
|
||||
await self._rx.aclose()
|
||||
self._subs.pop(self._rx)
|
||||
self._state.subs.pop(self._rx)
|
||||
|
||||
|
||||
def broadcast_receiver(
|
||||
|
||||
recv_chan: ReceiveChannel,
|
||||
recv_chan: CloneableReceiveChannel,
|
||||
max_buffer_size: int,
|
||||
**kwargs,
|
||||
|
||||
) -> BroadcastReceiver:
|
||||
|
||||
|
@ -206,6 +256,7 @@ def broadcast_receiver(
|
|||
queue=deque(maxlen=max_buffer_size),
|
||||
subs={},
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ from dataclasses import dataclass
|
|||
from typing import (
|
||||
Any, Iterator, Optional, Callable,
|
||||
AsyncGenerator, Dict,
|
||||
AsyncIterator, Awaitable
|
||||
)
|
||||
|
||||
import warnings
|
||||
|
@ -49,8 +50,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
|
|||
def __init__(
|
||||
self,
|
||||
ctx: 'Context', # typing: ignore # noqa
|
||||
rx_chan: trio.abc.ReceiveChannel,
|
||||
|
||||
rx_chan: trio.MemoryReceiveChannel,
|
||||
) -> None:
|
||||
self._ctx = ctx
|
||||
self._rx_chan = rx_chan
|
||||
|
@ -248,7 +248,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
|
|||
async def subscribe(
|
||||
self,
|
||||
|
||||
) -> BroadcastReceiver:
|
||||
) -> AsyncIterator[BroadcastReceiver]:
|
||||
'''Allocate and return a ``BroadcastReceiver`` which delegates
|
||||
to this message stream.
|
||||
|
||||
|
@ -261,21 +261,24 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel):
|
|||
receiver wrapper.
|
||||
|
||||
'''
|
||||
# NOTE: This operation is indempotent and non-reversible, so be
|
||||
# sure you can deal with any (theoretical) overhead of the the
|
||||
# allocated ``BroadcastReceiver`` before calling this method for
|
||||
# the first time.
|
||||
if self._broadcaster is None:
|
||||
self._broadcaster = broadcast_receiver(
|
||||
self,
|
||||
self._rx_chan._state.max_buffer_size,
|
||||
self._rx_chan._state.max_buffer_size, # type: ignore
|
||||
)
|
||||
# override the original stream instance's receive to
|
||||
# delegate to the broadcaster receive such that
|
||||
# new subscribers will be copied received values
|
||||
# XXX: this operation is indempotent and non-reversible,
|
||||
# so be sure you can deal with any (theoretical) overhead
|
||||
# of the the ``BroadcastReceiver`` before calling
|
||||
# this method for the first time.
|
||||
|
||||
# XXX: why does this work without a recursion issue?!
|
||||
self.receive = self._broadcaster.receive
|
||||
# NOTE: we override the original stream instance's receive
|
||||
# method to now delegate to the broadcaster's ``.receive()``
|
||||
# such that new subscribers will be copied received values
|
||||
# and this stream doesn't have to expect it's original
|
||||
# consumer(s) to get a new broadcast rx handle.
|
||||
self.receive = self._broadcaster.receive # type: ignore
|
||||
# seems there's no graceful way to type this with ``mypy``?
|
||||
# https://github.com/python/mypy/issues/708
|
||||
|
||||
async with self._broadcaster.subscribe() as bstream:
|
||||
# a ``MsgStream`` clone is allocated for the
|
||||
|
|
Loading…
Reference in New Issue