Switch to using typing.Protocl instead of abc.ABC on ChannelManager, improve abstraction and add comments

one_ring_to_rule_them_all
Guillermo Rodriguez 2025-04-03 12:34:40 -03:00
parent 4b9d6b9276
commit b1e1187a19
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
1 changed files with 120 additions and 44 deletions

View File

@ -18,9 +18,12 @@ Ring buffer ipc publish-subscribe mechanism brokered by ringd
can dynamically add new outputs (publisher) or inputs (subscriber)
'''
import time
from abc import (
ABC,
abstractmethod
from typing import (
runtime_checkable,
Protocol,
TypeVar,
Self,
AsyncContextManager
)
from contextlib import asynccontextmanager as acm
from dataclasses import dataclass
@ -30,6 +33,7 @@ import tractor
from tractor.ipc import (
RingBuffBytesSender,
RingBuffBytesReceiver,
attach_to_ringbuf_schannel,
attach_to_ringbuf_rchannel
)
@ -40,28 +44,72 @@ import tractor.ipc._ringbuf._ringd as ringd
log = tractor.log.get_logger(__name__)
ChannelType = TypeVar('ChannelType')
@dataclass
class ChannelInfo:
connect_time: float
name: str
channel: RingBuffBytesSender
channel: ChannelType
cancel_scope: trio.CancelScope
class ChannelManager(ABC):
# TODO: maybe move this abstraction to another module or standalone?
# its not ring buf specific and allows fan out and fan in an a dynamic
# amount of channels
@runtime_checkable
class ChannelManager(Protocol[ChannelType]):
'''
Common data structures and methods pubsub classes use to manage channels &
their related handler background tasks, as well as cancellation of them.
'''
def __init__(
self,
n: trio.Nursery,
):
self._n = n
self._channels: list[ChannelInfo] = []
self._channels: list[Self.ChannelInfo] = []
@abstractmethod
async def _channel_handler_task(self, name: str):
async def _open_channel(
self,
name: str
) -> AsyncContextManager[ChannelType]:
'''
Used to instantiate channel resources given a name
'''
...
async def _channel_task(self, info: ChannelInfo) -> None:
'''
Long running task that manages the channel
'''
...
async def _channel_handler_task(self, name: str):
async with self._open_channel(name) as chan:
with trio.CancelScope() as cancel_scope:
info = Self.ChannelInfo(
connect_time=time.time(),
name=name,
channel=chan,
cancel_scope=cancel_scope
)
self._channels.append(info)
await self._channel_task(info)
self._maybe_destroy_channel(name)
def find_channel(self, name: str) -> tuple[int, ChannelInfo] | None:
'''
Given a channel name maybe return its index and value from
internal _channels list.
'''
for entry in enumerate(self._channels):
i, info = entry
if info.name == name:
@ -70,6 +118,11 @@ class ChannelManager(ABC):
return None
def _maybe_destroy_channel(self, name: str):
'''
If channel exists cancel its scope and remove from internal
_channels list.
'''
maybe_entry = self.find_channel(name)
if maybe_entry:
i, info = maybe_entry
@ -77,12 +130,20 @@ class ChannelManager(ABC):
del self._channels[i]
def add_channel(self, name: str):
'''
Add a new channel to be handled
'''
self._n.start_soon(
self._channel_handler_task,
name
)
def remove_channel(self, name: str):
'''
Remove a channel and stop its handling
'''
self._maybe_destroy_channel(name)
def __len__(self) -> int:
@ -92,8 +153,24 @@ class ChannelManager(ABC):
for chan in self._channels:
self._maybe_destroy_channel(chan.name)
async def __aenter__(self):
return self
class RingBuffPublisher(ChannelManager, trio.abc.SendChannel[bytes]):
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.aclose()
class RingBuffPublisher(
ChannelManager[RingBuffBytesSender]
):
'''
Implement ChannelManager protocol + trio.abc.SendChannel[bytes]
using ring buffers as transport.
- use a `trio.Event` to make sure `send` blocks until at least one channel
available.
'''
def __init__(
self,
@ -107,29 +184,24 @@ class RingBuffPublisher(ChannelManager, trio.abc.SendChannel[bytes]):
self._batch_size: int = batch_size
async def _channel_handler_task(
@acm
async def _open_channel(
self,
name: str
):
) -> AsyncContextManager[RingBuffBytesSender]:
async with (
ringd.open_ringbuf(
name=name,
must_exist=True,
) as token,
attach_to_ringbuf_schannel(token) as schan
attach_to_ringbuf_schannel(token) as chan
):
with trio.CancelScope() as cancel_scope:
self._channels.append(ChannelInfo(
connect_time=time.time(),
name=name,
channel=schan,
cancel_scope=cancel_scope
))
yield chan
async def _channel_task(self, info: Self.ChannelInfo) -> None:
self._connect_event.set()
await trio.sleep_forever()
self._maybe_destroy_channel(name)
async def send(self, msg: bytes):
# wait at least one decoder connected
if len(self) == 0:
@ -182,11 +254,21 @@ async def open_ringbuf_publisher(
) as outputs
):
yield outputs
await outputs.aclose()
class RingBuffSubscriber(ChannelManager, trio.abc.ReceiveChannel[bytes]):
class RingBuffSubscriber(
ChannelManager[RingBuffBytesReceiver]
):
'''
Implement ChannelManager protocol + trio.abc.ReceiveChannel[bytes]
using ring buffers as transport.
- use a trio memory channel pair to multiplex all received messages into a
single `trio.MemoryReceiveChannel`, give a sender channel clone to each
_channel_task.
'''
def __init__(
self,
n: trio.Nursery,
@ -194,35 +276,30 @@ class RingBuffSubscriber(ChannelManager, trio.abc.ReceiveChannel[bytes]):
super().__init__(n)
self._send_chan, self._recv_chan = trio.open_memory_channel(0)
async def _channel_handler_task(
@acm
async def _open_channel(
self,
name: str
):
) -> AsyncContextManager[RingBuffBytesReceiver]:
async with (
ringd.open_ringbuf(
name=name,
must_exist=True
must_exist=True,
) as token,
attach_to_ringbuf_rchannel(token) as rchan
attach_to_ringbuf_rchannel(token) as chan
):
with trio.CancelScope() as cancel_scope:
self._channels.append(ChannelInfo(
connect_time=time.time(),
name=name,
channel=rchan,
cancel_scope=cancel_scope
))
yield chan
async def _channel_task(self, info: ChannelInfo) -> None:
send_chan = self._send_chan.clone()
try:
async for msg in rchan:
async for msg in info.channel:
await send_chan.send(msg)
except tractor._exceptions.InternalError:
# TODO: cleaner cancellation!
...
self._maybe_destroy_channel(name)
async def receive(self) -> bytes:
return await self._recv_chan.receive()
@ -234,5 +311,4 @@ async def open_ringbuf_subscriber():
RingBuffSubscriber(n) as inputs
):
yield inputs
await inputs.aclose()