Pubsub topics, enc & decoders

Implicit aclose on all channels on ChannelManager aclose
Implicit nursery cancel on pubsub acms
Use long running actor portal for open_{pub,sub}_channel_at fns
Add optional encoder/decoder on pubsub
Add topic system for multiple pub or sub on same actor
Add wait fn for sub and pub channel register
one_ring_to_rule_them_all
Guillermo Rodriguez 2025-04-22 01:46:41 -03:00
parent 8799cf3b78
commit a553446619
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
1 changed files with 161 additions and 131 deletions

View File

@ -31,8 +31,14 @@ from dataclasses import dataclass
import trio import trio
import tractor import tractor
from msgspec.msgpack import (
Encoder,
Decoder
)
from tractor.ipc._ringbuf import ( from tractor.ipc._ringbuf import (
RBToken, RBToken,
PayloadT,
RingBufferSendChannel, RingBufferSendChannel,
RingBufferReceiveChannel, RingBufferReceiveChannel,
attach_to_ringbuf_sender, attach_to_ringbuf_sender,
@ -242,6 +248,7 @@ class ChannelManager(Generic[ChannelType]):
if info.channel.closed: if info.channel.closed:
continue continue
await info.channel.aclose()
await self.remove_channel(info.token.shm_name) await self.remove_channel(info.token.shm_name)
self._is_closed = True self._is_closed = True
@ -253,7 +260,7 @@ Ring buffer publisher & subscribe pattern mediated by `ringd` actor.
''' '''
class RingBufferPublisher(trio.abc.SendChannel[bytes]): class RingBufferPublisher(trio.abc.SendChannel[PayloadT]):
''' '''
Use ChannelManager to create a multi ringbuf round robin sender that can Use ChannelManager to create a multi ringbuf round robin sender that can
dynamically add or remove more outputs. dynamically add or remove more outputs.
@ -270,13 +277,16 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
msgs_per_turn: int = 1, msgs_per_turn: int = 1,
# global batch size for all channels # global batch size for all channels
batch_size: int = 1 batch_size: int = 1,
encoder: Encoder | None = None
): ):
self._batch_size: int = batch_size self._batch_size: int = batch_size
self.msgs_per_turn = msgs_per_turn self.msgs_per_turn = msgs_per_turn
self._enc = encoder
# helper to manage acms + long running tasks # helper to manage acms + long running tasks
self._chanmngr = ChannelManager[RingBufferSendChannel]( self._chanmngr = ChannelManager[RingBufferSendChannel[PayloadT]](
n, n,
self._open_channel, self._open_channel,
self._channel_task self._channel_task
@ -349,10 +359,11 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
self, self,
token: RBToken token: RBToken
) -> AsyncContextManager[RingBufferSendChannel]: ) -> AsyncContextManager[RingBufferSendChannel[PayloadT]]:
async with attach_to_ringbuf_sender( async with attach_to_ringbuf_sender(
token, token,
batch_size=self._batch_size batch_size=self._batch_size,
encoder=self._enc
) as ring: ) as ring:
yield ring yield ring
@ -387,7 +398,7 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
info = self.channels[turn] info = self.channels[turn]
await info.channel.send(msg) await info.channel.send(msg)
async def broadcast(self, msg: bytes): async def broadcast(self, msg: PayloadT):
''' '''
Send a msg to all channels, if no channels connected, does nothing. Send a msg to all channels, if no channels connected, does nothing.
''' '''
@ -406,8 +417,8 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
... ...
async def __aenter__(self): async def __aenter__(self):
self._chanmngr.open()
self._is_closed = False self._is_closed = False
self._chanmngr.open()
return self return self
async def aclose(self) -> None: async def aclose(self) -> None:
@ -420,7 +431,7 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
self._is_closed = True self._is_closed = True
class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]): class RingBufferSubscriber(trio.abc.ReceiveChannel[PayloadT]):
''' '''
Use ChannelManager to create a multi ringbuf receiver that can Use ChannelManager to create a multi ringbuf receiver that can
dynamically add or remove more inputs and combine all into a single output. dynamically add or remove more inputs and combine all into a single output.
@ -440,11 +451,10 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
self, self,
n: trio.Nursery, n: trio.Nursery,
# if connecting to a publisher that has already sent messages set decoder: Decoder | None = None
# to the next expected payload index this subscriber will receive
start_index: int = 0
): ):
self._chanmngr = ChannelManager[RingBufferReceiveChannel]( self._dec = decoder
self._chanmngr = ChannelManager[RingBufferReceiveChannel[PayloadT]](
n, n,
self._open_channel, self._open_channel,
self._channel_task self._channel_task
@ -483,7 +493,10 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
token: RBToken token: RBToken
) -> AsyncContextManager[RingBufferSendChannel]: ) -> AsyncContextManager[RingBufferSendChannel]:
async with attach_to_ringbuf_receiver(token) as ring: async with attach_to_ringbuf_receiver(
token,
decoder=self._dec
) as ring:
yield ring yield ring
async def _channel_task(self, info: ChannelInfo) -> None: async def _channel_task(self, info: ChannelInfo) -> None:
@ -509,7 +522,7 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
except trio.ClosedResourceError: except trio.ClosedResourceError:
break break
async def receive(self) -> bytes: async def receive(self) -> PayloadT:
''' '''
Receive next in order msg Receive next in order msg
''' '''
@ -543,73 +556,74 @@ Actor module for managing publisher & subscriber channels remotely through
`tractor.context` rpc `tractor.context` rpc
''' '''
_publisher: RingBufferPublisher | None = None @dataclass
_subscriber: RingBufferSubscriber | None = None class PublisherEntry:
publisher: RingBufferPublisher | None = None
is_set: trio.Event = trio.Event()
def set_publisher(pub: RingBufferPublisher): _publishers: dict[str, PublisherEntry] = {}
global _publisher
if _publisher:
def maybe_init_publisher(topic: str) -> PublisherEntry:
entry = _publishers.get(topic, None)
if not entry:
entry = PublisherEntry()
_publishers[topic] = entry
return entry
def set_publisher(topic: str, pub: RingBufferPublisher):
global _publishers
entry = _publishers.get(topic, None)
if not entry:
entry = maybe_init_publisher(topic)
if entry.publisher:
raise RuntimeError( raise RuntimeError(
f'publisher already set on {tractor.current_actor()}' f'publisher for topic {topic} already set on {tractor.current_actor()}'
) )
_publisher = pub entry.publisher = pub
entry.is_set.set()
def set_subscriber(sub: RingBufferSubscriber): def get_publisher(topic: str) -> RingBufferPublisher:
global _subscriber entry = _publishers.get(topic, None)
if not entry or not entry.publisher:
if _subscriber:
raise RuntimeError(
f'subscriber already set on {tractor.current_actor()}'
)
_subscriber = sub
def get_publisher() -> RingBufferPublisher:
global _publisher
if not _publisher:
raise RuntimeError( raise RuntimeError(
f'{tractor.current_actor()} tried to get publisher' f'{tractor.current_actor()} tried to get publisher'
'but it\'s not set' 'but it\'s not set'
) )
return _publisher return entry.publisher
def get_subscriber() -> RingBufferSubscriber: async def wait_publisher(topic: str) -> RingBufferPublisher:
global _subscriber entry = maybe_init_publisher(topic)
await entry.is_set.wait()
if not _subscriber: return entry.publisher
raise RuntimeError(
f'{tractor.current_actor()} tried to get subscriber'
'but it\'s not set'
)
return _subscriber
@tractor.context @tractor.context
async def _add_pub_channel( async def _add_pub_channel(
ctx: tractor.Context, ctx: tractor.Context,
topic: str,
token: RBToken token: RBToken
): ):
publisher = get_publisher() publisher = await wait_publisher(topic)
await ctx.started()
await publisher.add_channel(token) await publisher.add_channel(token)
@tractor.context @tractor.context
async def _remove_pub_channel( async def _remove_pub_channel(
ctx: tractor.Context, ctx: tractor.Context,
topic: str,
ring_name: str ring_name: str
): ):
publisher = get_publisher() publisher = await wait_publisher(topic)
await ctx.started()
maybe_token = fdshare.maybe_get_fds(ring_name) maybe_token = fdshare.maybe_get_fds(ring_name)
if maybe_token: if maybe_token:
await publisher.remove_channel(ring_name) await publisher.remove_channel(ring_name)
@ -619,18 +633,10 @@ async def _remove_pub_channel(
async def open_pub_channel_at( async def open_pub_channel_at(
actor_name: str, actor_name: str,
token: RBToken, token: RBToken,
cleanup: bool = True, topic: str = 'default',
): ):
async with ( async with tractor.find_actor(actor_name) as portal:
tractor.find_actor(actor_name) as portal, await portal.run(_add_pub_channel, topic=topic, token=token)
portal.open_context(
_add_pub_channel,
token=token
) as (ctx, _)
):
...
try: try:
yield yield
@ -642,36 +648,77 @@ async def open_pub_channel_at(
) )
raise raise
finally: await portal.run(_remove_pub_channel, topic=topic, ring_name=token.shm_name)
if not cleanup:
return
async with tractor.find_actor(actor_name) as portal:
if portal: @dataclass
async with portal.open_context( class SubscriberEntry:
_remove_pub_channel, subscriber: RingBufferSubscriber | None = None
ring_name=token.shm_name is_set: trio.Event = trio.Event()
) as (ctx, _):
...
_subscribers: dict[str, SubscriberEntry] = {}
def maybe_init_subscriber(topic: str) -> SubscriberEntry:
entry = _subscribers.get(topic, None)
if not entry:
entry = SubscriberEntry()
_subscribers[topic] = entry
return entry
def set_subscriber(topic: str, sub: RingBufferSubscriber):
global _subscribers
entry = _subscribers.get(topic, None)
if not entry:
entry = maybe_init_subscriber(topic)
if entry.subscriber:
raise RuntimeError(
f'subscriber for topic {topic} already set on {tractor.current_actor()}'
)
entry.subscriber = sub
entry.is_set.set()
def get_subscriber(topic: str) -> RingBufferSubscriber:
entry = _subscribers.get(topic, None)
if not entry or not entry.subscriber:
raise RuntimeError(
f'{tractor.current_actor()} tried to get subscriber'
'but it\'s not set'
)
return entry.subscriber
async def wait_subscriber(topic: str) -> RingBufferSubscriber:
entry = maybe_init_subscriber(topic)
await entry.is_set.wait()
return entry.subscriber
@tractor.context @tractor.context
async def _add_sub_channel( async def _add_sub_channel(
ctx: tractor.Context, ctx: tractor.Context,
topic: str,
token: RBToken token: RBToken
): ):
subscriber = get_subscriber() subscriber = await wait_subscriber(topic)
await ctx.started()
await subscriber.add_channel(token) await subscriber.add_channel(token)
@tractor.context @tractor.context
async def _remove_sub_channel( async def _remove_sub_channel(
ctx: tractor.Context, ctx: tractor.Context,
topic: str,
ring_name: str ring_name: str
): ):
subscriber = get_subscriber() subscriber = await wait_subscriber(topic)
await ctx.started()
maybe_token = fdshare.maybe_get_fds(ring_name) maybe_token = fdshare.maybe_get_fds(ring_name)
if maybe_token: if maybe_token:
await subscriber.remove_channel(ring_name) await subscriber.remove_channel(ring_name)
@ -681,18 +728,10 @@ async def _remove_sub_channel(
async def open_sub_channel_at( async def open_sub_channel_at(
actor_name: str, actor_name: str,
token: RBToken, token: RBToken,
cleanup: bool = True, topic: str = 'default',
): ):
async with ( async with tractor.find_actor(actor_name) as portal:
tractor.find_actor(actor_name) as portal, await portal.run(_add_sub_channel, topic=topic, token=token)
portal.open_context(
_add_sub_channel,
token=token
) as (ctx, _)
):
...
try: try:
yield yield
@ -704,18 +743,7 @@ async def open_sub_channel_at(
) )
raise raise
finally: await portal.run(_remove_sub_channel, topic=topic, ring_name=token.shm_name)
if not cleanup:
return
async with tractor.find_actor(actor_name) as portal:
if portal:
async with portal.open_context(
_remove_sub_channel,
ring_name=token.shm_name
) as (ctx, _):
...
''' '''
@ -725,12 +753,17 @@ High level helpers to open publisher & subscriber
@acm @acm
async def open_ringbuf_publisher( async def open_ringbuf_publisher(
# name to distinguish this publisher
topic: str = 'default',
# global batch size for channels # global batch size for channels
batch_size: int = 1, batch_size: int = 1,
# messages before changing output channel # messages before changing output channel
msgs_per_turn: int = 1, msgs_per_turn: int = 1,
encoder: Encoder | None = None,
# ensure subscriber receives in same order publisher sent # ensure subscriber receives in same order publisher sent
# causes it to use wrapped payloads which contain the og # causes it to use wrapped payloads which contain the og
# index # index
@ -750,26 +783,28 @@ async def open_ringbuf_publisher(
trio.open_nursery(strict_exception_groups=False) as n, trio.open_nursery(strict_exception_groups=False) as n,
RingBufferPublisher( RingBufferPublisher(
n, n,
batch_size=batch_size batch_size=batch_size,
encoder=encoder,
) as publisher ) as publisher
): ):
if guarantee_order: if guarantee_order:
order_send_channel(publisher) order_send_channel(publisher)
if set_module_var: if set_module_var:
set_publisher(publisher) set_publisher(topic, publisher)
try:
yield publisher yield publisher
except trio.Cancelled: n.cancel_scope.cancel()
with trio.CancelScope(shield=True):
await publisher.aclose()
raise
@acm @acm
async def open_ringbuf_subscriber( async def open_ringbuf_subscriber(
# name to distinguish this subscriber
topic: str = 'default',
decoder: Decoder | None = None,
# expect indexed payloads and unwrap them in order # expect indexed payloads and unwrap them in order
guarantee_order: bool = False, guarantee_order: bool = False,
@ -784,7 +819,7 @@ async def open_ringbuf_subscriber(
''' '''
async with ( async with (
trio.open_nursery(strict_exception_groups=False) as n, trio.open_nursery(strict_exception_groups=False) as n,
RingBufferSubscriber(n) as subscriber RingBufferSubscriber(n, decoder=decoder) as subscriber
): ):
# maybe monkey patch `.receive` to use indexed payloads # maybe monkey patch `.receive` to use indexed payloads
if guarantee_order: if guarantee_order:
@ -792,13 +827,8 @@ async def open_ringbuf_subscriber(
# maybe set global module var for remote actor channel updates # maybe set global module var for remote actor channel updates
if set_module_var: if set_module_var:
global _subscriber set_subscriber(topic, subscriber)
set_subscriber(subscriber)
try:
yield subscriber yield subscriber
except trio.Cancelled: n.cancel_scope.cancel()
with trio.CancelScope(shield=True):
await subscriber.aclose()
raise