diff --git a/tractor/ipc/_ringbuf/_pubsub.py b/tractor/ipc/_ringbuf/_pubsub.py index 4d5e0d20..37d54308 100644 --- a/tractor/ipc/_ringbuf/_pubsub.py +++ b/tractor/ipc/_ringbuf/_pubsub.py @@ -58,7 +58,7 @@ class ChannelInfo: cancel_scope: trio.CancelScope -class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]): +class ChannelManager(Generic[ChannelType]): ''' Helper for managing channel resources and their handler tasks with cancellation, add or remove channels dynamically! @@ -89,18 +89,15 @@ class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]): # methods that modify self._channels should be ordered by FIFO self._lock = trio.StrictFIFOLock() - @acm - async def maybe_lock(self): - ''' - If lock is not held, acquire + self._is_closed: bool = True - ''' - if self._lock.locked(): - yield - return + @property + def closed(self) -> bool: + return self._is_closed - async with self._lock: - yield + @property + def lock(self) -> trio.StrictFIFOLock: + return self._lock @property def channels(self) -> list[ChannelInfo]: @@ -139,7 +136,7 @@ class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]): with cancel_scope: await self._channel_task(info) - await self._maybe_destroy_channel(name) + self._maybe_destroy_channel(name) def _find_channel(self, name: str) -> tuple[int, ChannelInfo] | None: ''' @@ -156,25 +153,27 @@ class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]): return None - async def _maybe_destroy_channel(self, name: str): + def _maybe_destroy_channel(self, name: str): ''' If channel exists cancel its scope and remove from internal _channels list. ''' - async with self.maybe_lock(): - maybe_entry = self._find_channel(name) - if maybe_entry: - i, info = maybe_entry - info.cancel_scope.cancel() - del self._channels[i] + maybe_entry = self._find_channel(name) + if maybe_entry: + i, info = maybe_entry + info.cancel_scope.cancel() + del self._channels[i] async def add_channel(self, name: str, **kwargs): ''' Add a new channel to be handled ''' - async with self.maybe_lock(): + if self.closed: + raise trio.ClosedResourceError + + async with self._lock: await self._n.start(partial( self._channel_handler_task, name, @@ -186,8 +185,11 @@ class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]): Remove a channel and stop its handling ''' - async with self.maybe_lock(): - await self._maybe_destroy_channel(name) + if self.closed: + raise trio.ClosedResourceError + + async with self._lock: + self._maybe_destroy_channel(name) # if that was last channel reset connect event if len(self) == 0: @@ -198,6 +200,9 @@ class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]): Wait until at least one channel added ''' + if self.closed: + raise trio.ClosedResourceError + await self._connect_event.wait() self._connect_event = trio.Event() @@ -212,10 +217,18 @@ class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]): raise KeyError(f'Channel {name} not found!') - async def aclose(self) -> None: - async with self.maybe_lock(): - for info in self._channels: - await self.remove_channel(info.name) + def open(self): + self._is_closed = False + + async def close(self) -> None: + if self.closed: + log.warning('tried to close ChannelManager but its already closed...') + return + + for info in self._channels: + await self.remove_channel(info.name) + + self._is_closed = True ''' @@ -264,6 +277,12 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]): self._next_turn: int = 0 + self._is_closed: bool = True + + @property + def closed(self) -> bool: + return self._is_closed + @property def batch_size(self) -> int: return self._batch_size @@ -289,36 +308,10 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]): name: str, must_exist: bool = False ): - ''' - Store additional runtime info for channel and add channel to underlying - ChannelManager - - ''' await self._chanmngr.add_channel(name, must_exist=must_exist) async def remove_channel(self, name: str): - ''' - Send EOF to channel (does `channel.flush` also) then remove from - `ChannelManager` acquire both `self._send_lock` and - `self._chanmngr.maybe_lock()` in order to ensure no channel - modifications or sends happen concurrenty - ''' - async with self._chanmngr.maybe_lock(): - # ensure all pending messages are sent - info = self.get_channel(name) - - try: - while True: - msg = info.channel.rchan.receive_nowait() - await info.channel.ring.send(msg) - - except trio.WouldBlock: - await info.channel.ring.flush() - - await info.channel.schan.aclose() - - # finally remove from ChannelManager - await self._chanmngr.remove_channel(name) + await self._chanmngr.remove_channel(name) @acm async def _open_channel( @@ -345,6 +338,13 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]): schan=schan, rchan=rchan ) + try: + while True: + msg = rchan.receive_nowait() + await ring.send(msg) + + except trio.WouldBlock: + ... async def _channel_task(self, info: ChannelInfo) -> None: ''' @@ -352,8 +352,12 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]): payloads update event then drain all into send channel. ''' - async for msg in info.channel.rchan: - await info.channel.ring.send(msg) + try: + async for msg in info.channel.rchan: + await info.channel.ring.send(msg) + + except trio.Cancelled: + ... async def send(self, msg: bytes): ''' @@ -365,6 +369,12 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]): variables dont happen out of order. ''' + if self.closed: + raise trio.ClosedResourceError + + if self._send_lock.locked(): + raise trio.BusyResourceError + async with self._send_lock: # wait at least one decoder connected if len(self.channels) == 0: @@ -378,8 +388,23 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]): self._next_turn += 1 + async def flush(self, new_batch_size: int | None = None): + async with self._chanmngr.lock: + for info in self.channels: + await info.channel.ring.flush(new_batch_size=new_batch_size) + + async def __aenter__(self): + self._chanmngr.open() + self._is_closed = False + return self + async def aclose(self) -> None: - await self._chanmngr.aclose() + if self.closed: + log.warning('tried to close RingBufferPublisher but its already closed...') + return + + await self._chanmngr.close() + self._is_closed = True @acm @@ -445,6 +470,14 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]): self._schan, self._rchan = trio.open_memory_channel(0) + self._is_closed: bool = True + + self._receive_lock = trio.StrictFIFOLock() + + @property + def closed(self) -> bool: + return self._is_closed + @property def channels(self) -> list[ChannelInfo]: return self._chanmngr.channels @@ -453,17 +486,9 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]): return self._chanmngr[name] async def add_channel(self, name: str, must_exist: bool = False): - ''' - Add new input channel by name - - ''' await self._chanmngr.add_channel(name, must_exist=must_exist) async def remove_channel(self, name: str): - ''' - Remove an input channel by name - - ''' await self._chanmngr.remove_channel(name) @acm @@ -506,14 +531,36 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]): except trio.EndOfChannel: break + except trio.ClosedResourceError: + break + async def receive(self) -> bytes: ''' Receive next in order msg ''' - return await self._rchan.receive() + if self.closed: + raise trio.ClosedResourceError + + if self._receive_lock.locked(): + raise trio.BusyResourceError + + async with self._receive_lock: + return await self._rchan.receive() + + async def __aenter__(self): + self._is_closed = False + self._chanmngr.open() + return self async def aclose(self) -> None: - await self._chanmngr.aclose() + if self.closed: + log.warning('tried to close RingBufferSubscriber but its already closed...') + return + + await self._chanmngr.close() + await self._schan.aclose() + await self._rchan.aclose() + self._is_closed = True @acm async def open_ringbuf_subscriber(