RingBufferReceiveChannel fixes for the non clean eof case, add comments

one_ring_to_rule_them_all
Guillermo Rodriguez 2025-04-06 21:16:55 -03:00
parent 8e1f95881c
commit 853aa740aa
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
1 changed files with 57 additions and 31 deletions

View File

@ -108,19 +108,14 @@ def open_ringbuf(
create=True create=True
) )
try: try:
with ( token = RBToken(
EventFD(open_eventfd(), 'r') as write_event, shm_name=shm_name,
EventFD(open_eventfd(), 'r') as wrap_event, write_eventfd=open_eventfd(),
EventFD(open_eventfd(), 'r') as eof_event, wrap_eventfd=open_eventfd(),
): eof_eventfd=open_eventfd(),
token = RBToken( buf_size=buf_size
shm_name=shm_name, )
write_eventfd=write_event.fd, yield token
wrap_eventfd=wrap_event.fd,
eof_eventfd=eof_event.fd,
buf_size=buf_size
)
yield token
finally: finally:
shm.unlink() shm.unlink()
@ -219,6 +214,9 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
# ensure no concurrent `.send()` calls # ensure no concurrent `.send()` calls
self._send_lock = trio.StrictFIFOLock() self._send_lock = trio.StrictFIFOLock()
# ensure no concurrent `.flush()` calls
self._flush_lock = trio.StrictFIFOLock()
@property @property
def closed(self) -> bool: def closed(self) -> bool:
return self._is_closed return self._is_closed
@ -294,7 +292,7 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
if self.closed: if self.closed:
raise trio.ClosedResourceError raise trio.ClosedResourceError
async with self._send_lock: async with self._flush_lock:
for msg in self._batch: for msg in self._batch:
await self.send_all(msg) await self.send_all(msg)
@ -312,6 +310,9 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
async with self._send_lock: async with self._send_lock:
msg: bytes = struct.pack("<I", len(value)) + value msg: bytes = struct.pack("<I", len(value)) + value
if self.batch_size == 1: if self.batch_size == 1:
if len(self._batch) > 0:
await self.flush()
await self.send_all(msg) await self.send_all(msg)
return return
@ -349,9 +350,8 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
self._is_closed = True self._is_closed = True
async def aclose(self): async def aclose(self):
if not self.closed: if self.closed:
await self.send(b'') return
await self.flush()
self._close() self._close()
@ -374,20 +374,37 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
cleanup: bool = True, cleanup: bool = True,
): ):
self._token = RBToken.from_msg(token) self._token = RBToken.from_msg(token)
# ringbuf os resources
self._shm: SharedMemory | None = None self._shm: SharedMemory | None = None
self._write_event = EventFD(self._token.write_eventfd, 'w') self._write_event = EventFD(self._token.write_eventfd, 'w')
self._wrap_event = EventFD(self._token.wrap_eventfd, 'r') self._wrap_event = EventFD(self._token.wrap_eventfd, 'r')
self._eof_event = EventFD(self._token.eof_eventfd, 'r') self._eof_event = EventFD(self._token.eof_eventfd, 'r')
# current read ptr
self._ptr: int = 0 self._ptr: int = 0
# current write_ptr (max bytes we can read from buf)
self._write_ptr: int = 0 self._write_ptr: int = 0
# end ptr is used when EOF is signaled, it will contain maximun
# readable position on buf
self._end_ptr: int = -1 self._end_ptr: int = -1
# close shm & fds on exit?
self._cleanup: bool = cleanup self._cleanup: bool = cleanup
# have we closed this ringbuf?
# set to `False` on `.open()`
self._is_closed: bool = True self._is_closed: bool = True
# ensure no concurrent `.receive_some()` calls
self._receive_some_lock = trio.StrictFIFOLock() self._receive_some_lock = trio.StrictFIFOLock()
# ensure no concurrent `.receive_exactly()` calls
self._receive_exactly_lock = trio.StrictFIFOLock() self._receive_exactly_lock = trio.StrictFIFOLock()
# ensure no concurrent `.receive()` calls
self._receive_lock = trio.StrictFIFOLock() self._receive_lock = trio.StrictFIFOLock()
@property @property
@ -416,6 +433,10 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
def wrap_fd(self) -> int: def wrap_fd(self) -> int:
return self._wrap_event.fd return self._wrap_event.fd
@property
def eof_was_signaled(self) -> bool:
return self._end_ptr != -1
async def _eof_monitor_task(self): async def _eof_monitor_task(self):
''' '''
Long running EOF event monitor, automatically run in bg by Long running EOF event monitor, automatically run in bg by
@ -428,7 +449,6 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
''' '''
try: try:
self._end_ptr = await self._eof_event.read() self._end_ptr = await self._eof_event.read()
self._write_event.close()
except EFDReadCancelled: except EFDReadCancelled:
... ...
@ -436,6 +456,11 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
except trio.Cancelled: except trio.Cancelled:
... ...
finally:
# closing write_event should trigger `EFDReadCancelled`
# on any pending read
self._write_event.close()
def receive_nowait(self, max_bytes: int = _DEFAULT_RB_SIZE) -> bytes: def receive_nowait(self, max_bytes: int = _DEFAULT_RB_SIZE) -> bytes:
''' '''
Try to receive any bytes we can without blocking or raise Try to receive any bytes we can without blocking or raise
@ -456,7 +481,7 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
# no more bytes to read # no more bytes to read
if delta == 0: if delta == 0:
# if `end_ptr` is set that means we read all bytes before EOF # if `end_ptr` is set that means we read all bytes before EOF
if self._end_ptr > 0: if self.eof_was_signaled:
return b'' return b''
# signal the need to wait on `write_event` # signal the need to wait on `write_event`
@ -500,7 +525,7 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
except trio.WouldBlock as e: except trio.WouldBlock as e:
# we have read all we can, see if new data is available # we have read all we can, see if new data is available
if self._end_ptr == -1: if not self.eof_was_signaled:
# if we havent been signaled about EOF yet # if we havent been signaled about EOF yet
try: try:
# wait next write and advance `write_ptr` # wait next write and advance `write_ptr`
@ -510,17 +535,15 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
except EFDReadCancelled: except EFDReadCancelled:
# while waiting for new data `self._write_event` was closed # while waiting for new data `self._write_event` was closed
# this means writer signaled EOF try:
if self._end_ptr > 0: # if eof was signaled receive no wait will not raise
# receive_nowait will handle read until EOF # trio.WouldBlock and will push remaining until EOF
return self.receive_nowait(max_bytes=max_bytes) return self.receive_nowait(max_bytes=max_bytes)
else: except trio.WouldBlock:
# shouldnt happen because self._eof_monitor_task always sets # eof was not signaled but `self._wrap_event` is closed
# self._end_ptr before closing self._write_event # this means send side closed without EOF signal
raise InternalError( return b''
'self._write_event.read cancelled but self._end_ptr is not set'
)
else: else:
# shouldnt happen because receive_nowait does not raise # shouldnt happen because receive_nowait does not raise
@ -533,7 +556,7 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
async def receive_exactly(self, num_bytes: int) -> bytes: async def receive_exactly(self, num_bytes: int) -> bytes:
''' '''
Fetch bytes until we read exactly `num_bytes` or EOF. Fetch bytes until we read exactly `num_bytes` or EOC.
''' '''
if self.closed: if self.closed:
@ -563,7 +586,7 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
async def receive(self) -> bytes: async def receive(self) -> bytes:
''' '''
Receive a complete payload Receive a complete payload or raise EOC
''' '''
if self.closed: if self.closed:
@ -607,6 +630,9 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
self._is_closed = True self._is_closed = True
async def aclose(self): async def aclose(self):
if self.closed:
return
self.close() self.close()
async def __aenter__(self): async def __aenter__(self):