Better encapsulate RingBuff ctx managment methods and support non ipc usage
Add trio.StrictFIFOLock on sender.send_all Support max_bytes argument on receive_some, keep track of write_ptr on receiver Add max_bytes receive test test_ringbuf_max_bytes Add docstrings to all ringbuf tests Remove EFD_NONBLOCK support, not necesary anymore since we can use abandon_on_cancel=True on trio.to_thread.run_sync Close eventfd's after usage on open_ringbuf
							parent
							
								
									017bc50582
								
							
						
					
					
						commit
						c424f6c8d5
					
				|  | @ -58,6 +58,8 @@ async def child_write_shm( | ||||||
|         for msg in msgs: |         for msg in msgs: | ||||||
|             await sender.send_all(msg) |             await sender.send_all(msg) | ||||||
| 
 | 
 | ||||||
|  |     print('writer exit') | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||||
|     'msg_amount,rand_min,rand_max,buf_size', |     'msg_amount,rand_min,rand_max,buf_size', | ||||||
|  | @ -83,6 +85,15 @@ def test_ringbuf( | ||||||
|     rand_max: int, |     rand_max: int, | ||||||
|     buf_size: int |     buf_size: int | ||||||
| ): | ): | ||||||
|  |     ''' | ||||||
|  |     - Open a new ring buf on root actor | ||||||
|  |     - Create a sender subactor and generate {msg_amount} messages | ||||||
|  |     optionally with a random amount of bytes at the end of each, | ||||||
|  |     return total_bytes on `ctx.started`, then send all messages | ||||||
|  |     - Create a receiver subactor and receive until total_bytes are | ||||||
|  |     read, print simple perf stats. | ||||||
|  | 
 | ||||||
|  |     ''' | ||||||
|     async def main(): |     async def main(): | ||||||
|         with open_ringbuf( |         with open_ringbuf( | ||||||
|             'test_ringbuf', |             'test_ringbuf', | ||||||
|  | @ -140,6 +151,11 @@ async def child_blocked_receiver( | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_ring_reader_cancel(): | def test_ring_reader_cancel(): | ||||||
|  |     ''' | ||||||
|  |     Test that a receiver blocked on eventfd(2) read responds to | ||||||
|  |     cancellation. | ||||||
|  | 
 | ||||||
|  |     ''' | ||||||
|     async def main(): |     async def main(): | ||||||
|         with open_ringbuf('test_ring_cancel_reader') as token: |         with open_ringbuf('test_ring_cancel_reader') as token: | ||||||
|             async with ( |             async with ( | ||||||
|  | @ -178,6 +194,11 @@ async def child_blocked_sender( | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_ring_sender_cancel(): | def test_ring_sender_cancel(): | ||||||
|  |     ''' | ||||||
|  |     Test that a sender blocked on eventfd(2) read responds to | ||||||
|  |     cancellation. | ||||||
|  | 
 | ||||||
|  |     ''' | ||||||
|     async def main(): |     async def main(): | ||||||
|         with open_ringbuf( |         with open_ringbuf( | ||||||
|             'test_ring_cancel_sender', |             'test_ring_cancel_sender', | ||||||
|  | @ -203,3 +224,36 @@ def test_ring_sender_cancel(): | ||||||
| 
 | 
 | ||||||
|     with pytest.raises(tractor._exceptions.ContextCancelled): |     with pytest.raises(tractor._exceptions.ContextCancelled): | ||||||
|         trio.run(main) |         trio.run(main) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_ringbuf_max_bytes(): | ||||||
|  |     ''' | ||||||
|  |     Test that RingBuffReceiver.receive_some's max_bytes optional | ||||||
|  |     argument works correctly, send a msg of size 100, then | ||||||
|  |     force receive of messages with max_bytes == 1, wait until | ||||||
|  |     100 of these messages are received, then compare join of | ||||||
|  |     msgs with original message | ||||||
|  | 
 | ||||||
|  |     ''' | ||||||
|  |     msg = b''.join(str(i % 10).encode() for i in range(100)) | ||||||
|  |     msgs = [] | ||||||
|  | 
 | ||||||
|  |     async def main(): | ||||||
|  |         with open_ringbuf( | ||||||
|  |             'test_ringbuf_max_bytes', | ||||||
|  |             buf_size=10 | ||||||
|  |         ) as token: | ||||||
|  |             async with ( | ||||||
|  |                 trio.open_nursery() as n, | ||||||
|  |                 RingBuffSender(token, is_ipc=False) as sender, | ||||||
|  |                 RingBuffReceiver(token, is_ipc=False) as receiver | ||||||
|  |             ): | ||||||
|  |                 n.start_soon(sender.send_all, msg) | ||||||
|  |                 while len(msgs) < len(msg): | ||||||
|  |                     msg_part = await receiver.receive_some(max_bytes=1) | ||||||
|  |                     msg_part = bytes(msg_part) | ||||||
|  |                     assert len(msg_part) == 1 | ||||||
|  |                     msgs.append(msg_part) | ||||||
|  | 
 | ||||||
|  |     trio.run(main) | ||||||
|  |     assert msg == b''.join(msgs) | ||||||
|  |  | ||||||
|  | @ -28,11 +28,15 @@ from msgspec import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| from ._linux import ( | from ._linux import ( | ||||||
|     EFD_NONBLOCK, |  | ||||||
|     open_eventfd, |     open_eventfd, | ||||||
|  |     close_eventfd, | ||||||
|     EventFD |     EventFD | ||||||
| ) | ) | ||||||
| from ._mp_bs import disable_mantracker | from ._mp_bs import disable_mantracker | ||||||
|  | from tractor.log import get_logger | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | log = get_logger(__name__) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| disable_mantracker() | disable_mantracker() | ||||||
|  | @ -64,8 +68,6 @@ class RBToken(Struct, frozen=True): | ||||||
| def open_ringbuf( | def open_ringbuf( | ||||||
|     shm_name: str, |     shm_name: str, | ||||||
|     buf_size: int = 10 * 1024, |     buf_size: int = 10 * 1024, | ||||||
|     write_efd_flags: int = 0, |  | ||||||
|     wrap_efd_flags: int = 0 |  | ||||||
| ) -> RBToken: | ) -> RBToken: | ||||||
|     shm = SharedMemory( |     shm = SharedMemory( | ||||||
|         name=shm_name, |         name=shm_name, | ||||||
|  | @ -75,16 +77,21 @@ def open_ringbuf( | ||||||
|     try: |     try: | ||||||
|         token = RBToken( |         token = RBToken( | ||||||
|             shm_name=shm_name, |             shm_name=shm_name, | ||||||
|             write_eventfd=open_eventfd(flags=write_efd_flags), |             write_eventfd=open_eventfd(), | ||||||
|             wrap_eventfd=open_eventfd(flags=wrap_efd_flags), |             wrap_eventfd=open_eventfd(), | ||||||
|             buf_size=buf_size |             buf_size=buf_size | ||||||
|         ) |         ) | ||||||
|         yield token |         yield token | ||||||
|  |         close_eventfd(token.write_eventfd) | ||||||
|  |         close_eventfd(token.wrap_eventfd) | ||||||
| 
 | 
 | ||||||
|     finally: |     finally: | ||||||
|         shm.unlink() |         shm.unlink() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | Buffer = bytes | bytearray | memoryview | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class RingBuffSender(trio.abc.SendStream): | class RingBuffSender(trio.abc.SendStream): | ||||||
|     ''' |     ''' | ||||||
|     IPC Reliable Ring Buffer sender side implementation |     IPC Reliable Ring Buffer sender side implementation | ||||||
|  | @ -97,24 +104,26 @@ class RingBuffSender(trio.abc.SendStream): | ||||||
|         self, |         self, | ||||||
|         token: RBToken, |         token: RBToken, | ||||||
|         start_ptr: int = 0, |         start_ptr: int = 0, | ||||||
|  |         is_ipc: bool = True | ||||||
|     ): |     ): | ||||||
|         token = RBToken.from_msg(token) |         self._token = RBToken.from_msg(token) | ||||||
|         self._shm = SharedMemory( |         self._shm: SharedMemory | None = None | ||||||
|             name=token.shm_name, |         self._write_event = EventFD(self._token.write_eventfd, 'w') | ||||||
|             size=token.buf_size, |         self._wrap_event = EventFD(self._token.wrap_eventfd, 'r') | ||||||
|             create=False |  | ||||||
|         ) |  | ||||||
|         self._write_event = EventFD(token.write_eventfd, 'w') |  | ||||||
|         self._wrap_event = EventFD(token.wrap_eventfd, 'r') |  | ||||||
|         self._ptr = start_ptr |         self._ptr = start_ptr | ||||||
| 
 | 
 | ||||||
|  |         self._is_ipc = is_ipc | ||||||
|  |         self._send_lock = trio.StrictFIFOLock() | ||||||
|  | 
 | ||||||
|     @property |     @property | ||||||
|     def key(self) -> str: |     def name(self) -> str: | ||||||
|  |         if not self._shm: | ||||||
|  |             raise ValueError('shared memory not initialized yet!') | ||||||
|         return self._shm.name |         return self._shm.name | ||||||
| 
 | 
 | ||||||
|     @property |     @property | ||||||
|     def size(self) -> int: |     def size(self) -> int: | ||||||
|         return self._shm.size |         return self._token.buf_size | ||||||
| 
 | 
 | ||||||
|     @property |     @property | ||||||
|     def ptr(self) -> int: |     def ptr(self) -> int: | ||||||
|  | @ -128,38 +137,48 @@ class RingBuffSender(trio.abc.SendStream): | ||||||
|     def wrap_fd(self) -> int: |     def wrap_fd(self) -> int: | ||||||
|         return self._wrap_event.fd |         return self._wrap_event.fd | ||||||
| 
 | 
 | ||||||
|     async def send_all(self, data: bytes | bytearray | memoryview): |     async def send_all(self, data: Buffer): | ||||||
|         # while data is larger than the remaining buf |         async with self._send_lock: | ||||||
|         target_ptr = self.ptr + len(data) |             # while data is larger than the remaining buf | ||||||
|         while target_ptr > self.size: |             target_ptr = self.ptr + len(data) | ||||||
|             # write all bytes that fit |             while target_ptr > self.size: | ||||||
|             remaining = self.size - self.ptr |                 # write all bytes that fit | ||||||
|             self._shm.buf[self.ptr:] = data[:remaining] |                 remaining = self.size - self.ptr | ||||||
|             # signal write and wait for reader wrap around |                 self._shm.buf[self.ptr:] = data[:remaining] | ||||||
|             self._write_event.write(remaining) |                 # signal write and wait for reader wrap around | ||||||
|             await self._wrap_event.read() |                 self._write_event.write(remaining) | ||||||
|  |                 await self._wrap_event.read() | ||||||
| 
 | 
 | ||||||
|             # wrap around and trim already written bytes |                 # wrap around and trim already written bytes | ||||||
|             self._ptr = 0 |                 self._ptr = 0 | ||||||
|             data = data[remaining:] |                 data = data[remaining:] | ||||||
|             target_ptr = self._ptr + len(data) |                 target_ptr = self._ptr + len(data) | ||||||
| 
 | 
 | ||||||
|         # remaining data fits on buffer |             # remaining data fits on buffer | ||||||
|         self._shm.buf[self.ptr:target_ptr] = data |             self._shm.buf[self.ptr:target_ptr] = data | ||||||
|         self._write_event.write(len(data)) |             self._write_event.write(len(data)) | ||||||
|         self._ptr = target_ptr |             self._ptr = target_ptr | ||||||
| 
 | 
 | ||||||
|     async def wait_send_all_might_not_block(self): |     async def wait_send_all_might_not_block(self): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
| 
 | 
 | ||||||
|     async def aclose(self): |     def open(self): | ||||||
|         self._write_event.close() |         self._shm = SharedMemory( | ||||||
|         self._wrap_event.close() |             name=self._token.shm_name, | ||||||
|         self._shm.close() |             size=self._token.buf_size, | ||||||
| 
 |             create=False | ||||||
|     async def __aenter__(self): |         ) | ||||||
|         self._write_event.open() |         self._write_event.open() | ||||||
|         self._wrap_event.open() |         self._wrap_event.open() | ||||||
|  | 
 | ||||||
|  |     async def aclose(self): | ||||||
|  |         if self._is_ipc: | ||||||
|  |             self._write_event.close() | ||||||
|  |             self._wrap_event.close() | ||||||
|  |             self._shm.close() | ||||||
|  | 
 | ||||||
|  |     async def __aenter__(self): | ||||||
|  |         self.open() | ||||||
|         return self |         return self | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -175,26 +194,25 @@ class RingBuffReceiver(trio.abc.ReceiveStream): | ||||||
|         self, |         self, | ||||||
|         token: RBToken, |         token: RBToken, | ||||||
|         start_ptr: int = 0, |         start_ptr: int = 0, | ||||||
|         flags: int = 0 |         is_ipc: bool = True | ||||||
|     ): |     ): | ||||||
|         token = RBToken.from_msg(token) |         self._token = RBToken.from_msg(token) | ||||||
|         self._shm = SharedMemory( |         self._shm: SharedMemory | None = None | ||||||
|             name=token.shm_name, |         self._write_event = EventFD(self._token.write_eventfd, 'w') | ||||||
|             size=token.buf_size, |         self._wrap_event = EventFD(self._token.wrap_eventfd, 'r') | ||||||
|             create=False |  | ||||||
|         ) |  | ||||||
|         self._write_event = EventFD(token.write_eventfd, 'w') |  | ||||||
|         self._wrap_event = EventFD(token.wrap_eventfd, 'r') |  | ||||||
|         self._ptr = start_ptr |         self._ptr = start_ptr | ||||||
|         self._flags = flags |         self._write_ptr = start_ptr | ||||||
|  |         self._is_ipc = is_ipc | ||||||
| 
 | 
 | ||||||
|     @property |     @property | ||||||
|     def key(self) -> str: |     def name(self) -> str: | ||||||
|  |         if not self._shm: | ||||||
|  |             raise ValueError('shared memory not initialized yet!') | ||||||
|         return self._shm.name |         return self._shm.name | ||||||
| 
 | 
 | ||||||
|     @property |     @property | ||||||
|     def size(self) -> int: |     def size(self) -> int: | ||||||
|         return self._shm.size |         return self._token.buf_size | ||||||
| 
 | 
 | ||||||
|     @property |     @property | ||||||
|     def ptr(self) -> int: |     def ptr(self) -> int: | ||||||
|  | @ -208,46 +226,46 @@ class RingBuffReceiver(trio.abc.ReceiveStream): | ||||||
|     def wrap_fd(self) -> int: |     def wrap_fd(self) -> int: | ||||||
|         return self._wrap_event.fd |         return self._wrap_event.fd | ||||||
| 
 | 
 | ||||||
|     async def receive_some( |     async def receive_some(self, max_bytes: int | None = None) -> memoryview: | ||||||
|         self, |         delta = self._write_ptr - self._ptr | ||||||
|         max_bytes: int | None = None, |         if delta == 0: | ||||||
|         nb_timeout: float = 0.1 |  | ||||||
|     ) -> memoryview: |  | ||||||
|         # if non blocking eventfd enabled, do polling |  | ||||||
|         # until next write, this allows signal handling |  | ||||||
|         if self._flags | EFD_NONBLOCK: |  | ||||||
|             delta = None |  | ||||||
|             while delta is None: |  | ||||||
|                 try: |  | ||||||
|                     delta = await self._write_event.read() |  | ||||||
| 
 |  | ||||||
|                 except OSError as e: |  | ||||||
|                     if e.errno == 'EAGAIN': |  | ||||||
|                         continue |  | ||||||
| 
 |  | ||||||
|                     raise e |  | ||||||
| 
 |  | ||||||
|         else: |  | ||||||
|             delta = await self._write_event.read() |             delta = await self._write_event.read() | ||||||
|  |             self._write_ptr += delta | ||||||
|  | 
 | ||||||
|  |         if isinstance(max_bytes, int): | ||||||
|  |             if max_bytes == 0: | ||||||
|  |                 raise ValueError('if set, max_bytes must be > 0') | ||||||
|  |             delta = min(delta, max_bytes) | ||||||
|  | 
 | ||||||
|  |         target_ptr = self._ptr + delta | ||||||
| 
 | 
 | ||||||
|         # fetch next segment and advance ptr |         # fetch next segment and advance ptr | ||||||
|         next_ptr = self._ptr + delta |         segment = self._shm.buf[self._ptr:target_ptr] | ||||||
|         segment = self._shm.buf[self._ptr:next_ptr] |         self._ptr = target_ptr | ||||||
|         self._ptr = next_ptr |  | ||||||
| 
 | 
 | ||||||
|         if self.ptr == self.size: |         if self._ptr == self.size: | ||||||
|             # reached the end, signal wrap around |             # reached the end, signal wrap around | ||||||
|             self._ptr = 0 |             self._ptr = 0 | ||||||
|  |             self._write_ptr = 0 | ||||||
|             self._wrap_event.write(1) |             self._wrap_event.write(1) | ||||||
| 
 | 
 | ||||||
|         return segment |         return segment | ||||||
| 
 | 
 | ||||||
|     async def aclose(self): |     def open(self): | ||||||
|         self._write_event.close() |         self._shm = SharedMemory( | ||||||
|         self._wrap_event.close() |             name=self._token.shm_name, | ||||||
|         self._shm.close() |             size=self._token.buf_size, | ||||||
| 
 |             create=False | ||||||
|     async def __aenter__(self): |         ) | ||||||
|         self._write_event.open() |         self._write_event.open() | ||||||
|         self._wrap_event.open() |         self._wrap_event.open() | ||||||
|  | 
 | ||||||
|  |     async def aclose(self): | ||||||
|  |         if self._is_ipc: | ||||||
|  |             self._write_event.close() | ||||||
|  |             self._wrap_event.close() | ||||||
|  |             self._shm.close() | ||||||
|  | 
 | ||||||
|  |     async def __aenter__(self): | ||||||
|  |         self.open() | ||||||
|         return self |         return self | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue