Rename RingBuff -> RingBuffer
Combine RingBuffer stream and channel apis Implement RingBufferReceiveChannel.receive_nowait Make msg generator calculate hash
							parent
							
								
									70d72fd173
								
							
						
					
					
						commit
						eb20e5ea8d
					
				|  | @ -8,7 +8,6 @@ from tractor.ipc import ( | |||
|     open_ringbuf, | ||||
|     attach_to_ringbuf_receiver, | ||||
|     attach_to_ringbuf_sender, | ||||
|     attach_to_ringbuf_stream, | ||||
|     attach_to_ringbuf_channel, | ||||
|     RBToken, | ||||
| ) | ||||
|  | @ -21,7 +20,6 @@ from tractor._testing.samples import ( | |||
| @tractor.context | ||||
| async def child_read_shm( | ||||
|     ctx: tractor.Context, | ||||
|     msg_amount: int, | ||||
|     token: RBToken, | ||||
| ) -> str: | ||||
|     ''' | ||||
|  | @ -37,11 +35,13 @@ async def child_read_shm( | |||
|     ''' | ||||
|     await ctx.started() | ||||
|     print('reader started') | ||||
|     msg_amount = 0 | ||||
|     recvd_bytes = 0 | ||||
|     recvd_hash = hashlib.sha256() | ||||
|     start_ts = time.time() | ||||
|     async with attach_to_ringbuf_receiver(token) as receiver: | ||||
|         async for msg in receiver: | ||||
|             msg_amount += 1 | ||||
|             recvd_hash.update(msg) | ||||
|             recvd_bytes += len(msg) | ||||
| 
 | ||||
|  | @ -75,19 +75,16 @@ async def child_write_shm( | |||
|     Attach to ringbuf and send all generated messages. | ||||
| 
 | ||||
|     ''' | ||||
|     msgs, _total_bytes = generate_sample_messages( | ||||
|     sent_hash, msgs, _total_bytes = generate_sample_messages( | ||||
|         msg_amount, | ||||
|         rand_min=rand_min, | ||||
|         rand_max=rand_max, | ||||
|     ) | ||||
|     print('writer hashing payload...') | ||||
|     sent_hash = hashlib.sha256(b''.join(msgs)).hexdigest() | ||||
|     print('writer done hashing.') | ||||
|     await ctx.started(sent_hash) | ||||
|     print('writer started') | ||||
|     async with attach_to_ringbuf_sender(token, cleanup=False) as sender: | ||||
|         for msg in msgs: | ||||
|             await sender.send_all(msg) | ||||
|             await sender.send(msg) | ||||
| 
 | ||||
|     print('writer exit') | ||||
| 
 | ||||
|  | @ -155,7 +152,6 @@ def test_ringbuf( | |||
|                     recv_p.open_context( | ||||
|                         child_read_shm, | ||||
|                         token=token, | ||||
|                         msg_amount=msg_amount | ||||
|                     ) as (rctx, _sent), | ||||
|                 ): | ||||
|                     recvd_hash = await rctx.result() | ||||
|  | @ -291,75 +287,6 @@ def test_receiver_max_bytes(): | |||
|     assert msg == b''.join(msgs) | ||||
| 
 | ||||
| 
 | ||||
| def test_stapled_ringbuf(): | ||||
|     ''' | ||||
|     Open two ringbufs and give tokens to tasks (swap them such that in/out tokens | ||||
|     are inversed on each task) which will open the streams and use trio.StapledStream | ||||
|     to have a single bidirectional stream. | ||||
| 
 | ||||
|     Then take turns to send and receive messages. | ||||
| 
 | ||||
|     ''' | ||||
|     msg = generate_single_byte_msgs(100) | ||||
|     pair_0_msgs = [] | ||||
|     pair_1_msgs = [] | ||||
| 
 | ||||
|     pair_0_done = trio.Event() | ||||
|     pair_1_done = trio.Event() | ||||
| 
 | ||||
|     async def pair_0(token_in: RBToken, token_out: RBToken): | ||||
|         async with attach_to_ringbuf_stream( | ||||
|             token_in, | ||||
|             token_out, | ||||
|             cleanup_in=False, | ||||
|             cleanup_out=False | ||||
|         ) as stream: | ||||
|             # first turn to send | ||||
|             await stream.send_all(msg) | ||||
| 
 | ||||
|             # second turn to receive | ||||
|             while len(pair_0_msgs) != len(msg): | ||||
|                 _msg = await stream.receive_some(max_bytes=1) | ||||
|                 pair_0_msgs.append(_msg) | ||||
| 
 | ||||
|             pair_0_done.set() | ||||
|             await pair_1_done.wait() | ||||
| 
 | ||||
| 
 | ||||
|     async def pair_1(token_in: RBToken, token_out: RBToken): | ||||
|         async with attach_to_ringbuf_stream( | ||||
|             token_in, | ||||
|             token_out, | ||||
|             cleanup_in=False, | ||||
|             cleanup_out=False | ||||
|         ) as stream: | ||||
|             # first turn to receive | ||||
|             while len(pair_1_msgs) != len(msg): | ||||
|                 _msg = await stream.receive_some(max_bytes=1) | ||||
|                 pair_1_msgs.append(_msg) | ||||
| 
 | ||||
|             # second turn to send | ||||
|             await stream.send_all(msg) | ||||
| 
 | ||||
|             pair_1_done.set() | ||||
|             await pair_0_done.wait() | ||||
| 
 | ||||
| 
 | ||||
|     async def main(): | ||||
|         with tractor.ipc.open_ringbuf_pair( | ||||
|             'test_stapled_ringbuf' | ||||
|         ) as (token_0, token_1): | ||||
|             async with trio.open_nursery() as n: | ||||
|                 n.start_soon(pair_0, token_0, token_1) | ||||
|                 n.start_soon(pair_1, token_1, token_0) | ||||
| 
 | ||||
| 
 | ||||
|     trio.run(main) | ||||
| 
 | ||||
|     assert msg == b''.join(pair_0_msgs) | ||||
|     assert msg == b''.join(pair_1_msgs) | ||||
| 
 | ||||
| 
 | ||||
| @tractor.context | ||||
| async def child_channel_sender( | ||||
|     ctx: tractor.Context, | ||||
|  | @ -369,7 +296,7 @@ async def child_channel_sender( | |||
|     token_out: RBToken | ||||
| ): | ||||
|     import random | ||||
|     msgs, _total_bytes = generate_sample_messages( | ||||
|     _hash, msgs, _total_bytes = generate_sample_messages( | ||||
|         random.randint(msg_amount_min, msg_amount_max), | ||||
|         rand_min=256, | ||||
|         rand_max=1024, | ||||
|  | @ -379,7 +306,6 @@ async def child_channel_sender( | |||
|         token_out | ||||
|     ) as chan: | ||||
|         await ctx.started(msgs) | ||||
| 
 | ||||
|         for msg in msgs: | ||||
|             await chan.send(msg) | ||||
| 
 | ||||
|  | @ -392,16 +318,16 @@ def test_channel(): | |||
|     async def main(): | ||||
|         with tractor.ipc.open_ringbuf_pair( | ||||
|             'test_ringbuf_transport' | ||||
|         ) as (token_0, token_1): | ||||
|         ) as (send_token, recv_token): | ||||
|             async with ( | ||||
|                 attach_to_ringbuf_channel(token_0, token_1) as chan, | ||||
|                 attach_to_ringbuf_channel(send_token, recv_token) as chan, | ||||
|                 tractor.open_nursery() as an | ||||
|             ): | ||||
|                 recv_p = await an.start_actor( | ||||
|                     'test_ringbuf_transport_sender', | ||||
|                     enable_modules=[__name__], | ||||
|                     proc_kwargs={ | ||||
|                         'pass_fds': token_0.fds + token_1.fds | ||||
|                         'pass_fds': send_token.fds + recv_token.fds | ||||
|                     } | ||||
|                 ) | ||||
|                 async with ( | ||||
|  | @ -409,8 +335,8 @@ def test_channel(): | |||
|                         child_channel_sender, | ||||
|                         msg_amount_min=msg_amount_min, | ||||
|                         msg_amount_max=msg_amount_max, | ||||
|                         token_in=token_1, | ||||
|                         token_out=token_0 | ||||
|                         token_in=recv_token, | ||||
|                         token_out=send_token | ||||
|                     ) as (ctx, msgs), | ||||
|                 ): | ||||
|                     recv_msgs = [] | ||||
|  |  | |||
|  | @ -1,5 +1,6 @@ | |||
| import os | ||||
| import random | ||||
| import hashlib | ||||
| 
 | ||||
| 
 | ||||
| def generate_single_byte_msgs(amount: int) -> bytes: | ||||
|  | @ -23,7 +24,7 @@ def generate_sample_messages( | |||
|     rand_min: int = 0, | ||||
|     rand_max: int = 0, | ||||
|     silent: bool = False, | ||||
| ) -> tuple[list[bytes], int]: | ||||
| ) -> tuple[str, list[bytes], int]: | ||||
|     ''' | ||||
|     Generate bytes msgs for tests. | ||||
| 
 | ||||
|  | @ -55,6 +56,7 @@ def generate_sample_messages( | |||
|         else: | ||||
|             log_interval = 1000 | ||||
| 
 | ||||
|     payload_hash = hashlib.sha256() | ||||
|     for i in range(amount): | ||||
|         msg = f'[{i:08}]'.encode('utf-8') | ||||
| 
 | ||||
|  | @ -64,6 +66,7 @@ def generate_sample_messages( | |||
| 
 | ||||
|         size += len(msg) | ||||
| 
 | ||||
|         payload_hash.update(msg) | ||||
|         msgs.append(msg) | ||||
| 
 | ||||
|         if ( | ||||
|  | @ -78,4 +81,4 @@ def generate_sample_messages( | |||
|     if not silent: | ||||
|         print(f'done, {size:,} bytes in total') | ||||
| 
 | ||||
|     return msgs, size | ||||
|     return payload_hash.hexdigest(), msgs, size | ||||
|  |  | |||
|  | @ -31,17 +31,16 @@ from ._chan import ( | |||
| if platform.system() == 'Linux': | ||||
|     from ._ringbuf import ( | ||||
|         RBToken as RBToken, | ||||
| 
 | ||||
|         open_ringbuf as open_ringbuf, | ||||
|         RingBuffSender as RingBuffSender, | ||||
|         RingBuffReceiver as RingBuffReceiver, | ||||
|         open_ringbuf_pair as open_ringbuf_pair, | ||||
|         attach_to_ringbuf_receiver as attach_to_ringbuf_receiver, | ||||
| 
 | ||||
|         RingBufferSendChannel as RingBufferSendChannel, | ||||
|         attach_to_ringbuf_sender as attach_to_ringbuf_sender, | ||||
|         attach_to_ringbuf_stream as attach_to_ringbuf_stream, | ||||
|         RingBuffBytesSender as RingBuffBytesSender, | ||||
|         RingBuffBytesReceiver as RingBuffBytesReceiver, | ||||
|         RingBuffChannel as RingBuffChannel, | ||||
|         attach_to_ringbuf_schannel as attach_to_ringbuf_schannel, | ||||
|         attach_to_ringbuf_rchannel as attach_to_ringbuf_rchannel, | ||||
| 
 | ||||
|         RingBufferReceiveChannel as RingBufferReceiveChannel, | ||||
|         attach_to_ringbuf_receiver as attach_to_ringbuf_receiver, | ||||
| 
 | ||||
|         RingBufferChannel as RingBufferChannel, | ||||
|         attach_to_ringbuf_channel as attach_to_ringbuf_channel, | ||||
|     ) | ||||
|  |  | |||
|  | @ -126,6 +126,30 @@ def open_ringbuf( | |||
|         shm.unlink() | ||||
| 
 | ||||
| 
 | ||||
| @cm | ||||
| def open_ringbuf_pair( | ||||
|     name: str, | ||||
|     buf_size: int = _DEFAULT_RB_SIZE | ||||
| ) -> ContextManager[tuple(RBToken, RBToken)]: | ||||
|     ''' | ||||
|     Handle resources for a ringbuf pair to be used for | ||||
|     bidirectional messaging. | ||||
| 
 | ||||
|     ''' | ||||
|     with ( | ||||
|         open_ringbuf( | ||||
|             name + '.send', | ||||
|             buf_size=buf_size | ||||
|         ) as send_token, | ||||
| 
 | ||||
|         open_ringbuf( | ||||
|             name + '.recv', | ||||
|             buf_size=buf_size | ||||
|         ) as recv_token | ||||
|     ): | ||||
|         yield send_token, recv_token | ||||
| 
 | ||||
| 
 | ||||
| Buffer = bytes | bytearray | memoryview | ||||
| 
 | ||||
| 
 | ||||
|  | @ -135,32 +159,65 @@ IPC Reliable Ring Buffer | |||
| `eventfd(2)` is used for wrap around sync, to signal writes to | ||||
| the reader and end of stream. | ||||
| 
 | ||||
| In order to guarantee full messages are received, all bytes | ||||
| sent by `RingBufferSendChannel` are preceded with a 4 byte header | ||||
| which decodes into a uint32 indicating the actual size of the | ||||
| next full payload. | ||||
| 
 | ||||
| ''' | ||||
| 
 | ||||
| 
 | ||||
| class RingBuffSender(trio.abc.SendStream): | ||||
| class RingBufferSendChannel(trio.abc.SendChannel[bytes]): | ||||
|     ''' | ||||
|     Ring Buffer sender side implementation | ||||
| 
 | ||||
|     Do not use directly! manage with `attach_to_ringbuf_sender` | ||||
|     after having opened a ringbuf context with `open_ringbuf`. | ||||
| 
 | ||||
|     Optional batch mode: | ||||
| 
 | ||||
|     If `batch_size` > 1 messages wont get sent immediately but will be | ||||
|     stored until `batch_size` messages are pending, then it will send | ||||
|     them all at once. | ||||
| 
 | ||||
|     `batch_size` can be changed dynamically but always call, `flush()` | ||||
|     right before. | ||||
| 
 | ||||
|     ''' | ||||
|     def __init__( | ||||
|         self, | ||||
|         token: RBToken, | ||||
|         batch_size: int = 1, | ||||
|         cleanup: bool = False | ||||
|     ): | ||||
|         self._token = RBToken.from_msg(token) | ||||
|         self.batch_size = batch_size | ||||
| 
 | ||||
|         # ringbuf os resources | ||||
|         self._shm: SharedMemory | None = None | ||||
|         self._write_event = EventFD(self._token.write_eventfd, 'w') | ||||
|         self._wrap_event = EventFD(self._token.wrap_eventfd, 'r') | ||||
|         self._eof_event = EventFD(self._token.eof_eventfd, 'w') | ||||
| 
 | ||||
|         # current write pointer | ||||
|         self._ptr = 0 | ||||
| 
 | ||||
|         # when `batch_size` > 1 store messages on `self._batch` and write them | ||||
|         # all, once `len(self._batch) == `batch_size` | ||||
|         self._batch: list[bytes] = [] | ||||
| 
 | ||||
|         self._cleanup = cleanup | ||||
|         self._send_lock = trio.StrictFIFOLock() | ||||
| 
 | ||||
|     @acm | ||||
|     async def _maybe_lock(self) -> AsyncContextManager[None]: | ||||
|         if self._send_lock.locked(): | ||||
|             yield | ||||
|             return | ||||
| 
 | ||||
|         async with self._send_lock: | ||||
|             yield | ||||
| 
 | ||||
|     @property | ||||
|     def name(self) -> str: | ||||
|         if not self._shm: | ||||
|  | @ -183,11 +240,19 @@ class RingBuffSender(trio.abc.SendStream): | |||
|     def wrap_fd(self) -> int: | ||||
|         return self._wrap_event.fd | ||||
| 
 | ||||
|     @property | ||||
|     def pending_msgs(self) -> int: | ||||
|         return len(self._batch) | ||||
| 
 | ||||
|     @property | ||||
|     def must_flush(self) -> bool: | ||||
|         return self.pending_msgs >= self.batch_size | ||||
| 
 | ||||
|     async def _wait_wrap(self): | ||||
|         await self._wrap_event.read() | ||||
| 
 | ||||
|     async def send_all(self, data: Buffer): | ||||
|         async with self._send_lock: | ||||
|         async with self._maybe_lock(): | ||||
|             # while data is larger than the remaining buf | ||||
|             target_ptr = self.ptr + len(data) | ||||
|             while target_ptr > self.size: | ||||
|  | @ -211,6 +276,34 @@ class RingBuffSender(trio.abc.SendStream): | |||
|     async def wait_send_all_might_not_block(self): | ||||
|         raise NotImplementedError | ||||
| 
 | ||||
|     async def flush( | ||||
|         self, | ||||
|         new_batch_size: int | None = None | ||||
|     ) -> None: | ||||
|         async with self._maybe_lock(): | ||||
|             for msg in self._batch: | ||||
|                 await self.send_all(msg) | ||||
| 
 | ||||
|             self._batch = [] | ||||
|             if new_batch_size: | ||||
|                 self.batch_size = new_batch_size | ||||
| 
 | ||||
|     async def send(self, value: bytes) -> None: | ||||
|         async with self._maybe_lock(): | ||||
|             msg: bytes = struct.pack("<I", len(value)) + value | ||||
|             if self.batch_size == 1: | ||||
|                 await self.send_all(msg) | ||||
|                 return | ||||
| 
 | ||||
|             self._batch.append(msg) | ||||
|             if self.must_flush: | ||||
|                 await self.flush() | ||||
| 
 | ||||
|     async def send_eof(self) -> None: | ||||
|         async with self._send_lock: | ||||
|             await self.flush(new_batch_size=1) | ||||
|             await self.send(b'') | ||||
| 
 | ||||
|     def open(self): | ||||
|         try: | ||||
|             self._shm = SharedMemory( | ||||
|  | @ -238,7 +331,6 @@ class RingBuffSender(trio.abc.SendStream): | |||
|             self._shm.close() | ||||
| 
 | ||||
|     async def aclose(self): | ||||
|         async with self._send_lock: | ||||
|         self.close() | ||||
| 
 | ||||
|     async def __aenter__(self): | ||||
|  | @ -246,7 +338,7 @@ class RingBuffSender(trio.abc.SendStream): | |||
|         return self | ||||
| 
 | ||||
| 
 | ||||
| class RingBuffReceiver(trio.abc.ReceiveStream): | ||||
| class RingBufferReceiveChannel(trio.abc.ReceiveChannel[bytes]): | ||||
|     ''' | ||||
|     Ring Buffer receiver side implementation | ||||
| 
 | ||||
|  | @ -312,21 +404,48 @@ class RingBuffReceiver(trio.abc.ReceiveStream): | |||
|         except trio.Cancelled: | ||||
|             ... | ||||
| 
 | ||||
|     async def receive_some(self, max_bytes: int | None = None) -> bytes: | ||||
|     def receive_nowait(self, max_bytes: int = _DEFAULT_RB_SIZE) -> bytes: | ||||
|         ''' | ||||
|         Try to receive any bytes we can without blocking or raise | ||||
|         `trio.WouldBlock`. | ||||
| 
 | ||||
|         ''' | ||||
|         if max_bytes < 1: | ||||
|             raise ValueError("max_bytes must be >= 1") | ||||
| 
 | ||||
|         delta = self._write_ptr - self._ptr | ||||
|         if delta == 0: | ||||
|             raise trio.WouldBlock | ||||
| 
 | ||||
|         # dont overflow caller | ||||
|         delta = min(delta, max_bytes) | ||||
| 
 | ||||
|         target_ptr = self._ptr + delta | ||||
| 
 | ||||
|         # fetch next segment and advance ptr | ||||
|         segment = bytes(self._shm.buf[self._ptr:target_ptr]) | ||||
|         self._ptr = target_ptr | ||||
| 
 | ||||
|         if self._ptr == self.size: | ||||
|             # reached the end, signal wrap around | ||||
|             self._ptr = 0 | ||||
|             self._write_ptr = 0 | ||||
|             self._wrap_event.write(1) | ||||
| 
 | ||||
|         return segment | ||||
| 
 | ||||
|     async def receive_some(self, max_bytes: int = _DEFAULT_RB_SIZE) -> bytes: | ||||
|         ''' | ||||
|         Receive up to `max_bytes`, if no `max_bytes` is provided | ||||
|         a reasonable default is used. | ||||
| 
 | ||||
|         Can return < max_bytes. | ||||
| 
 | ||||
|         ''' | ||||
|         if max_bytes is None: | ||||
|             max_bytes: int = _DEFAULT_RB_SIZE | ||||
|         try: | ||||
|             return self.receive_nowait(max_bytes=max_bytes) | ||||
| 
 | ||||
|         if max_bytes < 1: | ||||
|             raise ValueError("max_bytes must be >= 1") | ||||
| 
 | ||||
|         # delta is remaining bytes we havent read | ||||
|         delta = self._write_ptr - self._ptr | ||||
|         if delta == 0: | ||||
|         except trio.WouldBlock: | ||||
|             # we have read all we can, see if new data is available | ||||
|             if self._end_ptr < 0: | ||||
|                 # if we havent been signaled about EOF yet | ||||
|  | @ -353,22 +472,39 @@ class RingBuffReceiver(trio.abc.ReceiveStream): | |||
|                 # no more bytes to read and self._end_ptr set, EOF reached | ||||
|                 return b'' | ||||
| 
 | ||||
|         # dont overflow caller | ||||
|         delta = min(delta, max_bytes) | ||||
|         return await self.receive_some(max_bytes=max_bytes) | ||||
| 
 | ||||
|         target_ptr = self._ptr + delta | ||||
|     async def receive_exactly(self, num_bytes: int) -> bytes: | ||||
|         ''' | ||||
|         Fetch bytes until we read exactly `num_bytes` or EOF. | ||||
| 
 | ||||
|         # fetch next segment and advance ptr | ||||
|         segment = bytes(self._shm.buf[self._ptr:target_ptr]) | ||||
|         self._ptr = target_ptr | ||||
|         ''' | ||||
|         payload = b'' | ||||
|         while len(payload) < num_bytes: | ||||
|             remaining = num_bytes - len(payload) | ||||
| 
 | ||||
|         if self._ptr == self.size: | ||||
|             # reached the end, signal wrap around | ||||
|             self._ptr = 0 | ||||
|             self._write_ptr = 0 | ||||
|             self._wrap_event.write(1) | ||||
|             new_bytes = await self.receive_some( | ||||
|                 max_bytes=remaining | ||||
|             ) | ||||
| 
 | ||||
|         return segment | ||||
|             if new_bytes == b'': | ||||
|                 raise trio.EndOfChannel | ||||
| 
 | ||||
|             payload += new_bytes | ||||
| 
 | ||||
|         return payload | ||||
| 
 | ||||
|     async def receive(self) -> bytes: | ||||
|         ''' | ||||
|         Receive a complete payload | ||||
| 
 | ||||
|         ''' | ||||
|         header: bytes = await self.receive_exactly(4) | ||||
|         size: int | ||||
|         size, = struct.unpack("<I", header) | ||||
|         if size == 0: | ||||
|             raise trio.EndOfChannel | ||||
|         return await self.receive_exactly(size) | ||||
| 
 | ||||
|     def open(self): | ||||
|         try: | ||||
|  | @ -402,18 +538,20 @@ class RingBuffReceiver(trio.abc.ReceiveStream): | |||
| 
 | ||||
| @acm | ||||
| async def attach_to_ringbuf_receiver( | ||||
| 
 | ||||
|     token: RBToken, | ||||
|     cleanup: bool = True | ||||
| ) -> AsyncContextManager[RingBuffReceiver]: | ||||
| 
 | ||||
| ) -> AsyncContextManager[RingBufferReceiveChannel]: | ||||
|     ''' | ||||
|     Attach a RingBuffReceiver from a previously opened | ||||
|     Attach a RingBufferReceiveChannel from a previously opened | ||||
|     RBToken. | ||||
| 
 | ||||
|     Launches `receiver._eof_monitor_task` in a `trio.Nursery`. | ||||
|     ''' | ||||
|     async with ( | ||||
|         trio.open_nursery() as n, | ||||
|         RingBuffReceiver( | ||||
|         RingBufferReceiveChannel( | ||||
|             token, | ||||
|             cleanup=cleanup | ||||
|         ) as receiver | ||||
|  | @ -424,232 +562,33 @@ async def attach_to_ringbuf_receiver( | |||
| 
 | ||||
| @acm | ||||
| async def attach_to_ringbuf_sender( | ||||
| 
 | ||||
|     token: RBToken, | ||||
|     cleanup: bool = True | ||||
| ) -> AsyncContextManager[RingBuffSender]: | ||||
| 
 | ||||
| ) -> AsyncContextManager[RingBufferSendChannel]: | ||||
|     ''' | ||||
|     Attach a RingBuffSender from a previously opened | ||||
|     Attach a RingBufferSendChannel from a previously opened | ||||
|     RBToken. | ||||
| 
 | ||||
|     ''' | ||||
|     async with RingBuffSender( | ||||
|     async with RingBufferSendChannel( | ||||
|         token, | ||||
|         cleanup=cleanup | ||||
|     ) as sender: | ||||
|         yield sender | ||||
| 
 | ||||
| 
 | ||||
| @cm | ||||
| def open_ringbuf_pair( | ||||
|     name: str, | ||||
|     buf_size: int = _DEFAULT_RB_SIZE | ||||
| ) -> ContextManager[tuple(RBToken, RBToken)]: | ||||
| class RingBufferChannel(trio.abc.Channel[bytes]): | ||||
|     ''' | ||||
|     Handle resources for a ringbuf pair to be used for | ||||
|     bidirectional messaging. | ||||
| 
 | ||||
|     ''' | ||||
|     with ( | ||||
|         open_ringbuf( | ||||
|             name + '.pair0', | ||||
|             buf_size=buf_size | ||||
|         ) as token_0, | ||||
| 
 | ||||
|         open_ringbuf( | ||||
|             name + '.pair1', | ||||
|             buf_size=buf_size | ||||
|         ) as token_1 | ||||
|     ): | ||||
|         yield token_0, token_1 | ||||
| 
 | ||||
| 
 | ||||
| @acm | ||||
| async def attach_to_ringbuf_stream( | ||||
|     token_in: RBToken, | ||||
|     token_out: RBToken, | ||||
|     cleanup_in: bool = True, | ||||
|     cleanup_out: bool = True | ||||
| ) -> AsyncContextManager[trio.StapledStream]: | ||||
|     ''' | ||||
|     Attach a trio.StapledStream from a previously opened | ||||
|     ringbuf pair. | ||||
| 
 | ||||
|     ''' | ||||
|     async with ( | ||||
|         attach_to_ringbuf_receiver( | ||||
|             token_in, | ||||
|             cleanup=cleanup_in | ||||
|         ) as receiver, | ||||
|         attach_to_ringbuf_sender( | ||||
|             token_out, | ||||
|             cleanup=cleanup_out | ||||
|         ) as sender, | ||||
|     ): | ||||
|         yield trio.StapledStream(sender, receiver) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| class RingBuffBytesSender(trio.abc.SendChannel[bytes]): | ||||
|     ''' | ||||
|     In order to guarantee full messages are received, all bytes | ||||
|     sent by `RingBuffBytesSender` are preceded with a 4 byte header | ||||
|     which decodes into a uint32 indicating the actual size of the | ||||
|     next payload. | ||||
| 
 | ||||
|     Optional batch mode: | ||||
| 
 | ||||
|     If `batch_size` > 1 messages wont get sent immediately but will be | ||||
|     stored until `batch_size` messages are pending, then it will send | ||||
|     them all at once. | ||||
| 
 | ||||
|     `batch_size` can be changed dynamically but always call, `flush()` | ||||
|     right before. | ||||
| 
 | ||||
|     ''' | ||||
|     def __init__( | ||||
|         self, | ||||
|         sender: RingBuffSender, | ||||
|         batch_size: int = 1 | ||||
|     ): | ||||
|         self._sender = sender | ||||
|         self.batch_size = batch_size | ||||
|         self._batch_msg_len = 0 | ||||
|         self._batch: bytes = b'' | ||||
|         self._send_lock = trio.StrictFIFOLock() | ||||
| 
 | ||||
|     @property | ||||
|     def pending_msgs(self) -> int: | ||||
|         return self._batch_msg_len | ||||
| 
 | ||||
|     @property | ||||
|     def must_flush(self) -> bool: | ||||
|         return self._batch_msg_len >= self.batch_size | ||||
| 
 | ||||
|     async def _flush( | ||||
|         self, | ||||
|         new_batch_size: int | None = None | ||||
|     ) -> None: | ||||
|         await self._sender.send_all(self._batch) | ||||
|         self._batch = b'' | ||||
|         self._batch_msg_len = 0 | ||||
|         if new_batch_size: | ||||
|             self.batch_size = new_batch_size | ||||
| 
 | ||||
|     async def flush( | ||||
|         self, | ||||
|         new_batch_size: int | None = None | ||||
|     ) -> None: | ||||
|         async with self._send_lock: | ||||
|             await self._flush(new_batch_size=new_batch_size) | ||||
| 
 | ||||
|     async def send(self, value: bytes) -> None: | ||||
|         async with self._send_lock: | ||||
|             msg: bytes = struct.pack("<I", len(value)) + value | ||||
|             if self.batch_size == 1: | ||||
|                 await self._sender.send_all(msg) | ||||
|                 return | ||||
| 
 | ||||
|             self._batch += msg | ||||
|             self._batch_msg_len += 1 | ||||
|             if self.must_flush: | ||||
|                 await self._flush() | ||||
| 
 | ||||
|     async def send_eof(self) -> None: | ||||
|         await self.flush(new_batch_size=1) | ||||
|         await self.send(b'') | ||||
| 
 | ||||
|     async def aclose(self) -> None: | ||||
|         async with self._send_lock: | ||||
|             await self._sender.aclose() | ||||
| 
 | ||||
| 
 | ||||
| class RingBuffBytesReceiver(trio.abc.ReceiveChannel[bytes]): | ||||
|     ''' | ||||
|     See `RingBuffBytesSender` docstring. | ||||
| 
 | ||||
|     A `tricycle.BufferedReceiveStream` is used for the | ||||
|     `receive_exactly` API. | ||||
|     ''' | ||||
|     def __init__( | ||||
|         self, | ||||
|         receiver: RingBuffReceiver | ||||
|     ): | ||||
|         self._receiver = receiver | ||||
| 
 | ||||
|     async def _receive_exactly(self, num_bytes: int) -> bytes: | ||||
|         ''' | ||||
|         Fetch bytes from receiver until we read exactly `num_bytes` | ||||
|         or end of stream is signaled. | ||||
| 
 | ||||
|         ''' | ||||
|         payload = b'' | ||||
|         while len(payload) < num_bytes: | ||||
|             remaining = num_bytes - len(payload) | ||||
| 
 | ||||
|             new_bytes = await self._receiver.receive_some( | ||||
|                 max_bytes=remaining | ||||
|             ) | ||||
| 
 | ||||
|             if new_bytes == b'': | ||||
|                 raise trio.EndOfChannel | ||||
| 
 | ||||
|             payload += new_bytes | ||||
| 
 | ||||
|         return payload | ||||
| 
 | ||||
|     async def receive(self) -> bytes: | ||||
|         header: bytes = await self._receive_exactly(4) | ||||
|         size: int | ||||
|         size, = struct.unpack("<I", header) | ||||
|         if size == 0: | ||||
|             raise trio.EndOfChannel | ||||
|         return await self._receive_exactly(size) | ||||
| 
 | ||||
|     async def aclose(self) -> None: | ||||
|         await self._receiver.aclose() | ||||
| 
 | ||||
| 
 | ||||
| @acm | ||||
| async def attach_to_ringbuf_rchannel( | ||||
|     token: RBToken, | ||||
|     cleanup: bool = True | ||||
| ) -> AsyncContextManager[RingBuffBytesReceiver]: | ||||
|     ''' | ||||
|     Attach a RingBuffBytesReceiver from a previously opened | ||||
|     RBToken. | ||||
|     ''' | ||||
|     async with attach_to_ringbuf_receiver( | ||||
|         token, cleanup=cleanup | ||||
|     ) as receiver: | ||||
|         yield RingBuffBytesReceiver(receiver) | ||||
| 
 | ||||
| 
 | ||||
| @acm | ||||
| async def attach_to_ringbuf_schannel( | ||||
|     token: RBToken, | ||||
|     cleanup: bool = True, | ||||
|     batch_size: int = 1, | ||||
| ) -> AsyncContextManager[RingBuffBytesSender]: | ||||
|     ''' | ||||
|     Attach a RingBuffBytesSender from a previously opened | ||||
|     RBToken. | ||||
|     ''' | ||||
|     async with attach_to_ringbuf_sender( | ||||
|         token, cleanup=cleanup | ||||
|     ) as sender: | ||||
|         yield RingBuffBytesSender(sender, batch_size=batch_size) | ||||
| 
 | ||||
| 
 | ||||
| class RingBuffChannel(trio.abc.Channel[bytes]): | ||||
|     ''' | ||||
|     Combine `RingBuffBytesSender` and `RingBuffBytesReceiver` | ||||
|     Combine `RingBufferSendChannel` and `RingBufferReceiveChannel` | ||||
|     in order to expose the bidirectional `trio.abc.Channel` API. | ||||
| 
 | ||||
|     ''' | ||||
|     def __init__( | ||||
|         self, | ||||
|         sender: RingBuffBytesSender, | ||||
|         receiver: RingBuffBytesReceiver | ||||
|         sender: RingBufferSendChannel, | ||||
|         receiver: RingBufferReceiveChannel | ||||
|     ): | ||||
|         self._sender = sender | ||||
|         self._receiver = receiver | ||||
|  | @ -666,6 +605,12 @@ class RingBuffChannel(trio.abc.Channel[bytes]): | |||
|     def pending_msgs(self) -> int: | ||||
|         return self._sender.pending_msgs | ||||
| 
 | ||||
|     async def send_all(self, value: bytes) -> None: | ||||
|         await self._sender.send_all(value) | ||||
| 
 | ||||
|     async def wait_send_all_might_not_block(self): | ||||
|         await self._sender.wait_send_all_might_not_block() | ||||
| 
 | ||||
|     async def flush( | ||||
|         self, | ||||
|         new_batch_size: int | None = None | ||||
|  | @ -678,6 +623,15 @@ class RingBuffChannel(trio.abc.Channel[bytes]): | |||
|     async def send_eof(self) -> None: | ||||
|         await self._sender.send_eof() | ||||
| 
 | ||||
|     def receive_nowait(self, max_bytes: int = _DEFAULT_RB_SIZE) -> bytes: | ||||
|         return self._receiver.receive_nowait(max_bytes=max_bytes) | ||||
| 
 | ||||
|     async def receive_some(self, max_bytes: int = _DEFAULT_RB_SIZE) -> bytes: | ||||
|         return await self._receiver.receive_some(max_bytes=max_bytes) | ||||
| 
 | ||||
|     async def receive_exactly(self, num_bytes: int) -> bytes: | ||||
|         return await self._receiver.receive_exactly(num_bytes) | ||||
| 
 | ||||
|     async def receive(self) -> bytes: | ||||
|         return await self._receiver.receive() | ||||
| 
 | ||||
|  | @ -691,23 +645,20 @@ async def attach_to_ringbuf_channel( | |||
|     token_in: RBToken, | ||||
|     token_out: RBToken, | ||||
|     cleanup_in: bool = True, | ||||
|     cleanup_out: bool = True, | ||||
|     batch_size: int = 1 | ||||
| ) -> AsyncContextManager[RingBuffChannel]: | ||||
|     cleanup_out: bool = True | ||||
| ) -> AsyncContextManager[trio.StapledStream]: | ||||
|     ''' | ||||
|     Attach to an already opened ringbuf pair and return | ||||
|     a `RingBuffChannel`. | ||||
|     Attach to two previously opened `RBToken`s and return a `RingBufferChannel` | ||||
| 
 | ||||
|     ''' | ||||
|     async with ( | ||||
|         attach_to_ringbuf_rchannel( | ||||
|         attach_to_ringbuf_receiver( | ||||
|             token_in, | ||||
|             cleanup=cleanup_in | ||||
|         ) as receiver, | ||||
|         attach_to_ringbuf_schannel( | ||||
|         attach_to_ringbuf_sender( | ||||
|             token_out, | ||||
|             cleanup=cleanup_out, | ||||
|             batch_size=batch_size | ||||
|             cleanup=cleanup_out | ||||
|         ) as sender, | ||||
|     ): | ||||
|         yield RingBuffChannel(sender, receiver) | ||||
|         yield RingBufferChannel(sender, receiver) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue