diff --git a/tractor/ipc/_ringbuf/_pubsub.py b/tractor/ipc/_ringbuf/_pubsub.py index 94ce7460..50c48366 100644 --- a/tractor/ipc/_ringbuf/_pubsub.py +++ b/tractor/ipc/_ringbuf/_pubsub.py @@ -18,9 +18,12 @@ Ring buffer ipc publish-subscribe mechanism brokered by ringd can dynamically add new outputs (publisher) or inputs (subscriber) ''' import time -from abc import ( - ABC, - abstractmethod +from typing import ( + runtime_checkable, + Protocol, + TypeVar, + Self, + AsyncContextManager ) from contextlib import asynccontextmanager as acm from dataclasses import dataclass @@ -30,6 +33,7 @@ import tractor from tractor.ipc import ( RingBuffBytesSender, + RingBuffBytesReceiver, attach_to_ringbuf_schannel, attach_to_ringbuf_rchannel ) @@ -40,28 +44,72 @@ import tractor.ipc._ringbuf._ringd as ringd log = tractor.log.get_logger(__name__) +ChannelType = TypeVar('ChannelType') + + @dataclass class ChannelInfo: connect_time: float name: str - channel: RingBuffBytesSender + channel: ChannelType cancel_scope: trio.CancelScope -class ChannelManager(ABC): +# TODO: maybe move this abstraction to another module or standalone? +# its not ring buf specific and allows fan out and fan in an a dynamic +# amount of channels +@runtime_checkable +class ChannelManager(Protocol[ChannelType]): + ''' + Common data structures and methods pubsub classes use to manage channels & + their related handler background tasks, as well as cancellation of them. + + ''' def __init__( self, n: trio.Nursery, ): self._n = n - self._channels: list[ChannelInfo] = [] + self._channels: list[Self.ChannelInfo] = [] - @abstractmethod - async def _channel_handler_task(self, name: str): + async def _open_channel( + self, + name: str + ) -> AsyncContextManager[ChannelType]: + ''' + Used to instantiate channel resources given a name + + ''' ... + async def _channel_task(self, info: ChannelInfo) -> None: + ''' + Long running task that manages the channel + + ''' + ... + + async def _channel_handler_task(self, name: str): + async with self._open_channel(name) as chan: + with trio.CancelScope() as cancel_scope: + info = Self.ChannelInfo( + connect_time=time.time(), + name=name, + channel=chan, + cancel_scope=cancel_scope + ) + self._channels.append(info) + await self._channel_task(info) + + self._maybe_destroy_channel(name) + def find_channel(self, name: str) -> tuple[int, ChannelInfo] | None: + ''' + Given a channel name maybe return its index and value from + internal _channels list. + + ''' for entry in enumerate(self._channels): i, info = entry if info.name == name: @@ -70,6 +118,11 @@ class ChannelManager(ABC): return None def _maybe_destroy_channel(self, name: str): + ''' + If channel exists cancel its scope and remove from internal + _channels list. + + ''' maybe_entry = self.find_channel(name) if maybe_entry: i, info = maybe_entry @@ -77,12 +130,20 @@ class ChannelManager(ABC): del self._channels[i] def add_channel(self, name: str): + ''' + Add a new channel to be handled + + ''' self._n.start_soon( self._channel_handler_task, name ) def remove_channel(self, name: str): + ''' + Remove a channel and stop its handling + + ''' self._maybe_destroy_channel(name) def __len__(self) -> int: @@ -92,8 +153,24 @@ class ChannelManager(ABC): for chan in self._channels: self._maybe_destroy_channel(chan.name) + async def __aenter__(self): + return self -class RingBuffPublisher(ChannelManager, trio.abc.SendChannel[bytes]): + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.aclose() + + +class RingBuffPublisher( + ChannelManager[RingBuffBytesSender] +): + ''' + Implement ChannelManager protocol + trio.abc.SendChannel[bytes] + using ring buffers as transport. + + - use a `trio.Event` to make sure `send` blocks until at least one channel + available. + + ''' def __init__( self, @@ -107,28 +184,23 @@ class RingBuffPublisher(ChannelManager, trio.abc.SendChannel[bytes]): self._batch_size: int = batch_size - async def _channel_handler_task( + @acm + async def _open_channel( self, name: str - ): + ) -> AsyncContextManager[RingBuffBytesSender]: async with ( ringd.open_ringbuf( name=name, must_exist=True, ) as token, - attach_to_ringbuf_schannel(token) as schan + attach_to_ringbuf_schannel(token) as chan ): - with trio.CancelScope() as cancel_scope: - self._channels.append(ChannelInfo( - connect_time=time.time(), - name=name, - channel=schan, - cancel_scope=cancel_scope - )) - self._connect_event.set() - await trio.sleep_forever() + yield chan - self._maybe_destroy_channel(name) + async def _channel_task(self, info: Self.ChannelInfo) -> None: + self._connect_event.set() + await trio.sleep_forever() async def send(self, msg: bytes): # wait at least one decoder connected @@ -182,11 +254,21 @@ async def open_ringbuf_publisher( ) as outputs ): yield outputs - await outputs.aclose() -class RingBuffSubscriber(ChannelManager, trio.abc.ReceiveChannel[bytes]): +class RingBuffSubscriber( + ChannelManager[RingBuffBytesReceiver] +): + ''' + Implement ChannelManager protocol + trio.abc.ReceiveChannel[bytes] + using ring buffers as transport. + + - use a trio memory channel pair to multiplex all received messages into a + single `trio.MemoryReceiveChannel`, give a sender channel clone to each + _channel_task. + + ''' def __init__( self, n: trio.Nursery, @@ -194,34 +276,29 @@ class RingBuffSubscriber(ChannelManager, trio.abc.ReceiveChannel[bytes]): super().__init__(n) self._send_chan, self._recv_chan = trio.open_memory_channel(0) - async def _channel_handler_task( + @acm + async def _open_channel( self, name: str - ): + ) -> AsyncContextManager[RingBuffBytesReceiver]: async with ( ringd.open_ringbuf( name=name, - must_exist=True + must_exist=True, ) as token, - - attach_to_ringbuf_rchannel(token) as rchan + attach_to_ringbuf_rchannel(token) as chan ): - with trio.CancelScope() as cancel_scope: - self._channels.append(ChannelInfo( - connect_time=time.time(), - name=name, - channel=rchan, - cancel_scope=cancel_scope - )) - send_chan = self._send_chan.clone() - try: - async for msg in rchan: - await send_chan.send(msg) + yield chan - except tractor._exceptions.InternalError: - ... + async def _channel_task(self, info: ChannelInfo) -> None: + send_chan = self._send_chan.clone() + try: + async for msg in info.channel: + await send_chan.send(msg) - self._maybe_destroy_channel(name) + except tractor._exceptions.InternalError: + # TODO: cleaner cancellation! + ... async def receive(self) -> bytes: return await self._recv_chan.receive() @@ -234,5 +311,4 @@ async def open_ringbuf_subscriber(): RingBuffSubscriber(n) as inputs ): yield inputs - await inputs.aclose()