Add trio resource semantics to ring pubsub

one_ring_to_rule_them_all
Guillermo Rodriguez 2025-04-06 21:19:39 -03:00
parent 853aa740aa
commit 171545e4fb
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
1 changed files with 113 additions and 66 deletions

View File

@ -58,7 +58,7 @@ class ChannelInfo:
cancel_scope: trio.CancelScope
class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]):
class ChannelManager(Generic[ChannelType]):
'''
Helper for managing channel resources and their handler tasks with
cancellation, add or remove channels dynamically!
@ -89,18 +89,15 @@ class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]):
# methods that modify self._channels should be ordered by FIFO
self._lock = trio.StrictFIFOLock()
@acm
async def maybe_lock(self):
'''
If lock is not held, acquire
self._is_closed: bool = True
'''
if self._lock.locked():
yield
return
@property
def closed(self) -> bool:
return self._is_closed
async with self._lock:
yield
@property
def lock(self) -> trio.StrictFIFOLock:
return self._lock
@property
def channels(self) -> list[ChannelInfo]:
@ -139,7 +136,7 @@ class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]):
with cancel_scope:
await self._channel_task(info)
await self._maybe_destroy_channel(name)
self._maybe_destroy_channel(name)
def _find_channel(self, name: str) -> tuple[int, ChannelInfo] | None:
'''
@ -156,25 +153,27 @@ class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]):
return None
async def _maybe_destroy_channel(self, name: str):
def _maybe_destroy_channel(self, name: str):
'''
If channel exists cancel its scope and remove from internal
_channels list.
'''
async with self.maybe_lock():
maybe_entry = self._find_channel(name)
if maybe_entry:
i, info = maybe_entry
info.cancel_scope.cancel()
del self._channels[i]
maybe_entry = self._find_channel(name)
if maybe_entry:
i, info = maybe_entry
info.cancel_scope.cancel()
del self._channels[i]
async def add_channel(self, name: str, **kwargs):
'''
Add a new channel to be handled
'''
async with self.maybe_lock():
if self.closed:
raise trio.ClosedResourceError
async with self._lock:
await self._n.start(partial(
self._channel_handler_task,
name,
@ -186,8 +185,11 @@ class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]):
Remove a channel and stop its handling
'''
async with self.maybe_lock():
await self._maybe_destroy_channel(name)
if self.closed:
raise trio.ClosedResourceError
async with self._lock:
self._maybe_destroy_channel(name)
# if that was last channel reset connect event
if len(self) == 0:
@ -198,6 +200,9 @@ class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]):
Wait until at least one channel added
'''
if self.closed:
raise trio.ClosedResourceError
await self._connect_event.wait()
self._connect_event = trio.Event()
@ -212,10 +217,18 @@ class ChannelManager(trio.abc.AsyncResource, Generic[ChannelType]):
raise KeyError(f'Channel {name} not found!')
async def aclose(self) -> None:
async with self.maybe_lock():
for info in self._channels:
await self.remove_channel(info.name)
def open(self):
self._is_closed = False
async def close(self) -> None:
if self.closed:
log.warning('tried to close ChannelManager but its already closed...')
return
for info in self._channels:
await self.remove_channel(info.name)
self._is_closed = True
'''
@ -264,6 +277,12 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
self._next_turn: int = 0
self._is_closed: bool = True
@property
def closed(self) -> bool:
return self._is_closed
@property
def batch_size(self) -> int:
return self._batch_size
@ -289,36 +308,10 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
name: str,
must_exist: bool = False
):
'''
Store additional runtime info for channel and add channel to underlying
ChannelManager
'''
await self._chanmngr.add_channel(name, must_exist=must_exist)
async def remove_channel(self, name: str):
'''
Send EOF to channel (does `channel.flush` also) then remove from
`ChannelManager` acquire both `self._send_lock` and
`self._chanmngr.maybe_lock()` in order to ensure no channel
modifications or sends happen concurrenty
'''
async with self._chanmngr.maybe_lock():
# ensure all pending messages are sent
info = self.get_channel(name)
try:
while True:
msg = info.channel.rchan.receive_nowait()
await info.channel.ring.send(msg)
except trio.WouldBlock:
await info.channel.ring.flush()
await info.channel.schan.aclose()
# finally remove from ChannelManager
await self._chanmngr.remove_channel(name)
await self._chanmngr.remove_channel(name)
@acm
async def _open_channel(
@ -345,6 +338,13 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
schan=schan,
rchan=rchan
)
try:
while True:
msg = rchan.receive_nowait()
await ring.send(msg)
except trio.WouldBlock:
...
async def _channel_task(self, info: ChannelInfo) -> None:
'''
@ -352,8 +352,12 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
payloads update event then drain all into send channel.
'''
async for msg in info.channel.rchan:
await info.channel.ring.send(msg)
try:
async for msg in info.channel.rchan:
await info.channel.ring.send(msg)
except trio.Cancelled:
...
async def send(self, msg: bytes):
'''
@ -365,6 +369,12 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
variables dont happen out of order.
'''
if self.closed:
raise trio.ClosedResourceError
if self._send_lock.locked():
raise trio.BusyResourceError
async with self._send_lock:
# wait at least one decoder connected
if len(self.channels) == 0:
@ -378,8 +388,23 @@ class RingBufferPublisher(trio.abc.SendChannel[bytes]):
self._next_turn += 1
async def flush(self, new_batch_size: int | None = None):
async with self._chanmngr.lock:
for info in self.channels:
await info.channel.ring.flush(new_batch_size=new_batch_size)
async def __aenter__(self):
self._chanmngr.open()
self._is_closed = False
return self
async def aclose(self) -> None:
await self._chanmngr.aclose()
if self.closed:
log.warning('tried to close RingBufferPublisher but its already closed...')
return
await self._chanmngr.close()
self._is_closed = True
@acm
@ -445,6 +470,14 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
self._schan, self._rchan = trio.open_memory_channel(0)
self._is_closed: bool = True
self._receive_lock = trio.StrictFIFOLock()
@property
def closed(self) -> bool:
return self._is_closed
@property
def channels(self) -> list[ChannelInfo]:
return self._chanmngr.channels
@ -453,17 +486,9 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
return self._chanmngr[name]
async def add_channel(self, name: str, must_exist: bool = False):
'''
Add new input channel by name
'''
await self._chanmngr.add_channel(name, must_exist=must_exist)
async def remove_channel(self, name: str):
'''
Remove an input channel by name
'''
await self._chanmngr.remove_channel(name)
@acm
@ -506,14 +531,36 @@ class RingBufferSubscriber(trio.abc.ReceiveChannel[bytes]):
except trio.EndOfChannel:
break
except trio.ClosedResourceError:
break
async def receive(self) -> bytes:
'''
Receive next in order msg
'''
return await self._rchan.receive()
if self.closed:
raise trio.ClosedResourceError
if self._receive_lock.locked():
raise trio.BusyResourceError
async with self._receive_lock:
return await self._rchan.receive()
async def __aenter__(self):
self._is_closed = False
self._chanmngr.open()
return self
async def aclose(self) -> None:
await self._chanmngr.aclose()
if self.closed:
log.warning('tried to close RingBufferSubscriber but its already closed...')
return
await self._chanmngr.close()
await self._schan.aclose()
await self._rchan.aclose()
self._is_closed = True
@acm
async def open_ringbuf_subscriber(