From 1dfc639e54de24aaf856aef6ea91c60239ce8ca6 Mon Sep 17 00:00:00 2001 From: Guillermo Rodriguez Date: Fri, 4 Apr 2025 02:44:45 -0300 Subject: [PATCH] Fully test and fix bugs on _ringbuf._pubsub Add generic channel orderer --- tests/test_ringd.py | 368 +++++++++++------------ tractor/ipc/_ringbuf/_pubsub.py | 511 +++++++++++++++++++++++--------- tractor/trionics/__init__.py | 5 + tractor/trionics/_ordering.py | 89 ++++++ 4 files changed, 648 insertions(+), 325 deletions(-) create mode 100644 tractor/trionics/_ordering.py diff --git a/tests/test_ringd.py b/tests/test_ringd.py index 40040a43..3eda428a 100644 --- a/tests/test_ringd.py +++ b/tests/test_ringd.py @@ -92,187 +92,187 @@ def test_ringd(): trio.run(main) -# class Struct(msgspec.Struct): -# -# def encode(self) -> bytes: -# return msgspec.msgpack.encode(self) -# -# -# class AddChannelMsg(Struct, frozen=True, tag=True): -# name: str -# -# -# class RemoveChannelMsg(Struct, frozen=True, tag=True): -# name: str -# -# -# class RangeMsg(Struct, frozen=True, tag=True): -# start: int -# end: int -# -# -# ControlMessages = AddChannelMsg | RemoveChannelMsg | RangeMsg -# -# -# @tractor.context -# async def subscriber_child(ctx: tractor.Context): -# await ctx.started() -# async with ( -# open_ringbuf_subscriber(guarantee_order=True) as subs, -# trio.open_nursery() as n, -# ctx.open_stream() as stream -# ): -# range_msg = None -# range_event = trio.Event() -# range_scope = trio.CancelScope() -# -# async def _control_listen_task(): -# nonlocal range_msg, range_event -# async for msg in stream: -# msg = msgspec.msgpack.decode(msg, type=ControlMessages) -# match msg: -# case AddChannelMsg(): -# await subs.add_channel(msg.name, must_exist=False) -# -# case RemoveChannelMsg(): -# await subs.remove_channel(msg.name) -# -# case RangeMsg(): -# range_msg = msg -# range_event.set() -# -# await stream.send(b'ack') -# -# range_scope.cancel() -# -# n.start_soon(_control_listen_task) -# -# with range_scope: -# while True: -# await range_event.wait() -# range_event = trio.Event() -# for i in range(range_msg.start, range_msg.end): -# recv = int.from_bytes(await subs.receive()) -# # if recv != i: -# # raise AssertionError( -# # f'received: {recv} expected: {i}' -# # ) -# -# log.info(f'received: {recv} expected: {i}') -# -# await stream.send(b'valid range') -# log.info('FINISHED RANGE') -# -# log.info('subscriber exit') -# -# -# @tractor.context -# async def publisher_child(ctx: tractor.Context): -# await ctx.started() -# async with ( -# open_ringbuf_publisher(batch_size=1, guarantee_order=True) as pub, -# ctx.open_stream() as stream -# ): -# abs_index = 0 -# async for msg in stream: -# msg = msgspec.msgpack.decode(msg, type=ControlMessages) -# match msg: -# case AddChannelMsg(): -# await pub.add_channel(msg.name, must_exist=True) -# -# case RemoveChannelMsg(): -# await pub.remove_channel(msg.name) -# -# case RangeMsg(): -# for i in range(msg.start, msg.end): -# await pub.send(i.to_bytes(4)) -# log.info(f'sent {i}, index: {abs_index}') -# abs_index += 1 -# -# await stream.send(b'ack') -# -# log.info('publisher exit') -# -# -# -# def test_pubsub(): -# ''' -# Spawn ringd actor and two childs that access same ringbuf through ringd. -# -# Both will use `ringd.open_ringbuf` to allocate the ringbuf, then attach to -# them as sender and receiver. -# -# ''' -# async def main(): -# async with ( -# tractor.open_nursery( -# loglevel='info', -# # debug_mode=True, -# # enable_stack_on_sig=True -# ) as an, -# -# ringd.open_ringd() -# ): -# recv_portal = await an.start_actor( -# 'recv', -# enable_modules=[__name__] -# ) -# send_portal = await an.start_actor( -# 'send', -# enable_modules=[__name__] -# ) -# -# async with ( -# recv_portal.open_context(subscriber_child) as (rctx, _), -# rctx.open_stream() as recv_stream, -# send_portal.open_context(publisher_child) as (sctx, _), -# sctx.open_stream() as send_stream, -# ): -# async def send_wait_ack(msg: bytes): -# await recv_stream.send(msg) -# ack = await recv_stream.receive() -# assert ack == b'ack' -# -# await send_stream.send(msg) -# ack = await send_stream.receive() -# assert ack == b'ack' -# -# async def add_channel(name: str): -# await send_wait_ack(AddChannelMsg(name=name).encode()) -# -# async def remove_channel(name: str): -# await send_wait_ack(RemoveChannelMsg(name=name).encode()) -# -# async def send_range(start: int, end: int): -# await send_wait_ack(RangeMsg(start=start, end=end).encode()) -# range_ack = await recv_stream.receive() -# assert range_ack == b'valid range' -# -# # simple test, open one channel and send 0..100 range -# ring_name = 'ring-first' -# await add_channel(ring_name) -# await send_range(0, 100) -# await remove_channel(ring_name) -# -# # redo -# ring_name = 'ring-redo' -# await add_channel(ring_name) -# await send_range(0, 100) -# await remove_channel(ring_name) -# -# # multi chan test -# ring_names = [] -# for i in range(3): -# ring_names.append(f'multi-ring-{i}') -# -# for name in ring_names: -# await add_channel(name) -# -# await send_range(0, 300) -# -# for name in ring_names: -# await remove_channel(name) -# -# await an.cancel() -# -# trio.run(main) +class Struct(msgspec.Struct): + + def encode(self) -> bytes: + return msgspec.msgpack.encode(self) + + +class AddChannelMsg(Struct, frozen=True, tag=True): + name: str + + +class RemoveChannelMsg(Struct, frozen=True, tag=True): + name: str + + +class RangeMsg(Struct, frozen=True, tag=True): + start: int + end: int + + +ControlMessages = AddChannelMsg | RemoveChannelMsg | RangeMsg + + +@tractor.context +async def subscriber_child(ctx: tractor.Context): + await ctx.started() + async with ( + open_ringbuf_subscriber(guarantee_order=True) as subs, + trio.open_nursery() as n, + ctx.open_stream() as stream + ): + range_msg = None + range_event = trio.Event() + range_scope = trio.CancelScope() + + async def _control_listen_task(): + nonlocal range_msg, range_event + async for msg in stream: + msg = msgspec.msgpack.decode(msg, type=ControlMessages) + match msg: + case AddChannelMsg(): + await subs.add_channel(msg.name, must_exist=False) + + case RemoveChannelMsg(): + await subs.remove_channel(msg.name) + + case RangeMsg(): + range_msg = msg + range_event.set() + + await stream.send(b'ack') + + range_scope.cancel() + + n.start_soon(_control_listen_task) + + with range_scope: + while True: + await range_event.wait() + range_event = trio.Event() + for i in range(range_msg.start, range_msg.end): + recv = int.from_bytes(await subs.receive()) + # if recv != i: + # raise AssertionError( + # f'received: {recv} expected: {i}' + # ) + + log.info(f'received: {recv} expected: {i}') + + await stream.send(b'valid range') + log.info('FINISHED RANGE') + + log.info('subscriber exit') + + +@tractor.context +async def publisher_child(ctx: tractor.Context): + await ctx.started() + async with ( + open_ringbuf_publisher(batch_size=1, guarantee_order=True) as pub, + ctx.open_stream() as stream + ): + abs_index = 0 + async for msg in stream: + msg = msgspec.msgpack.decode(msg, type=ControlMessages) + match msg: + case AddChannelMsg(): + await pub.add_channel(msg.name, must_exist=True) + + case RemoveChannelMsg(): + await pub.remove_channel(msg.name) + + case RangeMsg(): + for i in range(msg.start, msg.end): + await pub.send(i.to_bytes(4)) + log.info(f'sent {i}, index: {abs_index}') + abs_index += 1 + + await stream.send(b'ack') + + log.info('publisher exit') + + + +def test_pubsub(): + ''' + Spawn ringd actor and two childs that access same ringbuf through ringd. + + Both will use `ringd.open_ringbuf` to allocate the ringbuf, then attach to + them as sender and receiver. + + ''' + async def main(): + async with ( + tractor.open_nursery( + loglevel='info', + # debug_mode=True, + # enable_stack_on_sig=True + ) as an, + + ringd.open_ringd() + ): + recv_portal = await an.start_actor( + 'recv', + enable_modules=[__name__] + ) + send_portal = await an.start_actor( + 'send', + enable_modules=[__name__] + ) + + async with ( + recv_portal.open_context(subscriber_child) as (rctx, _), + rctx.open_stream() as recv_stream, + send_portal.open_context(publisher_child) as (sctx, _), + sctx.open_stream() as send_stream, + ): + async def send_wait_ack(msg: bytes): + await recv_stream.send(msg) + ack = await recv_stream.receive() + assert ack == b'ack' + + await send_stream.send(msg) + ack = await send_stream.receive() + assert ack == b'ack' + + async def add_channel(name: str): + await send_wait_ack(AddChannelMsg(name=name).encode()) + + async def remove_channel(name: str): + await send_wait_ack(RemoveChannelMsg(name=name).encode()) + + async def send_range(start: int, end: int): + await send_wait_ack(RangeMsg(start=start, end=end).encode()) + range_ack = await recv_stream.receive() + assert range_ack == b'valid range' + + # simple test, open one channel and send 0..100 range + ring_name = 'ring-first' + await add_channel(ring_name) + await send_range(0, 100) + await remove_channel(ring_name) + + # redo + ring_name = 'ring-redo' + await add_channel(ring_name) + await send_range(0, 100) + await remove_channel(ring_name) + + # multi chan test + ring_names = [] + for i in range(3): + ring_names.append(f'multi-ring-{i}') + + for name in ring_names: + await add_channel(name) + + await send_range(0, 300) + + for name in ring_names: + await remove_channel(name) + + await an.cancel() + + trio.run(main) diff --git a/tractor/ipc/_ringbuf/_pubsub.py b/tractor/ipc/_ringbuf/_pubsub.py index fe8b5b5b..4d5e0d20 100644 --- a/tractor/ipc/_ringbuf/_pubsub.py +++ b/tractor/ipc/_ringbuf/_pubsub.py @@ -17,13 +17,14 @@ Ring buffer ipc publish-subscribe mechanism brokered by ringd can dynamically add new outputs (publisher) or inputs (subscriber) ''' -import time from typing import ( - runtime_checkable, - Protocol, TypeVar, + Generic, + Callable, + Awaitable, AsyncContextManager ) +from functools import partial from contextlib import asynccontextmanager as acm from dataclasses import dataclass @@ -31,12 +32,16 @@ import trio import tractor from tractor.ipc import ( - RingBuffBytesSender, - RingBuffBytesReceiver, - attach_to_ringbuf_schannel, - attach_to_ringbuf_rchannel + RingBufferSendChannel, + RingBufferReceiveChannel, + attach_to_ringbuf_sender, + attach_to_ringbuf_receiver ) +from tractor.trionics import ( + order_send_channel, + order_receive_channel +) import tractor.ipc._ringbuf._ringd as ringd @@ -48,66 +53,100 @@ ChannelType = TypeVar('ChannelType') @dataclass class ChannelInfo: - connect_time: float name: str channel: ChannelType cancel_scope: trio.CancelScope -# TODO: maybe move this abstraction to another module or standalone? -# its not ring buf specific and allows fan out and fan in an a dynamic -# amount of channels -@runtime_checkable -class ChannelManager(Protocol[ChannelType]): +class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]): ''' - Common data structures and methods pubsub classes use to manage channels & - their related handler background tasks, as well as cancellation of them. + Helper for managing channel resources and their handler tasks with + cancellation, add or remove channels dynamically! ''' def __init__( self, + # nursery used to spawn channel handler tasks n: trio.Nursery, + + # acm will be used for setup & teardown of channel resources + open_channel_acm: Callable[..., AsyncContextManager[ChannelType]], + + # long running bg task to handle channel + channel_task: Callable[..., Awaitable[None]] ): self._n = n + self._open_channel = open_channel_acm + self._channel_task = channel_task + + # signal when a new channel conects and we previously had none + self._connect_event = trio.Event() + + # store channel runtime variables self._channels: list[ChannelInfo] = [] - async def _open_channel( + # methods that modify self._channels should be ordered by FIFO + self._lock = trio.StrictFIFOLock() + + @acm + async def maybe_lock(self): + ''' + If lock is not held, acquire + + ''' + if self._lock.locked(): + yield + return + + async with self._lock: + yield + + @property + def channels(self) -> list[ChannelInfo]: + return self._channels + + async def _channel_handler_task( self, - name: str - ) -> AsyncContextManager[ChannelType]: + name: str, + task_status: trio.TASK_STATUS_IGNORED, + **kwargs + ): ''' - Used to instantiate channel resources given a name + Open channel resources, add to internal data structures, signal channel + connect through trio.Event, and run `channel_task` with cancel scope, + and finally, maybe remove channel from internal data structures. + Spawned by `add_channel` function, lock is held from begining of fn + until `task_status.started()` call. + + kwargs are proxied to `self._open_channel` acm. ''' - ... + async with self._open_channel(name, **kwargs) as chan: + cancel_scope = trio.CancelScope() + info = ChannelInfo( + name=name, + channel=chan, + cancel_scope=cancel_scope + ) + self._channels.append(info) - async def _channel_task(self, info: ChannelInfo) -> None: - ''' - Long running task that manages the channel + if len(self) == 1: + self._connect_event.set() - ''' - ... + task_status.started() - async def _channel_handler_task(self, name: str): - async with self._open_channel(name) as chan: - with trio.CancelScope() as cancel_scope: - info = ChannelInfo( - connect_time=time.time(), - name=name, - channel=chan, - cancel_scope=cancel_scope - ) - self._channels.append(info) + with cancel_scope: await self._channel_task(info) - self._maybe_destroy_channel(name) + await self._maybe_destroy_channel(name) - def find_channel(self, name: str) -> tuple[int, ChannelInfo] | None: + def _find_channel(self, name: str) -> tuple[int, ChannelInfo] | None: ''' Given a channel name maybe return its index and value from internal _channels list. + Only use after acquiring lock. ''' for entry in enumerate(self._channels): i, info = entry @@ -116,105 +155,114 @@ class ChannelManager(Protocol[ChannelType]): return None - def _maybe_destroy_channel(self, name: str): + + async def _maybe_destroy_channel(self, name: str): ''' If channel exists cancel its scope and remove from internal _channels list. ''' - maybe_entry = self.find_channel(name) - if maybe_entry: - i, info = maybe_entry - info.cancel_scope.cancel() - del self._channels[i] + async with self.maybe_lock(): + maybe_entry = self._find_channel(name) + if maybe_entry: + i, info = maybe_entry + info.cancel_scope.cancel() + del self._channels[i] - def add_channel(self, name: str): + async def add_channel(self, name: str, **kwargs): ''' Add a new channel to be handled ''' - self._n.start_soon( - self._channel_handler_task, - name - ) + async with self.maybe_lock(): + await self._n.start(partial( + self._channel_handler_task, + name, + **kwargs + )) - def remove_channel(self, name: str): + async def remove_channel(self, name: str): ''' Remove a channel and stop its handling ''' - self._maybe_destroy_channel(name) + async with self.maybe_lock(): + await self._maybe_destroy_channel(name) + + # if that was last channel reset connect event + if len(self) == 0: + self._connect_event = trio.Event() + + async def wait_for_channel(self): + ''' + Wait until at least one channel added + + ''' + await self._connect_event.wait() + self._connect_event = trio.Event() def __len__(self) -> int: return len(self._channels) + def __getitem__(self, name: str): + maybe_entry = self._find_channel(name) + if maybe_entry: + _, info = maybe_entry + return info + + raise KeyError(f'Channel {name} not found!') + async def aclose(self) -> None: - for chan in self._channels: - self._maybe_destroy_channel(chan.name) - - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.aclose() + async with self.maybe_lock(): + for info in self._channels: + await self.remove_channel(info.name) -class RingBuffPublisher( - ChannelManager[RingBuffBytesSender] -): +''' +Ring buffer publisher & subscribe pattern mediated by `ringd` actor. + +''' + +@dataclass +class PublisherChannels: + ring: RingBufferSendChannel + schan: trio.MemorySendChannel + rchan: trio.MemoryReceiveChannel + + +class RingBufferPublisher(trio.abc.SendChannel[bytes]): ''' - Implement ChannelManager protocol + trio.abc.SendChannel[bytes] - using ring buffers as transport. + Use ChannelManager to create a multi ringbuf round robin sender that can + dynamically add or remove more outputs. - - use a `trio.Event` to make sure `send` blocks until at least one channel - available. + Don't instantiate directly, use `open_ringbuf_publisher` acm to manage its + lifecycle. ''' - def __init__( self, n: trio.Nursery, + + # new ringbufs created will have this buf_size buf_size: int = 10 * 1024, + + # global batch size for all channels batch_size: int = 1 ): - super().__init__(n) - self._connect_event = trio.Event() - self._next_turn: int = 0 - + self._buf_size = buf_size self._batch_size: int = batch_size - @acm - async def _open_channel( - self, - name: str - ) -> AsyncContextManager[RingBuffBytesSender]: - async with ( - ringd.open_ringbuf( - name=name, - must_exist=True, - ) as token, - attach_to_ringbuf_schannel(token) as chan - ): - yield chan + self._chanmngr = ChannelManager[PublisherChannels]( + n, + self._open_channel, + self._channel_task + ) - async def _channel_task(self, info: ChannelInfo) -> None: - self._connect_event.set() - await trio.sleep_forever() + # methods that send data over the channels need to be acquire send lock + # in order to guarantee order of operations + self._send_lock = trio.StrictFIFOLock() - async def send(self, msg: bytes): - # wait at least one decoder connected - if len(self) == 0: - await self._connect_event.wait() - self._connect_event = trio.Event() - - if self._next_turn >= len(self): - self._next_turn = 0 - - turn = self._next_turn - self._next_turn += 1 - - output = self._channels[turn] - await output.channel.send(msg) + self._next_turn: int = 0 @property def batch_size(self) -> int: @@ -222,92 +270,273 @@ class RingBuffPublisher( @batch_size.setter def set_batch_size(self, value: int) -> None: - for output in self._channels: - output.channel.batch_size = value + for info in self.channels: + info.channel.ring.batch_size = value - async def flush( + @property + def channels(self) -> list[ChannelInfo]: + return self._chanmngr.channels + + def get_channel(self, name: str) -> ChannelInfo: + ''' + Get underlying ChannelInfo from name + + ''' + return self._chanmngr[name] + + async def add_channel( self, - new_batch_size: int | None = None + name: str, + must_exist: bool = False ): - for output in self._channels: - await output.channel.flush( - new_batch_size=new_batch_size + ''' + Store additional runtime info for channel and add channel to underlying + ChannelManager + + ''' + await self._chanmngr.add_channel(name, must_exist=must_exist) + + async def remove_channel(self, name: str): + ''' + Send EOF to channel (does `channel.flush` also) then remove from + `ChannelManager` acquire both `self._send_lock` and + `self._chanmngr.maybe_lock()` in order to ensure no channel + modifications or sends happen concurrenty + ''' + async with self._chanmngr.maybe_lock(): + # ensure all pending messages are sent + info = self.get_channel(name) + + try: + while True: + msg = info.channel.rchan.receive_nowait() + await info.channel.ring.send(msg) + + except trio.WouldBlock: + await info.channel.ring.flush() + + await info.channel.schan.aclose() + + # finally remove from ChannelManager + await self._chanmngr.remove_channel(name) + + @acm + async def _open_channel( + + self, + name: str, + must_exist: bool = False + + ) -> AsyncContextManager[PublisherChannels]: + ''' + Open a ringbuf through `ringd` and attach as send side + ''' + async with ( + ringd.open_ringbuf( + name=name, + buf_size=self._buf_size, + must_exist=must_exist, + ) as token, + attach_to_ringbuf_sender(token) as ring, + ): + schan, rchan = trio.open_memory_channel(0) + yield PublisherChannels( + ring=ring, + schan=schan, + rchan=rchan ) - async def send_eof(self): - for output in self._channels: - await output.channel.send_eof() + async def _channel_task(self, info: ChannelInfo) -> None: + ''' + Forever get current runtime info for channel, wait on its next pending + payloads update event then drain all into send channel. + + ''' + async for msg in info.channel.rchan: + await info.channel.ring.send(msg) + + async def send(self, msg: bytes): + ''' + If no output channels connected, wait until one, then fetch the next + channel based on turn, add the indexed payload and update + `self._next_turn` & `self._next_index`. + + Needs to acquire `self._send_lock` to make sure updates to turn & index + variables dont happen out of order. + + ''' + async with self._send_lock: + # wait at least one decoder connected + if len(self.channels) == 0: + await self._chanmngr.wait_for_channel() + + if self._next_turn >= len(self.channels): + self._next_turn = 0 + + info = self.channels[self._next_turn] + await info.channel.schan.send(msg) + + self._next_turn += 1 + + async def aclose(self) -> None: + await self._chanmngr.aclose() @acm async def open_ringbuf_publisher( + buf_size: int = 10 * 1024, - batch_size: int = 1 -): + batch_size: int = 1, + guarantee_order: bool = False, + force_cancel: bool = False + +) -> AsyncContextManager[RingBufferPublisher]: + ''' + Open a new ringbuf publisher + + ''' async with ( trio.open_nursery() as n, - RingBuffPublisher( + RingBufferPublisher( n, buf_size=buf_size, batch_size=batch_size - ) as outputs + ) as publisher ): - yield outputs + if guarantee_order: + order_send_channel(publisher) + + yield publisher + + if force_cancel: + # implicitly cancel any running channel handler task + n.cancel_scope.cancel() - -class RingBuffSubscriber( - ChannelManager[RingBuffBytesReceiver] -): +class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]): ''' - Implement ChannelManager protocol + trio.abc.ReceiveChannel[bytes] - using ring buffers as transport. + Use ChannelManager to create a multi ringbuf receiver that can + dynamically add or remove more inputs and combine all into a single output. - - use a trio memory channel pair to multiplex all received messages into a - single `trio.MemoryReceiveChannel`, give a sender channel clone to each - _channel_task. + In order for `self.receive` messages to be returned in order, publisher + will send all payloads as `OrderedPayload` msgpack encoded msgs, this + allows our channel handler tasks to just stash the out of order payloads + inside `self._pending_payloads` and if a in order payload is available + signal through `self._new_payload_event`. + + On `self.receive` we wait until at least one channel is connected, then if + an in order payload is pending, we pop and return it, in case no in order + payload is available wait until next `self._new_payload_event.set()`. ''' def __init__( self, n: trio.Nursery, + + # if connecting to a publisher that has already sent messages set + # to the next expected payload index this subscriber will receive + start_index: int = 0 ): - super().__init__(n) - self._send_chan, self._recv_chan = trio.open_memory_channel(0) + self._chanmngr = ChannelManager[RingBufferReceiveChannel]( + n, + self._open_channel, + self._channel_task + ) + + self._schan, self._rchan = trio.open_memory_channel(0) + + @property + def channels(self) -> list[ChannelInfo]: + return self._chanmngr.channels + + def get_channel(self, name: str): + return self._chanmngr[name] + + async def add_channel(self, name: str, must_exist: bool = False): + ''' + Add new input channel by name + + ''' + await self._chanmngr.add_channel(name, must_exist=must_exist) + + async def remove_channel(self, name: str): + ''' + Remove an input channel by name + + ''' + await self._chanmngr.remove_channel(name) @acm async def _open_channel( + self, - name: str - ) -> AsyncContextManager[RingBuffBytesReceiver]: + name: str, + must_exist: bool = False + + ) -> AsyncContextManager[RingBufferReceiveChannel]: + ''' + Open a ringbuf through `ringd` and attach as receiver side + ''' async with ( ringd.open_ringbuf( name=name, - must_exist=True, + must_exist=must_exist, ) as token, - attach_to_ringbuf_rchannel(token) as chan + attach_to_ringbuf_receiver(token) as chan ): yield chan async def _channel_task(self, info: ChannelInfo) -> None: - send_chan = self._send_chan.clone() - try: - async for msg in info.channel: - await send_chan.send(msg) + ''' + Iterate over receive channel messages, decode them as `OrderedPayload`s + and stash them in `self._pending_payloads`, in case we can pop next in + order payload, signal through setting `self._new_payload_event`. - except tractor._exceptions.InternalError: - # TODO: cleaner cancellation! - ... + ''' + while True: + try: + msg = await info.channel.receive() + await self._schan.send(msg) + + except tractor.linux.eventfd.EFDReadCancelled as e: + # when channel gets removed while we are doing a receive + log.exception(e) + break + + except trio.EndOfChannel: + break async def receive(self) -> bytes: - return await self._recv_chan.receive() + ''' + Receive next in order msg + ''' + return await self._rchan.receive() + async def aclose(self) -> None: + await self._chanmngr.aclose() @acm -async def open_ringbuf_subscriber(): +async def open_ringbuf_subscriber( + + guarantee_order: bool = False, + force_cancel: bool = False + +) -> AsyncContextManager[RingBufferPublisher]: + ''' + Open a new ringbuf subscriber + + ''' async with ( trio.open_nursery() as n, - RingBuffSubscriber(n) as inputs + RingBufferSubscriber( + n, + ) as subscriber ): - yield inputs + if guarantee_order: + order_receive_channel(subscriber) + yield subscriber + + if force_cancel: + # implicitly cancel any running channel handler task + n.cancel_scope.cancel() diff --git a/tractor/trionics/__init__.py b/tractor/trionics/__init__.py index df9b6f26..97d03da7 100644 --- a/tractor/trionics/__init__.py +++ b/tractor/trionics/__init__.py @@ -32,3 +32,8 @@ from ._broadcast import ( from ._beg import ( collapse_eg as collapse_eg, ) + +from ._ordering import ( + order_send_channel as order_send_channel, + order_receive_channel as order_receive_channel +) diff --git a/tractor/trionics/_ordering.py b/tractor/trionics/_ordering.py new file mode 100644 index 00000000..2d7e9082 --- /dev/null +++ b/tractor/trionics/_ordering.py @@ -0,0 +1,89 @@ +from __future__ import annotations +from heapq import ( + heappush, + heappop +) + +import trio +import msgspec + + +class OrderedPayload(msgspec.Struct, frozen=True): + index: int + payload: bytes + + @classmethod + def from_msg(cls, msg: bytes) -> OrderedPayload: + return msgspec.msgpack.decode(msg, type=OrderedPayload) + + def encode(self) -> bytes: + return msgspec.msgpack.encode(self) + + +def order_send_channel( + channel: trio.abc.SendChannel[bytes], + start_index: int = 0 +): + + next_index = start_index + send_lock = trio.StrictFIFOLock() + + channel._send = channel.send + channel._aclose = channel.aclose + + async def send(msg: bytes): + nonlocal next_index + async with send_lock: + await channel._send( + OrderedPayload( + index=next_index, + payload=msg + ).encode() + ) + next_index += 1 + + async def aclose(): + async with send_lock: + await channel._aclose() + + channel.send = send + channel.aclose = aclose + + +def order_receive_channel( + channel: trio.abc.ReceiveChannel[bytes], + start_index: int = 0 +): + next_index = start_index + pqueue = [] + + channel._receive = channel.receive + + def can_pop_next() -> bool: + return ( + len(pqueue) > 0 + and + pqueue[0][0] == next_index + ) + + async def drain_to_heap(): + while not can_pop_next(): + msg = await channel._receive() + msg = OrderedPayload.from_msg(msg) + heappush(pqueue, (msg.index, msg.payload)) + + def pop_next(): + nonlocal next_index + _, msg = heappop(pqueue) + next_index += 1 + return msg + + async def receive() -> bytes: + if can_pop_next(): + return pop_next() + + await drain_to_heap() + + return pop_next() + + channel.receive = receive