Adhere to trio semantics on channels for closed and busy resource cases

one_ring_to_rule_them_all
Guillermo Rodriguez 2025-04-06 17:02:15 -03:00
parent 3a1eda9d6d
commit 1451feb159
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
2 changed files with 147 additions and 65 deletions

View File

@ -56,6 +56,7 @@ async def child_read_shm(
print(f'\n\telapsed ms: {elapsed_ms}')
print(f'\tmsg/sec: {int(msg_amount / elapsed):,}')
print(f'\tbytes/sec: {int(recvd_bytes / elapsed):,}')
print(f'\treceived msgs: {msg_amount:,}')
print(f'\treceived bytes: {recvd_bytes:,}')
return recvd_hash.hexdigest()
@ -165,7 +166,6 @@ def test_ringbuf(
await send_p.cancel_actor()
await recv_p.cancel_actor()
trio.run(main)

View File

@ -200,23 +200,28 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
self._eof_event = EventFD(self._token.eof_eventfd, 'w')
# current write pointer
self._ptr = 0
self._ptr: int = 0
# when `batch_size` > 1 store messages on `self._batch` and write them
# all, once `len(self._batch) == `batch_size`
self._batch: list[bytes] = []
self._cleanup = cleanup
# close shm & fds on exit?
self._cleanup: bool = cleanup
# have we closed this ringbuf?
# set to `False` on `.open()`
self._is_closed: bool = True
# ensure no concurrent `.send_all()` calls
self._send_all_lock = trio.StrictFIFOLock()
# ensure no concurrent `.send()` calls
self._send_lock = trio.StrictFIFOLock()
@acm
async def _maybe_lock(self) -> AsyncContextManager[None]:
if self._send_lock.locked():
yield
return
async with self._send_lock:
yield
@property
def closed(self) -> bool:
return self._is_closed
@property
def name(self) -> str:
@ -252,7 +257,13 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
await self._wrap_event.read()
async def send_all(self, data: Buffer):
async with self._maybe_lock():
if self.closed:
raise trio.ClosedResourceError
if self._send_all_lock.locked():
raise trio.BusyResourceError
async with self._send_all_lock:
# while data is larger than the remaining buf
target_ptr = self.ptr + len(data)
while target_ptr > self.size:
@ -274,13 +285,16 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
self._ptr = target_ptr
async def wait_send_all_might_not_block(self):
raise NotImplementedError
return
async def flush(
self,
new_batch_size: int | None = None
) -> None:
async with self._maybe_lock():
if self.closed:
raise trio.ClosedResourceError
async with self._send_lock:
for msg in self._batch:
await self.send_all(msg)
@ -289,7 +303,13 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
self.batch_size = new_batch_size
async def send(self, value: bytes) -> None:
async with self._maybe_lock():
if self.closed:
raise trio.ClosedResourceError
if self._send_lock.locked():
raise trio.BusyResourceError
async with self._send_lock:
msg: bytes = struct.pack("<I", len(value)) + value
if self.batch_size == 1:
await self.send_all(msg)
@ -299,11 +319,6 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
if self.must_flush:
await self.flush()
async def send_eof(self) -> None:
async with self._send_lock:
await self.flush(new_batch_size=1)
await self.send(b'')
def open(self):
try:
self._shm = SharedMemory(
@ -314,12 +329,13 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
self._write_event.open()
self._wrap_event.open()
self._eof_event.open()
self._is_closed = False
except Exception as e:
e.add_note(f'while opening sender for {self._token.as_msg()}')
raise e
def close(self):
def _close(self):
self._eof_event.write(
self._ptr if self._ptr > 0 else self.size
)
@ -330,8 +346,14 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
self._eof_event.close()
self._shm.close()
self._is_closed = True
async def aclose(self):
self.close()
if not self.closed:
await self.send(b'')
await self.flush()
self._close()
async def __aenter__(self):
self.open()
@ -362,6 +384,16 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
self._cleanup: bool = cleanup
self._is_closed: bool = True
self._receive_some_lock = trio.StrictFIFOLock()
self._receive_exactly_lock = trio.StrictFIFOLock()
self._receive_lock = trio.StrictFIFOLock()
@property
def closed(self) -> bool:
return self._is_closed
@property
def name(self) -> str:
if not self._shm:
@ -409,12 +441,25 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
Try to receive any bytes we can without blocking or raise
`trio.WouldBlock`.
Returns b'' when no more bytes can be read (EOF signaled & read all).
'''
if max_bytes < 1:
raise ValueError("max_bytes must be >= 1")
delta = self._write_ptr - self._ptr
# in case `end_ptr` is set that means eof was signaled.
# it will be >= `write_ptr`, use it for delta calc
highest_ptr = max(self._write_ptr, self._end_ptr)
delta = highest_ptr - self._ptr
# no more bytes to read
if delta == 0:
# if `end_ptr` is set that means we read all bytes before EOF
if self._end_ptr > 0:
return b''
# signal the need to wait on `write_event`
raise trio.WouldBlock
# dont overflow caller
@ -442,35 +487,47 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
Can return < max_bytes.
'''
try:
return self.receive_nowait(max_bytes=max_bytes)
if self.closed:
raise trio.ClosedResourceError
except trio.WouldBlock:
# we have read all we can, see if new data is available
if self._end_ptr < 0:
# if we havent been signaled about EOF yet
try:
delta = await self._write_event.read()
self._write_ptr += delta
if self._receive_some_lock.locked():
raise trio.BusyResourceError
except EFDReadCancelled:
# while waiting for new data `self._write_event` was closed
# this means writer signaled EOF
if self._end_ptr > 0:
# final self._write_ptr modification and recalculate delta
self._write_ptr = self._end_ptr
delta = self._end_ptr - self._ptr
async with self._receive_some_lock:
try:
# attempt direct read
return self.receive_nowait(max_bytes=max_bytes)
else:
# shouldnt happen cause self._eof_monitor_task always sets
# self._end_ptr before closing self._write_event
raise InternalError(
'self._write_event.read cancelled but self._end_ptr is not set'
)
except trio.WouldBlock as e:
# we have read all we can, see if new data is available
if self._end_ptr == -1:
# if we havent been signaled about EOF yet
try:
# wait next write and advance `write_ptr`
delta = await self._write_event.read()
self._write_ptr += delta
# yield lock and re-enter
else:
# no more bytes to read and self._end_ptr set, EOF reached
return b''
except EFDReadCancelled:
# while waiting for new data `self._write_event` was closed
# this means writer signaled EOF
if self._end_ptr > 0:
# receive_nowait will handle read until EOF
return self.receive_nowait(max_bytes=max_bytes)
else:
# shouldnt happen because self._eof_monitor_task always sets
# self._end_ptr before closing self._write_event
raise InternalError(
'self._write_event.read cancelled but self._end_ptr is not set'
)
else:
# shouldnt happen because receive_nowait does not raise
# trio.WouldBlock when `end_ptr` is set
raise InternalError(
'self._end_ptr is set but receive_nowait raised trio.WouldBlock'
) from e
return await self.receive_some(max_bytes=max_bytes)
@ -479,32 +536,50 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
Fetch bytes until we read exactly `num_bytes` or EOF.
'''
payload = b''
while len(payload) < num_bytes:
remaining = num_bytes - len(payload)
if self.closed:
raise trio.ClosedResourceError
new_bytes = await self.receive_some(
max_bytes=remaining
)
if self._receive_exactly_lock.locked():
raise trio.BusyResourceError
if new_bytes == b'':
async with self._receive_exactly_lock:
payload = b''
while len(payload) < num_bytes:
remaining = num_bytes - len(payload)
new_bytes = await self.receive_some(
max_bytes=remaining
)
if new_bytes == b'':
break
payload += new_bytes
if payload == b'':
raise trio.EndOfChannel
payload += new_bytes
return payload
return payload
async def receive(self) -> bytes:
'''
Receive a complete payload
'''
header: bytes = await self.receive_exactly(4)
size: int
size, = struct.unpack("<I", header)
if size == 0:
raise trio.EndOfChannel
return await self.receive_exactly(size)
if self.closed:
raise trio.ClosedResourceError
if self._receive_lock.locked():
raise trio.BusyResourceError
async with self._receive_lock:
header: bytes = await self.receive_exactly(4)
size: int
size, = struct.unpack("<I", header)
if size == 0:
raise trio.EndOfChannel
return await self.receive_exactly(size)
def open(self):
try:
@ -516,6 +591,7 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
self._write_event.open()
self._wrap_event.open()
self._eof_event.open()
self._is_closed = False
except Exception as e:
e.add_note(f'while opening receiver for {self._token.as_msg()}')
@ -528,6 +604,8 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
self._eof_event.close()
self._shm.close()
self._is_closed = True
async def aclose(self):
self.close()
@ -564,6 +642,7 @@ async def attach_to_ringbuf_receiver(
async def attach_to_ringbuf_sender(
token: RBToken,
batch_size: int = 1,
cleanup: bool = True
) -> AsyncContextManager[RingBufferSendChannel]:
@ -574,6 +653,7 @@ async def attach_to_ringbuf_sender(
'''
async with RingBufferSendChannel(
token,
batch_size=batch_size,
cleanup=cleanup
) as sender:
yield sender
@ -644,8 +724,9 @@ class RingBufferChannel(trio.abc.Channel[bytes]):
async def attach_to_ringbuf_channel(
token_in: RBToken,
token_out: RBToken,
batch_size: int = 1,
cleanup_in: bool = True,
cleanup_out: bool = True
cleanup_out: bool = True,
) -> AsyncContextManager[trio.StapledStream]:
'''
Attach to two previously opened `RBToken`s and return a `RingBufferChannel`
@ -658,6 +739,7 @@ async def attach_to_ringbuf_channel(
) as receiver,
attach_to_ringbuf_sender(
token_out,
batch_size=batch_size,
cleanup=cleanup_out
) as sender,
):