Rename RingBuff -> RingBuffer
Combine RingBuffer stream and channel apis Implement RingBufferReceiveChannel.receive_nowait Make msg generator calculate hash
							parent
							
								
									70d72fd173
								
							
						
					
					
						commit
						eb20e5ea8d
					
				|  | @ -8,7 +8,6 @@ from tractor.ipc import ( | ||||||
|     open_ringbuf, |     open_ringbuf, | ||||||
|     attach_to_ringbuf_receiver, |     attach_to_ringbuf_receiver, | ||||||
|     attach_to_ringbuf_sender, |     attach_to_ringbuf_sender, | ||||||
|     attach_to_ringbuf_stream, |  | ||||||
|     attach_to_ringbuf_channel, |     attach_to_ringbuf_channel, | ||||||
|     RBToken, |     RBToken, | ||||||
| ) | ) | ||||||
|  | @ -21,7 +20,6 @@ from tractor._testing.samples import ( | ||||||
| @tractor.context | @tractor.context | ||||||
| async def child_read_shm( | async def child_read_shm( | ||||||
|     ctx: tractor.Context, |     ctx: tractor.Context, | ||||||
|     msg_amount: int, |  | ||||||
|     token: RBToken, |     token: RBToken, | ||||||
| ) -> str: | ) -> str: | ||||||
|     ''' |     ''' | ||||||
|  | @ -37,11 +35,13 @@ async def child_read_shm( | ||||||
|     ''' |     ''' | ||||||
|     await ctx.started() |     await ctx.started() | ||||||
|     print('reader started') |     print('reader started') | ||||||
|  |     msg_amount = 0 | ||||||
|     recvd_bytes = 0 |     recvd_bytes = 0 | ||||||
|     recvd_hash = hashlib.sha256() |     recvd_hash = hashlib.sha256() | ||||||
|     start_ts = time.time() |     start_ts = time.time() | ||||||
|     async with attach_to_ringbuf_receiver(token) as receiver: |     async with attach_to_ringbuf_receiver(token) as receiver: | ||||||
|         async for msg in receiver: |         async for msg in receiver: | ||||||
|  |             msg_amount += 1 | ||||||
|             recvd_hash.update(msg) |             recvd_hash.update(msg) | ||||||
|             recvd_bytes += len(msg) |             recvd_bytes += len(msg) | ||||||
| 
 | 
 | ||||||
|  | @ -75,19 +75,16 @@ async def child_write_shm( | ||||||
|     Attach to ringbuf and send all generated messages. |     Attach to ringbuf and send all generated messages. | ||||||
| 
 | 
 | ||||||
|     ''' |     ''' | ||||||
|     msgs, _total_bytes = generate_sample_messages( |     sent_hash, msgs, _total_bytes = generate_sample_messages( | ||||||
|         msg_amount, |         msg_amount, | ||||||
|         rand_min=rand_min, |         rand_min=rand_min, | ||||||
|         rand_max=rand_max, |         rand_max=rand_max, | ||||||
|     ) |     ) | ||||||
|     print('writer hashing payload...') |  | ||||||
|     sent_hash = hashlib.sha256(b''.join(msgs)).hexdigest() |  | ||||||
|     print('writer done hashing.') |  | ||||||
|     await ctx.started(sent_hash) |     await ctx.started(sent_hash) | ||||||
|     print('writer started') |     print('writer started') | ||||||
|     async with attach_to_ringbuf_sender(token, cleanup=False) as sender: |     async with attach_to_ringbuf_sender(token, cleanup=False) as sender: | ||||||
|         for msg in msgs: |         for msg in msgs: | ||||||
|             await sender.send_all(msg) |             await sender.send(msg) | ||||||
| 
 | 
 | ||||||
|     print('writer exit') |     print('writer exit') | ||||||
| 
 | 
 | ||||||
|  | @ -155,7 +152,6 @@ def test_ringbuf( | ||||||
|                     recv_p.open_context( |                     recv_p.open_context( | ||||||
|                         child_read_shm, |                         child_read_shm, | ||||||
|                         token=token, |                         token=token, | ||||||
|                         msg_amount=msg_amount |  | ||||||
|                     ) as (rctx, _sent), |                     ) as (rctx, _sent), | ||||||
|                 ): |                 ): | ||||||
|                     recvd_hash = await rctx.result() |                     recvd_hash = await rctx.result() | ||||||
|  | @ -291,75 +287,6 @@ def test_receiver_max_bytes(): | ||||||
|     assert msg == b''.join(msgs) |     assert msg == b''.join(msgs) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_stapled_ringbuf(): |  | ||||||
|     ''' |  | ||||||
|     Open two ringbufs and give tokens to tasks (swap them such that in/out tokens |  | ||||||
|     are inversed on each task) which will open the streams and use trio.StapledStream |  | ||||||
|     to have a single bidirectional stream. |  | ||||||
| 
 |  | ||||||
|     Then take turns to send and receive messages. |  | ||||||
| 
 |  | ||||||
|     ''' |  | ||||||
|     msg = generate_single_byte_msgs(100) |  | ||||||
|     pair_0_msgs = [] |  | ||||||
|     pair_1_msgs = [] |  | ||||||
| 
 |  | ||||||
|     pair_0_done = trio.Event() |  | ||||||
|     pair_1_done = trio.Event() |  | ||||||
| 
 |  | ||||||
|     async def pair_0(token_in: RBToken, token_out: RBToken): |  | ||||||
|         async with attach_to_ringbuf_stream( |  | ||||||
|             token_in, |  | ||||||
|             token_out, |  | ||||||
|             cleanup_in=False, |  | ||||||
|             cleanup_out=False |  | ||||||
|         ) as stream: |  | ||||||
|             # first turn to send |  | ||||||
|             await stream.send_all(msg) |  | ||||||
| 
 |  | ||||||
|             # second turn to receive |  | ||||||
|             while len(pair_0_msgs) != len(msg): |  | ||||||
|                 _msg = await stream.receive_some(max_bytes=1) |  | ||||||
|                 pair_0_msgs.append(_msg) |  | ||||||
| 
 |  | ||||||
|             pair_0_done.set() |  | ||||||
|             await pair_1_done.wait() |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     async def pair_1(token_in: RBToken, token_out: RBToken): |  | ||||||
|         async with attach_to_ringbuf_stream( |  | ||||||
|             token_in, |  | ||||||
|             token_out, |  | ||||||
|             cleanup_in=False, |  | ||||||
|             cleanup_out=False |  | ||||||
|         ) as stream: |  | ||||||
|             # first turn to receive |  | ||||||
|             while len(pair_1_msgs) != len(msg): |  | ||||||
|                 _msg = await stream.receive_some(max_bytes=1) |  | ||||||
|                 pair_1_msgs.append(_msg) |  | ||||||
| 
 |  | ||||||
|             # second turn to send |  | ||||||
|             await stream.send_all(msg) |  | ||||||
| 
 |  | ||||||
|             pair_1_done.set() |  | ||||||
|             await pair_0_done.wait() |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     async def main(): |  | ||||||
|         with tractor.ipc.open_ringbuf_pair( |  | ||||||
|             'test_stapled_ringbuf' |  | ||||||
|         ) as (token_0, token_1): |  | ||||||
|             async with trio.open_nursery() as n: |  | ||||||
|                 n.start_soon(pair_0, token_0, token_1) |  | ||||||
|                 n.start_soon(pair_1, token_1, token_0) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     trio.run(main) |  | ||||||
| 
 |  | ||||||
|     assert msg == b''.join(pair_0_msgs) |  | ||||||
|     assert msg == b''.join(pair_1_msgs) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| @tractor.context | @tractor.context | ||||||
| async def child_channel_sender( | async def child_channel_sender( | ||||||
|     ctx: tractor.Context, |     ctx: tractor.Context, | ||||||
|  | @ -369,7 +296,7 @@ async def child_channel_sender( | ||||||
|     token_out: RBToken |     token_out: RBToken | ||||||
| ): | ): | ||||||
|     import random |     import random | ||||||
|     msgs, _total_bytes = generate_sample_messages( |     _hash, msgs, _total_bytes = generate_sample_messages( | ||||||
|         random.randint(msg_amount_min, msg_amount_max), |         random.randint(msg_amount_min, msg_amount_max), | ||||||
|         rand_min=256, |         rand_min=256, | ||||||
|         rand_max=1024, |         rand_max=1024, | ||||||
|  | @ -379,7 +306,6 @@ async def child_channel_sender( | ||||||
|         token_out |         token_out | ||||||
|     ) as chan: |     ) as chan: | ||||||
|         await ctx.started(msgs) |         await ctx.started(msgs) | ||||||
| 
 |  | ||||||
|         for msg in msgs: |         for msg in msgs: | ||||||
|             await chan.send(msg) |             await chan.send(msg) | ||||||
| 
 | 
 | ||||||
|  | @ -392,16 +318,16 @@ def test_channel(): | ||||||
|     async def main(): |     async def main(): | ||||||
|         with tractor.ipc.open_ringbuf_pair( |         with tractor.ipc.open_ringbuf_pair( | ||||||
|             'test_ringbuf_transport' |             'test_ringbuf_transport' | ||||||
|         ) as (token_0, token_1): |         ) as (send_token, recv_token): | ||||||
|             async with ( |             async with ( | ||||||
|                 attach_to_ringbuf_channel(token_0, token_1) as chan, |                 attach_to_ringbuf_channel(send_token, recv_token) as chan, | ||||||
|                 tractor.open_nursery() as an |                 tractor.open_nursery() as an | ||||||
|             ): |             ): | ||||||
|                 recv_p = await an.start_actor( |                 recv_p = await an.start_actor( | ||||||
|                     'test_ringbuf_transport_sender', |                     'test_ringbuf_transport_sender', | ||||||
|                     enable_modules=[__name__], |                     enable_modules=[__name__], | ||||||
|                     proc_kwargs={ |                     proc_kwargs={ | ||||||
|                         'pass_fds': token_0.fds + token_1.fds |                         'pass_fds': send_token.fds + recv_token.fds | ||||||
|                     } |                     } | ||||||
|                 ) |                 ) | ||||||
|                 async with ( |                 async with ( | ||||||
|  | @ -409,8 +335,8 @@ def test_channel(): | ||||||
|                         child_channel_sender, |                         child_channel_sender, | ||||||
|                         msg_amount_min=msg_amount_min, |                         msg_amount_min=msg_amount_min, | ||||||
|                         msg_amount_max=msg_amount_max, |                         msg_amount_max=msg_amount_max, | ||||||
|                         token_in=token_1, |                         token_in=recv_token, | ||||||
|                         token_out=token_0 |                         token_out=send_token | ||||||
|                     ) as (ctx, msgs), |                     ) as (ctx, msgs), | ||||||
|                 ): |                 ): | ||||||
|                     recv_msgs = [] |                     recv_msgs = [] | ||||||
|  |  | ||||||
|  | @ -1,5 +1,6 @@ | ||||||
| import os | import os | ||||||
| import random | import random | ||||||
|  | import hashlib | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def generate_single_byte_msgs(amount: int) -> bytes: | def generate_single_byte_msgs(amount: int) -> bytes: | ||||||
|  | @ -23,7 +24,7 @@ def generate_sample_messages( | ||||||
|     rand_min: int = 0, |     rand_min: int = 0, | ||||||
|     rand_max: int = 0, |     rand_max: int = 0, | ||||||
|     silent: bool = False, |     silent: bool = False, | ||||||
| ) -> tuple[list[bytes], int]: | ) -> tuple[str, list[bytes], int]: | ||||||
|     ''' |     ''' | ||||||
|     Generate bytes msgs for tests. |     Generate bytes msgs for tests. | ||||||
| 
 | 
 | ||||||
|  | @ -55,6 +56,7 @@ def generate_sample_messages( | ||||||
|         else: |         else: | ||||||
|             log_interval = 1000 |             log_interval = 1000 | ||||||
| 
 | 
 | ||||||
|  |     payload_hash = hashlib.sha256() | ||||||
|     for i in range(amount): |     for i in range(amount): | ||||||
|         msg = f'[{i:08}]'.encode('utf-8') |         msg = f'[{i:08}]'.encode('utf-8') | ||||||
| 
 | 
 | ||||||
|  | @ -64,6 +66,7 @@ def generate_sample_messages( | ||||||
| 
 | 
 | ||||||
|         size += len(msg) |         size += len(msg) | ||||||
| 
 | 
 | ||||||
|  |         payload_hash.update(msg) | ||||||
|         msgs.append(msg) |         msgs.append(msg) | ||||||
| 
 | 
 | ||||||
|         if ( |         if ( | ||||||
|  | @ -78,4 +81,4 @@ def generate_sample_messages( | ||||||
|     if not silent: |     if not silent: | ||||||
|         print(f'done, {size:,} bytes in total') |         print(f'done, {size:,} bytes in total') | ||||||
| 
 | 
 | ||||||
|     return msgs, size |     return payload_hash.hexdigest(), msgs, size | ||||||
|  |  | ||||||
|  | @ -31,17 +31,16 @@ from ._chan import ( | ||||||
| if platform.system() == 'Linux': | if platform.system() == 'Linux': | ||||||
|     from ._ringbuf import ( |     from ._ringbuf import ( | ||||||
|         RBToken as RBToken, |         RBToken as RBToken, | ||||||
|  | 
 | ||||||
|         open_ringbuf as open_ringbuf, |         open_ringbuf as open_ringbuf, | ||||||
|         RingBuffSender as RingBuffSender, |  | ||||||
|         RingBuffReceiver as RingBuffReceiver, |  | ||||||
|         open_ringbuf_pair as open_ringbuf_pair, |         open_ringbuf_pair as open_ringbuf_pair, | ||||||
|         attach_to_ringbuf_receiver as attach_to_ringbuf_receiver, | 
 | ||||||
|  |         RingBufferSendChannel as RingBufferSendChannel, | ||||||
|         attach_to_ringbuf_sender as attach_to_ringbuf_sender, |         attach_to_ringbuf_sender as attach_to_ringbuf_sender, | ||||||
|         attach_to_ringbuf_stream as attach_to_ringbuf_stream, | 
 | ||||||
|         RingBuffBytesSender as RingBuffBytesSender, |         RingBufferReceiveChannel as RingBufferReceiveChannel, | ||||||
|         RingBuffBytesReceiver as RingBuffBytesReceiver, |         attach_to_ringbuf_receiver as attach_to_ringbuf_receiver, | ||||||
|         RingBuffChannel as RingBuffChannel, | 
 | ||||||
|         attach_to_ringbuf_schannel as attach_to_ringbuf_schannel, |         RingBufferChannel as RingBufferChannel, | ||||||
|         attach_to_ringbuf_rchannel as attach_to_ringbuf_rchannel, |  | ||||||
|         attach_to_ringbuf_channel as attach_to_ringbuf_channel, |         attach_to_ringbuf_channel as attach_to_ringbuf_channel, | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  | @ -126,6 +126,30 @@ def open_ringbuf( | ||||||
|         shm.unlink() |         shm.unlink() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @cm | ||||||
|  | def open_ringbuf_pair( | ||||||
|  |     name: str, | ||||||
|  |     buf_size: int = _DEFAULT_RB_SIZE | ||||||
|  | ) -> ContextManager[tuple(RBToken, RBToken)]: | ||||||
|  |     ''' | ||||||
|  |     Handle resources for a ringbuf pair to be used for | ||||||
|  |     bidirectional messaging. | ||||||
|  | 
 | ||||||
|  |     ''' | ||||||
|  |     with ( | ||||||
|  |         open_ringbuf( | ||||||
|  |             name + '.send', | ||||||
|  |             buf_size=buf_size | ||||||
|  |         ) as send_token, | ||||||
|  | 
 | ||||||
|  |         open_ringbuf( | ||||||
|  |             name + '.recv', | ||||||
|  |             buf_size=buf_size | ||||||
|  |         ) as recv_token | ||||||
|  |     ): | ||||||
|  |         yield send_token, recv_token | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| Buffer = bytes | bytearray | memoryview | Buffer = bytes | bytearray | memoryview | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -135,32 +159,65 @@ IPC Reliable Ring Buffer | ||||||
| `eventfd(2)` is used for wrap around sync, to signal writes to | `eventfd(2)` is used for wrap around sync, to signal writes to | ||||||
| the reader and end of stream. | the reader and end of stream. | ||||||
| 
 | 
 | ||||||
|  | In order to guarantee full messages are received, all bytes | ||||||
|  | sent by `RingBufferSendChannel` are preceded with a 4 byte header | ||||||
|  | which decodes into a uint32 indicating the actual size of the | ||||||
|  | next full payload. | ||||||
|  | 
 | ||||||
| ''' | ''' | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class RingBuffSender(trio.abc.SendStream): | class RingBufferSendChannel(trio.abc.SendChannel[bytes]): | ||||||
|     ''' |     ''' | ||||||
|     Ring Buffer sender side implementation |     Ring Buffer sender side implementation | ||||||
| 
 | 
 | ||||||
|     Do not use directly! manage with `attach_to_ringbuf_sender` |     Do not use directly! manage with `attach_to_ringbuf_sender` | ||||||
|     after having opened a ringbuf context with `open_ringbuf`. |     after having opened a ringbuf context with `open_ringbuf`. | ||||||
| 
 | 
 | ||||||
|  |     Optional batch mode: | ||||||
|  | 
 | ||||||
|  |     If `batch_size` > 1 messages wont get sent immediately but will be | ||||||
|  |     stored until `batch_size` messages are pending, then it will send | ||||||
|  |     them all at once. | ||||||
|  | 
 | ||||||
|  |     `batch_size` can be changed dynamically but always call, `flush()` | ||||||
|  |     right before. | ||||||
|  | 
 | ||||||
|     ''' |     ''' | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         token: RBToken, |         token: RBToken, | ||||||
|  |         batch_size: int = 1, | ||||||
|         cleanup: bool = False |         cleanup: bool = False | ||||||
|     ): |     ): | ||||||
|         self._token = RBToken.from_msg(token) |         self._token = RBToken.from_msg(token) | ||||||
|  |         self.batch_size = batch_size | ||||||
|  | 
 | ||||||
|  |         # ringbuf os resources | ||||||
|         self._shm: SharedMemory | None = None |         self._shm: SharedMemory | None = None | ||||||
|         self._write_event = EventFD(self._token.write_eventfd, 'w') |         self._write_event = EventFD(self._token.write_eventfd, 'w') | ||||||
|         self._wrap_event = EventFD(self._token.wrap_eventfd, 'r') |         self._wrap_event = EventFD(self._token.wrap_eventfd, 'r') | ||||||
|         self._eof_event = EventFD(self._token.eof_eventfd, 'w') |         self._eof_event = EventFD(self._token.eof_eventfd, 'w') | ||||||
|  | 
 | ||||||
|  |         # current write pointer | ||||||
|         self._ptr = 0 |         self._ptr = 0 | ||||||
| 
 | 
 | ||||||
|  |         # when `batch_size` > 1 store messages on `self._batch` and write them | ||||||
|  |         # all, once `len(self._batch) == `batch_size` | ||||||
|  |         self._batch: list[bytes] = [] | ||||||
|  | 
 | ||||||
|         self._cleanup = cleanup |         self._cleanup = cleanup | ||||||
|         self._send_lock = trio.StrictFIFOLock() |         self._send_lock = trio.StrictFIFOLock() | ||||||
| 
 | 
 | ||||||
|  |     @acm | ||||||
|  |     async def _maybe_lock(self) -> AsyncContextManager[None]: | ||||||
|  |         if self._send_lock.locked(): | ||||||
|  |             yield | ||||||
|  |             return | ||||||
|  | 
 | ||||||
|  |         async with self._send_lock: | ||||||
|  |             yield | ||||||
|  | 
 | ||||||
|     @property |     @property | ||||||
|     def name(self) -> str: |     def name(self) -> str: | ||||||
|         if not self._shm: |         if not self._shm: | ||||||
|  | @ -183,11 +240,19 @@ class RingBuffSender(trio.abc.SendStream): | ||||||
|     def wrap_fd(self) -> int: |     def wrap_fd(self) -> int: | ||||||
|         return self._wrap_event.fd |         return self._wrap_event.fd | ||||||
| 
 | 
 | ||||||
|  |     @property | ||||||
|  |     def pending_msgs(self) -> int: | ||||||
|  |         return len(self._batch) | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def must_flush(self) -> bool: | ||||||
|  |         return self.pending_msgs >= self.batch_size | ||||||
|  | 
 | ||||||
|     async def _wait_wrap(self): |     async def _wait_wrap(self): | ||||||
|         await self._wrap_event.read() |         await self._wrap_event.read() | ||||||
| 
 | 
 | ||||||
|     async def send_all(self, data: Buffer): |     async def send_all(self, data: Buffer): | ||||||
|         async with self._send_lock: |         async with self._maybe_lock(): | ||||||
|             # while data is larger than the remaining buf |             # while data is larger than the remaining buf | ||||||
|             target_ptr = self.ptr + len(data) |             target_ptr = self.ptr + len(data) | ||||||
|             while target_ptr > self.size: |             while target_ptr > self.size: | ||||||
|  | @ -211,6 +276,34 @@ class RingBuffSender(trio.abc.SendStream): | ||||||
|     async def wait_send_all_might_not_block(self): |     async def wait_send_all_might_not_block(self): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
| 
 | 
 | ||||||
|  |     async def flush( | ||||||
|  |         self, | ||||||
|  |         new_batch_size: int | None = None | ||||||
|  |     ) -> None: | ||||||
|  |         async with self._maybe_lock(): | ||||||
|  |             for msg in self._batch: | ||||||
|  |                 await self.send_all(msg) | ||||||
|  | 
 | ||||||
|  |             self._batch = [] | ||||||
|  |             if new_batch_size: | ||||||
|  |                 self.batch_size = new_batch_size | ||||||
|  | 
 | ||||||
|  |     async def send(self, value: bytes) -> None: | ||||||
|  |         async with self._maybe_lock(): | ||||||
|  |             msg: bytes = struct.pack("<I", len(value)) + value | ||||||
|  |             if self.batch_size == 1: | ||||||
|  |                 await self.send_all(msg) | ||||||
|  |                 return | ||||||
|  | 
 | ||||||
|  |             self._batch.append(msg) | ||||||
|  |             if self.must_flush: | ||||||
|  |                 await self.flush() | ||||||
|  | 
 | ||||||
|  |     async def send_eof(self) -> None: | ||||||
|  |         async with self._send_lock: | ||||||
|  |             await self.flush(new_batch_size=1) | ||||||
|  |             await self.send(b'') | ||||||
|  | 
 | ||||||
|     def open(self): |     def open(self): | ||||||
|         try: |         try: | ||||||
|             self._shm = SharedMemory( |             self._shm = SharedMemory( | ||||||
|  | @ -238,7 +331,6 @@ class RingBuffSender(trio.abc.SendStream): | ||||||
|             self._shm.close() |             self._shm.close() | ||||||
| 
 | 
 | ||||||
|     async def aclose(self): |     async def aclose(self): | ||||||
|         async with self._send_lock: |  | ||||||
|         self.close() |         self.close() | ||||||
| 
 | 
 | ||||||
|     async def __aenter__(self): |     async def __aenter__(self): | ||||||
|  | @ -246,7 +338,7 @@ class RingBuffSender(trio.abc.SendStream): | ||||||
|         return self |         return self | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class RingBuffReceiver(trio.abc.ReceiveStream): | class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]): | ||||||
|     ''' |     ''' | ||||||
|     Ring Buffer receiver side implementation |     Ring Buffer receiver side implementation | ||||||
| 
 | 
 | ||||||
|  | @ -312,21 +404,48 @@ class RingBuffReceiver(trio.abc.ReceiveStream): | ||||||
|         except trio.Cancelled: |         except trio.Cancelled: | ||||||
|             ... |             ... | ||||||
| 
 | 
 | ||||||
|     async def receive_some(self, max_bytes: int | None = None) -> bytes: |     def receive_nowait(self, max_bytes: int = _DEFAULT_RB_SIZE) -> bytes: | ||||||
|  |         ''' | ||||||
|  |         Try to receive any bytes we can without blocking or raise | ||||||
|  |         `trio.WouldBlock`. | ||||||
|  | 
 | ||||||
|  |         ''' | ||||||
|  |         if max_bytes < 1: | ||||||
|  |             raise ValueError("max_bytes must be >= 1") | ||||||
|  | 
 | ||||||
|  |         delta = self._write_ptr - self._ptr | ||||||
|  |         if delta == 0: | ||||||
|  |             raise trio.WouldBlock | ||||||
|  | 
 | ||||||
|  |         # dont overflow caller | ||||||
|  |         delta = min(delta, max_bytes) | ||||||
|  | 
 | ||||||
|  |         target_ptr = self._ptr + delta | ||||||
|  | 
 | ||||||
|  |         # fetch next segment and advance ptr | ||||||
|  |         segment = bytes(self._shm.buf[self._ptr:target_ptr]) | ||||||
|  |         self._ptr = target_ptr | ||||||
|  | 
 | ||||||
|  |         if self._ptr == self.size: | ||||||
|  |             # reached the end, signal wrap around | ||||||
|  |             self._ptr = 0 | ||||||
|  |             self._write_ptr = 0 | ||||||
|  |             self._wrap_event.write(1) | ||||||
|  | 
 | ||||||
|  |         return segment | ||||||
|  | 
 | ||||||
|  |     async def receive_some(self, max_bytes: int = _DEFAULT_RB_SIZE) -> bytes: | ||||||
|         ''' |         ''' | ||||||
|         Receive up to `max_bytes`, if no `max_bytes` is provided |         Receive up to `max_bytes`, if no `max_bytes` is provided | ||||||
|         a reasonable default is used. |         a reasonable default is used. | ||||||
| 
 | 
 | ||||||
|  |         Can return < max_bytes. | ||||||
|  | 
 | ||||||
|         ''' |         ''' | ||||||
|         if max_bytes is None: |         try: | ||||||
|             max_bytes: int = _DEFAULT_RB_SIZE |             return self.receive_nowait(max_bytes=max_bytes) | ||||||
| 
 | 
 | ||||||
|         if max_bytes < 1: |         except trio.WouldBlock: | ||||||
|             raise ValueError("max_bytes must be >= 1") |  | ||||||
| 
 |  | ||||||
|         # delta is remaining bytes we havent read |  | ||||||
|         delta = self._write_ptr - self._ptr |  | ||||||
|         if delta == 0: |  | ||||||
|             # we have read all we can, see if new data is available |             # we have read all we can, see if new data is available | ||||||
|             if self._end_ptr < 0: |             if self._end_ptr < 0: | ||||||
|                 # if we havent been signaled about EOF yet |                 # if we havent been signaled about EOF yet | ||||||
|  | @ -353,22 +472,39 @@ class RingBuffReceiver(trio.abc.ReceiveStream): | ||||||
|                 # no more bytes to read and self._end_ptr set, EOF reached |                 # no more bytes to read and self._end_ptr set, EOF reached | ||||||
|                 return b'' |                 return b'' | ||||||
| 
 | 
 | ||||||
|         # dont overflow caller |         return await self.receive_some(max_bytes=max_bytes) | ||||||
|         delta = min(delta, max_bytes) |  | ||||||
| 
 | 
 | ||||||
|         target_ptr = self._ptr + delta |     async def receive_exactly(self, num_bytes: int) -> bytes: | ||||||
|  |         ''' | ||||||
|  |         Fetch bytes until we read exactly `num_bytes` or EOF. | ||||||
| 
 | 
 | ||||||
|         # fetch next segment and advance ptr |         ''' | ||||||
|         segment = bytes(self._shm.buf[self._ptr:target_ptr]) |         payload = b'' | ||||||
|         self._ptr = target_ptr |         while len(payload) < num_bytes: | ||||||
|  |             remaining = num_bytes - len(payload) | ||||||
| 
 | 
 | ||||||
|         if self._ptr == self.size: |             new_bytes = await self.receive_some( | ||||||
|             # reached the end, signal wrap around |                 max_bytes=remaining | ||||||
|             self._ptr = 0 |             ) | ||||||
|             self._write_ptr = 0 |  | ||||||
|             self._wrap_event.write(1) |  | ||||||
| 
 | 
 | ||||||
|         return segment |             if new_bytes == b'': | ||||||
|  |                 raise trio.EndOfChannel | ||||||
|  | 
 | ||||||
|  |             payload += new_bytes | ||||||
|  | 
 | ||||||
|  |         return payload | ||||||
|  | 
 | ||||||
|  |     async def receive(self) -> bytes: | ||||||
|  |         ''' | ||||||
|  |         Receive a complete payload | ||||||
|  | 
 | ||||||
|  |         ''' | ||||||
|  |         header: bytes = await self.receive_exactly(4) | ||||||
|  |         size: int | ||||||
|  |         size, = struct.unpack("<I", header) | ||||||
|  |         if size == 0: | ||||||
|  |             raise trio.EndOfChannel | ||||||
|  |         return await self.receive_exactly(size) | ||||||
| 
 | 
 | ||||||
|     def open(self): |     def open(self): | ||||||
|         try: |         try: | ||||||
|  | @ -402,18 +538,20 @@ class RingBuffReceiver(trio.abc.ReceiveStream): | ||||||
| 
 | 
 | ||||||
| @acm | @acm | ||||||
| async def attach_to_ringbuf_receiver( | async def attach_to_ringbuf_receiver( | ||||||
|  | 
 | ||||||
|     token: RBToken, |     token: RBToken, | ||||||
|     cleanup: bool = True |     cleanup: bool = True | ||||||
| ) -> AsyncContextManager[RingBuffReceiver]: | 
 | ||||||
|  | ) -> AsyncContextManager[RingBufferReceiveChannel]: | ||||||
|     ''' |     ''' | ||||||
|     Attach a RingBuffReceiver from a previously opened |     Attach a RingBufferReceiveChannel from a previously opened | ||||||
|     RBToken. |     RBToken. | ||||||
| 
 | 
 | ||||||
|     Launches `receiver._eof_monitor_task` in a `trio.Nursery`. |     Launches `receiver._eof_monitor_task` in a `trio.Nursery`. | ||||||
|     ''' |     ''' | ||||||
|     async with ( |     async with ( | ||||||
|         trio.open_nursery() as n, |         trio.open_nursery() as n, | ||||||
|         RingBuffReceiver( |         RingBufferReceiveChannel( | ||||||
|             token, |             token, | ||||||
|             cleanup=cleanup |             cleanup=cleanup | ||||||
|         ) as receiver |         ) as receiver | ||||||
|  | @ -424,232 +562,33 @@ async def attach_to_ringbuf_receiver( | ||||||
| 
 | 
 | ||||||
| @acm | @acm | ||||||
| async def attach_to_ringbuf_sender( | async def attach_to_ringbuf_sender( | ||||||
|  | 
 | ||||||
|     token: RBToken, |     token: RBToken, | ||||||
|     cleanup: bool = True |     cleanup: bool = True | ||||||
| ) -> AsyncContextManager[RingBuffSender]: | 
 | ||||||
|  | ) -> AsyncContextManager[RingBufferSendChannel]: | ||||||
|     ''' |     ''' | ||||||
|     Attach a RingBuffSender from a previously opened |     Attach a RingBufferSendChannel from a previously opened | ||||||
|     RBToken. |     RBToken. | ||||||
| 
 | 
 | ||||||
|     ''' |     ''' | ||||||
|     async with RingBuffSender( |     async with RingBufferSendChannel( | ||||||
|         token, |         token, | ||||||
|         cleanup=cleanup |         cleanup=cleanup | ||||||
|     ) as sender: |     ) as sender: | ||||||
|         yield sender |         yield sender | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @cm | class RingBufferChannel(trio.abc.Channel[bytes]): | ||||||
| def open_ringbuf_pair( |  | ||||||
|     name: str, |  | ||||||
|     buf_size: int = _DEFAULT_RB_SIZE |  | ||||||
| ) -> ContextManager[tuple(RBToken, RBToken)]: |  | ||||||
|     ''' |     ''' | ||||||
|     Handle resources for a ringbuf pair to be used for |     Combine `RingBufferSendChannel` and `RingBufferReceiveChannel` | ||||||
|     bidirectional messaging. |  | ||||||
| 
 |  | ||||||
|     ''' |  | ||||||
|     with ( |  | ||||||
|         open_ringbuf( |  | ||||||
|             name + '.pair0', |  | ||||||
|             buf_size=buf_size |  | ||||||
|         ) as token_0, |  | ||||||
| 
 |  | ||||||
|         open_ringbuf( |  | ||||||
|             name + '.pair1', |  | ||||||
|             buf_size=buf_size |  | ||||||
|         ) as token_1 |  | ||||||
|     ): |  | ||||||
|         yield token_0, token_1 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| @acm |  | ||||||
| async def attach_to_ringbuf_stream( |  | ||||||
|     token_in: RBToken, |  | ||||||
|     token_out: RBToken, |  | ||||||
|     cleanup_in: bool = True, |  | ||||||
|     cleanup_out: bool = True |  | ||||||
| ) -> AsyncContextManager[trio.StapledStream]: |  | ||||||
|     ''' |  | ||||||
|     Attach a trio.StapledStream from a previously opened |  | ||||||
|     ringbuf pair. |  | ||||||
| 
 |  | ||||||
|     ''' |  | ||||||
|     async with ( |  | ||||||
|         attach_to_ringbuf_receiver( |  | ||||||
|             token_in, |  | ||||||
|             cleanup=cleanup_in |  | ||||||
|         ) as receiver, |  | ||||||
|         attach_to_ringbuf_sender( |  | ||||||
|             token_out, |  | ||||||
|             cleanup=cleanup_out |  | ||||||
|         ) as sender, |  | ||||||
|     ): |  | ||||||
|         yield trio.StapledStream(sender, receiver) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class RingBuffBytesSender(trio.abc.SendChannel[bytes]): |  | ||||||
|     ''' |  | ||||||
|     In order to guarantee full messages are received, all bytes |  | ||||||
|     sent by `RingBuffBytesSender` are preceded with a 4 byte header |  | ||||||
|     which decodes into a uint32 indicating the actual size of the |  | ||||||
|     next payload. |  | ||||||
| 
 |  | ||||||
|     Optional batch mode: |  | ||||||
| 
 |  | ||||||
|     If `batch_size` > 1 messages wont get sent immediately but will be |  | ||||||
|     stored until `batch_size` messages are pending, then it will send |  | ||||||
|     them all at once. |  | ||||||
| 
 |  | ||||||
|     `batch_size` can be changed dynamically but always call, `flush()` |  | ||||||
|     right before. |  | ||||||
| 
 |  | ||||||
|     ''' |  | ||||||
|     def __init__( |  | ||||||
|         self, |  | ||||||
|         sender: RingBuffSender, |  | ||||||
|         batch_size: int = 1 |  | ||||||
|     ): |  | ||||||
|         self._sender = sender |  | ||||||
|         self.batch_size = batch_size |  | ||||||
|         self._batch_msg_len = 0 |  | ||||||
|         self._batch: bytes = b'' |  | ||||||
|         self._send_lock = trio.StrictFIFOLock() |  | ||||||
| 
 |  | ||||||
|     @property |  | ||||||
|     def pending_msgs(self) -> int: |  | ||||||
|         return self._batch_msg_len |  | ||||||
| 
 |  | ||||||
|     @property |  | ||||||
|     def must_flush(self) -> bool: |  | ||||||
|         return self._batch_msg_len >= self.batch_size |  | ||||||
| 
 |  | ||||||
|     async def _flush( |  | ||||||
|         self, |  | ||||||
|         new_batch_size: int | None = None |  | ||||||
|     ) -> None: |  | ||||||
|         await self._sender.send_all(self._batch) |  | ||||||
|         self._batch = b'' |  | ||||||
|         self._batch_msg_len = 0 |  | ||||||
|         if new_batch_size: |  | ||||||
|             self.batch_size = new_batch_size |  | ||||||
| 
 |  | ||||||
|     async def flush( |  | ||||||
|         self, |  | ||||||
|         new_batch_size: int | None = None |  | ||||||
|     ) -> None: |  | ||||||
|         async with self._send_lock: |  | ||||||
|             await self._flush(new_batch_size=new_batch_size) |  | ||||||
| 
 |  | ||||||
|     async def send(self, value: bytes) -> None: |  | ||||||
|         async with self._send_lock: |  | ||||||
|             msg: bytes = struct.pack("<I", len(value)) + value |  | ||||||
|             if self.batch_size == 1: |  | ||||||
|                 await self._sender.send_all(msg) |  | ||||||
|                 return |  | ||||||
| 
 |  | ||||||
|             self._batch += msg |  | ||||||
|             self._batch_msg_len += 1 |  | ||||||
|             if self.must_flush: |  | ||||||
|                 await self._flush() |  | ||||||
| 
 |  | ||||||
|     async def send_eof(self) -> None: |  | ||||||
|         await self.flush(new_batch_size=1) |  | ||||||
|         await self.send(b'') |  | ||||||
| 
 |  | ||||||
|     async def aclose(self) -> None: |  | ||||||
|         async with self._send_lock: |  | ||||||
|             await self._sender.aclose() |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class RingBuffBytesReceiver(trio.abc.ReceiveChannel[bytes]): |  | ||||||
|     ''' |  | ||||||
|     See `RingBuffBytesSender` docstring. |  | ||||||
| 
 |  | ||||||
|     A `tricycle.BufferedReceiveStream` is used for the |  | ||||||
|     `receive_exactly` API. |  | ||||||
|     ''' |  | ||||||
|     def __init__( |  | ||||||
|         self, |  | ||||||
|         receiver: RingBuffReceiver |  | ||||||
|     ): |  | ||||||
|         self._receiver = receiver |  | ||||||
| 
 |  | ||||||
|     async def _receive_exactly(self, num_bytes: int) -> bytes: |  | ||||||
|         ''' |  | ||||||
|         Fetch bytes from receiver until we read exactly `num_bytes` |  | ||||||
|         or end of stream is signaled. |  | ||||||
| 
 |  | ||||||
|         ''' |  | ||||||
|         payload = b'' |  | ||||||
|         while len(payload) < num_bytes: |  | ||||||
|             remaining = num_bytes - len(payload) |  | ||||||
| 
 |  | ||||||
|             new_bytes = await self._receiver.receive_some( |  | ||||||
|                 max_bytes=remaining |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|             if new_bytes == b'': |  | ||||||
|                 raise trio.EndOfChannel |  | ||||||
| 
 |  | ||||||
|             payload += new_bytes |  | ||||||
| 
 |  | ||||||
|         return payload |  | ||||||
| 
 |  | ||||||
|     async def receive(self) -> bytes: |  | ||||||
|         header: bytes = await self._receive_exactly(4) |  | ||||||
|         size: int |  | ||||||
|         size, = struct.unpack("<I", header) |  | ||||||
|         if size == 0: |  | ||||||
|             raise trio.EndOfChannel |  | ||||||
|         return await self._receive_exactly(size) |  | ||||||
| 
 |  | ||||||
|     async def aclose(self) -> None: |  | ||||||
|         await self._receiver.aclose() |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| @acm |  | ||||||
| async def attach_to_ringbuf_rchannel( |  | ||||||
|     token: RBToken, |  | ||||||
|     cleanup: bool = True |  | ||||||
| ) -> AsyncContextManager[RingBuffBytesReceiver]: |  | ||||||
|     ''' |  | ||||||
|     Attach a RingBuffBytesReceiver from a previously opened |  | ||||||
|     RBToken. |  | ||||||
|     ''' |  | ||||||
|     async with attach_to_ringbuf_receiver( |  | ||||||
|         token, cleanup=cleanup |  | ||||||
|     ) as receiver: |  | ||||||
|         yield RingBuffBytesReceiver(receiver) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| @acm |  | ||||||
| async def attach_to_ringbuf_schannel( |  | ||||||
|     token: RBToken, |  | ||||||
|     cleanup: bool = True, |  | ||||||
|     batch_size: int = 1, |  | ||||||
| ) -> AsyncContextManager[RingBuffBytesSender]: |  | ||||||
|     ''' |  | ||||||
|     Attach a RingBuffBytesSender from a previously opened |  | ||||||
|     RBToken. |  | ||||||
|     ''' |  | ||||||
|     async with attach_to_ringbuf_sender( |  | ||||||
|         token, cleanup=cleanup |  | ||||||
|     ) as sender: |  | ||||||
|         yield RingBuffBytesSender(sender, batch_size=batch_size) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class RingBuffChannel(trio.abc.Channel[bytes]): |  | ||||||
|     ''' |  | ||||||
|     Combine `RingBuffBytesSender` and `RingBuffBytesReceiver` |  | ||||||
|     in order to expose the bidirectional `trio.abc.Channel` API. |     in order to expose the bidirectional `trio.abc.Channel` API. | ||||||
| 
 | 
 | ||||||
|     ''' |     ''' | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         sender: RingBuffBytesSender, |         sender: RingBufferSendChannel, | ||||||
|         receiver: RingBuffBytesReceiver |         receiver: RingBufferReceiveChannel | ||||||
|     ): |     ): | ||||||
|         self._sender = sender |         self._sender = sender | ||||||
|         self._receiver = receiver |         self._receiver = receiver | ||||||
|  | @ -666,6 +605,12 @@ class RingBuffChannel(trio.abc.Channel[bytes]): | ||||||
|     def pending_msgs(self) -> int: |     def pending_msgs(self) -> int: | ||||||
|         return self._sender.pending_msgs |         return self._sender.pending_msgs | ||||||
| 
 | 
 | ||||||
|  |     async def send_all(self, value: bytes) -> None: | ||||||
|  |         await self._sender.send_all(value) | ||||||
|  | 
 | ||||||
|  |     async def wait_send_all_might_not_block(self): | ||||||
|  |         await self._sender.wait_send_all_might_not_block() | ||||||
|  | 
 | ||||||
|     async def flush( |     async def flush( | ||||||
|         self, |         self, | ||||||
|         new_batch_size: int | None = None |         new_batch_size: int | None = None | ||||||
|  | @ -678,6 +623,15 @@ class RingBuffChannel(trio.abc.Channel[bytes]): | ||||||
|     async def send_eof(self) -> None: |     async def send_eof(self) -> None: | ||||||
|         await self._sender.send_eof() |         await self._sender.send_eof() | ||||||
| 
 | 
 | ||||||
|  |     def receive_nowait(self, max_bytes: int = _DEFAULT_RB_SIZE) -> bytes: | ||||||
|  |         return self._receiver.receive_nowait(max_bytes=max_bytes) | ||||||
|  | 
 | ||||||
|  |     async def receive_some(self, max_bytes: int = _DEFAULT_RB_SIZE) -> bytes: | ||||||
|  |         return await self._receiver.receive_some(max_bytes=max_bytes) | ||||||
|  | 
 | ||||||
|  |     async def receive_exactly(self, num_bytes: int) -> bytes: | ||||||
|  |         return await self._receiver.receive_exactly(num_bytes) | ||||||
|  | 
 | ||||||
|     async def receive(self) -> bytes: |     async def receive(self) -> bytes: | ||||||
|         return await self._receiver.receive() |         return await self._receiver.receive() | ||||||
| 
 | 
 | ||||||
|  | @ -691,23 +645,20 @@ async def attach_to_ringbuf_channel( | ||||||
|     token_in: RBToken, |     token_in: RBToken, | ||||||
|     token_out: RBToken, |     token_out: RBToken, | ||||||
|     cleanup_in: bool = True, |     cleanup_in: bool = True, | ||||||
|     cleanup_out: bool = True, |     cleanup_out: bool = True | ||||||
|     batch_size: int = 1 | ) -> AsyncContextManager[trio.StapledStream]: | ||||||
| ) -> AsyncContextManager[RingBuffChannel]: |  | ||||||
|     ''' |     ''' | ||||||
|     Attach to an already opened ringbuf pair and return |     Attach to two previously opened `RBToken`s and return a `RingBufferChannel` | ||||||
|     a `RingBuffChannel`. |  | ||||||
| 
 | 
 | ||||||
|     ''' |     ''' | ||||||
|     async with ( |     async with ( | ||||||
|         attach_to_ringbuf_rchannel( |         attach_to_ringbuf_receiver( | ||||||
|             token_in, |             token_in, | ||||||
|             cleanup=cleanup_in |             cleanup=cleanup_in | ||||||
|         ) as receiver, |         ) as receiver, | ||||||
|         attach_to_ringbuf_schannel( |         attach_to_ringbuf_sender( | ||||||
|             token_out, |             token_out, | ||||||
|             cleanup=cleanup_out, |             cleanup=cleanup_out | ||||||
|             batch_size=batch_size |  | ||||||
|         ) as sender, |         ) as sender, | ||||||
|     ): |     ): | ||||||
|         yield RingBuffChannel(sender, receiver) |         yield RingBufferChannel(sender, receiver) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue