Add trio resource semantics to ring pubsub
							parent
							
								
									7b668c2f33
								
							
						
					
					
						commit
						a15b852b18
					
				|  | @ -58,7 +58,7 @@ class ChannelInfo: | ||||||
|     cancel_scope: trio.CancelScope |     cancel_scope: trio.CancelScope | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]): | class ChannelManager(Generic[ChannelType]): | ||||||
|     ''' |     ''' | ||||||
|     Helper for managing channel resources and their handler tasks with |     Helper for managing channel resources and their handler tasks with | ||||||
|     cancellation, add or remove channels dynamically! |     cancellation, add or remove channels dynamically! | ||||||
|  | @ -89,18 +89,15 @@ class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]): | ||||||
|         # methods that modify self._channels should be ordered by FIFO |         # methods that modify self._channels should be ordered by FIFO | ||||||
|         self._lock = trio.StrictFIFOLock() |         self._lock = trio.StrictFIFOLock() | ||||||
| 
 | 
 | ||||||
|     @acm |         self._is_closed: bool = True | ||||||
|     async def maybe_lock(self): |  | ||||||
|         ''' |  | ||||||
|         If lock is not held, acquire |  | ||||||
| 
 | 
 | ||||||
|         ''' |     @property | ||||||
|         if self._lock.locked(): |     def closed(self) -> bool: | ||||||
|             yield |         return self._is_closed | ||||||
|             return |  | ||||||
| 
 | 
 | ||||||
|         async with self._lock: |     @property | ||||||
|             yield |     def lock(self) -> trio.StrictFIFOLock: | ||||||
|  |         return self._lock | ||||||
| 
 | 
 | ||||||
|     @property |     @property | ||||||
|     def channels(self) -> list[ChannelInfo]: |     def channels(self) -> list[ChannelInfo]: | ||||||
|  | @ -139,7 +136,7 @@ class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]): | ||||||
|             with cancel_scope: |             with cancel_scope: | ||||||
|                 await self._channel_task(info) |                 await self._channel_task(info) | ||||||
| 
 | 
 | ||||||
|         await self._maybe_destroy_channel(name) |         self._maybe_destroy_channel(name) | ||||||
| 
 | 
 | ||||||
|     def _find_channel(self, name: str) -> tuple[int, ChannelInfo] | None: |     def _find_channel(self, name: str) -> tuple[int, ChannelInfo] | None: | ||||||
|         ''' |         ''' | ||||||
|  | @ -156,13 +153,12 @@ class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]): | ||||||
|         return None |         return None | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     async def _maybe_destroy_channel(self, name: str): |     def _maybe_destroy_channel(self, name: str): | ||||||
|         ''' |         ''' | ||||||
|         If channel exists cancel its scope and remove from internal |         If channel exists cancel its scope and remove from internal | ||||||
|         _channels list. |         _channels list. | ||||||
| 
 | 
 | ||||||
|         ''' |         ''' | ||||||
|         async with self.maybe_lock(): |  | ||||||
|         maybe_entry = self._find_channel(name) |         maybe_entry = self._find_channel(name) | ||||||
|         if maybe_entry: |         if maybe_entry: | ||||||
|             i, info = maybe_entry |             i, info = maybe_entry | ||||||
|  | @ -174,7 +170,10 @@ class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]): | ||||||
|         Add a new channel to be handled |         Add a new channel to be handled | ||||||
| 
 | 
 | ||||||
|         ''' |         ''' | ||||||
|         async with self.maybe_lock(): |         if self.closed: | ||||||
|  |             raise trio.ClosedResourceError | ||||||
|  | 
 | ||||||
|  |         async with self._lock: | ||||||
|             await self._n.start(partial( |             await self._n.start(partial( | ||||||
|                 self._channel_handler_task, |                 self._channel_handler_task, | ||||||
|                 name, |                 name, | ||||||
|  | @ -186,8 +185,11 @@ class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]): | ||||||
|         Remove a channel and stop its handling |         Remove a channel and stop its handling | ||||||
| 
 | 
 | ||||||
|         ''' |         ''' | ||||||
|         async with self.maybe_lock(): |         if self.closed: | ||||||
|             await self._maybe_destroy_channel(name) |             raise trio.ClosedResourceError | ||||||
|  | 
 | ||||||
|  |         async with self._lock: | ||||||
|  |             self._maybe_destroy_channel(name) | ||||||
| 
 | 
 | ||||||
|             # if that was last channel reset connect event |             # if that was last channel reset connect event | ||||||
|             if len(self) == 0: |             if len(self) == 0: | ||||||
|  | @ -198,6 +200,9 @@ class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]): | ||||||
|         Wait until at least one channel added |         Wait until at least one channel added | ||||||
| 
 | 
 | ||||||
|         ''' |         ''' | ||||||
|  |         if self.closed: | ||||||
|  |             raise trio.ClosedResourceError | ||||||
|  | 
 | ||||||
|         await self._connect_event.wait() |         await self._connect_event.wait() | ||||||
|         self._connect_event = trio.Event() |         self._connect_event = trio.Event() | ||||||
| 
 | 
 | ||||||
|  | @ -212,11 +217,19 @@ class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]): | ||||||
| 
 | 
 | ||||||
|         raise KeyError(f'Channel {name} not found!') |         raise KeyError(f'Channel {name} not found!') | ||||||
| 
 | 
 | ||||||
|     async def aclose(self) -> None: |     def open(self): | ||||||
|         async with self.maybe_lock(): |         self._is_closed = False | ||||||
|  | 
 | ||||||
|  |     async def close(self) -> None: | ||||||
|  |         if self.closed: | ||||||
|  |             log.warning('tried to close ChannelManager but its already closed...') | ||||||
|  |             return | ||||||
|  | 
 | ||||||
|         for info in self._channels: |         for info in self._channels: | ||||||
|             await self.remove_channel(info.name) |             await self.remove_channel(info.name) | ||||||
| 
 | 
 | ||||||
|  |         self._is_closed = True | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| ''' | ''' | ||||||
| Ring buffer publisher & subscribe pattern mediated by `ringd` actor. | Ring buffer publisher & subscribe pattern mediated by `ringd` actor. | ||||||
|  | @ -264,6 +277,12 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]): | ||||||
| 
 | 
 | ||||||
|         self._next_turn: int = 0 |         self._next_turn: int = 0 | ||||||
| 
 | 
 | ||||||
|  |         self._is_closed: bool = True | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def closed(self) -> bool: | ||||||
|  |         return self._is_closed | ||||||
|  | 
 | ||||||
|     @property |     @property | ||||||
|     def batch_size(self) -> int: |     def batch_size(self) -> int: | ||||||
|         return self._batch_size |         return self._batch_size | ||||||
|  | @ -289,35 +308,9 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]): | ||||||
|         name: str, |         name: str, | ||||||
|         must_exist: bool = False |         must_exist: bool = False | ||||||
|     ): |     ): | ||||||
|         ''' |  | ||||||
|         Store additional runtime info for channel and add channel to underlying |  | ||||||
|         ChannelManager |  | ||||||
| 
 |  | ||||||
|         ''' |  | ||||||
|         await self._chanmngr.add_channel(name, must_exist=must_exist) |         await self._chanmngr.add_channel(name, must_exist=must_exist) | ||||||
| 
 | 
 | ||||||
|     async def remove_channel(self, name: str): |     async def remove_channel(self, name: str): | ||||||
|         ''' |  | ||||||
|         Send EOF to channel (does `channel.flush` also) then remove from |  | ||||||
|         `ChannelManager` acquire both `self._send_lock` and |  | ||||||
|         `self._chanmngr.maybe_lock()` in order to ensure no channel |  | ||||||
|         modifications or sends happen concurrenty |  | ||||||
|         ''' |  | ||||||
|         async with self._chanmngr.maybe_lock(): |  | ||||||
|             # ensure all pending messages are sent |  | ||||||
|             info = self.get_channel(name) |  | ||||||
| 
 |  | ||||||
|             try: |  | ||||||
|                 while True: |  | ||||||
|                     msg = info.channel.rchan.receive_nowait() |  | ||||||
|                     await info.channel.ring.send(msg) |  | ||||||
| 
 |  | ||||||
|             except trio.WouldBlock: |  | ||||||
|                 await info.channel.ring.flush() |  | ||||||
| 
 |  | ||||||
|             await info.channel.schan.aclose() |  | ||||||
| 
 |  | ||||||
|             # finally remove from ChannelManager |  | ||||||
|         await self._chanmngr.remove_channel(name) |         await self._chanmngr.remove_channel(name) | ||||||
| 
 | 
 | ||||||
|     @acm |     @acm | ||||||
|  | @ -345,6 +338,13 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]): | ||||||
|                 schan=schan, |                 schan=schan, | ||||||
|                 rchan=rchan |                 rchan=rchan | ||||||
|             ) |             ) | ||||||
|  |             try: | ||||||
|  |                 while True: | ||||||
|  |                     msg = rchan.receive_nowait() | ||||||
|  |                     await ring.send(msg) | ||||||
|  | 
 | ||||||
|  |             except trio.WouldBlock: | ||||||
|  |                 ... | ||||||
| 
 | 
 | ||||||
|     async def _channel_task(self, info: ChannelInfo) -> None: |     async def _channel_task(self, info: ChannelInfo) -> None: | ||||||
|         ''' |         ''' | ||||||
|  | @ -352,9 +352,13 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]): | ||||||
|         payloads update event then drain all into send channel. |         payloads update event then drain all into send channel. | ||||||
| 
 | 
 | ||||||
|         ''' |         ''' | ||||||
|  |         try: | ||||||
|             async for msg in info.channel.rchan: |             async for msg in info.channel.rchan: | ||||||
|                 await info.channel.ring.send(msg) |                 await info.channel.ring.send(msg) | ||||||
| 
 | 
 | ||||||
|  |         except trio.Cancelled: | ||||||
|  |             ... | ||||||
|  | 
 | ||||||
|     async def send(self, msg: bytes): |     async def send(self, msg: bytes): | ||||||
|         ''' |         ''' | ||||||
|         If no output channels connected, wait until one, then fetch the next |         If no output channels connected, wait until one, then fetch the next | ||||||
|  | @ -365,6 +369,12 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]): | ||||||
|         variables dont happen out of order. |         variables dont happen out of order. | ||||||
| 
 | 
 | ||||||
|         ''' |         ''' | ||||||
|  |         if self.closed: | ||||||
|  |             raise trio.ClosedResourceError | ||||||
|  | 
 | ||||||
|  |         if self._send_lock.locked(): | ||||||
|  |             raise trio.BusyResourceError | ||||||
|  | 
 | ||||||
|         async with self._send_lock: |         async with self._send_lock: | ||||||
|             # wait at least one decoder connected |             # wait at least one decoder connected | ||||||
|             if len(self.channels) == 0: |             if len(self.channels) == 0: | ||||||
|  | @ -378,8 +388,23 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]): | ||||||
| 
 | 
 | ||||||
|             self._next_turn += 1 |             self._next_turn += 1 | ||||||
| 
 | 
 | ||||||
|  |     async def flush(self, new_batch_size: int | None = None): | ||||||
|  |         async with self._chanmngr.lock: | ||||||
|  |             for info in self.channels: | ||||||
|  |                 await info.channel.ring.flush(new_batch_size=new_batch_size) | ||||||
|  | 
 | ||||||
|  |     async def __aenter__(self): | ||||||
|  |         self._chanmngr.open() | ||||||
|  |         self._is_closed = False | ||||||
|  |         return self | ||||||
|  | 
 | ||||||
|     async def aclose(self) -> None: |     async def aclose(self) -> None: | ||||||
|         await self._chanmngr.aclose() |         if self.closed: | ||||||
|  |             log.warning('tried to close RingBufferPublisher but its already closed...') | ||||||
|  |             return | ||||||
|  | 
 | ||||||
|  |         await self._chanmngr.close() | ||||||
|  |         self._is_closed = True | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @acm | @acm | ||||||
|  | @ -445,6 +470,14 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]): | ||||||
| 
 | 
 | ||||||
|         self._schan, self._rchan = trio.open_memory_channel(0) |         self._schan, self._rchan = trio.open_memory_channel(0) | ||||||
| 
 | 
 | ||||||
|  |         self._is_closed: bool = True | ||||||
|  | 
 | ||||||
|  |         self._receive_lock = trio.StrictFIFOLock() | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def closed(self) -> bool: | ||||||
|  |         return self._is_closed | ||||||
|  | 
 | ||||||
|     @property |     @property | ||||||
|     def channels(self) -> list[ChannelInfo]: |     def channels(self) -> list[ChannelInfo]: | ||||||
|         return self._chanmngr.channels |         return self._chanmngr.channels | ||||||
|  | @ -453,17 +486,9 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]): | ||||||
|         return self._chanmngr[name] |         return self._chanmngr[name] | ||||||
| 
 | 
 | ||||||
|     async def add_channel(self, name: str, must_exist: bool = False): |     async def add_channel(self, name: str, must_exist: bool = False): | ||||||
|         ''' |  | ||||||
|         Add new input channel by name |  | ||||||
| 
 |  | ||||||
|         ''' |  | ||||||
|         await self._chanmngr.add_channel(name, must_exist=must_exist) |         await self._chanmngr.add_channel(name, must_exist=must_exist) | ||||||
| 
 | 
 | ||||||
|     async def remove_channel(self, name: str): |     async def remove_channel(self, name: str): | ||||||
|         ''' |  | ||||||
|         Remove an input channel by name |  | ||||||
| 
 |  | ||||||
|         ''' |  | ||||||
|         await self._chanmngr.remove_channel(name) |         await self._chanmngr.remove_channel(name) | ||||||
| 
 | 
 | ||||||
|     @acm |     @acm | ||||||
|  | @ -506,14 +531,36 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]): | ||||||
|             except trio.EndOfChannel: |             except trio.EndOfChannel: | ||||||
|                 break |                 break | ||||||
| 
 | 
 | ||||||
|  |             except trio.ClosedResourceError: | ||||||
|  |                 break | ||||||
|  | 
 | ||||||
|     async def receive(self) -> bytes: |     async def receive(self) -> bytes: | ||||||
|         ''' |         ''' | ||||||
|         Receive next in order msg |         Receive next in order msg | ||||||
|         ''' |         ''' | ||||||
|  |         if self.closed: | ||||||
|  |             raise trio.ClosedResourceError | ||||||
|  | 
 | ||||||
|  |         if self._receive_lock.locked(): | ||||||
|  |             raise trio.BusyResourceError | ||||||
|  | 
 | ||||||
|  |         async with self._receive_lock: | ||||||
|             return await self._rchan.receive() |             return await self._rchan.receive() | ||||||
| 
 | 
 | ||||||
|  |     async def __aenter__(self): | ||||||
|  |         self._is_closed = False | ||||||
|  |         self._chanmngr.open() | ||||||
|  |         return self | ||||||
|  | 
 | ||||||
|     async def aclose(self) -> None: |     async def aclose(self) -> None: | ||||||
|         await self._chanmngr.aclose() |         if self.closed: | ||||||
|  |             log.warning('tried to close RingBufferSubscriber but its already closed...') | ||||||
|  |             return | ||||||
|  | 
 | ||||||
|  |         await self._chanmngr.close() | ||||||
|  |         await self._schan.aclose() | ||||||
|  |         await self._rchan.aclose() | ||||||
|  |         self._is_closed = True | ||||||
| 
 | 
 | ||||||
| @acm | @acm | ||||||
| async def open_ringbuf_subscriber( | async def open_ringbuf_subscriber( | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue