Fully test and fix bugs on _ringbuf._pubsub
Add generic channel ordererone_ring_to_rule_them_all
parent
bebd327023
commit
1dfc639e54
|
@ -92,187 +92,187 @@ def test_ringd():
|
||||||
trio.run(main)
|
trio.run(main)
|
||||||
|
|
||||||
|
|
||||||
# class Struct(msgspec.Struct):
|
class Struct(msgspec.Struct):
|
||||||
#
|
|
||||||
# def encode(self) -> bytes:
|
def encode(self) -> bytes:
|
||||||
# return msgspec.msgpack.encode(self)
|
return msgspec.msgpack.encode(self)
|
||||||
#
|
|
||||||
#
|
|
||||||
# class AddChannelMsg(Struct, frozen=True, tag=True):
|
class AddChannelMsg(Struct, frozen=True, tag=True):
|
||||||
# name: str
|
name: str
|
||||||
#
|
|
||||||
#
|
|
||||||
# class RemoveChannelMsg(Struct, frozen=True, tag=True):
|
class RemoveChannelMsg(Struct, frozen=True, tag=True):
|
||||||
# name: str
|
name: str
|
||||||
#
|
|
||||||
#
|
|
||||||
# class RangeMsg(Struct, frozen=True, tag=True):
|
class RangeMsg(Struct, frozen=True, tag=True):
|
||||||
# start: int
|
start: int
|
||||||
# end: int
|
end: int
|
||||||
#
|
|
||||||
#
|
|
||||||
# ControlMessages = AddChannelMsg | RemoveChannelMsg | RangeMsg
|
ControlMessages = AddChannelMsg | RemoveChannelMsg | RangeMsg
|
||||||
#
|
|
||||||
#
|
|
||||||
# @tractor.context
|
@tractor.context
|
||||||
# async def subscriber_child(ctx: tractor.Context):
|
async def subscriber_child(ctx: tractor.Context):
|
||||||
# await ctx.started()
|
await ctx.started()
|
||||||
# async with (
|
async with (
|
||||||
# open_ringbuf_subscriber(guarantee_order=True) as subs,
|
open_ringbuf_subscriber(guarantee_order=True) as subs,
|
||||||
# trio.open_nursery() as n,
|
trio.open_nursery() as n,
|
||||||
# ctx.open_stream() as stream
|
ctx.open_stream() as stream
|
||||||
# ):
|
):
|
||||||
# range_msg = None
|
range_msg = None
|
||||||
# range_event = trio.Event()
|
range_event = trio.Event()
|
||||||
# range_scope = trio.CancelScope()
|
range_scope = trio.CancelScope()
|
||||||
#
|
|
||||||
# async def _control_listen_task():
|
async def _control_listen_task():
|
||||||
# nonlocal range_msg, range_event
|
nonlocal range_msg, range_event
|
||||||
# async for msg in stream:
|
async for msg in stream:
|
||||||
# msg = msgspec.msgpack.decode(msg, type=ControlMessages)
|
msg = msgspec.msgpack.decode(msg, type=ControlMessages)
|
||||||
# match msg:
|
match msg:
|
||||||
# case AddChannelMsg():
|
case AddChannelMsg():
|
||||||
# await subs.add_channel(msg.name, must_exist=False)
|
await subs.add_channel(msg.name, must_exist=False)
|
||||||
#
|
|
||||||
# case RemoveChannelMsg():
|
case RemoveChannelMsg():
|
||||||
# await subs.remove_channel(msg.name)
|
await subs.remove_channel(msg.name)
|
||||||
#
|
|
||||||
# case RangeMsg():
|
case RangeMsg():
|
||||||
# range_msg = msg
|
range_msg = msg
|
||||||
# range_event.set()
|
range_event.set()
|
||||||
#
|
|
||||||
# await stream.send(b'ack')
|
await stream.send(b'ack')
|
||||||
#
|
|
||||||
# range_scope.cancel()
|
range_scope.cancel()
|
||||||
#
|
|
||||||
# n.start_soon(_control_listen_task)
|
n.start_soon(_control_listen_task)
|
||||||
#
|
|
||||||
# with range_scope:
|
with range_scope:
|
||||||
# while True:
|
while True:
|
||||||
# await range_event.wait()
|
await range_event.wait()
|
||||||
# range_event = trio.Event()
|
range_event = trio.Event()
|
||||||
# for i in range(range_msg.start, range_msg.end):
|
for i in range(range_msg.start, range_msg.end):
|
||||||
# recv = int.from_bytes(await subs.receive())
|
recv = int.from_bytes(await subs.receive())
|
||||||
# # if recv != i:
|
# if recv != i:
|
||||||
# # raise AssertionError(
|
# raise AssertionError(
|
||||||
# # f'received: {recv} expected: {i}'
|
# f'received: {recv} expected: {i}'
|
||||||
# # )
|
# )
|
||||||
#
|
|
||||||
# log.info(f'received: {recv} expected: {i}')
|
log.info(f'received: {recv} expected: {i}')
|
||||||
#
|
|
||||||
# await stream.send(b'valid range')
|
await stream.send(b'valid range')
|
||||||
# log.info('FINISHED RANGE')
|
log.info('FINISHED RANGE')
|
||||||
#
|
|
||||||
# log.info('subscriber exit')
|
log.info('subscriber exit')
|
||||||
#
|
|
||||||
#
|
|
||||||
# @tractor.context
|
@tractor.context
|
||||||
# async def publisher_child(ctx: tractor.Context):
|
async def publisher_child(ctx: tractor.Context):
|
||||||
# await ctx.started()
|
await ctx.started()
|
||||||
# async with (
|
async with (
|
||||||
# open_ringbuf_publisher(batch_size=1, guarantee_order=True) as pub,
|
open_ringbuf_publisher(batch_size=1, guarantee_order=True) as pub,
|
||||||
# ctx.open_stream() as stream
|
ctx.open_stream() as stream
|
||||||
# ):
|
):
|
||||||
# abs_index = 0
|
abs_index = 0
|
||||||
# async for msg in stream:
|
async for msg in stream:
|
||||||
# msg = msgspec.msgpack.decode(msg, type=ControlMessages)
|
msg = msgspec.msgpack.decode(msg, type=ControlMessages)
|
||||||
# match msg:
|
match msg:
|
||||||
# case AddChannelMsg():
|
case AddChannelMsg():
|
||||||
# await pub.add_channel(msg.name, must_exist=True)
|
await pub.add_channel(msg.name, must_exist=True)
|
||||||
#
|
|
||||||
# case RemoveChannelMsg():
|
case RemoveChannelMsg():
|
||||||
# await pub.remove_channel(msg.name)
|
await pub.remove_channel(msg.name)
|
||||||
#
|
|
||||||
# case RangeMsg():
|
case RangeMsg():
|
||||||
# for i in range(msg.start, msg.end):
|
for i in range(msg.start, msg.end):
|
||||||
# await pub.send(i.to_bytes(4))
|
await pub.send(i.to_bytes(4))
|
||||||
# log.info(f'sent {i}, index: {abs_index}')
|
log.info(f'sent {i}, index: {abs_index}')
|
||||||
# abs_index += 1
|
abs_index += 1
|
||||||
#
|
|
||||||
# await stream.send(b'ack')
|
await stream.send(b'ack')
|
||||||
#
|
|
||||||
# log.info('publisher exit')
|
log.info('publisher exit')
|
||||||
#
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# def test_pubsub():
|
def test_pubsub():
|
||||||
# '''
|
'''
|
||||||
# Spawn ringd actor and two childs that access same ringbuf through ringd.
|
Spawn ringd actor and two childs that access same ringbuf through ringd.
|
||||||
#
|
|
||||||
# Both will use `ringd.open_ringbuf` to allocate the ringbuf, then attach to
|
Both will use `ringd.open_ringbuf` to allocate the ringbuf, then attach to
|
||||||
# them as sender and receiver.
|
them as sender and receiver.
|
||||||
#
|
|
||||||
# '''
|
'''
|
||||||
# async def main():
|
async def main():
|
||||||
# async with (
|
async with (
|
||||||
# tractor.open_nursery(
|
tractor.open_nursery(
|
||||||
# loglevel='info',
|
loglevel='info',
|
||||||
# # debug_mode=True,
|
# debug_mode=True,
|
||||||
# # enable_stack_on_sig=True
|
# enable_stack_on_sig=True
|
||||||
# ) as an,
|
) as an,
|
||||||
#
|
|
||||||
# ringd.open_ringd()
|
ringd.open_ringd()
|
||||||
# ):
|
):
|
||||||
# recv_portal = await an.start_actor(
|
recv_portal = await an.start_actor(
|
||||||
# 'recv',
|
'recv',
|
||||||
# enable_modules=[__name__]
|
enable_modules=[__name__]
|
||||||
# )
|
)
|
||||||
# send_portal = await an.start_actor(
|
send_portal = await an.start_actor(
|
||||||
# 'send',
|
'send',
|
||||||
# enable_modules=[__name__]
|
enable_modules=[__name__]
|
||||||
# )
|
)
|
||||||
#
|
|
||||||
# async with (
|
async with (
|
||||||
# recv_portal.open_context(subscriber_child) as (rctx, _),
|
recv_portal.open_context(subscriber_child) as (rctx, _),
|
||||||
# rctx.open_stream() as recv_stream,
|
rctx.open_stream() as recv_stream,
|
||||||
# send_portal.open_context(publisher_child) as (sctx, _),
|
send_portal.open_context(publisher_child) as (sctx, _),
|
||||||
# sctx.open_stream() as send_stream,
|
sctx.open_stream() as send_stream,
|
||||||
# ):
|
):
|
||||||
# async def send_wait_ack(msg: bytes):
|
async def send_wait_ack(msg: bytes):
|
||||||
# await recv_stream.send(msg)
|
await recv_stream.send(msg)
|
||||||
# ack = await recv_stream.receive()
|
ack = await recv_stream.receive()
|
||||||
# assert ack == b'ack'
|
assert ack == b'ack'
|
||||||
#
|
|
||||||
# await send_stream.send(msg)
|
await send_stream.send(msg)
|
||||||
# ack = await send_stream.receive()
|
ack = await send_stream.receive()
|
||||||
# assert ack == b'ack'
|
assert ack == b'ack'
|
||||||
#
|
|
||||||
# async def add_channel(name: str):
|
async def add_channel(name: str):
|
||||||
# await send_wait_ack(AddChannelMsg(name=name).encode())
|
await send_wait_ack(AddChannelMsg(name=name).encode())
|
||||||
#
|
|
||||||
# async def remove_channel(name: str):
|
async def remove_channel(name: str):
|
||||||
# await send_wait_ack(RemoveChannelMsg(name=name).encode())
|
await send_wait_ack(RemoveChannelMsg(name=name).encode())
|
||||||
#
|
|
||||||
# async def send_range(start: int, end: int):
|
async def send_range(start: int, end: int):
|
||||||
# await send_wait_ack(RangeMsg(start=start, end=end).encode())
|
await send_wait_ack(RangeMsg(start=start, end=end).encode())
|
||||||
# range_ack = await recv_stream.receive()
|
range_ack = await recv_stream.receive()
|
||||||
# assert range_ack == b'valid range'
|
assert range_ack == b'valid range'
|
||||||
#
|
|
||||||
# # simple test, open one channel and send 0..100 range
|
# simple test, open one channel and send 0..100 range
|
||||||
# ring_name = 'ring-first'
|
ring_name = 'ring-first'
|
||||||
# await add_channel(ring_name)
|
await add_channel(ring_name)
|
||||||
# await send_range(0, 100)
|
await send_range(0, 100)
|
||||||
# await remove_channel(ring_name)
|
await remove_channel(ring_name)
|
||||||
#
|
|
||||||
# # redo
|
# redo
|
||||||
# ring_name = 'ring-redo'
|
ring_name = 'ring-redo'
|
||||||
# await add_channel(ring_name)
|
await add_channel(ring_name)
|
||||||
# await send_range(0, 100)
|
await send_range(0, 100)
|
||||||
# await remove_channel(ring_name)
|
await remove_channel(ring_name)
|
||||||
#
|
|
||||||
# # multi chan test
|
# multi chan test
|
||||||
# ring_names = []
|
ring_names = []
|
||||||
# for i in range(3):
|
for i in range(3):
|
||||||
# ring_names.append(f'multi-ring-{i}')
|
ring_names.append(f'multi-ring-{i}')
|
||||||
#
|
|
||||||
# for name in ring_names:
|
for name in ring_names:
|
||||||
# await add_channel(name)
|
await add_channel(name)
|
||||||
#
|
|
||||||
# await send_range(0, 300)
|
await send_range(0, 300)
|
||||||
#
|
|
||||||
# for name in ring_names:
|
for name in ring_names:
|
||||||
# await remove_channel(name)
|
await remove_channel(name)
|
||||||
#
|
|
||||||
# await an.cancel()
|
await an.cancel()
|
||||||
#
|
|
||||||
# trio.run(main)
|
trio.run(main)
|
||||||
|
|
|
@ -17,13 +17,14 @@
|
||||||
Ring buffer ipc publish-subscribe mechanism brokered by ringd
|
Ring buffer ipc publish-subscribe mechanism brokered by ringd
|
||||||
can dynamically add new outputs (publisher) or inputs (subscriber)
|
can dynamically add new outputs (publisher) or inputs (subscriber)
|
||||||
'''
|
'''
|
||||||
import time
|
|
||||||
from typing import (
|
from typing import (
|
||||||
runtime_checkable,
|
|
||||||
Protocol,
|
|
||||||
TypeVar,
|
TypeVar,
|
||||||
|
Generic,
|
||||||
|
Callable,
|
||||||
|
Awaitable,
|
||||||
AsyncContextManager
|
AsyncContextManager
|
||||||
)
|
)
|
||||||
|
from functools import partial
|
||||||
from contextlib import asynccontextmanager as acm
|
from contextlib import asynccontextmanager as acm
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
@ -31,12 +32,16 @@ import trio
|
||||||
import tractor
|
import tractor
|
||||||
|
|
||||||
from tractor.ipc import (
|
from tractor.ipc import (
|
||||||
RingBuffBytesSender,
|
RingBufferSendChannel,
|
||||||
RingBuffBytesReceiver,
|
RingBufferReceiveChannel,
|
||||||
attach_to_ringbuf_schannel,
|
attach_to_ringbuf_sender,
|
||||||
attach_to_ringbuf_rchannel
|
attach_to_ringbuf_receiver
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from tractor.trionics import (
|
||||||
|
order_send_channel,
|
||||||
|
order_receive_channel
|
||||||
|
)
|
||||||
import tractor.ipc._ringbuf._ringd as ringd
|
import tractor.ipc._ringbuf._ringd as ringd
|
||||||
|
|
||||||
|
|
||||||
|
@ -48,66 +53,100 @@ ChannelType = TypeVar('ChannelType')
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ChannelInfo:
|
class ChannelInfo:
|
||||||
connect_time: float
|
|
||||||
name: str
|
name: str
|
||||||
channel: ChannelType
|
channel: ChannelType
|
||||||
cancel_scope: trio.CancelScope
|
cancel_scope: trio.CancelScope
|
||||||
|
|
||||||
|
|
||||||
# TODO: maybe move this abstraction to another module or standalone?
|
class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]):
|
||||||
# its not ring buf specific and allows fan out and fan in an a dynamic
|
|
||||||
# amount of channels
|
|
||||||
@runtime_checkable
|
|
||||||
class ChannelManager(Protocol[ChannelType]):
|
|
||||||
'''
|
'''
|
||||||
Common data structures and methods pubsub classes use to manage channels &
|
Helper for managing channel resources and their handler tasks with
|
||||||
their related handler background tasks, as well as cancellation of them.
|
cancellation, add or remove channels dynamically!
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
# nursery used to spawn channel handler tasks
|
||||||
n: trio.Nursery,
|
n: trio.Nursery,
|
||||||
|
|
||||||
|
# acm will be used for setup & teardown of channel resources
|
||||||
|
open_channel_acm: Callable[..., AsyncContextManager[ChannelType]],
|
||||||
|
|
||||||
|
# long running bg task to handle channel
|
||||||
|
channel_task: Callable[..., Awaitable[None]]
|
||||||
):
|
):
|
||||||
self._n = n
|
self._n = n
|
||||||
|
self._open_channel = open_channel_acm
|
||||||
|
self._channel_task = channel_task
|
||||||
|
|
||||||
|
# signal when a new channel conects and we previously had none
|
||||||
|
self._connect_event = trio.Event()
|
||||||
|
|
||||||
|
# store channel runtime variables
|
||||||
self._channels: list[ChannelInfo] = []
|
self._channels: list[ChannelInfo] = []
|
||||||
|
|
||||||
async def _open_channel(
|
# methods that modify self._channels should be ordered by FIFO
|
||||||
|
self._lock = trio.StrictFIFOLock()
|
||||||
|
|
||||||
|
@acm
|
||||||
|
async def maybe_lock(self):
|
||||||
|
'''
|
||||||
|
If lock is not held, acquire
|
||||||
|
|
||||||
|
'''
|
||||||
|
if self._lock.locked():
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
yield
|
||||||
|
|
||||||
|
@property
|
||||||
|
def channels(self) -> list[ChannelInfo]:
|
||||||
|
return self._channels
|
||||||
|
|
||||||
|
async def _channel_handler_task(
|
||||||
self,
|
self,
|
||||||
name: str
|
name: str,
|
||||||
) -> AsyncContextManager[ChannelType]:
|
task_status: trio.TASK_STATUS_IGNORED,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
'''
|
'''
|
||||||
Used to instantiate channel resources given a name
|
Open channel resources, add to internal data structures, signal channel
|
||||||
|
connect through trio.Event, and run `channel_task` with cancel scope,
|
||||||
|
and finally, maybe remove channel from internal data structures.
|
||||||
|
|
||||||
|
Spawned by `add_channel` function, lock is held from begining of fn
|
||||||
|
until `task_status.started()` call.
|
||||||
|
|
||||||
|
kwargs are proxied to `self._open_channel` acm.
|
||||||
'''
|
'''
|
||||||
...
|
async with self._open_channel(name, **kwargs) as chan:
|
||||||
|
cancel_scope = trio.CancelScope()
|
||||||
|
info = ChannelInfo(
|
||||||
|
name=name,
|
||||||
|
channel=chan,
|
||||||
|
cancel_scope=cancel_scope
|
||||||
|
)
|
||||||
|
self._channels.append(info)
|
||||||
|
|
||||||
async def _channel_task(self, info: ChannelInfo) -> None:
|
if len(self) == 1:
|
||||||
'''
|
self._connect_event.set()
|
||||||
Long running task that manages the channel
|
|
||||||
|
|
||||||
'''
|
task_status.started()
|
||||||
...
|
|
||||||
|
|
||||||
async def _channel_handler_task(self, name: str):
|
with cancel_scope:
|
||||||
async with self._open_channel(name) as chan:
|
|
||||||
with trio.CancelScope() as cancel_scope:
|
|
||||||
info = ChannelInfo(
|
|
||||||
connect_time=time.time(),
|
|
||||||
name=name,
|
|
||||||
channel=chan,
|
|
||||||
cancel_scope=cancel_scope
|
|
||||||
)
|
|
||||||
self._channels.append(info)
|
|
||||||
await self._channel_task(info)
|
await self._channel_task(info)
|
||||||
|
|
||||||
self._maybe_destroy_channel(name)
|
await self._maybe_destroy_channel(name)
|
||||||
|
|
||||||
def find_channel(self, name: str) -> tuple[int, ChannelInfo] | None:
|
def _find_channel(self, name: str) -> tuple[int, ChannelInfo] | None:
|
||||||
'''
|
'''
|
||||||
Given a channel name maybe return its index and value from
|
Given a channel name maybe return its index and value from
|
||||||
internal _channels list.
|
internal _channels list.
|
||||||
|
|
||||||
|
Only use after acquiring lock.
|
||||||
'''
|
'''
|
||||||
for entry in enumerate(self._channels):
|
for entry in enumerate(self._channels):
|
||||||
i, info = entry
|
i, info = entry
|
||||||
|
@ -116,105 +155,114 @@ class ChannelManager(Protocol[ChannelType]):
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _maybe_destroy_channel(self, name: str):
|
|
||||||
|
async def _maybe_destroy_channel(self, name: str):
|
||||||
'''
|
'''
|
||||||
If channel exists cancel its scope and remove from internal
|
If channel exists cancel its scope and remove from internal
|
||||||
_channels list.
|
_channels list.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
maybe_entry = self.find_channel(name)
|
async with self.maybe_lock():
|
||||||
if maybe_entry:
|
maybe_entry = self._find_channel(name)
|
||||||
i, info = maybe_entry
|
if maybe_entry:
|
||||||
info.cancel_scope.cancel()
|
i, info = maybe_entry
|
||||||
del self._channels[i]
|
info.cancel_scope.cancel()
|
||||||
|
del self._channels[i]
|
||||||
|
|
||||||
def add_channel(self, name: str):
|
async def add_channel(self, name: str, **kwargs):
|
||||||
'''
|
'''
|
||||||
Add a new channel to be handled
|
Add a new channel to be handled
|
||||||
|
|
||||||
'''
|
'''
|
||||||
self._n.start_soon(
|
async with self.maybe_lock():
|
||||||
self._channel_handler_task,
|
await self._n.start(partial(
|
||||||
name
|
self._channel_handler_task,
|
||||||
)
|
name,
|
||||||
|
**kwargs
|
||||||
|
))
|
||||||
|
|
||||||
def remove_channel(self, name: str):
|
async def remove_channel(self, name: str):
|
||||||
'''
|
'''
|
||||||
Remove a channel and stop its handling
|
Remove a channel and stop its handling
|
||||||
|
|
||||||
'''
|
'''
|
||||||
self._maybe_destroy_channel(name)
|
async with self.maybe_lock():
|
||||||
|
await self._maybe_destroy_channel(name)
|
||||||
|
|
||||||
|
# if that was last channel reset connect event
|
||||||
|
if len(self) == 0:
|
||||||
|
self._connect_event = trio.Event()
|
||||||
|
|
||||||
|
async def wait_for_channel(self):
|
||||||
|
'''
|
||||||
|
Wait until at least one channel added
|
||||||
|
|
||||||
|
'''
|
||||||
|
await self._connect_event.wait()
|
||||||
|
self._connect_event = trio.Event()
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self._channels)
|
return len(self._channels)
|
||||||
|
|
||||||
|
def __getitem__(self, name: str):
|
||||||
|
maybe_entry = self._find_channel(name)
|
||||||
|
if maybe_entry:
|
||||||
|
_, info = maybe_entry
|
||||||
|
return info
|
||||||
|
|
||||||
|
raise KeyError(f'Channel {name} not found!')
|
||||||
|
|
||||||
async def aclose(self) -> None:
|
async def aclose(self) -> None:
|
||||||
for chan in self._channels:
|
async with self.maybe_lock():
|
||||||
self._maybe_destroy_channel(chan.name)
|
for info in self._channels:
|
||||||
|
await self.remove_channel(info.name)
|
||||||
async def __aenter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
await self.aclose()
|
|
||||||
|
|
||||||
|
|
||||||
class RingBuffPublisher(
|
'''
|
||||||
ChannelManager[RingBuffBytesSender]
|
Ring buffer publisher & subscribe pattern mediated by `ringd` actor.
|
||||||
):
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PublisherChannels:
|
||||||
|
ring: RingBufferSendChannel
|
||||||
|
schan: trio.MemorySendChannel
|
||||||
|
rchan: trio.MemoryReceiveChannel
|
||||||
|
|
||||||
|
|
||||||
|
class RingBufferPublisher(trio.abc.SendChannel[bytes]):
|
||||||
'''
|
'''
|
||||||
Implement ChannelManager protocol + trio.abc.SendChannel[bytes]
|
Use ChannelManager to create a multi ringbuf round robin sender that can
|
||||||
using ring buffers as transport.
|
dynamically add or remove more outputs.
|
||||||
|
|
||||||
- use a `trio.Event` to make sure `send` blocks until at least one channel
|
Don't instantiate directly, use `open_ringbuf_publisher` acm to manage its
|
||||||
available.
|
lifecycle.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
n: trio.Nursery,
|
n: trio.Nursery,
|
||||||
|
|
||||||
|
# new ringbufs created will have this buf_size
|
||||||
buf_size: int = 10 * 1024,
|
buf_size: int = 10 * 1024,
|
||||||
|
|
||||||
|
# global batch size for all channels
|
||||||
batch_size: int = 1
|
batch_size: int = 1
|
||||||
):
|
):
|
||||||
super().__init__(n)
|
self._buf_size = buf_size
|
||||||
self._connect_event = trio.Event()
|
|
||||||
self._next_turn: int = 0
|
|
||||||
|
|
||||||
self._batch_size: int = batch_size
|
self._batch_size: int = batch_size
|
||||||
|
|
||||||
@acm
|
self._chanmngr = ChannelManager[PublisherChannels](
|
||||||
async def _open_channel(
|
n,
|
||||||
self,
|
self._open_channel,
|
||||||
name: str
|
self._channel_task
|
||||||
) -> AsyncContextManager[RingBuffBytesSender]:
|
)
|
||||||
async with (
|
|
||||||
ringd.open_ringbuf(
|
|
||||||
name=name,
|
|
||||||
must_exist=True,
|
|
||||||
) as token,
|
|
||||||
attach_to_ringbuf_schannel(token) as chan
|
|
||||||
):
|
|
||||||
yield chan
|
|
||||||
|
|
||||||
async def _channel_task(self, info: ChannelInfo) -> None:
|
# methods that send data over the channels need to be acquire send lock
|
||||||
self._connect_event.set()
|
# in order to guarantee order of operations
|
||||||
await trio.sleep_forever()
|
self._send_lock = trio.StrictFIFOLock()
|
||||||
|
|
||||||
async def send(self, msg: bytes):
|
self._next_turn: int = 0
|
||||||
# wait at least one decoder connected
|
|
||||||
if len(self) == 0:
|
|
||||||
await self._connect_event.wait()
|
|
||||||
self._connect_event = trio.Event()
|
|
||||||
|
|
||||||
if self._next_turn >= len(self):
|
|
||||||
self._next_turn = 0
|
|
||||||
|
|
||||||
turn = self._next_turn
|
|
||||||
self._next_turn += 1
|
|
||||||
|
|
||||||
output = self._channels[turn]
|
|
||||||
await output.channel.send(msg)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_size(self) -> int:
|
def batch_size(self) -> int:
|
||||||
|
@ -222,92 +270,273 @@ class RingBuffPublisher(
|
||||||
|
|
||||||
@batch_size.setter
|
@batch_size.setter
|
||||||
def set_batch_size(self, value: int) -> None:
|
def set_batch_size(self, value: int) -> None:
|
||||||
for output in self._channels:
|
for info in self.channels:
|
||||||
output.channel.batch_size = value
|
info.channel.ring.batch_size = value
|
||||||
|
|
||||||
async def flush(
|
@property
|
||||||
|
def channels(self) -> list[ChannelInfo]:
|
||||||
|
return self._chanmngr.channels
|
||||||
|
|
||||||
|
def get_channel(self, name: str) -> ChannelInfo:
|
||||||
|
'''
|
||||||
|
Get underlying ChannelInfo from name
|
||||||
|
|
||||||
|
'''
|
||||||
|
return self._chanmngr[name]
|
||||||
|
|
||||||
|
async def add_channel(
|
||||||
self,
|
self,
|
||||||
new_batch_size: int | None = None
|
name: str,
|
||||||
|
must_exist: bool = False
|
||||||
):
|
):
|
||||||
for output in self._channels:
|
'''
|
||||||
await output.channel.flush(
|
Store additional runtime info for channel and add channel to underlying
|
||||||
new_batch_size=new_batch_size
|
ChannelManager
|
||||||
|
|
||||||
|
'''
|
||||||
|
await self._chanmngr.add_channel(name, must_exist=must_exist)
|
||||||
|
|
||||||
|
async def remove_channel(self, name: str):
|
||||||
|
'''
|
||||||
|
Send EOF to channel (does `channel.flush` also) then remove from
|
||||||
|
`ChannelManager` acquire both `self._send_lock` and
|
||||||
|
`self._chanmngr.maybe_lock()` in order to ensure no channel
|
||||||
|
modifications or sends happen concurrenty
|
||||||
|
'''
|
||||||
|
async with self._chanmngr.maybe_lock():
|
||||||
|
# ensure all pending messages are sent
|
||||||
|
info = self.get_channel(name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
msg = info.channel.rchan.receive_nowait()
|
||||||
|
await info.channel.ring.send(msg)
|
||||||
|
|
||||||
|
except trio.WouldBlock:
|
||||||
|
await info.channel.ring.flush()
|
||||||
|
|
||||||
|
await info.channel.schan.aclose()
|
||||||
|
|
||||||
|
# finally remove from ChannelManager
|
||||||
|
await self._chanmngr.remove_channel(name)
|
||||||
|
|
||||||
|
@acm
|
||||||
|
async def _open_channel(
|
||||||
|
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
must_exist: bool = False
|
||||||
|
|
||||||
|
) -> AsyncContextManager[PublisherChannels]:
|
||||||
|
'''
|
||||||
|
Open a ringbuf through `ringd` and attach as send side
|
||||||
|
'''
|
||||||
|
async with (
|
||||||
|
ringd.open_ringbuf(
|
||||||
|
name=name,
|
||||||
|
buf_size=self._buf_size,
|
||||||
|
must_exist=must_exist,
|
||||||
|
) as token,
|
||||||
|
attach_to_ringbuf_sender(token) as ring,
|
||||||
|
):
|
||||||
|
schan, rchan = trio.open_memory_channel(0)
|
||||||
|
yield PublisherChannels(
|
||||||
|
ring=ring,
|
||||||
|
schan=schan,
|
||||||
|
rchan=rchan
|
||||||
)
|
)
|
||||||
|
|
||||||
async def send_eof(self):
|
async def _channel_task(self, info: ChannelInfo) -> None:
|
||||||
for output in self._channels:
|
'''
|
||||||
await output.channel.send_eof()
|
Forever get current runtime info for channel, wait on its next pending
|
||||||
|
payloads update event then drain all into send channel.
|
||||||
|
|
||||||
|
'''
|
||||||
|
async for msg in info.channel.rchan:
|
||||||
|
await info.channel.ring.send(msg)
|
||||||
|
|
||||||
|
async def send(self, msg: bytes):
|
||||||
|
'''
|
||||||
|
If no output channels connected, wait until one, then fetch the next
|
||||||
|
channel based on turn, add the indexed payload and update
|
||||||
|
`self._next_turn` & `self._next_index`.
|
||||||
|
|
||||||
|
Needs to acquire `self._send_lock` to make sure updates to turn & index
|
||||||
|
variables dont happen out of order.
|
||||||
|
|
||||||
|
'''
|
||||||
|
async with self._send_lock:
|
||||||
|
# wait at least one decoder connected
|
||||||
|
if len(self.channels) == 0:
|
||||||
|
await self._chanmngr.wait_for_channel()
|
||||||
|
|
||||||
|
if self._next_turn >= len(self.channels):
|
||||||
|
self._next_turn = 0
|
||||||
|
|
||||||
|
info = self.channels[self._next_turn]
|
||||||
|
await info.channel.schan.send(msg)
|
||||||
|
|
||||||
|
self._next_turn += 1
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
await self._chanmngr.aclose()
|
||||||
|
|
||||||
|
|
||||||
@acm
|
@acm
|
||||||
async def open_ringbuf_publisher(
|
async def open_ringbuf_publisher(
|
||||||
|
|
||||||
buf_size: int = 10 * 1024,
|
buf_size: int = 10 * 1024,
|
||||||
batch_size: int = 1
|
batch_size: int = 1,
|
||||||
):
|
guarantee_order: bool = False,
|
||||||
|
force_cancel: bool = False
|
||||||
|
|
||||||
|
) -> AsyncContextManager[RingBufferPublisher]:
|
||||||
|
'''
|
||||||
|
Open a new ringbuf publisher
|
||||||
|
|
||||||
|
'''
|
||||||
async with (
|
async with (
|
||||||
trio.open_nursery() as n,
|
trio.open_nursery() as n,
|
||||||
RingBuffPublisher(
|
RingBufferPublisher(
|
||||||
n,
|
n,
|
||||||
buf_size=buf_size,
|
buf_size=buf_size,
|
||||||
batch_size=batch_size
|
batch_size=batch_size
|
||||||
) as outputs
|
) as publisher
|
||||||
):
|
):
|
||||||
yield outputs
|
if guarantee_order:
|
||||||
|
order_send_channel(publisher)
|
||||||
|
|
||||||
|
yield publisher
|
||||||
|
|
||||||
|
if force_cancel:
|
||||||
|
# implicitly cancel any running channel handler task
|
||||||
|
n.cancel_scope.cancel()
|
||||||
|
|
||||||
|
|
||||||
|
class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
|
||||||
class RingBuffSubscriber(
|
|
||||||
ChannelManager[RingBuffBytesReceiver]
|
|
||||||
):
|
|
||||||
'''
|
'''
|
||||||
Implement ChannelManager protocol + trio.abc.ReceiveChannel[bytes]
|
Use ChannelManager to create a multi ringbuf receiver that can
|
||||||
using ring buffers as transport.
|
dynamically add or remove more inputs and combine all into a single output.
|
||||||
|
|
||||||
- use a trio memory channel pair to multiplex all received messages into a
|
In order for `self.receive` messages to be returned in order, publisher
|
||||||
single `trio.MemoryReceiveChannel`, give a sender channel clone to each
|
will send all payloads as `OrderedPayload` msgpack encoded msgs, this
|
||||||
_channel_task.
|
allows our channel handler tasks to just stash the out of order payloads
|
||||||
|
inside `self._pending_payloads` and if a in order payload is available
|
||||||
|
signal through `self._new_payload_event`.
|
||||||
|
|
||||||
|
On `self.receive` we wait until at least one channel is connected, then if
|
||||||
|
an in order payload is pending, we pop and return it, in case no in order
|
||||||
|
payload is available wait until next `self._new_payload_event.set()`.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
n: trio.Nursery,
|
n: trio.Nursery,
|
||||||
|
|
||||||
|
# if connecting to a publisher that has already sent messages set
|
||||||
|
# to the next expected payload index this subscriber will receive
|
||||||
|
start_index: int = 0
|
||||||
):
|
):
|
||||||
super().__init__(n)
|
self._chanmngr = ChannelManager[RingBufferReceiveChannel](
|
||||||
self._send_chan, self._recv_chan = trio.open_memory_channel(0)
|
n,
|
||||||
|
self._open_channel,
|
||||||
|
self._channel_task
|
||||||
|
)
|
||||||
|
|
||||||
|
self._schan, self._rchan = trio.open_memory_channel(0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def channels(self) -> list[ChannelInfo]:
|
||||||
|
return self._chanmngr.channels
|
||||||
|
|
||||||
|
def get_channel(self, name: str):
|
||||||
|
return self._chanmngr[name]
|
||||||
|
|
||||||
|
async def add_channel(self, name: str, must_exist: bool = False):
|
||||||
|
'''
|
||||||
|
Add new input channel by name
|
||||||
|
|
||||||
|
'''
|
||||||
|
await self._chanmngr.add_channel(name, must_exist=must_exist)
|
||||||
|
|
||||||
|
async def remove_channel(self, name: str):
|
||||||
|
'''
|
||||||
|
Remove an input channel by name
|
||||||
|
|
||||||
|
'''
|
||||||
|
await self._chanmngr.remove_channel(name)
|
||||||
|
|
||||||
@acm
|
@acm
|
||||||
async def _open_channel(
|
async def _open_channel(
|
||||||
|
|
||||||
self,
|
self,
|
||||||
name: str
|
name: str,
|
||||||
) -> AsyncContextManager[RingBuffBytesReceiver]:
|
must_exist: bool = False
|
||||||
|
|
||||||
|
) -> AsyncContextManager[RingBufferReceiveChannel]:
|
||||||
|
'''
|
||||||
|
Open a ringbuf through `ringd` and attach as receiver side
|
||||||
|
'''
|
||||||
async with (
|
async with (
|
||||||
ringd.open_ringbuf(
|
ringd.open_ringbuf(
|
||||||
name=name,
|
name=name,
|
||||||
must_exist=True,
|
must_exist=must_exist,
|
||||||
) as token,
|
) as token,
|
||||||
attach_to_ringbuf_rchannel(token) as chan
|
attach_to_ringbuf_receiver(token) as chan
|
||||||
):
|
):
|
||||||
yield chan
|
yield chan
|
||||||
|
|
||||||
async def _channel_task(self, info: ChannelInfo) -> None:
|
async def _channel_task(self, info: ChannelInfo) -> None:
|
||||||
send_chan = self._send_chan.clone()
|
'''
|
||||||
try:
|
Iterate over receive channel messages, decode them as `OrderedPayload`s
|
||||||
async for msg in info.channel:
|
and stash them in `self._pending_payloads`, in case we can pop next in
|
||||||
await send_chan.send(msg)
|
order payload, signal through setting `self._new_payload_event`.
|
||||||
|
|
||||||
except tractor._exceptions.InternalError:
|
'''
|
||||||
# TODO: cleaner cancellation!
|
while True:
|
||||||
...
|
try:
|
||||||
|
msg = await info.channel.receive()
|
||||||
|
await self._schan.send(msg)
|
||||||
|
|
||||||
|
except tractor.linux.eventfd.EFDReadCancelled as e:
|
||||||
|
# when channel gets removed while we are doing a receive
|
||||||
|
log.exception(e)
|
||||||
|
break
|
||||||
|
|
||||||
|
except trio.EndOfChannel:
|
||||||
|
break
|
||||||
|
|
||||||
async def receive(self) -> bytes:
|
async def receive(self) -> bytes:
|
||||||
return await self._recv_chan.receive()
|
'''
|
||||||
|
Receive next in order msg
|
||||||
|
'''
|
||||||
|
return await self._rchan.receive()
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
await self._chanmngr.aclose()
|
||||||
|
|
||||||
@acm
|
@acm
|
||||||
async def open_ringbuf_subscriber():
|
async def open_ringbuf_subscriber(
|
||||||
|
|
||||||
|
guarantee_order: bool = False,
|
||||||
|
force_cancel: bool = False
|
||||||
|
|
||||||
|
) -> AsyncContextManager[RingBufferPublisher]:
|
||||||
|
'''
|
||||||
|
Open a new ringbuf subscriber
|
||||||
|
|
||||||
|
'''
|
||||||
async with (
|
async with (
|
||||||
trio.open_nursery() as n,
|
trio.open_nursery() as n,
|
||||||
RingBuffSubscriber(n) as inputs
|
RingBufferSubscriber(
|
||||||
|
n,
|
||||||
|
) as subscriber
|
||||||
):
|
):
|
||||||
yield inputs
|
if guarantee_order:
|
||||||
|
order_receive_channel(subscriber)
|
||||||
|
|
||||||
|
yield subscriber
|
||||||
|
|
||||||
|
if force_cancel:
|
||||||
|
# implicitly cancel any running channel handler task
|
||||||
|
n.cancel_scope.cancel()
|
||||||
|
|
|
@ -32,3 +32,8 @@ from ._broadcast import (
|
||||||
from ._beg import (
|
from ._beg import (
|
||||||
collapse_eg as collapse_eg,
|
collapse_eg as collapse_eg,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ._ordering import (
|
||||||
|
order_send_channel as order_send_channel,
|
||||||
|
order_receive_channel as order_receive_channel
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,89 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
from heapq import (
|
||||||
|
heappush,
|
||||||
|
heappop
|
||||||
|
)
|
||||||
|
|
||||||
|
import trio
|
||||||
|
import msgspec
|
||||||
|
|
||||||
|
|
||||||
|
class OrderedPayload(msgspec.Struct, frozen=True):
|
||||||
|
index: int
|
||||||
|
payload: bytes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_msg(cls, msg: bytes) -> OrderedPayload:
|
||||||
|
return msgspec.msgpack.decode(msg, type=OrderedPayload)
|
||||||
|
|
||||||
|
def encode(self) -> bytes:
|
||||||
|
return msgspec.msgpack.encode(self)
|
||||||
|
|
||||||
|
|
||||||
|
def order_send_channel(
|
||||||
|
channel: trio.abc.SendChannel[bytes],
|
||||||
|
start_index: int = 0
|
||||||
|
):
|
||||||
|
|
||||||
|
next_index = start_index
|
||||||
|
send_lock = trio.StrictFIFOLock()
|
||||||
|
|
||||||
|
channel._send = channel.send
|
||||||
|
channel._aclose = channel.aclose
|
||||||
|
|
||||||
|
async def send(msg: bytes):
|
||||||
|
nonlocal next_index
|
||||||
|
async with send_lock:
|
||||||
|
await channel._send(
|
||||||
|
OrderedPayload(
|
||||||
|
index=next_index,
|
||||||
|
payload=msg
|
||||||
|
).encode()
|
||||||
|
)
|
||||||
|
next_index += 1
|
||||||
|
|
||||||
|
async def aclose():
|
||||||
|
async with send_lock:
|
||||||
|
await channel._aclose()
|
||||||
|
|
||||||
|
channel.send = send
|
||||||
|
channel.aclose = aclose
|
||||||
|
|
||||||
|
|
||||||
|
def order_receive_channel(
|
||||||
|
channel: trio.abc.ReceiveChannel[bytes],
|
||||||
|
start_index: int = 0
|
||||||
|
):
|
||||||
|
next_index = start_index
|
||||||
|
pqueue = []
|
||||||
|
|
||||||
|
channel._receive = channel.receive
|
||||||
|
|
||||||
|
def can_pop_next() -> bool:
|
||||||
|
return (
|
||||||
|
len(pqueue) > 0
|
||||||
|
and
|
||||||
|
pqueue[0][0] == next_index
|
||||||
|
)
|
||||||
|
|
||||||
|
async def drain_to_heap():
|
||||||
|
while not can_pop_next():
|
||||||
|
msg = await channel._receive()
|
||||||
|
msg = OrderedPayload.from_msg(msg)
|
||||||
|
heappush(pqueue, (msg.index, msg.payload))
|
||||||
|
|
||||||
|
def pop_next():
|
||||||
|
nonlocal next_index
|
||||||
|
_, msg = heappop(pqueue)
|
||||||
|
next_index += 1
|
||||||
|
return msg
|
||||||
|
|
||||||
|
async def receive() -> bytes:
|
||||||
|
if can_pop_next():
|
||||||
|
return pop_next()
|
||||||
|
|
||||||
|
await drain_to_heap()
|
||||||
|
|
||||||
|
return pop_next()
|
||||||
|
|
||||||
|
channel.receive = receive
|
Loading…
Reference in New Issue