Pubsub topics, enc & decoders
Implicit aclose on all channels on ChannelManager aclose Implicit nursery cancel on pubsub acms Use long running actor portal for open_{pub,sub}_channel_at fns Add optional encoder/decoder on pubsub Add topic system for multiple pub or sub on same actor Add wait fn for sub and pub channel registerone_ring_to_rule_them_all
parent
8799cf3b78
commit
a553446619
|
@ -31,8 +31,14 @@ from dataclasses import dataclass
|
|||
import trio
|
||||
import tractor
|
||||
|
||||
from msgspec.msgpack import (
|
||||
Encoder,
|
||||
Decoder
|
||||
)
|
||||
|
||||
from tractor.ipc._ringbuf import (
|
||||
RBToken,
|
||||
PayloadT,
|
||||
RingBufferSendChannel,
|
||||
RingBufferReceiveChannel,
|
||||
attach_to_ringbuf_sender,
|
||||
|
@ -242,6 +248,7 @@ class ChannelManager(Generic[ChannelType]):
|
|||
if info.channel.closed:
|
||||
continue
|
||||
|
||||
await info.channel.aclose()
|
||||
await self.remove_channel(info.token.shm_name)
|
||||
|
||||
self._is_closed = True
|
||||
|
@ -253,7 +260,7 @@ Ring buffer publisher & subscribe pattern mediated by `ringd` actor.
|
|||
'''
|
||||
|
||||
|
||||
class RingBufferPublisher(trio.abc.SendChannel[bytes]):
|
||||
class RingBufferPublisher(trio.abc.SendChannel[PayloadT]):
|
||||
'''
|
||||
Use ChannelManager to create a multi ringbuf round robin sender that can
|
||||
dynamically add or remove more outputs.
|
||||
|
@ -270,13 +277,16 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
|
|||
msgs_per_turn: int = 1,
|
||||
|
||||
# global batch size for all channels
|
||||
batch_size: int = 1
|
||||
batch_size: int = 1,
|
||||
|
||||
encoder: Encoder | None = None
|
||||
):
|
||||
self._batch_size: int = batch_size
|
||||
self.msgs_per_turn = msgs_per_turn
|
||||
self._enc = encoder
|
||||
|
||||
# helper to manage acms + long running tasks
|
||||
self._chanmngr = ChannelManager[RingBufferSendChannel](
|
||||
self._chanmngr = ChannelManager[RingBufferSendChannel[PayloadT]](
|
||||
n,
|
||||
self._open_channel,
|
||||
self._channel_task
|
||||
|
@ -349,10 +359,11 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
|
|||
self,
|
||||
token: RBToken
|
||||
|
||||
) -> AsyncContextManager[RingBufferSendChannel]:
|
||||
) -> AsyncContextManager[RingBufferSendChannel[PayloadT]]:
|
||||
async with attach_to_ringbuf_sender(
|
||||
token,
|
||||
batch_size=self._batch_size
|
||||
batch_size=self._batch_size,
|
||||
encoder=self._enc
|
||||
) as ring:
|
||||
yield ring
|
||||
|
||||
|
@ -387,7 +398,7 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
|
|||
info = self.channels[turn]
|
||||
await info.channel.send(msg)
|
||||
|
||||
async def broadcast(self, msg: bytes):
|
||||
async def broadcast(self, msg: PayloadT):
|
||||
'''
|
||||
Send a msg to all channels, if no channels connected, does nothing.
|
||||
'''
|
||||
|
@ -406,8 +417,8 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
|
|||
...
|
||||
|
||||
async def __aenter__(self):
|
||||
self._chanmngr.open()
|
||||
self._is_closed = False
|
||||
self._chanmngr.open()
|
||||
return self
|
||||
|
||||
async def aclose(self) -> None:
|
||||
|
@ -420,7 +431,7 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
|
|||
self._is_closed = True
|
||||
|
||||
|
||||
class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
|
||||
class RingBufferSubscriber(trio.abc.ReceiveChannel[PayloadT]):
|
||||
'''
|
||||
Use ChannelManager to create a multi ringbuf receiver that can
|
||||
dynamically add or remove more inputs and combine all into a single output.
|
||||
|
@ -440,11 +451,10 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
|
|||
self,
|
||||
n: trio.Nursery,
|
||||
|
||||
# if connecting to a publisher that has already sent messages set
|
||||
# to the next expected payload index this subscriber will receive
|
||||
start_index: int = 0
|
||||
decoder: Decoder | None = None
|
||||
):
|
||||
self._chanmngr = ChannelManager[RingBufferReceiveChannel](
|
||||
self._dec = decoder
|
||||
self._chanmngr = ChannelManager[RingBufferReceiveChannel[PayloadT]](
|
||||
n,
|
||||
self._open_channel,
|
||||
self._channel_task
|
||||
|
@ -483,7 +493,10 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
|
|||
token: RBToken
|
||||
|
||||
) -> AsyncContextManager[RingBufferSendChannel]:
|
||||
async with attach_to_ringbuf_receiver(token) as ring:
|
||||
async with attach_to_ringbuf_receiver(
|
||||
token,
|
||||
decoder=self._dec
|
||||
) as ring:
|
||||
yield ring
|
||||
|
||||
async def _channel_task(self, info: ChannelInfo) -> None:
|
||||
|
@ -509,7 +522,7 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
|
|||
except trio.ClosedResourceError:
|
||||
break
|
||||
|
||||
async def receive(self) -> bytes:
|
||||
async def receive(self) -> PayloadT:
|
||||
'''
|
||||
Receive next in order msg
|
||||
'''
|
||||
|
@ -543,73 +556,74 @@ Actor module for managing publisher & subscriber channels remotely through
|
|||
`tractor.context` rpc
|
||||
'''
|
||||
|
||||
_publisher: RingBufferPublisher | None = None
|
||||
_subscriber: RingBufferSubscriber | None = None
|
||||
@dataclass
|
||||
class PublisherEntry:
|
||||
publisher: RingBufferPublisher | None = None
|
||||
is_set: trio.Event = trio.Event()
|
||||
|
||||
|
||||
def set_publisher(pub: RingBufferPublisher):
|
||||
global _publisher
|
||||
_publishers: dict[str, PublisherEntry] = {}
|
||||
|
||||
if _publisher:
|
||||
|
||||
def maybe_init_publisher(topic: str) -> PublisherEntry:
|
||||
entry = _publishers.get(topic, None)
|
||||
if not entry:
|
||||
entry = PublisherEntry()
|
||||
_publishers[topic] = entry
|
||||
|
||||
return entry
|
||||
|
||||
|
||||
def set_publisher(topic: str, pub: RingBufferPublisher):
|
||||
global _publishers
|
||||
|
||||
entry = _publishers.get(topic, None)
|
||||
if not entry:
|
||||
entry = maybe_init_publisher(topic)
|
||||
|
||||
if entry.publisher:
|
||||
raise RuntimeError(
|
||||
f'publisher already set on {tractor.current_actor()}'
|
||||
f'publisher for topic {topic} already set on {tractor.current_actor()}'
|
||||
)
|
||||
|
||||
_publisher = pub
|
||||
entry.publisher = pub
|
||||
entry.is_set.set()
|
||||
|
||||
|
||||
def set_subscriber(sub: RingBufferSubscriber):
|
||||
global _subscriber
|
||||
|
||||
if _subscriber:
|
||||
raise RuntimeError(
|
||||
f'subscriber already set on {tractor.current_actor()}'
|
||||
)
|
||||
|
||||
_subscriber = sub
|
||||
|
||||
|
||||
def get_publisher() -> RingBufferPublisher:
|
||||
global _publisher
|
||||
|
||||
if not _publisher:
|
||||
def get_publisher(topic: str) -> RingBufferPublisher:
|
||||
entry = _publishers.get(topic, None)
|
||||
if not entry or not entry.publisher:
|
||||
raise RuntimeError(
|
||||
f'{tractor.current_actor()} tried to get publisher'
|
||||
'but it\'s not set'
|
||||
)
|
||||
|
||||
return _publisher
|
||||
return entry.publisher
|
||||
|
||||
|
||||
def get_subscriber() -> RingBufferSubscriber:
|
||||
global _subscriber
|
||||
|
||||
if not _subscriber:
|
||||
raise RuntimeError(
|
||||
f'{tractor.current_actor()} tried to get subscriber'
|
||||
'but it\'s not set'
|
||||
)
|
||||
|
||||
return _subscriber
|
||||
async def wait_publisher(topic: str) -> RingBufferPublisher:
|
||||
entry = maybe_init_publisher(topic)
|
||||
await entry.is_set.wait()
|
||||
return entry.publisher
|
||||
|
||||
|
||||
@tractor.context
|
||||
async def _add_pub_channel(
|
||||
ctx: tractor.Context,
|
||||
topic: str,
|
||||
token: RBToken
|
||||
):
|
||||
publisher = get_publisher()
|
||||
await ctx.started()
|
||||
publisher = await wait_publisher(topic)
|
||||
await publisher.add_channel(token)
|
||||
|
||||
|
||||
@tractor.context
|
||||
async def _remove_pub_channel(
|
||||
ctx: tractor.Context,
|
||||
topic: str,
|
||||
ring_name: str
|
||||
):
|
||||
publisher = get_publisher()
|
||||
await ctx.started()
|
||||
publisher = await wait_publisher(topic)
|
||||
maybe_token = fdshare.maybe_get_fds(ring_name)
|
||||
if maybe_token:
|
||||
await publisher.remove_channel(ring_name)
|
||||
|
@ -619,59 +633,92 @@ async def _remove_pub_channel(
|
|||
async def open_pub_channel_at(
|
||||
actor_name: str,
|
||||
token: RBToken,
|
||||
cleanup: bool = True,
|
||||
topic: str = 'default',
|
||||
):
|
||||
async with (
|
||||
tractor.find_actor(actor_name) as portal,
|
||||
async with tractor.find_actor(actor_name) as portal:
|
||||
await portal.run(_add_pub_channel, topic=topic, token=token)
|
||||
try:
|
||||
yield
|
||||
|
||||
portal.open_context(
|
||||
_add_pub_channel,
|
||||
token=token
|
||||
) as (ctx, _)
|
||||
):
|
||||
...
|
||||
except trio.Cancelled:
|
||||
log.warning(
|
||||
'open_pub_channel_at got cancelled!\n'
|
||||
f'\tactor_name = {actor_name}\n'
|
||||
f'\ttoken = {token}\n'
|
||||
)
|
||||
raise
|
||||
|
||||
try:
|
||||
yield
|
||||
await portal.run(_remove_pub_channel, topic=topic, ring_name=token.shm_name)
|
||||
|
||||
except trio.Cancelled:
|
||||
log.warning(
|
||||
'open_pub_channel_at got cancelled!\n'
|
||||
f'\tactor_name = {actor_name}\n'
|
||||
f'\ttoken = {token}\n'
|
||||
|
||||
@dataclass
|
||||
class SubscriberEntry:
|
||||
subscriber: RingBufferSubscriber | None = None
|
||||
is_set: trio.Event = trio.Event()
|
||||
|
||||
|
||||
_subscribers: dict[str, SubscriberEntry] = {}
|
||||
|
||||
|
||||
def maybe_init_subscriber(topic: str) -> SubscriberEntry:
|
||||
entry = _subscribers.get(topic, None)
|
||||
if not entry:
|
||||
entry = SubscriberEntry()
|
||||
_subscribers[topic] = entry
|
||||
|
||||
return entry
|
||||
|
||||
|
||||
def set_subscriber(topic: str, sub: RingBufferSubscriber):
|
||||
global _subscribers
|
||||
|
||||
entry = _subscribers.get(topic, None)
|
||||
if not entry:
|
||||
entry = maybe_init_subscriber(topic)
|
||||
|
||||
if entry.subscriber:
|
||||
raise RuntimeError(
|
||||
f'subscriber for topic {topic} already set on {tractor.current_actor()}'
|
||||
)
|
||||
raise
|
||||
|
||||
finally:
|
||||
if not cleanup:
|
||||
return
|
||||
entry.subscriber = sub
|
||||
entry.is_set.set()
|
||||
|
||||
async with tractor.find_actor(actor_name) as portal:
|
||||
if portal:
|
||||
async with portal.open_context(
|
||||
_remove_pub_channel,
|
||||
ring_name=token.shm_name
|
||||
) as (ctx, _):
|
||||
...
|
||||
|
||||
def get_subscriber(topic: str) -> RingBufferSubscriber:
|
||||
entry = _subscribers.get(topic, None)
|
||||
if not entry or not entry.subscriber:
|
||||
raise RuntimeError(
|
||||
f'{tractor.current_actor()} tried to get subscriber'
|
||||
'but it\'s not set'
|
||||
)
|
||||
|
||||
return entry.subscriber
|
||||
|
||||
|
||||
async def wait_subscriber(topic: str) -> RingBufferSubscriber:
|
||||
entry = maybe_init_subscriber(topic)
|
||||
await entry.is_set.wait()
|
||||
return entry.subscriber
|
||||
|
||||
|
||||
@tractor.context
|
||||
async def _add_sub_channel(
|
||||
ctx: tractor.Context,
|
||||
topic: str,
|
||||
token: RBToken
|
||||
):
|
||||
subscriber = get_subscriber()
|
||||
await ctx.started()
|
||||
subscriber = await wait_subscriber(topic)
|
||||
await subscriber.add_channel(token)
|
||||
|
||||
|
||||
@tractor.context
|
||||
async def _remove_sub_channel(
|
||||
ctx: tractor.Context,
|
||||
topic: str,
|
||||
ring_name: str
|
||||
):
|
||||
subscriber = get_subscriber()
|
||||
await ctx.started()
|
||||
subscriber = await wait_subscriber(topic)
|
||||
maybe_token = fdshare.maybe_get_fds(ring_name)
|
||||
if maybe_token:
|
||||
await subscriber.remove_channel(ring_name)
|
||||
|
@ -681,41 +728,22 @@ async def _remove_sub_channel(
|
|||
async def open_sub_channel_at(
|
||||
actor_name: str,
|
||||
token: RBToken,
|
||||
cleanup: bool = True,
|
||||
topic: str = 'default',
|
||||
):
|
||||
async with (
|
||||
tractor.find_actor(actor_name) as portal,
|
||||
async with tractor.find_actor(actor_name) as portal:
|
||||
await portal.run(_add_sub_channel, topic=topic, token=token)
|
||||
try:
|
||||
yield
|
||||
|
||||
portal.open_context(
|
||||
_add_sub_channel,
|
||||
token=token
|
||||
) as (ctx, _)
|
||||
):
|
||||
...
|
||||
|
||||
try:
|
||||
yield
|
||||
|
||||
except trio.Cancelled:
|
||||
log.warning(
|
||||
'open_sub_channel_at got cancelled!\n'
|
||||
f'\tactor_name = {actor_name}\n'
|
||||
f'\ttoken = {token}\n'
|
||||
)
|
||||
raise
|
||||
|
||||
finally:
|
||||
if not cleanup:
|
||||
return
|
||||
|
||||
async with tractor.find_actor(actor_name) as portal:
|
||||
if portal:
|
||||
async with portal.open_context(
|
||||
_remove_sub_channel,
|
||||
ring_name=token.shm_name
|
||||
) as (ctx, _):
|
||||
...
|
||||
except trio.Cancelled:
|
||||
log.warning(
|
||||
'open_sub_channel_at got cancelled!\n'
|
||||
f'\tactor_name = {actor_name}\n'
|
||||
f'\ttoken = {token}\n'
|
||||
)
|
||||
raise
|
||||
|
||||
await portal.run(_remove_sub_channel, topic=topic, ring_name=token.shm_name)
|
||||
|
||||
|
||||
'''
|
||||
|
@ -725,12 +753,17 @@ High level helpers to open publisher & subscriber
|
|||
|
||||
@acm
|
||||
async def open_ringbuf_publisher(
|
||||
# name to distinguish this publisher
|
||||
topic: str = 'default',
|
||||
|
||||
# global batch size for channels
|
||||
batch_size: int = 1,
|
||||
|
||||
# messages before changing output channel
|
||||
msgs_per_turn: int = 1,
|
||||
|
||||
encoder: Encoder | None = None,
|
||||
|
||||
# ensure subscriber receives in same order publisher sent
|
||||
# causes it to use wrapped payloads which contain the og
|
||||
# index
|
||||
|
@ -750,26 +783,28 @@ async def open_ringbuf_publisher(
|
|||
trio.open_nursery(strict_exception_groups=False) as n,
|
||||
RingBufferPublisher(
|
||||
n,
|
||||
batch_size=batch_size
|
||||
batch_size=batch_size,
|
||||
encoder=encoder,
|
||||
) as publisher
|
||||
):
|
||||
if guarantee_order:
|
||||
order_send_channel(publisher)
|
||||
|
||||
if set_module_var:
|
||||
set_publisher(publisher)
|
||||
set_publisher(topic, publisher)
|
||||
|
||||
try:
|
||||
yield publisher
|
||||
yield publisher
|
||||
|
||||
except trio.Cancelled:
|
||||
with trio.CancelScope(shield=True):
|
||||
await publisher.aclose()
|
||||
raise
|
||||
n.cancel_scope.cancel()
|
||||
|
||||
|
||||
@acm
|
||||
async def open_ringbuf_subscriber(
|
||||
# name to distinguish this subscriber
|
||||
topic: str = 'default',
|
||||
|
||||
decoder: Decoder | None = None,
|
||||
|
||||
# expect indexed payloads and unwrap them in order
|
||||
guarantee_order: bool = False,
|
||||
|
||||
|
@ -784,7 +819,7 @@ async def open_ringbuf_subscriber(
|
|||
'''
|
||||
async with (
|
||||
trio.open_nursery(strict_exception_groups=False) as n,
|
||||
RingBufferSubscriber(n) as subscriber
|
||||
RingBufferSubscriber(n, decoder=decoder) as subscriber
|
||||
):
|
||||
# maybe monkey patch `.receive` to use indexed payloads
|
||||
if guarantee_order:
|
||||
|
@ -792,13 +827,8 @@ async def open_ringbuf_subscriber(
|
|||
|
||||
# maybe set global module var for remote actor channel updates
|
||||
if set_module_var:
|
||||
global _subscriber
|
||||
set_subscriber(subscriber)
|
||||
set_subscriber(topic, subscriber)
|
||||
|
||||
try:
|
||||
yield subscriber
|
||||
yield subscriber
|
||||
|
||||
except trio.Cancelled:
|
||||
with trio.CancelScope(shield=True):
|
||||
await subscriber.aclose()
|
||||
raise
|
||||
n.cancel_scope.cancel()
|
||||
|
|
Loading…
Reference in New Issue