Adhere to trio semantics on channels for closed and busy resource cases

one_ring_to_rule_them_all
Guillermo Rodriguez 2025-04-06 17:02:15 -03:00
parent 3a1eda9d6d
commit 1451feb159
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
2 changed files with 147 additions and 65 deletions

View File

@ -56,6 +56,7 @@ async def child_read_shm(
print(f'\n\telapsed ms: {elapsed_ms}') print(f'\n\telapsed ms: {elapsed_ms}')
print(f'\tmsg/sec: {int(msg_amount / elapsed):,}') print(f'\tmsg/sec: {int(msg_amount / elapsed):,}')
print(f'\tbytes/sec: {int(recvd_bytes / elapsed):,}') print(f'\tbytes/sec: {int(recvd_bytes / elapsed):,}')
print(f'\treceived msgs: {msg_amount:,}')
print(f'\treceived bytes: {recvd_bytes:,}') print(f'\treceived bytes: {recvd_bytes:,}')
return recvd_hash.hexdigest() return recvd_hash.hexdigest()
@ -165,7 +166,6 @@ def test_ringbuf(
await send_p.cancel_actor() await send_p.cancel_actor()
await recv_p.cancel_actor() await recv_p.cancel_actor()
trio.run(main) trio.run(main)

View File

@ -200,23 +200,28 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
self._eof_event = EventFD(self._token.eof_eventfd, 'w') self._eof_event = EventFD(self._token.eof_eventfd, 'w')
# current write pointer # current write pointer
self._ptr = 0 self._ptr: int = 0
# when `batch_size` > 1 store messages on `self._batch` and write them # when `batch_size` > 1 store messages on `self._batch` and write them
# all, once `len(self._batch) == `batch_size` # all, once `len(self._batch) == `batch_size`
self._batch: list[bytes] = [] self._batch: list[bytes] = []
self._cleanup = cleanup # close shm & fds on exit?
self._cleanup: bool = cleanup
# have we closed this ringbuf?
# set to `False` on `.open()`
self._is_closed: bool = True
# ensure no concurrent `.send_all()` calls
self._send_all_lock = trio.StrictFIFOLock()
# ensure no concurrent `.send()` calls
self._send_lock = trio.StrictFIFOLock() self._send_lock = trio.StrictFIFOLock()
@acm @property
async def _maybe_lock(self) -> AsyncContextManager[None]: def closed(self) -> bool:
if self._send_lock.locked(): return self._is_closed
yield
return
async with self._send_lock:
yield
@property @property
def name(self) -> str: def name(self) -> str:
@ -252,7 +257,13 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
await self._wrap_event.read() await self._wrap_event.read()
async def send_all(self, data: Buffer): async def send_all(self, data: Buffer):
async with self._maybe_lock(): if self.closed:
raise trio.ClosedResourceError
if self._send_all_lock.locked():
raise trio.BusyResourceError
async with self._send_all_lock:
# while data is larger than the remaining buf # while data is larger than the remaining buf
target_ptr = self.ptr + len(data) target_ptr = self.ptr + len(data)
while target_ptr > self.size: while target_ptr > self.size:
@ -274,13 +285,16 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
self._ptr = target_ptr self._ptr = target_ptr
async def wait_send_all_might_not_block(self): async def wait_send_all_might_not_block(self):
raise NotImplementedError return
async def flush( async def flush(
self, self,
new_batch_size: int | None = None new_batch_size: int | None = None
) -> None: ) -> None:
async with self._maybe_lock(): if self.closed:
raise trio.ClosedResourceError
async with self._send_lock:
for msg in self._batch: for msg in self._batch:
await self.send_all(msg) await self.send_all(msg)
@ -289,7 +303,13 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
self.batch_size = new_batch_size self.batch_size = new_batch_size
async def send(self, value: bytes) -> None: async def send(self, value: bytes) -> None:
async with self._maybe_lock(): if self.closed:
raise trio.ClosedResourceError
if self._send_lock.locked():
raise trio.BusyResourceError
async with self._send_lock:
msg: bytes = struct.pack("<I", len(value)) + value msg: bytes = struct.pack("<I", len(value)) + value
if self.batch_size == 1: if self.batch_size == 1:
await self.send_all(msg) await self.send_all(msg)
@ -299,11 +319,6 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
if self.must_flush: if self.must_flush:
await self.flush() await self.flush()
async def send_eof(self) -> None:
async with self._send_lock:
await self.flush(new_batch_size=1)
await self.send(b'')
def open(self): def open(self):
try: try:
self._shm = SharedMemory( self._shm = SharedMemory(
@ -314,12 +329,13 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
self._write_event.open() self._write_event.open()
self._wrap_event.open() self._wrap_event.open()
self._eof_event.open() self._eof_event.open()
self._is_closed = False
except Exception as e: except Exception as e:
e.add_note(f'while opening sender for {self._token.as_msg()}') e.add_note(f'while opening sender for {self._token.as_msg()}')
raise e raise e
def close(self): def _close(self):
self._eof_event.write( self._eof_event.write(
self._ptr if self._ptr > 0 else self.size self._ptr if self._ptr > 0 else self.size
) )
@ -330,8 +346,14 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
self._eof_event.close() self._eof_event.close()
self._shm.close() self._shm.close()
self._is_closed = True
async def aclose(self): async def aclose(self):
self.close() if not self.closed:
await self.send(b'')
await self.flush()
self._close()
async def __aenter__(self): async def __aenter__(self):
self.open() self.open()
@ -362,6 +384,16 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
self._cleanup: bool = cleanup self._cleanup: bool = cleanup
self._is_closed: bool = True
self._receive_some_lock = trio.StrictFIFOLock()
self._receive_exactly_lock = trio.StrictFIFOLock()
self._receive_lock = trio.StrictFIFOLock()
@property
def closed(self) -> bool:
return self._is_closed
@property @property
def name(self) -> str: def name(self) -> str:
if not self._shm: if not self._shm:
@ -409,12 +441,25 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
Try to receive any bytes we can without blocking or raise Try to receive any bytes we can without blocking or raise
`trio.WouldBlock`. `trio.WouldBlock`.
Returns b'' when no more bytes can be read (EOF signaled & read all).
''' '''
if max_bytes < 1: if max_bytes < 1:
raise ValueError("max_bytes must be >= 1") raise ValueError("max_bytes must be >= 1")
delta = self._write_ptr - self._ptr # in case `end_ptr` is set that means eof was signaled.
# it will be >= `write_ptr`, use it for delta calc
highest_ptr = max(self._write_ptr, self._end_ptr)
delta = highest_ptr - self._ptr
# no more bytes to read
if delta == 0: if delta == 0:
# if `end_ptr` is set that means we read all bytes before EOF
if self._end_ptr > 0:
return b''
# signal the need to wait on `write_event`
raise trio.WouldBlock raise trio.WouldBlock
# dont overflow caller # dont overflow caller
@ -442,35 +487,47 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
Can return < max_bytes. Can return < max_bytes.
''' '''
try: if self.closed:
return self.receive_nowait(max_bytes=max_bytes) raise trio.ClosedResourceError
except trio.WouldBlock: if self._receive_some_lock.locked():
# we have read all we can, see if new data is available raise trio.BusyResourceError
if self._end_ptr < 0:
# if we havent been signaled about EOF yet
try:
delta = await self._write_event.read()
self._write_ptr += delta
except EFDReadCancelled: async with self._receive_some_lock:
# while waiting for new data `self._write_event` was closed try:
# this means writer signaled EOF # attempt direct read
if self._end_ptr > 0: return self.receive_nowait(max_bytes=max_bytes)
# final self._write_ptr modification and recalculate delta
self._write_ptr = self._end_ptr
delta = self._end_ptr - self._ptr
else: except trio.WouldBlock as e:
# shouldnt happen cause self._eof_monitor_task always sets # we have read all we can, see if new data is available
# self._end_ptr before closing self._write_event if self._end_ptr == -1:
raise InternalError( # if we havent been signaled about EOF yet
'self._write_event.read cancelled but self._end_ptr is not set' try:
) # wait next write and advance `write_ptr`
delta = await self._write_event.read()
self._write_ptr += delta
# yield lock and re-enter
else: except EFDReadCancelled:
# no more bytes to read and self._end_ptr set, EOF reached # while waiting for new data `self._write_event` was closed
return b'' # this means writer signaled EOF
if self._end_ptr > 0:
# receive_nowait will handle read until EOF
return self.receive_nowait(max_bytes=max_bytes)
else:
# shouldnt happen because self._eof_monitor_task always sets
# self._end_ptr before closing self._write_event
raise InternalError(
'self._write_event.read cancelled but self._end_ptr is not set'
)
else:
# shouldnt happen because receive_nowait does not raise
# trio.WouldBlock when `end_ptr` is set
raise InternalError(
'self._end_ptr is set but receive_nowait raised trio.WouldBlock'
) from e
return await self.receive_some(max_bytes=max_bytes) return await self.receive_some(max_bytes=max_bytes)
@ -479,32 +536,50 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
Fetch bytes until we read exactly `num_bytes` or EOF. Fetch bytes until we read exactly `num_bytes` or EOF.
''' '''
payload = b'' if self.closed:
while len(payload) < num_bytes: raise trio.ClosedResourceError
remaining = num_bytes - len(payload)
new_bytes = await self.receive_some( if self._receive_exactly_lock.locked():
max_bytes=remaining raise trio.BusyResourceError
)
if new_bytes == b'': async with self._receive_exactly_lock:
payload = b''
while len(payload) < num_bytes:
remaining = num_bytes - len(payload)
new_bytes = await self.receive_some(
max_bytes=remaining
)
if new_bytes == b'':
break
payload += new_bytes
if payload == b'':
raise trio.EndOfChannel raise trio.EndOfChannel
payload += new_bytes return payload
return payload
async def receive(self) -> bytes: async def receive(self) -> bytes:
''' '''
Receive a complete payload Receive a complete payload
''' '''
header: bytes = await self.receive_exactly(4) if self.closed:
size: int raise trio.ClosedResourceError
size, = struct.unpack("<I", header)
if size == 0: if self._receive_lock.locked():
raise trio.EndOfChannel raise trio.BusyResourceError
return await self.receive_exactly(size)
async with self._receive_lock:
header: bytes = await self.receive_exactly(4)
size: int
size, = struct.unpack("<I", header)
if size == 0:
raise trio.EndOfChannel
return await self.receive_exactly(size)
def open(self): def open(self):
try: try:
@ -516,6 +591,7 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
self._write_event.open() self._write_event.open()
self._wrap_event.open() self._wrap_event.open()
self._eof_event.open() self._eof_event.open()
self._is_closed = False
except Exception as e: except Exception as e:
e.add_note(f'while opening receiver for {self._token.as_msg()}') e.add_note(f'while opening receiver for {self._token.as_msg()}')
@ -528,6 +604,8 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
self._eof_event.close() self._eof_event.close()
self._shm.close() self._shm.close()
self._is_closed = True
async def aclose(self): async def aclose(self):
self.close() self.close()
@ -564,6 +642,7 @@ async def attach_to_ringbuf_receiver(
async def attach_to_ringbuf_sender( async def attach_to_ringbuf_sender(
token: RBToken, token: RBToken,
batch_size: int = 1,
cleanup: bool = True cleanup: bool = True
) -> AsyncContextManager[RingBufferSendChannel]: ) -> AsyncContextManager[RingBufferSendChannel]:
@ -574,6 +653,7 @@ async def attach_to_ringbuf_sender(
''' '''
async with RingBufferSendChannel( async with RingBufferSendChannel(
token, token,
batch_size=batch_size,
cleanup=cleanup cleanup=cleanup
) as sender: ) as sender:
yield sender yield sender
@ -644,8 +724,9 @@ class RingBufferChannel(trio.abc.Channel[bytes]):
async def attach_to_ringbuf_channel( async def attach_to_ringbuf_channel(
token_in: RBToken, token_in: RBToken,
token_out: RBToken, token_out: RBToken,
batch_size: int = 1,
cleanup_in: bool = True, cleanup_in: bool = True,
cleanup_out: bool = True cleanup_out: bool = True,
) -> AsyncContextManager[trio.StapledStream]: ) -> AsyncContextManager[trio.StapledStream]:
''' '''
Attach to two previously opened `RBToken`s and return a `RingBufferChannel` Attach to two previously opened `RBToken`s and return a `RingBufferChannel`
@ -658,6 +739,7 @@ async def attach_to_ringbuf_channel(
) as receiver, ) as receiver,
attach_to_ringbuf_sender( attach_to_ringbuf_sender(
token_out, token_out,
batch_size=batch_size,
cleanup=cleanup_out cleanup=cleanup_out
) as sender, ) as sender,
): ):