Adhere to trio semantics on channels for closed and busy resource cases
parent
3a1eda9d6d
commit
1451feb159
|
@ -56,6 +56,7 @@ async def child_read_shm(
|
||||||
print(f'\n\telapsed ms: {elapsed_ms}')
|
print(f'\n\telapsed ms: {elapsed_ms}')
|
||||||
print(f'\tmsg/sec: {int(msg_amount / elapsed):,}')
|
print(f'\tmsg/sec: {int(msg_amount / elapsed):,}')
|
||||||
print(f'\tbytes/sec: {int(recvd_bytes / elapsed):,}')
|
print(f'\tbytes/sec: {int(recvd_bytes / elapsed):,}')
|
||||||
|
print(f'\treceived msgs: {msg_amount:,}')
|
||||||
print(f'\treceived bytes: {recvd_bytes:,}')
|
print(f'\treceived bytes: {recvd_bytes:,}')
|
||||||
|
|
||||||
return recvd_hash.hexdigest()
|
return recvd_hash.hexdigest()
|
||||||
|
@ -165,7 +166,6 @@ def test_ringbuf(
|
||||||
await send_p.cancel_actor()
|
await send_p.cancel_actor()
|
||||||
await recv_p.cancel_actor()
|
await recv_p.cancel_actor()
|
||||||
|
|
||||||
|
|
||||||
trio.run(main)
|
trio.run(main)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -200,23 +200,28 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
|
||||||
self._eof_event = EventFD(self._token.eof_eventfd, 'w')
|
self._eof_event = EventFD(self._token.eof_eventfd, 'w')
|
||||||
|
|
||||||
# current write pointer
|
# current write pointer
|
||||||
self._ptr = 0
|
self._ptr: int = 0
|
||||||
|
|
||||||
# when `batch_size` > 1 store messages on `self._batch` and write them
|
# when `batch_size` > 1 store messages on `self._batch` and write them
|
||||||
# all, once `len(self._batch) == `batch_size`
|
# all, once `len(self._batch) == `batch_size`
|
||||||
self._batch: list[bytes] = []
|
self._batch: list[bytes] = []
|
||||||
|
|
||||||
self._cleanup = cleanup
|
# close shm & fds on exit?
|
||||||
|
self._cleanup: bool = cleanup
|
||||||
|
|
||||||
|
# have we closed this ringbuf?
|
||||||
|
# set to `False` on `.open()`
|
||||||
|
self._is_closed: bool = True
|
||||||
|
|
||||||
|
# ensure no concurrent `.send_all()` calls
|
||||||
|
self._send_all_lock = trio.StrictFIFOLock()
|
||||||
|
|
||||||
|
# ensure no concurrent `.send()` calls
|
||||||
self._send_lock = trio.StrictFIFOLock()
|
self._send_lock = trio.StrictFIFOLock()
|
||||||
|
|
||||||
@acm
|
@property
|
||||||
async def _maybe_lock(self) -> AsyncContextManager[None]:
|
def closed(self) -> bool:
|
||||||
if self._send_lock.locked():
|
return self._is_closed
|
||||||
yield
|
|
||||||
return
|
|
||||||
|
|
||||||
async with self._send_lock:
|
|
||||||
yield
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
|
@ -252,7 +257,13 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
|
||||||
await self._wrap_event.read()
|
await self._wrap_event.read()
|
||||||
|
|
||||||
async def send_all(self, data: Buffer):
|
async def send_all(self, data: Buffer):
|
||||||
async with self._maybe_lock():
|
if self.closed:
|
||||||
|
raise trio.ClosedResourceError
|
||||||
|
|
||||||
|
if self._send_all_lock.locked():
|
||||||
|
raise trio.BusyResourceError
|
||||||
|
|
||||||
|
async with self._send_all_lock:
|
||||||
# while data is larger than the remaining buf
|
# while data is larger than the remaining buf
|
||||||
target_ptr = self.ptr + len(data)
|
target_ptr = self.ptr + len(data)
|
||||||
while target_ptr > self.size:
|
while target_ptr > self.size:
|
||||||
|
@ -274,13 +285,16 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
|
||||||
self._ptr = target_ptr
|
self._ptr = target_ptr
|
||||||
|
|
||||||
async def wait_send_all_might_not_block(self):
|
async def wait_send_all_might_not_block(self):
|
||||||
raise NotImplementedError
|
return
|
||||||
|
|
||||||
async def flush(
|
async def flush(
|
||||||
self,
|
self,
|
||||||
new_batch_size: int | None = None
|
new_batch_size: int | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
async with self._maybe_lock():
|
if self.closed:
|
||||||
|
raise trio.ClosedResourceError
|
||||||
|
|
||||||
|
async with self._send_lock:
|
||||||
for msg in self._batch:
|
for msg in self._batch:
|
||||||
await self.send_all(msg)
|
await self.send_all(msg)
|
||||||
|
|
||||||
|
@ -289,7 +303,13 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
|
||||||
self.batch_size = new_batch_size
|
self.batch_size = new_batch_size
|
||||||
|
|
||||||
async def send(self, value: bytes) -> None:
|
async def send(self, value: bytes) -> None:
|
||||||
async with self._maybe_lock():
|
if self.closed:
|
||||||
|
raise trio.ClosedResourceError
|
||||||
|
|
||||||
|
if self._send_lock.locked():
|
||||||
|
raise trio.BusyResourceError
|
||||||
|
|
||||||
|
async with self._send_lock:
|
||||||
msg: bytes = struct.pack("<I", len(value)) + value
|
msg: bytes = struct.pack("<I", len(value)) + value
|
||||||
if self.batch_size == 1:
|
if self.batch_size == 1:
|
||||||
await self.send_all(msg)
|
await self.send_all(msg)
|
||||||
|
@ -299,11 +319,6 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
|
||||||
if self.must_flush:
|
if self.must_flush:
|
||||||
await self.flush()
|
await self.flush()
|
||||||
|
|
||||||
async def send_eof(self) -> None:
|
|
||||||
async with self._send_lock:
|
|
||||||
await self.flush(new_batch_size=1)
|
|
||||||
await self.send(b'')
|
|
||||||
|
|
||||||
def open(self):
|
def open(self):
|
||||||
try:
|
try:
|
||||||
self._shm = SharedMemory(
|
self._shm = SharedMemory(
|
||||||
|
@ -314,12 +329,13 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
|
||||||
self._write_event.open()
|
self._write_event.open()
|
||||||
self._wrap_event.open()
|
self._wrap_event.open()
|
||||||
self._eof_event.open()
|
self._eof_event.open()
|
||||||
|
self._is_closed = False
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
e.add_note(f'while opening sender for {self._token.as_msg()}')
|
e.add_note(f'while opening sender for {self._token.as_msg()}')
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def close(self):
|
def _close(self):
|
||||||
self._eof_event.write(
|
self._eof_event.write(
|
||||||
self._ptr if self._ptr > 0 else self.size
|
self._ptr if self._ptr > 0 else self.size
|
||||||
)
|
)
|
||||||
|
@ -330,8 +346,14 @@ class RingBufferSendChannel(trio.abc.SendChannel[bytes]):
|
||||||
self._eof_event.close()
|
self._eof_event.close()
|
||||||
self._shm.close()
|
self._shm.close()
|
||||||
|
|
||||||
|
self._is_closed = True
|
||||||
|
|
||||||
async def aclose(self):
|
async def aclose(self):
|
||||||
self.close()
|
if not self.closed:
|
||||||
|
await self.send(b'')
|
||||||
|
await self.flush()
|
||||||
|
|
||||||
|
self._close()
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
self.open()
|
self.open()
|
||||||
|
@ -362,6 +384,16 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
||||||
|
|
||||||
self._cleanup: bool = cleanup
|
self._cleanup: bool = cleanup
|
||||||
|
|
||||||
|
self._is_closed: bool = True
|
||||||
|
|
||||||
|
self._receive_some_lock = trio.StrictFIFOLock()
|
||||||
|
self._receive_exactly_lock = trio.StrictFIFOLock()
|
||||||
|
self._receive_lock = trio.StrictFIFOLock()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def closed(self) -> bool:
|
||||||
|
return self._is_closed
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
if not self._shm:
|
if not self._shm:
|
||||||
|
@ -409,12 +441,25 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
||||||
Try to receive any bytes we can without blocking or raise
|
Try to receive any bytes we can without blocking or raise
|
||||||
`trio.WouldBlock`.
|
`trio.WouldBlock`.
|
||||||
|
|
||||||
|
Returns b'' when no more bytes can be read (EOF signaled & read all).
|
||||||
|
|
||||||
'''
|
'''
|
||||||
if max_bytes < 1:
|
if max_bytes < 1:
|
||||||
raise ValueError("max_bytes must be >= 1")
|
raise ValueError("max_bytes must be >= 1")
|
||||||
|
|
||||||
delta = self._write_ptr - self._ptr
|
# in case `end_ptr` is set that means eof was signaled.
|
||||||
|
# it will be >= `write_ptr`, use it for delta calc
|
||||||
|
highest_ptr = max(self._write_ptr, self._end_ptr)
|
||||||
|
|
||||||
|
delta = highest_ptr - self._ptr
|
||||||
|
|
||||||
|
# no more bytes to read
|
||||||
if delta == 0:
|
if delta == 0:
|
||||||
|
# if `end_ptr` is set that means we read all bytes before EOF
|
||||||
|
if self._end_ptr > 0:
|
||||||
|
return b''
|
||||||
|
|
||||||
|
# signal the need to wait on `write_event`
|
||||||
raise trio.WouldBlock
|
raise trio.WouldBlock
|
||||||
|
|
||||||
# dont overflow caller
|
# dont overflow caller
|
||||||
|
@ -442,35 +487,47 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
||||||
Can return < max_bytes.
|
Can return < max_bytes.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
if self.closed:
|
||||||
|
raise trio.ClosedResourceError
|
||||||
|
|
||||||
|
if self._receive_some_lock.locked():
|
||||||
|
raise trio.BusyResourceError
|
||||||
|
|
||||||
|
async with self._receive_some_lock:
|
||||||
try:
|
try:
|
||||||
|
# attempt direct read
|
||||||
return self.receive_nowait(max_bytes=max_bytes)
|
return self.receive_nowait(max_bytes=max_bytes)
|
||||||
|
|
||||||
except trio.WouldBlock:
|
except trio.WouldBlock as e:
|
||||||
# we have read all we can, see if new data is available
|
# we have read all we can, see if new data is available
|
||||||
if self._end_ptr < 0:
|
if self._end_ptr == -1:
|
||||||
# if we havent been signaled about EOF yet
|
# if we havent been signaled about EOF yet
|
||||||
try:
|
try:
|
||||||
|
# wait next write and advance `write_ptr`
|
||||||
delta = await self._write_event.read()
|
delta = await self._write_event.read()
|
||||||
self._write_ptr += delta
|
self._write_ptr += delta
|
||||||
|
# yield lock and re-enter
|
||||||
|
|
||||||
except EFDReadCancelled:
|
except EFDReadCancelled:
|
||||||
# while waiting for new data `self._write_event` was closed
|
# while waiting for new data `self._write_event` was closed
|
||||||
# this means writer signaled EOF
|
# this means writer signaled EOF
|
||||||
if self._end_ptr > 0:
|
if self._end_ptr > 0:
|
||||||
# final self._write_ptr modification and recalculate delta
|
# receive_nowait will handle read until EOF
|
||||||
self._write_ptr = self._end_ptr
|
return self.receive_nowait(max_bytes=max_bytes)
|
||||||
delta = self._end_ptr - self._ptr
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# shouldnt happen cause self._eof_monitor_task always sets
|
# shouldnt happen because self._eof_monitor_task always sets
|
||||||
# self._end_ptr before closing self._write_event
|
# self._end_ptr before closing self._write_event
|
||||||
raise InternalError(
|
raise InternalError(
|
||||||
'self._write_event.read cancelled but self._end_ptr is not set'
|
'self._write_event.read cancelled but self._end_ptr is not set'
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# no more bytes to read and self._end_ptr set, EOF reached
|
# shouldnt happen because receive_nowait does not raise
|
||||||
return b''
|
# trio.WouldBlock when `end_ptr` is set
|
||||||
|
raise InternalError(
|
||||||
|
'self._end_ptr is set but receive_nowait raised trio.WouldBlock'
|
||||||
|
) from e
|
||||||
|
|
||||||
return await self.receive_some(max_bytes=max_bytes)
|
return await self.receive_some(max_bytes=max_bytes)
|
||||||
|
|
||||||
|
@ -479,6 +536,13 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
||||||
Fetch bytes until we read exactly `num_bytes` or EOF.
|
Fetch bytes until we read exactly `num_bytes` or EOF.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
if self.closed:
|
||||||
|
raise trio.ClosedResourceError
|
||||||
|
|
||||||
|
if self._receive_exactly_lock.locked():
|
||||||
|
raise trio.BusyResourceError
|
||||||
|
|
||||||
|
async with self._receive_exactly_lock:
|
||||||
payload = b''
|
payload = b''
|
||||||
while len(payload) < num_bytes:
|
while len(payload) < num_bytes:
|
||||||
remaining = num_bytes - len(payload)
|
remaining = num_bytes - len(payload)
|
||||||
|
@ -488,10 +552,13 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
||||||
)
|
)
|
||||||
|
|
||||||
if new_bytes == b'':
|
if new_bytes == b'':
|
||||||
raise trio.EndOfChannel
|
break
|
||||||
|
|
||||||
payload += new_bytes
|
payload += new_bytes
|
||||||
|
|
||||||
|
if payload == b'':
|
||||||
|
raise trio.EndOfChannel
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
async def receive(self) -> bytes:
|
async def receive(self) -> bytes:
|
||||||
|
@ -499,11 +566,19 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
||||||
Receive a complete payload
|
Receive a complete payload
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
if self.closed:
|
||||||
|
raise trio.ClosedResourceError
|
||||||
|
|
||||||
|
if self._receive_lock.locked():
|
||||||
|
raise trio.BusyResourceError
|
||||||
|
|
||||||
|
async with self._receive_lock:
|
||||||
header: bytes = await self.receive_exactly(4)
|
header: bytes = await self.receive_exactly(4)
|
||||||
size: int
|
size: int
|
||||||
size, = struct.unpack("<I", header)
|
size, = struct.unpack("<I", header)
|
||||||
if size == 0:
|
if size == 0:
|
||||||
raise trio.EndOfChannel
|
raise trio.EndOfChannel
|
||||||
|
|
||||||
return await self.receive_exactly(size)
|
return await self.receive_exactly(size)
|
||||||
|
|
||||||
def open(self):
|
def open(self):
|
||||||
|
@ -516,6 +591,7 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
||||||
self._write_event.open()
|
self._write_event.open()
|
||||||
self._wrap_event.open()
|
self._wrap_event.open()
|
||||||
self._eof_event.open()
|
self._eof_event.open()
|
||||||
|
self._is_closed = False
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
e.add_note(f'while opening receiver for {self._token.as_msg()}')
|
e.add_note(f'while opening receiver for {self._token.as_msg()}')
|
||||||
|
@ -528,6 +604,8 @@ class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]):
|
||||||
self._eof_event.close()
|
self._eof_event.close()
|
||||||
self._shm.close()
|
self._shm.close()
|
||||||
|
|
||||||
|
self._is_closed = True
|
||||||
|
|
||||||
async def aclose(self):
|
async def aclose(self):
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
|
@ -564,6 +642,7 @@ async def attach_to_ringbuf_receiver(
|
||||||
async def attach_to_ringbuf_sender(
|
async def attach_to_ringbuf_sender(
|
||||||
|
|
||||||
token: RBToken,
|
token: RBToken,
|
||||||
|
batch_size: int = 1,
|
||||||
cleanup: bool = True
|
cleanup: bool = True
|
||||||
|
|
||||||
) -> AsyncContextManager[RingBufferSendChannel]:
|
) -> AsyncContextManager[RingBufferSendChannel]:
|
||||||
|
@ -574,6 +653,7 @@ async def attach_to_ringbuf_sender(
|
||||||
'''
|
'''
|
||||||
async with RingBufferSendChannel(
|
async with RingBufferSendChannel(
|
||||||
token,
|
token,
|
||||||
|
batch_size=batch_size,
|
||||||
cleanup=cleanup
|
cleanup=cleanup
|
||||||
) as sender:
|
) as sender:
|
||||||
yield sender
|
yield sender
|
||||||
|
@ -644,8 +724,9 @@ class RingBufferChannel(trio.abc.Channel[bytes]):
|
||||||
async def attach_to_ringbuf_channel(
|
async def attach_to_ringbuf_channel(
|
||||||
token_in: RBToken,
|
token_in: RBToken,
|
||||||
token_out: RBToken,
|
token_out: RBToken,
|
||||||
|
batch_size: int = 1,
|
||||||
cleanup_in: bool = True,
|
cleanup_in: bool = True,
|
||||||
cleanup_out: bool = True
|
cleanup_out: bool = True,
|
||||||
) -> AsyncContextManager[trio.StapledStream]:
|
) -> AsyncContextManager[trio.StapledStream]:
|
||||||
'''
|
'''
|
||||||
Attach to two previously opened `RBToken`s and return a `RingBufferChannel`
|
Attach to two previously opened `RBToken`s and return a `RingBufferChannel`
|
||||||
|
@ -658,6 +739,7 @@ async def attach_to_ringbuf_channel(
|
||||||
) as receiver,
|
) as receiver,
|
||||||
attach_to_ringbuf_sender(
|
attach_to_ringbuf_sender(
|
||||||
token_out,
|
token_out,
|
||||||
|
batch_size=batch_size,
|
||||||
cleanup=cleanup_out
|
cleanup=cleanup_out
|
||||||
) as sender,
|
) as sender,
|
||||||
):
|
):
|
||||||
|
|
Loading…
Reference in New Issue