Fully test and fix bugs on _ringbuf._pubsub

Add generic channel orderer
one_ring_to_rule_them_all
Guillermo Rodriguez 2025-04-04 02:44:45 -03:00
parent bebd327023
commit 1dfc639e54
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
4 changed files with 648 additions and 325 deletions

View File

@ -92,187 +92,187 @@ def test_ringd():
trio.run(main)
# class Struct(msgspec.Struct):
#
# def encode(self) -> bytes:
# return msgspec.msgpack.encode(self)
#
#
# class AddChannelMsg(Struct, frozen=True, tag=True):
# name: str
#
#
# class RemoveChannelMsg(Struct, frozen=True, tag=True):
# name: str
#
#
# class RangeMsg(Struct, frozen=True, tag=True):
# start: int
# end: int
#
#
# ControlMessages = AddChannelMsg | RemoveChannelMsg | RangeMsg
#
#
# @tractor.context
# async def subscriber_child(ctx: tractor.Context):
# await ctx.started()
# async with (
# open_ringbuf_subscriber(guarantee_order=True) as subs,
# trio.open_nursery() as n,
# ctx.open_stream() as stream
# ):
# range_msg = None
# range_event = trio.Event()
# range_scope = trio.CancelScope()
#
# async def _control_listen_task():
# nonlocal range_msg, range_event
# async for msg in stream:
# msg = msgspec.msgpack.decode(msg, type=ControlMessages)
# match msg:
# case AddChannelMsg():
# await subs.add_channel(msg.name, must_exist=False)
#
# case RemoveChannelMsg():
# await subs.remove_channel(msg.name)
#
# case RangeMsg():
# range_msg = msg
# range_event.set()
#
# await stream.send(b'ack')
#
# range_scope.cancel()
#
# n.start_soon(_control_listen_task)
#
# with range_scope:
# while True:
# await range_event.wait()
# range_event = trio.Event()
# for i in range(range_msg.start, range_msg.end):
# recv = int.from_bytes(await subs.receive())
# # if recv != i:
# # raise AssertionError(
# # f'received: {recv} expected: {i}'
# # )
#
# log.info(f'received: {recv} expected: {i}')
#
# await stream.send(b'valid range')
# log.info('FINISHED RANGE')
#
# log.info('subscriber exit')
#
#
# @tractor.context
# async def publisher_child(ctx: tractor.Context):
# await ctx.started()
# async with (
# open_ringbuf_publisher(batch_size=1, guarantee_order=True) as pub,
# ctx.open_stream() as stream
# ):
# abs_index = 0
# async for msg in stream:
# msg = msgspec.msgpack.decode(msg, type=ControlMessages)
# match msg:
# case AddChannelMsg():
# await pub.add_channel(msg.name, must_exist=True)
#
# case RemoveChannelMsg():
# await pub.remove_channel(msg.name)
#
# case RangeMsg():
# for i in range(msg.start, msg.end):
# await pub.send(i.to_bytes(4))
# log.info(f'sent {i}, index: {abs_index}')
# abs_index += 1
#
# await stream.send(b'ack')
#
# log.info('publisher exit')
#
#
#
# def test_pubsub():
# '''
# Spawn ringd actor and two childs that access same ringbuf through ringd.
#
# Both will use `ringd.open_ringbuf` to allocate the ringbuf, then attach to
# them as sender and receiver.
#
# '''
# async def main():
# async with (
# tractor.open_nursery(
# loglevel='info',
# # debug_mode=True,
# # enable_stack_on_sig=True
# ) as an,
#
# ringd.open_ringd()
# ):
# recv_portal = await an.start_actor(
# 'recv',
# enable_modules=[__name__]
class Struct(msgspec.Struct):
def encode(self) -> bytes:
return msgspec.msgpack.encode(self)
class AddChannelMsg(Struct, frozen=True, tag=True):
name: str
class RemoveChannelMsg(Struct, frozen=True, tag=True):
name: str
class RangeMsg(Struct, frozen=True, tag=True):
start: int
end: int
ControlMessages = AddChannelMsg | RemoveChannelMsg | RangeMsg
@tractor.context
async def subscriber_child(ctx: tractor.Context):
await ctx.started()
async with (
open_ringbuf_subscriber(guarantee_order=True) as subs,
trio.open_nursery() as n,
ctx.open_stream() as stream
):
range_msg = None
range_event = trio.Event()
range_scope = trio.CancelScope()
async def _control_listen_task():
nonlocal range_msg, range_event
async for msg in stream:
msg = msgspec.msgpack.decode(msg, type=ControlMessages)
match msg:
case AddChannelMsg():
await subs.add_channel(msg.name, must_exist=False)
case RemoveChannelMsg():
await subs.remove_channel(msg.name)
case RangeMsg():
range_msg = msg
range_event.set()
await stream.send(b'ack')
range_scope.cancel()
n.start_soon(_control_listen_task)
with range_scope:
while True:
await range_event.wait()
range_event = trio.Event()
for i in range(range_msg.start, range_msg.end):
recv = int.from_bytes(await subs.receive())
# if recv != i:
# raise AssertionError(
# f'received: {recv} expected: {i}'
# )
# send_portal = await an.start_actor(
# 'send',
# enable_modules=[__name__]
# )
#
# async with (
# recv_portal.open_context(subscriber_child) as (rctx, _),
# rctx.open_stream() as recv_stream,
# send_portal.open_context(publisher_child) as (sctx, _),
# sctx.open_stream() as send_stream,
# ):
# async def send_wait_ack(msg: bytes):
# await recv_stream.send(msg)
# ack = await recv_stream.receive()
# assert ack == b'ack'
#
# await send_stream.send(msg)
# ack = await send_stream.receive()
# assert ack == b'ack'
#
# async def add_channel(name: str):
# await send_wait_ack(AddChannelMsg(name=name).encode())
#
# async def remove_channel(name: str):
# await send_wait_ack(RemoveChannelMsg(name=name).encode())
#
# async def send_range(start: int, end: int):
# await send_wait_ack(RangeMsg(start=start, end=end).encode())
# range_ack = await recv_stream.receive()
# assert range_ack == b'valid range'
#
# # simple test, open one channel and send 0..100 range
# ring_name = 'ring-first'
# await add_channel(ring_name)
# await send_range(0, 100)
# await remove_channel(ring_name)
#
# # redo
# ring_name = 'ring-redo'
# await add_channel(ring_name)
# await send_range(0, 100)
# await remove_channel(ring_name)
#
# # multi chan test
# ring_names = []
# for i in range(3):
# ring_names.append(f'multi-ring-{i}')
#
# for name in ring_names:
# await add_channel(name)
#
# await send_range(0, 300)
#
# for name in ring_names:
# await remove_channel(name)
#
# await an.cancel()
#
# trio.run(main)
log.info(f'received: {recv} expected: {i}')
await stream.send(b'valid range')
log.info('FINISHED RANGE')
log.info('subscriber exit')
@tractor.context
async def publisher_child(ctx: tractor.Context):
await ctx.started()
async with (
open_ringbuf_publisher(batch_size=1, guarantee_order=True) as pub,
ctx.open_stream() as stream
):
abs_index = 0
async for msg in stream:
msg = msgspec.msgpack.decode(msg, type=ControlMessages)
match msg:
case AddChannelMsg():
await pub.add_channel(msg.name, must_exist=True)
case RemoveChannelMsg():
await pub.remove_channel(msg.name)
case RangeMsg():
for i in range(msg.start, msg.end):
await pub.send(i.to_bytes(4))
log.info(f'sent {i}, index: {abs_index}')
abs_index += 1
await stream.send(b'ack')
log.info('publisher exit')
def test_pubsub():
'''
Spawn ringd actor and two childs that access same ringbuf through ringd.
Both will use `ringd.open_ringbuf` to allocate the ringbuf, then attach to
them as sender and receiver.
'''
async def main():
async with (
tractor.open_nursery(
loglevel='info',
# debug_mode=True,
# enable_stack_on_sig=True
) as an,
ringd.open_ringd()
):
recv_portal = await an.start_actor(
'recv',
enable_modules=[__name__]
)
send_portal = await an.start_actor(
'send',
enable_modules=[__name__]
)
async with (
recv_portal.open_context(subscriber_child) as (rctx, _),
rctx.open_stream() as recv_stream,
send_portal.open_context(publisher_child) as (sctx, _),
sctx.open_stream() as send_stream,
):
async def send_wait_ack(msg: bytes):
await recv_stream.send(msg)
ack = await recv_stream.receive()
assert ack == b'ack'
await send_stream.send(msg)
ack = await send_stream.receive()
assert ack == b'ack'
async def add_channel(name: str):
await send_wait_ack(AddChannelMsg(name=name).encode())
async def remove_channel(name: str):
await send_wait_ack(RemoveChannelMsg(name=name).encode())
async def send_range(start: int, end: int):
await send_wait_ack(RangeMsg(start=start, end=end).encode())
range_ack = await recv_stream.receive()
assert range_ack == b'valid range'
# simple test, open one channel and send 0..100 range
ring_name = 'ring-first'
await add_channel(ring_name)
await send_range(0, 100)
await remove_channel(ring_name)
# redo
ring_name = 'ring-redo'
await add_channel(ring_name)
await send_range(0, 100)
await remove_channel(ring_name)
# multi chan test
ring_names = []
for i in range(3):
ring_names.append(f'multi-ring-{i}')
for name in ring_names:
await add_channel(name)
await send_range(0, 300)
for name in ring_names:
await remove_channel(name)
await an.cancel()
trio.run(main)

View File

@ -17,13 +17,14 @@
Ring buffer ipc publish-subscribe mechanism brokered by ringd
can dynamically add new outputs (publisher) or inputs (subscriber)
'''
import time
from typing import (
runtime_checkable,
Protocol,
TypeVar,
Generic,
Callable,
Awaitable,
AsyncContextManager
)
from functools import partial
from contextlib import asynccontextmanager as acm
from dataclasses import dataclass
@ -31,12 +32,16 @@ import trio
import tractor
from tractor.ipc import (
RingBuffBytesSender,
RingBuffBytesReceiver,
attach_to_ringbuf_schannel,
attach_to_ringbuf_rchannel
RingBufferSendChannel,
RingBufferReceiveChannel,
attach_to_ringbuf_sender,
attach_to_ringbuf_receiver
)
from tractor.trionics import (
order_send_channel,
order_receive_channel
)
import tractor.ipc._ringbuf._ringd as ringd
@ -48,66 +53,100 @@ ChannelType = TypeVar('ChannelType')
@dataclass
class ChannelInfo:
connect_time: float
name: str
channel: ChannelType
cancel_scope: trio.CancelScope
# TODO: maybe move this abstraction to another module or standalone?
# its not ring buf specific and allows fan out and fan in an a dynamic
# amount of channels
@runtime_checkable
class ChannelManager(Protocol[ChannelType]):
class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]):
'''
Common data structures and methods pubsub classes use to manage channels &
their related handler background tasks, as well as cancellation of them.
Helper for managing channel resources and their handler tasks with
cancellation, add or remove channels dynamically!
'''
def __init__(
self,
# nursery used to spawn channel handler tasks
n: trio.Nursery,
# acm will be used for setup & teardown of channel resources
open_channel_acm: Callable[..., AsyncContextManager[ChannelType]],
# long running bg task to handle channel
channel_task: Callable[..., Awaitable[None]]
):
self._n = n
self._open_channel = open_channel_acm
self._channel_task = channel_task
# signal when a new channel conects and we previously had none
self._connect_event = trio.Event()
# store channel runtime variables
self._channels: list[ChannelInfo] = []
async def _open_channel(
# methods that modify self._channels should be ordered by FIFO
self._lock = trio.StrictFIFOLock()
@acm
async def maybe_lock(self):
'''
If lock is not held, acquire
'''
if self._lock.locked():
yield
return
async with self._lock:
yield
@property
def channels(self) -> list[ChannelInfo]:
return self._channels
async def _channel_handler_task(
self,
name: str
) -> AsyncContextManager[ChannelType]:
name: str,
task_status: trio.TASK_STATUS_IGNORED,
**kwargs
):
'''
Used to instantiate channel resources given a name
Open channel resources, add to internal data structures, signal channel
connect through trio.Event, and run `channel_task` with cancel scope,
and finally, maybe remove channel from internal data structures.
Spawned by `add_channel` function, lock is held from begining of fn
until `task_status.started()` call.
kwargs are proxied to `self._open_channel` acm.
'''
...
async def _channel_task(self, info: ChannelInfo) -> None:
'''
Long running task that manages the channel
'''
...
async def _channel_handler_task(self, name: str):
async with self._open_channel(name) as chan:
with trio.CancelScope() as cancel_scope:
async with self._open_channel(name, **kwargs) as chan:
cancel_scope = trio.CancelScope()
info = ChannelInfo(
connect_time=time.time(),
name=name,
channel=chan,
cancel_scope=cancel_scope
)
self._channels.append(info)
if len(self) == 1:
self._connect_event.set()
task_status.started()
with cancel_scope:
await self._channel_task(info)
self._maybe_destroy_channel(name)
await self._maybe_destroy_channel(name)
def find_channel(self, name: str) -> tuple[int, ChannelInfo] | None:
def _find_channel(self, name: str) -> tuple[int, ChannelInfo] | None:
'''
Given a channel name maybe return its index and value from
internal _channels list.
Only use after acquiring lock.
'''
for entry in enumerate(self._channels):
i, info = entry
@ -116,105 +155,114 @@ class ChannelManager(Protocol[ChannelType]):
return None
def _maybe_destroy_channel(self, name: str):
async def _maybe_destroy_channel(self, name: str):
'''
If channel exists cancel its scope and remove from internal
_channels list.
'''
maybe_entry = self.find_channel(name)
async with self.maybe_lock():
maybe_entry = self._find_channel(name)
if maybe_entry:
i, info = maybe_entry
info.cancel_scope.cancel()
del self._channels[i]
def add_channel(self, name: str):
async def add_channel(self, name: str, **kwargs):
'''
Add a new channel to be handled
'''
self._n.start_soon(
async with self.maybe_lock():
await self._n.start(partial(
self._channel_handler_task,
name
)
name,
**kwargs
))
def remove_channel(self, name: str):
async def remove_channel(self, name: str):
'''
Remove a channel and stop its handling
'''
self._maybe_destroy_channel(name)
async with self.maybe_lock():
await self._maybe_destroy_channel(name)
# if that was last channel reset connect event
if len(self) == 0:
self._connect_event = trio.Event()
async def wait_for_channel(self):
'''
Wait until at least one channel added
'''
await self._connect_event.wait()
self._connect_event = trio.Event()
def __len__(self) -> int:
return len(self._channels)
def __getitem__(self, name: str):
maybe_entry = self._find_channel(name)
if maybe_entry:
_, info = maybe_entry
return info
raise KeyError(f'Channel {name} not found!')
async def aclose(self) -> None:
for chan in self._channels:
self._maybe_destroy_channel(chan.name)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.aclose()
async with self.maybe_lock():
for info in self._channels:
await self.remove_channel(info.name)
class RingBuffPublisher(
ChannelManager[RingBuffBytesSender]
):
'''
Implement ChannelManager protocol + trio.abc.SendChannel[bytes]
using ring buffers as transport.
- use a `trio.Event` to make sure `send` blocks until at least one channel
available.
Ring buffer publisher & subscribe pattern mediated by `ringd` actor.
'''
@dataclass
class PublisherChannels:
ring: RingBufferSendChannel
schan: trio.MemorySendChannel
rchan: trio.MemoryReceiveChannel
class RingBufferPublisher(trio.abc.SendChannel[bytes]):
'''
Use ChannelManager to create a multi ringbuf round robin sender that can
dynamically add or remove more outputs.
Don't instantiate directly, use `open_ringbuf_publisher` acm to manage its
lifecycle.
'''
def __init__(
self,
n: trio.Nursery,
# new ringbufs created will have this buf_size
buf_size: int = 10 * 1024,
# global batch size for all channels
batch_size: int = 1
):
super().__init__(n)
self._connect_event = trio.Event()
self._next_turn: int = 0
self._buf_size = buf_size
self._batch_size: int = batch_size
@acm
async def _open_channel(
self,
name: str
) -> AsyncContextManager[RingBuffBytesSender]:
async with (
ringd.open_ringbuf(
name=name,
must_exist=True,
) as token,
attach_to_ringbuf_schannel(token) as chan
):
yield chan
self._chanmngr = ChannelManager[PublisherChannels](
n,
self._open_channel,
self._channel_task
)
async def _channel_task(self, info: ChannelInfo) -> None:
self._connect_event.set()
await trio.sleep_forever()
# methods that send data over the channels need to be acquire send lock
# in order to guarantee order of operations
self._send_lock = trio.StrictFIFOLock()
async def send(self, msg: bytes):
# wait at least one decoder connected
if len(self) == 0:
await self._connect_event.wait()
self._connect_event = trio.Event()
if self._next_turn >= len(self):
self._next_turn = 0
turn = self._next_turn
self._next_turn += 1
output = self._channels[turn]
await output.channel.send(msg)
self._next_turn: int = 0
@property
def batch_size(self) -> int:
@ -222,92 +270,273 @@ class RingBuffPublisher(
@batch_size.setter
def set_batch_size(self, value: int) -> None:
for output in self._channels:
output.channel.batch_size = value
for info in self.channels:
info.channel.ring.batch_size = value
async def flush(
@property
def channels(self) -> list[ChannelInfo]:
return self._chanmngr.channels
def get_channel(self, name: str) -> ChannelInfo:
'''
Get underlying ChannelInfo from name
'''
return self._chanmngr[name]
async def add_channel(
self,
new_batch_size: int | None = None
name: str,
must_exist: bool = False
):
for output in self._channels:
await output.channel.flush(
new_batch_size=new_batch_size
'''
Store additional runtime info for channel and add channel to underlying
ChannelManager
'''
await self._chanmngr.add_channel(name, must_exist=must_exist)
async def remove_channel(self, name: str):
'''
Send EOF to channel (does `channel.flush` also) then remove from
`ChannelManager` acquire both `self._send_lock` and
`self._chanmngr.maybe_lock()` in order to ensure no channel
modifications or sends happen concurrenty
'''
async with self._chanmngr.maybe_lock():
# ensure all pending messages are sent
info = self.get_channel(name)
try:
while True:
msg = info.channel.rchan.receive_nowait()
await info.channel.ring.send(msg)
except trio.WouldBlock:
await info.channel.ring.flush()
await info.channel.schan.aclose()
# finally remove from ChannelManager
await self._chanmngr.remove_channel(name)
@acm
async def _open_channel(
self,
name: str,
must_exist: bool = False
) -> AsyncContextManager[PublisherChannels]:
'''
Open a ringbuf through `ringd` and attach as send side
'''
async with (
ringd.open_ringbuf(
name=name,
buf_size=self._buf_size,
must_exist=must_exist,
) as token,
attach_to_ringbuf_sender(token) as ring,
):
schan, rchan = trio.open_memory_channel(0)
yield PublisherChannels(
ring=ring,
schan=schan,
rchan=rchan
)
async def send_eof(self):
for output in self._channels:
await output.channel.send_eof()
async def _channel_task(self, info: ChannelInfo) -> None:
'''
Forever get current runtime info for channel, wait on its next pending
payloads update event then drain all into send channel.
'''
async for msg in info.channel.rchan:
await info.channel.ring.send(msg)
async def send(self, msg: bytes):
'''
If no output channels connected, wait until one, then fetch the next
channel based on turn, add the indexed payload and update
`self._next_turn` & `self._next_index`.
Needs to acquire `self._send_lock` to make sure updates to turn & index
variables dont happen out of order.
'''
async with self._send_lock:
# wait at least one decoder connected
if len(self.channels) == 0:
await self._chanmngr.wait_for_channel()
if self._next_turn >= len(self.channels):
self._next_turn = 0
info = self.channels[self._next_turn]
await info.channel.schan.send(msg)
self._next_turn += 1
async def aclose(self) -> None:
await self._chanmngr.aclose()
@acm
async def open_ringbuf_publisher(
buf_size: int = 10 * 1024,
batch_size: int = 1
):
batch_size: int = 1,
guarantee_order: bool = False,
force_cancel: bool = False
) -> AsyncContextManager[RingBufferPublisher]:
'''
Open a new ringbuf publisher
'''
async with (
trio.open_nursery() as n,
RingBuffPublisher(
RingBufferPublisher(
n,
buf_size=buf_size,
batch_size=batch_size
) as outputs
) as publisher
):
yield outputs
if guarantee_order:
order_send_channel(publisher)
yield publisher
if force_cancel:
# implicitly cancel any running channel handler task
n.cancel_scope.cancel()
class RingBuffSubscriber(
ChannelManager[RingBuffBytesReceiver]
):
class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
'''
Implement ChannelManager protocol + trio.abc.ReceiveChannel[bytes]
using ring buffers as transport.
Use ChannelManager to create a multi ringbuf receiver that can
dynamically add or remove more inputs and combine all into a single output.
- use a trio memory channel pair to multiplex all received messages into a
single `trio.MemoryReceiveChannel`, give a sender channel clone to each
_channel_task.
In order for `self.receive` messages to be returned in order, publisher
will send all payloads as `OrderedPayload` msgpack encoded msgs, this
allows our channel handler tasks to just stash the out of order payloads
inside `self._pending_payloads` and if a in order payload is available
signal through `self._new_payload_event`.
On `self.receive` we wait until at least one channel is connected, then if
an in order payload is pending, we pop and return it, in case no in order
payload is available wait until next `self._new_payload_event.set()`.
'''
def __init__(
self,
n: trio.Nursery,
# if connecting to a publisher that has already sent messages set
# to the next expected payload index this subscriber will receive
start_index: int = 0
):
super().__init__(n)
self._send_chan, self._recv_chan = trio.open_memory_channel(0)
self._chanmngr = ChannelManager[RingBufferReceiveChannel](
n,
self._open_channel,
self._channel_task
)
self._schan, self._rchan = trio.open_memory_channel(0)
@property
def channels(self) -> list[ChannelInfo]:
return self._chanmngr.channels
def get_channel(self, name: str):
return self._chanmngr[name]
async def add_channel(self, name: str, must_exist: bool = False):
'''
Add new input channel by name
'''
await self._chanmngr.add_channel(name, must_exist=must_exist)
async def remove_channel(self, name: str):
'''
Remove an input channel by name
'''
await self._chanmngr.remove_channel(name)
@acm
async def _open_channel(
self,
name: str
) -> AsyncContextManager[RingBuffBytesReceiver]:
name: str,
must_exist: bool = False
) -> AsyncContextManager[RingBufferReceiveChannel]:
'''
Open a ringbuf through `ringd` and attach as receiver side
'''
async with (
ringd.open_ringbuf(
name=name,
must_exist=True,
must_exist=must_exist,
) as token,
attach_to_ringbuf_rchannel(token) as chan
attach_to_ringbuf_receiver(token) as chan
):
yield chan
async def _channel_task(self, info: ChannelInfo) -> None:
send_chan = self._send_chan.clone()
try:
async for msg in info.channel:
await send_chan.send(msg)
'''
Iterate over receive channel messages, decode them as `OrderedPayload`s
and stash them in `self._pending_payloads`, in case we can pop next in
order payload, signal through setting `self._new_payload_event`.
except tractor._exceptions.InternalError:
# TODO: cleaner cancellation!
...
'''
while True:
try:
msg = await info.channel.receive()
await self._schan.send(msg)
except tractor.linux.eventfd.EFDReadCancelled as e:
# when channel gets removed while we are doing a receive
log.exception(e)
break
except trio.EndOfChannel:
break
async def receive(self) -> bytes:
return await self._recv_chan.receive()
'''
Receive next in order msg
'''
return await self._rchan.receive()
async def aclose(self) -> None:
await self._chanmngr.aclose()
@acm
async def open_ringbuf_subscriber():
async def open_ringbuf_subscriber(
guarantee_order: bool = False,
force_cancel: bool = False
) -> AsyncContextManager[RingBufferPublisher]:
'''
Open a new ringbuf subscriber
'''
async with (
trio.open_nursery() as n,
RingBuffSubscriber(n) as inputs
RingBufferSubscriber(
n,
) as subscriber
):
yield inputs
if guarantee_order:
order_receive_channel(subscriber)
yield subscriber
if force_cancel:
# implicitly cancel any running channel handler task
n.cancel_scope.cancel()

View File

@ -32,3 +32,8 @@ from ._broadcast import (
from ._beg import (
collapse_eg as collapse_eg,
)
from ._ordering import (
order_send_channel as order_send_channel,
order_receive_channel as order_receive_channel
)

View File

@ -0,0 +1,89 @@
from __future__ import annotations
from heapq import (
heappush,
heappop
)
import trio
import msgspec
class OrderedPayload(msgspec.Struct, frozen=True):
index: int
payload: bytes
@classmethod
def from_msg(cls, msg: bytes) -> OrderedPayload:
return msgspec.msgpack.decode(msg, type=OrderedPayload)
def encode(self) -> bytes:
return msgspec.msgpack.encode(self)
def order_send_channel(
channel: trio.abc.SendChannel[bytes],
start_index: int = 0
):
next_index = start_index
send_lock = trio.StrictFIFOLock()
channel._send = channel.send
channel._aclose = channel.aclose
async def send(msg: bytes):
nonlocal next_index
async with send_lock:
await channel._send(
OrderedPayload(
index=next_index,
payload=msg
).encode()
)
next_index += 1
async def aclose():
async with send_lock:
await channel._aclose()
channel.send = send
channel.aclose = aclose
def order_receive_channel(
channel: trio.abc.ReceiveChannel[bytes],
start_index: int = 0
):
next_index = start_index
pqueue = []
channel._receive = channel.receive
def can_pop_next() -> bool:
return (
len(pqueue) > 0
and
pqueue[0][0] == next_index
)
async def drain_to_heap():
while not can_pop_next():
msg = await channel._receive()
msg = OrderedPayload.from_msg(msg)
heappush(pqueue, (msg.index, msg.payload))
def pop_next():
nonlocal next_index
_, msg = heappop(pqueue)
next_index += 1
return msg
async def receive() -> bytes:
if can_pop_next():
return pop_next()
await drain_to_heap()
return pop_next()
channel.receive = receive