Add trio resource semantics to ring pubsub

one_ring_to_rule_them_all
Guillermo Rodriguez 2025-04-06 21:19:39 -03:00
parent 853aa740aa
commit 171545e4fb
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
1 changed files with 113 additions and 66 deletions

View File

@ -58,7 +58,7 @@ class ChannelInfo:
cancel_scope: trio.CancelScope cancel_scope: trio.CancelScope
class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]): class ChannelManager(Generic[ChannelType]):
''' '''
Helper for managing channel resources and their handler tasks with Helper for managing channel resources and their handler tasks with
cancellation, add or remove channels dynamically! cancellation, add or remove channels dynamically!
@ -89,18 +89,15 @@ class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]):
# methods that modify self._channels should be ordered by FIFO # methods that modify self._channels should be ordered by FIFO
self._lock = trio.StrictFIFOLock() self._lock = trio.StrictFIFOLock()
@acm self._is_closed: bool = True
async def maybe_lock(self):
'''
If lock is not held, acquire
''' @property
if self._lock.locked(): def closed(self) -> bool:
yield return self._is_closed
return
async with self._lock: @property
yield def lock(self) -> trio.StrictFIFOLock:
return self._lock
@property @property
def channels(self) -> list[ChannelInfo]: def channels(self) -> list[ChannelInfo]:
@ -139,7 +136,7 @@ class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]):
with cancel_scope: with cancel_scope:
await self._channel_task(info) await self._channel_task(info)
await self._maybe_destroy_channel(name) self._maybe_destroy_channel(name)
def _find_channel(self, name: str) -> tuple[int, ChannelInfo] | None: def _find_channel(self, name: str) -> tuple[int, ChannelInfo] | None:
''' '''
@ -156,13 +153,12 @@ class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]):
return None return None
async def _maybe_destroy_channel(self, name: str): def _maybe_destroy_channel(self, name: str):
''' '''
If channel exists cancel its scope and remove from internal If channel exists cancel its scope and remove from internal
_channels list. _channels list.
''' '''
async with self.maybe_lock():
maybe_entry = self._find_channel(name) maybe_entry = self._find_channel(name)
if maybe_entry: if maybe_entry:
i, info = maybe_entry i, info = maybe_entry
@ -174,7 +170,10 @@ class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]):
Add a new channel to be handled Add a new channel to be handled
''' '''
async with self.maybe_lock(): if self.closed:
raise trio.ClosedResourceError
async with self._lock:
await self._n.start(partial( await self._n.start(partial(
self._channel_handler_task, self._channel_handler_task,
name, name,
@ -186,8 +185,11 @@ class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]):
Remove a channel and stop its handling Remove a channel and stop its handling
''' '''
async with self.maybe_lock(): if self.closed:
await self._maybe_destroy_channel(name) raise trio.ClosedResourceError
async with self._lock:
self._maybe_destroy_channel(name)
# if that was last channel reset connect event # if that was last channel reset connect event
if len(self) == 0: if len(self) == 0:
@ -198,6 +200,9 @@ class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]):
Wait until at least one channel added Wait until at least one channel added
''' '''
if self.closed:
raise trio.ClosedResourceError
await self._connect_event.wait() await self._connect_event.wait()
self._connect_event = trio.Event() self._connect_event = trio.Event()
@ -212,11 +217,19 @@ class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]):
raise KeyError(f'Channel {name} not found!') raise KeyError(f'Channel {name} not found!')
async def aclose(self) -> None: def open(self):
async with self.maybe_lock(): self._is_closed = False
async def close(self) -> None:
if self.closed:
log.warning('tried to close ChannelManager but its already closed...')
return
for info in self._channels: for info in self._channels:
await self.remove_channel(info.name) await self.remove_channel(info.name)
self._is_closed = True
''' '''
Ring buffer publisher & subscribe pattern mediated by `ringd` actor. Ring buffer publisher & subscribe pattern mediated by `ringd` actor.
@ -264,6 +277,12 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
self._next_turn: int = 0 self._next_turn: int = 0
self._is_closed: bool = True
@property
def closed(self) -> bool:
return self._is_closed
@property @property
def batch_size(self) -> int: def batch_size(self) -> int:
return self._batch_size return self._batch_size
@ -289,35 +308,9 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
name: str, name: str,
must_exist: bool = False must_exist: bool = False
): ):
'''
Store additional runtime info for channel and add channel to underlying
ChannelManager
'''
await self._chanmngr.add_channel(name, must_exist=must_exist) await self._chanmngr.add_channel(name, must_exist=must_exist)
async def remove_channel(self, name: str): async def remove_channel(self, name: str):
'''
Send EOF to channel (does `channel.flush` also) then remove from
`ChannelManager` acquire both `self._send_lock` and
`self._chanmngr.maybe_lock()` in order to ensure no channel
modifications or sends happen concurrenty
'''
async with self._chanmngr.maybe_lock():
# ensure all pending messages are sent
info = self.get_channel(name)
try:
while True:
msg = info.channel.rchan.receive_nowait()
await info.channel.ring.send(msg)
except trio.WouldBlock:
await info.channel.ring.flush()
await info.channel.schan.aclose()
# finally remove from ChannelManager
await self._chanmngr.remove_channel(name) await self._chanmngr.remove_channel(name)
@acm @acm
@ -345,6 +338,13 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
schan=schan, schan=schan,
rchan=rchan rchan=rchan
) )
try:
while True:
msg = rchan.receive_nowait()
await ring.send(msg)
except trio.WouldBlock:
...
async def _channel_task(self, info: ChannelInfo) -> None: async def _channel_task(self, info: ChannelInfo) -> None:
''' '''
@ -352,9 +352,13 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
payloads update event then drain all into send channel. payloads update event then drain all into send channel.
''' '''
try:
async for msg in info.channel.rchan: async for msg in info.channel.rchan:
await info.channel.ring.send(msg) await info.channel.ring.send(msg)
except trio.Cancelled:
...
async def send(self, msg: bytes): async def send(self, msg: bytes):
''' '''
If no output channels connected, wait until one, then fetch the next If no output channels connected, wait until one, then fetch the next
@ -365,6 +369,12 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
variables dont happen out of order. variables dont happen out of order.
''' '''
if self.closed:
raise trio.ClosedResourceError
if self._send_lock.locked():
raise trio.BusyResourceError
async with self._send_lock: async with self._send_lock:
# wait at least one decoder connected # wait at least one decoder connected
if len(self.channels) == 0: if len(self.channels) == 0:
@ -378,8 +388,23 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
self._next_turn += 1 self._next_turn += 1
async def flush(self, new_batch_size: int | None = None):
async with self._chanmngr.lock:
for info in self.channels:
await info.channel.ring.flush(new_batch_size=new_batch_size)
async def __aenter__(self):
self._chanmngr.open()
self._is_closed = False
return self
async def aclose(self) -> None: async def aclose(self) -> None:
await self._chanmngr.aclose() if self.closed:
log.warning('tried to close RingBufferPublisher but its already closed...')
return
await self._chanmngr.close()
self._is_closed = True
@acm @acm
@ -445,6 +470,14 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
self._schan, self._rchan = trio.open_memory_channel(0) self._schan, self._rchan = trio.open_memory_channel(0)
self._is_closed: bool = True
self._receive_lock = trio.StrictFIFOLock()
@property
def closed(self) -> bool:
return self._is_closed
@property @property
def channels(self) -> list[ChannelInfo]: def channels(self) -> list[ChannelInfo]:
return self._chanmngr.channels return self._chanmngr.channels
@ -453,17 +486,9 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
return self._chanmngr[name] return self._chanmngr[name]
async def add_channel(self, name: str, must_exist: bool = False): async def add_channel(self, name: str, must_exist: bool = False):
'''
Add new input channel by name
'''
await self._chanmngr.add_channel(name, must_exist=must_exist) await self._chanmngr.add_channel(name, must_exist=must_exist)
async def remove_channel(self, name: str): async def remove_channel(self, name: str):
'''
Remove an input channel by name
'''
await self._chanmngr.remove_channel(name) await self._chanmngr.remove_channel(name)
@acm @acm
@ -506,14 +531,36 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
except trio.EndOfChannel: except trio.EndOfChannel:
break break
except trio.ClosedResourceError:
break
async def receive(self) -> bytes: async def receive(self) -> bytes:
''' '''
Receive next in order msg Receive next in order msg
''' '''
if self.closed:
raise trio.ClosedResourceError
if self._receive_lock.locked():
raise trio.BusyResourceError
async with self._receive_lock:
return await self._rchan.receive() return await self._rchan.receive()
async def __aenter__(self):
self._is_closed = False
self._chanmngr.open()
return self
async def aclose(self) -> None: async def aclose(self) -> None:
await self._chanmngr.aclose() if self.closed:
log.warning('tried to close RingBufferSubscriber but its already closed...')
return
await self._chanmngr.close()
await self._schan.aclose()
await self._rchan.aclose()
self._is_closed = True
@acm @acm
async def open_ringbuf_subscriber( async def open_ringbuf_subscriber(