Pubsub topics, enc & decoders
Implicit aclose on all channels on ChannelManager aclose Implicit nursery cancel on pubsub acms Use long running actor portal for open_{pub,sub}_channel_at fns Add optional encoder/decoder on pubsub Add topic system for multiple pub or sub on same actor Add wait fn for sub and pub channel registerone_ring_to_rule_them_all
parent
8799cf3b78
commit
a553446619
|
@ -31,8 +31,14 @@ from dataclasses import dataclass
|
||||||
import trio
|
import trio
|
||||||
import tractor
|
import tractor
|
||||||
|
|
||||||
|
from msgspec.msgpack import (
|
||||||
|
Encoder,
|
||||||
|
Decoder
|
||||||
|
)
|
||||||
|
|
||||||
from tractor.ipc._ringbuf import (
|
from tractor.ipc._ringbuf import (
|
||||||
RBToken,
|
RBToken,
|
||||||
|
PayloadT,
|
||||||
RingBufferSendChannel,
|
RingBufferSendChannel,
|
||||||
RingBufferReceiveChannel,
|
RingBufferReceiveChannel,
|
||||||
attach_to_ringbuf_sender,
|
attach_to_ringbuf_sender,
|
||||||
|
@ -242,6 +248,7 @@ class ChannelManager(Generic[ChannelType]):
|
||||||
if info.channel.closed:
|
if info.channel.closed:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
await info.channel.aclose()
|
||||||
await self.remove_channel(info.token.shm_name)
|
await self.remove_channel(info.token.shm_name)
|
||||||
|
|
||||||
self._is_closed = True
|
self._is_closed = True
|
||||||
|
@ -253,7 +260,7 @@ Ring buffer publisher & subscribe pattern mediated by `ringd` actor.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
|
||||||
class RingBufferPublisher(trio.abc.SendChannel[bytes]):
|
class RingBufferPublisher(trio.abc.SendChannel[PayloadT]):
|
||||||
'''
|
'''
|
||||||
Use ChannelManager to create a multi ringbuf round robin sender that can
|
Use ChannelManager to create a multi ringbuf round robin sender that can
|
||||||
dynamically add or remove more outputs.
|
dynamically add or remove more outputs.
|
||||||
|
@ -270,13 +277,16 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
|
||||||
msgs_per_turn: int = 1,
|
msgs_per_turn: int = 1,
|
||||||
|
|
||||||
# global batch size for all channels
|
# global batch size for all channels
|
||||||
batch_size: int = 1
|
batch_size: int = 1,
|
||||||
|
|
||||||
|
encoder: Encoder | None = None
|
||||||
):
|
):
|
||||||
self._batch_size: int = batch_size
|
self._batch_size: int = batch_size
|
||||||
self.msgs_per_turn = msgs_per_turn
|
self.msgs_per_turn = msgs_per_turn
|
||||||
|
self._enc = encoder
|
||||||
|
|
||||||
# helper to manage acms + long running tasks
|
# helper to manage acms + long running tasks
|
||||||
self._chanmngr = ChannelManager[RingBufferSendChannel](
|
self._chanmngr = ChannelManager[RingBufferSendChannel[PayloadT]](
|
||||||
n,
|
n,
|
||||||
self._open_channel,
|
self._open_channel,
|
||||||
self._channel_task
|
self._channel_task
|
||||||
|
@ -349,10 +359,11 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
|
||||||
self,
|
self,
|
||||||
token: RBToken
|
token: RBToken
|
||||||
|
|
||||||
) -> AsyncContextManager[RingBufferSendChannel]:
|
) -> AsyncContextManager[RingBufferSendChannel[PayloadT]]:
|
||||||
async with attach_to_ringbuf_sender(
|
async with attach_to_ringbuf_sender(
|
||||||
token,
|
token,
|
||||||
batch_size=self._batch_size
|
batch_size=self._batch_size,
|
||||||
|
encoder=self._enc
|
||||||
) as ring:
|
) as ring:
|
||||||
yield ring
|
yield ring
|
||||||
|
|
||||||
|
@ -387,7 +398,7 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
|
||||||
info = self.channels[turn]
|
info = self.channels[turn]
|
||||||
await info.channel.send(msg)
|
await info.channel.send(msg)
|
||||||
|
|
||||||
async def broadcast(self, msg: bytes):
|
async def broadcast(self, msg: PayloadT):
|
||||||
'''
|
'''
|
||||||
Send a msg to all channels, if no channels connected, does nothing.
|
Send a msg to all channels, if no channels connected, does nothing.
|
||||||
'''
|
'''
|
||||||
|
@ -406,8 +417,8 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
|
||||||
...
|
...
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
self._chanmngr.open()
|
|
||||||
self._is_closed = False
|
self._is_closed = False
|
||||||
|
self._chanmngr.open()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def aclose(self) -> None:
|
async def aclose(self) -> None:
|
||||||
|
@ -420,7 +431,7 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
|
||||||
self._is_closed = True
|
self._is_closed = True
|
||||||
|
|
||||||
|
|
||||||
class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
|
class RingBufferSubscriber(trio.abc.ReceiveChannel[PayloadT]):
|
||||||
'''
|
'''
|
||||||
Use ChannelManager to create a multi ringbuf receiver that can
|
Use ChannelManager to create a multi ringbuf receiver that can
|
||||||
dynamically add or remove more inputs and combine all into a single output.
|
dynamically add or remove more inputs and combine all into a single output.
|
||||||
|
@ -440,11 +451,10 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
|
||||||
self,
|
self,
|
||||||
n: trio.Nursery,
|
n: trio.Nursery,
|
||||||
|
|
||||||
# if connecting to a publisher that has already sent messages set
|
decoder: Decoder | None = None
|
||||||
# to the next expected payload index this subscriber will receive
|
|
||||||
start_index: int = 0
|
|
||||||
):
|
):
|
||||||
self._chanmngr = ChannelManager[RingBufferReceiveChannel](
|
self._dec = decoder
|
||||||
|
self._chanmngr = ChannelManager[RingBufferReceiveChannel[PayloadT]](
|
||||||
n,
|
n,
|
||||||
self._open_channel,
|
self._open_channel,
|
||||||
self._channel_task
|
self._channel_task
|
||||||
|
@ -483,7 +493,10 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
|
||||||
token: RBToken
|
token: RBToken
|
||||||
|
|
||||||
) -> AsyncContextManager[RingBufferSendChannel]:
|
) -> AsyncContextManager[RingBufferSendChannel]:
|
||||||
async with attach_to_ringbuf_receiver(token) as ring:
|
async with attach_to_ringbuf_receiver(
|
||||||
|
token,
|
||||||
|
decoder=self._dec
|
||||||
|
) as ring:
|
||||||
yield ring
|
yield ring
|
||||||
|
|
||||||
async def _channel_task(self, info: ChannelInfo) -> None:
|
async def _channel_task(self, info: ChannelInfo) -> None:
|
||||||
|
@ -509,7 +522,7 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
|
||||||
except trio.ClosedResourceError:
|
except trio.ClosedResourceError:
|
||||||
break
|
break
|
||||||
|
|
||||||
async def receive(self) -> bytes:
|
async def receive(self) -> PayloadT:
|
||||||
'''
|
'''
|
||||||
Receive next in order msg
|
Receive next in order msg
|
||||||
'''
|
'''
|
||||||
|
@ -543,73 +556,74 @@ Actor module for managing publisher & subscriber channels remotely through
|
||||||
`tractor.context` rpc
|
`tractor.context` rpc
|
||||||
'''
|
'''
|
||||||
|
|
||||||
_publisher: RingBufferPublisher | None = None
|
@dataclass
|
||||||
_subscriber: RingBufferSubscriber | None = None
|
class PublisherEntry:
|
||||||
|
publisher: RingBufferPublisher | None = None
|
||||||
|
is_set: trio.Event = trio.Event()
|
||||||
|
|
||||||
|
|
||||||
def set_publisher(pub: RingBufferPublisher):
|
_publishers: dict[str, PublisherEntry] = {}
|
||||||
global _publisher
|
|
||||||
|
|
||||||
if _publisher:
|
|
||||||
|
def maybe_init_publisher(topic: str) -> PublisherEntry:
|
||||||
|
entry = _publishers.get(topic, None)
|
||||||
|
if not entry:
|
||||||
|
entry = PublisherEntry()
|
||||||
|
_publishers[topic] = entry
|
||||||
|
|
||||||
|
return entry
|
||||||
|
|
||||||
|
|
||||||
|
def set_publisher(topic: str, pub: RingBufferPublisher):
|
||||||
|
global _publishers
|
||||||
|
|
||||||
|
entry = _publishers.get(topic, None)
|
||||||
|
if not entry:
|
||||||
|
entry = maybe_init_publisher(topic)
|
||||||
|
|
||||||
|
if entry.publisher:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f'publisher already set on {tractor.current_actor()}'
|
f'publisher for topic {topic} already set on {tractor.current_actor()}'
|
||||||
)
|
)
|
||||||
|
|
||||||
_publisher = pub
|
entry.publisher = pub
|
||||||
|
entry.is_set.set()
|
||||||
|
|
||||||
|
|
||||||
def set_subscriber(sub: RingBufferSubscriber):
|
def get_publisher(topic: str) -> RingBufferPublisher:
|
||||||
global _subscriber
|
entry = _publishers.get(topic, None)
|
||||||
|
if not entry or not entry.publisher:
|
||||||
if _subscriber:
|
|
||||||
raise RuntimeError(
|
|
||||||
f'subscriber already set on {tractor.current_actor()}'
|
|
||||||
)
|
|
||||||
|
|
||||||
_subscriber = sub
|
|
||||||
|
|
||||||
|
|
||||||
def get_publisher() -> RingBufferPublisher:
|
|
||||||
global _publisher
|
|
||||||
|
|
||||||
if not _publisher:
|
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f'{tractor.current_actor()} tried to get publisher'
|
f'{tractor.current_actor()} tried to get publisher'
|
||||||
'but it\'s not set'
|
'but it\'s not set'
|
||||||
)
|
)
|
||||||
|
|
||||||
return _publisher
|
return entry.publisher
|
||||||
|
|
||||||
|
|
||||||
def get_subscriber() -> RingBufferSubscriber:
|
async def wait_publisher(topic: str) -> RingBufferPublisher:
|
||||||
global _subscriber
|
entry = maybe_init_publisher(topic)
|
||||||
|
await entry.is_set.wait()
|
||||||
if not _subscriber:
|
return entry.publisher
|
||||||
raise RuntimeError(
|
|
||||||
f'{tractor.current_actor()} tried to get subscriber'
|
|
||||||
'but it\'s not set'
|
|
||||||
)
|
|
||||||
|
|
||||||
return _subscriber
|
|
||||||
|
|
||||||
|
|
||||||
@tractor.context
|
@tractor.context
|
||||||
async def _add_pub_channel(
|
async def _add_pub_channel(
|
||||||
ctx: tractor.Context,
|
ctx: tractor.Context,
|
||||||
|
topic: str,
|
||||||
token: RBToken
|
token: RBToken
|
||||||
):
|
):
|
||||||
publisher = get_publisher()
|
publisher = await wait_publisher(topic)
|
||||||
await ctx.started()
|
|
||||||
await publisher.add_channel(token)
|
await publisher.add_channel(token)
|
||||||
|
|
||||||
|
|
||||||
@tractor.context
|
@tractor.context
|
||||||
async def _remove_pub_channel(
|
async def _remove_pub_channel(
|
||||||
ctx: tractor.Context,
|
ctx: tractor.Context,
|
||||||
|
topic: str,
|
||||||
ring_name: str
|
ring_name: str
|
||||||
):
|
):
|
||||||
publisher = get_publisher()
|
publisher = await wait_publisher(topic)
|
||||||
await ctx.started()
|
|
||||||
maybe_token = fdshare.maybe_get_fds(ring_name)
|
maybe_token = fdshare.maybe_get_fds(ring_name)
|
||||||
if maybe_token:
|
if maybe_token:
|
||||||
await publisher.remove_channel(ring_name)
|
await publisher.remove_channel(ring_name)
|
||||||
|
@ -619,18 +633,10 @@ async def _remove_pub_channel(
|
||||||
async def open_pub_channel_at(
|
async def open_pub_channel_at(
|
||||||
actor_name: str,
|
actor_name: str,
|
||||||
token: RBToken,
|
token: RBToken,
|
||||||
cleanup: bool = True,
|
topic: str = 'default',
|
||||||
):
|
):
|
||||||
async with (
|
async with tractor.find_actor(actor_name) as portal:
|
||||||
tractor.find_actor(actor_name) as portal,
|
await portal.run(_add_pub_channel, topic=topic, token=token)
|
||||||
|
|
||||||
portal.open_context(
|
|
||||||
_add_pub_channel,
|
|
||||||
token=token
|
|
||||||
) as (ctx, _)
|
|
||||||
):
|
|
||||||
...
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
@ -642,36 +648,77 @@ async def open_pub_channel_at(
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
finally:
|
await portal.run(_remove_pub_channel, topic=topic, ring_name=token.shm_name)
|
||||||
if not cleanup:
|
|
||||||
return
|
|
||||||
|
|
||||||
async with tractor.find_actor(actor_name) as portal:
|
|
||||||
if portal:
|
@dataclass
|
||||||
async with portal.open_context(
|
class SubscriberEntry:
|
||||||
_remove_pub_channel,
|
subscriber: RingBufferSubscriber | None = None
|
||||||
ring_name=token.shm_name
|
is_set: trio.Event = trio.Event()
|
||||||
) as (ctx, _):
|
|
||||||
...
|
|
||||||
|
_subscribers: dict[str, SubscriberEntry] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_init_subscriber(topic: str) -> SubscriberEntry:
|
||||||
|
entry = _subscribers.get(topic, None)
|
||||||
|
if not entry:
|
||||||
|
entry = SubscriberEntry()
|
||||||
|
_subscribers[topic] = entry
|
||||||
|
|
||||||
|
return entry
|
||||||
|
|
||||||
|
|
||||||
|
def set_subscriber(topic: str, sub: RingBufferSubscriber):
|
||||||
|
global _subscribers
|
||||||
|
|
||||||
|
entry = _subscribers.get(topic, None)
|
||||||
|
if not entry:
|
||||||
|
entry = maybe_init_subscriber(topic)
|
||||||
|
|
||||||
|
if entry.subscriber:
|
||||||
|
raise RuntimeError(
|
||||||
|
f'subscriber for topic {topic} already set on {tractor.current_actor()}'
|
||||||
|
)
|
||||||
|
|
||||||
|
entry.subscriber = sub
|
||||||
|
entry.is_set.set()
|
||||||
|
|
||||||
|
|
||||||
|
def get_subscriber(topic: str) -> RingBufferSubscriber:
|
||||||
|
entry = _subscribers.get(topic, None)
|
||||||
|
if not entry or not entry.subscriber:
|
||||||
|
raise RuntimeError(
|
||||||
|
f'{tractor.current_actor()} tried to get subscriber'
|
||||||
|
'but it\'s not set'
|
||||||
|
)
|
||||||
|
|
||||||
|
return entry.subscriber
|
||||||
|
|
||||||
|
|
||||||
|
async def wait_subscriber(topic: str) -> RingBufferSubscriber:
|
||||||
|
entry = maybe_init_subscriber(topic)
|
||||||
|
await entry.is_set.wait()
|
||||||
|
return entry.subscriber
|
||||||
|
|
||||||
|
|
||||||
@tractor.context
|
@tractor.context
|
||||||
async def _add_sub_channel(
|
async def _add_sub_channel(
|
||||||
ctx: tractor.Context,
|
ctx: tractor.Context,
|
||||||
|
topic: str,
|
||||||
token: RBToken
|
token: RBToken
|
||||||
):
|
):
|
||||||
subscriber = get_subscriber()
|
subscriber = await wait_subscriber(topic)
|
||||||
await ctx.started()
|
|
||||||
await subscriber.add_channel(token)
|
await subscriber.add_channel(token)
|
||||||
|
|
||||||
|
|
||||||
@tractor.context
|
@tractor.context
|
||||||
async def _remove_sub_channel(
|
async def _remove_sub_channel(
|
||||||
ctx: tractor.Context,
|
ctx: tractor.Context,
|
||||||
|
topic: str,
|
||||||
ring_name: str
|
ring_name: str
|
||||||
):
|
):
|
||||||
subscriber = get_subscriber()
|
subscriber = await wait_subscriber(topic)
|
||||||
await ctx.started()
|
|
||||||
maybe_token = fdshare.maybe_get_fds(ring_name)
|
maybe_token = fdshare.maybe_get_fds(ring_name)
|
||||||
if maybe_token:
|
if maybe_token:
|
||||||
await subscriber.remove_channel(ring_name)
|
await subscriber.remove_channel(ring_name)
|
||||||
|
@ -681,18 +728,10 @@ async def _remove_sub_channel(
|
||||||
async def open_sub_channel_at(
|
async def open_sub_channel_at(
|
||||||
actor_name: str,
|
actor_name: str,
|
||||||
token: RBToken,
|
token: RBToken,
|
||||||
cleanup: bool = True,
|
topic: str = 'default',
|
||||||
):
|
):
|
||||||
async with (
|
async with tractor.find_actor(actor_name) as portal:
|
||||||
tractor.find_actor(actor_name) as portal,
|
await portal.run(_add_sub_channel, topic=topic, token=token)
|
||||||
|
|
||||||
portal.open_context(
|
|
||||||
_add_sub_channel,
|
|
||||||
token=token
|
|
||||||
) as (ctx, _)
|
|
||||||
):
|
|
||||||
...
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
@ -704,18 +743,7 @@ async def open_sub_channel_at(
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
finally:
|
await portal.run(_remove_sub_channel, topic=topic, ring_name=token.shm_name)
|
||||||
if not cleanup:
|
|
||||||
return
|
|
||||||
|
|
||||||
async with tractor.find_actor(actor_name) as portal:
|
|
||||||
if portal:
|
|
||||||
async with portal.open_context(
|
|
||||||
_remove_sub_channel,
|
|
||||||
ring_name=token.shm_name
|
|
||||||
) as (ctx, _):
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
@ -725,12 +753,17 @@ High level helpers to open publisher & subscriber
|
||||||
|
|
||||||
@acm
|
@acm
|
||||||
async def open_ringbuf_publisher(
|
async def open_ringbuf_publisher(
|
||||||
|
# name to distinguish this publisher
|
||||||
|
topic: str = 'default',
|
||||||
|
|
||||||
# global batch size for channels
|
# global batch size for channels
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
|
|
||||||
# messages before changing output channel
|
# messages before changing output channel
|
||||||
msgs_per_turn: int = 1,
|
msgs_per_turn: int = 1,
|
||||||
|
|
||||||
|
encoder: Encoder | None = None,
|
||||||
|
|
||||||
# ensure subscriber receives in same order publisher sent
|
# ensure subscriber receives in same order publisher sent
|
||||||
# causes it to use wrapped payloads which contain the og
|
# causes it to use wrapped payloads which contain the og
|
||||||
# index
|
# index
|
||||||
|
@ -750,26 +783,28 @@ async def open_ringbuf_publisher(
|
||||||
trio.open_nursery(strict_exception_groups=False) as n,
|
trio.open_nursery(strict_exception_groups=False) as n,
|
||||||
RingBufferPublisher(
|
RingBufferPublisher(
|
||||||
n,
|
n,
|
||||||
batch_size=batch_size
|
batch_size=batch_size,
|
||||||
|
encoder=encoder,
|
||||||
) as publisher
|
) as publisher
|
||||||
):
|
):
|
||||||
if guarantee_order:
|
if guarantee_order:
|
||||||
order_send_channel(publisher)
|
order_send_channel(publisher)
|
||||||
|
|
||||||
if set_module_var:
|
if set_module_var:
|
||||||
set_publisher(publisher)
|
set_publisher(topic, publisher)
|
||||||
|
|
||||||
try:
|
|
||||||
yield publisher
|
yield publisher
|
||||||
|
|
||||||
except trio.Cancelled:
|
n.cancel_scope.cancel()
|
||||||
with trio.CancelScope(shield=True):
|
|
||||||
await publisher.aclose()
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@acm
|
@acm
|
||||||
async def open_ringbuf_subscriber(
|
async def open_ringbuf_subscriber(
|
||||||
|
# name to distinguish this subscriber
|
||||||
|
topic: str = 'default',
|
||||||
|
|
||||||
|
decoder: Decoder | None = None,
|
||||||
|
|
||||||
# expect indexed payloads and unwrap them in order
|
# expect indexed payloads and unwrap them in order
|
||||||
guarantee_order: bool = False,
|
guarantee_order: bool = False,
|
||||||
|
|
||||||
|
@ -784,7 +819,7 @@ async def open_ringbuf_subscriber(
|
||||||
'''
|
'''
|
||||||
async with (
|
async with (
|
||||||
trio.open_nursery(strict_exception_groups=False) as n,
|
trio.open_nursery(strict_exception_groups=False) as n,
|
||||||
RingBufferSubscriber(n) as subscriber
|
RingBufferSubscriber(n, decoder=decoder) as subscriber
|
||||||
):
|
):
|
||||||
# maybe monkey patch `.receive` to use indexed payloads
|
# maybe monkey patch `.receive` to use indexed payloads
|
||||||
if guarantee_order:
|
if guarantee_order:
|
||||||
|
@ -792,13 +827,8 @@ async def open_ringbuf_subscriber(
|
||||||
|
|
||||||
# maybe set global module var for remote actor channel updates
|
# maybe set global module var for remote actor channel updates
|
||||||
if set_module_var:
|
if set_module_var:
|
||||||
global _subscriber
|
set_subscriber(topic, subscriber)
|
||||||
set_subscriber(subscriber)
|
|
||||||
|
|
||||||
try:
|
|
||||||
yield subscriber
|
yield subscriber
|
||||||
|
|
||||||
except trio.Cancelled:
|
n.cancel_scope.cancel()
|
||||||
with trio.CancelScope(shield=True):
|
|
||||||
await subscriber.aclose()
|
|
||||||
raise
|
|
||||||
|
|
Loading…
Reference in New Issue