Fully test and fix bugs on _ringbuf._pubsub
Add generic channel orderer
							parent
							
								
									0b9c2de3ad
								
							
						
					
					
						commit
						3c1873c68a
					
				|  | @ -92,187 +92,187 @@ def test_ringd(): | |||
|     trio.run(main) | ||||
| 
 | ||||
| 
 | ||||
| # class Struct(msgspec.Struct): | ||||
| #  | ||||
| #     def encode(self) -> bytes: | ||||
| #         return msgspec.msgpack.encode(self) | ||||
| #  | ||||
| #  | ||||
| # class AddChannelMsg(Struct, frozen=True, tag=True): | ||||
| #     name: str | ||||
| #  | ||||
| #  | ||||
| # class RemoveChannelMsg(Struct, frozen=True, tag=True): | ||||
| #     name: str | ||||
| #  | ||||
| #  | ||||
| # class RangeMsg(Struct, frozen=True, tag=True): | ||||
| #     start: int | ||||
| #     end: int | ||||
| #  | ||||
| #  | ||||
| # ControlMessages = AddChannelMsg | RemoveChannelMsg | RangeMsg | ||||
| #  | ||||
| #  | ||||
| # @tractor.context | ||||
| # async def subscriber_child(ctx: tractor.Context): | ||||
| #     await ctx.started() | ||||
| #     async with ( | ||||
| #         open_ringbuf_subscriber(guarantee_order=True) as subs, | ||||
| #         trio.open_nursery() as n, | ||||
| #         ctx.open_stream() as stream | ||||
| #     ): | ||||
| #         range_msg = None | ||||
| #         range_event = trio.Event() | ||||
| #         range_scope = trio.CancelScope() | ||||
| #  | ||||
| #         async def _control_listen_task(): | ||||
| #             nonlocal range_msg, range_event | ||||
| #             async for msg in stream: | ||||
| #                 msg = msgspec.msgpack.decode(msg, type=ControlMessages) | ||||
| #                 match msg: | ||||
| #                     case AddChannelMsg(): | ||||
| #                         await subs.add_channel(msg.name, must_exist=False) | ||||
| #  | ||||
| #                     case RemoveChannelMsg(): | ||||
| #                         await subs.remove_channel(msg.name) | ||||
| #  | ||||
| #                     case RangeMsg(): | ||||
| #                         range_msg = msg | ||||
| #                         range_event.set() | ||||
| #  | ||||
| #                 await stream.send(b'ack') | ||||
| #  | ||||
| #             range_scope.cancel() | ||||
| #  | ||||
| #         n.start_soon(_control_listen_task) | ||||
| #  | ||||
| #         with range_scope: | ||||
| #             while True: | ||||
| #                 await range_event.wait() | ||||
| #                 range_event = trio.Event() | ||||
| #                 for i in range(range_msg.start, range_msg.end): | ||||
| #                     recv = int.from_bytes(await subs.receive()) | ||||
| #                     # if recv != i: | ||||
| #                     #     raise AssertionError( | ||||
| #                     #         f'received: {recv} expected: {i}' | ||||
| #                     #     ) | ||||
| #  | ||||
| #                     log.info(f'received: {recv} expected: {i}') | ||||
| #  | ||||
| #                 await stream.send(b'valid range') | ||||
| #                 log.info('FINISHED RANGE') | ||||
| #  | ||||
| #     log.info('subscriber exit') | ||||
| #  | ||||
| #  | ||||
| # @tractor.context | ||||
| # async def publisher_child(ctx: tractor.Context): | ||||
| #     await ctx.started() | ||||
| #     async with ( | ||||
| #         open_ringbuf_publisher(batch_size=1, guarantee_order=True) as pub, | ||||
| #         ctx.open_stream() as stream | ||||
| #     ): | ||||
| #         abs_index = 0 | ||||
| #         async for msg in stream: | ||||
| #             msg = msgspec.msgpack.decode(msg, type=ControlMessages) | ||||
| #             match msg: | ||||
| #                 case AddChannelMsg(): | ||||
| #                     await pub.add_channel(msg.name, must_exist=True) | ||||
| #  | ||||
| #                 case RemoveChannelMsg(): | ||||
| #                     await pub.remove_channel(msg.name) | ||||
| #  | ||||
| #                 case RangeMsg(): | ||||
| #                     for i in range(msg.start, msg.end): | ||||
| #                         await pub.send(i.to_bytes(4)) | ||||
| #                         log.info(f'sent {i}, index: {abs_index}') | ||||
| #                         abs_index += 1 | ||||
| #  | ||||
| #             await stream.send(b'ack') | ||||
| #  | ||||
| #     log.info('publisher exit') | ||||
| #  | ||||
| #  | ||||
| #  | ||||
| # def test_pubsub(): | ||||
| #     ''' | ||||
| #     Spawn ringd actor and two childs that access same ringbuf through ringd. | ||||
| #  | ||||
| #     Both will use `ringd.open_ringbuf` to allocate the ringbuf, then attach to | ||||
| #     them as sender and receiver. | ||||
| #  | ||||
| #     ''' | ||||
| #     async def main(): | ||||
| #         async with ( | ||||
| #             tractor.open_nursery( | ||||
| #                 loglevel='info', | ||||
| #                 # debug_mode=True, | ||||
| #                 # enable_stack_on_sig=True | ||||
| #             ) as an, | ||||
| #  | ||||
| #             ringd.open_ringd() | ||||
| #         ): | ||||
| #             recv_portal = await an.start_actor( | ||||
| #                 'recv', | ||||
| #                 enable_modules=[__name__] | ||||
| class Struct(msgspec.Struct): | ||||
| 
 | ||||
|     def encode(self) -> bytes: | ||||
|         return msgspec.msgpack.encode(self) | ||||
| 
 | ||||
| 
 | ||||
| class AddChannelMsg(Struct, frozen=True, tag=True): | ||||
|     name: str | ||||
| 
 | ||||
| 
 | ||||
| class RemoveChannelMsg(Struct, frozen=True, tag=True): | ||||
|     name: str | ||||
| 
 | ||||
| 
 | ||||
| class RangeMsg(Struct, frozen=True, tag=True): | ||||
|     start: int | ||||
|     end: int | ||||
| 
 | ||||
| 
 | ||||
| ControlMessages = AddChannelMsg | RemoveChannelMsg | RangeMsg | ||||
| 
 | ||||
| 
 | ||||
| @tractor.context | ||||
| async def subscriber_child(ctx: tractor.Context): | ||||
|     await ctx.started() | ||||
|     async with ( | ||||
|         open_ringbuf_subscriber(guarantee_order=True) as subs, | ||||
|         trio.open_nursery() as n, | ||||
|         ctx.open_stream() as stream | ||||
|     ): | ||||
|         range_msg = None | ||||
|         range_event = trio.Event() | ||||
|         range_scope = trio.CancelScope() | ||||
| 
 | ||||
|         async def _control_listen_task(): | ||||
|             nonlocal range_msg, range_event | ||||
|             async for msg in stream: | ||||
|                 msg = msgspec.msgpack.decode(msg, type=ControlMessages) | ||||
|                 match msg: | ||||
|                     case AddChannelMsg(): | ||||
|                         await subs.add_channel(msg.name, must_exist=False) | ||||
| 
 | ||||
|                     case RemoveChannelMsg(): | ||||
|                         await subs.remove_channel(msg.name) | ||||
| 
 | ||||
|                     case RangeMsg(): | ||||
|                         range_msg = msg | ||||
|                         range_event.set() | ||||
| 
 | ||||
|                 await stream.send(b'ack') | ||||
| 
 | ||||
|             range_scope.cancel() | ||||
| 
 | ||||
|         n.start_soon(_control_listen_task) | ||||
| 
 | ||||
|         with range_scope: | ||||
|             while True: | ||||
|                 await range_event.wait() | ||||
|                 range_event = trio.Event() | ||||
|                 for i in range(range_msg.start, range_msg.end): | ||||
|                     recv = int.from_bytes(await subs.receive()) | ||||
|                     # if recv != i: | ||||
|                     #     raise AssertionError( | ||||
|                     #         f'received: {recv} expected: {i}' | ||||
|                     #     ) | ||||
| #             send_portal = await an.start_actor( | ||||
| #                 'send', | ||||
| #                 enable_modules=[__name__] | ||||
| #             ) | ||||
| #  | ||||
| #             async with ( | ||||
| #                 recv_portal.open_context(subscriber_child) as (rctx, _), | ||||
| #                 rctx.open_stream() as recv_stream, | ||||
| #                 send_portal.open_context(publisher_child) as (sctx, _), | ||||
| #                 sctx.open_stream() as send_stream, | ||||
| #             ): | ||||
| #                 async def send_wait_ack(msg: bytes): | ||||
| #                     await recv_stream.send(msg) | ||||
| #                     ack = await recv_stream.receive() | ||||
| #                     assert ack == b'ack' | ||||
| #  | ||||
| #                     await send_stream.send(msg) | ||||
| #                     ack = await send_stream.receive() | ||||
| #                     assert ack == b'ack' | ||||
| #  | ||||
| #                 async def add_channel(name: str): | ||||
| #                     await send_wait_ack(AddChannelMsg(name=name).encode()) | ||||
| #  | ||||
| #                 async def remove_channel(name: str): | ||||
| #                     await send_wait_ack(RemoveChannelMsg(name=name).encode()) | ||||
| #  | ||||
| #                 async def send_range(start: int, end: int): | ||||
| #                     await send_wait_ack(RangeMsg(start=start, end=end).encode()) | ||||
| #                     range_ack = await recv_stream.receive() | ||||
| #                     assert range_ack == b'valid range' | ||||
| #  | ||||
| #                 # simple test, open one channel and send 0..100 range | ||||
| #                 ring_name = 'ring-first' | ||||
| #                 await add_channel(ring_name) | ||||
| #                 await send_range(0, 100) | ||||
| #                 await remove_channel(ring_name) | ||||
| #  | ||||
| #                 # redo | ||||
| #                 ring_name = 'ring-redo' | ||||
| #                 await add_channel(ring_name) | ||||
| #                 await send_range(0, 100) | ||||
| #                 await remove_channel(ring_name) | ||||
| #  | ||||
| #                 # multi chan test | ||||
| #                 ring_names = [] | ||||
| #                 for i in range(3): | ||||
| #                     ring_names.append(f'multi-ring-{i}') | ||||
| #  | ||||
| #                 for name in ring_names: | ||||
| #                     await add_channel(name) | ||||
| #  | ||||
| #                 await send_range(0, 300) | ||||
| #  | ||||
| #                 for name in ring_names: | ||||
| #                     await remove_channel(name) | ||||
| #  | ||||
| #             await an.cancel() | ||||
| #  | ||||
| #     trio.run(main) | ||||
| 
 | ||||
|                     log.info(f'received: {recv} expected: {i}') | ||||
| 
 | ||||
|                 await stream.send(b'valid range') | ||||
|                 log.info('FINISHED RANGE') | ||||
| 
 | ||||
|     log.info('subscriber exit') | ||||
| 
 | ||||
| 
 | ||||
| @tractor.context | ||||
| async def publisher_child(ctx: tractor.Context): | ||||
|     await ctx.started() | ||||
|     async with ( | ||||
|         open_ringbuf_publisher(batch_size=1, guarantee_order=True) as pub, | ||||
|         ctx.open_stream() as stream | ||||
|     ): | ||||
|         abs_index = 0 | ||||
|         async for msg in stream: | ||||
|             msg = msgspec.msgpack.decode(msg, type=ControlMessages) | ||||
|             match msg: | ||||
|                 case AddChannelMsg(): | ||||
|                     await pub.add_channel(msg.name, must_exist=True) | ||||
| 
 | ||||
|                 case RemoveChannelMsg(): | ||||
|                     await pub.remove_channel(msg.name) | ||||
| 
 | ||||
|                 case RangeMsg(): | ||||
|                     for i in range(msg.start, msg.end): | ||||
|                         await pub.send(i.to_bytes(4)) | ||||
|                         log.info(f'sent {i}, index: {abs_index}') | ||||
|                         abs_index += 1 | ||||
| 
 | ||||
|             await stream.send(b'ack') | ||||
| 
 | ||||
|     log.info('publisher exit') | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| def test_pubsub(): | ||||
|     ''' | ||||
|     Spawn ringd actor and two childs that access same ringbuf through ringd. | ||||
| 
 | ||||
|     Both will use `ringd.open_ringbuf` to allocate the ringbuf, then attach to | ||||
|     them as sender and receiver. | ||||
| 
 | ||||
|     ''' | ||||
|     async def main(): | ||||
|         async with ( | ||||
|             tractor.open_nursery( | ||||
|                 loglevel='info', | ||||
|                 # debug_mode=True, | ||||
|                 # enable_stack_on_sig=True | ||||
|             ) as an, | ||||
| 
 | ||||
|             ringd.open_ringd() | ||||
|         ): | ||||
|             recv_portal = await an.start_actor( | ||||
|                 'recv', | ||||
|                 enable_modules=[__name__] | ||||
|             ) | ||||
|             send_portal = await an.start_actor( | ||||
|                 'send', | ||||
|                 enable_modules=[__name__] | ||||
|             ) | ||||
| 
 | ||||
|             async with ( | ||||
|                 recv_portal.open_context(subscriber_child) as (rctx, _), | ||||
|                 rctx.open_stream() as recv_stream, | ||||
|                 send_portal.open_context(publisher_child) as (sctx, _), | ||||
|                 sctx.open_stream() as send_stream, | ||||
|             ): | ||||
|                 async def send_wait_ack(msg: bytes): | ||||
|                     await recv_stream.send(msg) | ||||
|                     ack = await recv_stream.receive() | ||||
|                     assert ack == b'ack' | ||||
| 
 | ||||
|                     await send_stream.send(msg) | ||||
|                     ack = await send_stream.receive() | ||||
|                     assert ack == b'ack' | ||||
| 
 | ||||
|                 async def add_channel(name: str): | ||||
|                     await send_wait_ack(AddChannelMsg(name=name).encode()) | ||||
| 
 | ||||
|                 async def remove_channel(name: str): | ||||
|                     await send_wait_ack(RemoveChannelMsg(name=name).encode()) | ||||
| 
 | ||||
|                 async def send_range(start: int, end: int): | ||||
|                     await send_wait_ack(RangeMsg(start=start, end=end).encode()) | ||||
|                     range_ack = await recv_stream.receive() | ||||
|                     assert range_ack == b'valid range' | ||||
| 
 | ||||
|                 # simple test, open one channel and send 0..100 range | ||||
|                 ring_name = 'ring-first' | ||||
|                 await add_channel(ring_name) | ||||
|                 await send_range(0, 100) | ||||
|                 await remove_channel(ring_name) | ||||
| 
 | ||||
|                 # redo | ||||
|                 ring_name = 'ring-redo' | ||||
|                 await add_channel(ring_name) | ||||
|                 await send_range(0, 100) | ||||
|                 await remove_channel(ring_name) | ||||
| 
 | ||||
|                 # multi chan test | ||||
|                 ring_names = [] | ||||
|                 for i in range(3): | ||||
|                     ring_names.append(f'multi-ring-{i}') | ||||
| 
 | ||||
|                 for name in ring_names: | ||||
|                     await add_channel(name) | ||||
| 
 | ||||
|                 await send_range(0, 300) | ||||
| 
 | ||||
|                 for name in ring_names: | ||||
|                     await remove_channel(name) | ||||
| 
 | ||||
|             await an.cancel() | ||||
| 
 | ||||
|     trio.run(main) | ||||
|  |  | |||
|  | @ -17,13 +17,14 @@ | |||
| Ring buffer ipc publish-subscribe mechanism brokered by ringd | ||||
| can dynamically add new outputs (publisher) or inputs (subscriber) | ||||
| ''' | ||||
| import time | ||||
| from typing import ( | ||||
|     runtime_checkable, | ||||
|     Protocol, | ||||
|     TypeVar, | ||||
|     Generic, | ||||
|     Callable, | ||||
|     Awaitable, | ||||
|     AsyncContextManager | ||||
| ) | ||||
| from functools import partial | ||||
| from contextlib import asynccontextmanager as acm | ||||
| from dataclasses import dataclass | ||||
| 
 | ||||
|  | @ -31,12 +32,16 @@ import trio | |||
| import tractor | ||||
| 
 | ||||
| from tractor.ipc import ( | ||||
|     RingBuffBytesSender, | ||||
|     RingBuffBytesReceiver, | ||||
|     attach_to_ringbuf_schannel, | ||||
|     attach_to_ringbuf_rchannel | ||||
|     RingBufferSendChannel, | ||||
|     RingBufferReceiveChannel, | ||||
|     attach_to_ringbuf_sender, | ||||
|     attach_to_ringbuf_receiver | ||||
| ) | ||||
| 
 | ||||
| from tractor.trionics import ( | ||||
|     order_send_channel, | ||||
|     order_receive_channel | ||||
| ) | ||||
| import tractor.ipc._ringbuf._ringd as ringd | ||||
| 
 | ||||
| 
 | ||||
|  | @ -48,66 +53,100 @@ ChannelType = TypeVar('ChannelType') | |||
| 
 | ||||
| @dataclass | ||||
| class ChannelInfo: | ||||
|     connect_time: float | ||||
|     name: str | ||||
|     channel: ChannelType | ||||
|     cancel_scope: trio.CancelScope | ||||
| 
 | ||||
| 
 | ||||
| # TODO: maybe move this abstraction to another module or standalone? | ||||
| # its not ring buf specific and allows fan out and fan in an a dynamic | ||||
| # amount of channels | ||||
| @runtime_checkable | ||||
| class ChannelManager(Protocol[ChannelType]): | ||||
| class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]): | ||||
|     ''' | ||||
|     Common data structures and methods pubsub classes use to manage channels & | ||||
|     their related handler background tasks, as well as cancellation of them. | ||||
|     Helper for managing channel resources and their handler tasks with | ||||
|     cancellation, add or remove channels dynamically! | ||||
| 
 | ||||
|     ''' | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         # nursery used to spawn channel handler tasks | ||||
|         n: trio.Nursery, | ||||
| 
 | ||||
|         # acm will be used for setup & teardown of channel resources | ||||
|         open_channel_acm: Callable[..., AsyncContextManager[ChannelType]], | ||||
| 
 | ||||
|         # long running bg task to handle channel | ||||
|         channel_task: Callable[..., Awaitable[None]] | ||||
|     ): | ||||
|         self._n = n | ||||
|         self._open_channel = open_channel_acm | ||||
|         self._channel_task = channel_task | ||||
| 
 | ||||
|         # signal when a new channel conects and we previously had none | ||||
|         self._connect_event = trio.Event() | ||||
| 
 | ||||
|         # store channel runtime variables | ||||
|         self._channels: list[ChannelInfo] = [] | ||||
| 
 | ||||
|     async def _open_channel( | ||||
|         # methods that modify self._channels should be ordered by FIFO | ||||
|         self._lock = trio.StrictFIFOLock() | ||||
| 
 | ||||
|     @acm | ||||
|     async def maybe_lock(self): | ||||
|         ''' | ||||
|         If lock is not held, acquire | ||||
| 
 | ||||
|         ''' | ||||
|         if self._lock.locked(): | ||||
|             yield | ||||
|             return | ||||
| 
 | ||||
|         async with self._lock: | ||||
|             yield | ||||
| 
 | ||||
|     @property | ||||
|     def channels(self) -> list[ChannelInfo]: | ||||
|         return self._channels | ||||
| 
 | ||||
|     async def _channel_handler_task( | ||||
|         self, | ||||
|         name: str | ||||
|     ) -> AsyncContextManager[ChannelType]: | ||||
|         name: str, | ||||
|         task_status: trio.TASK_STATUS_IGNORED, | ||||
|         **kwargs | ||||
|     ): | ||||
|         ''' | ||||
|         Used to instantiate channel resources given a name | ||||
|         Open channel resources, add to internal data structures, signal channel | ||||
|         connect through trio.Event, and run `channel_task` with cancel scope, | ||||
|         and finally, maybe remove channel from internal data structures. | ||||
| 
 | ||||
|         Spawned by `add_channel` function, lock is held from begining of fn | ||||
|         until `task_status.started()` call. | ||||
| 
 | ||||
|         kwargs are proxied to `self._open_channel` acm. | ||||
|         ''' | ||||
|         ... | ||||
| 
 | ||||
|     async def _channel_task(self, info: ChannelInfo) -> None: | ||||
|         ''' | ||||
|         Long running task that manages the channel | ||||
| 
 | ||||
|         ''' | ||||
|         ... | ||||
| 
 | ||||
|     async def _channel_handler_task(self, name: str): | ||||
|         async with self._open_channel(name) as chan: | ||||
|             with trio.CancelScope() as cancel_scope: | ||||
|         async with self._open_channel(name, **kwargs) as chan: | ||||
|             cancel_scope = trio.CancelScope() | ||||
|             info = ChannelInfo( | ||||
|                     connect_time=time.time(), | ||||
|                 name=name, | ||||
|                 channel=chan, | ||||
|                 cancel_scope=cancel_scope | ||||
|             ) | ||||
|             self._channels.append(info) | ||||
| 
 | ||||
|             if len(self) == 1: | ||||
|                 self._connect_event.set() | ||||
| 
 | ||||
|             task_status.started() | ||||
| 
 | ||||
|             with cancel_scope: | ||||
|                 await self._channel_task(info) | ||||
| 
 | ||||
|         self._maybe_destroy_channel(name) | ||||
|         await self._maybe_destroy_channel(name) | ||||
| 
 | ||||
|     def find_channel(self, name: str) -> tuple[int, ChannelInfo] | None: | ||||
|     def _find_channel(self, name: str) -> tuple[int, ChannelInfo] | None: | ||||
|         ''' | ||||
|         Given a channel name maybe return its index and value from | ||||
|         internal _channels list. | ||||
| 
 | ||||
|         Only use after acquiring lock. | ||||
|         ''' | ||||
|         for entry in enumerate(self._channels): | ||||
|             i, info = entry | ||||
|  | @ -116,105 +155,114 @@ class ChannelManager(Protocol[ChannelType]): | |||
| 
 | ||||
|         return None | ||||
| 
 | ||||
|     def _maybe_destroy_channel(self, name: str): | ||||
| 
 | ||||
|     async def _maybe_destroy_channel(self, name: str): | ||||
|         ''' | ||||
|         If channel exists cancel its scope and remove from internal | ||||
|         _channels list. | ||||
| 
 | ||||
|         ''' | ||||
|         maybe_entry = self.find_channel(name) | ||||
|         async with self.maybe_lock(): | ||||
|             maybe_entry = self._find_channel(name) | ||||
|             if maybe_entry: | ||||
|                 i, info = maybe_entry | ||||
|                 info.cancel_scope.cancel() | ||||
|                 del self._channels[i] | ||||
| 
 | ||||
|     def add_channel(self, name: str): | ||||
|     async def add_channel(self, name: str, **kwargs): | ||||
|         ''' | ||||
|         Add a new channel to be handled | ||||
| 
 | ||||
|         ''' | ||||
|         self._n.start_soon( | ||||
|         async with self.maybe_lock(): | ||||
|             await self._n.start(partial( | ||||
|                 self._channel_handler_task, | ||||
|             name | ||||
|         ) | ||||
|                 name, | ||||
|                 **kwargs | ||||
|             )) | ||||
| 
 | ||||
|     def remove_channel(self, name: str): | ||||
|     async def remove_channel(self, name: str): | ||||
|         ''' | ||||
|         Remove a channel and stop its handling | ||||
| 
 | ||||
|         ''' | ||||
|         self._maybe_destroy_channel(name) | ||||
|         async with self.maybe_lock(): | ||||
|             await self._maybe_destroy_channel(name) | ||||
| 
 | ||||
|             # if that was last channel reset connect event | ||||
|             if len(self) == 0: | ||||
|                 self._connect_event = trio.Event() | ||||
| 
 | ||||
|     async def wait_for_channel(self): | ||||
|         ''' | ||||
|         Wait until at least one channel added | ||||
| 
 | ||||
|         ''' | ||||
|         await self._connect_event.wait() | ||||
|         self._connect_event = trio.Event() | ||||
| 
 | ||||
|     def __len__(self) -> int: | ||||
|         return len(self._channels) | ||||
| 
 | ||||
|     def __getitem__(self, name: str): | ||||
|         maybe_entry = self._find_channel(name) | ||||
|         if maybe_entry: | ||||
|             _, info = maybe_entry | ||||
|             return info | ||||
| 
 | ||||
|         raise KeyError(f'Channel {name} not found!') | ||||
| 
 | ||||
|     async def aclose(self) -> None: | ||||
|         for chan in self._channels: | ||||
|             self._maybe_destroy_channel(chan.name) | ||||
| 
 | ||||
|     async def __aenter__(self): | ||||
|         return self | ||||
| 
 | ||||
|     async def __aexit__(self, exc_type, exc_val, exc_tb): | ||||
|         await self.aclose() | ||||
|         async with self.maybe_lock(): | ||||
|             for info in self._channels: | ||||
|                 await self.remove_channel(info.name) | ||||
| 
 | ||||
| 
 | ||||
| class RingBuffPublisher( | ||||
|     ChannelManager[RingBuffBytesSender] | ||||
| ): | ||||
| ''' | ||||
|     Implement ChannelManager protocol + trio.abc.SendChannel[bytes] | ||||
|     using ring buffers as transport. | ||||
| 
 | ||||
|     - use a `trio.Event` to make sure `send` blocks until at least one channel | ||||
|       available. | ||||
| Ring buffer publisher & subscribe pattern mediated by `ringd` actor. | ||||
| 
 | ||||
| ''' | ||||
| 
 | ||||
| @dataclass | ||||
| class PublisherChannels: | ||||
|     ring: RingBufferSendChannel | ||||
|     schan: trio.MemorySendChannel | ||||
|     rchan: trio.MemoryReceiveChannel | ||||
| 
 | ||||
| 
 | ||||
| class RingBufferPublisher(trio.abc.SendChannel[bytes]): | ||||
|     ''' | ||||
|     Use ChannelManager to create a multi ringbuf round robin sender that can | ||||
|     dynamically add or remove more outputs. | ||||
| 
 | ||||
|     Don't instantiate directly, use `open_ringbuf_publisher` acm to manage its | ||||
|     lifecycle. | ||||
| 
 | ||||
|     ''' | ||||
|     def __init__( | ||||
|         self, | ||||
|         n: trio.Nursery, | ||||
| 
 | ||||
|         # new ringbufs created will have this buf_size | ||||
|         buf_size: int = 10 * 1024, | ||||
| 
 | ||||
|         # global batch size for all channels | ||||
|         batch_size: int = 1 | ||||
|     ): | ||||
|         super().__init__(n) | ||||
|         self._connect_event = trio.Event() | ||||
|         self._next_turn: int = 0 | ||||
| 
 | ||||
|         self._buf_size = buf_size | ||||
|         self._batch_size: int = batch_size | ||||
| 
 | ||||
|     @acm | ||||
|     async def _open_channel( | ||||
|         self, | ||||
|         name: str | ||||
|     ) -> AsyncContextManager[RingBuffBytesSender]: | ||||
|         async with ( | ||||
|             ringd.open_ringbuf( | ||||
|                 name=name, | ||||
|                 must_exist=True, | ||||
|             ) as token, | ||||
|             attach_to_ringbuf_schannel(token) as chan | ||||
|         ): | ||||
|             yield chan | ||||
|         self._chanmngr = ChannelManager[PublisherChannels]( | ||||
|             n, | ||||
|             self._open_channel, | ||||
|             self._channel_task | ||||
|         ) | ||||
| 
 | ||||
|     async def _channel_task(self, info: ChannelInfo) -> None: | ||||
|         self._connect_event.set() | ||||
|         await trio.sleep_forever() | ||||
|         # methods that send data over the channels need to be acquire send lock | ||||
|         # in order to guarantee order of operations | ||||
|         self._send_lock = trio.StrictFIFOLock() | ||||
| 
 | ||||
|     async def send(self, msg: bytes): | ||||
|         # wait at least one decoder connected | ||||
|         if len(self) == 0: | ||||
|             await self._connect_event.wait() | ||||
|             self._connect_event = trio.Event() | ||||
| 
 | ||||
|         if self._next_turn >= len(self): | ||||
|             self._next_turn = 0 | ||||
| 
 | ||||
|         turn = self._next_turn | ||||
|         self._next_turn += 1 | ||||
| 
 | ||||
|         output = self._channels[turn] | ||||
|         await output.channel.send(msg) | ||||
|         self._next_turn: int = 0 | ||||
| 
 | ||||
|     @property | ||||
|     def batch_size(self) -> int: | ||||
|  | @ -222,92 +270,273 @@ class RingBuffPublisher( | |||
| 
 | ||||
|     @batch_size.setter | ||||
|     def set_batch_size(self, value: int) -> None: | ||||
|         for output in self._channels: | ||||
|             output.channel.batch_size = value | ||||
|         for info in self.channels: | ||||
|             info.channel.ring.batch_size = value | ||||
| 
 | ||||
|     async def flush( | ||||
|     @property | ||||
|     def channels(self) -> list[ChannelInfo]: | ||||
|         return self._chanmngr.channels | ||||
| 
 | ||||
|     def get_channel(self, name: str) -> ChannelInfo: | ||||
|         ''' | ||||
|         Get underlying ChannelInfo from name | ||||
| 
 | ||||
|         ''' | ||||
|         return self._chanmngr[name] | ||||
| 
 | ||||
|     async def add_channel( | ||||
|         self, | ||||
|         new_batch_size: int | None = None | ||||
|         name: str, | ||||
|         must_exist: bool = False | ||||
|     ): | ||||
|         for output in self._channels: | ||||
|             await output.channel.flush( | ||||
|                 new_batch_size=new_batch_size | ||||
|         ''' | ||||
|         Store additional runtime info for channel and add channel to underlying | ||||
|         ChannelManager | ||||
| 
 | ||||
|         ''' | ||||
|         await self._chanmngr.add_channel(name, must_exist=must_exist) | ||||
| 
 | ||||
|     async def remove_channel(self, name: str): | ||||
|         ''' | ||||
|         Send EOF to channel (does `channel.flush` also) then remove from | ||||
|         `ChannelManager` acquire both `self._send_lock` and | ||||
|         `self._chanmngr.maybe_lock()` in order to ensure no channel | ||||
|         modifications or sends happen concurrenty | ||||
|         ''' | ||||
|         async with self._chanmngr.maybe_lock(): | ||||
|             # ensure all pending messages are sent | ||||
|             info = self.get_channel(name) | ||||
| 
 | ||||
|             try: | ||||
|                 while True: | ||||
|                     msg = info.channel.rchan.receive_nowait() | ||||
|                     await info.channel.ring.send(msg) | ||||
| 
 | ||||
|             except trio.WouldBlock: | ||||
|                 await info.channel.ring.flush() | ||||
| 
 | ||||
|             await info.channel.schan.aclose() | ||||
| 
 | ||||
|             # finally remove from ChannelManager | ||||
|             await self._chanmngr.remove_channel(name) | ||||
| 
 | ||||
|     @acm | ||||
|     async def _open_channel( | ||||
| 
 | ||||
|         self, | ||||
|         name: str, | ||||
|         must_exist: bool = False | ||||
| 
 | ||||
|     ) -> AsyncContextManager[PublisherChannels]: | ||||
|         ''' | ||||
|         Open a ringbuf through `ringd` and attach as send side | ||||
|         ''' | ||||
|         async with ( | ||||
|             ringd.open_ringbuf( | ||||
|                 name=name, | ||||
|                 buf_size=self._buf_size, | ||||
|                 must_exist=must_exist, | ||||
|             ) as token, | ||||
|             attach_to_ringbuf_sender(token) as ring, | ||||
|         ): | ||||
|             schan, rchan = trio.open_memory_channel(0) | ||||
|             yield PublisherChannels( | ||||
|                 ring=ring, | ||||
|                 schan=schan, | ||||
|                 rchan=rchan | ||||
|             ) | ||||
| 
 | ||||
|     async def send_eof(self): | ||||
|         for output in self._channels: | ||||
|             await output.channel.send_eof() | ||||
|     async def _channel_task(self, info: ChannelInfo) -> None: | ||||
|         ''' | ||||
|         Forever get current runtime info for channel, wait on its next pending | ||||
|         payloads update event then drain all into send channel. | ||||
| 
 | ||||
|         ''' | ||||
|         async for msg in info.channel.rchan: | ||||
|             await info.channel.ring.send(msg) | ||||
| 
 | ||||
|     async def send(self, msg: bytes): | ||||
|         ''' | ||||
|         If no output channels connected, wait until one, then fetch the next | ||||
|         channel based on turn, add the indexed payload and update | ||||
|         `self._next_turn` & `self._next_index`. | ||||
| 
 | ||||
|         Needs to acquire `self._send_lock` to make sure updates to turn & index | ||||
|         variables dont happen out of order. | ||||
| 
 | ||||
|         ''' | ||||
|         async with self._send_lock: | ||||
|             # wait at least one decoder connected | ||||
|             if len(self.channels) == 0: | ||||
|                 await self._chanmngr.wait_for_channel() | ||||
| 
 | ||||
|             if self._next_turn >= len(self.channels): | ||||
|                 self._next_turn = 0 | ||||
| 
 | ||||
|             info = self.channels[self._next_turn] | ||||
|             await info.channel.schan.send(msg) | ||||
| 
 | ||||
|             self._next_turn += 1 | ||||
| 
 | ||||
|     async def aclose(self) -> None: | ||||
|         await self._chanmngr.aclose() | ||||
| 
 | ||||
| 
 | ||||
| @acm | ||||
| async def open_ringbuf_publisher( | ||||
| 
 | ||||
|     buf_size: int = 10 * 1024, | ||||
|     batch_size: int = 1 | ||||
| ): | ||||
|     batch_size: int = 1, | ||||
|     guarantee_order: bool = False, | ||||
|     force_cancel: bool = False | ||||
| 
 | ||||
| ) -> AsyncContextManager[RingBufferPublisher]: | ||||
|     ''' | ||||
|     Open a new ringbuf publisher | ||||
| 
 | ||||
|     ''' | ||||
|     async with ( | ||||
|         trio.open_nursery() as n, | ||||
|         RingBuffPublisher( | ||||
|         RingBufferPublisher( | ||||
|             n, | ||||
|             buf_size=buf_size, | ||||
|             batch_size=batch_size | ||||
|         ) as outputs | ||||
|         ) as publisher | ||||
|     ): | ||||
|         yield outputs | ||||
|         if guarantee_order: | ||||
|             order_send_channel(publisher) | ||||
| 
 | ||||
|         yield publisher | ||||
| 
 | ||||
|         if force_cancel: | ||||
|             # implicitly cancel any running channel handler task | ||||
|             n.cancel_scope.cancel() | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| class RingBuffSubscriber( | ||||
|     ChannelManager[RingBuffBytesReceiver] | ||||
| ): | ||||
| class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]): | ||||
|     ''' | ||||
|     Implement ChannelManager protocol + trio.abc.ReceiveChannel[bytes] | ||||
|     using ring buffers as transport. | ||||
|     Use ChannelManager to create a multi ringbuf receiver that can | ||||
|     dynamically add or remove more inputs and combine all into a single output. | ||||
| 
 | ||||
|     - use a trio memory channel pair to multiplex all received messages into a | ||||
|       single `trio.MemoryReceiveChannel`, give a sender channel clone to each | ||||
|       _channel_task. | ||||
|     In order for `self.receive` messages to be returned in order, publisher | ||||
|     will send all payloads as `OrderedPayload` msgpack encoded msgs, this | ||||
|     allows our channel handler tasks to just stash the out of order payloads | ||||
|     inside `self._pending_payloads` and if a in order payload is available | ||||
|     signal through `self._new_payload_event`. | ||||
| 
 | ||||
|     On `self.receive` we wait until at least one channel is connected, then if | ||||
|     an in order payload is pending, we pop and return it, in case no in order | ||||
|     payload is available wait until next `self._new_payload_event.set()`. | ||||
| 
 | ||||
|     ''' | ||||
|     def __init__( | ||||
|         self, | ||||
|         n: trio.Nursery, | ||||
| 
 | ||||
|         # if connecting to a publisher that has already sent messages set  | ||||
|         # to the next expected payload index this subscriber will receive | ||||
|         start_index: int = 0 | ||||
|     ): | ||||
|         super().__init__(n) | ||||
|         self._send_chan, self._recv_chan = trio.open_memory_channel(0) | ||||
|         self._chanmngr = ChannelManager[RingBufferReceiveChannel]( | ||||
|             n, | ||||
|             self._open_channel, | ||||
|             self._channel_task | ||||
|         ) | ||||
| 
 | ||||
|         self._schan, self._rchan = trio.open_memory_channel(0) | ||||
| 
 | ||||
|     @property | ||||
|     def channels(self) -> list[ChannelInfo]: | ||||
|         return self._chanmngr.channels | ||||
| 
 | ||||
|     def get_channel(self, name: str): | ||||
|         return self._chanmngr[name] | ||||
| 
 | ||||
|     async def add_channel(self, name: str, must_exist: bool = False): | ||||
|         ''' | ||||
|         Add new input channel by name | ||||
| 
 | ||||
|         ''' | ||||
|         await self._chanmngr.add_channel(name, must_exist=must_exist) | ||||
| 
 | ||||
|     async def remove_channel(self, name: str): | ||||
|         ''' | ||||
|         Remove an input channel by name | ||||
| 
 | ||||
|         ''' | ||||
|         await self._chanmngr.remove_channel(name) | ||||
| 
 | ||||
|     @acm | ||||
|     async def _open_channel( | ||||
| 
 | ||||
|         self, | ||||
|         name: str | ||||
|     ) -> AsyncContextManager[RingBuffBytesReceiver]: | ||||
|         name: str, | ||||
|         must_exist: bool = False | ||||
| 
 | ||||
|     ) -> AsyncContextManager[RingBufferReceiveChannel]: | ||||
|         ''' | ||||
|         Open a ringbuf through `ringd` and attach as receiver side | ||||
|         ''' | ||||
|         async with ( | ||||
|             ringd.open_ringbuf( | ||||
|                 name=name, | ||||
|                 must_exist=True, | ||||
|                 must_exist=must_exist, | ||||
|             ) as token, | ||||
|             attach_to_ringbuf_rchannel(token) as chan | ||||
|             attach_to_ringbuf_receiver(token) as chan | ||||
|         ): | ||||
|             yield chan | ||||
| 
 | ||||
|     async def _channel_task(self, info: ChannelInfo) -> None: | ||||
|         send_chan = self._send_chan.clone() | ||||
|         try: | ||||
|             async for msg in info.channel: | ||||
|                 await send_chan.send(msg) | ||||
|         ''' | ||||
|         Iterate over receive channel messages, decode them as `OrderedPayload`s | ||||
|         and stash them in `self._pending_payloads`, in case we can pop next in | ||||
|         order payload, signal through setting `self._new_payload_event`. | ||||
| 
 | ||||
|         except tractor._exceptions.InternalError: | ||||
|             # TODO: cleaner cancellation! | ||||
|             ... | ||||
|         ''' | ||||
|         while True: | ||||
|             try: | ||||
|                 msg = await info.channel.receive() | ||||
|                 await self._schan.send(msg) | ||||
| 
 | ||||
|             except tractor.linux.eventfd.EFDReadCancelled as e: | ||||
|                 # when channel gets removed while we are doing a receive | ||||
|                 log.exception(e) | ||||
|                 break | ||||
| 
 | ||||
|             except trio.EndOfChannel: | ||||
|                 break | ||||
| 
 | ||||
|     async def receive(self) -> bytes: | ||||
|         return await self._recv_chan.receive() | ||||
|         ''' | ||||
|         Receive next in order msg | ||||
|         ''' | ||||
|         return await self._rchan.receive() | ||||
| 
 | ||||
|     async def aclose(self) -> None: | ||||
|         await self._chanmngr.aclose() | ||||
| 
 | ||||
| @acm | ||||
| async def open_ringbuf_subscriber(): | ||||
| async def open_ringbuf_subscriber( | ||||
| 
 | ||||
|     guarantee_order: bool = False, | ||||
|     force_cancel: bool = False | ||||
| 
 | ||||
| ) -> AsyncContextManager[RingBufferPublisher]: | ||||
|     ''' | ||||
|     Open a new ringbuf subscriber | ||||
| 
 | ||||
|     ''' | ||||
|     async with ( | ||||
|         trio.open_nursery() as n, | ||||
|         RingBuffSubscriber(n) as inputs | ||||
|         RingBufferSubscriber( | ||||
|             n, | ||||
|         ) as subscriber | ||||
|     ): | ||||
|         yield inputs | ||||
|         if guarantee_order: | ||||
|             order_receive_channel(subscriber) | ||||
| 
 | ||||
|         yield subscriber | ||||
| 
 | ||||
|         if force_cancel: | ||||
|             # implicitly cancel any running channel handler task | ||||
|             n.cancel_scope.cancel() | ||||
|  |  | |||
|  | @ -32,3 +32,8 @@ from ._broadcast import ( | |||
| from ._beg import ( | ||||
|     collapse_eg as collapse_eg, | ||||
| ) | ||||
| 
 | ||||
| from ._ordering import ( | ||||
|     order_send_channel as order_send_channel, | ||||
|     order_receive_channel as order_receive_channel | ||||
| ) | ||||
|  |  | |||
|  | @ -0,0 +1,89 @@ | |||
| from __future__ import annotations | ||||
| from heapq import ( | ||||
|     heappush, | ||||
|     heappop | ||||
| ) | ||||
| 
 | ||||
| import trio | ||||
| import msgspec | ||||
| 
 | ||||
| 
 | ||||
| class OrderedPayload(msgspec.Struct, frozen=True): | ||||
|     index: int | ||||
|     payload: bytes | ||||
| 
 | ||||
|     @classmethod | ||||
|     def from_msg(cls, msg: bytes) -> OrderedPayload: | ||||
|         return msgspec.msgpack.decode(msg, type=OrderedPayload) | ||||
| 
 | ||||
|     def encode(self) -> bytes: | ||||
|         return msgspec.msgpack.encode(self) | ||||
| 
 | ||||
| 
 | ||||
| def order_send_channel( | ||||
|     channel: trio.abc.SendChannel[bytes], | ||||
|     start_index: int = 0 | ||||
| ): | ||||
| 
 | ||||
|     next_index = start_index | ||||
|     send_lock = trio.StrictFIFOLock() | ||||
| 
 | ||||
|     channel._send = channel.send | ||||
|     channel._aclose = channel.aclose | ||||
| 
 | ||||
|     async def send(msg: bytes): | ||||
|         nonlocal next_index | ||||
|         async with send_lock: | ||||
|             await channel._send( | ||||
|                 OrderedPayload( | ||||
|                     index=next_index, | ||||
|                     payload=msg | ||||
|                 ).encode() | ||||
|             ) | ||||
|             next_index += 1 | ||||
| 
 | ||||
|     async def aclose(): | ||||
|         async with send_lock: | ||||
|             await channel._aclose() | ||||
| 
 | ||||
|     channel.send = send | ||||
|     channel.aclose = aclose | ||||
| 
 | ||||
| 
 | ||||
| def order_receive_channel( | ||||
|     channel: trio.abc.ReceiveChannel[bytes], | ||||
|     start_index: int = 0 | ||||
| ): | ||||
|     next_index = start_index | ||||
|     pqueue = [] | ||||
| 
 | ||||
|     channel._receive = channel.receive | ||||
| 
 | ||||
|     def can_pop_next() -> bool: | ||||
|         return ( | ||||
|             len(pqueue) > 0 | ||||
|             and | ||||
|             pqueue[0][0] == next_index | ||||
|         ) | ||||
| 
 | ||||
|     async def drain_to_heap(): | ||||
|         while not can_pop_next(): | ||||
|             msg = await channel._receive() | ||||
|             msg = OrderedPayload.from_msg(msg) | ||||
|             heappush(pqueue, (msg.index, msg.payload)) | ||||
| 
 | ||||
|     def pop_next(): | ||||
|         nonlocal next_index | ||||
|         _, msg = heappop(pqueue) | ||||
|         next_index += 1 | ||||
|         return msg | ||||
| 
 | ||||
|     async def receive() -> bytes: | ||||
|         if can_pop_next(): | ||||
|             return pop_next() | ||||
| 
 | ||||
|         await drain_to_heap() | ||||
| 
 | ||||
|         return pop_next() | ||||
| 
 | ||||
|     channel.receive = receive | ||||
		Loading…
	
		Reference in New Issue