From 5d6fa643ba15a45f994cebf2bda593247567e095 Mon Sep 17 00:00:00 2001 From: Guillermo Rodriguez Date: Thu, 10 Apr 2025 13:13:08 -0300 Subject: [PATCH] Better APIs for ringd and pubsub Pubsub: Remove un-necesary ChannelManager locking mechanism Make ChannelManager.close wait for all channel removals Make publisher turn switch configurable with `msgs_per_turn` variable Fix batch_size setter on publisher Add broadcast to publisher Add endpoints on pubsub for remote actors to dynamically add and remove channels Ringd: Add fifo lock and use it on methods that modify _rings state Add comments Break up ringd.open_ringbuf apis into attach_, open_ & maybe_open_ When attaching its no longer a long running context, only on opens Adapt ringd test to new apis --- tests/test_ringd.py | 21 +- tractor/ipc/_ringbuf/_pubsub.py | 465 ++++++++++++++++++++++++-------- tractor/ipc/_ringbuf/_ringd.py | 364 ++++++++++++++++++------- 3 files changed, 617 insertions(+), 233 deletions(-) diff --git a/tests/test_ringd.py b/tests/test_ringd.py index 4b5c792e..e08b7c1c 100644 --- a/tests/test_ringd.py +++ b/tests/test_ringd.py @@ -1,5 +1,3 @@ -from contextlib import asynccontextmanager as acm - import trio import tractor import msgspec @@ -40,7 +38,7 @@ async def send_child( ring_name: str ): async with ( - ringd.open_ringbuf(ring_name) as token, + ringd.attach_ringbuf(ring_name) as token, attach_to_ringbuf_sender(token) as chan, ): @@ -63,9 +61,7 @@ def test_ringd(): async with ( tractor.open_nursery() as an, - ringd.open_ringd( - loglevel='info' - ) + ringd.open_ringd() ): recv_portal = await an.start_actor( 'recv', @@ -133,10 +129,10 @@ async def subscriber_child(ctx: tractor.Context): msg = msgspec.msgpack.decode(msg, type=ControlMessages) match msg: case AddChannelMsg(): - await subs.add_channel(msg.name, must_exist=False) + await subs.add_channel(msg.name) case RemoveChannelMsg(): - await subs.remove_channel(msg.name) + subs.remove_channel(msg.name) case RangeMsg(): range_msg = msg @@ -171,7 +167,7 @@ async def subscriber_child(ctx: tractor.Context): async def publisher_child(ctx: tractor.Context): await ctx.started() async with ( - open_ringbuf_publisher(batch_size=100, guarantee_order=True) as pub, + open_ringbuf_publisher(guarantee_order=True) as pub, ctx.open_stream() as stream ): async for msg in stream: @@ -181,7 +177,7 @@ async def publisher_child(ctx: tractor.Context): await pub.add_channel(msg.name, must_exist=True) case RemoveChannelMsg(): - await pub.remove_channel(msg.name) + pub.remove_channel(msg.name) case RangeMsg(): for i in range(msg.size): @@ -258,11 +254,6 @@ def test_pubsub(): await send_range(100) await remove_channel(ring_name) - # try using same ring name - await add_channel(ring_name) - await send_range(100) - await remove_channel(ring_name) - # multi chan test ring_names = [] for i in range(3): diff --git a/tractor/ipc/_ringbuf/_pubsub.py b/tractor/ipc/_ringbuf/_pubsub.py index 37d54308..e2575ab6 100644 --- a/tractor/ipc/_ringbuf/_pubsub.py +++ b/tractor/ipc/_ringbuf/_pubsub.py @@ -86,19 +86,14 @@ class ChannelManager(Generic[ChannelType]): # store channel runtime variables self._channels: list[ChannelInfo] = [] - # methods that modify self._channels should be ordered by FIFO - self._lock = trio.StrictFIFOLock() - self._is_closed: bool = True + self._teardown = trio.Event() + @property def closed(self) -> bool: return self._is_closed - @property - def lock(self) -> trio.StrictFIFOLock: - return self._lock - @property def channels(self) -> list[ChannelInfo]: return self._channels @@ -106,8 +101,8 @@ class ChannelManager(Generic[ChannelType]): async def _channel_handler_task( self, name: str, - task_status: trio.TASK_STATUS_IGNORED, - **kwargs + must_exist: bool = False, + task_status=trio.TASK_STATUS_IGNORED, ): ''' Open channel resources, add to internal data structures, signal channel @@ -119,7 +114,7 @@ class ChannelManager(Generic[ChannelType]): kwargs are proxied to `self._open_channel` acm. ''' - async with self._open_channel(name, **kwargs) as chan: + async with self._open_channel(name, must_exist=must_exist) as chan: cancel_scope = trio.CancelScope() info = ChannelInfo( name=name, @@ -138,6 +133,9 @@ class ChannelManager(Generic[ChannelType]): self._maybe_destroy_channel(name) + if len(self) == 0: + self._teardown.set() + def _find_channel(self, name: str) -> tuple[int, ChannelInfo] | None: ''' Given a channel name maybe return its index and value from @@ -165,7 +163,7 @@ class ChannelManager(Generic[ChannelType]): info.cancel_scope.cancel() del self._channels[i] - async def add_channel(self, name: str, **kwargs): + async def add_channel(self, name: str, must_exist: bool = False): ''' Add a new channel to be handled @@ -173,14 +171,13 @@ class ChannelManager(Generic[ChannelType]): if self.closed: raise trio.ClosedResourceError - async with self._lock: - await self._n.start(partial( - self._channel_handler_task, - name, - **kwargs - )) + await self._n.start(partial( + self._channel_handler_task, + name, + must_exist=must_exist + )) - async def remove_channel(self, name: str): + def remove_channel(self, name: str): ''' Remove a channel and stop its handling @@ -188,12 +185,11 @@ class ChannelManager(Generic[ChannelType]): if self.closed: raise trio.ClosedResourceError - async with self._lock: - self._maybe_destroy_channel(name) + self._maybe_destroy_channel(name) - # if that was last channel reset connect event - if len(self) == 0: - self._connect_event = trio.Event() + # if that was last channel reset connect event + if len(self) == 0: + self._connect_event = trio.Event() async def wait_for_channel(self): ''' @@ -226,7 +222,18 @@ class ChannelManager(Generic[ChannelType]): return for info in self._channels: - await self.remove_channel(info.name) + if info.channel.closed: + continue + + self.remove_channel(info.name) + + try: + await self._teardown.wait() + + except trio.Cancelled: + # log.exception('close was cancelled') + raise + self._is_closed = True @@ -236,12 +243,6 @@ Ring buffer publisher & subscribe pattern mediated by `ringd` actor. ''' -@dataclass -class PublisherChannels: - ring: RingBufferSendChannel - schan: trio.MemorySendChannel - rchan: trio.MemoryReceiveChannel - class RingBufferPublisher(trio.abc.SendChannel[bytes]): ''' @@ -259,24 +260,32 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]): # new ringbufs created will have this buf_size buf_size: int = 10 * 1024, + # amount of msgs to each ring before switching turns + msgs_per_turn: int = 1, + # global batch size for all channels batch_size: int = 1 ): self._buf_size = buf_size self._batch_size: int = batch_size + self.msgs_per_turn = msgs_per_turn - self._chanmngr = ChannelManager[PublisherChannels]( + # helper to manage acms + long running tasks + self._chanmngr = ChannelManager[RingBufferSendChannel]( n, self._open_channel, self._channel_task ) - # methods that send data over the channels need to be acquire send lock - # in order to guarantee order of operations + # ensure no concurrent `.send()` calls self._send_lock = trio.StrictFIFOLock() + # index of channel to be used for next send self._next_turn: int = 0 - + # amount of messages sent this turn + self._turn_msgs: int = 0 + # have we closed this publisher? + # set to `False` on `.__aenter__()` self._is_closed: bool = True @property @@ -288,14 +297,31 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]): return self._batch_size @batch_size.setter - def set_batch_size(self, value: int) -> None: + def batch_size(self, value: int) -> None: for info in self.channels: - info.channel.ring.batch_size = value + info.channel.batch_size = value @property def channels(self) -> list[ChannelInfo]: return self._chanmngr.channels + def _get_next_turn(self) -> int: + ''' + Maybe switch turn and reset self._turn_msgs or just increment it. + Return current turn + ''' + if self._turn_msgs == self.msgs_per_turn: + self._turn_msgs = 0 + self._next_turn += 1 + + if self._next_turn >= len(self.channels): + self._next_turn = 0 + + else: + self._turn_msgs += 1 + + return self._next_turn + def get_channel(self, name: str) -> ChannelInfo: ''' Get underlying ChannelInfo from name @@ -310,8 +336,8 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]): ): await self._chanmngr.add_channel(name, must_exist=must_exist) - async def remove_channel(self, name: str): - await self._chanmngr.remove_channel(name) + def remove_channel(self, name: str): + self._chanmngr.remove_channel(name) @acm async def _open_channel( @@ -320,41 +346,45 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]): name: str, must_exist: bool = False - ) -> AsyncContextManager[PublisherChannels]: + ) -> AsyncContextManager[RingBufferSendChannel]: ''' Open a ringbuf through `ringd` and attach as send side ''' - async with ( - ringd.open_ringbuf( - name=name, - buf_size=self._buf_size, - must_exist=must_exist, - ) as token, - attach_to_ringbuf_sender(token) as ring, - ): - schan, rchan = trio.open_memory_channel(0) - yield PublisherChannels( - ring=ring, - schan=schan, - rchan=rchan - ) - try: - while True: - msg = rchan.receive_nowait() - await ring.send(msg) + if must_exist: + ringd_fn = ringd.attach_ringbuf + kwargs = {} - except trio.WouldBlock: - ... + else: + ringd_fn = ringd.open_ringbuf + kwargs = {'buf_size': self._buf_size} + + async with ( + ringd_fn( + name=name, + **kwargs + ) as token, + + attach_to_ringbuf_sender( + token, + batch_size=self._batch_size + ) as ring, + ): + yield ring + # try: + # # ensure all messages are sent + # await ring.flush() + + # except Exception as e: + # e.add_note(f'while closing ringbuf send channel {name}') + # log.exception(e) async def _channel_task(self, info: ChannelInfo) -> None: ''' - Forever get current runtime info for channel, wait on its next pending - payloads update event then drain all into send channel. + Wait forever until channel cancellation ''' try: - async for msg in info.channel.rchan: - await info.channel.ring.send(msg) + await trio.sleep_forever() except trio.Cancelled: ... @@ -362,11 +392,9 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]): async def send(self, msg: bytes): ''' If no output channels connected, wait until one, then fetch the next - channel based on turn, add the indexed payload and update - `self._next_turn` & `self._next_index`. + channel based on turn. - Needs to acquire `self._send_lock` to make sure updates to turn & index - variables dont happen out of order. + Needs to acquire `self._send_lock` to ensure no concurrent calls. ''' if self.closed: @@ -380,18 +408,28 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]): if len(self.channels) == 0: await self._chanmngr.wait_for_channel() - if self._next_turn >= len(self.channels): - self._next_turn = 0 + turn = self._get_next_turn() - info = self.channels[self._next_turn] - await info.channel.schan.send(msg) + info = self.channels[turn] + await info.channel.send(msg) - self._next_turn += 1 + async def broadcast(self, msg: bytes): + ''' + Send a msg to all channels, if no channels connected, does nothing. + ''' + if self.closed: + raise trio.ClosedResourceError + + for info in self.channels: + await info.channel.send(msg) async def flush(self, new_batch_size: int | None = None): - async with self._chanmngr.lock: - for info in self.channels: - await info.channel.ring.flush(new_batch_size=new_batch_size) + for info in self.channels: + try: + await info.channel.flush(new_batch_size=new_batch_size) + + except trio.ClosedResourceError: + ... async def __aenter__(self): self._chanmngr.open() @@ -403,41 +441,12 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]): log.warning('tried to close RingBufferPublisher but its already closed...') return - await self._chanmngr.close() + with trio.CancelScope(shield=True): + await self._chanmngr.close() + self._is_closed = True -@acm -async def open_ringbuf_publisher( - - buf_size: int = 10 * 1024, - batch_size: int = 1, - guarantee_order: bool = False, - force_cancel: bool = False - -) -> AsyncContextManager[RingBufferPublisher]: - ''' - Open a new ringbuf publisher - - ''' - async with ( - trio.open_nursery() as n, - RingBufferPublisher( - n, - buf_size=buf_size, - batch_size=batch_size - ) as publisher - ): - if guarantee_order: - order_send_channel(publisher) - - yield publisher - - if force_cancel: - # implicitly cancel any running channel handler task - n.cancel_scope.cancel() - - class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]): ''' Use ChannelManager to create a multi ringbuf receiver that can @@ -458,10 +467,15 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]): self, n: trio.Nursery, + # new ringbufs created will have this buf_size + buf_size: int = 10 * 1024, + # if connecting to a publisher that has already sent messages set # to the next expected payload index this subscriber will receive start_index: int = 0 ): + self._buf_size = buf_size + self._chanmngr = ChannelManager[RingBufferReceiveChannel]( n, self._open_channel, @@ -488,8 +502,8 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]): async def add_channel(self, name: str, must_exist: bool = False): await self._chanmngr.add_channel(name, must_exist=must_exist) - async def remove_channel(self, name: str): - await self._chanmngr.remove_channel(name) + def remove_channel(self, name: str): + self._chanmngr.remove_channel(name) @acm async def _open_channel( @@ -502,11 +516,20 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]): ''' Open a ringbuf through `ringd` and attach as receiver side ''' + if must_exist: + ringd_fn = ringd.attach_ringbuf + kwargs = {} + + else: + ringd_fn = ringd.open_ringbuf + kwargs = {'buf_size': self._buf_size} + async with ( - ringd.open_ringbuf( + ringd_fn( name=name, - must_exist=must_exist, + **kwargs ) as token, + attach_to_ringbuf_receiver(token) as chan ): yield chan @@ -554,7 +577,6 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]): async def aclose(self) -> None: if self.closed: - log.warning('tried to close RingBufferSubscriber but its already closed...') return await self._chanmngr.close() @@ -562,26 +584,241 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]): await self._rchan.aclose() self._is_closed = True + +''' +Actor module for managing publisher & subscriber channels remotely through +`tractor.context` rpc +''' + +_publisher: RingBufferPublisher | None = None +_subscriber: RingBufferSubscriber | None = None + + +def set_publisher(pub: RingBufferPublisher): + global _publisher + + if _publisher: + raise RuntimeError( + f'publisher already set on {tractor.current_actor()}' + ) + + _publisher = pub + + +def set_subscriber(sub: RingBufferSubscriber): + global _subscriber + + if _subscriber: + raise RuntimeError( + f'subscriber already set on {tractor.current_actor()}' + ) + + _subscriber = sub + + +def get_publisher() -> RingBufferPublisher: + global _publisher + + if not _publisher: + raise RuntimeError( + f'{tractor.current_actor()} tried to get publisher' + 'but it\'s not set' + ) + + return _publisher + + +def get_subscriber() -> RingBufferSubscriber: + global _subscriber + + if not _subscriber: + raise RuntimeError( + f'{tractor.current_actor()} tried to get subscriber' + 'but it\'s not set' + ) + + return _subscriber + + +@tractor.context +async def open_pub_channel( + ctx: tractor.Context, + ring_name: str, + must_exist: bool = False +): + publisher = get_publisher() + await publisher.add_channel( + ring_name, + must_exist=must_exist + ) + + await ctx.started() + + try: + await trio.sleep_forever() + + finally: + try: + publisher.remove_channel(ring_name) + + except trio.ClosedResourceError: + ... + + +@acm +async def open_pub_channel_at( + actor_name: str, + ring_name: str, + must_exist: bool = False +): + async with ( + tractor.find_actor(actor_name) as portal, + portal.open_context( + open_pub_channel, + ring_name=ring_name, + must_exist=must_exist + ) as (ctx, _) + ): + yield + await ctx.cancel() + + +@tractor.context +async def open_sub_channel( + ctx: tractor.Context, + ring_name: str, + must_exist: bool = False +): + subscriber = get_subscriber() + await subscriber.add_channel( + ring_name, + must_exist=must_exist + ) + + await ctx.started() + + try: + await trio.sleep_forever() + + finally: + try: + subscriber.remove_channel(ring_name) + + except trio.ClosedResourceError: + ... + + +@acm +async def open_sub_channel_at( + actor_name: str, + ring_name: str, + must_exist: bool = False +): + async with ( + tractor.find_actor(actor_name) as portal, + portal.open_context( + open_sub_channel, + ring_name=ring_name, + must_exist=must_exist + ) as (ctx, _) + ): + yield + await ctx.cancel() + + +''' +High level helpers to open publisher & subscriber +''' + + +@acm +async def open_ringbuf_publisher( + # buf size for created rings + buf_size: int = 10 * 1024, + + # global batch size for channels + batch_size: int = 1, + + # messages before changing output channel + msgs_per_turn: int = 1, + + # ensure subscriber receives in same order publisher sent + # causes it to use wrapped payloads which contain the og + # index + guarantee_order: bool = False, + + # explicit nursery cancel call on cleanup + force_cancel: bool = False, + + # on creation, set the `_publisher` global in order to use the provided + # tractor.context & helper utils for adding and removing new channels from + # remote actors + set_module_var: bool = True + +) -> AsyncContextManager[RingBufferPublisher]: + ''' + Open a new ringbuf publisher + + ''' + async with ( + trio.open_nursery(strict_exception_groups=False) as n, + RingBufferPublisher( + n, + buf_size=buf_size, + batch_size=batch_size + ) as publisher + ): + if guarantee_order: + order_send_channel(publisher) + + if set_module_var: + set_publisher(publisher) + + try: + yield publisher + + finally: + if force_cancel: + # implicitly cancel any running channel handler task + n.cancel_scope.cancel() + + @acm async def open_ringbuf_subscriber( + # buf size for created rings + buf_size: int = 10 * 1024, + # expect indexed payloads and unwrap them in order guarantee_order: bool = False, - force_cancel: bool = False + # explicit nursery cancel call on cleanup + force_cancel: bool = False, + + # on creation, set the `_subscriber` global in order to use the provided + # tractor.context & helper utils for adding and removing new channels from + # remote actors + set_module_var: bool = True ) -> AsyncContextManager[RingBufferPublisher]: ''' Open a new ringbuf subscriber ''' async with ( - trio.open_nursery() as n, + trio.open_nursery(strict_exception_groups=False) as n, RingBufferSubscriber( n, + buf_size=buf_size ) as subscriber ): + # maybe monkey patch `.receive` to use indexed payloads if guarantee_order: order_receive_channel(subscriber) + # maybe set global module var for remote actor channel updates + if set_module_var: + global _subscriber + set_subscriber(subscriber) + yield subscriber if force_cancel: diff --git a/tractor/ipc/_ringbuf/_ringd.py b/tractor/ipc/_ringbuf/_ringd.py index 24c3e530..51818e34 100644 --- a/tractor/ipc/_ringbuf/_ringd.py +++ b/tractor/ipc/_ringbuf/_ringd.py @@ -25,6 +25,7 @@ open_ringbuf acm, will automatically contact ringd. ''' import os import tempfile +from typing import AsyncContextManager from pathlib import Path from contextlib import ( asynccontextmanager as acm @@ -33,111 +34,118 @@ from dataclasses import dataclass import trio import tractor -from tractor.linux import send_fds, recv_fds +from tractor.linux import ( + send_fds, + recv_fds, +) import tractor.ipc._ringbuf as ringbuf from tractor.ipc._ringbuf import RBToken log = tractor.log.get_logger(__name__) -# log = tractor.log.get_console_log(level='info') -class RingNotFound(Exception): - ... +''' +Daemon implementation + +''' -_ringd_actor_name = 'ringd' -_root_key = _ringd_actor_name + f'-{os.getpid()}' +_ringd_actor_name: str = 'ringd' + + +_root_name: str = f'{_ringd_actor_name}-{os.getpid()}' + + +def _make_ring_name(name: str) -> str: + ''' + User provided ring names will be prefixed by the ringd actor name and pid. + ''' + return f'{_root_name}.{name}' @dataclass class RingInfo: token: RBToken creator: str - unlink: trio.Event() _rings: dict[str, RingInfo] = {} +_ring_lock = trio.StrictFIFOLock() def _maybe_get_ring(name: str) -> RingInfo | None: - if name in _rings: - return _rings[name] + ''' + Maybe return RingInfo for a given name str - return None + ''' + # if full name was passed, strip root name + if _root_name in name: + name = name.replace(f'{_root_name}.', '') + + return _rings.get(name, None) + + +def _get_ring(name: str) -> RingInfo: + ''' + Return a RingInfo for a given name or raise + ''' + info = _maybe_get_ring(name) + + if not info: + raise RuntimeError(f'Ring \"{name}\" not found!') + + return info def _insert_ring(name: str, info: RingInfo): + ''' + Add a new ring + ''' + if name in _rings: + raise RuntimeError(f'A ring with name {name} already exists!') + _rings[name] = info def _destroy_ring(name: str): + ''' + Delete information about a ring + ''' + if name not in _rings: + raise RuntimeError(f'Tried to delete non existant {name} ring!') + del _rings[name] -async def _attach_to_ring( - ringd_pid: int, - ring_name: str -) -> RBToken: - actor = tractor.current_actor() - - fd_amount = 3 - sock_path = str( - Path(tempfile.gettempdir()) - / - f'ringd-{ringd_pid}-{ring_name}-to-{actor.name}.sock' - ) - - log.info(f'trying to attach to ring {ring_name}...') - - async with ( - tractor.find_actor(_ringd_actor_name) as ringd, - ringd.open_context( - _pass_fds, - name=ring_name, - sock_path=sock_path - ) as (ctx, token), - ): - fds = await recv_fds(sock_path, fd_amount) - log.info( - f'received fds: {fds}' - ) - - token = RBToken.from_msg(token) - - write, wrap, eof = fds - - return RBToken( - shm_name=token.shm_name, - write_eventfd=write, - wrap_eventfd=wrap, - eof_eventfd=eof, - buf_size=token.buf_size - ) - - @tractor.context async def _pass_fds( ctx: tractor.Context, name: str, sock_path: str ): - global _rings - info = _maybe_get_ring(name) + ''' + Ringd endpoint to request passing fds of a ring. - if not info: - raise RingNotFound(f'Ring \"{name}\" not found!') + Supports passing fullname or not (ringd actor name and pid before ring + name). - token = info.token + See `_attach_to_ring` function for usage. + ''' + async with _ring_lock: + # get ring fds or raise error + token = _get_ring(name).token - async with send_fds(token.fds, sock_path): - log.info(f'connected to {sock_path} for fd passing') - await ctx.started(token) + # start fd passing context using socket on `sock_path` + async with send_fds(token.fds, sock_path): + log.info(f'connected to {sock_path} for fd passing') + # use started to signal socket is ready and send token in order for + # client to get extra info like buf_size + await ctx.started(token) + # send_fds will block until receive side acks - log.info(f'fds {token.fds} sent') - - return token + log.info(f'ring {name} fds: {token.fds}, sent') @tractor.context @@ -145,60 +153,105 @@ async def _open_ringbuf( ctx: tractor.Context, caller: str, name: str, - buf_size: int = 10 * 1024, - must_exist: bool = False, + buf_size: int = 10 * 1024 ): - global _root_key, _rings - log.info(f'maybe open ring {name} from {caller}, must_exist = {must_exist}') + ''' + Ringd endpoint to create and allocate resources for a new ring. - info = _maybe_get_ring(name) + ''' + await _ring_lock.acquire() + maybe_info = _maybe_get_ring(name) - if info: - log.info(f'ring {name} exists, {caller} attached') - - await ctx.started(os.getpid()) - - async with ctx.open_stream() as stream: - await stream.receive() - - info.unlink.set() - - log.info(f'{caller} detached from ring {name}') - - return - - if must_exist: - raise RingNotFound( - f'Tried to open_ringbuf but it doesn\'t exist: {name}' + if maybe_info: + raise RuntimeError( + f'Tried to create ringbuf but it already exists: {name}' ) + fullname = _make_ring_name(name) + with ringbuf.open_ringbuf( - _root_key + name, + fullname, buf_size=buf_size ) as token: - unlink_event = trio.Event() + _insert_ring( name, RingInfo( token=token, creator=caller, - unlink=unlink_event, ) ) - log.info(f'ring {name} created by {caller}') - await ctx.started(os.getpid()) - async with ctx.open_stream() as stream: - await stream.receive() + _ring_lock.release() - await unlink_event.wait() - _destroy_ring(name) + # yield full ring name to rebuild token after fd passing + await ctx.started(fullname) - log.info(f'ring {name} destroyed by {caller}') + # await ctx cancel to remove ring from tracking and cleanup + try: + log.info(f'ring {name} created by {caller}') + await trio.sleep_forever() + + finally: + _destroy_ring(name) + + log.info(f'ring {name} destroyed by {caller}') + + +@tractor.context +async def _attach_ringbuf( + ctx: tractor.Context, + caller: str, + name: str +) -> str: + ''' + Ringd endpoint to "attach" to an existing ring, this just ensures ring + actually exists and returns its full name. + ''' + async with _ring_lock: + info = _maybe_get_ring(name) + + if not info: + raise RuntimeError( + f'{caller} tried to open_ringbuf but it doesn\'t exist: {name}' + ) + + await ctx.started() + + # return full ring name to rebuild token after fd passing + return info.token.shm_name + + +@tractor.context +async def _maybe_open_ringbuf( + ctx: tractor.Context, + caller: str, + name: str, + buf_size: int = 10 * 1024, +): + ''' + If ring already exists attach, if not create it. + ''' + maybe_info = _maybe_get_ring(name) + + if maybe_info: + return await _attach_ringbuf(ctx, caller, name) + + return await _open_ringbuf(ctx, caller, name, buf_size=buf_size) + + +''' +Ringd client side helpers + +''' @acm async def open_ringd(**kwargs) -> tractor.Portal: + ''' + Spawn new ringd actor. + + ''' async with tractor.open_nursery(**kwargs) as an: portal = await an.start_actor( _ringd_actor_name, @@ -210,21 +263,69 @@ async def open_ringd(**kwargs) -> tractor.Portal: @acm async def wait_for_ringd() -> tractor.Portal: + ''' + Wait for ringd actor to be up. + + ''' async with tractor.wait_for_actor( _ringd_actor_name ) as portal: yield portal +async def _request_ring_fds( + fullname: str +) -> RBToken: + ''' + Private helper to fetch ring fds from ringd actor. + ''' + actor = tractor.current_actor() + + fd_amount = 3 + sock_path = str( + Path(tempfile.gettempdir()) + / + f'{fullname}-to-{actor.name}.sock' + ) + + log.info(f'trying to attach to {fullname}...') + + async with ( + tractor.find_actor(_ringd_actor_name) as ringd, + + ringd.open_context( + _pass_fds, + name=fullname, + sock_path=sock_path + ) as (ctx, token), + ): + fds = await recv_fds(sock_path, fd_amount) + write, wrap, eof = fds + log.info( + f'received fds, write: {write}, wrap: {wrap}, eof: {eof}' + ) + + token = RBToken.from_msg(token) + + return RBToken( + shm_name=fullname, + write_eventfd=write, + wrap_eventfd=wrap, + eof_eventfd=eof, + buf_size=token.buf_size + ) + + + @acm async def open_ringbuf( - name: str, buf_size: int = 10 * 1024, +) -> AsyncContextManager[RBToken]: + ''' + Create a new ring and retrieve its fds. - must_exist: bool = False, - -) -> RBToken: + ''' actor = tractor.current_actor() async with ( wait_for_ringd() as ringd, @@ -234,12 +335,67 @@ async def open_ringbuf( caller=actor.name, name=name, buf_size=buf_size, - must_exist=must_exist - ) as (rd_ctx, ringd_pid), - - rd_ctx.open_stream() as _stream, + ) as (ctx, fullname), ): - token = await _attach_to_ring(ringd_pid, name) - log.info(f'attached to {token}') + token = await _request_ring_fds(fullname) + log.info(f'{actor.name} opened {token}') + try: + yield token + + finally: + with trio.CancelScope(shield=True): + await ctx.cancel() + + +@acm +async def attach_ringbuf( + name: str, +) -> AsyncContextManager[RBToken]: + ''' + Attach to an existing ring and retreive its fds. + + ''' + actor = tractor.current_actor() + async with ( + wait_for_ringd() as ringd, + + ringd.open_context( + _attach_ringbuf, + caller=actor.name, + name=name, + ) as (ctx, _), + ): + fullname = await ctx.wait_for_result() + token = await _request_ring_fds(fullname) + log.info(f'{actor.name} attached {token}') yield token + +@acm +async def maybe_open_ringbuf( + name: str, + buf_size: int = 10 * 1024, +) -> AsyncContextManager[RBToken]: + ''' + Attach or create a ring and retreive its fds. + + ''' + actor = tractor.current_actor() + async with ( + wait_for_ringd() as ringd, + + ringd.open_context( + _maybe_open_ringbuf, + caller=actor.name, + name=name, + buf_size=buf_size, + ) as (ctx, fullname), + ): + token = await _request_ring_fds(fullname) + log.info(f'{actor.name} opened {token}') + try: + yield token + + finally: + with trio.CancelScope(shield=True): + await ctx.cancel()