Fully test and fix bugs on _ringbuf._pubsub
Add generic channel orderer
							parent
							
								
									0b9c2de3ad
								
							
						
					
					
						commit
						3c1873c68a
					
				|  | @ -92,187 +92,187 @@ def test_ringd(): | ||||||
|     trio.run(main) |     trio.run(main) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| # class Struct(msgspec.Struct): | class Struct(msgspec.Struct): | ||||||
| #  | 
 | ||||||
| #     def encode(self) -> bytes: |     def encode(self) -> bytes: | ||||||
| #         return msgspec.msgpack.encode(self) |         return msgspec.msgpack.encode(self) | ||||||
| #  | 
 | ||||||
| #  | 
 | ||||||
| # class AddChannelMsg(Struct, frozen=True, tag=True): | class AddChannelMsg(Struct, frozen=True, tag=True): | ||||||
| #     name: str |     name: str | ||||||
| #  | 
 | ||||||
| #  | 
 | ||||||
| # class RemoveChannelMsg(Struct, frozen=True, tag=True): | class RemoveChannelMsg(Struct, frozen=True, tag=True): | ||||||
| #     name: str |     name: str | ||||||
| #  | 
 | ||||||
| #  | 
 | ||||||
| # class RangeMsg(Struct, frozen=True, tag=True): | class RangeMsg(Struct, frozen=True, tag=True): | ||||||
| #     start: int |     start: int | ||||||
| #     end: int |     end: int | ||||||
| #  | 
 | ||||||
| #  | 
 | ||||||
| # ControlMessages = AddChannelMsg | RemoveChannelMsg | RangeMsg | ControlMessages = AddChannelMsg | RemoveChannelMsg | RangeMsg | ||||||
| #  | 
 | ||||||
| #  | 
 | ||||||
| # @tractor.context | @tractor.context | ||||||
| # async def subscriber_child(ctx: tractor.Context): | async def subscriber_child(ctx: tractor.Context): | ||||||
| #     await ctx.started() |     await ctx.started() | ||||||
| #     async with ( |     async with ( | ||||||
| #         open_ringbuf_subscriber(guarantee_order=True) as subs, |         open_ringbuf_subscriber(guarantee_order=True) as subs, | ||||||
| #         trio.open_nursery() as n, |         trio.open_nursery() as n, | ||||||
| #         ctx.open_stream() as stream |         ctx.open_stream() as stream | ||||||
| #     ): |     ): | ||||||
| #         range_msg = None |         range_msg = None | ||||||
| #         range_event = trio.Event() |         range_event = trio.Event() | ||||||
| #         range_scope = trio.CancelScope() |         range_scope = trio.CancelScope() | ||||||
| #  | 
 | ||||||
| #         async def _control_listen_task(): |         async def _control_listen_task(): | ||||||
| #             nonlocal range_msg, range_event |             nonlocal range_msg, range_event | ||||||
| #             async for msg in stream: |             async for msg in stream: | ||||||
| #                 msg = msgspec.msgpack.decode(msg, type=ControlMessages) |                 msg = msgspec.msgpack.decode(msg, type=ControlMessages) | ||||||
| #                 match msg: |                 match msg: | ||||||
| #                     case AddChannelMsg(): |                     case AddChannelMsg(): | ||||||
| #                         await subs.add_channel(msg.name, must_exist=False) |                         await subs.add_channel(msg.name, must_exist=False) | ||||||
| #  | 
 | ||||||
| #                     case RemoveChannelMsg(): |                     case RemoveChannelMsg(): | ||||||
| #                         await subs.remove_channel(msg.name) |                         await subs.remove_channel(msg.name) | ||||||
| #  | 
 | ||||||
| #                     case RangeMsg(): |                     case RangeMsg(): | ||||||
| #                         range_msg = msg |                         range_msg = msg | ||||||
| #                         range_event.set() |                         range_event.set() | ||||||
| #  | 
 | ||||||
| #                 await stream.send(b'ack') |                 await stream.send(b'ack') | ||||||
| #  | 
 | ||||||
| #             range_scope.cancel() |             range_scope.cancel() | ||||||
| #  | 
 | ||||||
| #         n.start_soon(_control_listen_task) |         n.start_soon(_control_listen_task) | ||||||
| #  | 
 | ||||||
| #         with range_scope: |         with range_scope: | ||||||
| #             while True: |             while True: | ||||||
| #                 await range_event.wait() |                 await range_event.wait() | ||||||
| #                 range_event = trio.Event() |                 range_event = trio.Event() | ||||||
| #                 for i in range(range_msg.start, range_msg.end): |                 for i in range(range_msg.start, range_msg.end): | ||||||
| #                     recv = int.from_bytes(await subs.receive()) |                     recv = int.from_bytes(await subs.receive()) | ||||||
| #                     # if recv != i: |                     # if recv != i: | ||||||
| #                     #     raise AssertionError( |                     #     raise AssertionError( | ||||||
| #                     #         f'received: {recv} expected: {i}' |                     #         f'received: {recv} expected: {i}' | ||||||
| #                     #     ) |                     #     ) | ||||||
| #  | 
 | ||||||
| #                     log.info(f'received: {recv} expected: {i}') |                     log.info(f'received: {recv} expected: {i}') | ||||||
| #  | 
 | ||||||
| #                 await stream.send(b'valid range') |                 await stream.send(b'valid range') | ||||||
| #                 log.info('FINISHED RANGE') |                 log.info('FINISHED RANGE') | ||||||
| #  | 
 | ||||||
| #     log.info('subscriber exit') |     log.info('subscriber exit') | ||||||
| #  | 
 | ||||||
| #  | 
 | ||||||
| # @tractor.context | @tractor.context | ||||||
| # async def publisher_child(ctx: tractor.Context): | async def publisher_child(ctx: tractor.Context): | ||||||
| #     await ctx.started() |     await ctx.started() | ||||||
| #     async with ( |     async with ( | ||||||
| #         open_ringbuf_publisher(batch_size=1, guarantee_order=True) as pub, |         open_ringbuf_publisher(batch_size=1, guarantee_order=True) as pub, | ||||||
| #         ctx.open_stream() as stream |         ctx.open_stream() as stream | ||||||
| #     ): |     ): | ||||||
| #         abs_index = 0 |         abs_index = 0 | ||||||
| #         async for msg in stream: |         async for msg in stream: | ||||||
| #             msg = msgspec.msgpack.decode(msg, type=ControlMessages) |             msg = msgspec.msgpack.decode(msg, type=ControlMessages) | ||||||
| #             match msg: |             match msg: | ||||||
| #                 case AddChannelMsg(): |                 case AddChannelMsg(): | ||||||
| #                     await pub.add_channel(msg.name, must_exist=True) |                     await pub.add_channel(msg.name, must_exist=True) | ||||||
| #  | 
 | ||||||
| #                 case RemoveChannelMsg(): |                 case RemoveChannelMsg(): | ||||||
| #                     await pub.remove_channel(msg.name) |                     await pub.remove_channel(msg.name) | ||||||
| #  | 
 | ||||||
| #                 case RangeMsg(): |                 case RangeMsg(): | ||||||
| #                     for i in range(msg.start, msg.end): |                     for i in range(msg.start, msg.end): | ||||||
| #                         await pub.send(i.to_bytes(4)) |                         await pub.send(i.to_bytes(4)) | ||||||
| #                         log.info(f'sent {i}, index: {abs_index}') |                         log.info(f'sent {i}, index: {abs_index}') | ||||||
| #                         abs_index += 1 |                         abs_index += 1 | ||||||
| #  | 
 | ||||||
| #             await stream.send(b'ack') |             await stream.send(b'ack') | ||||||
| #  | 
 | ||||||
| #     log.info('publisher exit') |     log.info('publisher exit') | ||||||
| #  | 
 | ||||||
| #  | 
 | ||||||
| #  | 
 | ||||||
| # def test_pubsub(): | def test_pubsub(): | ||||||
| #     ''' |     ''' | ||||||
| #     Spawn ringd actor and two childs that access same ringbuf through ringd. |     Spawn ringd actor and two childs that access same ringbuf through ringd. | ||||||
| #  | 
 | ||||||
| #     Both will use `ringd.open_ringbuf` to allocate the ringbuf, then attach to |     Both will use `ringd.open_ringbuf` to allocate the ringbuf, then attach to | ||||||
| #     them as sender and receiver. |     them as sender and receiver. | ||||||
| #  | 
 | ||||||
| #     ''' |     ''' | ||||||
| #     async def main(): |     async def main(): | ||||||
| #         async with ( |         async with ( | ||||||
| #             tractor.open_nursery( |             tractor.open_nursery( | ||||||
| #                 loglevel='info', |                 loglevel='info', | ||||||
| #                 # debug_mode=True, |                 # debug_mode=True, | ||||||
| #                 # enable_stack_on_sig=True |                 # enable_stack_on_sig=True | ||||||
| #             ) as an, |             ) as an, | ||||||
| #  | 
 | ||||||
| #             ringd.open_ringd() |             ringd.open_ringd() | ||||||
| #         ): |         ): | ||||||
| #             recv_portal = await an.start_actor( |             recv_portal = await an.start_actor( | ||||||
| #                 'recv', |                 'recv', | ||||||
| #                 enable_modules=[__name__] |                 enable_modules=[__name__] | ||||||
| #             ) |             ) | ||||||
| #             send_portal = await an.start_actor( |             send_portal = await an.start_actor( | ||||||
| #                 'send', |                 'send', | ||||||
| #                 enable_modules=[__name__] |                 enable_modules=[__name__] | ||||||
| #             ) |             ) | ||||||
| #  | 
 | ||||||
| #             async with ( |             async with ( | ||||||
| #                 recv_portal.open_context(subscriber_child) as (rctx, _), |                 recv_portal.open_context(subscriber_child) as (rctx, _), | ||||||
| #                 rctx.open_stream() as recv_stream, |                 rctx.open_stream() as recv_stream, | ||||||
| #                 send_portal.open_context(publisher_child) as (sctx, _), |                 send_portal.open_context(publisher_child) as (sctx, _), | ||||||
| #                 sctx.open_stream() as send_stream, |                 sctx.open_stream() as send_stream, | ||||||
| #             ): |             ): | ||||||
| #                 async def send_wait_ack(msg: bytes): |                 async def send_wait_ack(msg: bytes): | ||||||
| #                     await recv_stream.send(msg) |                     await recv_stream.send(msg) | ||||||
| #                     ack = await recv_stream.receive() |                     ack = await recv_stream.receive() | ||||||
| #                     assert ack == b'ack' |                     assert ack == b'ack' | ||||||
| #  | 
 | ||||||
| #                     await send_stream.send(msg) |                     await send_stream.send(msg) | ||||||
| #                     ack = await send_stream.receive() |                     ack = await send_stream.receive() | ||||||
| #                     assert ack == b'ack' |                     assert ack == b'ack' | ||||||
| #  | 
 | ||||||
| #                 async def add_channel(name: str): |                 async def add_channel(name: str): | ||||||
| #                     await send_wait_ack(AddChannelMsg(name=name).encode()) |                     await send_wait_ack(AddChannelMsg(name=name).encode()) | ||||||
| #  | 
 | ||||||
| #                 async def remove_channel(name: str): |                 async def remove_channel(name: str): | ||||||
| #                     await send_wait_ack(RemoveChannelMsg(name=name).encode()) |                     await send_wait_ack(RemoveChannelMsg(name=name).encode()) | ||||||
| #  | 
 | ||||||
| #                 async def send_range(start: int, end: int): |                 async def send_range(start: int, end: int): | ||||||
| #                     await send_wait_ack(RangeMsg(start=start, end=end).encode()) |                     await send_wait_ack(RangeMsg(start=start, end=end).encode()) | ||||||
| #                     range_ack = await recv_stream.receive() |                     range_ack = await recv_stream.receive() | ||||||
| #                     assert range_ack == b'valid range' |                     assert range_ack == b'valid range' | ||||||
| #  | 
 | ||||||
| #                 # simple test, open one channel and send 0..100 range |                 # simple test, open one channel and send 0..100 range | ||||||
| #                 ring_name = 'ring-first' |                 ring_name = 'ring-first' | ||||||
| #                 await add_channel(ring_name) |                 await add_channel(ring_name) | ||||||
| #                 await send_range(0, 100) |                 await send_range(0, 100) | ||||||
| #                 await remove_channel(ring_name) |                 await remove_channel(ring_name) | ||||||
| #  | 
 | ||||||
| #                 # redo |                 # redo | ||||||
| #                 ring_name = 'ring-redo' |                 ring_name = 'ring-redo' | ||||||
| #                 await add_channel(ring_name) |                 await add_channel(ring_name) | ||||||
| #                 await send_range(0, 100) |                 await send_range(0, 100) | ||||||
| #                 await remove_channel(ring_name) |                 await remove_channel(ring_name) | ||||||
| #  | 
 | ||||||
| #                 # multi chan test |                 # multi chan test | ||||||
| #                 ring_names = [] |                 ring_names = [] | ||||||
| #                 for i in range(3): |                 for i in range(3): | ||||||
| #                     ring_names.append(f'multi-ring-{i}') |                     ring_names.append(f'multi-ring-{i}') | ||||||
| #  | 
 | ||||||
| #                 for name in ring_names: |                 for name in ring_names: | ||||||
| #                     await add_channel(name) |                     await add_channel(name) | ||||||
| #  | 
 | ||||||
| #                 await send_range(0, 300) |                 await send_range(0, 300) | ||||||
| #  | 
 | ||||||
| #                 for name in ring_names: |                 for name in ring_names: | ||||||
| #                     await remove_channel(name) |                     await remove_channel(name) | ||||||
| #  | 
 | ||||||
| #             await an.cancel() |             await an.cancel() | ||||||
| #  | 
 | ||||||
| #     trio.run(main) |     trio.run(main) | ||||||
|  |  | ||||||
|  | @ -17,13 +17,14 @@ | ||||||
| Ring buffer ipc publish-subscribe mechanism brokered by ringd | Ring buffer ipc publish-subscribe mechanism brokered by ringd | ||||||
| can dynamically add new outputs (publisher) or inputs (subscriber) | can dynamically add new outputs (publisher) or inputs (subscriber) | ||||||
| ''' | ''' | ||||||
| import time |  | ||||||
| from typing import ( | from typing import ( | ||||||
|     runtime_checkable, |  | ||||||
|     Protocol, |  | ||||||
|     TypeVar, |     TypeVar, | ||||||
|  |     Generic, | ||||||
|  |     Callable, | ||||||
|  |     Awaitable, | ||||||
|     AsyncContextManager |     AsyncContextManager | ||||||
| ) | ) | ||||||
|  | from functools import partial | ||||||
| from contextlib import asynccontextmanager as acm | from contextlib import asynccontextmanager as acm | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
| 
 | 
 | ||||||
|  | @ -31,12 +32,16 @@ import trio | ||||||
| import tractor | import tractor | ||||||
| 
 | 
 | ||||||
| from tractor.ipc import ( | from tractor.ipc import ( | ||||||
|     RingBuffBytesSender, |     RingBufferSendChannel, | ||||||
|     RingBuffBytesReceiver, |     RingBufferReceiveChannel, | ||||||
|     attach_to_ringbuf_schannel, |     attach_to_ringbuf_sender, | ||||||
|     attach_to_ringbuf_rchannel |     attach_to_ringbuf_receiver | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | from tractor.trionics import ( | ||||||
|  |     order_send_channel, | ||||||
|  |     order_receive_channel | ||||||
|  | ) | ||||||
| import tractor.ipc._ringbuf._ringd as ringd | import tractor.ipc._ringbuf._ringd as ringd | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -48,66 +53,100 @@ ChannelType = TypeVar('ChannelType') | ||||||
| 
 | 
 | ||||||
| @dataclass | @dataclass | ||||||
| class ChannelInfo: | class ChannelInfo: | ||||||
|     connect_time: float |  | ||||||
|     name: str |     name: str | ||||||
|     channel: ChannelType |     channel: ChannelType | ||||||
|     cancel_scope: trio.CancelScope |     cancel_scope: trio.CancelScope | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| # TODO: maybe move this abstraction to another module or standalone? | class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]): | ||||||
| # its not ring buf specific and allows fan out and fan in an a dynamic |  | ||||||
| # amount of channels |  | ||||||
| @runtime_checkable |  | ||||||
| class ChannelManager(Protocol[ChannelType]): |  | ||||||
|     ''' |     ''' | ||||||
|     Common data structures and methods pubsub classes use to manage channels & |     Helper for managing channel resources and their handler tasks with | ||||||
|     their related handler background tasks, as well as cancellation of them. |     cancellation, add or remove channels dynamically! | ||||||
| 
 | 
 | ||||||
|     ''' |     ''' | ||||||
| 
 | 
 | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|  |         # nursery used to spawn channel handler tasks | ||||||
|         n: trio.Nursery, |         n: trio.Nursery, | ||||||
|  | 
 | ||||||
|  |         # acm will be used for setup & teardown of channel resources | ||||||
|  |         open_channel_acm: Callable[..., AsyncContextManager[ChannelType]], | ||||||
|  | 
 | ||||||
|  |         # long running bg task to handle channel | ||||||
|  |         channel_task: Callable[..., Awaitable[None]] | ||||||
|     ): |     ): | ||||||
|         self._n = n |         self._n = n | ||||||
|  |         self._open_channel = open_channel_acm | ||||||
|  |         self._channel_task = channel_task | ||||||
|  | 
 | ||||||
|  |         # signal when a new channel conects and we previously had none | ||||||
|  |         self._connect_event = trio.Event() | ||||||
|  | 
 | ||||||
|  |         # store channel runtime variables | ||||||
|         self._channels: list[ChannelInfo] = [] |         self._channels: list[ChannelInfo] = [] | ||||||
| 
 | 
 | ||||||
|     async def _open_channel( |         # methods that modify self._channels should be ordered by FIFO | ||||||
|  |         self._lock = trio.StrictFIFOLock() | ||||||
|  | 
 | ||||||
|  |     @acm | ||||||
|  |     async def maybe_lock(self): | ||||||
|  |         ''' | ||||||
|  |         If lock is not held, acquire | ||||||
|  | 
 | ||||||
|  |         ''' | ||||||
|  |         if self._lock.locked(): | ||||||
|  |             yield | ||||||
|  |             return | ||||||
|  | 
 | ||||||
|  |         async with self._lock: | ||||||
|  |             yield | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def channels(self) -> list[ChannelInfo]: | ||||||
|  |         return self._channels | ||||||
|  | 
 | ||||||
|  |     async def _channel_handler_task( | ||||||
|         self, |         self, | ||||||
|         name: str |         name: str, | ||||||
|     ) -> AsyncContextManager[ChannelType]: |         task_status: trio.TASK_STATUS_IGNORED, | ||||||
|  |         **kwargs | ||||||
|  |     ): | ||||||
|         ''' |         ''' | ||||||
|         Used to instantiate channel resources given a name |         Open channel resources, add to internal data structures, signal channel | ||||||
|  |         connect through trio.Event, and run `channel_task` with cancel scope, | ||||||
|  |         and finally, maybe remove channel from internal data structures. | ||||||
| 
 | 
 | ||||||
|  |         Spawned by `add_channel` function, lock is held from begining of fn | ||||||
|  |         until `task_status.started()` call. | ||||||
|  | 
 | ||||||
|  |         kwargs are proxied to `self._open_channel` acm. | ||||||
|         ''' |         ''' | ||||||
|         ... |         async with self._open_channel(name, **kwargs) as chan: | ||||||
| 
 |             cancel_scope = trio.CancelScope() | ||||||
|     async def _channel_task(self, info: ChannelInfo) -> None: |  | ||||||
|         ''' |  | ||||||
|         Long running task that manages the channel |  | ||||||
| 
 |  | ||||||
|         ''' |  | ||||||
|         ... |  | ||||||
| 
 |  | ||||||
|     async def _channel_handler_task(self, name: str): |  | ||||||
|         async with self._open_channel(name) as chan: |  | ||||||
|             with trio.CancelScope() as cancel_scope: |  | ||||||
|             info = ChannelInfo( |             info = ChannelInfo( | ||||||
|                     connect_time=time.time(), |  | ||||||
|                 name=name, |                 name=name, | ||||||
|                 channel=chan, |                 channel=chan, | ||||||
|                 cancel_scope=cancel_scope |                 cancel_scope=cancel_scope | ||||||
|             ) |             ) | ||||||
|             self._channels.append(info) |             self._channels.append(info) | ||||||
|  | 
 | ||||||
|  |             if len(self) == 1: | ||||||
|  |                 self._connect_event.set() | ||||||
|  | 
 | ||||||
|  |             task_status.started() | ||||||
|  | 
 | ||||||
|  |             with cancel_scope: | ||||||
|                 await self._channel_task(info) |                 await self._channel_task(info) | ||||||
| 
 | 
 | ||||||
|         self._maybe_destroy_channel(name) |         await self._maybe_destroy_channel(name) | ||||||
| 
 | 
 | ||||||
|     def find_channel(self, name: str) -> tuple[int, ChannelInfo] | None: |     def _find_channel(self, name: str) -> tuple[int, ChannelInfo] | None: | ||||||
|         ''' |         ''' | ||||||
|         Given a channel name maybe return its index and value from |         Given a channel name maybe return its index and value from | ||||||
|         internal _channels list. |         internal _channels list. | ||||||
| 
 | 
 | ||||||
|  |         Only use after acquiring lock. | ||||||
|         ''' |         ''' | ||||||
|         for entry in enumerate(self._channels): |         for entry in enumerate(self._channels): | ||||||
|             i, info = entry |             i, info = entry | ||||||
|  | @ -116,105 +155,114 @@ class ChannelManager(Protocol[ChannelType]): | ||||||
| 
 | 
 | ||||||
|         return None |         return None | ||||||
| 
 | 
 | ||||||
|     def _maybe_destroy_channel(self, name: str): | 
 | ||||||
|  |     async def _maybe_destroy_channel(self, name: str): | ||||||
|         ''' |         ''' | ||||||
|         If channel exists cancel its scope and remove from internal |         If channel exists cancel its scope and remove from internal | ||||||
|         _channels list. |         _channels list. | ||||||
| 
 | 
 | ||||||
|         ''' |         ''' | ||||||
|         maybe_entry = self.find_channel(name) |         async with self.maybe_lock(): | ||||||
|  |             maybe_entry = self._find_channel(name) | ||||||
|             if maybe_entry: |             if maybe_entry: | ||||||
|                 i, info = maybe_entry |                 i, info = maybe_entry | ||||||
|                 info.cancel_scope.cancel() |                 info.cancel_scope.cancel() | ||||||
|                 del self._channels[i] |                 del self._channels[i] | ||||||
| 
 | 
 | ||||||
|     def add_channel(self, name: str): |     async def add_channel(self, name: str, **kwargs): | ||||||
|         ''' |         ''' | ||||||
|         Add a new channel to be handled |         Add a new channel to be handled | ||||||
| 
 | 
 | ||||||
|         ''' |         ''' | ||||||
|         self._n.start_soon( |         async with self.maybe_lock(): | ||||||
|  |             await self._n.start(partial( | ||||||
|                 self._channel_handler_task, |                 self._channel_handler_task, | ||||||
|             name |                 name, | ||||||
|         ) |                 **kwargs | ||||||
|  |             )) | ||||||
| 
 | 
 | ||||||
|     def remove_channel(self, name: str): |     async def remove_channel(self, name: str): | ||||||
|         ''' |         ''' | ||||||
|         Remove a channel and stop its handling |         Remove a channel and stop its handling | ||||||
| 
 | 
 | ||||||
|         ''' |         ''' | ||||||
|         self._maybe_destroy_channel(name) |         async with self.maybe_lock(): | ||||||
|  |             await self._maybe_destroy_channel(name) | ||||||
|  | 
 | ||||||
|  |             # if that was last channel reset connect event | ||||||
|  |             if len(self) == 0: | ||||||
|  |                 self._connect_event = trio.Event() | ||||||
|  | 
 | ||||||
|  |     async def wait_for_channel(self): | ||||||
|  |         ''' | ||||||
|  |         Wait until at least one channel added | ||||||
|  | 
 | ||||||
|  |         ''' | ||||||
|  |         await self._connect_event.wait() | ||||||
|  |         self._connect_event = trio.Event() | ||||||
| 
 | 
 | ||||||
|     def __len__(self) -> int: |     def __len__(self) -> int: | ||||||
|         return len(self._channels) |         return len(self._channels) | ||||||
| 
 | 
 | ||||||
|  |     def __getitem__(self, name: str): | ||||||
|  |         maybe_entry = self._find_channel(name) | ||||||
|  |         if maybe_entry: | ||||||
|  |             _, info = maybe_entry | ||||||
|  |             return info | ||||||
|  | 
 | ||||||
|  |         raise KeyError(f'Channel {name} not found!') | ||||||
|  | 
 | ||||||
|     async def aclose(self) -> None: |     async def aclose(self) -> None: | ||||||
|         for chan in self._channels: |         async with self.maybe_lock(): | ||||||
|             self._maybe_destroy_channel(chan.name) |             for info in self._channels: | ||||||
| 
 |                 await self.remove_channel(info.name) | ||||||
|     async def __aenter__(self): |  | ||||||
|         return self |  | ||||||
| 
 |  | ||||||
|     async def __aexit__(self, exc_type, exc_val, exc_tb): |  | ||||||
|         await self.aclose() |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class RingBuffPublisher( | ''' | ||||||
|     ChannelManager[RingBuffBytesSender] | Ring buffer publisher & subscribe pattern mediated by `ringd` actor. | ||||||
| ): | 
 | ||||||
|  | ''' | ||||||
|  | 
 | ||||||
|  | @dataclass | ||||||
|  | class PublisherChannels: | ||||||
|  |     ring: RingBufferSendChannel | ||||||
|  |     schan: trio.MemorySendChannel | ||||||
|  |     rchan: trio.MemoryReceiveChannel | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class RingBufferPublisher(trio.abc.SendChannel[bytes]): | ||||||
|     ''' |     ''' | ||||||
|     Implement ChannelManager protocol + trio.abc.SendChannel[bytes] |     Use ChannelManager to create a multi ringbuf round robin sender that can | ||||||
|     using ring buffers as transport. |     dynamically add or remove more outputs. | ||||||
| 
 | 
 | ||||||
|     - use a `trio.Event` to make sure `send` blocks until at least one channel |     Don't instantiate directly, use `open_ringbuf_publisher` acm to manage its | ||||||
|       available. |     lifecycle. | ||||||
| 
 | 
 | ||||||
|     ''' |     ''' | ||||||
| 
 |  | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         n: trio.Nursery, |         n: trio.Nursery, | ||||||
|  | 
 | ||||||
|  |         # new ringbufs created will have this buf_size | ||||||
|         buf_size: int = 10 * 1024, |         buf_size: int = 10 * 1024, | ||||||
|  | 
 | ||||||
|  |         # global batch size for all channels | ||||||
|         batch_size: int = 1 |         batch_size: int = 1 | ||||||
|     ): |     ): | ||||||
|         super().__init__(n) |         self._buf_size = buf_size | ||||||
|         self._connect_event = trio.Event() |  | ||||||
|         self._next_turn: int = 0 |  | ||||||
| 
 |  | ||||||
|         self._batch_size: int = batch_size |         self._batch_size: int = batch_size | ||||||
| 
 | 
 | ||||||
|     @acm |         self._chanmngr = ChannelManager[PublisherChannels]( | ||||||
|     async def _open_channel( |             n, | ||||||
|         self, |             self._open_channel, | ||||||
|         name: str |             self._channel_task | ||||||
|     ) -> AsyncContextManager[RingBuffBytesSender]: |         ) | ||||||
|         async with ( |  | ||||||
|             ringd.open_ringbuf( |  | ||||||
|                 name=name, |  | ||||||
|                 must_exist=True, |  | ||||||
|             ) as token, |  | ||||||
|             attach_to_ringbuf_schannel(token) as chan |  | ||||||
|         ): |  | ||||||
|             yield chan |  | ||||||
| 
 | 
 | ||||||
|     async def _channel_task(self, info: ChannelInfo) -> None: |         # methods that send data over the channels need to be acquire send lock | ||||||
|         self._connect_event.set() |         # in order to guarantee order of operations | ||||||
|         await trio.sleep_forever() |         self._send_lock = trio.StrictFIFOLock() | ||||||
| 
 | 
 | ||||||
|     async def send(self, msg: bytes): |         self._next_turn: int = 0 | ||||||
|         # wait at least one decoder connected |  | ||||||
|         if len(self) == 0: |  | ||||||
|             await self._connect_event.wait() |  | ||||||
|             self._connect_event = trio.Event() |  | ||||||
| 
 |  | ||||||
|         if self._next_turn >= len(self): |  | ||||||
|             self._next_turn = 0 |  | ||||||
| 
 |  | ||||||
|         turn = self._next_turn |  | ||||||
|         self._next_turn += 1 |  | ||||||
| 
 |  | ||||||
|         output = self._channels[turn] |  | ||||||
|         await output.channel.send(msg) |  | ||||||
| 
 | 
 | ||||||
|     @property |     @property | ||||||
|     def batch_size(self) -> int: |     def batch_size(self) -> int: | ||||||
|  | @ -222,92 +270,273 @@ class RingBuffPublisher( | ||||||
| 
 | 
 | ||||||
|     @batch_size.setter |     @batch_size.setter | ||||||
|     def set_batch_size(self, value: int) -> None: |     def set_batch_size(self, value: int) -> None: | ||||||
|         for output in self._channels: |         for info in self.channels: | ||||||
|             output.channel.batch_size = value |             info.channel.ring.batch_size = value | ||||||
| 
 | 
 | ||||||
|     async def flush( |     @property | ||||||
|  |     def channels(self) -> list[ChannelInfo]: | ||||||
|  |         return self._chanmngr.channels | ||||||
|  | 
 | ||||||
|  |     def get_channel(self, name: str) -> ChannelInfo: | ||||||
|  |         ''' | ||||||
|  |         Get underlying ChannelInfo from name | ||||||
|  | 
 | ||||||
|  |         ''' | ||||||
|  |         return self._chanmngr[name] | ||||||
|  | 
 | ||||||
|  |     async def add_channel( | ||||||
|         self, |         self, | ||||||
|         new_batch_size: int | None = None |         name: str, | ||||||
|  |         must_exist: bool = False | ||||||
|     ): |     ): | ||||||
|         for output in self._channels: |         ''' | ||||||
|             await output.channel.flush( |         Store additional runtime info for channel and add channel to underlying | ||||||
|                 new_batch_size=new_batch_size |         ChannelManager | ||||||
|  | 
 | ||||||
|  |         ''' | ||||||
|  |         await self._chanmngr.add_channel(name, must_exist=must_exist) | ||||||
|  | 
 | ||||||
|  |     async def remove_channel(self, name: str): | ||||||
|  |         ''' | ||||||
|  |         Send EOF to channel (does `channel.flush` also) then remove from | ||||||
|  |         `ChannelManager` acquire both `self._send_lock` and | ||||||
|  |         `self._chanmngr.maybe_lock()` in order to ensure no channel | ||||||
|  |         modifications or sends happen concurrenty | ||||||
|  |         ''' | ||||||
|  |         async with self._chanmngr.maybe_lock(): | ||||||
|  |             # ensure all pending messages are sent | ||||||
|  |             info = self.get_channel(name) | ||||||
|  | 
 | ||||||
|  |             try: | ||||||
|  |                 while True: | ||||||
|  |                     msg = info.channel.rchan.receive_nowait() | ||||||
|  |                     await info.channel.ring.send(msg) | ||||||
|  | 
 | ||||||
|  |             except trio.WouldBlock: | ||||||
|  |                 await info.channel.ring.flush() | ||||||
|  | 
 | ||||||
|  |             await info.channel.schan.aclose() | ||||||
|  | 
 | ||||||
|  |             # finally remove from ChannelManager | ||||||
|  |             await self._chanmngr.remove_channel(name) | ||||||
|  | 
 | ||||||
|  |     @acm | ||||||
|  |     async def _open_channel( | ||||||
|  | 
 | ||||||
|  |         self, | ||||||
|  |         name: str, | ||||||
|  |         must_exist: bool = False | ||||||
|  | 
 | ||||||
|  |     ) -> AsyncContextManager[PublisherChannels]: | ||||||
|  |         ''' | ||||||
|  |         Open a ringbuf through `ringd` and attach as send side | ||||||
|  |         ''' | ||||||
|  |         async with ( | ||||||
|  |             ringd.open_ringbuf( | ||||||
|  |                 name=name, | ||||||
|  |                 buf_size=self._buf_size, | ||||||
|  |                 must_exist=must_exist, | ||||||
|  |             ) as token, | ||||||
|  |             attach_to_ringbuf_sender(token) as ring, | ||||||
|  |         ): | ||||||
|  |             schan, rchan = trio.open_memory_channel(0) | ||||||
|  |             yield PublisherChannels( | ||||||
|  |                 ring=ring, | ||||||
|  |                 schan=schan, | ||||||
|  |                 rchan=rchan | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|     async def send_eof(self): |     async def _channel_task(self, info: ChannelInfo) -> None: | ||||||
|         for output in self._channels: |         ''' | ||||||
|             await output.channel.send_eof() |         Forever get current runtime info for channel, wait on its next pending | ||||||
|  |         payloads update event then drain all into send channel. | ||||||
|  | 
 | ||||||
|  |         ''' | ||||||
|  |         async for msg in info.channel.rchan: | ||||||
|  |             await info.channel.ring.send(msg) | ||||||
|  | 
 | ||||||
|  |     async def send(self, msg: bytes): | ||||||
|  |         ''' | ||||||
|  |         If no output channels connected, wait until one, then fetch the next | ||||||
|  |         channel based on turn, add the indexed payload and update | ||||||
|  |         `self._next_turn` & `self._next_index`. | ||||||
|  | 
 | ||||||
|  |         Needs to acquire `self._send_lock` to make sure updates to turn & index | ||||||
|  |         variables dont happen out of order. | ||||||
|  | 
 | ||||||
|  |         ''' | ||||||
|  |         async with self._send_lock: | ||||||
|  |             # wait at least one decoder connected | ||||||
|  |             if len(self.channels) == 0: | ||||||
|  |                 await self._chanmngr.wait_for_channel() | ||||||
|  | 
 | ||||||
|  |             if self._next_turn >= len(self.channels): | ||||||
|  |                 self._next_turn = 0 | ||||||
|  | 
 | ||||||
|  |             info = self.channels[self._next_turn] | ||||||
|  |             await info.channel.schan.send(msg) | ||||||
|  | 
 | ||||||
|  |             self._next_turn += 1 | ||||||
|  | 
 | ||||||
|  |     async def aclose(self) -> None: | ||||||
|  |         await self._chanmngr.aclose() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @acm | @acm | ||||||
| async def open_ringbuf_publisher( | async def open_ringbuf_publisher( | ||||||
|  | 
 | ||||||
|     buf_size: int = 10 * 1024, |     buf_size: int = 10 * 1024, | ||||||
|     batch_size: int = 1 |     batch_size: int = 1, | ||||||
| ): |     guarantee_order: bool = False, | ||||||
|  |     force_cancel: bool = False | ||||||
|  | 
 | ||||||
|  | ) -> AsyncContextManager[RingBufferPublisher]: | ||||||
|  |     ''' | ||||||
|  |     Open a new ringbuf publisher | ||||||
|  | 
 | ||||||
|  |     ''' | ||||||
|     async with ( |     async with ( | ||||||
|         trio.open_nursery() as n, |         trio.open_nursery() as n, | ||||||
|         RingBuffPublisher( |         RingBufferPublisher( | ||||||
|             n, |             n, | ||||||
|             buf_size=buf_size, |             buf_size=buf_size, | ||||||
|             batch_size=batch_size |             batch_size=batch_size | ||||||
|         ) as outputs |         ) as publisher | ||||||
|     ): |     ): | ||||||
|         yield outputs |         if guarantee_order: | ||||||
|  |             order_send_channel(publisher) | ||||||
|  | 
 | ||||||
|  |         yield publisher | ||||||
|  | 
 | ||||||
|  |         if force_cancel: | ||||||
|  |             # implicitly cancel any running channel handler task | ||||||
|  |             n.cancel_scope.cancel() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 | class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]): | ||||||
| class RingBuffSubscriber( |  | ||||||
|     ChannelManager[RingBuffBytesReceiver] |  | ||||||
| ): |  | ||||||
|     ''' |     ''' | ||||||
|     Implement ChannelManager protocol + trio.abc.ReceiveChannel[bytes] |     Use ChannelManager to create a multi ringbuf receiver that can | ||||||
|     using ring buffers as transport. |     dynamically add or remove more inputs and combine all into a single output. | ||||||
| 
 | 
 | ||||||
|     - use a trio memory channel pair to multiplex all received messages into a |     In order for `self.receive` messages to be returned in order, publisher | ||||||
|       single `trio.MemoryReceiveChannel`, give a sender channel clone to each |     will send all payloads as `OrderedPayload` msgpack encoded msgs, this | ||||||
|       _channel_task. |     allows our channel handler tasks to just stash the out of order payloads | ||||||
|  |     inside `self._pending_payloads` and if a in order payload is available | ||||||
|  |     signal through `self._new_payload_event`. | ||||||
|  | 
 | ||||||
|  |     On `self.receive` we wait until at least one channel is connected, then if | ||||||
|  |     an in order payload is pending, we pop and return it, in case no in order | ||||||
|  |     payload is available wait until next `self._new_payload_event.set()`. | ||||||
| 
 | 
 | ||||||
|     ''' |     ''' | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         n: trio.Nursery, |         n: trio.Nursery, | ||||||
|  | 
 | ||||||
|  |         # if connecting to a publisher that has already sent messages set  | ||||||
|  |         # to the next expected payload index this subscriber will receive | ||||||
|  |         start_index: int = 0 | ||||||
|     ): |     ): | ||||||
|         super().__init__(n) |         self._chanmngr = ChannelManager[RingBufferReceiveChannel]( | ||||||
|         self._send_chan, self._recv_chan = trio.open_memory_channel(0) |             n, | ||||||
|  |             self._open_channel, | ||||||
|  |             self._channel_task | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         self._schan, self._rchan = trio.open_memory_channel(0) | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def channels(self) -> list[ChannelInfo]: | ||||||
|  |         return self._chanmngr.channels | ||||||
|  | 
 | ||||||
|  |     def get_channel(self, name: str): | ||||||
|  |         return self._chanmngr[name] | ||||||
|  | 
 | ||||||
|  |     async def add_channel(self, name: str, must_exist: bool = False): | ||||||
|  |         ''' | ||||||
|  |         Add new input channel by name | ||||||
|  | 
 | ||||||
|  |         ''' | ||||||
|  |         await self._chanmngr.add_channel(name, must_exist=must_exist) | ||||||
|  | 
 | ||||||
|  |     async def remove_channel(self, name: str): | ||||||
|  |         ''' | ||||||
|  |         Remove an input channel by name | ||||||
|  | 
 | ||||||
|  |         ''' | ||||||
|  |         await self._chanmngr.remove_channel(name) | ||||||
| 
 | 
 | ||||||
|     @acm |     @acm | ||||||
|     async def _open_channel( |     async def _open_channel( | ||||||
|  | 
 | ||||||
|         self, |         self, | ||||||
|         name: str |         name: str, | ||||||
|     ) -> AsyncContextManager[RingBuffBytesReceiver]: |         must_exist: bool = False | ||||||
|  | 
 | ||||||
|  |     ) -> AsyncContextManager[RingBufferReceiveChannel]: | ||||||
|  |         ''' | ||||||
|  |         Open a ringbuf through `ringd` and attach as receiver side | ||||||
|  |         ''' | ||||||
|         async with ( |         async with ( | ||||||
|             ringd.open_ringbuf( |             ringd.open_ringbuf( | ||||||
|                 name=name, |                 name=name, | ||||||
|                 must_exist=True, |                 must_exist=must_exist, | ||||||
|             ) as token, |             ) as token, | ||||||
|             attach_to_ringbuf_rchannel(token) as chan |             attach_to_ringbuf_receiver(token) as chan | ||||||
|         ): |         ): | ||||||
|             yield chan |             yield chan | ||||||
| 
 | 
 | ||||||
|     async def _channel_task(self, info: ChannelInfo) -> None: |     async def _channel_task(self, info: ChannelInfo) -> None: | ||||||
|         send_chan = self._send_chan.clone() |         ''' | ||||||
|         try: |         Iterate over receive channel messages, decode them as `OrderedPayload`s | ||||||
|             async for msg in info.channel: |         and stash them in `self._pending_payloads`, in case we can pop next in | ||||||
|                 await send_chan.send(msg) |         order payload, signal through setting `self._new_payload_event`. | ||||||
| 
 | 
 | ||||||
|         except tractor._exceptions.InternalError: |         ''' | ||||||
|             # TODO: cleaner cancellation! |         while True: | ||||||
|             ... |             try: | ||||||
|  |                 msg = await info.channel.receive() | ||||||
|  |                 await self._schan.send(msg) | ||||||
|  | 
 | ||||||
|  |             except tractor.linux.eventfd.EFDReadCancelled as e: | ||||||
|  |                 # when channel gets removed while we are doing a receive | ||||||
|  |                 log.exception(e) | ||||||
|  |                 break | ||||||
|  | 
 | ||||||
|  |             except trio.EndOfChannel: | ||||||
|  |                 break | ||||||
| 
 | 
 | ||||||
|     async def receive(self) -> bytes: |     async def receive(self) -> bytes: | ||||||
|         return await self._recv_chan.receive() |         ''' | ||||||
|  |         Receive next in order msg | ||||||
|  |         ''' | ||||||
|  |         return await self._rchan.receive() | ||||||
| 
 | 
 | ||||||
|  |     async def aclose(self) -> None: | ||||||
|  |         await self._chanmngr.aclose() | ||||||
| 
 | 
 | ||||||
| @acm | @acm | ||||||
| async def open_ringbuf_subscriber(): | async def open_ringbuf_subscriber( | ||||||
|  | 
 | ||||||
|  |     guarantee_order: bool = False, | ||||||
|  |     force_cancel: bool = False | ||||||
|  | 
 | ||||||
|  | ) -> AsyncContextManager[RingBufferPublisher]: | ||||||
|  |     ''' | ||||||
|  |     Open a new ringbuf subscriber | ||||||
|  | 
 | ||||||
|  |     ''' | ||||||
|     async with ( |     async with ( | ||||||
|         trio.open_nursery() as n, |         trio.open_nursery() as n, | ||||||
|         RingBuffSubscriber(n) as inputs |         RingBufferSubscriber( | ||||||
|  |             n, | ||||||
|  |         ) as subscriber | ||||||
|     ): |     ): | ||||||
|         yield inputs |         if guarantee_order: | ||||||
|  |             order_receive_channel(subscriber) | ||||||
| 
 | 
 | ||||||
|  |         yield subscriber | ||||||
|  | 
 | ||||||
|  |         if force_cancel: | ||||||
|  |             # implicitly cancel any running channel handler task | ||||||
|  |             n.cancel_scope.cancel() | ||||||
|  |  | ||||||
|  | @ -32,3 +32,8 @@ from ._broadcast import ( | ||||||
| from ._beg import ( | from ._beg import ( | ||||||
|     collapse_eg as collapse_eg, |     collapse_eg as collapse_eg, | ||||||
| ) | ) | ||||||
|  | 
 | ||||||
|  | from ._ordering import ( | ||||||
|  |     order_send_channel as order_send_channel, | ||||||
|  |     order_receive_channel as order_receive_channel | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | @ -0,0 +1,89 @@ | ||||||
|  | from __future__ import annotations | ||||||
|  | from heapq import ( | ||||||
|  |     heappush, | ||||||
|  |     heappop | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | import trio | ||||||
|  | import msgspec | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class OrderedPayload(msgspec.Struct, frozen=True): | ||||||
|  |     index: int | ||||||
|  |     payload: bytes | ||||||
|  | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def from_msg(cls, msg: bytes) -> OrderedPayload: | ||||||
|  |         return msgspec.msgpack.decode(msg, type=OrderedPayload) | ||||||
|  | 
 | ||||||
|  |     def encode(self) -> bytes: | ||||||
|  |         return msgspec.msgpack.encode(self) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def order_send_channel( | ||||||
|  |     channel: trio.abc.SendChannel[bytes], | ||||||
|  |     start_index: int = 0 | ||||||
|  | ): | ||||||
|  | 
 | ||||||
|  |     next_index = start_index | ||||||
|  |     send_lock = trio.StrictFIFOLock() | ||||||
|  | 
 | ||||||
|  |     channel._send = channel.send | ||||||
|  |     channel._aclose = channel.aclose | ||||||
|  | 
 | ||||||
|  |     async def send(msg: bytes): | ||||||
|  |         nonlocal next_index | ||||||
|  |         async with send_lock: | ||||||
|  |             await channel._send( | ||||||
|  |                 OrderedPayload( | ||||||
|  |                     index=next_index, | ||||||
|  |                     payload=msg | ||||||
|  |                 ).encode() | ||||||
|  |             ) | ||||||
|  |             next_index += 1 | ||||||
|  | 
 | ||||||
|  |     async def aclose(): | ||||||
|  |         async with send_lock: | ||||||
|  |             await channel._aclose() | ||||||
|  | 
 | ||||||
|  |     channel.send = send | ||||||
|  |     channel.aclose = aclose | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def order_receive_channel( | ||||||
|  |     channel: trio.abc.ReceiveChannel[bytes], | ||||||
|  |     start_index: int = 0 | ||||||
|  | ): | ||||||
|  |     next_index = start_index | ||||||
|  |     pqueue = [] | ||||||
|  | 
 | ||||||
|  |     channel._receive = channel.receive | ||||||
|  | 
 | ||||||
|  |     def can_pop_next() -> bool: | ||||||
|  |         return ( | ||||||
|  |             len(pqueue) > 0 | ||||||
|  |             and | ||||||
|  |             pqueue[0][0] == next_index | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     async def drain_to_heap(): | ||||||
|  |         while not can_pop_next(): | ||||||
|  |             msg = await channel._receive() | ||||||
|  |             msg = OrderedPayload.from_msg(msg) | ||||||
|  |             heappush(pqueue, (msg.index, msg.payload)) | ||||||
|  | 
 | ||||||
|  |     def pop_next(): | ||||||
|  |         nonlocal next_index | ||||||
|  |         _, msg = heappop(pqueue) | ||||||
|  |         next_index += 1 | ||||||
|  |         return msg | ||||||
|  | 
 | ||||||
|  |     async def receive() -> bytes: | ||||||
|  |         if can_pop_next(): | ||||||
|  |             return pop_next() | ||||||
|  | 
 | ||||||
|  |         await drain_to_heap() | ||||||
|  | 
 | ||||||
|  |         return pop_next() | ||||||
|  | 
 | ||||||
|  |     channel.receive = receive | ||||||
		Loading…
	
		Reference in New Issue