Better encapsulate RingBuff ctx managment methods and support non ipc usage
Add trio.StrictFIFOLock on sender.send_all Support max_bytes argument on receive_some, keep track of write_ptr on receiver Add max_bytes receive test test_ringbuf_max_bytes Add docstrings to all ringbuf tests Remove EFD_NONBLOCK support, not necesary anymore since we can use abandon_on_cancel=True on trio.to_thread.run_sync Close eventfd's after usage on open_ringbuf
							parent
							
								
									017bc50582
								
							
						
					
					
						commit
						c424f6c8d5
					
				|  | @ -58,6 +58,8 @@ async def child_write_shm( | |||
|         for msg in msgs: | ||||
|             await sender.send_all(msg) | ||||
| 
 | ||||
|     print('writer exit') | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|     'msg_amount,rand_min,rand_max,buf_size', | ||||
|  | @ -83,6 +85,15 @@ def test_ringbuf( | |||
|     rand_max: int, | ||||
|     buf_size: int | ||||
| ): | ||||
|     ''' | ||||
|     - Open a new ring buf on root actor | ||||
|     - Create a sender subactor and generate {msg_amount} messages | ||||
|     optionally with a random amount of bytes at the end of each, | ||||
|     return total_bytes on `ctx.started`, then send all messages | ||||
|     - Create a receiver subactor and receive until total_bytes are | ||||
|     read, print simple perf stats. | ||||
| 
 | ||||
|     ''' | ||||
|     async def main(): | ||||
|         with open_ringbuf( | ||||
|             'test_ringbuf', | ||||
|  | @ -140,6 +151,11 @@ async def child_blocked_receiver( | |||
| 
 | ||||
| 
 | ||||
| def test_ring_reader_cancel(): | ||||
|     ''' | ||||
|     Test that a receiver blocked on eventfd(2) read responds to | ||||
|     cancellation. | ||||
| 
 | ||||
|     ''' | ||||
|     async def main(): | ||||
|         with open_ringbuf('test_ring_cancel_reader') as token: | ||||
|             async with ( | ||||
|  | @ -178,6 +194,11 @@ async def child_blocked_sender( | |||
| 
 | ||||
| 
 | ||||
| def test_ring_sender_cancel(): | ||||
|     ''' | ||||
|     Test that a sender blocked on eventfd(2) read responds to | ||||
|     cancellation. | ||||
| 
 | ||||
|     ''' | ||||
|     async def main(): | ||||
|         with open_ringbuf( | ||||
|             'test_ring_cancel_sender', | ||||
|  | @ -203,3 +224,36 @@ def test_ring_sender_cancel(): | |||
| 
 | ||||
|     with pytest.raises(tractor._exceptions.ContextCancelled): | ||||
|         trio.run(main) | ||||
| 
 | ||||
| 
 | ||||
| def test_ringbuf_max_bytes(): | ||||
|     ''' | ||||
|     Test that RingBuffReceiver.receive_some's max_bytes optional | ||||
|     argument works correctly, send a msg of size 100, then | ||||
|     force receive of messages with max_bytes == 1, wait until | ||||
|     100 of these messages are received, then compare join of | ||||
|     msgs with original message | ||||
| 
 | ||||
|     ''' | ||||
|     msg = b''.join(str(i % 10).encode() for i in range(100)) | ||||
|     msgs = [] | ||||
| 
 | ||||
|     async def main(): | ||||
|         with open_ringbuf( | ||||
|             'test_ringbuf_max_bytes', | ||||
|             buf_size=10 | ||||
|         ) as token: | ||||
|             async with ( | ||||
|                 trio.open_nursery() as n, | ||||
|                 RingBuffSender(token, is_ipc=False) as sender, | ||||
|                 RingBuffReceiver(token, is_ipc=False) as receiver | ||||
|             ): | ||||
|                 n.start_soon(sender.send_all, msg) | ||||
|                 while len(msgs) < len(msg): | ||||
|                     msg_part = await receiver.receive_some(max_bytes=1) | ||||
|                     msg_part = bytes(msg_part) | ||||
|                     assert len(msg_part) == 1 | ||||
|                     msgs.append(msg_part) | ||||
| 
 | ||||
|     trio.run(main) | ||||
|     assert msg == b''.join(msgs) | ||||
|  |  | |||
|  | @ -28,11 +28,15 @@ from msgspec import ( | |||
| ) | ||||
| 
 | ||||
| from ._linux import ( | ||||
|     EFD_NONBLOCK, | ||||
|     open_eventfd, | ||||
|     close_eventfd, | ||||
|     EventFD | ||||
| ) | ||||
| from ._mp_bs import disable_mantracker | ||||
| from tractor.log import get_logger | ||||
| 
 | ||||
| 
 | ||||
| log = get_logger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| disable_mantracker() | ||||
|  | @ -64,8 +68,6 @@ class RBToken(Struct, frozen=True): | |||
| def open_ringbuf( | ||||
|     shm_name: str, | ||||
|     buf_size: int = 10 * 1024, | ||||
|     write_efd_flags: int = 0, | ||||
|     wrap_efd_flags: int = 0 | ||||
| ) -> RBToken: | ||||
|     shm = SharedMemory( | ||||
|         name=shm_name, | ||||
|  | @ -75,16 +77,21 @@ def open_ringbuf( | |||
|     try: | ||||
|         token = RBToken( | ||||
|             shm_name=shm_name, | ||||
|             write_eventfd=open_eventfd(flags=write_efd_flags), | ||||
|             wrap_eventfd=open_eventfd(flags=wrap_efd_flags), | ||||
|             write_eventfd=open_eventfd(), | ||||
|             wrap_eventfd=open_eventfd(), | ||||
|             buf_size=buf_size | ||||
|         ) | ||||
|         yield token | ||||
|         close_eventfd(token.write_eventfd) | ||||
|         close_eventfd(token.wrap_eventfd) | ||||
| 
 | ||||
|     finally: | ||||
|         shm.unlink() | ||||
| 
 | ||||
| 
 | ||||
| Buffer = bytes | bytearray | memoryview | ||||
| 
 | ||||
| 
 | ||||
| class RingBuffSender(trio.abc.SendStream): | ||||
|     ''' | ||||
|     IPC Reliable Ring Buffer sender side implementation | ||||
|  | @ -97,24 +104,26 @@ class RingBuffSender(trio.abc.SendStream): | |||
|         self, | ||||
|         token: RBToken, | ||||
|         start_ptr: int = 0, | ||||
|         is_ipc: bool = True | ||||
|     ): | ||||
|         token = RBToken.from_msg(token) | ||||
|         self._shm = SharedMemory( | ||||
|             name=token.shm_name, | ||||
|             size=token.buf_size, | ||||
|             create=False | ||||
|         ) | ||||
|         self._write_event = EventFD(token.write_eventfd, 'w') | ||||
|         self._wrap_event = EventFD(token.wrap_eventfd, 'r') | ||||
|         self._token = RBToken.from_msg(token) | ||||
|         self._shm: SharedMemory | None = None | ||||
|         self._write_event = EventFD(self._token.write_eventfd, 'w') | ||||
|         self._wrap_event = EventFD(self._token.wrap_eventfd, 'r') | ||||
|         self._ptr = start_ptr | ||||
| 
 | ||||
|         self._is_ipc = is_ipc | ||||
|         self._send_lock = trio.StrictFIFOLock() | ||||
| 
 | ||||
|     @property | ||||
|     def key(self) -> str: | ||||
|     def name(self) -> str: | ||||
|         if not self._shm: | ||||
|             raise ValueError('shared memory not initialized yet!') | ||||
|         return self._shm.name | ||||
| 
 | ||||
|     @property | ||||
|     def size(self) -> int: | ||||
|         return self._shm.size | ||||
|         return self._token.buf_size | ||||
| 
 | ||||
|     @property | ||||
|     def ptr(self) -> int: | ||||
|  | @ -128,38 +137,48 @@ class RingBuffSender(trio.abc.SendStream): | |||
|     def wrap_fd(self) -> int: | ||||
|         return self._wrap_event.fd | ||||
| 
 | ||||
|     async def send_all(self, data: bytes | bytearray | memoryview): | ||||
|         # while data is larger than the remaining buf | ||||
|         target_ptr = self.ptr + len(data) | ||||
|         while target_ptr > self.size: | ||||
|             # write all bytes that fit | ||||
|             remaining = self.size - self.ptr | ||||
|             self._shm.buf[self.ptr:] = data[:remaining] | ||||
|             # signal write and wait for reader wrap around | ||||
|             self._write_event.write(remaining) | ||||
|             await self._wrap_event.read() | ||||
|     async def send_all(self, data: Buffer): | ||||
|         async with self._send_lock: | ||||
|             # while data is larger than the remaining buf | ||||
|             target_ptr = self.ptr + len(data) | ||||
|             while target_ptr > self.size: | ||||
|                 # write all bytes that fit | ||||
|                 remaining = self.size - self.ptr | ||||
|                 self._shm.buf[self.ptr:] = data[:remaining] | ||||
|                 # signal write and wait for reader wrap around | ||||
|                 self._write_event.write(remaining) | ||||
|                 await self._wrap_event.read() | ||||
| 
 | ||||
|             # wrap around and trim already written bytes | ||||
|             self._ptr = 0 | ||||
|             data = data[remaining:] | ||||
|             target_ptr = self._ptr + len(data) | ||||
|                 # wrap around and trim already written bytes | ||||
|                 self._ptr = 0 | ||||
|                 data = data[remaining:] | ||||
|                 target_ptr = self._ptr + len(data) | ||||
| 
 | ||||
|         # remaining data fits on buffer | ||||
|         self._shm.buf[self.ptr:target_ptr] = data | ||||
|         self._write_event.write(len(data)) | ||||
|         self._ptr = target_ptr | ||||
|             # remaining data fits on buffer | ||||
|             self._shm.buf[self.ptr:target_ptr] = data | ||||
|             self._write_event.write(len(data)) | ||||
|             self._ptr = target_ptr | ||||
| 
 | ||||
|     async def wait_send_all_might_not_block(self): | ||||
|         raise NotImplementedError | ||||
| 
 | ||||
|     async def aclose(self): | ||||
|         self._write_event.close() | ||||
|         self._wrap_event.close() | ||||
|         self._shm.close() | ||||
| 
 | ||||
|     async def __aenter__(self): | ||||
|     def open(self): | ||||
|         self._shm = SharedMemory( | ||||
|             name=self._token.shm_name, | ||||
|             size=self._token.buf_size, | ||||
|             create=False | ||||
|         ) | ||||
|         self._write_event.open() | ||||
|         self._wrap_event.open() | ||||
| 
 | ||||
|     async def aclose(self): | ||||
|         if self._is_ipc: | ||||
|             self._write_event.close() | ||||
|             self._wrap_event.close() | ||||
|             self._shm.close() | ||||
| 
 | ||||
|     async def __aenter__(self): | ||||
|         self.open() | ||||
|         return self | ||||
| 
 | ||||
| 
 | ||||
|  | @ -175,26 +194,25 @@ class RingBuffReceiver(trio.abc.ReceiveStream): | |||
|         self, | ||||
|         token: RBToken, | ||||
|         start_ptr: int = 0, | ||||
|         flags: int = 0 | ||||
|         is_ipc: bool = True | ||||
|     ): | ||||
|         token = RBToken.from_msg(token) | ||||
|         self._shm = SharedMemory( | ||||
|             name=token.shm_name, | ||||
|             size=token.buf_size, | ||||
|             create=False | ||||
|         ) | ||||
|         self._write_event = EventFD(token.write_eventfd, 'w') | ||||
|         self._wrap_event = EventFD(token.wrap_eventfd, 'r') | ||||
|         self._token = RBToken.from_msg(token) | ||||
|         self._shm: SharedMemory | None = None | ||||
|         self._write_event = EventFD(self._token.write_eventfd, 'w') | ||||
|         self._wrap_event = EventFD(self._token.wrap_eventfd, 'r') | ||||
|         self._ptr = start_ptr | ||||
|         self._flags = flags | ||||
|         self._write_ptr = start_ptr | ||||
|         self._is_ipc = is_ipc | ||||
| 
 | ||||
|     @property | ||||
|     def key(self) -> str: | ||||
|     def name(self) -> str: | ||||
|         if not self._shm: | ||||
|             raise ValueError('shared memory not initialized yet!') | ||||
|         return self._shm.name | ||||
| 
 | ||||
|     @property | ||||
|     def size(self) -> int: | ||||
|         return self._shm.size | ||||
|         return self._token.buf_size | ||||
| 
 | ||||
|     @property | ||||
|     def ptr(self) -> int: | ||||
|  | @ -208,46 +226,46 @@ class RingBuffReceiver(trio.abc.ReceiveStream): | |||
|     def wrap_fd(self) -> int: | ||||
|         return self._wrap_event.fd | ||||
| 
 | ||||
|     async def receive_some( | ||||
|         self, | ||||
|         max_bytes: int | None = None, | ||||
|         nb_timeout: float = 0.1 | ||||
|     ) -> memoryview: | ||||
|         # if non blocking eventfd enabled, do polling | ||||
|         # until next write, this allows signal handling | ||||
|         if self._flags | EFD_NONBLOCK: | ||||
|             delta = None | ||||
|             while delta is None: | ||||
|                 try: | ||||
|                     delta = await self._write_event.read() | ||||
| 
 | ||||
|                 except OSError as e: | ||||
|                     if e.errno == 'EAGAIN': | ||||
|                         continue | ||||
| 
 | ||||
|                     raise e | ||||
| 
 | ||||
|         else: | ||||
|     async def receive_some(self, max_bytes: int | None = None) -> memoryview: | ||||
|         delta = self._write_ptr - self._ptr | ||||
|         if delta == 0: | ||||
|             delta = await self._write_event.read() | ||||
|             self._write_ptr += delta | ||||
| 
 | ||||
|         if isinstance(max_bytes, int): | ||||
|             if max_bytes == 0: | ||||
|                 raise ValueError('if set, max_bytes must be > 0') | ||||
|             delta = min(delta, max_bytes) | ||||
| 
 | ||||
|         target_ptr = self._ptr + delta | ||||
| 
 | ||||
|         # fetch next segment and advance ptr | ||||
|         next_ptr = self._ptr + delta | ||||
|         segment = self._shm.buf[self._ptr:next_ptr] | ||||
|         self._ptr = next_ptr | ||||
|         segment = self._shm.buf[self._ptr:target_ptr] | ||||
|         self._ptr = target_ptr | ||||
| 
 | ||||
|         if self.ptr == self.size: | ||||
|         if self._ptr == self.size: | ||||
|             # reached the end, signal wrap around | ||||
|             self._ptr = 0 | ||||
|             self._write_ptr = 0 | ||||
|             self._wrap_event.write(1) | ||||
| 
 | ||||
|         return segment | ||||
| 
 | ||||
|     async def aclose(self): | ||||
|         self._write_event.close() | ||||
|         self._wrap_event.close() | ||||
|         self._shm.close() | ||||
| 
 | ||||
|     async def __aenter__(self): | ||||
|     def open(self): | ||||
|         self._shm = SharedMemory( | ||||
|             name=self._token.shm_name, | ||||
|             size=self._token.buf_size, | ||||
|             create=False | ||||
|         ) | ||||
|         self._write_event.open() | ||||
|         self._wrap_event.open() | ||||
| 
 | ||||
|     async def aclose(self): | ||||
|         if self._is_ipc: | ||||
|             self._write_event.close() | ||||
|             self._wrap_event.close() | ||||
|             self._shm.close() | ||||
| 
 | ||||
|     async def __aenter__(self): | ||||
|         self.open() | ||||
|         return self | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue