RingBufferReceiveChannel fixes for the non clean eof case, add comments
parent
8e1f95881c
commit
853aa740aa
|
@ -108,16 +108,11 @@ def open_ringbuf(
|
|||
create=True
|
||||
)
|
||||
try:
|
||||
with (
|
||||
EventFD(open_eventfd(), 'r') as write_event,
|
||||
EventFD(open_eventfd(), 'r') as wrap_event,
|
||||
EventFD(open_eventfd(), 'r') as eof_event,
|
||||
):
|
||||
token = RBToken(
|
||||
shm_name=shm_name,
|
||||
write_eventfd=write_event.fd,
|
||||
wrap_eventfd=wrap_event.fd,
|
||||
eof_eventfd=eof_event.fd,
|
||||
write_eventfd=open_eventfd(),
|
||||
wrap_eventfd=open_eventfd(),
|
||||
eof_eventfd=open_eventfd(),
|
||||
buf_size=buf_size
|
||||
)
|
||||
yield token
|
||||
|
@ -219,6 +214,9 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
|
|||
# ensure no concurrent `.send()` calls
|
||||
self._send_lock = trio.StrictFIFOLock()
|
||||
|
||||
# ensure no concurrent `.flush()` calls
|
||||
self._flush_lock = trio.StrictFIFOLock()
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._is_closed
|
||||
|
@ -294,7 +292,7 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
|
|||
if self.closed:
|
||||
raise trio.ClosedResourceError
|
||||
|
||||
async with self._send_lock:
|
||||
async with self._flush_lock:
|
||||
for msg in self._batch:
|
||||
await self.send_all(msg)
|
||||
|
||||
|
@ -312,6 +310,9 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
|
|||
async with self._send_lock:
|
||||
msg: bytes = struct.pack("<I", len(value)) + value
|
||||
if self.batch_size == 1:
|
||||
if len(self._batch) > 0:
|
||||
await self.flush()
|
||||
|
||||
await self.send_all(msg)
|
||||
return
|
||||
|
||||
|
@ -349,9 +350,8 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
|
|||
self._is_closed = True
|
||||
|
||||
async def aclose(self):
|
||||
if not self.closed:
|
||||
await self.send(b'')
|
||||
await self.flush()
|
||||
if self.closed:
|
||||
return
|
||||
|
||||
self._close()
|
||||
|
||||
|
@ -374,20 +374,37 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
|||
cleanup: bool = True,
|
||||
):
|
||||
self._token = RBToken.from_msg(token)
|
||||
|
||||
# ringbuf os resources
|
||||
self._shm: SharedMemory | None = None
|
||||
self._write_event = EventFD(self._token.write_eventfd, 'w')
|
||||
self._wrap_event = EventFD(self._token.wrap_eventfd, 'r')
|
||||
self._eof_event = EventFD(self._token.eof_eventfd, 'r')
|
||||
|
||||
# current read ptr
|
||||
self._ptr: int = 0
|
||||
|
||||
# current write_ptr (max bytes we can read from buf)
|
||||
self._write_ptr: int = 0
|
||||
|
||||
# end ptr is used when EOF is signaled, it will contain maximun
|
||||
# readable position on buf
|
||||
self._end_ptr: int = -1
|
||||
|
||||
# close shm & fds on exit?
|
||||
self._cleanup: bool = cleanup
|
||||
|
||||
# have we closed this ringbuf?
|
||||
# set to `False` on `.open()`
|
||||
self._is_closed: bool = True
|
||||
|
||||
# ensure no concurrent `.receive_some()` calls
|
||||
self._receive_some_lock = trio.StrictFIFOLock()
|
||||
|
||||
# ensure no concurrent `.receive_exactly()` calls
|
||||
self._receive_exactly_lock = trio.StrictFIFOLock()
|
||||
|
||||
# ensure no concurrent `.receive()` calls
|
||||
self._receive_lock = trio.StrictFIFOLock()
|
||||
|
||||
@property
|
||||
|
@ -416,6 +433,10 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
|||
def wrap_fd(self) -> int:
|
||||
return self._wrap_event.fd
|
||||
|
||||
@property
|
||||
def eof_was_signaled(self) -> bool:
|
||||
return self._end_ptr != -1
|
||||
|
||||
async def _eof_monitor_task(self):
|
||||
'''
|
||||
Long running EOF event monitor, automatically run in bg by
|
||||
|
@ -428,7 +449,6 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
|||
'''
|
||||
try:
|
||||
self._end_ptr = await self._eof_event.read()
|
||||
self._write_event.close()
|
||||
|
||||
except EFDReadCancelled:
|
||||
...
|
||||
|
@ -436,6 +456,11 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
|||
except trio.Cancelled:
|
||||
...
|
||||
|
||||
finally:
|
||||
# closing write_event should trigger `EFDReadCancelled`
|
||||
# on any pending read
|
||||
self._write_event.close()
|
||||
|
||||
def receive_nowait(self, max_bytes: int = _DEFAULT_RB_SIZE) -> bytes:
|
||||
'''
|
||||
Try to receive any bytes we can without blocking or raise
|
||||
|
@ -456,7 +481,7 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
|||
# no more bytes to read
|
||||
if delta == 0:
|
||||
# if `end_ptr` is set that means we read all bytes before EOF
|
||||
if self._end_ptr > 0:
|
||||
if self.eof_was_signaled:
|
||||
return b''
|
||||
|
||||
# signal the need to wait on `write_event`
|
||||
|
@ -500,7 +525,7 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
|||
|
||||
except trio.WouldBlock as e:
|
||||
# we have read all we can, see if new data is available
|
||||
if self._end_ptr == -1:
|
||||
if not self.eof_was_signaled:
|
||||
# if we havent been signaled about EOF yet
|
||||
try:
|
||||
# wait next write and advance `write_ptr`
|
||||
|
@ -510,17 +535,15 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
|||
|
||||
except EFDReadCancelled:
|
||||
# while waiting for new data `self._write_event` was closed
|
||||
# this means writer signaled EOF
|
||||
if self._end_ptr > 0:
|
||||
# receive_nowait will handle read until EOF
|
||||
try:
|
||||
# if eof was signaled receive no wait will not raise
|
||||
# trio.WouldBlock and will push remaining until EOF
|
||||
return self.receive_nowait(max_bytes=max_bytes)
|
||||
|
||||
else:
|
||||
# shouldnt happen because self._eof_monitor_task always sets
|
||||
# self._end_ptr before closing self._write_event
|
||||
raise InternalError(
|
||||
'self._write_event.read cancelled but self._end_ptr is not set'
|
||||
)
|
||||
except trio.WouldBlock:
|
||||
# eof was not signaled but `self._wrap_event` is closed
|
||||
# this means send side closed without EOF signal
|
||||
return b''
|
||||
|
||||
else:
|
||||
# shouldnt happen because receive_nowait does not raise
|
||||
|
@ -533,7 +556,7 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
|||
|
||||
async def receive_exactly(self, num_bytes: int) -> bytes:
|
||||
'''
|
||||
Fetch bytes until we read exactly `num_bytes` or EOF.
|
||||
Fetch bytes until we read exactly `num_bytes` or EOC.
|
||||
|
||||
'''
|
||||
if self.closed:
|
||||
|
@ -563,7 +586,7 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
|||
|
||||
async def receive(self) -> bytes:
|
||||
'''
|
||||
Receive a complete payload
|
||||
Receive a complete payload or raise EOC
|
||||
|
||||
'''
|
||||
if self.closed:
|
||||
|
@ -607,6 +630,9 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
|||
self._is_closed = True
|
||||
|
||||
async def aclose(self):
|
||||
if self.closed:
|
||||
return
|
||||
|
||||
self.close()
|
||||
|
||||
async def __aenter__(self):
|
||||
|
|
Loading…
Reference in New Issue