Switch `tractor.ipc.MsgTransport.stream` type to `trio.abc.Stream`
Add EOF signaling mechanism Support proper `receive_some` end of stream semantics Add StapledStream non-ipc test Create MsgpackRBStream similar to MsgpackTCPStream for buffered whole-msg reads Add EventFD.read cancellation on EventFD.close mechanism using cancel scope Add test for eventfd cancellation Improve and add docstrings
							parent
							
								
									ba353bf46f
								
							
						
					
					
						commit
						be818a720a
					
				|  | @ -0,0 +1,32 @@ | |||
| import trio | ||||
| import pytest | ||||
| from tractor.ipc import ( | ||||
|     open_eventfd, | ||||
|     EFDReadCancelled, | ||||
|     EventFD | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| def test_eventfd_read_cancellation(): | ||||
|     ''' | ||||
|     Ensure EventFD.read raises EFDReadCancelled if EventFD.close() | ||||
|     is called. | ||||
| 
 | ||||
|     ''' | ||||
|     fd = open_eventfd() | ||||
| 
 | ||||
|     async def _read(event: EventFD): | ||||
|         with pytest.raises(EFDReadCancelled): | ||||
|             await event.read() | ||||
| 
 | ||||
|     async def main(): | ||||
|         async with trio.open_nursery() as n: | ||||
|             with ( | ||||
|                 EventFD(fd, 'w') as event, | ||||
|                 trio.fail_after(3) | ||||
|             ): | ||||
|                 n.start_soon(_read, event) | ||||
|                 await trio.sleep(0.2) | ||||
|                 event.close() | ||||
| 
 | ||||
|     trio.run(main) | ||||
|  | @ -5,11 +5,16 @@ import pytest | |||
| import tractor | ||||
| from tractor.ipc import ( | ||||
|     open_ringbuf, | ||||
|     attach_to_ringbuf_receiver, | ||||
|     attach_to_ringbuf_sender, | ||||
|     attach_to_ringbuf_pair, | ||||
|     attach_to_ringbuf_stream, | ||||
|     RBToken, | ||||
|     RingBuffSender, | ||||
|     RingBuffReceiver | ||||
| ) | ||||
| from tractor._testing.samples import generate_sample_messages | ||||
| from tractor._testing.samples import ( | ||||
|     generate_single_byte_msgs, | ||||
|     generate_sample_messages | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| @tractor.context | ||||
|  | @ -17,20 +22,14 @@ async def child_read_shm( | |||
|     ctx: tractor.Context, | ||||
|     msg_amount: int, | ||||
|     token: RBToken, | ||||
|     total_bytes: int, | ||||
| ) -> None: | ||||
|     recvd_bytes = 0 | ||||
|     await ctx.started() | ||||
|     start_ts = time.time() | ||||
|     async with RingBuffReceiver(token) as receiver: | ||||
|         while recvd_bytes < total_bytes: | ||||
|             msg = await receiver.receive_some() | ||||
|     async with attach_to_ringbuf_receiver(token) as receiver: | ||||
|         async for msg in receiver: | ||||
|             recvd_bytes += len(msg) | ||||
| 
 | ||||
|         # make sure we dont hold any memoryviews | ||||
|         # before the ctx manager aclose() | ||||
|         msg = None | ||||
| 
 | ||||
|     end_ts = time.time() | ||||
|     elapsed = end_ts - start_ts | ||||
|     elapsed_ms = int(elapsed * 1000) | ||||
|  | @ -38,6 +37,7 @@ async def child_read_shm( | |||
|     print(f'\n\telapsed ms: {elapsed_ms}') | ||||
|     print(f'\tmsg/sec: {int(msg_amount / elapsed):,}') | ||||
|     print(f'\tbytes/sec: {int(recvd_bytes / elapsed):,}') | ||||
|     print(f'\treceived bytes: {recvd_bytes}') | ||||
| 
 | ||||
| 
 | ||||
| @tractor.context | ||||
|  | @ -54,7 +54,7 @@ async def child_write_shm( | |||
|         rand_max=rand_max, | ||||
|     ) | ||||
|     await ctx.started(total_bytes) | ||||
|     async with RingBuffSender(token) as sender: | ||||
|     async with attach_to_ringbuf_sender(token, cleanup=False) as sender: | ||||
|         for msg in msgs: | ||||
|             await sender.send_all(msg) | ||||
| 
 | ||||
|  | @ -99,14 +99,8 @@ def test_ringbuf( | |||
|             'test_ringbuf', | ||||
|             buf_size=buf_size | ||||
|         ) as token: | ||||
|             proc_kwargs = { | ||||
|                 'pass_fds': (token.write_eventfd, token.wrap_eventfd) | ||||
|             } | ||||
|             proc_kwargs = {'pass_fds': token.fds} | ||||
| 
 | ||||
|             common_kwargs = { | ||||
|                 'msg_amount': msg_amount, | ||||
|                 'token': token, | ||||
|             } | ||||
|             async with tractor.open_nursery() as an: | ||||
|                 send_p = await an.start_actor( | ||||
|                     'ring_sender', | ||||
|  | @ -121,14 +115,15 @@ def test_ringbuf( | |||
|                 async with ( | ||||
|                     send_p.open_context( | ||||
|                         child_write_shm, | ||||
|                         token=token, | ||||
|                         msg_amount=msg_amount, | ||||
|                         rand_min=rand_min, | ||||
|                         rand_max=rand_max, | ||||
|                         **common_kwargs | ||||
|                     ) as (sctx, total_bytes), | ||||
|                     recv_p.open_context( | ||||
|                         child_read_shm, | ||||
|                         **common_kwargs, | ||||
|                         total_bytes=total_bytes, | ||||
|                         token=token, | ||||
|                         msg_amount=msg_amount | ||||
|                     ) as (sctx, _sent), | ||||
|                 ): | ||||
|                     await recv_p.result() | ||||
|  | @ -145,7 +140,7 @@ async def child_blocked_receiver( | |||
|     ctx: tractor.Context, | ||||
|     token: RBToken | ||||
| ): | ||||
|     async with RingBuffReceiver(token) as receiver: | ||||
|     async with attach_to_ringbuf_receiver(token) as receiver: | ||||
|         await ctx.started() | ||||
|         await receiver.receive_some() | ||||
| 
 | ||||
|  | @ -160,13 +155,13 @@ def test_ring_reader_cancel(): | |||
|         with open_ringbuf('test_ring_cancel_reader') as token: | ||||
|             async with ( | ||||
|                 tractor.open_nursery() as an, | ||||
|                 RingBuffSender(token) as _sender, | ||||
|                 attach_to_ringbuf_sender(token) as _sender, | ||||
|             ): | ||||
|                 recv_p = await an.start_actor( | ||||
|                     'ring_blocked_receiver', | ||||
|                     enable_modules=[__name__], | ||||
|                     proc_kwargs={ | ||||
|                         'pass_fds': (token.write_eventfd, token.wrap_eventfd) | ||||
|                         'pass_fds': token.fds | ||||
|                     } | ||||
|                 ) | ||||
|                 async with ( | ||||
|  | @ -188,7 +183,7 @@ async def child_blocked_sender( | |||
|     ctx: tractor.Context, | ||||
|     token: RBToken | ||||
| ): | ||||
|     async with RingBuffSender(token) as sender: | ||||
|     async with attach_to_ringbuf_sender(token) as sender: | ||||
|         await ctx.started() | ||||
|         await sender.send_all(b'this will wrap') | ||||
| 
 | ||||
|  | @ -209,7 +204,7 @@ def test_ring_sender_cancel(): | |||
|                     'ring_blocked_sender', | ||||
|                     enable_modules=[__name__], | ||||
|                     proc_kwargs={ | ||||
|                         'pass_fds': (token.write_eventfd, token.wrap_eventfd) | ||||
|                         'pass_fds': token.fds | ||||
|                     } | ||||
|                 ) | ||||
|                 async with ( | ||||
|  | @ -235,7 +230,7 @@ def test_ringbuf_max_bytes(): | |||
|     msgs with original message | ||||
| 
 | ||||
|     ''' | ||||
|     msg = b''.join(str(i % 10).encode() for i in range(100)) | ||||
|     msg = generate_single_byte_msgs(100) | ||||
|     msgs = [] | ||||
| 
 | ||||
|     async def main(): | ||||
|  | @ -245,15 +240,153 @@ def test_ringbuf_max_bytes(): | |||
|         ) as token: | ||||
|             async with ( | ||||
|                 trio.open_nursery() as n, | ||||
|                 RingBuffSender(token, is_ipc=False) as sender, | ||||
|                 RingBuffReceiver(token, is_ipc=False) as receiver | ||||
|                 attach_to_ringbuf_sender(token, cleanup=False) as sender, | ||||
|                 attach_to_ringbuf_receiver(token, cleanup=False) as receiver | ||||
|             ): | ||||
|                 n.start_soon(sender.send_all, msg) | ||||
|                 async def _send_and_close(): | ||||
|                     await sender.send_all(msg) | ||||
|                     await sender.aclose() | ||||
| 
 | ||||
|                 n.start_soon(_send_and_close) | ||||
|                 while len(msgs) < len(msg): | ||||
|                     msg_part = await receiver.receive_some(max_bytes=1) | ||||
|                     msg_part = bytes(msg_part) | ||||
|                     assert len(msg_part) == 1 | ||||
|                     msgs.append(msg_part) | ||||
| 
 | ||||
|     trio.run(main) | ||||
|     assert msg == b''.join(msgs) | ||||
| 
 | ||||
| 
 | ||||
| def test_stapled_ringbuf(): | ||||
|     ''' | ||||
|     Open two ringbufs and give tokens to tasks (swap them such that in/out tokens | ||||
|     are inversed on each task) which will open the streams and use trio.StapledStream | ||||
|     to have a single bidirectional stream. | ||||
| 
 | ||||
|     Then take turns to send and receive messages. | ||||
| 
 | ||||
|     ''' | ||||
|     msg = generate_single_byte_msgs(100) | ||||
|     pair_0_msgs = [] | ||||
|     pair_1_msgs = [] | ||||
| 
 | ||||
|     pair_0_done = trio.Event() | ||||
|     pair_1_done = trio.Event() | ||||
| 
 | ||||
|     async def pair_0(token_in: RBToken, token_out: RBToken): | ||||
|         async with attach_to_ringbuf_pair( | ||||
|             token_in, | ||||
|             token_out, | ||||
|             cleanup_in=False, | ||||
|             cleanup_out=False | ||||
|         ) as stream: | ||||
|             # first turn to send | ||||
|             await stream.send_all(msg) | ||||
| 
 | ||||
|             # second turn to receive | ||||
|             while len(pair_0_msgs) != len(msg): | ||||
|                 _msg = await stream.receive_some(max_bytes=1) | ||||
|                 pair_0_msgs.append(_msg) | ||||
| 
 | ||||
|             pair_0_done.set() | ||||
|             await pair_1_done.wait() | ||||
| 
 | ||||
| 
 | ||||
|     async def pair_1(token_in: RBToken, token_out: RBToken): | ||||
|         async with attach_to_ringbuf_pair( | ||||
|             token_in, | ||||
|             token_out, | ||||
|             cleanup_in=False, | ||||
|             cleanup_out=False | ||||
|         ) as stream: | ||||
|             # first turn to receive | ||||
|             while len(pair_1_msgs) != len(msg): | ||||
|                 _msg = await stream.receive_some(max_bytes=1) | ||||
|                 pair_1_msgs.append(_msg) | ||||
| 
 | ||||
|             # second turn to send | ||||
|             await stream.send_all(msg) | ||||
| 
 | ||||
|             pair_1_done.set() | ||||
|             await pair_0_done.wait() | ||||
| 
 | ||||
| 
 | ||||
|     async def main(): | ||||
|         with tractor.ipc.open_ringbuf_pair( | ||||
|             'test_stapled_ringbuf' | ||||
|         ) as (token_0, token_1): | ||||
|             async with trio.open_nursery() as n: | ||||
|                 n.start_soon(pair_0, token_0, token_1) | ||||
|                 n.start_soon(pair_1, token_1, token_0) | ||||
| 
 | ||||
| 
 | ||||
|     trio.run(main) | ||||
| 
 | ||||
|     assert msg == b''.join(pair_0_msgs) | ||||
|     assert msg == b''.join(pair_1_msgs) | ||||
| 
 | ||||
| 
 | ||||
| @tractor.context | ||||
| async def child_transport_sender( | ||||
|     ctx: tractor.Context, | ||||
|     msg_amount_min: int, | ||||
|     msg_amount_max: int, | ||||
|     token_in: RBToken, | ||||
|     token_out: RBToken | ||||
| ): | ||||
|     import random | ||||
|     msgs, _total_bytes = generate_sample_messages( | ||||
|         random.randint(msg_amount_min, msg_amount_max), | ||||
|         rand_min=256, | ||||
|         rand_max=1024, | ||||
|     ) | ||||
|     async with attach_to_ringbuf_stream( | ||||
|         token_in, | ||||
|         token_out | ||||
|     ) as transport: | ||||
|         await ctx.started(msgs) | ||||
| 
 | ||||
|         for msg in msgs: | ||||
|             await transport.send(msg) | ||||
| 
 | ||||
|         await transport.recv() | ||||
| 
 | ||||
| 
 | ||||
| def test_ringbuf_transport(): | ||||
| 
 | ||||
|     msg_amount_min = 100 | ||||
|     msg_amount_max = 1000 | ||||
| 
 | ||||
|     async def main(): | ||||
|         with tractor.ipc.open_ringbuf_pair( | ||||
|             'test_ringbuf_transport' | ||||
|         ) as (token_0, token_1): | ||||
|             async with ( | ||||
|                 attach_to_ringbuf_stream(token_0, token_1) as transport, | ||||
|                 tractor.open_nursery() as an | ||||
|             ): | ||||
|                 recv_p = await an.start_actor( | ||||
|                     'test_ringbuf_transport_sender', | ||||
|                     enable_modules=[__name__], | ||||
|                     proc_kwargs={ | ||||
|                         'pass_fds': token_0.fds + token_1.fds | ||||
|                     } | ||||
|                 ) | ||||
|                 async with ( | ||||
|                     recv_p.open_context( | ||||
|                         child_transport_sender, | ||||
|                         msg_amount_min=msg_amount_min, | ||||
|                         msg_amount_max=msg_amount_max, | ||||
|                         token_in=token_1, | ||||
|                         token_out=token_0 | ||||
|                     ) as (ctx, msgs), | ||||
|                 ): | ||||
|                     recv_msgs = [] | ||||
|                     while len(recv_msgs) < len(msgs): | ||||
|                         recv_msgs.append(await transport.recv()) | ||||
| 
 | ||||
|                     await transport.send(b'end') | ||||
|                     await recv_p.cancel_actor() | ||||
|                     assert recv_msgs == msgs | ||||
| 
 | ||||
|     trio.run(main) | ||||
|  |  | |||
|  | @ -2,6 +2,10 @@ import os | |||
| import random | ||||
| 
 | ||||
| 
 | ||||
| def generate_single_byte_msgs(amount: int) -> bytes: | ||||
|     return b''.join(str(i % 10).encode() for i in range(amount)) | ||||
| 
 | ||||
| 
 | ||||
| def generate_sample_messages( | ||||
|     amount: int, | ||||
|     rand_min: int = 0, | ||||
|  |  | |||
|  | @ -39,12 +39,19 @@ if platform.system() == 'Linux': | |||
|         write_eventfd as write_eventfd, | ||||
|         read_eventfd as read_eventfd, | ||||
|         close_eventfd as close_eventfd, | ||||
|         EFDReadCancelled as EFDReadCancelled, | ||||
|         EventFD as EventFD, | ||||
|     ) | ||||
| 
 | ||||
|     from ._ringbuf import ( | ||||
|         RBToken as RBToken, | ||||
|         open_ringbuf as open_ringbuf, | ||||
|         RingBuffSender as RingBuffSender, | ||||
|         RingBuffReceiver as RingBuffReceiver, | ||||
|         open_ringbuf as open_ringbuf | ||||
|         open_ringbuf_pair as open_ringbuf_pair, | ||||
|         attach_to_ringbuf_receiver as attach_to_ringbuf_receiver, | ||||
|         attach_to_ringbuf_sender as attach_to_ringbuf_sender, | ||||
|         attach_to_ringbuf_pair as attach_to_ringbuf_pair, | ||||
|         attach_to_ringbuf_stream as attach_to_ringbuf_stream, | ||||
|         MsgpackRBStream as MsgpackRBStream | ||||
|     ) | ||||
|  |  | |||
|  | @ -108,6 +108,10 @@ def close_eventfd(fd: int) -> int: | |||
|         raise OSError(errno.errorcode[ffi.errno], 'close failed') | ||||
| 
 | ||||
| 
 | ||||
| class EFDReadCancelled(Exception): | ||||
|     ... | ||||
| 
 | ||||
| 
 | ||||
| class EventFD: | ||||
|     ''' | ||||
|     Use a previously opened eventfd(2), meant to be used in | ||||
|  | @ -124,6 +128,7 @@ class EventFD: | |||
|         self._fd: int = fd | ||||
|         self._omode: str = omode | ||||
|         self._fobj = None | ||||
|         self._cscope: trio.CancelScope | None = None | ||||
| 
 | ||||
|     @property | ||||
|     def fd(self) -> int | None: | ||||
|  | @ -133,17 +138,38 @@ class EventFD: | |||
|         return write_eventfd(self._fd, value) | ||||
| 
 | ||||
|     async def read(self) -> int: | ||||
|         return await trio.to_thread.run_sync( | ||||
|             read_eventfd, self._fd, | ||||
|             abandon_on_cancel=True | ||||
|         ) | ||||
|         ''' | ||||
|         Async wrapper for `read_eventfd(self.fd)` | ||||
| 
 | ||||
|         `trio.to_thread.run_sync` is used, need to use a `trio.CancelScope` | ||||
|         in order to make it cancellable when `self.close()` is called. | ||||
| 
 | ||||
|         ''' | ||||
|         self._cscope = trio.CancelScope() | ||||
|         with self._cscope: | ||||
|             return await trio.to_thread.run_sync( | ||||
|                 read_eventfd, self._fd, | ||||
|                 abandon_on_cancel=True | ||||
|             ) | ||||
| 
 | ||||
|         if self._cscope.cancelled_caught: | ||||
|             raise EFDReadCancelled | ||||
| 
 | ||||
|         self._cscope = None | ||||
| 
 | ||||
|     def open(self): | ||||
|         self._fobj = os.fdopen(self._fd, self._omode) | ||||
| 
 | ||||
|     def close(self): | ||||
|         if self._fobj: | ||||
|             self._fobj.close() | ||||
|             try: | ||||
|                 self._fobj.close() | ||||
| 
 | ||||
|             except OSError: | ||||
|                 ... | ||||
| 
 | ||||
|         if self._cscope: | ||||
|             self._cscope.cancel() | ||||
| 
 | ||||
|     def __enter__(self): | ||||
|         self.open() | ||||
|  |  | |||
|  | @ -18,10 +18,22 @@ IPC Reliable RingBuffer implementation | |||
| 
 | ||||
| ''' | ||||
| from __future__ import annotations | ||||
| from contextlib import contextmanager as cm | ||||
| import struct | ||||
| from collections.abc import ( | ||||
|     AsyncGenerator, | ||||
|     AsyncIterator | ||||
| ) | ||||
| from contextlib import ( | ||||
|     contextmanager as cm, | ||||
|     asynccontextmanager as acm | ||||
| ) | ||||
| from typing import ( | ||||
|     Any | ||||
| ) | ||||
| from multiprocessing.shared_memory import SharedMemory | ||||
| 
 | ||||
| import trio | ||||
| from tricycle import BufferedReceiveStream | ||||
| from msgspec import ( | ||||
|     Struct, | ||||
|     to_builtins | ||||
|  | @ -30,10 +42,16 @@ from msgspec import ( | |||
| from ._linux import ( | ||||
|     open_eventfd, | ||||
|     close_eventfd, | ||||
|     EFDReadCancelled, | ||||
|     EventFD | ||||
| ) | ||||
| from ._mp_bs import disable_mantracker | ||||
| from tractor.log import get_logger | ||||
| from tractor._exceptions import ( | ||||
|     TransportClosed, | ||||
|     InternalError | ||||
| ) | ||||
| from tractor.ipc import MsgTransport | ||||
| 
 | ||||
| 
 | ||||
| log = get_logger(__name__) | ||||
|  | @ -41,16 +59,21 @@ log = get_logger(__name__) | |||
| 
 | ||||
| disable_mantracker() | ||||
| 
 | ||||
| _DEFAULT_RB_SIZE = 10 * 1024 | ||||
| 
 | ||||
| 
 | ||||
| class RBToken(Struct, frozen=True): | ||||
|     ''' | ||||
|     RingBuffer token contains necesary info to open the two | ||||
|     RingBuffer token contains necesary info to open the three | ||||
|     eventfds and the shared memory | ||||
| 
 | ||||
|     ''' | ||||
|     shm_name: str | ||||
|     write_eventfd: int | ||||
|     wrap_eventfd: int | ||||
| 
 | ||||
|     write_eventfd: int  # used to signal writer ptr advance | ||||
|     wrap_eventfd: int  # used to signal reader ready after wrap around | ||||
|     eof_eventfd: int  # used to signal writer closed | ||||
| 
 | ||||
|     buf_size: int | ||||
| 
 | ||||
|     def as_msg(self): | ||||
|  | @ -63,12 +86,29 @@ class RBToken(Struct, frozen=True): | |||
| 
 | ||||
|         return RBToken(**msg) | ||||
| 
 | ||||
|     @property | ||||
|     def fds(self) -> tuple[int, int, int]: | ||||
|         ''' | ||||
|         Useful for `pass_fds` params | ||||
| 
 | ||||
|         ''' | ||||
|         return ( | ||||
|             self.write_eventfd, | ||||
|             self.wrap_eventfd, | ||||
|             self.eof_eventfd | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| @cm | ||||
| def open_ringbuf( | ||||
|     shm_name: str, | ||||
|     buf_size: int = 10 * 1024, | ||||
|     buf_size: int = _DEFAULT_RB_SIZE, | ||||
| ) -> RBToken: | ||||
|     ''' | ||||
|     Handle resources for a ringbuf (shm, eventfd), yield `RBToken` to | ||||
|     be used with `attach_to_ringbuf_sender` and `attach_to_ringbuf_receiver` | ||||
| 
 | ||||
|     ''' | ||||
|     shm = SharedMemory( | ||||
|         name=shm_name, | ||||
|         size=buf_size, | ||||
|  | @ -79,11 +119,27 @@ def open_ringbuf( | |||
|             shm_name=shm_name, | ||||
|             write_eventfd=open_eventfd(), | ||||
|             wrap_eventfd=open_eventfd(), | ||||
|             eof_eventfd=open_eventfd(), | ||||
|             buf_size=buf_size | ||||
|         ) | ||||
|         yield token | ||||
|         close_eventfd(token.write_eventfd) | ||||
|         close_eventfd(token.wrap_eventfd) | ||||
|         try: | ||||
|             close_eventfd(token.write_eventfd) | ||||
| 
 | ||||
|         except OSError: | ||||
|             ... | ||||
| 
 | ||||
|         try: | ||||
|             close_eventfd(token.wrap_eventfd) | ||||
| 
 | ||||
|         except OSError: | ||||
|             ... | ||||
| 
 | ||||
|         try: | ||||
|             close_eventfd(token.eof_eventfd) | ||||
| 
 | ||||
|         except OSError: | ||||
|             ... | ||||
| 
 | ||||
|     finally: | ||||
|         shm.unlink() | ||||
|  | @ -91,28 +147,36 @@ def open_ringbuf( | |||
| 
 | ||||
| Buffer = bytes | bytearray | memoryview | ||||
| 
 | ||||
| ''' | ||||
| IPC Reliable Ring Buffer | ||||
| 
 | ||||
| `eventfd(2)` is used for wrap around sync, to signal writes to | ||||
| the reader and end of stream. | ||||
| 
 | ||||
| ''' | ||||
| 
 | ||||
| 
 | ||||
| class RingBuffSender(trio.abc.SendStream): | ||||
|     ''' | ||||
|     IPC Reliable Ring Buffer sender side implementation | ||||
|     Ring Buffer sender side implementation | ||||
| 
 | ||||
|     `eventfd(2)` is used for wrap around sync, and also to signal | ||||
|     writes to the reader. | ||||
|     Do not use directly! manage with `attach_to_ringbuf_sender` | ||||
|     after having opened a ringbuf context with `open_ringbuf`. | ||||
| 
 | ||||
|     ''' | ||||
|     def __init__( | ||||
|         self, | ||||
|         token: RBToken, | ||||
|         start_ptr: int = 0, | ||||
|         is_ipc: bool = True | ||||
|         cleanup: bool = False | ||||
|     ): | ||||
|         self._token = RBToken.from_msg(token) | ||||
|         self._shm: SharedMemory | None = None | ||||
|         self._write_event = EventFD(self._token.write_eventfd, 'w') | ||||
|         self._wrap_event = EventFD(self._token.wrap_eventfd, 'r') | ||||
|         self._ptr = start_ptr | ||||
|         self._eof_event = EventFD(self._token.eof_eventfd, 'w') | ||||
|         self._ptr = 0 | ||||
| 
 | ||||
|         self._is_ipc = is_ipc | ||||
|         self._cleanup = cleanup | ||||
|         self._send_lock = trio.StrictFIFOLock() | ||||
| 
 | ||||
|     @property | ||||
|  | @ -170,13 +234,21 @@ class RingBuffSender(trio.abc.SendStream): | |||
|         ) | ||||
|         self._write_event.open() | ||||
|         self._wrap_event.open() | ||||
|         self._eof_event.open() | ||||
| 
 | ||||
|     async def aclose(self): | ||||
|         if self._is_ipc: | ||||
|     def close(self): | ||||
|         self._eof_event.write( | ||||
|             self._ptr if self._ptr > 0 else self.size | ||||
|         ) | ||||
|         if self._cleanup: | ||||
|             self._write_event.close() | ||||
|             self._wrap_event.close() | ||||
|             self._eof_event.close() | ||||
|             self._shm.close() | ||||
| 
 | ||||
|     async def aclose(self): | ||||
|         self.close() | ||||
| 
 | ||||
|     async def __aenter__(self): | ||||
|         self.open() | ||||
|         return self | ||||
|  | @ -184,25 +256,27 @@ class RingBuffSender(trio.abc.SendStream): | |||
| 
 | ||||
| class RingBuffReceiver(trio.abc.ReceiveStream): | ||||
|     ''' | ||||
|     IPC Reliable Ring Buffer receiver side implementation | ||||
|     Ring Buffer receiver side implementation | ||||
| 
 | ||||
|     `eventfd(2)` is used for wrap around sync, and also to signal | ||||
|     writes to the reader. | ||||
|     Do not use directly! manage with `attach_to_ringbuf_receiver` | ||||
|     after having opened a ringbuf context with `open_ringbuf`. | ||||
| 
 | ||||
|     ''' | ||||
|     def __init__( | ||||
|         self, | ||||
|         token: RBToken, | ||||
|         start_ptr: int = 0, | ||||
|         is_ipc: bool = True | ||||
|         cleanup: bool = True, | ||||
|     ): | ||||
|         self._token = RBToken.from_msg(token) | ||||
|         self._shm: SharedMemory | None = None | ||||
|         self._write_event = EventFD(self._token.write_eventfd, 'w') | ||||
|         self._wrap_event = EventFD(self._token.wrap_eventfd, 'r') | ||||
|         self._ptr = start_ptr | ||||
|         self._write_ptr = start_ptr | ||||
|         self._is_ipc = is_ipc | ||||
|         self._eof_event = EventFD(self._token.eof_eventfd, 'r') | ||||
|         self._ptr: int = 0 | ||||
|         self._write_ptr: int = 0 | ||||
|         self._end_ptr: int = -1 | ||||
| 
 | ||||
|         self._cleanup: bool = cleanup | ||||
| 
 | ||||
|     @property | ||||
|     def name(self) -> str: | ||||
|  | @ -226,21 +300,71 @@ class RingBuffReceiver(trio.abc.ReceiveStream): | |||
|     def wrap_fd(self) -> int: | ||||
|         return self._wrap_event.fd | ||||
| 
 | ||||
|     async def receive_some(self, max_bytes: int | None = None) -> memoryview: | ||||
|     async def _eof_monitor_task(self): | ||||
|         ''' | ||||
|         Long running EOF event monitor, automatically run in bg by | ||||
|         `attach_to_ringbuf_receiver` context manager, if EOF event | ||||
|         is set its value will be the end pointer (highest valid | ||||
|         index to be read from buf, after setting the `self._end_ptr` | ||||
|         we close the write event which should cancel any blocked | ||||
|         `self._write_event.read()`s on it. | ||||
| 
 | ||||
|         ''' | ||||
|         try: | ||||
|             self._end_ptr = await self._eof_event.read() | ||||
|             self._write_event.close() | ||||
| 
 | ||||
|         except EFDReadCancelled: | ||||
|             ... | ||||
| 
 | ||||
|     async def receive_some(self, max_bytes: int | None = None) -> bytes: | ||||
|         ''' | ||||
|         Receive up to `max_bytes`, if no `max_bytes` is provided | ||||
|         a reasonable default is used. | ||||
| 
 | ||||
|         ''' | ||||
|         if max_bytes is None: | ||||
|             max_bytes: int = _DEFAULT_RB_SIZE | ||||
| 
 | ||||
|         if max_bytes < 1: | ||||
|             raise ValueError("max_bytes must be >= 1") | ||||
| 
 | ||||
|         # delta is remaining bytes we havent read | ||||
|         delta = self._write_ptr - self._ptr | ||||
|         if delta == 0: | ||||
|             delta = await self._write_event.read() | ||||
|             self._write_ptr += delta | ||||
|             # we have read all we can, see if new data is available | ||||
|             if self._end_ptr < 0: | ||||
|                 # if we havent been signaled about EOF yet | ||||
|                 try: | ||||
|                     delta = await self._write_event.read() | ||||
|                     self._write_ptr += delta | ||||
| 
 | ||||
|         if isinstance(max_bytes, int): | ||||
|             if max_bytes == 0: | ||||
|                 raise ValueError('if set, max_bytes must be > 0') | ||||
|             delta = min(delta, max_bytes) | ||||
|                 except EFDReadCancelled: | ||||
|                     # while waiting for new data `self._write_event` was closed | ||||
|                     # this means writer signaled EOF | ||||
|                     if self._end_ptr > 0: | ||||
|                         # final self._write_ptr modification and recalculate delta | ||||
|                         self._write_ptr = self._end_ptr | ||||
|                         delta = self._end_ptr - self._ptr | ||||
| 
 | ||||
|                     else: | ||||
|                         # shouldnt happen cause self._eof_monitor_task always sets | ||||
|                         # self._end_ptr before closing self._write_event | ||||
|                         raise InternalError( | ||||
|                             'self._write_event.read cancelled but self._end_ptr is not set' | ||||
|                         ) | ||||
| 
 | ||||
|             else: | ||||
|                 # no more bytes to read and self._end_ptr set, EOF reached | ||||
|                 return b'' | ||||
| 
 | ||||
|         # dont overflow caller | ||||
|         delta = min(delta, max_bytes) | ||||
| 
 | ||||
|         target_ptr = self._ptr + delta | ||||
| 
 | ||||
|         # fetch next segment and advance ptr | ||||
|         segment = self._shm.buf[self._ptr:target_ptr] | ||||
|         segment = bytes(self._shm.buf[self._ptr:target_ptr]) | ||||
|         self._ptr = target_ptr | ||||
| 
 | ||||
|         if self._ptr == self.size: | ||||
|  | @ -259,13 +383,284 @@ class RingBuffReceiver(trio.abc.ReceiveStream): | |||
|         ) | ||||
|         self._write_event.open() | ||||
|         self._wrap_event.open() | ||||
|         self._eof_event.open() | ||||
| 
 | ||||
|     async def aclose(self): | ||||
|         if self._is_ipc: | ||||
|     def close(self): | ||||
|         if self._cleanup: | ||||
|             self._write_event.close() | ||||
|             self._wrap_event.close() | ||||
|             self._eof_event.close() | ||||
|             self._shm.close() | ||||
| 
 | ||||
|     async def aclose(self): | ||||
|         self.close() | ||||
| 
 | ||||
|     async def __aenter__(self): | ||||
|         self.open() | ||||
|         return self | ||||
| 
 | ||||
| 
 | ||||
| @acm | ||||
| async def attach_to_ringbuf_receiver( | ||||
|     token: RBToken, | ||||
|     cleanup: bool = True | ||||
| ): | ||||
|     ''' | ||||
|     Instantiate a RingBuffReceiver from a previously opened | ||||
|     RBToken. | ||||
| 
 | ||||
|     Launches `receiver._eof_monitor_task` in a `trio.Nursery`. | ||||
|     ''' | ||||
|     async with ( | ||||
|         trio.open_nursery() as n, | ||||
|         RingBuffReceiver( | ||||
|             token, | ||||
|             cleanup=cleanup | ||||
|         ) as receiver | ||||
|     ): | ||||
|         n.start_soon(receiver._eof_monitor_task) | ||||
|         yield receiver | ||||
| 
 | ||||
| @acm | ||||
| async def attach_to_ringbuf_sender( | ||||
|     token: RBToken, | ||||
|     cleanup: bool = True | ||||
| ): | ||||
|     ''' | ||||
|     Instantiate a RingBuffSender from a previously opened | ||||
|     RBToken. | ||||
| 
 | ||||
|     ''' | ||||
|     async with RingBuffSender( | ||||
|         token, | ||||
|         cleanup=cleanup | ||||
|     ) as sender: | ||||
|         yield sender | ||||
| 
 | ||||
| 
 | ||||
| @cm | ||||
| def open_ringbuf_pair( | ||||
|     name: str, | ||||
|     buf_size: int = _DEFAULT_RB_SIZE | ||||
| ): | ||||
|     ''' | ||||
|     Handle resources for a ringbuf pair to be used for | ||||
|     bidirectional messaging. | ||||
| 
 | ||||
|     ''' | ||||
|     with ( | ||||
|         open_ringbuf( | ||||
|             name + '.pair0', | ||||
|             buf_size=buf_size | ||||
|         ) as token_0, | ||||
| 
 | ||||
|         open_ringbuf( | ||||
|             name + '.pair1', | ||||
|             buf_size=buf_size | ||||
|         ) as token_1 | ||||
|     ): | ||||
|         yield token_0, token_1 | ||||
| 
 | ||||
| 
 | ||||
| @acm | ||||
| async def attach_to_ringbuf_pair( | ||||
|     token_in: RBToken, | ||||
|     token_out: RBToken, | ||||
|     cleanup_in: bool = True, | ||||
|     cleanup_out: bool = True | ||||
| ): | ||||
|     ''' | ||||
|     Instantiate a trio.StapledStream from a previously opened | ||||
|     ringbuf pair. | ||||
| 
 | ||||
|     ''' | ||||
|     async with ( | ||||
|         attach_to_ringbuf_receiver( | ||||
|             token_in, | ||||
|             cleanup=cleanup_in | ||||
|         ) as receiver, | ||||
|         attach_to_ringbuf_sender( | ||||
|             token_out, | ||||
|             cleanup=cleanup_out | ||||
|         ) as sender, | ||||
|     ): | ||||
|         yield trio.StapledStream(sender, receiver) | ||||
| 
 | ||||
| 
 | ||||
| class MsgpackRBStream(MsgTransport): | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         stream: trio.StapledStream | ||||
|     ): | ||||
|         self.stream = stream | ||||
| 
 | ||||
|         # create read loop intance | ||||
|         self._aiter_pkts = self._iter_packets() | ||||
|         self._send_lock = trio.StrictFIFOLock() | ||||
| 
 | ||||
|         self.drained: list[dict] = [] | ||||
| 
 | ||||
|         self.recv_stream = BufferedReceiveStream( | ||||
|             transport_stream=stream | ||||
|         ) | ||||
| 
 | ||||
|     async def _iter_packets(self) -> AsyncGenerator[dict, None]: | ||||
|         ''' | ||||
|         Yield `bytes`-blob decoded packets from the underlying TCP | ||||
|         stream using the current task's `MsgCodec`. | ||||
| 
 | ||||
|         This is a streaming routine implemented as an async generator | ||||
|         func (which was the original design, but could be changed?) | ||||
|         and is allocated by a `.__call__()` inside `.__init__()` where | ||||
|         it is assigned to the `._aiter_pkts` attr. | ||||
| 
 | ||||
|         ''' | ||||
| 
 | ||||
|         while True: | ||||
|             try: | ||||
|                 header: bytes = await self.recv_stream.receive_exactly(4) | ||||
|             except ( | ||||
|                 ValueError, | ||||
|                 ConnectionResetError, | ||||
| 
 | ||||
|                 # not sure entirely why we need this but without it we | ||||
|                 # seem to be getting racy failures here on | ||||
|                 # arbiter/registry name subs.. | ||||
|                 trio.BrokenResourceError, | ||||
| 
 | ||||
|             ) as trans_err: | ||||
| 
 | ||||
|                 loglevel = 'transport' | ||||
|                 match trans_err: | ||||
|                     # case ( | ||||
|                     #     ConnectionResetError() | ||||
|                     # ): | ||||
|                     #     loglevel = 'transport' | ||||
| 
 | ||||
|                     # peer actor (graceful??) TCP EOF but `tricycle` | ||||
|                     # seems to raise a 0-bytes-read? | ||||
|                     case ValueError() if ( | ||||
|                         'unclean EOF' in trans_err.args[0] | ||||
|                     ): | ||||
|                         pass | ||||
| 
 | ||||
|                     # peer actor (task) prolly shutdown quickly due | ||||
|                     # to cancellation | ||||
|                     case trio.BrokenResourceError() if ( | ||||
|                         'Connection reset by peer' in trans_err.args[0] | ||||
|                     ): | ||||
|                         pass | ||||
| 
 | ||||
|                     # unless the disconnect condition falls under "a | ||||
|                     # normal operation breakage" we usualy console warn | ||||
|                     # about it. | ||||
|                     case _: | ||||
|                         loglevel: str = 'warning' | ||||
| 
 | ||||
| 
 | ||||
|                 raise TransportClosed( | ||||
|                     message=( | ||||
|                         f'IPC transport already closed by peer\n' | ||||
|                         f'x)> {type(trans_err)}\n' | ||||
|                         f' |_{self}\n' | ||||
|                     ), | ||||
|                     loglevel=loglevel, | ||||
|                 ) from trans_err | ||||
| 
 | ||||
|             # XXX definitely can happen if transport is closed | ||||
|             # manually by another `trio.lowlevel.Task` in the | ||||
|             # same actor; we use this in some simulated fault | ||||
|             # testing for ex, but generally should never happen | ||||
|             # under normal operation! | ||||
|             # | ||||
|             # NOTE: as such we always re-raise this error from the | ||||
|             #       RPC msg loop! | ||||
|             except trio.ClosedResourceError as closure_err: | ||||
|                 raise TransportClosed( | ||||
|                     message=( | ||||
|                         f'IPC transport already manually closed locally?\n' | ||||
|                         f'x)> {type(closure_err)} \n' | ||||
|                         f' |_{self}\n' | ||||
|                     ), | ||||
|                     loglevel='error', | ||||
|                     raise_on_report=( | ||||
|                         closure_err.args[0] == 'another task closed this fd' | ||||
|                         or | ||||
|                         closure_err.args[0] in ['another task closed this fd'] | ||||
|                     ), | ||||
|                 ) from closure_err | ||||
| 
 | ||||
|             # graceful EOF disconnect | ||||
|             if header == b'': | ||||
|                 raise TransportClosed( | ||||
|                     message=( | ||||
|                         f'IPC transport already gracefully closed\n' | ||||
|                         f')>\n' | ||||
|                         f'|_{self}\n' | ||||
|                     ), | ||||
|                     loglevel='transport', | ||||
|                     # cause=???  # handy or no? | ||||
|                 ) | ||||
| 
 | ||||
|             size: int | ||||
|             size, = struct.unpack("<I", header) | ||||
| 
 | ||||
|             log.transport(f'received header {size}')  # type: ignore | ||||
|             msg_bytes: bytes = await self.recv_stream.receive_exactly(size) | ||||
| 
 | ||||
|             log.transport(f"received {msg_bytes}")  # type: ignore | ||||
|             yield msg_bytes | ||||
| 
 | ||||
|     async def send( | ||||
|         self, | ||||
|         msg: bytes, | ||||
| 
 | ||||
|     ) -> None: | ||||
|         ''' | ||||
|         Send a msgpack encoded py-object-blob-as-msg. | ||||
| 
 | ||||
|         ''' | ||||
|         async with self._send_lock: | ||||
|             size: bytes = struct.pack("<I", len(msg)) | ||||
|             return await self.stream.send_all(size + msg) | ||||
| 
 | ||||
|     async def recv(self) -> Any: | ||||
|         return await self._aiter_pkts.asend(None) | ||||
| 
 | ||||
|     async def drain(self) -> AsyncIterator[dict]: | ||||
|         ''' | ||||
|         Drain the stream's remaining messages sent from | ||||
|         the far end until the connection is closed by | ||||
|         the peer. | ||||
| 
 | ||||
|         ''' | ||||
|         try: | ||||
|             async for msg in self._iter_packets(): | ||||
|                 self.drained.append(msg) | ||||
|         except TransportClosed: | ||||
|             for msg in self.drained: | ||||
|                 yield msg | ||||
| 
 | ||||
|     def __aiter__(self): | ||||
|         return self._aiter_pkts | ||||
| 
 | ||||
| 
 | ||||
| @acm | ||||
| async def attach_to_ringbuf_stream( | ||||
|     token_in: RBToken, | ||||
|     token_out: RBToken, | ||||
|     cleanup_in: bool = True, | ||||
|     cleanup_out: bool = True | ||||
| ): | ||||
|     ''' | ||||
|     Wrap a ringbuf trio.StapledStream in a MsgpackRBStream | ||||
| 
 | ||||
|     ''' | ||||
|     async with attach_to_ringbuf_pair( | ||||
|         token_in, | ||||
|         token_out, | ||||
|         cleanup_in=cleanup_in, | ||||
|         cleanup_out=cleanup_out | ||||
|     ) as stream: | ||||
|         yield MsgpackRBStream(stream) | ||||
|  |  | |||
|  | @ -26,7 +26,6 @@ import struct | |||
| from typing import ( | ||||
|     Any, | ||||
|     Callable, | ||||
|     Type, | ||||
| ) | ||||
| 
 | ||||
| import msgspec | ||||
|  |  | |||
|  | @ -41,10 +41,10 @@ class MsgTransport(Protocol[MsgType]): | |||
| # eventual msg definition/types? | ||||
| # - https://docs.python.org/3/library/typing.html#typing.Protocol | ||||
| 
 | ||||
|     stream: trio.SocketStream | ||||
|     stream: trio.abc.Stream | ||||
|     drained: list[MsgType] | ||||
| 
 | ||||
|     def __init__(self, stream: trio.SocketStream) -> None: | ||||
|     def __init__(self, stream: trio.abc.Stream) -> None: | ||||
|         ... | ||||
| 
 | ||||
|     # XXX: should this instead be called `.sendall()`? | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue