Pubsub topics, enc & decoders

Implicit aclose on all channels on ChannelManager aclose
Implicit nursery cancel on pubsub acms
Use long running actor portal for open_{pub,sub}_channel_at fns
Add optional encoder/decoder on pubsub
Add topic system for multiple pub or sub on same actor
Add wait fn for sub and pub channel register
one_ring_to_rule_them_all
Guillermo Rodriguez 2025-04-22 01:46:41 -03:00
parent 8799cf3b78
commit a553446619
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
1 changed files with 161 additions and 131 deletions

View File

@ -31,8 +31,14 @@ from dataclasses import dataclass
import trio
import tractor
from msgspec.msgpack import (
Encoder,
Decoder
)
from tractor.ipc._ringbuf import (
RBToken,
PayloadT,
RingBufferSendChannel,
RingBufferReceiveChannel,
attach_to_ringbuf_sender,
@ -242,6 +248,7 @@ class ChannelManager(Generic[ChannelType]):
if info.channel.closed:
continue
await info.channel.aclose()
await self.remove_channel(info.token.shm_name)
self._is_closed = True
@ -253,7 +260,7 @@ Ring buffer publisher & subscribe pattern mediated by `ringd` actor.
'''
class RingBufferPublisher(trio.abc.SendChannel[bytes]):
class RingBufferPublisher(trio.abc.SendChannel[PayloadT]):
'''
Use ChannelManager to create a multi ringbuf round robin sender that can
dynamically add or remove more outputs.
@ -270,13 +277,16 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
msgs_per_turn: int = 1,
# global batch size for all channels
batch_size: int = 1
batch_size: int = 1,
encoder: Encoder | None = None
):
self._batch_size: int = batch_size
self.msgs_per_turn = msgs_per_turn
self._enc = encoder
# helper to manage acms + long running tasks
self._chanmngr = ChannelManager[RingBufferSendChannel](
self._chanmngr = ChannelManager[RingBufferSendChannel[PayloadT]](
n,
self._open_channel,
self._channel_task
@ -349,10 +359,11 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
self,
token: RBToken
) -> AsyncContextManager[RingBufferSendChannel]:
) -> AsyncContextManager[RingBufferSendChannel[PayloadT]]:
async with attach_to_ringbuf_sender(
token,
batch_size=self._batch_size
batch_size=self._batch_size,
encoder=self._enc
) as ring:
yield ring
@ -387,7 +398,7 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
info = self.channels[turn]
await info.channel.send(msg)
async def broadcast(self, msg: bytes):
async def broadcast(self, msg: PayloadT):
'''
Send a msg to all channels, if no channels connected, does nothing.
'''
@ -406,8 +417,8 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
...
async def __aenter__(self):
self._chanmngr.open()
self._is_closed = False
self._chanmngr.open()
return self
async def aclose(self) -> None:
@ -420,7 +431,7 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
self._is_closed = True
class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
class RingBufferSubscriber(trio.abc.ReceiveChannel[PayloadT]):
'''
Use ChannelManager to create a multi ringbuf receiver that can
dynamically add or remove more inputs and combine all into a single output.
@ -440,11 +451,10 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
self,
n: trio.Nursery,
# if connecting to a publisher that has already sent messages set
# to the next expected payload index this subscriber will receive
start_index: int = 0
decoder: Decoder | None = None
):
self._chanmngr = ChannelManager[RingBufferReceiveChannel](
self._dec = decoder
self._chanmngr = ChannelManager[RingBufferReceiveChannel[PayloadT]](
n,
self._open_channel,
self._channel_task
@ -483,7 +493,10 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
token: RBToken
) -> AsyncContextManager[RingBufferSendChannel]:
async with attach_to_ringbuf_receiver(token) as ring:
async with attach_to_ringbuf_receiver(
token,
decoder=self._dec
) as ring:
yield ring
async def _channel_task(self, info: ChannelInfo) -> None:
@ -509,7 +522,7 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
except trio.ClosedResourceError:
break
async def receive(self) -> bytes:
async def receive(self) -> PayloadT:
'''
Receive next in order msg
'''
@ -543,73 +556,74 @@ Actor module for managing publisher & subscriber channels remotely through
`tractor.context` rpc
'''
_publisher: RingBufferPublisher | None = None
_subscriber: RingBufferSubscriber | None = None
@dataclass
class PublisherEntry:
publisher: RingBufferPublisher | None = None
is_set: trio.Event = trio.Event()
def set_publisher(pub: RingBufferPublisher):
global _publisher
_publishers: dict[str, PublisherEntry] = {}
if _publisher:
def maybe_init_publisher(topic: str) -> PublisherEntry:
entry = _publishers.get(topic, None)
if not entry:
entry = PublisherEntry()
_publishers[topic] = entry
return entry
def set_publisher(topic: str, pub: RingBufferPublisher):
global _publishers
entry = _publishers.get(topic, None)
if not entry:
entry = maybe_init_publisher(topic)
if entry.publisher:
raise RuntimeError(
f'publisher already set on {tractor.current_actor()}'
f'publisher for topic {topic} already set on {tractor.current_actor()}'
)
_publisher = pub
entry.publisher = pub
entry.is_set.set()
def set_subscriber(sub: RingBufferSubscriber):
global _subscriber
if _subscriber:
raise RuntimeError(
f'subscriber already set on {tractor.current_actor()}'
)
_subscriber = sub
def get_publisher() -> RingBufferPublisher:
global _publisher
if not _publisher:
def get_publisher(topic: str) -> RingBufferPublisher:
entry = _publishers.get(topic, None)
if not entry or not entry.publisher:
raise RuntimeError(
f'{tractor.current_actor()} tried to get publisher'
'but it\'s not set'
)
return _publisher
return entry.publisher
def get_subscriber() -> RingBufferSubscriber:
global _subscriber
if not _subscriber:
raise RuntimeError(
f'{tractor.current_actor()} tried to get subscriber'
'but it\'s not set'
)
return _subscriber
async def wait_publisher(topic: str) -> RingBufferPublisher:
entry = maybe_init_publisher(topic)
await entry.is_set.wait()
return entry.publisher
@tractor.context
async def _add_pub_channel(
ctx: tractor.Context,
topic: str,
token: RBToken
):
publisher = get_publisher()
await ctx.started()
publisher = await wait_publisher(topic)
await publisher.add_channel(token)
@tractor.context
async def _remove_pub_channel(
ctx: tractor.Context,
topic: str,
ring_name: str
):
publisher = get_publisher()
await ctx.started()
publisher = await wait_publisher(topic)
maybe_token = fdshare.maybe_get_fds(ring_name)
if maybe_token:
await publisher.remove_channel(ring_name)
@ -619,18 +633,10 @@ async def _remove_pub_channel(
async def open_pub_channel_at(
actor_name: str,
token: RBToken,
cleanup: bool = True,
topic: str = 'default',
):
async with (
tractor.find_actor(actor_name) as portal,
portal.open_context(
_add_pub_channel,
token=token
) as (ctx, _)
):
...
async with tractor.find_actor(actor_name) as portal:
await portal.run(_add_pub_channel, topic=topic, token=token)
try:
yield
@ -642,36 +648,77 @@ async def open_pub_channel_at(
)
raise
finally:
if not cleanup:
return
await portal.run(_remove_pub_channel, topic=topic, ring_name=token.shm_name)
async with tractor.find_actor(actor_name) as portal:
if portal:
async with portal.open_context(
_remove_pub_channel,
ring_name=token.shm_name
) as (ctx, _):
...
@dataclass
class SubscriberEntry:
subscriber: RingBufferSubscriber | None = None
is_set: trio.Event = trio.Event()
_subscribers: dict[str, SubscriberEntry] = {}
def maybe_init_subscriber(topic: str) -> SubscriberEntry:
entry = _subscribers.get(topic, None)
if not entry:
entry = SubscriberEntry()
_subscribers[topic] = entry
return entry
def set_subscriber(topic: str, sub: RingBufferSubscriber):
global _subscribers
entry = _subscribers.get(topic, None)
if not entry:
entry = maybe_init_subscriber(topic)
if entry.subscriber:
raise RuntimeError(
f'subscriber for topic {topic} already set on {tractor.current_actor()}'
)
entry.subscriber = sub
entry.is_set.set()
def get_subscriber(topic: str) -> RingBufferSubscriber:
entry = _subscribers.get(topic, None)
if not entry or not entry.subscriber:
raise RuntimeError(
f'{tractor.current_actor()} tried to get subscriber'
'but it\'s not set'
)
return entry.subscriber
async def wait_subscriber(topic: str) -> RingBufferSubscriber:
entry = maybe_init_subscriber(topic)
await entry.is_set.wait()
return entry.subscriber
@tractor.context
async def _add_sub_channel(
ctx: tractor.Context,
topic: str,
token: RBToken
):
subscriber = get_subscriber()
await ctx.started()
subscriber = await wait_subscriber(topic)
await subscriber.add_channel(token)
@tractor.context
async def _remove_sub_channel(
ctx: tractor.Context,
topic: str,
ring_name: str
):
subscriber = get_subscriber()
await ctx.started()
subscriber = await wait_subscriber(topic)
maybe_token = fdshare.maybe_get_fds(ring_name)
if maybe_token:
await subscriber.remove_channel(ring_name)
@ -681,18 +728,10 @@ async def _remove_sub_channel(
async def open_sub_channel_at(
actor_name: str,
token: RBToken,
cleanup: bool = True,
topic: str = 'default',
):
async with (
tractor.find_actor(actor_name) as portal,
portal.open_context(
_add_sub_channel,
token=token
) as (ctx, _)
):
...
async with tractor.find_actor(actor_name) as portal:
await portal.run(_add_sub_channel, topic=topic, token=token)
try:
yield
@ -704,18 +743,7 @@ async def open_sub_channel_at(
)
raise
finally:
if not cleanup:
return
async with tractor.find_actor(actor_name) as portal:
if portal:
async with portal.open_context(
_remove_sub_channel,
ring_name=token.shm_name
) as (ctx, _):
...
await portal.run(_remove_sub_channel, topic=topic, ring_name=token.shm_name)
'''
@ -725,12 +753,17 @@ High level helpers to open publisher & subscriber
@acm
async def open_ringbuf_publisher(
# name to distinguish this publisher
topic: str = 'default',
# global batch size for channels
batch_size: int = 1,
# messages before changing output channel
msgs_per_turn: int = 1,
encoder: Encoder | None = None,
# ensure subscriber receives in same order publisher sent
# causes it to use wrapped payloads which contain the og
# index
@ -750,26 +783,28 @@ async def open_ringbuf_publisher(
trio.open_nursery(strict_exception_groups=False) as n,
RingBufferPublisher(
n,
batch_size=batch_size
batch_size=batch_size,
encoder=encoder,
) as publisher
):
if guarantee_order:
order_send_channel(publisher)
if set_module_var:
set_publisher(publisher)
set_publisher(topic, publisher)
try:
yield publisher
except trio.Cancelled:
with trio.CancelScope(shield=True):
await publisher.aclose()
raise
n.cancel_scope.cancel()
@acm
async def open_ringbuf_subscriber(
# name to distinguish this subscriber
topic: str = 'default',
decoder: Decoder | None = None,
# expect indexed payloads and unwrap them in order
guarantee_order: bool = False,
@ -784,7 +819,7 @@ async def open_ringbuf_subscriber(
'''
async with (
trio.open_nursery(strict_exception_groups=False) as n,
RingBufferSubscriber(n) as subscriber
RingBufferSubscriber(n, decoder=decoder) as subscriber
):
# maybe monkey patch `.receive` to use indexed payloads
if guarantee_order:
@ -792,13 +827,8 @@ async def open_ringbuf_subscriber(
# maybe set global module var for remote actor channel updates
if set_module_var:
global _subscriber
set_subscriber(subscriber)
set_subscriber(topic, subscriber)
try:
yield subscriber
except trio.Cancelled:
with trio.CancelScope(shield=True):
await subscriber.aclose()
raise
n.cancel_scope.cancel()