diff --git a/tractor/_broadcast.py b/tractor/_broadcast.py index 7d56430..bba8021 100644 --- a/tractor/_broadcast.py +++ b/tractor/_broadcast.py @@ -4,13 +4,15 @@ https://tokio-rs.github.io/tokio/doc/tokio/sync/broadcast/index.html ''' from __future__ import annotations +from abc import abstractmethod from collections import deque from contextlib import asynccontextmanager from dataclasses import dataclass from functools import partial from itertools import cycle from operator import ne -from typing import Optional +from typing import Optional, Callable, Awaitable, Any, AsyncIterator, Protocol +from typing import Generic, TypeVar import trio from trio._core._run import Task @@ -19,6 +21,49 @@ from trio.lowlevel import current_task import tractor +# A regular invariant generic type +T = TypeVar("T") + +# The type of object produced by a ReceiveChannel (covariant because +# ReceiveChannel[Derived] can be passed to someone expecting +# ReceiveChannel[Base]) +ReceiveType = TypeVar("ReceiveType", covariant=True) + + +class CloneableReceiveChannel( + Protocol, + Generic[ReceiveType], +): + @abstractmethod + def clone(self) -> CloneableReceiveChannel[ReceiveType]: + '''Clone this receiver usually by making a copy.''' + + @abstractmethod + async def receive(self) -> ReceiveType: + '''Same as in ``trio``.''' + + @abstractmethod + def __aiter__(self) -> AsyncIterator[ReceiveType]: + ... + + @abstractmethod + async def __anext__(self) -> ReceiveType: + ... + + # ``trio.abc.AsyncResource`` methods + @abstractmethod + async def aclose(self): + ... + + @abstractmethod + async def __aenter__(self) -> CloneableReceiveChannel[ReceiveType]: + ... + + @abstractmethod + async def __aexit__(self, *args) -> None: + ... + + class Lagged(trio.TooSlowError): '''Subscribed consumer task was too slow''' @@ -33,7 +78,7 @@ class BroadcastState: # map of underlying clones to receiver wrappers # which must be provided as a singleton per broadcaster # clone-subscription set. - subs: dict[trio.ReceiveChannel, BroadcastReceiver] + subs: dict[CloneableReceiveChannel, int] # broadcast event to wakeup all sleeping consumer tasks # on a newly produced value from the sender. @@ -51,8 +96,9 @@ class BroadcastReceiver(ReceiveChannel): def __init__( self, - rx_chan: ReceiveChannel, + rx_chan: CloneableReceiveChannel, state: BroadcastState, + receive_afunc: Optional[Callable[[], Awaitable[Any]]] = None, ) -> None: @@ -62,6 +108,7 @@ class BroadcastReceiver(ReceiveChannel): # underlying for this receiver self._rx = rx_chan + self._recv = receive_afunc or rx_chan.receive async def receive(self): @@ -113,7 +160,7 @@ class BroadcastReceiver(ReceiveChannel): if state.sender_ready is None: event = state.sender_ready = trio.Event() - value = await self._rx.receive() + value = await self._recv() # items with lower indices are "newer" state.queue.appendleft(value) @@ -152,7 +199,7 @@ class BroadcastReceiver(ReceiveChannel): @asynccontextmanager async def subscribe( self, - ) -> BroadcastReceiver: + ) -> AsyncIterator[BroadcastReceiver]: '''Subscribe for values from this broadcast receiver. Returns a new ``BroadCastReceiver`` which is registered for and @@ -160,6 +207,8 @@ class BroadcastReceiver(ReceiveChannel): provided at creation. ''' + # if we didn't want to enforce "clone-ability" how would + # we key arbitrary subscriptions? Use a token system? clone = self._rx.clone() state = self._state br = BroadcastReceiver( @@ -190,13 +239,14 @@ class BroadcastReceiver(ReceiveChannel): # up to the last received that still reside in the queue. # Is this what we want? await self._rx.aclose() - self._subs.pop(self._rx) + self._state.subs.pop(self._rx) def broadcast_receiver( - recv_chan: ReceiveChannel, + recv_chan: CloneableReceiveChannel, max_buffer_size: int, + **kwargs, ) -> BroadcastReceiver: @@ -206,6 +256,7 @@ def broadcast_receiver( queue=deque(maxlen=max_buffer_size), subs={}, ), + **kwargs, ) diff --git a/tractor/_streaming.py b/tractor/_streaming.py index 6e08272..2956451 100644 --- a/tractor/_streaming.py +++ b/tractor/_streaming.py @@ -9,6 +9,7 @@ from dataclasses import dataclass from typing import ( Any, Iterator, Optional, Callable, AsyncGenerator, Dict, + AsyncIterator, Awaitable ) import warnings @@ -47,7 +48,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): def __init__( self, ctx: 'Context', # typing: ignore # noqa - rx_chan: trio.abc.ReceiveChannel, + rx_chan: trio.MemoryReceiveChannel, ) -> None: self._ctx = ctx self._rx_chan = rx_chan @@ -246,7 +247,7 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): async def subscribe( self, - ) -> BroadcastReceiver: + ) -> AsyncIterator[BroadcastReceiver]: '''Allocate and return a ``BroadcastReceiver`` which delegates to this message stream. @@ -259,21 +260,24 @@ class ReceiveMsgStream(trio.abc.ReceiveChannel): receiver wrapper. ''' + # NOTE: This operation is indempotent and non-reversible, so be + # sure you can deal with any (theoretical) overhead of the the + # allocated ``BroadcastReceiver`` before calling this method for + # the first time. if self._broadcaster is None: self._broadcaster = broadcast_receiver( self, - self._rx_chan._state.max_buffer_size, + self._rx_chan._state.max_buffer_size, # type: ignore ) - # override the original stream instance's receive to - # delegate to the broadcaster receive such that - # new subscribers will be copied received values - # XXX: this operation is indempotent and non-reversible, - # so be sure you can deal with any (theoretical) overhead - # of the the ``BroadcastReceiver`` before calling - # this method for the first time. - # XXX: why does this work without a recursion issue?! - self.receive = self._broadcaster.receive + # NOTE: we override the original stream instance's receive + # method to now delegate to the broadcaster's ``.receive()`` + # such that new subscribers will be copied received values + # and this stream doesn't have to expect it's original + # consumer(s) to get a new broadcast rx handle. + self.receive = self._broadcaster.receive # type: ignore + # seems there's no graceful way to type this with ``mypy``? + # https://github.com/python/mypy/issues/708 async with self._broadcaster.subscribe() as bstream: # a ``MsgStream`` clone is allocated for the