Adhere to trio semantics on channels for closed and busy resource cases
parent
3a1eda9d6d
commit
1451feb159
|
@ -56,6 +56,7 @@ async def child_read_shm(
|
|||
print(f'\n\telapsed ms: {elapsed_ms}')
|
||||
print(f'\tmsg/sec: {int(msg_amount / elapsed):,}')
|
||||
print(f'\tbytes/sec: {int(recvd_bytes / elapsed):,}')
|
||||
print(f'\treceived msgs: {msg_amount:,}')
|
||||
print(f'\treceived bytes: {recvd_bytes:,}')
|
||||
|
||||
return recvd_hash.hexdigest()
|
||||
|
@ -165,7 +166,6 @@ def test_ringbuf(
|
|||
await send_p.cancel_actor()
|
||||
await recv_p.cancel_actor()
|
||||
|
||||
|
||||
trio.run(main)
|
||||
|
||||
|
||||
|
|
|
@ -200,23 +200,28 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
|
|||
self._eof_event = EventFD(self._token.eof_eventfd, 'w')
|
||||
|
||||
# current write pointer
|
||||
self._ptr = 0
|
||||
self._ptr: int = 0
|
||||
|
||||
# when `batch_size` > 1 store messages on `self._batch` and write them
|
||||
# all, once `len(self._batch) == `batch_size`
|
||||
self._batch: list[bytes] = []
|
||||
|
||||
self._cleanup = cleanup
|
||||
# close shm & fds on exit?
|
||||
self._cleanup: bool = cleanup
|
||||
|
||||
# have we closed this ringbuf?
|
||||
# set to `False` on `.open()`
|
||||
self._is_closed: bool = True
|
||||
|
||||
# ensure no concurrent `.send_all()` calls
|
||||
self._send_all_lock = trio.StrictFIFOLock()
|
||||
|
||||
# ensure no concurrent `.send()` calls
|
||||
self._send_lock = trio.StrictFIFOLock()
|
||||
|
||||
@acm
|
||||
async def _maybe_lock(self) -> AsyncContextManager[None]:
|
||||
if self._send_lock.locked():
|
||||
yield
|
||||
return
|
||||
|
||||
async with self._send_lock:
|
||||
yield
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._is_closed
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
|
@ -252,7 +257,13 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
|
|||
await self._wrap_event.read()
|
||||
|
||||
async def send_all(self, data: Buffer):
|
||||
async with self._maybe_lock():
|
||||
if self.closed:
|
||||
raise trio.ClosedResourceError
|
||||
|
||||
if self._send_all_lock.locked():
|
||||
raise trio.BusyResourceError
|
||||
|
||||
async with self._send_all_lock:
|
||||
# while data is larger than the remaining buf
|
||||
target_ptr = self.ptr + len(data)
|
||||
while target_ptr > self.size:
|
||||
|
@ -274,13 +285,16 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
|
|||
self._ptr = target_ptr
|
||||
|
||||
async def wait_send_all_might_not_block(self):
|
||||
raise NotImplementedError
|
||||
return
|
||||
|
||||
async def flush(
|
||||
self,
|
||||
new_batch_size: int | None = None
|
||||
) -> None:
|
||||
async with self._maybe_lock():
|
||||
if self.closed:
|
||||
raise trio.ClosedResourceError
|
||||
|
||||
async with self._send_lock:
|
||||
for msg in self._batch:
|
||||
await self.send_all(msg)
|
||||
|
||||
|
@ -289,7 +303,13 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
|
|||
self.batch_size = new_batch_size
|
||||
|
||||
async def send(self, value: bytes) -> None:
|
||||
async with self._maybe_lock():
|
||||
if self.closed:
|
||||
raise trio.ClosedResourceError
|
||||
|
||||
if self._send_lock.locked():
|
||||
raise trio.BusyResourceError
|
||||
|
||||
async with self._send_lock:
|
||||
msg: bytes = struct.pack("<I", len(value)) + value
|
||||
if self.batch_size == 1:
|
||||
await self.send_all(msg)
|
||||
|
@ -299,11 +319,6 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
|
|||
if self.must_flush:
|
||||
await self.flush()
|
||||
|
||||
async def send_eof(self) -> None:
|
||||
async with self._send_lock:
|
||||
await self.flush(new_batch_size=1)
|
||||
await self.send(b'')
|
||||
|
||||
def open(self):
|
||||
try:
|
||||
self._shm = SharedMemory(
|
||||
|
@ -314,12 +329,13 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
|
|||
self._write_event.open()
|
||||
self._wrap_event.open()
|
||||
self._eof_event.open()
|
||||
self._is_closed = False
|
||||
|
||||
except Exception as e:
|
||||
e.add_note(f'while opening sender for {self._token.as_msg()}')
|
||||
raise e
|
||||
|
||||
def close(self):
|
||||
def _close(self):
|
||||
self._eof_event.write(
|
||||
self._ptr if self._ptr > 0 else self.size
|
||||
)
|
||||
|
@ -330,8 +346,14 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
|
|||
self._eof_event.close()
|
||||
self._shm.close()
|
||||
|
||||
self._is_closed = True
|
||||
|
||||
async def aclose(self):
|
||||
self.close()
|
||||
if not self.closed:
|
||||
await self.send(b'')
|
||||
await self.flush()
|
||||
|
||||
self._close()
|
||||
|
||||
async def __aenter__(self):
|
||||
self.open()
|
||||
|
@ -362,6 +384,16 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
|||
|
||||
self._cleanup: bool = cleanup
|
||||
|
||||
self._is_closed: bool = True
|
||||
|
||||
self._receive_some_lock = trio.StrictFIFOLock()
|
||||
self._receive_exactly_lock = trio.StrictFIFOLock()
|
||||
self._receive_lock = trio.StrictFIFOLock()
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._is_closed
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
if not self._shm:
|
||||
|
@ -409,12 +441,25 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
|||
Try to receive any bytes we can without blocking or raise
|
||||
`trio.WouldBlock`.
|
||||
|
||||
Returns b'' when no more bytes can be read (EOF signaled & read all).
|
||||
|
||||
'''
|
||||
if max_bytes < 1:
|
||||
raise ValueError("max_bytes must be >= 1")
|
||||
|
||||
delta = self._write_ptr - self._ptr
|
||||
# in case `end_ptr` is set that means eof was signaled.
|
||||
# it will be >= `write_ptr`, use it for delta calc
|
||||
highest_ptr = max(self._write_ptr, self._end_ptr)
|
||||
|
||||
delta = highest_ptr - self._ptr
|
||||
|
||||
# no more bytes to read
|
||||
if delta == 0:
|
||||
# if `end_ptr` is set that means we read all bytes before EOF
|
||||
if self._end_ptr > 0:
|
||||
return b''
|
||||
|
||||
# signal the need to wait on `write_event`
|
||||
raise trio.WouldBlock
|
||||
|
||||
# dont overflow caller
|
||||
|
@ -442,35 +487,47 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
|||
Can return < max_bytes.
|
||||
|
||||
'''
|
||||
try:
|
||||
return self.receive_nowait(max_bytes=max_bytes)
|
||||
if self.closed:
|
||||
raise trio.ClosedResourceError
|
||||
|
||||
except trio.WouldBlock:
|
||||
# we have read all we can, see if new data is available
|
||||
if self._end_ptr < 0:
|
||||
# if we havent been signaled about EOF yet
|
||||
try:
|
||||
delta = await self._write_event.read()
|
||||
self._write_ptr += delta
|
||||
if self._receive_some_lock.locked():
|
||||
raise trio.BusyResourceError
|
||||
|
||||
except EFDReadCancelled:
|
||||
# while waiting for new data `self._write_event` was closed
|
||||
# this means writer signaled EOF
|
||||
if self._end_ptr > 0:
|
||||
# final self._write_ptr modification and recalculate delta
|
||||
self._write_ptr = self._end_ptr
|
||||
delta = self._end_ptr - self._ptr
|
||||
async with self._receive_some_lock:
|
||||
try:
|
||||
# attempt direct read
|
||||
return self.receive_nowait(max_bytes=max_bytes)
|
||||
|
||||
else:
|
||||
# shouldnt happen cause self._eof_monitor_task always sets
|
||||
# self._end_ptr before closing self._write_event
|
||||
raise InternalError(
|
||||
'self._write_event.read cancelled but self._end_ptr is not set'
|
||||
)
|
||||
except trio.WouldBlock as e:
|
||||
# we have read all we can, see if new data is available
|
||||
if self._end_ptr == -1:
|
||||
# if we havent been signaled about EOF yet
|
||||
try:
|
||||
# wait next write and advance `write_ptr`
|
||||
delta = await self._write_event.read()
|
||||
self._write_ptr += delta
|
||||
# yield lock and re-enter
|
||||
|
||||
else:
|
||||
# no more bytes to read and self._end_ptr set, EOF reached
|
||||
return b''
|
||||
except EFDReadCancelled:
|
||||
# while waiting for new data `self._write_event` was closed
|
||||
# this means writer signaled EOF
|
||||
if self._end_ptr > 0:
|
||||
# receive_nowait will handle read until EOF
|
||||
return self.receive_nowait(max_bytes=max_bytes)
|
||||
|
||||
else:
|
||||
# shouldnt happen because self._eof_monitor_task always sets
|
||||
# self._end_ptr before closing self._write_event
|
||||
raise InternalError(
|
||||
'self._write_event.read cancelled but self._end_ptr is not set'
|
||||
)
|
||||
|
||||
else:
|
||||
# shouldnt happen because receive_nowait does not raise
|
||||
# trio.WouldBlock when `end_ptr` is set
|
||||
raise InternalError(
|
||||
'self._end_ptr is set but receive_nowait raised trio.WouldBlock'
|
||||
) from e
|
||||
|
||||
return await self.receive_some(max_bytes=max_bytes)
|
||||
|
||||
|
@ -479,32 +536,50 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
|||
Fetch bytes until we read exactly `num_bytes` or EOF.
|
||||
|
||||
'''
|
||||
payload = b''
|
||||
while len(payload) < num_bytes:
|
||||
remaining = num_bytes - len(payload)
|
||||
if self.closed:
|
||||
raise trio.ClosedResourceError
|
||||
|
||||
new_bytes = await self.receive_some(
|
||||
max_bytes=remaining
|
||||
)
|
||||
if self._receive_exactly_lock.locked():
|
||||
raise trio.BusyResourceError
|
||||
|
||||
if new_bytes == b'':
|
||||
async with self._receive_exactly_lock:
|
||||
payload = b''
|
||||
while len(payload) < num_bytes:
|
||||
remaining = num_bytes - len(payload)
|
||||
|
||||
new_bytes = await self.receive_some(
|
||||
max_bytes=remaining
|
||||
)
|
||||
|
||||
if new_bytes == b'':
|
||||
break
|
||||
|
||||
payload += new_bytes
|
||||
|
||||
if payload == b'':
|
||||
raise trio.EndOfChannel
|
||||
|
||||
payload += new_bytes
|
||||
|
||||
return payload
|
||||
return payload
|
||||
|
||||
async def receive(self) -> bytes:
|
||||
'''
|
||||
Receive a complete payload
|
||||
|
||||
'''
|
||||
header: bytes = await self.receive_exactly(4)
|
||||
size: int
|
||||
size, = struct.unpack("<I", header)
|
||||
if size == 0:
|
||||
raise trio.EndOfChannel
|
||||
return await self.receive_exactly(size)
|
||||
if self.closed:
|
||||
raise trio.ClosedResourceError
|
||||
|
||||
if self._receive_lock.locked():
|
||||
raise trio.BusyResourceError
|
||||
|
||||
async with self._receive_lock:
|
||||
header: bytes = await self.receive_exactly(4)
|
||||
size: int
|
||||
size, = struct.unpack("<I", header)
|
||||
if size == 0:
|
||||
raise trio.EndOfChannel
|
||||
|
||||
return await self.receive_exactly(size)
|
||||
|
||||
def open(self):
|
||||
try:
|
||||
|
@ -516,6 +591,7 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
|||
self._write_event.open()
|
||||
self._wrap_event.open()
|
||||
self._eof_event.open()
|
||||
self._is_closed = False
|
||||
|
||||
except Exception as e:
|
||||
e.add_note(f'while opening receiver for {self._token.as_msg()}')
|
||||
|
@ -528,6 +604,8 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
|||
self._eof_event.close()
|
||||
self._shm.close()
|
||||
|
||||
self._is_closed = True
|
||||
|
||||
async def aclose(self):
|
||||
self.close()
|
||||
|
||||
|
@ -564,6 +642,7 @@ async def attach_to_ringbuf_receiver(
|
|||
async def attach_to_ringbuf_sender(
|
||||
|
||||
token: RBToken,
|
||||
batch_size: int = 1,
|
||||
cleanup: bool = True
|
||||
|
||||
) -> AsyncContextManager[RingBufferSendChannel]:
|
||||
|
@ -574,6 +653,7 @@ async def attach_to_ringbuf_sender(
|
|||
'''
|
||||
async with RingBufferSendChannel(
|
||||
token,
|
||||
batch_size=batch_size,
|
||||
cleanup=cleanup
|
||||
) as sender:
|
||||
yield sender
|
||||
|
@ -644,8 +724,9 @@ class RingBufferChannel(trio.abc.Channel[bytes]):
|
|||
async def attach_to_ringbuf_channel(
|
||||
token_in: RBToken,
|
||||
token_out: RBToken,
|
||||
batch_size: int = 1,
|
||||
cleanup_in: bool = True,
|
||||
cleanup_out: bool = True
|
||||
cleanup_out: bool = True,
|
||||
) -> AsyncContextManager[trio.StapledStream]:
|
||||
'''
|
||||
Attach to two previously opened `RBToken`s and return a `RingBufferChannel`
|
||||
|
@ -658,6 +739,7 @@ async def attach_to_ringbuf_channel(
|
|||
) as receiver,
|
||||
attach_to_ringbuf_sender(
|
||||
token_out,
|
||||
batch_size=batch_size,
|
||||
cleanup=cleanup_out
|
||||
) as sender,
|
||||
):
|
||||
|
|
Loading…
Reference in New Issue