Fully test and fix bugs on _ringbuf._pubsub

Add generic channel orderer
one_ring_to_rule_them_all
Guillermo Rodriguez 2025-04-04 02:44:45 -03:00
parent bebd327023
commit 1dfc639e54
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
4 changed files with 648 additions and 325 deletions

View File

@ -92,187 +92,187 @@ def test_ringd():
trio.run(main) trio.run(main)
# class Struct(msgspec.Struct): class Struct(msgspec.Struct):
#
# def encode(self) -> bytes: def encode(self) -> bytes:
# return msgspec.msgpack.encode(self) return msgspec.msgpack.encode(self)
#
#
# class AddChannelMsg(Struct, frozen=True, tag=True): class AddChannelMsg(Struct, frozen=True, tag=True):
# name: str name: str
#
#
# class RemoveChannelMsg(Struct, frozen=True, tag=True): class RemoveChannelMsg(Struct, frozen=True, tag=True):
# name: str name: str
#
#
# class RangeMsg(Struct, frozen=True, tag=True): class RangeMsg(Struct, frozen=True, tag=True):
# start: int start: int
# end: int end: int
#
#
# ControlMessages = AddChannelMsg | RemoveChannelMsg | RangeMsg ControlMessages = AddChannelMsg | RemoveChannelMsg | RangeMsg
#
#
# @tractor.context @tractor.context
# async def subscriber_child(ctx: tractor.Context): async def subscriber_child(ctx: tractor.Context):
# await ctx.started() await ctx.started()
# async with ( async with (
# open_ringbuf_subscriber(guarantee_order=True) as subs, open_ringbuf_subscriber(guarantee_order=True) as subs,
# trio.open_nursery() as n, trio.open_nursery() as n,
# ctx.open_stream() as stream ctx.open_stream() as stream
# ): ):
# range_msg = None range_msg = None
# range_event = trio.Event() range_event = trio.Event()
# range_scope = trio.CancelScope() range_scope = trio.CancelScope()
#
# async def _control_listen_task(): async def _control_listen_task():
# nonlocal range_msg, range_event nonlocal range_msg, range_event
# async for msg in stream: async for msg in stream:
# msg = msgspec.msgpack.decode(msg, type=ControlMessages) msg = msgspec.msgpack.decode(msg, type=ControlMessages)
# match msg: match msg:
# case AddChannelMsg(): case AddChannelMsg():
# await subs.add_channel(msg.name, must_exist=False) await subs.add_channel(msg.name, must_exist=False)
#
# case RemoveChannelMsg(): case RemoveChannelMsg():
# await subs.remove_channel(msg.name) await subs.remove_channel(msg.name)
#
# case RangeMsg(): case RangeMsg():
# range_msg = msg range_msg = msg
# range_event.set() range_event.set()
#
# await stream.send(b'ack') await stream.send(b'ack')
#
# range_scope.cancel() range_scope.cancel()
#
# n.start_soon(_control_listen_task) n.start_soon(_control_listen_task)
#
# with range_scope: with range_scope:
# while True: while True:
# await range_event.wait() await range_event.wait()
# range_event = trio.Event() range_event = trio.Event()
# for i in range(range_msg.start, range_msg.end): for i in range(range_msg.start, range_msg.end):
# recv = int.from_bytes(await subs.receive()) recv = int.from_bytes(await subs.receive())
# # if recv != i: # if recv != i:
# # raise AssertionError( # raise AssertionError(
# # f'received: {recv} expected: {i}' # f'received: {recv} expected: {i}'
# # ) # )
#
# log.info(f'received: {recv} expected: {i}') log.info(f'received: {recv} expected: {i}')
#
# await stream.send(b'valid range') await stream.send(b'valid range')
# log.info('FINISHED RANGE') log.info('FINISHED RANGE')
#
# log.info('subscriber exit') log.info('subscriber exit')
#
#
# @tractor.context @tractor.context
# async def publisher_child(ctx: tractor.Context): async def publisher_child(ctx: tractor.Context):
# await ctx.started() await ctx.started()
# async with ( async with (
# open_ringbuf_publisher(batch_size=1, guarantee_order=True) as pub, open_ringbuf_publisher(batch_size=1, guarantee_order=True) as pub,
# ctx.open_stream() as stream ctx.open_stream() as stream
# ): ):
# abs_index = 0 abs_index = 0
# async for msg in stream: async for msg in stream:
# msg = msgspec.msgpack.decode(msg, type=ControlMessages) msg = msgspec.msgpack.decode(msg, type=ControlMessages)
# match msg: match msg:
# case AddChannelMsg(): case AddChannelMsg():
# await pub.add_channel(msg.name, must_exist=True) await pub.add_channel(msg.name, must_exist=True)
#
# case RemoveChannelMsg(): case RemoveChannelMsg():
# await pub.remove_channel(msg.name) await pub.remove_channel(msg.name)
#
# case RangeMsg(): case RangeMsg():
# for i in range(msg.start, msg.end): for i in range(msg.start, msg.end):
# await pub.send(i.to_bytes(4)) await pub.send(i.to_bytes(4))
# log.info(f'sent {i}, index: {abs_index}') log.info(f'sent {i}, index: {abs_index}')
# abs_index += 1 abs_index += 1
#
# await stream.send(b'ack') await stream.send(b'ack')
#
# log.info('publisher exit') log.info('publisher exit')
#
#
#
# def test_pubsub(): def test_pubsub():
# ''' '''
# Spawn ringd actor and two childs that access same ringbuf through ringd. Spawn ringd actor and two childs that access same ringbuf through ringd.
#
# Both will use `ringd.open_ringbuf` to allocate the ringbuf, then attach to Both will use `ringd.open_ringbuf` to allocate the ringbuf, then attach to
# them as sender and receiver. them as sender and receiver.
#
# ''' '''
# async def main(): async def main():
# async with ( async with (
# tractor.open_nursery( tractor.open_nursery(
# loglevel='info', loglevel='info',
# # debug_mode=True, # debug_mode=True,
# # enable_stack_on_sig=True # enable_stack_on_sig=True
# ) as an, ) as an,
#
# ringd.open_ringd() ringd.open_ringd()
# ): ):
# recv_portal = await an.start_actor( recv_portal = await an.start_actor(
# 'recv', 'recv',
# enable_modules=[__name__] enable_modules=[__name__]
# ) )
# send_portal = await an.start_actor( send_portal = await an.start_actor(
# 'send', 'send',
# enable_modules=[__name__] enable_modules=[__name__]
# ) )
#
# async with ( async with (
# recv_portal.open_context(subscriber_child) as (rctx, _), recv_portal.open_context(subscriber_child) as (rctx, _),
# rctx.open_stream() as recv_stream, rctx.open_stream() as recv_stream,
# send_portal.open_context(publisher_child) as (sctx, _), send_portal.open_context(publisher_child) as (sctx, _),
# sctx.open_stream() as send_stream, sctx.open_stream() as send_stream,
# ): ):
# async def send_wait_ack(msg: bytes): async def send_wait_ack(msg: bytes):
# await recv_stream.send(msg) await recv_stream.send(msg)
# ack = await recv_stream.receive() ack = await recv_stream.receive()
# assert ack == b'ack' assert ack == b'ack'
#
# await send_stream.send(msg) await send_stream.send(msg)
# ack = await send_stream.receive() ack = await send_stream.receive()
# assert ack == b'ack' assert ack == b'ack'
#
# async def add_channel(name: str): async def add_channel(name: str):
# await send_wait_ack(AddChannelMsg(name=name).encode()) await send_wait_ack(AddChannelMsg(name=name).encode())
#
# async def remove_channel(name: str): async def remove_channel(name: str):
# await send_wait_ack(RemoveChannelMsg(name=name).encode()) await send_wait_ack(RemoveChannelMsg(name=name).encode())
#
# async def send_range(start: int, end: int): async def send_range(start: int, end: int):
# await send_wait_ack(RangeMsg(start=start, end=end).encode()) await send_wait_ack(RangeMsg(start=start, end=end).encode())
# range_ack = await recv_stream.receive() range_ack = await recv_stream.receive()
# assert range_ack == b'valid range' assert range_ack == b'valid range'
#
# # simple test, open one channel and send 0..100 range # simple test, open one channel and send 0..100 range
# ring_name = 'ring-first' ring_name = 'ring-first'
# await add_channel(ring_name) await add_channel(ring_name)
# await send_range(0, 100) await send_range(0, 100)
# await remove_channel(ring_name) await remove_channel(ring_name)
#
# # redo # redo
# ring_name = 'ring-redo' ring_name = 'ring-redo'
# await add_channel(ring_name) await add_channel(ring_name)
# await send_range(0, 100) await send_range(0, 100)
# await remove_channel(ring_name) await remove_channel(ring_name)
#
# # multi chan test # multi chan test
# ring_names = [] ring_names = []
# for i in range(3): for i in range(3):
# ring_names.append(f'multi-ring-{i}') ring_names.append(f'multi-ring-{i}')
#
# for name in ring_names: for name in ring_names:
# await add_channel(name) await add_channel(name)
#
# await send_range(0, 300) await send_range(0, 300)
#
# for name in ring_names: for name in ring_names:
# await remove_channel(name) await remove_channel(name)
#
# await an.cancel() await an.cancel()
#
# trio.run(main) trio.run(main)

View File

@ -17,13 +17,14 @@
Ring buffer ipc publish-subscribe mechanism brokered by ringd Ring buffer ipc publish-subscribe mechanism brokered by ringd
can dynamically add new outputs (publisher) or inputs (subscriber) can dynamically add new outputs (publisher) or inputs (subscriber)
''' '''
import time
from typing import ( from typing import (
runtime_checkable,
Protocol,
TypeVar, TypeVar,
Generic,
Callable,
Awaitable,
AsyncContextManager AsyncContextManager
) )
from functools import partial
from contextlib import asynccontextmanager as acm from contextlib import asynccontextmanager as acm
from dataclasses import dataclass from dataclasses import dataclass
@ -31,12 +32,16 @@ import trio
import tractor import tractor
from tractor.ipc import ( from tractor.ipc import (
RingBuffBytesSender, RingBufferSendChannel,
RingBuffBytesReceiver, RingBufferReceiveChannel,
attach_to_ringbuf_schannel, attach_to_ringbuf_sender,
attach_to_ringbuf_rchannel attach_to_ringbuf_receiver
) )
from tractor.trionics import (
order_send_channel,
order_receive_channel
)
import tractor.ipc._ringbuf._ringd as ringd import tractor.ipc._ringbuf._ringd as ringd
@ -48,66 +53,100 @@ ChannelType = TypeVar('ChannelType')
@dataclass @dataclass
class ChannelInfo: class ChannelInfo:
connect_time: float
name: str name: str
channel: ChannelType channel: ChannelType
cancel_scope: trio.CancelScope cancel_scope: trio.CancelScope
# TODO: maybe move this abstraction to another module or standalone? class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]):
# its not ring buf specific and allows fan out and fan in an a dynamic
# amount of channels
@runtime_checkable
class ChannelManager(Protocol[ChannelType]):
''' '''
Common data structures and methods pubsub classes use to manage channels & Helper for managing channel resources and their handler tasks with
their related handler background tasks, as well as cancellation of them. cancellation, add or remove channels dynamically!
''' '''
def __init__( def __init__(
self, self,
# nursery used to spawn channel handler tasks
n: trio.Nursery, n: trio.Nursery,
# acm will be used for setup & teardown of channel resources
open_channel_acm: Callable[..., AsyncContextManager[ChannelType]],
# long running bg task to handle channel
channel_task: Callable[..., Awaitable[None]]
): ):
self._n = n self._n = n
self._open_channel = open_channel_acm
self._channel_task = channel_task
# signal when a new channel conects and we previously had none
self._connect_event = trio.Event()
# store channel runtime variables
self._channels: list[ChannelInfo] = [] self._channels: list[ChannelInfo] = []
async def _open_channel( # methods that modify self._channels should be ordered by FIFO
self._lock = trio.StrictFIFOLock()
@acm
async def maybe_lock(self):
'''
If lock is not held, acquire
'''
if self._lock.locked():
yield
return
async with self._lock:
yield
@property
def channels(self) -> list[ChannelInfo]:
return self._channels
async def _channel_handler_task(
self, self,
name: str name: str,
) -> AsyncContextManager[ChannelType]: task_status: trio.TASK_STATUS_IGNORED,
**kwargs
):
''' '''
Used to instantiate channel resources given a name Open channel resources, add to internal data structures, signal channel
connect through trio.Event, and run `channel_task` with cancel scope,
and finally, maybe remove channel from internal data structures.
Spawned by `add_channel` function, lock is held from begining of fn
until `task_status.started()` call.
kwargs are proxied to `self._open_channel` acm.
''' '''
... async with self._open_channel(name, **kwargs) as chan:
cancel_scope = trio.CancelScope()
info = ChannelInfo(
name=name,
channel=chan,
cancel_scope=cancel_scope
)
self._channels.append(info)
async def _channel_task(self, info: ChannelInfo) -> None: if len(self) == 1:
''' self._connect_event.set()
Long running task that manages the channel
''' task_status.started()
...
async def _channel_handler_task(self, name: str): with cancel_scope:
async with self._open_channel(name) as chan:
with trio.CancelScope() as cancel_scope:
info = ChannelInfo(
connect_time=time.time(),
name=name,
channel=chan,
cancel_scope=cancel_scope
)
self._channels.append(info)
await self._channel_task(info) await self._channel_task(info)
self._maybe_destroy_channel(name) await self._maybe_destroy_channel(name)
def find_channel(self, name: str) -> tuple[int, ChannelInfo] | None: def _find_channel(self, name: str) -> tuple[int, ChannelInfo] | None:
''' '''
Given a channel name maybe return its index and value from Given a channel name maybe return its index and value from
internal _channels list. internal _channels list.
Only use after acquiring lock.
''' '''
for entry in enumerate(self._channels): for entry in enumerate(self._channels):
i, info = entry i, info = entry
@ -116,105 +155,114 @@ class ChannelManager(Protocol[ChannelType]):
return None return None
def _maybe_destroy_channel(self, name: str):
async def _maybe_destroy_channel(self, name: str):
''' '''
If channel exists cancel its scope and remove from internal If channel exists cancel its scope and remove from internal
_channels list. _channels list.
''' '''
maybe_entry = self.find_channel(name) async with self.maybe_lock():
if maybe_entry: maybe_entry = self._find_channel(name)
i, info = maybe_entry if maybe_entry:
info.cancel_scope.cancel() i, info = maybe_entry
del self._channels[i] info.cancel_scope.cancel()
del self._channels[i]
def add_channel(self, name: str): async def add_channel(self, name: str, **kwargs):
''' '''
Add a new channel to be handled Add a new channel to be handled
''' '''
self._n.start_soon( async with self.maybe_lock():
self._channel_handler_task, await self._n.start(partial(
name self._channel_handler_task,
) name,
**kwargs
))
def remove_channel(self, name: str): async def remove_channel(self, name: str):
''' '''
Remove a channel and stop its handling Remove a channel and stop its handling
''' '''
self._maybe_destroy_channel(name) async with self.maybe_lock():
await self._maybe_destroy_channel(name)
# if that was last channel reset connect event
if len(self) == 0:
self._connect_event = trio.Event()
async def wait_for_channel(self):
'''
Wait until at least one channel added
'''
await self._connect_event.wait()
self._connect_event = trio.Event()
def __len__(self) -> int: def __len__(self) -> int:
return len(self._channels) return len(self._channels)
def __getitem__(self, name: str):
maybe_entry = self._find_channel(name)
if maybe_entry:
_, info = maybe_entry
return info
raise KeyError(f'Channel {name} not found!')
async def aclose(self) -> None: async def aclose(self) -> None:
for chan in self._channels: async with self.maybe_lock():
self._maybe_destroy_channel(chan.name) for info in self._channels:
await self.remove_channel(info.name)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.aclose()
class RingBuffPublisher( '''
ChannelManager[RingBuffBytesSender] Ring buffer publisher & subscribe pattern mediated by `ringd` actor.
):
'''
@dataclass
class PublisherChannels:
ring: RingBufferSendChannel
schan: trio.MemorySendChannel
rchan: trio.MemoryReceiveChannel
class RingBufferPublisher(trio.abc.SendChannel[bytes]):
''' '''
Implement ChannelManager protocol + trio.abc.SendChannel[bytes] Use ChannelManager to create a multi ringbuf round robin sender that can
using ring buffers as transport. dynamically add or remove more outputs.
- use a `trio.Event` to make sure `send` blocks until at least one channel Don't instantiate directly, use `open_ringbuf_publisher` acm to manage its
available. lifecycle.
''' '''
def __init__( def __init__(
self, self,
n: trio.Nursery, n: trio.Nursery,
# new ringbufs created will have this buf_size
buf_size: int = 10 * 1024, buf_size: int = 10 * 1024,
# global batch size for all channels
batch_size: int = 1 batch_size: int = 1
): ):
super().__init__(n) self._buf_size = buf_size
self._connect_event = trio.Event()
self._next_turn: int = 0
self._batch_size: int = batch_size self._batch_size: int = batch_size
@acm self._chanmngr = ChannelManager[PublisherChannels](
async def _open_channel( n,
self, self._open_channel,
name: str self._channel_task
) -> AsyncContextManager[RingBuffBytesSender]: )
async with (
ringd.open_ringbuf(
name=name,
must_exist=True,
) as token,
attach_to_ringbuf_schannel(token) as chan
):
yield chan
async def _channel_task(self, info: ChannelInfo) -> None: # methods that send data over the channels need to be acquire send lock
self._connect_event.set() # in order to guarantee order of operations
await trio.sleep_forever() self._send_lock = trio.StrictFIFOLock()
async def send(self, msg: bytes): self._next_turn: int = 0
# wait at least one decoder connected
if len(self) == 0:
await self._connect_event.wait()
self._connect_event = trio.Event()
if self._next_turn >= len(self):
self._next_turn = 0
turn = self._next_turn
self._next_turn += 1
output = self._channels[turn]
await output.channel.send(msg)
@property @property
def batch_size(self) -> int: def batch_size(self) -> int:
@ -222,92 +270,273 @@ class RingBuffPublisher(
@batch_size.setter @batch_size.setter
def set_batch_size(self, value: int) -> None: def set_batch_size(self, value: int) -> None:
for output in self._channels: for info in self.channels:
output.channel.batch_size = value info.channel.ring.batch_size = value
async def flush( @property
def channels(self) -> list[ChannelInfo]:
return self._chanmngr.channels
def get_channel(self, name: str) -> ChannelInfo:
'''
Get underlying ChannelInfo from name
'''
return self._chanmngr[name]
async def add_channel(
self, self,
new_batch_size: int | None = None name: str,
must_exist: bool = False
): ):
for output in self._channels: '''
await output.channel.flush( Store additional runtime info for channel and add channel to underlying
new_batch_size=new_batch_size ChannelManager
'''
await self._chanmngr.add_channel(name, must_exist=must_exist)
async def remove_channel(self, name: str):
'''
Send EOF to channel (does `channel.flush` also) then remove from
`ChannelManager` acquire both `self._send_lock` and
`self._chanmngr.maybe_lock()` in order to ensure no channel
modifications or sends happen concurrenty
'''
async with self._chanmngr.maybe_lock():
# ensure all pending messages are sent
info = self.get_channel(name)
try:
while True:
msg = info.channel.rchan.receive_nowait()
await info.channel.ring.send(msg)
except trio.WouldBlock:
await info.channel.ring.flush()
await info.channel.schan.aclose()
# finally remove from ChannelManager
await self._chanmngr.remove_channel(name)
@acm
async def _open_channel(
self,
name: str,
must_exist: bool = False
) -> AsyncContextManager[PublisherChannels]:
'''
Open a ringbuf through `ringd` and attach as send side
'''
async with (
ringd.open_ringbuf(
name=name,
buf_size=self._buf_size,
must_exist=must_exist,
) as token,
attach_to_ringbuf_sender(token) as ring,
):
schan, rchan = trio.open_memory_channel(0)
yield PublisherChannels(
ring=ring,
schan=schan,
rchan=rchan
) )
async def send_eof(self): async def _channel_task(self, info: ChannelInfo) -> None:
for output in self._channels: '''
await output.channel.send_eof() Forever get current runtime info for channel, wait on its next pending
payloads update event then drain all into send channel.
'''
async for msg in info.channel.rchan:
await info.channel.ring.send(msg)
async def send(self, msg: bytes):
'''
If no output channels connected, wait until one, then fetch the next
channel based on turn, add the indexed payload and update
`self._next_turn` & `self._next_index`.
Needs to acquire `self._send_lock` to make sure updates to turn & index
variables dont happen out of order.
'''
async with self._send_lock:
# wait at least one decoder connected
if len(self.channels) == 0:
await self._chanmngr.wait_for_channel()
if self._next_turn >= len(self.channels):
self._next_turn = 0
info = self.channels[self._next_turn]
await info.channel.schan.send(msg)
self._next_turn += 1
async def aclose(self) -> None:
await self._chanmngr.aclose()
@acm @acm
async def open_ringbuf_publisher( async def open_ringbuf_publisher(
buf_size: int = 10 * 1024, buf_size: int = 10 * 1024,
batch_size: int = 1 batch_size: int = 1,
): guarantee_order: bool = False,
force_cancel: bool = False
) -> AsyncContextManager[RingBufferPublisher]:
'''
Open a new ringbuf publisher
'''
async with ( async with (
trio.open_nursery() as n, trio.open_nursery() as n,
RingBuffPublisher( RingBufferPublisher(
n, n,
buf_size=buf_size, buf_size=buf_size,
batch_size=batch_size batch_size=batch_size
) as outputs ) as publisher
): ):
yield outputs if guarantee_order:
order_send_channel(publisher)
yield publisher
if force_cancel:
# implicitly cancel any running channel handler task
n.cancel_scope.cancel()
class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
class RingBuffSubscriber(
ChannelManager[RingBuffBytesReceiver]
):
''' '''
Implement ChannelManager protocol + trio.abc.ReceiveChannel[bytes] Use ChannelManager to create a multi ringbuf receiver that can
using ring buffers as transport. dynamically add or remove more inputs and combine all into a single output.
- use a trio memory channel pair to multiplex all received messages into a In order for `self.receive` messages to be returned in order, publisher
single `trio.MemoryReceiveChannel`, give a sender channel clone to each will send all payloads as `OrderedPayload` msgpack encoded msgs, this
_channel_task. allows our channel handler tasks to just stash the out of order payloads
inside `self._pending_payloads` and if a in order payload is available
signal through `self._new_payload_event`.
On `self.receive` we wait until at least one channel is connected, then if
an in order payload is pending, we pop and return it, in case no in order
payload is available wait until next `self._new_payload_event.set()`.
''' '''
def __init__( def __init__(
self, self,
n: trio.Nursery, n: trio.Nursery,
# if connecting to a publisher that has already sent messages set
# to the next expected payload index this subscriber will receive
start_index: int = 0
): ):
super().__init__(n) self._chanmngr = ChannelManager[RingBufferReceiveChannel](
self._send_chan, self._recv_chan = trio.open_memory_channel(0) n,
self._open_channel,
self._channel_task
)
self._schan, self._rchan = trio.open_memory_channel(0)
@property
def channels(self) -> list[ChannelInfo]:
return self._chanmngr.channels
def get_channel(self, name: str):
return self._chanmngr[name]
async def add_channel(self, name: str, must_exist: bool = False):
'''
Add new input channel by name
'''
await self._chanmngr.add_channel(name, must_exist=must_exist)
async def remove_channel(self, name: str):
'''
Remove an input channel by name
'''
await self._chanmngr.remove_channel(name)
@acm @acm
async def _open_channel( async def _open_channel(
self, self,
name: str name: str,
) -> AsyncContextManager[RingBuffBytesReceiver]: must_exist: bool = False
) -> AsyncContextManager[RingBufferReceiveChannel]:
'''
Open a ringbuf through `ringd` and attach as receiver side
'''
async with ( async with (
ringd.open_ringbuf( ringd.open_ringbuf(
name=name, name=name,
must_exist=True, must_exist=must_exist,
) as token, ) as token,
attach_to_ringbuf_rchannel(token) as chan attach_to_ringbuf_receiver(token) as chan
): ):
yield chan yield chan
async def _channel_task(self, info: ChannelInfo) -> None: async def _channel_task(self, info: ChannelInfo) -> None:
send_chan = self._send_chan.clone() '''
try: Iterate over receive channel messages, decode them as `OrderedPayload`s
async for msg in info.channel: and stash them in `self._pending_payloads`, in case we can pop next in
await send_chan.send(msg) order payload, signal through setting `self._new_payload_event`.
except tractor._exceptions.InternalError: '''
# TODO: cleaner cancellation! while True:
... try:
msg = await info.channel.receive()
await self._schan.send(msg)
except tractor.linux.eventfd.EFDReadCancelled as e:
# when channel gets removed while we are doing a receive
log.exception(e)
break
except trio.EndOfChannel:
break
async def receive(self) -> bytes: async def receive(self) -> bytes:
return await self._recv_chan.receive() '''
Receive next in order msg
'''
return await self._rchan.receive()
async def aclose(self) -> None:
await self._chanmngr.aclose()
@acm @acm
async def open_ringbuf_subscriber(): async def open_ringbuf_subscriber(
guarantee_order: bool = False,
force_cancel: bool = False
) -> AsyncContextManager[RingBufferPublisher]:
'''
Open a new ringbuf subscriber
'''
async with ( async with (
trio.open_nursery() as n, trio.open_nursery() as n,
RingBuffSubscriber(n) as inputs RingBufferSubscriber(
n,
) as subscriber
): ):
yield inputs if guarantee_order:
order_receive_channel(subscriber)
yield subscriber
if force_cancel:
# implicitly cancel any running channel handler task
n.cancel_scope.cancel()

View File

@ -32,3 +32,8 @@ from ._broadcast import (
from ._beg import ( from ._beg import (
collapse_eg as collapse_eg, collapse_eg as collapse_eg,
) )
from ._ordering import (
order_send_channel as order_send_channel,
order_receive_channel as order_receive_channel
)

View File

@ -0,0 +1,89 @@
from __future__ import annotations
from heapq import (
heappush,
heappop
)
import trio
import msgspec
class OrderedPayload(msgspec.Struct, frozen=True):
index: int
payload: bytes
@classmethod
def from_msg(cls, msg: bytes) -> OrderedPayload:
return msgspec.msgpack.decode(msg, type=OrderedPayload)
def encode(self) -> bytes:
return msgspec.msgpack.encode(self)
def order_send_channel(
channel: trio.abc.SendChannel[bytes],
start_index: int = 0
):
next_index = start_index
send_lock = trio.StrictFIFOLock()
channel._send = channel.send
channel._aclose = channel.aclose
async def send(msg: bytes):
nonlocal next_index
async with send_lock:
await channel._send(
OrderedPayload(
index=next_index,
payload=msg
).encode()
)
next_index += 1
async def aclose():
async with send_lock:
await channel._aclose()
channel.send = send
channel.aclose = aclose
def order_receive_channel(
channel: trio.abc.ReceiveChannel[bytes],
start_index: int = 0
):
next_index = start_index
pqueue = []
channel._receive = channel.receive
def can_pop_next() -> bool:
return (
len(pqueue) > 0
and
pqueue[0][0] == next_index
)
async def drain_to_heap():
while not can_pop_next():
msg = await channel._receive()
msg = OrderedPayload.from_msg(msg)
heappush(pqueue, (msg.index, msg.payload))
def pop_next():
nonlocal next_index
_, msg = heappop(pqueue)
next_index += 1
return msg
async def receive() -> bytes:
if can_pop_next():
return pop_next()
await drain_to_heap()
return pop_next()
channel.receive = receive