Switch to using typing.Protocl instead of abc.ABC on ChannelManager, improve abstraction and add comments
parent
4b9d6b9276
commit
b1e1187a19
|
@ -18,9 +18,12 @@ Ring buffer ipc publish-subscribe mechanism brokered by ringd
|
||||||
can dynamically add new outputs (publisher) or inputs (subscriber)
|
can dynamically add new outputs (publisher) or inputs (subscriber)
|
||||||
'''
|
'''
|
||||||
import time
|
import time
|
||||||
from abc import (
|
from typing import (
|
||||||
ABC,
|
runtime_checkable,
|
||||||
abstractmethod
|
Protocol,
|
||||||
|
TypeVar,
|
||||||
|
Self,
|
||||||
|
AsyncContextManager
|
||||||
)
|
)
|
||||||
from contextlib import asynccontextmanager as acm
|
from contextlib import asynccontextmanager as acm
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
@ -30,6 +33,7 @@ import tractor
|
||||||
|
|
||||||
from tractor.ipc import (
|
from tractor.ipc import (
|
||||||
RingBuffBytesSender,
|
RingBuffBytesSender,
|
||||||
|
RingBuffBytesReceiver,
|
||||||
attach_to_ringbuf_schannel,
|
attach_to_ringbuf_schannel,
|
||||||
attach_to_ringbuf_rchannel
|
attach_to_ringbuf_rchannel
|
||||||
)
|
)
|
||||||
|
@ -40,28 +44,72 @@ import tractor.ipc._ringbuf._ringd as ringd
|
||||||
log = tractor.log.get_logger(__name__)
|
log = tractor.log.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
ChannelType = TypeVar('ChannelType')
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ChannelInfo:
|
class ChannelInfo:
|
||||||
connect_time: float
|
connect_time: float
|
||||||
name: str
|
name: str
|
||||||
channel: RingBuffBytesSender
|
channel: ChannelType
|
||||||
cancel_scope: trio.CancelScope
|
cancel_scope: trio.CancelScope
|
||||||
|
|
||||||
|
|
||||||
class ChannelManager(ABC):
|
# TODO: maybe move this abstraction to another module or standalone?
|
||||||
|
# its not ring buf specific and allows fan out and fan in an a dynamic
|
||||||
|
# amount of channels
|
||||||
|
@runtime_checkable
|
||||||
|
class ChannelManager(Protocol[ChannelType]):
|
||||||
|
'''
|
||||||
|
Common data structures and methods pubsub classes use to manage channels &
|
||||||
|
their related handler background tasks, as well as cancellation of them.
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
n: trio.Nursery,
|
n: trio.Nursery,
|
||||||
):
|
):
|
||||||
self._n = n
|
self._n = n
|
||||||
self._channels: list[ChannelInfo] = []
|
self._channels: list[Self.ChannelInfo] = []
|
||||||
|
|
||||||
@abstractmethod
|
async def _open_channel(
|
||||||
async def _channel_handler_task(self, name: str):
|
self,
|
||||||
|
name: str
|
||||||
|
) -> AsyncContextManager[ChannelType]:
|
||||||
|
'''
|
||||||
|
Used to instantiate channel resources given a name
|
||||||
|
|
||||||
|
'''
|
||||||
...
|
...
|
||||||
|
|
||||||
|
async def _channel_task(self, info: ChannelInfo) -> None:
|
||||||
|
'''
|
||||||
|
Long running task that manages the channel
|
||||||
|
|
||||||
|
'''
|
||||||
|
...
|
||||||
|
|
||||||
|
async def _channel_handler_task(self, name: str):
|
||||||
|
async with self._open_channel(name) as chan:
|
||||||
|
with trio.CancelScope() as cancel_scope:
|
||||||
|
info = Self.ChannelInfo(
|
||||||
|
connect_time=time.time(),
|
||||||
|
name=name,
|
||||||
|
channel=chan,
|
||||||
|
cancel_scope=cancel_scope
|
||||||
|
)
|
||||||
|
self._channels.append(info)
|
||||||
|
await self._channel_task(info)
|
||||||
|
|
||||||
|
self._maybe_destroy_channel(name)
|
||||||
|
|
||||||
def find_channel(self, name: str) -> tuple[int, ChannelInfo] | None:
|
def find_channel(self, name: str) -> tuple[int, ChannelInfo] | None:
|
||||||
|
'''
|
||||||
|
Given a channel name maybe return its index and value from
|
||||||
|
internal _channels list.
|
||||||
|
|
||||||
|
'''
|
||||||
for entry in enumerate(self._channels):
|
for entry in enumerate(self._channels):
|
||||||
i, info = entry
|
i, info = entry
|
||||||
if info.name == name:
|
if info.name == name:
|
||||||
|
@ -70,6 +118,11 @@ class ChannelManager(ABC):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _maybe_destroy_channel(self, name: str):
|
def _maybe_destroy_channel(self, name: str):
|
||||||
|
'''
|
||||||
|
If channel exists cancel its scope and remove from internal
|
||||||
|
_channels list.
|
||||||
|
|
||||||
|
'''
|
||||||
maybe_entry = self.find_channel(name)
|
maybe_entry = self.find_channel(name)
|
||||||
if maybe_entry:
|
if maybe_entry:
|
||||||
i, info = maybe_entry
|
i, info = maybe_entry
|
||||||
|
@ -77,12 +130,20 @@ class ChannelManager(ABC):
|
||||||
del self._channels[i]
|
del self._channels[i]
|
||||||
|
|
||||||
def add_channel(self, name: str):
|
def add_channel(self, name: str):
|
||||||
|
'''
|
||||||
|
Add a new channel to be handled
|
||||||
|
|
||||||
|
'''
|
||||||
self._n.start_soon(
|
self._n.start_soon(
|
||||||
self._channel_handler_task,
|
self._channel_handler_task,
|
||||||
name
|
name
|
||||||
)
|
)
|
||||||
|
|
||||||
def remove_channel(self, name: str):
|
def remove_channel(self, name: str):
|
||||||
|
'''
|
||||||
|
Remove a channel and stop its handling
|
||||||
|
|
||||||
|
'''
|
||||||
self._maybe_destroy_channel(name)
|
self._maybe_destroy_channel(name)
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
|
@ -92,8 +153,24 @@ class ChannelManager(ABC):
|
||||||
for chan in self._channels:
|
for chan in self._channels:
|
||||||
self._maybe_destroy_channel(chan.name)
|
self._maybe_destroy_channel(chan.name)
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
class RingBuffPublisher(ChannelManager, trio.abc.SendChannel[bytes]):
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
await self.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
class RingBuffPublisher(
|
||||||
|
ChannelManager[RingBuffBytesSender]
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
Implement ChannelManager protocol + trio.abc.SendChannel[bytes]
|
||||||
|
using ring buffers as transport.
|
||||||
|
|
||||||
|
- use a `trio.Event` to make sure `send` blocks until at least one channel
|
||||||
|
available.
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -107,29 +184,24 @@ class RingBuffPublisher(ChannelManager, trio.abc.SendChannel[bytes]):
|
||||||
|
|
||||||
self._batch_size: int = batch_size
|
self._batch_size: int = batch_size
|
||||||
|
|
||||||
async def _channel_handler_task(
|
@acm
|
||||||
|
async def _open_channel(
|
||||||
self,
|
self,
|
||||||
name: str
|
name: str
|
||||||
):
|
) -> AsyncContextManager[RingBuffBytesSender]:
|
||||||
async with (
|
async with (
|
||||||
ringd.open_ringbuf(
|
ringd.open_ringbuf(
|
||||||
name=name,
|
name=name,
|
||||||
must_exist=True,
|
must_exist=True,
|
||||||
) as token,
|
) as token,
|
||||||
attach_to_ringbuf_schannel(token) as schan
|
attach_to_ringbuf_schannel(token) as chan
|
||||||
):
|
):
|
||||||
with trio.CancelScope() as cancel_scope:
|
yield chan
|
||||||
self._channels.append(ChannelInfo(
|
|
||||||
connect_time=time.time(),
|
async def _channel_task(self, info: Self.ChannelInfo) -> None:
|
||||||
name=name,
|
|
||||||
channel=schan,
|
|
||||||
cancel_scope=cancel_scope
|
|
||||||
))
|
|
||||||
self._connect_event.set()
|
self._connect_event.set()
|
||||||
await trio.sleep_forever()
|
await trio.sleep_forever()
|
||||||
|
|
||||||
self._maybe_destroy_channel(name)
|
|
||||||
|
|
||||||
async def send(self, msg: bytes):
|
async def send(self, msg: bytes):
|
||||||
# wait at least one decoder connected
|
# wait at least one decoder connected
|
||||||
if len(self) == 0:
|
if len(self) == 0:
|
||||||
|
@ -182,11 +254,21 @@ async def open_ringbuf_publisher(
|
||||||
) as outputs
|
) as outputs
|
||||||
):
|
):
|
||||||
yield outputs
|
yield outputs
|
||||||
await outputs.aclose()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class RingBuffSubscriber(ChannelManager, trio.abc.ReceiveChannel[bytes]):
|
class RingBuffSubscriber(
|
||||||
|
ChannelManager[RingBuffBytesReceiver]
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
Implement ChannelManager protocol + trio.abc.ReceiveChannel[bytes]
|
||||||
|
using ring buffers as transport.
|
||||||
|
|
||||||
|
- use a trio memory channel pair to multiplex all received messages into a
|
||||||
|
single `trio.MemoryReceiveChannel`, give a sender channel clone to each
|
||||||
|
_channel_task.
|
||||||
|
|
||||||
|
'''
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
n: trio.Nursery,
|
n: trio.Nursery,
|
||||||
|
@ -194,35 +276,30 @@ class RingBuffSubscriber(ChannelManager, trio.abc.ReceiveChannel[bytes]):
|
||||||
super().__init__(n)
|
super().__init__(n)
|
||||||
self._send_chan, self._recv_chan = trio.open_memory_channel(0)
|
self._send_chan, self._recv_chan = trio.open_memory_channel(0)
|
||||||
|
|
||||||
async def _channel_handler_task(
|
@acm
|
||||||
|
async def _open_channel(
|
||||||
self,
|
self,
|
||||||
name: str
|
name: str
|
||||||
):
|
) -> AsyncContextManager[RingBuffBytesReceiver]:
|
||||||
async with (
|
async with (
|
||||||
ringd.open_ringbuf(
|
ringd.open_ringbuf(
|
||||||
name=name,
|
name=name,
|
||||||
must_exist=True
|
must_exist=True,
|
||||||
) as token,
|
) as token,
|
||||||
|
attach_to_ringbuf_rchannel(token) as chan
|
||||||
attach_to_ringbuf_rchannel(token) as rchan
|
|
||||||
):
|
):
|
||||||
with trio.CancelScope() as cancel_scope:
|
yield chan
|
||||||
self._channels.append(ChannelInfo(
|
|
||||||
connect_time=time.time(),
|
async def _channel_task(self, info: ChannelInfo) -> None:
|
||||||
name=name,
|
|
||||||
channel=rchan,
|
|
||||||
cancel_scope=cancel_scope
|
|
||||||
))
|
|
||||||
send_chan = self._send_chan.clone()
|
send_chan = self._send_chan.clone()
|
||||||
try:
|
try:
|
||||||
async for msg in rchan:
|
async for msg in info.channel:
|
||||||
await send_chan.send(msg)
|
await send_chan.send(msg)
|
||||||
|
|
||||||
except tractor._exceptions.InternalError:
|
except tractor._exceptions.InternalError:
|
||||||
|
# TODO: cleaner cancellation!
|
||||||
...
|
...
|
||||||
|
|
||||||
self._maybe_destroy_channel(name)
|
|
||||||
|
|
||||||
async def receive(self) -> bytes:
|
async def receive(self) -> bytes:
|
||||||
return await self._recv_chan.receive()
|
return await self._recv_chan.receive()
|
||||||
|
|
||||||
|
@ -234,5 +311,4 @@ async def open_ringbuf_subscriber():
|
||||||
RingBuffSubscriber(n) as inputs
|
RingBuffSubscriber(n) as inputs
|
||||||
):
|
):
|
||||||
yield inputs
|
yield inputs
|
||||||
await inputs.aclose()
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue