Better APIs for ringd and pubsub
Pubsub: Remove un-necesary ChannelManager locking mechanism Make ChannelManager.close wait for all channel removals Make publisher turn switch configurable with `msgs_per_turn` variable Fix batch_size setter on publisher Add broadcast to publisher Add endpoints on pubsub for remote actors to dynamically add and remove channels Ringd: Add fifo lock and use it on methods that modify _rings state Add comments Break up ringd.open_ringbuf apis into attach_, open_ & maybe_open_ When attaching its no longer a long running context, only on opens Adapt ringd test to new apisone_ring_to_rule_them_all
parent
e4868ded54
commit
5d6fa643ba
|
@ -1,5 +1,3 @@
|
|||
from contextlib import asynccontextmanager as acm
|
||||
|
||||
import trio
|
||||
import tractor
|
||||
import msgspec
|
||||
|
@ -40,7 +38,7 @@ async def send_child(
|
|||
ring_name: str
|
||||
):
|
||||
async with (
|
||||
ringd.open_ringbuf(ring_name) as token,
|
||||
ringd.attach_ringbuf(ring_name) as token,
|
||||
|
||||
attach_to_ringbuf_sender(token) as chan,
|
||||
):
|
||||
|
@ -63,9 +61,7 @@ def test_ringd():
|
|||
async with (
|
||||
tractor.open_nursery() as an,
|
||||
|
||||
ringd.open_ringd(
|
||||
loglevel='info'
|
||||
)
|
||||
ringd.open_ringd()
|
||||
):
|
||||
recv_portal = await an.start_actor(
|
||||
'recv',
|
||||
|
@ -133,10 +129,10 @@ async def subscriber_child(ctx: tractor.Context):
|
|||
msg = msgspec.msgpack.decode(msg, type=ControlMessages)
|
||||
match msg:
|
||||
case AddChannelMsg():
|
||||
await subs.add_channel(msg.name, must_exist=False)
|
||||
await subs.add_channel(msg.name)
|
||||
|
||||
case RemoveChannelMsg():
|
||||
await subs.remove_channel(msg.name)
|
||||
subs.remove_channel(msg.name)
|
||||
|
||||
case RangeMsg():
|
||||
range_msg = msg
|
||||
|
@ -171,7 +167,7 @@ async def subscriber_child(ctx: tractor.Context):
|
|||
async def publisher_child(ctx: tractor.Context):
|
||||
await ctx.started()
|
||||
async with (
|
||||
open_ringbuf_publisher(batch_size=100, guarantee_order=True) as pub,
|
||||
open_ringbuf_publisher(guarantee_order=True) as pub,
|
||||
ctx.open_stream() as stream
|
||||
):
|
||||
async for msg in stream:
|
||||
|
@ -181,7 +177,7 @@ async def publisher_child(ctx: tractor.Context):
|
|||
await pub.add_channel(msg.name, must_exist=True)
|
||||
|
||||
case RemoveChannelMsg():
|
||||
await pub.remove_channel(msg.name)
|
||||
pub.remove_channel(msg.name)
|
||||
|
||||
case RangeMsg():
|
||||
for i in range(msg.size):
|
||||
|
@ -258,11 +254,6 @@ def test_pubsub():
|
|||
await send_range(100)
|
||||
await remove_channel(ring_name)
|
||||
|
||||
# try using same ring name
|
||||
await add_channel(ring_name)
|
||||
await send_range(100)
|
||||
await remove_channel(ring_name)
|
||||
|
||||
# multi chan test
|
||||
ring_names = []
|
||||
for i in range(3):
|
||||
|
|
|
@ -86,19 +86,14 @@ class ChannelManager(Generic[ChannelType]):
|
|||
# store channel runtime variables
|
||||
self._channels: list[ChannelInfo] = []
|
||||
|
||||
# methods that modify self._channels should be ordered by FIFO
|
||||
self._lock = trio.StrictFIFOLock()
|
||||
|
||||
self._is_closed: bool = True
|
||||
|
||||
self._teardown = trio.Event()
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._is_closed
|
||||
|
||||
@property
|
||||
def lock(self) -> trio.StrictFIFOLock:
|
||||
return self._lock
|
||||
|
||||
@property
|
||||
def channels(self) -> list[ChannelInfo]:
|
||||
return self._channels
|
||||
|
@ -106,8 +101,8 @@ class ChannelManager(Generic[ChannelType]):
|
|||
async def _channel_handler_task(
|
||||
self,
|
||||
name: str,
|
||||
task_status: trio.TASK_STATUS_IGNORED,
|
||||
**kwargs
|
||||
must_exist: bool = False,
|
||||
task_status=trio.TASK_STATUS_IGNORED,
|
||||
):
|
||||
'''
|
||||
Open channel resources, add to internal data structures, signal channel
|
||||
|
@ -119,7 +114,7 @@ class ChannelManager(Generic[ChannelType]):
|
|||
|
||||
kwargs are proxied to `self._open_channel` acm.
|
||||
'''
|
||||
async with self._open_channel(name, **kwargs) as chan:
|
||||
async with self._open_channel(name, must_exist=must_exist) as chan:
|
||||
cancel_scope = trio.CancelScope()
|
||||
info = ChannelInfo(
|
||||
name=name,
|
||||
|
@ -138,6 +133,9 @@ class ChannelManager(Generic[ChannelType]):
|
|||
|
||||
self._maybe_destroy_channel(name)
|
||||
|
||||
if len(self) == 0:
|
||||
self._teardown.set()
|
||||
|
||||
def _find_channel(self, name: str) -> tuple[int, ChannelInfo] | None:
|
||||
'''
|
||||
Given a channel name maybe return its index and value from
|
||||
|
@ -165,7 +163,7 @@ class ChannelManager(Generic[ChannelType]):
|
|||
info.cancel_scope.cancel()
|
||||
del self._channels[i]
|
||||
|
||||
async def add_channel(self, name: str, **kwargs):
|
||||
async def add_channel(self, name: str, must_exist: bool = False):
|
||||
'''
|
||||
Add a new channel to be handled
|
||||
|
||||
|
@ -173,14 +171,13 @@ class ChannelManager(Generic[ChannelType]):
|
|||
if self.closed:
|
||||
raise trio.ClosedResourceError
|
||||
|
||||
async with self._lock:
|
||||
await self._n.start(partial(
|
||||
self._channel_handler_task,
|
||||
name,
|
||||
**kwargs
|
||||
))
|
||||
await self._n.start(partial(
|
||||
self._channel_handler_task,
|
||||
name,
|
||||
must_exist=must_exist
|
||||
))
|
||||
|
||||
async def remove_channel(self, name: str):
|
||||
def remove_channel(self, name: str):
|
||||
'''
|
||||
Remove a channel and stop its handling
|
||||
|
||||
|
@ -188,12 +185,11 @@ class ChannelManager(Generic[ChannelType]):
|
|||
if self.closed:
|
||||
raise trio.ClosedResourceError
|
||||
|
||||
async with self._lock:
|
||||
self._maybe_destroy_channel(name)
|
||||
self._maybe_destroy_channel(name)
|
||||
|
||||
# if that was last channel reset connect event
|
||||
if len(self) == 0:
|
||||
self._connect_event = trio.Event()
|
||||
# if that was last channel reset connect event
|
||||
if len(self) == 0:
|
||||
self._connect_event = trio.Event()
|
||||
|
||||
async def wait_for_channel(self):
|
||||
'''
|
||||
|
@ -226,7 +222,18 @@ class ChannelManager(Generic[ChannelType]):
|
|||
return
|
||||
|
||||
for info in self._channels:
|
||||
await self.remove_channel(info.name)
|
||||
if info.channel.closed:
|
||||
continue
|
||||
|
||||
self.remove_channel(info.name)
|
||||
|
||||
try:
|
||||
await self._teardown.wait()
|
||||
|
||||
except trio.Cancelled:
|
||||
# log.exception('close was cancelled')
|
||||
raise
|
||||
|
||||
|
||||
self._is_closed = True
|
||||
|
||||
|
@ -236,12 +243,6 @@ Ring buffer publisher & subscribe pattern mediated by `ringd` actor.
|
|||
|
||||
'''
|
||||
|
||||
@dataclass
|
||||
class PublisherChannels:
|
||||
ring: RingBufferSendChannel
|
||||
schan: trio.MemorySendChannel
|
||||
rchan: trio.MemoryReceiveChannel
|
||||
|
||||
|
||||
class RingBufferPublisher(trio.abc.SendChannel[bytes]):
|
||||
'''
|
||||
|
@ -259,24 +260,32 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
|
|||
# new ringbufs created will have this buf_size
|
||||
buf_size: int = 10 * 1024,
|
||||
|
||||
# amount of msgs to each ring before switching turns
|
||||
msgs_per_turn: int = 1,
|
||||
|
||||
# global batch size for all channels
|
||||
batch_size: int = 1
|
||||
):
|
||||
self._buf_size = buf_size
|
||||
self._batch_size: int = batch_size
|
||||
self.msgs_per_turn = msgs_per_turn
|
||||
|
||||
self._chanmngr = ChannelManager[PublisherChannels](
|
||||
# helper to manage acms + long running tasks
|
||||
self._chanmngr = ChannelManager[RingBufferSendChannel](
|
||||
n,
|
||||
self._open_channel,
|
||||
self._channel_task
|
||||
)
|
||||
|
||||
# methods that send data over the channels need to be acquire send lock
|
||||
# in order to guarantee order of operations
|
||||
# ensure no concurrent `.send()` calls
|
||||
self._send_lock = trio.StrictFIFOLock()
|
||||
|
||||
# index of channel to be used for next send
|
||||
self._next_turn: int = 0
|
||||
|
||||
# amount of messages sent this turn
|
||||
self._turn_msgs: int = 0
|
||||
# have we closed this publisher?
|
||||
# set to `False` on `.__aenter__()`
|
||||
self._is_closed: bool = True
|
||||
|
||||
@property
|
||||
|
@ -288,14 +297,31 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
|
|||
return self._batch_size
|
||||
|
||||
@batch_size.setter
|
||||
def set_batch_size(self, value: int) -> None:
|
||||
def batch_size(self, value: int) -> None:
|
||||
for info in self.channels:
|
||||
info.channel.ring.batch_size = value
|
||||
info.channel.batch_size = value
|
||||
|
||||
@property
|
||||
def channels(self) -> list[ChannelInfo]:
|
||||
return self._chanmngr.channels
|
||||
|
||||
def _get_next_turn(self) -> int:
|
||||
'''
|
||||
Maybe switch turn and reset self._turn_msgs or just increment it.
|
||||
Return current turn
|
||||
'''
|
||||
if self._turn_msgs == self.msgs_per_turn:
|
||||
self._turn_msgs = 0
|
||||
self._next_turn += 1
|
||||
|
||||
if self._next_turn >= len(self.channels):
|
||||
self._next_turn = 0
|
||||
|
||||
else:
|
||||
self._turn_msgs += 1
|
||||
|
||||
return self._next_turn
|
||||
|
||||
def get_channel(self, name: str) -> ChannelInfo:
|
||||
'''
|
||||
Get underlying ChannelInfo from name
|
||||
|
@ -310,8 +336,8 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
|
|||
):
|
||||
await self._chanmngr.add_channel(name, must_exist=must_exist)
|
||||
|
||||
async def remove_channel(self, name: str):
|
||||
await self._chanmngr.remove_channel(name)
|
||||
def remove_channel(self, name: str):
|
||||
self._chanmngr.remove_channel(name)
|
||||
|
||||
@acm
|
||||
async def _open_channel(
|
||||
|
@ -320,41 +346,45 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
|
|||
name: str,
|
||||
must_exist: bool = False
|
||||
|
||||
) -> AsyncContextManager[PublisherChannels]:
|
||||
) -> AsyncContextManager[RingBufferSendChannel]:
|
||||
'''
|
||||
Open a ringbuf through `ringd` and attach as send side
|
||||
'''
|
||||
async with (
|
||||
ringd.open_ringbuf(
|
||||
name=name,
|
||||
buf_size=self._buf_size,
|
||||
must_exist=must_exist,
|
||||
) as token,
|
||||
attach_to_ringbuf_sender(token) as ring,
|
||||
):
|
||||
schan, rchan = trio.open_memory_channel(0)
|
||||
yield PublisherChannels(
|
||||
ring=ring,
|
||||
schan=schan,
|
||||
rchan=rchan
|
||||
)
|
||||
try:
|
||||
while True:
|
||||
msg = rchan.receive_nowait()
|
||||
await ring.send(msg)
|
||||
if must_exist:
|
||||
ringd_fn = ringd.attach_ringbuf
|
||||
kwargs = {}
|
||||
|
||||
except trio.WouldBlock:
|
||||
...
|
||||
else:
|
||||
ringd_fn = ringd.open_ringbuf
|
||||
kwargs = {'buf_size': self._buf_size}
|
||||
|
||||
async with (
|
||||
ringd_fn(
|
||||
name=name,
|
||||
**kwargs
|
||||
) as token,
|
||||
|
||||
attach_to_ringbuf_sender(
|
||||
token,
|
||||
batch_size=self._batch_size
|
||||
) as ring,
|
||||
):
|
||||
yield ring
|
||||
# try:
|
||||
# # ensure all messages are sent
|
||||
# await ring.flush()
|
||||
|
||||
# except Exception as e:
|
||||
# e.add_note(f'while closing ringbuf send channel {name}')
|
||||
# log.exception(e)
|
||||
|
||||
async def _channel_task(self, info: ChannelInfo) -> None:
|
||||
'''
|
||||
Forever get current runtime info for channel, wait on its next pending
|
||||
payloads update event then drain all into send channel.
|
||||
Wait forever until channel cancellation
|
||||
|
||||
'''
|
||||
try:
|
||||
async for msg in info.channel.rchan:
|
||||
await info.channel.ring.send(msg)
|
||||
await trio.sleep_forever()
|
||||
|
||||
except trio.Cancelled:
|
||||
...
|
||||
|
@ -362,11 +392,9 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
|
|||
async def send(self, msg: bytes):
|
||||
'''
|
||||
If no output channels connected, wait until one, then fetch the next
|
||||
channel based on turn, add the indexed payload and update
|
||||
`self._next_turn` & `self._next_index`.
|
||||
channel based on turn.
|
||||
|
||||
Needs to acquire `self._send_lock` to make sure updates to turn & index
|
||||
variables dont happen out of order.
|
||||
Needs to acquire `self._send_lock` to ensure no concurrent calls.
|
||||
|
||||
'''
|
||||
if self.closed:
|
||||
|
@ -380,18 +408,28 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
|
|||
if len(self.channels) == 0:
|
||||
await self._chanmngr.wait_for_channel()
|
||||
|
||||
if self._next_turn >= len(self.channels):
|
||||
self._next_turn = 0
|
||||
turn = self._get_next_turn()
|
||||
|
||||
info = self.channels[self._next_turn]
|
||||
await info.channel.schan.send(msg)
|
||||
info = self.channels[turn]
|
||||
await info.channel.send(msg)
|
||||
|
||||
self._next_turn += 1
|
||||
async def broadcast(self, msg: bytes):
|
||||
'''
|
||||
Send a msg to all channels, if no channels connected, does nothing.
|
||||
'''
|
||||
if self.closed:
|
||||
raise trio.ClosedResourceError
|
||||
|
||||
for info in self.channels:
|
||||
await info.channel.send(msg)
|
||||
|
||||
async def flush(self, new_batch_size: int | None = None):
|
||||
async with self._chanmngr.lock:
|
||||
for info in self.channels:
|
||||
await info.channel.ring.flush(new_batch_size=new_batch_size)
|
||||
for info in self.channels:
|
||||
try:
|
||||
await info.channel.flush(new_batch_size=new_batch_size)
|
||||
|
||||
except trio.ClosedResourceError:
|
||||
...
|
||||
|
||||
async def __aenter__(self):
|
||||
self._chanmngr.open()
|
||||
|
@ -403,41 +441,12 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
|
|||
log.warning('tried to close RingBufferPublisher but its already closed...')
|
||||
return
|
||||
|
||||
await self._chanmngr.close()
|
||||
with trio.CancelScope(shield=True):
|
||||
await self._chanmngr.close()
|
||||
|
||||
self._is_closed = True
|
||||
|
||||
|
||||
@acm
|
||||
async def open_ringbuf_publisher(
|
||||
|
||||
buf_size: int = 10 * 1024,
|
||||
batch_size: int = 1,
|
||||
guarantee_order: bool = False,
|
||||
force_cancel: bool = False
|
||||
|
||||
) -> AsyncContextManager[RingBufferPublisher]:
|
||||
'''
|
||||
Open a new ringbuf publisher
|
||||
|
||||
'''
|
||||
async with (
|
||||
trio.open_nursery() as n,
|
||||
RingBufferPublisher(
|
||||
n,
|
||||
buf_size=buf_size,
|
||||
batch_size=batch_size
|
||||
) as publisher
|
||||
):
|
||||
if guarantee_order:
|
||||
order_send_channel(publisher)
|
||||
|
||||
yield publisher
|
||||
|
||||
if force_cancel:
|
||||
# implicitly cancel any running channel handler task
|
||||
n.cancel_scope.cancel()
|
||||
|
||||
|
||||
class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
|
||||
'''
|
||||
Use ChannelManager to create a multi ringbuf receiver that can
|
||||
|
@ -458,10 +467,15 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
|
|||
self,
|
||||
n: trio.Nursery,
|
||||
|
||||
# new ringbufs created will have this buf_size
|
||||
buf_size: int = 10 * 1024,
|
||||
|
||||
# if connecting to a publisher that has already sent messages set
|
||||
# to the next expected payload index this subscriber will receive
|
||||
start_index: int = 0
|
||||
):
|
||||
self._buf_size = buf_size
|
||||
|
||||
self._chanmngr = ChannelManager[RingBufferReceiveChannel](
|
||||
n,
|
||||
self._open_channel,
|
||||
|
@ -488,8 +502,8 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
|
|||
async def add_channel(self, name: str, must_exist: bool = False):
|
||||
await self._chanmngr.add_channel(name, must_exist=must_exist)
|
||||
|
||||
async def remove_channel(self, name: str):
|
||||
await self._chanmngr.remove_channel(name)
|
||||
def remove_channel(self, name: str):
|
||||
self._chanmngr.remove_channel(name)
|
||||
|
||||
@acm
|
||||
async def _open_channel(
|
||||
|
@ -502,11 +516,20 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
|
|||
'''
|
||||
Open a ringbuf through `ringd` and attach as receiver side
|
||||
'''
|
||||
if must_exist:
|
||||
ringd_fn = ringd.attach_ringbuf
|
||||
kwargs = {}
|
||||
|
||||
else:
|
||||
ringd_fn = ringd.open_ringbuf
|
||||
kwargs = {'buf_size': self._buf_size}
|
||||
|
||||
async with (
|
||||
ringd.open_ringbuf(
|
||||
ringd_fn(
|
||||
name=name,
|
||||
must_exist=must_exist,
|
||||
**kwargs
|
||||
) as token,
|
||||
|
||||
attach_to_ringbuf_receiver(token) as chan
|
||||
):
|
||||
yield chan
|
||||
|
@ -554,7 +577,6 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
|
|||
|
||||
async def aclose(self) -> None:
|
||||
if self.closed:
|
||||
log.warning('tried to close RingBufferSubscriber but its already closed...')
|
||||
return
|
||||
|
||||
await self._chanmngr.close()
|
||||
|
@ -562,26 +584,241 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
|
|||
await self._rchan.aclose()
|
||||
self._is_closed = True
|
||||
|
||||
|
||||
'''
|
||||
Actor module for managing publisher & subscriber channels remotely through
|
||||
`tractor.context` rpc
|
||||
'''
|
||||
|
||||
_publisher: RingBufferPublisher | None = None
|
||||
_subscriber: RingBufferSubscriber | None = None
|
||||
|
||||
|
||||
def set_publisher(pub: RingBufferPublisher):
|
||||
global _publisher
|
||||
|
||||
if _publisher:
|
||||
raise RuntimeError(
|
||||
f'publisher already set on {tractor.current_actor()}'
|
||||
)
|
||||
|
||||
_publisher = pub
|
||||
|
||||
|
||||
def set_subscriber(sub: RingBufferSubscriber):
|
||||
global _subscriber
|
||||
|
||||
if _subscriber:
|
||||
raise RuntimeError(
|
||||
f'subscriber already set on {tractor.current_actor()}'
|
||||
)
|
||||
|
||||
_subscriber = sub
|
||||
|
||||
|
||||
def get_publisher() -> RingBufferPublisher:
|
||||
global _publisher
|
||||
|
||||
if not _publisher:
|
||||
raise RuntimeError(
|
||||
f'{tractor.current_actor()} tried to get publisher'
|
||||
'but it\'s not set'
|
||||
)
|
||||
|
||||
return _publisher
|
||||
|
||||
|
||||
def get_subscriber() -> RingBufferSubscriber:
|
||||
global _subscriber
|
||||
|
||||
if not _subscriber:
|
||||
raise RuntimeError(
|
||||
f'{tractor.current_actor()} tried to get subscriber'
|
||||
'but it\'s not set'
|
||||
)
|
||||
|
||||
return _subscriber
|
||||
|
||||
|
||||
@tractor.context
|
||||
async def open_pub_channel(
|
||||
ctx: tractor.Context,
|
||||
ring_name: str,
|
||||
must_exist: bool = False
|
||||
):
|
||||
publisher = get_publisher()
|
||||
await publisher.add_channel(
|
||||
ring_name,
|
||||
must_exist=must_exist
|
||||
)
|
||||
|
||||
await ctx.started()
|
||||
|
||||
try:
|
||||
await trio.sleep_forever()
|
||||
|
||||
finally:
|
||||
try:
|
||||
publisher.remove_channel(ring_name)
|
||||
|
||||
except trio.ClosedResourceError:
|
||||
...
|
||||
|
||||
|
||||
@acm
|
||||
async def open_pub_channel_at(
|
||||
actor_name: str,
|
||||
ring_name: str,
|
||||
must_exist: bool = False
|
||||
):
|
||||
async with (
|
||||
tractor.find_actor(actor_name) as portal,
|
||||
portal.open_context(
|
||||
open_pub_channel,
|
||||
ring_name=ring_name,
|
||||
must_exist=must_exist
|
||||
) as (ctx, _)
|
||||
):
|
||||
yield
|
||||
await ctx.cancel()
|
||||
|
||||
|
||||
@tractor.context
|
||||
async def open_sub_channel(
|
||||
ctx: tractor.Context,
|
||||
ring_name: str,
|
||||
must_exist: bool = False
|
||||
):
|
||||
subscriber = get_subscriber()
|
||||
await subscriber.add_channel(
|
||||
ring_name,
|
||||
must_exist=must_exist
|
||||
)
|
||||
|
||||
await ctx.started()
|
||||
|
||||
try:
|
||||
await trio.sleep_forever()
|
||||
|
||||
finally:
|
||||
try:
|
||||
subscriber.remove_channel(ring_name)
|
||||
|
||||
except trio.ClosedResourceError:
|
||||
...
|
||||
|
||||
|
||||
@acm
|
||||
async def open_sub_channel_at(
|
||||
actor_name: str,
|
||||
ring_name: str,
|
||||
must_exist: bool = False
|
||||
):
|
||||
async with (
|
||||
tractor.find_actor(actor_name) as portal,
|
||||
portal.open_context(
|
||||
open_sub_channel,
|
||||
ring_name=ring_name,
|
||||
must_exist=must_exist
|
||||
) as (ctx, _)
|
||||
):
|
||||
yield
|
||||
await ctx.cancel()
|
||||
|
||||
|
||||
'''
|
||||
High level helpers to open publisher & subscriber
|
||||
'''
|
||||
|
||||
|
||||
@acm
|
||||
async def open_ringbuf_publisher(
|
||||
# buf size for created rings
|
||||
buf_size: int = 10 * 1024,
|
||||
|
||||
# global batch size for channels
|
||||
batch_size: int = 1,
|
||||
|
||||
# messages before changing output channel
|
||||
msgs_per_turn: int = 1,
|
||||
|
||||
# ensure subscriber receives in same order publisher sent
|
||||
# causes it to use wrapped payloads which contain the og
|
||||
# index
|
||||
guarantee_order: bool = False,
|
||||
|
||||
# explicit nursery cancel call on cleanup
|
||||
force_cancel: bool = False,
|
||||
|
||||
# on creation, set the `_publisher` global in order to use the provided
|
||||
# tractor.context & helper utils for adding and removing new channels from
|
||||
# remote actors
|
||||
set_module_var: bool = True
|
||||
|
||||
) -> AsyncContextManager[RingBufferPublisher]:
|
||||
'''
|
||||
Open a new ringbuf publisher
|
||||
|
||||
'''
|
||||
async with (
|
||||
trio.open_nursery(strict_exception_groups=False) as n,
|
||||
RingBufferPublisher(
|
||||
n,
|
||||
buf_size=buf_size,
|
||||
batch_size=batch_size
|
||||
) as publisher
|
||||
):
|
||||
if guarantee_order:
|
||||
order_send_channel(publisher)
|
||||
|
||||
if set_module_var:
|
||||
set_publisher(publisher)
|
||||
|
||||
try:
|
||||
yield publisher
|
||||
|
||||
finally:
|
||||
if force_cancel:
|
||||
# implicitly cancel any running channel handler task
|
||||
n.cancel_scope.cancel()
|
||||
|
||||
|
||||
@acm
|
||||
async def open_ringbuf_subscriber(
|
||||
# buf size for created rings
|
||||
buf_size: int = 10 * 1024,
|
||||
|
||||
# expect indexed payloads and unwrap them in order
|
||||
guarantee_order: bool = False,
|
||||
force_cancel: bool = False
|
||||
|
||||
# explicit nursery cancel call on cleanup
|
||||
force_cancel: bool = False,
|
||||
|
||||
# on creation, set the `_subscriber` global in order to use the provided
|
||||
# tractor.context & helper utils for adding and removing new channels from
|
||||
# remote actors
|
||||
set_module_var: bool = True
|
||||
) -> AsyncContextManager[RingBufferPublisher]:
|
||||
'''
|
||||
Open a new ringbuf subscriber
|
||||
|
||||
'''
|
||||
async with (
|
||||
trio.open_nursery() as n,
|
||||
trio.open_nursery(strict_exception_groups=False) as n,
|
||||
RingBufferSubscriber(
|
||||
n,
|
||||
buf_size=buf_size
|
||||
) as subscriber
|
||||
):
|
||||
# maybe monkey patch `.receive` to use indexed payloads
|
||||
if guarantee_order:
|
||||
order_receive_channel(subscriber)
|
||||
|
||||
# maybe set global module var for remote actor channel updates
|
||||
if set_module_var:
|
||||
global _subscriber
|
||||
set_subscriber(subscriber)
|
||||
|
||||
yield subscriber
|
||||
|
||||
if force_cancel:
|
||||
|
|
|
@ -25,6 +25,7 @@ open_ringbuf acm, will automatically contact ringd.
|
|||
'''
|
||||
import os
|
||||
import tempfile
|
||||
from typing import AsyncContextManager
|
||||
from pathlib import Path
|
||||
from contextlib import (
|
||||
asynccontextmanager as acm
|
||||
|
@ -33,111 +34,118 @@ from dataclasses import dataclass
|
|||
|
||||
import trio
|
||||
import tractor
|
||||
from tractor.linux import send_fds, recv_fds
|
||||
from tractor.linux import (
|
||||
send_fds,
|
||||
recv_fds,
|
||||
)
|
||||
|
||||
import tractor.ipc._ringbuf as ringbuf
|
||||
from tractor.ipc._ringbuf import RBToken
|
||||
|
||||
|
||||
log = tractor.log.get_logger(__name__)
|
||||
# log = tractor.log.get_console_log(level='info')
|
||||
|
||||
|
||||
class RingNotFound(Exception):
|
||||
...
|
||||
'''
|
||||
Daemon implementation
|
||||
|
||||
'''
|
||||
|
||||
|
||||
_ringd_actor_name = 'ringd'
|
||||
_root_key = _ringd_actor_name + f'-{os.getpid()}'
|
||||
_ringd_actor_name: str = 'ringd'
|
||||
|
||||
|
||||
_root_name: str = f'{_ringd_actor_name}-{os.getpid()}'
|
||||
|
||||
|
||||
def _make_ring_name(name: str) -> str:
|
||||
'''
|
||||
User provided ring names will be prefixed by the ringd actor name and pid.
|
||||
'''
|
||||
return f'{_root_name}.{name}'
|
||||
|
||||
|
||||
@dataclass
|
||||
class RingInfo:
|
||||
token: RBToken
|
||||
creator: str
|
||||
unlink: trio.Event()
|
||||
|
||||
|
||||
_rings: dict[str, RingInfo] = {}
|
||||
_ring_lock = trio.StrictFIFOLock()
|
||||
|
||||
|
||||
def _maybe_get_ring(name: str) -> RingInfo | None:
|
||||
if name in _rings:
|
||||
return _rings[name]
|
||||
'''
|
||||
Maybe return RingInfo for a given name str
|
||||
|
||||
return None
|
||||
'''
|
||||
# if full name was passed, strip root name
|
||||
if _root_name in name:
|
||||
name = name.replace(f'{_root_name}.', '')
|
||||
|
||||
return _rings.get(name, None)
|
||||
|
||||
|
||||
def _get_ring(name: str) -> RingInfo:
|
||||
'''
|
||||
Return a RingInfo for a given name or raise
|
||||
'''
|
||||
info = _maybe_get_ring(name)
|
||||
|
||||
if not info:
|
||||
raise RuntimeError(f'Ring \"{name}\" not found!')
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def _insert_ring(name: str, info: RingInfo):
|
||||
'''
|
||||
Add a new ring
|
||||
'''
|
||||
if name in _rings:
|
||||
raise RuntimeError(f'A ring with name {name} already exists!')
|
||||
|
||||
_rings[name] = info
|
||||
|
||||
|
||||
def _destroy_ring(name: str):
|
||||
'''
|
||||
Delete information about a ring
|
||||
'''
|
||||
if name not in _rings:
|
||||
raise RuntimeError(f'Tried to delete non existant {name} ring!')
|
||||
|
||||
del _rings[name]
|
||||
|
||||
|
||||
async def _attach_to_ring(
|
||||
ringd_pid: int,
|
||||
ring_name: str
|
||||
) -> RBToken:
|
||||
actor = tractor.current_actor()
|
||||
|
||||
fd_amount = 3
|
||||
sock_path = str(
|
||||
Path(tempfile.gettempdir())
|
||||
/
|
||||
f'ringd-{ringd_pid}-{ring_name}-to-{actor.name}.sock'
|
||||
)
|
||||
|
||||
log.info(f'trying to attach to ring {ring_name}...')
|
||||
|
||||
async with (
|
||||
tractor.find_actor(_ringd_actor_name) as ringd,
|
||||
ringd.open_context(
|
||||
_pass_fds,
|
||||
name=ring_name,
|
||||
sock_path=sock_path
|
||||
) as (ctx, token),
|
||||
):
|
||||
fds = await recv_fds(sock_path, fd_amount)
|
||||
log.info(
|
||||
f'received fds: {fds}'
|
||||
)
|
||||
|
||||
token = RBToken.from_msg(token)
|
||||
|
||||
write, wrap, eof = fds
|
||||
|
||||
return RBToken(
|
||||
shm_name=token.shm_name,
|
||||
write_eventfd=write,
|
||||
wrap_eventfd=wrap,
|
||||
eof_eventfd=eof,
|
||||
buf_size=token.buf_size
|
||||
)
|
||||
|
||||
|
||||
@tractor.context
|
||||
async def _pass_fds(
|
||||
ctx: tractor.Context,
|
||||
name: str,
|
||||
sock_path: str
|
||||
):
|
||||
global _rings
|
||||
info = _maybe_get_ring(name)
|
||||
'''
|
||||
Ringd endpoint to request passing fds of a ring.
|
||||
|
||||
if not info:
|
||||
raise RingNotFound(f'Ring \"{name}\" not found!')
|
||||
Supports passing fullname or not (ringd actor name and pid before ring
|
||||
name).
|
||||
|
||||
token = info.token
|
||||
See `_attach_to_ring` function for usage.
|
||||
'''
|
||||
async with _ring_lock:
|
||||
# get ring fds or raise error
|
||||
token = _get_ring(name).token
|
||||
|
||||
async with send_fds(token.fds, sock_path):
|
||||
log.info(f'connected to {sock_path} for fd passing')
|
||||
await ctx.started(token)
|
||||
# start fd passing context using socket on `sock_path`
|
||||
async with send_fds(token.fds, sock_path):
|
||||
log.info(f'connected to {sock_path} for fd passing')
|
||||
# use started to signal socket is ready and send token in order for
|
||||
# client to get extra info like buf_size
|
||||
await ctx.started(token)
|
||||
# send_fds will block until receive side acks
|
||||
|
||||
log.info(f'fds {token.fds} sent')
|
||||
|
||||
return token
|
||||
log.info(f'ring {name} fds: {token.fds}, sent')
|
||||
|
||||
|
||||
@tractor.context
|
||||
|
@ -145,60 +153,105 @@ async def _open_ringbuf(
|
|||
ctx: tractor.Context,
|
||||
caller: str,
|
||||
name: str,
|
||||
buf_size: int = 10 * 1024,
|
||||
must_exist: bool = False,
|
||||
buf_size: int = 10 * 1024
|
||||
):
|
||||
global _root_key, _rings
|
||||
log.info(f'maybe open ring {name} from {caller}, must_exist = {must_exist}')
|
||||
'''
|
||||
Ringd endpoint to create and allocate resources for a new ring.
|
||||
|
||||
info = _maybe_get_ring(name)
|
||||
'''
|
||||
await _ring_lock.acquire()
|
||||
maybe_info = _maybe_get_ring(name)
|
||||
|
||||
if info:
|
||||
log.info(f'ring {name} exists, {caller} attached')
|
||||
|
||||
await ctx.started(os.getpid())
|
||||
|
||||
async with ctx.open_stream() as stream:
|
||||
await stream.receive()
|
||||
|
||||
info.unlink.set()
|
||||
|
||||
log.info(f'{caller} detached from ring {name}')
|
||||
|
||||
return
|
||||
|
||||
if must_exist:
|
||||
raise RingNotFound(
|
||||
f'Tried to open_ringbuf but it doesn\'t exist: {name}'
|
||||
if maybe_info:
|
||||
raise RuntimeError(
|
||||
f'Tried to create ringbuf but it already exists: {name}'
|
||||
)
|
||||
|
||||
fullname = _make_ring_name(name)
|
||||
|
||||
with ringbuf.open_ringbuf(
|
||||
_root_key + name,
|
||||
fullname,
|
||||
buf_size=buf_size
|
||||
) as token:
|
||||
unlink_event = trio.Event()
|
||||
|
||||
_insert_ring(
|
||||
name,
|
||||
RingInfo(
|
||||
token=token,
|
||||
creator=caller,
|
||||
unlink=unlink_event,
|
||||
)
|
||||
)
|
||||
log.info(f'ring {name} created by {caller}')
|
||||
await ctx.started(os.getpid())
|
||||
|
||||
async with ctx.open_stream() as stream:
|
||||
await stream.receive()
|
||||
_ring_lock.release()
|
||||
|
||||
await unlink_event.wait()
|
||||
_destroy_ring(name)
|
||||
# yield full ring name to rebuild token after fd passing
|
||||
await ctx.started(fullname)
|
||||
|
||||
log.info(f'ring {name} destroyed by {caller}')
|
||||
# await ctx cancel to remove ring from tracking and cleanup
|
||||
try:
|
||||
log.info(f'ring {name} created by {caller}')
|
||||
await trio.sleep_forever()
|
||||
|
||||
finally:
|
||||
_destroy_ring(name)
|
||||
|
||||
log.info(f'ring {name} destroyed by {caller}')
|
||||
|
||||
|
||||
@tractor.context
|
||||
async def _attach_ringbuf(
|
||||
ctx: tractor.Context,
|
||||
caller: str,
|
||||
name: str
|
||||
) -> str:
|
||||
'''
|
||||
Ringd endpoint to "attach" to an existing ring, this just ensures ring
|
||||
actually exists and returns its full name.
|
||||
'''
|
||||
async with _ring_lock:
|
||||
info = _maybe_get_ring(name)
|
||||
|
||||
if not info:
|
||||
raise RuntimeError(
|
||||
f'{caller} tried to open_ringbuf but it doesn\'t exist: {name}'
|
||||
)
|
||||
|
||||
await ctx.started()
|
||||
|
||||
# return full ring name to rebuild token after fd passing
|
||||
return info.token.shm_name
|
||||
|
||||
|
||||
@tractor.context
|
||||
async def _maybe_open_ringbuf(
|
||||
ctx: tractor.Context,
|
||||
caller: str,
|
||||
name: str,
|
||||
buf_size: int = 10 * 1024,
|
||||
):
|
||||
'''
|
||||
If ring already exists attach, if not create it.
|
||||
'''
|
||||
maybe_info = _maybe_get_ring(name)
|
||||
|
||||
if maybe_info:
|
||||
return await _attach_ringbuf(ctx, caller, name)
|
||||
|
||||
return await _open_ringbuf(ctx, caller, name, buf_size=buf_size)
|
||||
|
||||
|
||||
'''
|
||||
Ringd client side helpers
|
||||
|
||||
'''
|
||||
|
||||
|
||||
@acm
|
||||
async def open_ringd(**kwargs) -> tractor.Portal:
|
||||
'''
|
||||
Spawn new ringd actor.
|
||||
|
||||
'''
|
||||
async with tractor.open_nursery(**kwargs) as an:
|
||||
portal = await an.start_actor(
|
||||
_ringd_actor_name,
|
||||
|
@ -210,21 +263,69 @@ async def open_ringd(**kwargs) -> tractor.Portal:
|
|||
|
||||
@acm
|
||||
async def wait_for_ringd() -> tractor.Portal:
|
||||
'''
|
||||
Wait for ringd actor to be up.
|
||||
|
||||
'''
|
||||
async with tractor.wait_for_actor(
|
||||
_ringd_actor_name
|
||||
) as portal:
|
||||
yield portal
|
||||
|
||||
|
||||
async def _request_ring_fds(
|
||||
fullname: str
|
||||
) -> RBToken:
|
||||
'''
|
||||
Private helper to fetch ring fds from ringd actor.
|
||||
'''
|
||||
actor = tractor.current_actor()
|
||||
|
||||
fd_amount = 3
|
||||
sock_path = str(
|
||||
Path(tempfile.gettempdir())
|
||||
/
|
||||
f'{fullname}-to-{actor.name}.sock'
|
||||
)
|
||||
|
||||
log.info(f'trying to attach to {fullname}...')
|
||||
|
||||
async with (
|
||||
tractor.find_actor(_ringd_actor_name) as ringd,
|
||||
|
||||
ringd.open_context(
|
||||
_pass_fds,
|
||||
name=fullname,
|
||||
sock_path=sock_path
|
||||
) as (ctx, token),
|
||||
):
|
||||
fds = await recv_fds(sock_path, fd_amount)
|
||||
write, wrap, eof = fds
|
||||
log.info(
|
||||
f'received fds, write: {write}, wrap: {wrap}, eof: {eof}'
|
||||
)
|
||||
|
||||
token = RBToken.from_msg(token)
|
||||
|
||||
return RBToken(
|
||||
shm_name=fullname,
|
||||
write_eventfd=write,
|
||||
wrap_eventfd=wrap,
|
||||
eof_eventfd=eof,
|
||||
buf_size=token.buf_size
|
||||
)
|
||||
|
||||
|
||||
|
||||
@acm
|
||||
async def open_ringbuf(
|
||||
|
||||
name: str,
|
||||
buf_size: int = 10 * 1024,
|
||||
) -> AsyncContextManager[RBToken]:
|
||||
'''
|
||||
Create a new ring and retrieve its fds.
|
||||
|
||||
must_exist: bool = False,
|
||||
|
||||
) -> RBToken:
|
||||
'''
|
||||
actor = tractor.current_actor()
|
||||
async with (
|
||||
wait_for_ringd() as ringd,
|
||||
|
@ -234,12 +335,67 @@ async def open_ringbuf(
|
|||
caller=actor.name,
|
||||
name=name,
|
||||
buf_size=buf_size,
|
||||
must_exist=must_exist
|
||||
) as (rd_ctx, ringd_pid),
|
||||
|
||||
rd_ctx.open_stream() as _stream,
|
||||
) as (ctx, fullname),
|
||||
):
|
||||
token = await _attach_to_ring(ringd_pid, name)
|
||||
log.info(f'attached to {token}')
|
||||
token = await _request_ring_fds(fullname)
|
||||
log.info(f'{actor.name} opened {token}')
|
||||
try:
|
||||
yield token
|
||||
|
||||
finally:
|
||||
with trio.CancelScope(shield=True):
|
||||
await ctx.cancel()
|
||||
|
||||
|
||||
@acm
|
||||
async def attach_ringbuf(
|
||||
name: str,
|
||||
) -> AsyncContextManager[RBToken]:
|
||||
'''
|
||||
Attach to an existing ring and retreive its fds.
|
||||
|
||||
'''
|
||||
actor = tractor.current_actor()
|
||||
async with (
|
||||
wait_for_ringd() as ringd,
|
||||
|
||||
ringd.open_context(
|
||||
_attach_ringbuf,
|
||||
caller=actor.name,
|
||||
name=name,
|
||||
) as (ctx, _),
|
||||
):
|
||||
fullname = await ctx.wait_for_result()
|
||||
token = await _request_ring_fds(fullname)
|
||||
log.info(f'{actor.name} attached {token}')
|
||||
yield token
|
||||
|
||||
|
||||
@acm
|
||||
async def maybe_open_ringbuf(
|
||||
name: str,
|
||||
buf_size: int = 10 * 1024,
|
||||
) -> AsyncContextManager[RBToken]:
|
||||
'''
|
||||
Attach or create a ring and retreive its fds.
|
||||
|
||||
'''
|
||||
actor = tractor.current_actor()
|
||||
async with (
|
||||
wait_for_ringd() as ringd,
|
||||
|
||||
ringd.open_context(
|
||||
_maybe_open_ringbuf,
|
||||
caller=actor.name,
|
||||
name=name,
|
||||
buf_size=buf_size,
|
||||
) as (ctx, fullname),
|
||||
):
|
||||
token = await _request_ring_fds(fullname)
|
||||
log.info(f'{actor.name} opened {token}')
|
||||
try:
|
||||
yield token
|
||||
|
||||
finally:
|
||||
with trio.CancelScope(shield=True):
|
||||
await ctx.cancel()
|
||||
|
|
Loading…
Reference in New Issue