Fully test and fix bugs on _ringbuf._pubsub
Add generic channel ordererone_ring_to_rule_them_all
parent
bebd327023
commit
1dfc639e54
|
@ -92,187 +92,187 @@ def test_ringd():
|
|||
trio.run(main)
|
||||
|
||||
|
||||
# class Struct(msgspec.Struct):
|
||||
#
|
||||
# def encode(self) -> bytes:
|
||||
# return msgspec.msgpack.encode(self)
|
||||
#
|
||||
#
|
||||
# class AddChannelMsg(Struct, frozen=True, tag=True):
|
||||
# name: str
|
||||
#
|
||||
#
|
||||
# class RemoveChannelMsg(Struct, frozen=True, tag=True):
|
||||
# name: str
|
||||
#
|
||||
#
|
||||
# class RangeMsg(Struct, frozen=True, tag=True):
|
||||
# start: int
|
||||
# end: int
|
||||
#
|
||||
#
|
||||
# ControlMessages = AddChannelMsg | RemoveChannelMsg | RangeMsg
|
||||
#
|
||||
#
|
||||
# @tractor.context
|
||||
# async def subscriber_child(ctx: tractor.Context):
|
||||
# await ctx.started()
|
||||
# async with (
|
||||
# open_ringbuf_subscriber(guarantee_order=True) as subs,
|
||||
# trio.open_nursery() as n,
|
||||
# ctx.open_stream() as stream
|
||||
# ):
|
||||
# range_msg = None
|
||||
# range_event = trio.Event()
|
||||
# range_scope = trio.CancelScope()
|
||||
#
|
||||
# async def _control_listen_task():
|
||||
# nonlocal range_msg, range_event
|
||||
# async for msg in stream:
|
||||
# msg = msgspec.msgpack.decode(msg, type=ControlMessages)
|
||||
# match msg:
|
||||
# case AddChannelMsg():
|
||||
# await subs.add_channel(msg.name, must_exist=False)
|
||||
#
|
||||
# case RemoveChannelMsg():
|
||||
# await subs.remove_channel(msg.name)
|
||||
#
|
||||
# case RangeMsg():
|
||||
# range_msg = msg
|
||||
# range_event.set()
|
||||
#
|
||||
# await stream.send(b'ack')
|
||||
#
|
||||
# range_scope.cancel()
|
||||
#
|
||||
# n.start_soon(_control_listen_task)
|
||||
#
|
||||
# with range_scope:
|
||||
# while True:
|
||||
# await range_event.wait()
|
||||
# range_event = trio.Event()
|
||||
# for i in range(range_msg.start, range_msg.end):
|
||||
# recv = int.from_bytes(await subs.receive())
|
||||
# # if recv != i:
|
||||
# # raise AssertionError(
|
||||
# # f'received: {recv} expected: {i}'
|
||||
# # )
|
||||
#
|
||||
# log.info(f'received: {recv} expected: {i}')
|
||||
#
|
||||
# await stream.send(b'valid range')
|
||||
# log.info('FINISHED RANGE')
|
||||
#
|
||||
# log.info('subscriber exit')
|
||||
#
|
||||
#
|
||||
# @tractor.context
|
||||
# async def publisher_child(ctx: tractor.Context):
|
||||
# await ctx.started()
|
||||
# async with (
|
||||
# open_ringbuf_publisher(batch_size=1, guarantee_order=True) as pub,
|
||||
# ctx.open_stream() as stream
|
||||
# ):
|
||||
# abs_index = 0
|
||||
# async for msg in stream:
|
||||
# msg = msgspec.msgpack.decode(msg, type=ControlMessages)
|
||||
# match msg:
|
||||
# case AddChannelMsg():
|
||||
# await pub.add_channel(msg.name, must_exist=True)
|
||||
#
|
||||
# case RemoveChannelMsg():
|
||||
# await pub.remove_channel(msg.name)
|
||||
#
|
||||
# case RangeMsg():
|
||||
# for i in range(msg.start, msg.end):
|
||||
# await pub.send(i.to_bytes(4))
|
||||
# log.info(f'sent {i}, index: {abs_index}')
|
||||
# abs_index += 1
|
||||
#
|
||||
# await stream.send(b'ack')
|
||||
#
|
||||
# log.info('publisher exit')
|
||||
#
|
||||
#
|
||||
#
|
||||
# def test_pubsub():
|
||||
# '''
|
||||
# Spawn ringd actor and two childs that access same ringbuf through ringd.
|
||||
#
|
||||
# Both will use `ringd.open_ringbuf` to allocate the ringbuf, then attach to
|
||||
# them as sender and receiver.
|
||||
#
|
||||
# '''
|
||||
# async def main():
|
||||
# async with (
|
||||
# tractor.open_nursery(
|
||||
# loglevel='info',
|
||||
# # debug_mode=True,
|
||||
# # enable_stack_on_sig=True
|
||||
# ) as an,
|
||||
#
|
||||
# ringd.open_ringd()
|
||||
# ):
|
||||
# recv_portal = await an.start_actor(
|
||||
# 'recv',
|
||||
# enable_modules=[__name__]
|
||||
class Struct(msgspec.Struct):
|
||||
|
||||
def encode(self) -> bytes:
|
||||
return msgspec.msgpack.encode(self)
|
||||
|
||||
|
||||
class AddChannelMsg(Struct, frozen=True, tag=True):
|
||||
name: str
|
||||
|
||||
|
||||
class RemoveChannelMsg(Struct, frozen=True, tag=True):
|
||||
name: str
|
||||
|
||||
|
||||
class RangeMsg(Struct, frozen=True, tag=True):
|
||||
start: int
|
||||
end: int
|
||||
|
||||
|
||||
ControlMessages = AddChannelMsg | RemoveChannelMsg | RangeMsg
|
||||
|
||||
|
||||
@tractor.context
|
||||
async def subscriber_child(ctx: tractor.Context):
|
||||
await ctx.started()
|
||||
async with (
|
||||
open_ringbuf_subscriber(guarantee_order=True) as subs,
|
||||
trio.open_nursery() as n,
|
||||
ctx.open_stream() as stream
|
||||
):
|
||||
range_msg = None
|
||||
range_event = trio.Event()
|
||||
range_scope = trio.CancelScope()
|
||||
|
||||
async def _control_listen_task():
|
||||
nonlocal range_msg, range_event
|
||||
async for msg in stream:
|
||||
msg = msgspec.msgpack.decode(msg, type=ControlMessages)
|
||||
match msg:
|
||||
case AddChannelMsg():
|
||||
await subs.add_channel(msg.name, must_exist=False)
|
||||
|
||||
case RemoveChannelMsg():
|
||||
await subs.remove_channel(msg.name)
|
||||
|
||||
case RangeMsg():
|
||||
range_msg = msg
|
||||
range_event.set()
|
||||
|
||||
await stream.send(b'ack')
|
||||
|
||||
range_scope.cancel()
|
||||
|
||||
n.start_soon(_control_listen_task)
|
||||
|
||||
with range_scope:
|
||||
while True:
|
||||
await range_event.wait()
|
||||
range_event = trio.Event()
|
||||
for i in range(range_msg.start, range_msg.end):
|
||||
recv = int.from_bytes(await subs.receive())
|
||||
# if recv != i:
|
||||
# raise AssertionError(
|
||||
# f'received: {recv} expected: {i}'
|
||||
# )
|
||||
# send_portal = await an.start_actor(
|
||||
# 'send',
|
||||
# enable_modules=[__name__]
|
||||
# )
|
||||
#
|
||||
# async with (
|
||||
# recv_portal.open_context(subscriber_child) as (rctx, _),
|
||||
# rctx.open_stream() as recv_stream,
|
||||
# send_portal.open_context(publisher_child) as (sctx, _),
|
||||
# sctx.open_stream() as send_stream,
|
||||
# ):
|
||||
# async def send_wait_ack(msg: bytes):
|
||||
# await recv_stream.send(msg)
|
||||
# ack = await recv_stream.receive()
|
||||
# assert ack == b'ack'
|
||||
#
|
||||
# await send_stream.send(msg)
|
||||
# ack = await send_stream.receive()
|
||||
# assert ack == b'ack'
|
||||
#
|
||||
# async def add_channel(name: str):
|
||||
# await send_wait_ack(AddChannelMsg(name=name).encode())
|
||||
#
|
||||
# async def remove_channel(name: str):
|
||||
# await send_wait_ack(RemoveChannelMsg(name=name).encode())
|
||||
#
|
||||
# async def send_range(start: int, end: int):
|
||||
# await send_wait_ack(RangeMsg(start=start, end=end).encode())
|
||||
# range_ack = await recv_stream.receive()
|
||||
# assert range_ack == b'valid range'
|
||||
#
|
||||
# # simple test, open one channel and send 0..100 range
|
||||
# ring_name = 'ring-first'
|
||||
# await add_channel(ring_name)
|
||||
# await send_range(0, 100)
|
||||
# await remove_channel(ring_name)
|
||||
#
|
||||
# # redo
|
||||
# ring_name = 'ring-redo'
|
||||
# await add_channel(ring_name)
|
||||
# await send_range(0, 100)
|
||||
# await remove_channel(ring_name)
|
||||
#
|
||||
# # multi chan test
|
||||
# ring_names = []
|
||||
# for i in range(3):
|
||||
# ring_names.append(f'multi-ring-{i}')
|
||||
#
|
||||
# for name in ring_names:
|
||||
# await add_channel(name)
|
||||
#
|
||||
# await send_range(0, 300)
|
||||
#
|
||||
# for name in ring_names:
|
||||
# await remove_channel(name)
|
||||
#
|
||||
# await an.cancel()
|
||||
#
|
||||
# trio.run(main)
|
||||
|
||||
log.info(f'received: {recv} expected: {i}')
|
||||
|
||||
await stream.send(b'valid range')
|
||||
log.info('FINISHED RANGE')
|
||||
|
||||
log.info('subscriber exit')
|
||||
|
||||
|
||||
@tractor.context
|
||||
async def publisher_child(ctx: tractor.Context):
|
||||
await ctx.started()
|
||||
async with (
|
||||
open_ringbuf_publisher(batch_size=1, guarantee_order=True) as pub,
|
||||
ctx.open_stream() as stream
|
||||
):
|
||||
abs_index = 0
|
||||
async for msg in stream:
|
||||
msg = msgspec.msgpack.decode(msg, type=ControlMessages)
|
||||
match msg:
|
||||
case AddChannelMsg():
|
||||
await pub.add_channel(msg.name, must_exist=True)
|
||||
|
||||
case RemoveChannelMsg():
|
||||
await pub.remove_channel(msg.name)
|
||||
|
||||
case RangeMsg():
|
||||
for i in range(msg.start, msg.end):
|
||||
await pub.send(i.to_bytes(4))
|
||||
log.info(f'sent {i}, index: {abs_index}')
|
||||
abs_index += 1
|
||||
|
||||
await stream.send(b'ack')
|
||||
|
||||
log.info('publisher exit')
|
||||
|
||||
|
||||
|
||||
def test_pubsub():
|
||||
'''
|
||||
Spawn ringd actor and two childs that access same ringbuf through ringd.
|
||||
|
||||
Both will use `ringd.open_ringbuf` to allocate the ringbuf, then attach to
|
||||
them as sender and receiver.
|
||||
|
||||
'''
|
||||
async def main():
|
||||
async with (
|
||||
tractor.open_nursery(
|
||||
loglevel='info',
|
||||
# debug_mode=True,
|
||||
# enable_stack_on_sig=True
|
||||
) as an,
|
||||
|
||||
ringd.open_ringd()
|
||||
):
|
||||
recv_portal = await an.start_actor(
|
||||
'recv',
|
||||
enable_modules=[__name__]
|
||||
)
|
||||
send_portal = await an.start_actor(
|
||||
'send',
|
||||
enable_modules=[__name__]
|
||||
)
|
||||
|
||||
async with (
|
||||
recv_portal.open_context(subscriber_child) as (rctx, _),
|
||||
rctx.open_stream() as recv_stream,
|
||||
send_portal.open_context(publisher_child) as (sctx, _),
|
||||
sctx.open_stream() as send_stream,
|
||||
):
|
||||
async def send_wait_ack(msg: bytes):
|
||||
await recv_stream.send(msg)
|
||||
ack = await recv_stream.receive()
|
||||
assert ack == b'ack'
|
||||
|
||||
await send_stream.send(msg)
|
||||
ack = await send_stream.receive()
|
||||
assert ack == b'ack'
|
||||
|
||||
async def add_channel(name: str):
|
||||
await send_wait_ack(AddChannelMsg(name=name).encode())
|
||||
|
||||
async def remove_channel(name: str):
|
||||
await send_wait_ack(RemoveChannelMsg(name=name).encode())
|
||||
|
||||
async def send_range(start: int, end: int):
|
||||
await send_wait_ack(RangeMsg(start=start, end=end).encode())
|
||||
range_ack = await recv_stream.receive()
|
||||
assert range_ack == b'valid range'
|
||||
|
||||
# simple test, open one channel and send 0..100 range
|
||||
ring_name = 'ring-first'
|
||||
await add_channel(ring_name)
|
||||
await send_range(0, 100)
|
||||
await remove_channel(ring_name)
|
||||
|
||||
# redo
|
||||
ring_name = 'ring-redo'
|
||||
await add_channel(ring_name)
|
||||
await send_range(0, 100)
|
||||
await remove_channel(ring_name)
|
||||
|
||||
# multi chan test
|
||||
ring_names = []
|
||||
for i in range(3):
|
||||
ring_names.append(f'multi-ring-{i}')
|
||||
|
||||
for name in ring_names:
|
||||
await add_channel(name)
|
||||
|
||||
await send_range(0, 300)
|
||||
|
||||
for name in ring_names:
|
||||
await remove_channel(name)
|
||||
|
||||
await an.cancel()
|
||||
|
||||
trio.run(main)
|
||||
|
|
|
@ -17,13 +17,14 @@
|
|||
Ring buffer ipc publish-subscribe mechanism brokered by ringd
|
||||
can dynamically add new outputs (publisher) or inputs (subscriber)
|
||||
'''
|
||||
import time
|
||||
from typing import (
|
||||
runtime_checkable,
|
||||
Protocol,
|
||||
TypeVar,
|
||||
Generic,
|
||||
Callable,
|
||||
Awaitable,
|
||||
AsyncContextManager
|
||||
)
|
||||
from functools import partial
|
||||
from contextlib import asynccontextmanager as acm
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
@ -31,12 +32,16 @@ import trio
|
|||
import tractor
|
||||
|
||||
from tractor.ipc import (
|
||||
RingBuffBytesSender,
|
||||
RingBuffBytesReceiver,
|
||||
attach_to_ringbuf_schannel,
|
||||
attach_to_ringbuf_rchannel
|
||||
RingBufferSendChannel,
|
||||
RingBufferReceiveChannel,
|
||||
attach_to_ringbuf_sender,
|
||||
attach_to_ringbuf_receiver
|
||||
)
|
||||
|
||||
from tractor.trionics import (
|
||||
order_send_channel,
|
||||
order_receive_channel
|
||||
)
|
||||
import tractor.ipc._ringbuf._ringd as ringd
|
||||
|
||||
|
||||
|
@ -48,66 +53,100 @@ ChannelType = TypeVar('ChannelType')
|
|||
|
||||
@dataclass
|
||||
class ChannelInfo:
|
||||
connect_time: float
|
||||
name: str
|
||||
channel: ChannelType
|
||||
cancel_scope: trio.CancelScope
|
||||
|
||||
|
||||
# TODO: maybe move this abstraction to another module or standalone?
|
||||
# its not ring buf specific and allows fan out and fan in an a dynamic
|
||||
# amount of channels
|
||||
@runtime_checkable
|
||||
class ChannelManager(Protocol[ChannelType]):
|
||||
class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]):
|
||||
'''
|
||||
Common data structures and methods pubsub classes use to manage channels &
|
||||
their related handler background tasks, as well as cancellation of them.
|
||||
Helper for managing channel resources and their handler tasks with
|
||||
cancellation, add or remove channels dynamically!
|
||||
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# nursery used to spawn channel handler tasks
|
||||
n: trio.Nursery,
|
||||
|
||||
# acm will be used for setup & teardown of channel resources
|
||||
open_channel_acm: Callable[..., AsyncContextManager[ChannelType]],
|
||||
|
||||
# long running bg task to handle channel
|
||||
channel_task: Callable[..., Awaitable[None]]
|
||||
):
|
||||
self._n = n
|
||||
self._open_channel = open_channel_acm
|
||||
self._channel_task = channel_task
|
||||
|
||||
# signal when a new channel conects and we previously had none
|
||||
self._connect_event = trio.Event()
|
||||
|
||||
# store channel runtime variables
|
||||
self._channels: list[ChannelInfo] = []
|
||||
|
||||
async def _open_channel(
|
||||
# methods that modify self._channels should be ordered by FIFO
|
||||
self._lock = trio.StrictFIFOLock()
|
||||
|
||||
@acm
|
||||
async def maybe_lock(self):
|
||||
'''
|
||||
If lock is not held, acquire
|
||||
|
||||
'''
|
||||
if self._lock.locked():
|
||||
yield
|
||||
return
|
||||
|
||||
async with self._lock:
|
||||
yield
|
||||
|
||||
@property
|
||||
def channels(self) -> list[ChannelInfo]:
|
||||
return self._channels
|
||||
|
||||
async def _channel_handler_task(
|
||||
self,
|
||||
name: str
|
||||
) -> AsyncContextManager[ChannelType]:
|
||||
name: str,
|
||||
task_status: trio.TASK_STATUS_IGNORED,
|
||||
**kwargs
|
||||
):
|
||||
'''
|
||||
Used to instantiate channel resources given a name
|
||||
Open channel resources, add to internal data structures, signal channel
|
||||
connect through trio.Event, and run `channel_task` with cancel scope,
|
||||
and finally, maybe remove channel from internal data structures.
|
||||
|
||||
Spawned by `add_channel` function, lock is held from begining of fn
|
||||
until `task_status.started()` call.
|
||||
|
||||
kwargs are proxied to `self._open_channel` acm.
|
||||
'''
|
||||
...
|
||||
|
||||
async def _channel_task(self, info: ChannelInfo) -> None:
|
||||
'''
|
||||
Long running task that manages the channel
|
||||
|
||||
'''
|
||||
...
|
||||
|
||||
async def _channel_handler_task(self, name: str):
|
||||
async with self._open_channel(name) as chan:
|
||||
with trio.CancelScope() as cancel_scope:
|
||||
async with self._open_channel(name, **kwargs) as chan:
|
||||
cancel_scope = trio.CancelScope()
|
||||
info = ChannelInfo(
|
||||
connect_time=time.time(),
|
||||
name=name,
|
||||
channel=chan,
|
||||
cancel_scope=cancel_scope
|
||||
)
|
||||
self._channels.append(info)
|
||||
|
||||
if len(self) == 1:
|
||||
self._connect_event.set()
|
||||
|
||||
task_status.started()
|
||||
|
||||
with cancel_scope:
|
||||
await self._channel_task(info)
|
||||
|
||||
self._maybe_destroy_channel(name)
|
||||
await self._maybe_destroy_channel(name)
|
||||
|
||||
def find_channel(self, name: str) -> tuple[int, ChannelInfo] | None:
|
||||
def _find_channel(self, name: str) -> tuple[int, ChannelInfo] | None:
|
||||
'''
|
||||
Given a channel name maybe return its index and value from
|
||||
internal _channels list.
|
||||
|
||||
Only use after acquiring lock.
|
||||
'''
|
||||
for entry in enumerate(self._channels):
|
||||
i, info = entry
|
||||
|
@ -116,105 +155,114 @@ class ChannelManager(Protocol[ChannelType]):
|
|||
|
||||
return None
|
||||
|
||||
def _maybe_destroy_channel(self, name: str):
|
||||
|
||||
async def _maybe_destroy_channel(self, name: str):
|
||||
'''
|
||||
If channel exists cancel its scope and remove from internal
|
||||
_channels list.
|
||||
|
||||
'''
|
||||
maybe_entry = self.find_channel(name)
|
||||
async with self.maybe_lock():
|
||||
maybe_entry = self._find_channel(name)
|
||||
if maybe_entry:
|
||||
i, info = maybe_entry
|
||||
info.cancel_scope.cancel()
|
||||
del self._channels[i]
|
||||
|
||||
def add_channel(self, name: str):
|
||||
async def add_channel(self, name: str, **kwargs):
|
||||
'''
|
||||
Add a new channel to be handled
|
||||
|
||||
'''
|
||||
self._n.start_soon(
|
||||
async with self.maybe_lock():
|
||||
await self._n.start(partial(
|
||||
self._channel_handler_task,
|
||||
name
|
||||
)
|
||||
name,
|
||||
**kwargs
|
||||
))
|
||||
|
||||
def remove_channel(self, name: str):
|
||||
async def remove_channel(self, name: str):
|
||||
'''
|
||||
Remove a channel and stop its handling
|
||||
|
||||
'''
|
||||
self._maybe_destroy_channel(name)
|
||||
async with self.maybe_lock():
|
||||
await self._maybe_destroy_channel(name)
|
||||
|
||||
# if that was last channel reset connect event
|
||||
if len(self) == 0:
|
||||
self._connect_event = trio.Event()
|
||||
|
||||
async def wait_for_channel(self):
|
||||
'''
|
||||
Wait until at least one channel added
|
||||
|
||||
'''
|
||||
await self._connect_event.wait()
|
||||
self._connect_event = trio.Event()
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._channels)
|
||||
|
||||
def __getitem__(self, name: str):
|
||||
maybe_entry = self._find_channel(name)
|
||||
if maybe_entry:
|
||||
_, info = maybe_entry
|
||||
return info
|
||||
|
||||
raise KeyError(f'Channel {name} not found!')
|
||||
|
||||
async def aclose(self) -> None:
|
||||
for chan in self._channels:
|
||||
self._maybe_destroy_channel(chan.name)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.aclose()
|
||||
async with self.maybe_lock():
|
||||
for info in self._channels:
|
||||
await self.remove_channel(info.name)
|
||||
|
||||
|
||||
class RingBuffPublisher(
|
||||
ChannelManager[RingBuffBytesSender]
|
||||
):
|
||||
'''
|
||||
Implement ChannelManager protocol + trio.abc.SendChannel[bytes]
|
||||
using ring buffers as transport.
|
||||
|
||||
- use a `trio.Event` to make sure `send` blocks until at least one channel
|
||||
available.
|
||||
Ring buffer publisher & subscribe pattern mediated by `ringd` actor.
|
||||
|
||||
'''
|
||||
|
||||
@dataclass
|
||||
class PublisherChannels:
|
||||
ring: RingBufferSendChannel
|
||||
schan: trio.MemorySendChannel
|
||||
rchan: trio.MemoryReceiveChannel
|
||||
|
||||
|
||||
class RingBufferPublisher(trio.abc.SendChannel[bytes]):
|
||||
'''
|
||||
Use ChannelManager to create a multi ringbuf round robin sender that can
|
||||
dynamically add or remove more outputs.
|
||||
|
||||
Don't instantiate directly, use `open_ringbuf_publisher` acm to manage its
|
||||
lifecycle.
|
||||
|
||||
'''
|
||||
def __init__(
|
||||
self,
|
||||
n: trio.Nursery,
|
||||
|
||||
# new ringbufs created will have this buf_size
|
||||
buf_size: int = 10 * 1024,
|
||||
|
||||
# global batch size for all channels
|
||||
batch_size: int = 1
|
||||
):
|
||||
super().__init__(n)
|
||||
self._connect_event = trio.Event()
|
||||
self._next_turn: int = 0
|
||||
|
||||
self._buf_size = buf_size
|
||||
self._batch_size: int = batch_size
|
||||
|
||||
@acm
|
||||
async def _open_channel(
|
||||
self,
|
||||
name: str
|
||||
) -> AsyncContextManager[RingBuffBytesSender]:
|
||||
async with (
|
||||
ringd.open_ringbuf(
|
||||
name=name,
|
||||
must_exist=True,
|
||||
) as token,
|
||||
attach_to_ringbuf_schannel(token) as chan
|
||||
):
|
||||
yield chan
|
||||
self._chanmngr = ChannelManager[PublisherChannels](
|
||||
n,
|
||||
self._open_channel,
|
||||
self._channel_task
|
||||
)
|
||||
|
||||
async def _channel_task(self, info: ChannelInfo) -> None:
|
||||
self._connect_event.set()
|
||||
await trio.sleep_forever()
|
||||
# methods that send data over the channels need to be acquire send lock
|
||||
# in order to guarantee order of operations
|
||||
self._send_lock = trio.StrictFIFOLock()
|
||||
|
||||
async def send(self, msg: bytes):
|
||||
# wait at least one decoder connected
|
||||
if len(self) == 0:
|
||||
await self._connect_event.wait()
|
||||
self._connect_event = trio.Event()
|
||||
|
||||
if self._next_turn >= len(self):
|
||||
self._next_turn = 0
|
||||
|
||||
turn = self._next_turn
|
||||
self._next_turn += 1
|
||||
|
||||
output = self._channels[turn]
|
||||
await output.channel.send(msg)
|
||||
self._next_turn: int = 0
|
||||
|
||||
@property
|
||||
def batch_size(self) -> int:
|
||||
|
@ -222,92 +270,273 @@ class RingBuffPublisher(
|
|||
|
||||
@batch_size.setter
|
||||
def set_batch_size(self, value: int) -> None:
|
||||
for output in self._channels:
|
||||
output.channel.batch_size = value
|
||||
for info in self.channels:
|
||||
info.channel.ring.batch_size = value
|
||||
|
||||
async def flush(
|
||||
@property
|
||||
def channels(self) -> list[ChannelInfo]:
|
||||
return self._chanmngr.channels
|
||||
|
||||
def get_channel(self, name: str) -> ChannelInfo:
|
||||
'''
|
||||
Get underlying ChannelInfo from name
|
||||
|
||||
'''
|
||||
return self._chanmngr[name]
|
||||
|
||||
async def add_channel(
|
||||
self,
|
||||
new_batch_size: int | None = None
|
||||
name: str,
|
||||
must_exist: bool = False
|
||||
):
|
||||
for output in self._channels:
|
||||
await output.channel.flush(
|
||||
new_batch_size=new_batch_size
|
||||
'''
|
||||
Store additional runtime info for channel and add channel to underlying
|
||||
ChannelManager
|
||||
|
||||
'''
|
||||
await self._chanmngr.add_channel(name, must_exist=must_exist)
|
||||
|
||||
async def remove_channel(self, name: str):
|
||||
'''
|
||||
Send EOF to channel (does `channel.flush` also) then remove from
|
||||
`ChannelManager` acquire both `self._send_lock` and
|
||||
`self._chanmngr.maybe_lock()` in order to ensure no channel
|
||||
modifications or sends happen concurrenty
|
||||
'''
|
||||
async with self._chanmngr.maybe_lock():
|
||||
# ensure all pending messages are sent
|
||||
info = self.get_channel(name)
|
||||
|
||||
try:
|
||||
while True:
|
||||
msg = info.channel.rchan.receive_nowait()
|
||||
await info.channel.ring.send(msg)
|
||||
|
||||
except trio.WouldBlock:
|
||||
await info.channel.ring.flush()
|
||||
|
||||
await info.channel.schan.aclose()
|
||||
|
||||
# finally remove from ChannelManager
|
||||
await self._chanmngr.remove_channel(name)
|
||||
|
||||
@acm
|
||||
async def _open_channel(
|
||||
|
||||
self,
|
||||
name: str,
|
||||
must_exist: bool = False
|
||||
|
||||
) -> AsyncContextManager[PublisherChannels]:
|
||||
'''
|
||||
Open a ringbuf through `ringd` and attach as send side
|
||||
'''
|
||||
async with (
|
||||
ringd.open_ringbuf(
|
||||
name=name,
|
||||
buf_size=self._buf_size,
|
||||
must_exist=must_exist,
|
||||
) as token,
|
||||
attach_to_ringbuf_sender(token) as ring,
|
||||
):
|
||||
schan, rchan = trio.open_memory_channel(0)
|
||||
yield PublisherChannels(
|
||||
ring=ring,
|
||||
schan=schan,
|
||||
rchan=rchan
|
||||
)
|
||||
|
||||
async def send_eof(self):
|
||||
for output in self._channels:
|
||||
await output.channel.send_eof()
|
||||
async def _channel_task(self, info: ChannelInfo) -> None:
|
||||
'''
|
||||
Forever get current runtime info for channel, wait on its next pending
|
||||
payloads update event then drain all into send channel.
|
||||
|
||||
'''
|
||||
async for msg in info.channel.rchan:
|
||||
await info.channel.ring.send(msg)
|
||||
|
||||
async def send(self, msg: bytes):
|
||||
'''
|
||||
If no output channels connected, wait until one, then fetch the next
|
||||
channel based on turn, add the indexed payload and update
|
||||
`self._next_turn` & `self._next_index`.
|
||||
|
||||
Needs to acquire `self._send_lock` to make sure updates to turn & index
|
||||
variables dont happen out of order.
|
||||
|
||||
'''
|
||||
async with self._send_lock:
|
||||
# wait at least one decoder connected
|
||||
if len(self.channels) == 0:
|
||||
await self._chanmngr.wait_for_channel()
|
||||
|
||||
if self._next_turn >= len(self.channels):
|
||||
self._next_turn = 0
|
||||
|
||||
info = self.channels[self._next_turn]
|
||||
await info.channel.schan.send(msg)
|
||||
|
||||
self._next_turn += 1
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self._chanmngr.aclose()
|
||||
|
||||
|
||||
@acm
|
||||
async def open_ringbuf_publisher(
|
||||
|
||||
buf_size: int = 10 * 1024,
|
||||
batch_size: int = 1
|
||||
):
|
||||
batch_size: int = 1,
|
||||
guarantee_order: bool = False,
|
||||
force_cancel: bool = False
|
||||
|
||||
) -> AsyncContextManager[RingBufferPublisher]:
|
||||
'''
|
||||
Open a new ringbuf publisher
|
||||
|
||||
'''
|
||||
async with (
|
||||
trio.open_nursery() as n,
|
||||
RingBuffPublisher(
|
||||
RingBufferPublisher(
|
||||
n,
|
||||
buf_size=buf_size,
|
||||
batch_size=batch_size
|
||||
) as outputs
|
||||
) as publisher
|
||||
):
|
||||
yield outputs
|
||||
if guarantee_order:
|
||||
order_send_channel(publisher)
|
||||
|
||||
yield publisher
|
||||
|
||||
if force_cancel:
|
||||
# implicitly cancel any running channel handler task
|
||||
n.cancel_scope.cancel()
|
||||
|
||||
|
||||
|
||||
class RingBuffSubscriber(
|
||||
ChannelManager[RingBuffBytesReceiver]
|
||||
):
|
||||
class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
|
||||
'''
|
||||
Implement ChannelManager protocol + trio.abc.ReceiveChannel[bytes]
|
||||
using ring buffers as transport.
|
||||
Use ChannelManager to create a multi ringbuf receiver that can
|
||||
dynamically add or remove more inputs and combine all into a single output.
|
||||
|
||||
- use a trio memory channel pair to multiplex all received messages into a
|
||||
single `trio.MemoryReceiveChannel`, give a sender channel clone to each
|
||||
_channel_task.
|
||||
In order for `self.receive` messages to be returned in order, publisher
|
||||
will send all payloads as `OrderedPayload` msgpack encoded msgs, this
|
||||
allows our channel handler tasks to just stash the out of order payloads
|
||||
inside `self._pending_payloads` and if a in order payload is available
|
||||
signal through `self._new_payload_event`.
|
||||
|
||||
On `self.receive` we wait until at least one channel is connected, then if
|
||||
an in order payload is pending, we pop and return it, in case no in order
|
||||
payload is available wait until next `self._new_payload_event.set()`.
|
||||
|
||||
'''
|
||||
def __init__(
|
||||
self,
|
||||
n: trio.Nursery,
|
||||
|
||||
# if connecting to a publisher that has already sent messages set
|
||||
# to the next expected payload index this subscriber will receive
|
||||
start_index: int = 0
|
||||
):
|
||||
super().__init__(n)
|
||||
self._send_chan, self._recv_chan = trio.open_memory_channel(0)
|
||||
self._chanmngr = ChannelManager[RingBufferReceiveChannel](
|
||||
n,
|
||||
self._open_channel,
|
||||
self._channel_task
|
||||
)
|
||||
|
||||
self._schan, self._rchan = trio.open_memory_channel(0)
|
||||
|
||||
@property
|
||||
def channels(self) -> list[ChannelInfo]:
|
||||
return self._chanmngr.channels
|
||||
|
||||
def get_channel(self, name: str):
|
||||
return self._chanmngr[name]
|
||||
|
||||
async def add_channel(self, name: str, must_exist: bool = False):
|
||||
'''
|
||||
Add new input channel by name
|
||||
|
||||
'''
|
||||
await self._chanmngr.add_channel(name, must_exist=must_exist)
|
||||
|
||||
async def remove_channel(self, name: str):
|
||||
'''
|
||||
Remove an input channel by name
|
||||
|
||||
'''
|
||||
await self._chanmngr.remove_channel(name)
|
||||
|
||||
@acm
|
||||
async def _open_channel(
|
||||
|
||||
self,
|
||||
name: str
|
||||
) -> AsyncContextManager[RingBuffBytesReceiver]:
|
||||
name: str,
|
||||
must_exist: bool = False
|
||||
|
||||
) -> AsyncContextManager[RingBufferReceiveChannel]:
|
||||
'''
|
||||
Open a ringbuf through `ringd` and attach as receiver side
|
||||
'''
|
||||
async with (
|
||||
ringd.open_ringbuf(
|
||||
name=name,
|
||||
must_exist=True,
|
||||
must_exist=must_exist,
|
||||
) as token,
|
||||
attach_to_ringbuf_rchannel(token) as chan
|
||||
attach_to_ringbuf_receiver(token) as chan
|
||||
):
|
||||
yield chan
|
||||
|
||||
async def _channel_task(self, info: ChannelInfo) -> None:
|
||||
send_chan = self._send_chan.clone()
|
||||
try:
|
||||
async for msg in info.channel:
|
||||
await send_chan.send(msg)
|
||||
'''
|
||||
Iterate over receive channel messages, decode them as `OrderedPayload`s
|
||||
and stash them in `self._pending_payloads`, in case we can pop next in
|
||||
order payload, signal through setting `self._new_payload_event`.
|
||||
|
||||
except tractor._exceptions.InternalError:
|
||||
# TODO: cleaner cancellation!
|
||||
...
|
||||
'''
|
||||
while True:
|
||||
try:
|
||||
msg = await info.channel.receive()
|
||||
await self._schan.send(msg)
|
||||
|
||||
except tractor.linux.eventfd.EFDReadCancelled as e:
|
||||
# when channel gets removed while we are doing a receive
|
||||
log.exception(e)
|
||||
break
|
||||
|
||||
except trio.EndOfChannel:
|
||||
break
|
||||
|
||||
async def receive(self) -> bytes:
|
||||
return await self._recv_chan.receive()
|
||||
'''
|
||||
Receive next in order msg
|
||||
'''
|
||||
return await self._rchan.receive()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self._chanmngr.aclose()
|
||||
|
||||
@acm
|
||||
async def open_ringbuf_subscriber():
|
||||
async def open_ringbuf_subscriber(
|
||||
|
||||
guarantee_order: bool = False,
|
||||
force_cancel: bool = False
|
||||
|
||||
) -> AsyncContextManager[RingBufferPublisher]:
|
||||
'''
|
||||
Open a new ringbuf subscriber
|
||||
|
||||
'''
|
||||
async with (
|
||||
trio.open_nursery() as n,
|
||||
RingBuffSubscriber(n) as inputs
|
||||
RingBufferSubscriber(
|
||||
n,
|
||||
) as subscriber
|
||||
):
|
||||
yield inputs
|
||||
if guarantee_order:
|
||||
order_receive_channel(subscriber)
|
||||
|
||||
yield subscriber
|
||||
|
||||
if force_cancel:
|
||||
# implicitly cancel any running channel handler task
|
||||
n.cancel_scope.cancel()
|
||||
|
|
|
@ -32,3 +32,8 @@ from ._broadcast import (
|
|||
from ._beg import (
|
||||
collapse_eg as collapse_eg,
|
||||
)
|
||||
|
||||
from ._ordering import (
|
||||
order_send_channel as order_send_channel,
|
||||
order_receive_channel as order_receive_channel
|
||||
)
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
from __future__ import annotations
|
||||
from heapq import (
|
||||
heappush,
|
||||
heappop
|
||||
)
|
||||
|
||||
import trio
|
||||
import msgspec
|
||||
|
||||
|
||||
class OrderedPayload(msgspec.Struct, frozen=True):
|
||||
index: int
|
||||
payload: bytes
|
||||
|
||||
@classmethod
|
||||
def from_msg(cls, msg: bytes) -> OrderedPayload:
|
||||
return msgspec.msgpack.decode(msg, type=OrderedPayload)
|
||||
|
||||
def encode(self) -> bytes:
|
||||
return msgspec.msgpack.encode(self)
|
||||
|
||||
|
||||
def order_send_channel(
|
||||
channel: trio.abc.SendChannel[bytes],
|
||||
start_index: int = 0
|
||||
):
|
||||
|
||||
next_index = start_index
|
||||
send_lock = trio.StrictFIFOLock()
|
||||
|
||||
channel._send = channel.send
|
||||
channel._aclose = channel.aclose
|
||||
|
||||
async def send(msg: bytes):
|
||||
nonlocal next_index
|
||||
async with send_lock:
|
||||
await channel._send(
|
||||
OrderedPayload(
|
||||
index=next_index,
|
||||
payload=msg
|
||||
).encode()
|
||||
)
|
||||
next_index += 1
|
||||
|
||||
async def aclose():
|
||||
async with send_lock:
|
||||
await channel._aclose()
|
||||
|
||||
channel.send = send
|
||||
channel.aclose = aclose
|
||||
|
||||
|
||||
def order_receive_channel(
|
||||
channel: trio.abc.ReceiveChannel[bytes],
|
||||
start_index: int = 0
|
||||
):
|
||||
next_index = start_index
|
||||
pqueue = []
|
||||
|
||||
channel._receive = channel.receive
|
||||
|
||||
def can_pop_next() -> bool:
|
||||
return (
|
||||
len(pqueue) > 0
|
||||
and
|
||||
pqueue[0][0] == next_index
|
||||
)
|
||||
|
||||
async def drain_to_heap():
|
||||
while not can_pop_next():
|
||||
msg = await channel._receive()
|
||||
msg = OrderedPayload.from_msg(msg)
|
||||
heappush(pqueue, (msg.index, msg.payload))
|
||||
|
||||
def pop_next():
|
||||
nonlocal next_index
|
||||
_, msg = heappop(pqueue)
|
||||
next_index += 1
|
||||
return msg
|
||||
|
||||
async def receive() -> bytes:
|
||||
if can_pop_next():
|
||||
return pop_next()
|
||||
|
||||
await drain_to_heap()
|
||||
|
||||
return pop_next()
|
||||
|
||||
channel.receive = receive
|
Loading…
Reference in New Issue