From a553446619eec994f4cbb1bf814c0b15132597d7 Mon Sep 17 00:00:00 2001 From: Guillermo Rodriguez Date: Tue, 22 Apr 2025 01:46:41 -0300 Subject: [PATCH] Pubsub topics, enc & decoders Implicit aclose on all channels on ChannelManager aclose Implicit nursery cancel on pubsub acms Use long running actor portal for open_{pub,sub}_channel_at fns Add optional encoder/decoder on pubsub Add topic system for multiple pub or sub on same actor Add wait fn for sub and pub channel register --- tractor/ipc/_ringbuf/_pubsub.py | 292 ++++++++++++++++++-------------- 1 file changed, 161 insertions(+), 131 deletions(-) diff --git a/tractor/ipc/_ringbuf/_pubsub.py b/tractor/ipc/_ringbuf/_pubsub.py index 6a33e42a..7de1d9b2 100644 --- a/tractor/ipc/_ringbuf/_pubsub.py +++ b/tractor/ipc/_ringbuf/_pubsub.py @@ -31,8 +31,14 @@ from dataclasses import dataclass import trio import tractor +from msgspec.msgpack import ( + Encoder, + Decoder +) + from tractor.ipc._ringbuf import ( RBToken, + PayloadT, RingBufferSendChannel, RingBufferReceiveChannel, attach_to_ringbuf_sender, @@ -242,6 +248,7 @@ class ChannelManager(Generic[ChannelType]): if info.channel.closed: continue + await info.channel.aclose() await self.remove_channel(info.token.shm_name) self._is_closed = True @@ -253,7 +260,7 @@ Ring buffer publisher & subscribe pattern mediated by `ringd` actor. ''' -class RingBufferPublisher(trio.abc.SendChannel[bytes]): +class RingBufferPublisher(trio.abc.SendChannel[PayloadT]): ''' Use ChannelManager to create a multi ringbuf round robin sender that can dynamically add or remove more outputs. @@ -270,13 +277,16 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]): msgs_per_turn: int = 1, # global batch size for all channels - batch_size: int = 1 + batch_size: int = 1, + + encoder: Encoder | None = None ): self._batch_size: int = batch_size self.msgs_per_turn = msgs_per_turn + self._enc = encoder # helper to manage acms + long running tasks - self._chanmngr = ChannelManager[RingBufferSendChannel]( + self._chanmngr = ChannelManager[RingBufferSendChannel[PayloadT]]( n, self._open_channel, self._channel_task @@ -349,10 +359,11 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]): self, token: RBToken - ) -> AsyncContextManager[RingBufferSendChannel]: + ) -> AsyncContextManager[RingBufferSendChannel[PayloadT]]: async with attach_to_ringbuf_sender( token, - batch_size=self._batch_size + batch_size=self._batch_size, + encoder=self._enc ) as ring: yield ring @@ -387,7 +398,7 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]): info = self.channels[turn] await info.channel.send(msg) - async def broadcast(self, msg: bytes): + async def broadcast(self, msg: PayloadT): ''' Send a msg to all channels, if no channels connected, does nothing. ''' @@ -406,8 +417,8 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]): ... async def __aenter__(self): - self._chanmngr.open() self._is_closed = False + self._chanmngr.open() return self async def aclose(self) -> None: @@ -420,7 +431,7 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]): self._is_closed = True -class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]): +class RingBufferSubscriber(trio.abc.ReceiveChannel[PayloadT]): ''' Use ChannelManager to create a multi ringbuf receiver that can dynamically add or remove more inputs and combine all into a single output. @@ -440,11 +451,10 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]): self, n: trio.Nursery, - # if connecting to a publisher that has already sent messages set - # to the next expected payload index this subscriber will receive - start_index: int = 0 + decoder: Decoder | None = None ): - self._chanmngr = ChannelManager[RingBufferReceiveChannel]( + self._dec = decoder + self._chanmngr = ChannelManager[RingBufferReceiveChannel[PayloadT]]( n, self._open_channel, self._channel_task @@ -483,7 +493,10 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]): token: RBToken ) -> AsyncContextManager[RingBufferSendChannel]: - async with attach_to_ringbuf_receiver(token) as ring: + async with attach_to_ringbuf_receiver( + token, + decoder=self._dec + ) as ring: yield ring async def _channel_task(self, info: ChannelInfo) -> None: @@ -509,7 +522,7 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]): except trio.ClosedResourceError: break - async def receive(self) -> bytes: + async def receive(self) -> PayloadT: ''' Receive next in order msg ''' @@ -543,73 +556,74 @@ Actor module for managing publisher & subscriber channels remotely through `tractor.context` rpc ''' -_publisher: RingBufferPublisher | None = None -_subscriber: RingBufferSubscriber | None = None +@dataclass +class PublisherEntry: + publisher: RingBufferPublisher | None = None + is_set: trio.Event = trio.Event() -def set_publisher(pub: RingBufferPublisher): - global _publisher +_publishers: dict[str, PublisherEntry] = {} - if _publisher: + +def maybe_init_publisher(topic: str) -> PublisherEntry: + entry = _publishers.get(topic, None) + if not entry: + entry = PublisherEntry() + _publishers[topic] = entry + + return entry + + +def set_publisher(topic: str, pub: RingBufferPublisher): + global _publishers + + entry = _publishers.get(topic, None) + if not entry: + entry = maybe_init_publisher(topic) + + if entry.publisher: raise RuntimeError( - f'publisher already set on {tractor.current_actor()}' + f'publisher for topic {topic} already set on {tractor.current_actor()}' ) - _publisher = pub + entry.publisher = pub + entry.is_set.set() -def set_subscriber(sub: RingBufferSubscriber): - global _subscriber - - if _subscriber: - raise RuntimeError( - f'subscriber already set on {tractor.current_actor()}' - ) - - _subscriber = sub - - -def get_publisher() -> RingBufferPublisher: - global _publisher - - if not _publisher: +def get_publisher(topic: str) -> RingBufferPublisher: + entry = _publishers.get(topic, None) + if not entry or not entry.publisher: raise RuntimeError( f'{tractor.current_actor()} tried to get publisher' 'but it\'s not set' ) - return _publisher + return entry.publisher -def get_subscriber() -> RingBufferSubscriber: - global _subscriber - - if not _subscriber: - raise RuntimeError( - f'{tractor.current_actor()} tried to get subscriber' - 'but it\'s not set' - ) - - return _subscriber +async def wait_publisher(topic: str) -> RingBufferPublisher: + entry = maybe_init_publisher(topic) + await entry.is_set.wait() + return entry.publisher @tractor.context async def _add_pub_channel( ctx: tractor.Context, + topic: str, token: RBToken ): - publisher = get_publisher() - await ctx.started() + publisher = await wait_publisher(topic) await publisher.add_channel(token) @tractor.context async def _remove_pub_channel( ctx: tractor.Context, + topic: str, ring_name: str ): - publisher = get_publisher() - await ctx.started() + publisher = await wait_publisher(topic) maybe_token = fdshare.maybe_get_fds(ring_name) if maybe_token: await publisher.remove_channel(ring_name) @@ -619,59 +633,92 @@ async def _remove_pub_channel( async def open_pub_channel_at( actor_name: str, token: RBToken, - cleanup: bool = True, + topic: str = 'default', ): - async with ( - tractor.find_actor(actor_name) as portal, + async with tractor.find_actor(actor_name) as portal: + await portal.run(_add_pub_channel, topic=topic, token=token) + try: + yield - portal.open_context( - _add_pub_channel, - token=token - ) as (ctx, _) - ): - ... + except trio.Cancelled: + log.warning( + 'open_pub_channel_at got cancelled!\n' + f'\tactor_name = {actor_name}\n' + f'\ttoken = {token}\n' + ) + raise - try: - yield + await portal.run(_remove_pub_channel, topic=topic, ring_name=token.shm_name) - except trio.Cancelled: - log.warning( - 'open_pub_channel_at got cancelled!\n' - f'\tactor_name = {actor_name}\n' - f'\ttoken = {token}\n' + +@dataclass +class SubscriberEntry: + subscriber: RingBufferSubscriber | None = None + is_set: trio.Event = trio.Event() + + +_subscribers: dict[str, SubscriberEntry] = {} + + +def maybe_init_subscriber(topic: str) -> SubscriberEntry: + entry = _subscribers.get(topic, None) + if not entry: + entry = SubscriberEntry() + _subscribers[topic] = entry + + return entry + + +def set_subscriber(topic: str, sub: RingBufferSubscriber): + global _subscribers + + entry = _subscribers.get(topic, None) + if not entry: + entry = maybe_init_subscriber(topic) + + if entry.subscriber: + raise RuntimeError( + f'subscriber for topic {topic} already set on {tractor.current_actor()}' ) - raise - finally: - if not cleanup: - return + entry.subscriber = sub + entry.is_set.set() - async with tractor.find_actor(actor_name) as portal: - if portal: - async with portal.open_context( - _remove_pub_channel, - ring_name=token.shm_name - ) as (ctx, _): - ... + +def get_subscriber(topic: str) -> RingBufferSubscriber: + entry = _subscribers.get(topic, None) + if not entry or not entry.subscriber: + raise RuntimeError( + f'{tractor.current_actor()} tried to get subscriber' + 'but it\'s not set' + ) + + return entry.subscriber + + +async def wait_subscriber(topic: str) -> RingBufferSubscriber: + entry = maybe_init_subscriber(topic) + await entry.is_set.wait() + return entry.subscriber @tractor.context async def _add_sub_channel( ctx: tractor.Context, + topic: str, token: RBToken ): - subscriber = get_subscriber() - await ctx.started() + subscriber = await wait_subscriber(topic) await subscriber.add_channel(token) @tractor.context async def _remove_sub_channel( ctx: tractor.Context, + topic: str, ring_name: str ): - subscriber = get_subscriber() - await ctx.started() + subscriber = await wait_subscriber(topic) maybe_token = fdshare.maybe_get_fds(ring_name) if maybe_token: await subscriber.remove_channel(ring_name) @@ -681,41 +728,22 @@ async def _remove_sub_channel( async def open_sub_channel_at( actor_name: str, token: RBToken, - cleanup: bool = True, + topic: str = 'default', ): - async with ( - tractor.find_actor(actor_name) as portal, + async with tractor.find_actor(actor_name) as portal: + await portal.run(_add_sub_channel, topic=topic, token=token) + try: + yield - portal.open_context( - _add_sub_channel, - token=token - ) as (ctx, _) - ): - ... - - try: - yield - - except trio.Cancelled: - log.warning( - 'open_sub_channel_at got cancelled!\n' - f'\tactor_name = {actor_name}\n' - f'\ttoken = {token}\n' - ) - raise - - finally: - if not cleanup: - return - - async with tractor.find_actor(actor_name) as portal: - if portal: - async with portal.open_context( - _remove_sub_channel, - ring_name=token.shm_name - ) as (ctx, _): - ... + except trio.Cancelled: + log.warning( + 'open_sub_channel_at got cancelled!\n' + f'\tactor_name = {actor_name}\n' + f'\ttoken = {token}\n' + ) + raise + await portal.run(_remove_sub_channel, topic=topic, ring_name=token.shm_name) ''' @@ -725,12 +753,17 @@ High level helpers to open publisher & subscriber @acm async def open_ringbuf_publisher( + # name to distinguish this publisher + topic: str = 'default', + # global batch size for channels batch_size: int = 1, # messages before changing output channel msgs_per_turn: int = 1, + encoder: Encoder | None = None, + # ensure subscriber receives in same order publisher sent # causes it to use wrapped payloads which contain the og # index @@ -750,26 +783,28 @@ async def open_ringbuf_publisher( trio.open_nursery(strict_exception_groups=False) as n, RingBufferPublisher( n, - batch_size=batch_size + batch_size=batch_size, + encoder=encoder, ) as publisher ): if guarantee_order: order_send_channel(publisher) if set_module_var: - set_publisher(publisher) + set_publisher(topic, publisher) - try: - yield publisher + yield publisher - except trio.Cancelled: - with trio.CancelScope(shield=True): - await publisher.aclose() - raise + n.cancel_scope.cancel() @acm async def open_ringbuf_subscriber( + # name to distinguish this subscriber + topic: str = 'default', + + decoder: Decoder | None = None, + # expect indexed payloads and unwrap them in order guarantee_order: bool = False, @@ -784,7 +819,7 @@ async def open_ringbuf_subscriber( ''' async with ( trio.open_nursery(strict_exception_groups=False) as n, - RingBufferSubscriber(n) as subscriber + RingBufferSubscriber(n, decoder=decoder) as subscriber ): # maybe monkey patch `.receive` to use indexed payloads if guarantee_order: @@ -792,13 +827,8 @@ async def open_ringbuf_subscriber( # maybe set global module var for remote actor channel updates if set_module_var: - global _subscriber - set_subscriber(subscriber) + set_subscriber(topic, subscriber) - try: - yield subscriber + yield subscriber - except trio.Cancelled: - with trio.CancelScope(shield=True): - await subscriber.aclose() - raise + n.cancel_scope.cancel()