Switch to using typing.Protocl instead of abc.ABC on ChannelManager, improve abstraction and add comments

one_ring_to_rule_them_all
Guillermo Rodriguez 2025-04-03 12:34:40 -03:00
parent 4b9d6b9276
commit b1e1187a19
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
1 changed files with 120 additions and 44 deletions

View File

@ -18,9 +18,12 @@ Ring buffer ipc publish-subscribe mechanism brokered by ringd
can dynamically add new outputs (publisher) or inputs (subscriber) can dynamically add new outputs (publisher) or inputs (subscriber)
''' '''
import time import time
from abc import ( from typing import (
ABC, runtime_checkable,
abstractmethod Protocol,
TypeVar,
Self,
AsyncContextManager
) )
from contextlib import asynccontextmanager as acm from contextlib import asynccontextmanager as acm
from dataclasses import dataclass from dataclasses import dataclass
@ -30,6 +33,7 @@ import tractor
from tractor.ipc import ( from tractor.ipc import (
RingBuffBytesSender, RingBuffBytesSender,
RingBuffBytesReceiver,
attach_to_ringbuf_schannel, attach_to_ringbuf_schannel,
attach_to_ringbuf_rchannel attach_to_ringbuf_rchannel
) )
@ -40,28 +44,72 @@ import tractor.ipc._ringbuf._ringd as ringd
log = tractor.log.get_logger(__name__) log = tractor.log.get_logger(__name__)
ChannelType = TypeVar('ChannelType')
@dataclass @dataclass
class ChannelInfo: class ChannelInfo:
connect_time: float connect_time: float
name: str name: str
channel: RingBuffBytesSender channel: ChannelType
cancel_scope: trio.CancelScope cancel_scope: trio.CancelScope
class ChannelManager(ABC): # TODO: maybe move this abstraction to another module or standalone?
# its not ring buf specific and allows fan out and fan in an a dynamic
# amount of channels
@runtime_checkable
class ChannelManager(Protocol[ChannelType]):
'''
Common data structures and methods pubsub classes use to manage channels &
their related handler background tasks, as well as cancellation of them.
'''
def __init__( def __init__(
self, self,
n: trio.Nursery, n: trio.Nursery,
): ):
self._n = n self._n = n
self._channels: list[ChannelInfo] = [] self._channels: list[Self.ChannelInfo] = []
@abstractmethod async def _open_channel(
async def _channel_handler_task(self, name: str): self,
name: str
) -> AsyncContextManager[ChannelType]:
'''
Used to instantiate channel resources given a name
'''
... ...
async def _channel_task(self, info: ChannelInfo) -> None:
'''
Long running task that manages the channel
'''
...
async def _channel_handler_task(self, name: str):
async with self._open_channel(name) as chan:
with trio.CancelScope() as cancel_scope:
info = Self.ChannelInfo(
connect_time=time.time(),
name=name,
channel=chan,
cancel_scope=cancel_scope
)
self._channels.append(info)
await self._channel_task(info)
self._maybe_destroy_channel(name)
def find_channel(self, name: str) -> tuple[int, ChannelInfo] | None: def find_channel(self, name: str) -> tuple[int, ChannelInfo] | None:
'''
Given a channel name maybe return its index and value from
internal _channels list.
'''
for entry in enumerate(self._channels): for entry in enumerate(self._channels):
i, info = entry i, info = entry
if info.name == name: if info.name == name:
@ -70,6 +118,11 @@ class ChannelManager(ABC):
return None return None
def _maybe_destroy_channel(self, name: str): def _maybe_destroy_channel(self, name: str):
'''
If channel exists cancel its scope and remove from internal
_channels list.
'''
maybe_entry = self.find_channel(name) maybe_entry = self.find_channel(name)
if maybe_entry: if maybe_entry:
i, info = maybe_entry i, info = maybe_entry
@ -77,12 +130,20 @@ class ChannelManager(ABC):
del self._channels[i] del self._channels[i]
def add_channel(self, name: str): def add_channel(self, name: str):
'''
Add a new channel to be handled
'''
self._n.start_soon( self._n.start_soon(
self._channel_handler_task, self._channel_handler_task,
name name
) )
def remove_channel(self, name: str): def remove_channel(self, name: str):
'''
Remove a channel and stop its handling
'''
self._maybe_destroy_channel(name) self._maybe_destroy_channel(name)
def __len__(self) -> int: def __len__(self) -> int:
@ -92,8 +153,24 @@ class ChannelManager(ABC):
for chan in self._channels: for chan in self._channels:
self._maybe_destroy_channel(chan.name) self._maybe_destroy_channel(chan.name)
async def __aenter__(self):
return self
class RingBuffPublisher(ChannelManager, trio.abc.SendChannel[bytes]): async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.aclose()
class RingBuffPublisher(
ChannelManager[RingBuffBytesSender]
):
'''
Implement ChannelManager protocol + trio.abc.SendChannel[bytes]
using ring buffers as transport.
- use a `trio.Event` to make sure `send` blocks until at least one channel
available.
'''
def __init__( def __init__(
self, self,
@ -107,29 +184,24 @@ class RingBuffPublisher(ChannelManager, trio.abc.SendChannel[bytes]):
self._batch_size: int = batch_size self._batch_size: int = batch_size
async def _channel_handler_task( @acm
async def _open_channel(
self, self,
name: str name: str
): ) -> AsyncContextManager[RingBuffBytesSender]:
async with ( async with (
ringd.open_ringbuf( ringd.open_ringbuf(
name=name, name=name,
must_exist=True, must_exist=True,
) as token, ) as token,
attach_to_ringbuf_schannel(token) as schan attach_to_ringbuf_schannel(token) as chan
): ):
with trio.CancelScope() as cancel_scope: yield chan
self._channels.append(ChannelInfo(
connect_time=time.time(), async def _channel_task(self, info: Self.ChannelInfo) -> None:
name=name,
channel=schan,
cancel_scope=cancel_scope
))
self._connect_event.set() self._connect_event.set()
await trio.sleep_forever() await trio.sleep_forever()
self._maybe_destroy_channel(name)
async def send(self, msg: bytes): async def send(self, msg: bytes):
# wait at least one decoder connected # wait at least one decoder connected
if len(self) == 0: if len(self) == 0:
@ -182,11 +254,21 @@ async def open_ringbuf_publisher(
) as outputs ) as outputs
): ):
yield outputs yield outputs
await outputs.aclose()
class RingBuffSubscriber(ChannelManager, trio.abc.ReceiveChannel[bytes]): class RingBuffSubscriber(
ChannelManager[RingBuffBytesReceiver]
):
'''
Implement ChannelManager protocol + trio.abc.ReceiveChannel[bytes]
using ring buffers as transport.
- use a trio memory channel pair to multiplex all received messages into a
single `trio.MemoryReceiveChannel`, give a sender channel clone to each
_channel_task.
'''
def __init__( def __init__(
self, self,
n: trio.Nursery, n: trio.Nursery,
@ -194,35 +276,30 @@ class RingBuffSubscriber(ChannelManager, trio.abc.ReceiveChannel[bytes]):
super().__init__(n) super().__init__(n)
self._send_chan, self._recv_chan = trio.open_memory_channel(0) self._send_chan, self._recv_chan = trio.open_memory_channel(0)
async def _channel_handler_task( @acm
async def _open_channel(
self, self,
name: str name: str
): ) -> AsyncContextManager[RingBuffBytesReceiver]:
async with ( async with (
ringd.open_ringbuf( ringd.open_ringbuf(
name=name, name=name,
must_exist=True must_exist=True,
) as token, ) as token,
attach_to_ringbuf_rchannel(token) as chan
attach_to_ringbuf_rchannel(token) as rchan
): ):
with trio.CancelScope() as cancel_scope: yield chan
self._channels.append(ChannelInfo(
connect_time=time.time(), async def _channel_task(self, info: ChannelInfo) -> None:
name=name,
channel=rchan,
cancel_scope=cancel_scope
))
send_chan = self._send_chan.clone() send_chan = self._send_chan.clone()
try: try:
async for msg in rchan: async for msg in info.channel:
await send_chan.send(msg) await send_chan.send(msg)
except tractor._exceptions.InternalError: except tractor._exceptions.InternalError:
# TODO: cleaner cancellation!
... ...
self._maybe_destroy_channel(name)
async def receive(self) -> bytes: async def receive(self) -> bytes:
return await self._recv_chan.receive() return await self._recv_chan.receive()
@ -234,5 +311,4 @@ async def open_ringbuf_subscriber():
RingBuffSubscriber(n) as inputs RingBuffSubscriber(n) as inputs
): ):
yield inputs yield inputs
await inputs.aclose()