From debe63f4f2e39f0201c75d3c9abf92cf35e853d7 Mon Sep 17 00:00:00 2001
From: Tyler Goodlet <jgbt@protonmail.com>
Date: Wed, 12 Mar 2025 13:49:58 -0400
Subject: [PATCH] Slight `PldRx` rework to simplify

Namely renaming and tweaking the `MsgType` receiving methods,
- `.recv_msg()` from what was `.recv_msg_w_pld()` which both receives
  the IPC msg from the underlying `._rx_chan` and then decodes its
  payload with `.decode_pld()`; it now also log reports on the different
  "stage of SC dialog protocol" msg types via a `match/case`.
- a new `.recv_msg_nowait()` sync equivalent of ^ (*was*
  `.recv_pld_nowait()`) who's use was the source of a recently
  discovered bug where any final `Return.pld` is being
  consumed-n-discarded by by `MsgStream.aclose()` depending on
  ctx/stream teardown race conditions..

Also,
- remove all the "instance persistent" ipc-ctx attrs, specifically the
  optional `_ipc`, `_ctx` and the `.wraps_ipc()` cm, since none of them
  were ever really needed/used; all methods which require
  a `Context/MsgStream` are explicitly always passed.
- update a buncha typing namely to use the more generic-styled
  `PayloadT` over `Any` and obviously `MsgType[PayloadT]`.
---
 tractor/msg/_ops.py | 151 ++++++++++++++++++++------------------------
 1 file changed, 68 insertions(+), 83 deletions(-)

diff --git a/tractor/msg/_ops.py b/tractor/msg/_ops.py
index 5f4b9fe8..fbbbecff 100644
--- a/tractor/msg/_ops.py
+++ b/tractor/msg/_ops.py
@@ -110,33 +110,11 @@ class PldRx(Struct):
     # TODO: better to bind it here?
     # _rx_mc: trio.MemoryReceiveChannel
     _pld_dec: MsgDec
-    _ctx: Context|None = None
-    _ipc: Context|MsgStream|None = None
 
     @property
     def pld_dec(self) -> MsgDec:
         return self._pld_dec
 
-    # TODO: a better name?
-    # -[ ] when would this be used as it avoids needingn to pass the
-    #   ipc prim to every method
-    @cm
-    def wraps_ipc(
-        self,
-        ipc_prim: Context|MsgStream,
-
-    ) -> PldRx:
-        '''
-        Apply this payload receiver to an IPC primitive type, one
-        of `Context` or `MsgStream`.
-
-        '''
-        self._ipc = ipc_prim
-        try:
-            yield self
-        finally:
-            self._ipc = None
-
     @cm
     def limit_plds(
         self,
@@ -169,7 +147,7 @@ class PldRx(Struct):
     def dec(self) -> msgpack.Decoder:
         return self._pld_dec.dec
 
-    def recv_pld_nowait(
+    def recv_msg_nowait(
         self,
         # TODO: make this `MsgStream` compat as well, see above^
         # ipc_prim: Context|MsgStream,
@@ -180,7 +158,15 @@ class PldRx(Struct):
         hide_tb: bool = False,
         **dec_pld_kwargs,
 
-    ) -> Any|Raw:
+    ) -> tuple[
+        MsgType[PayloadT],
+        PayloadT,
+    ]:
+        '''
+        Attempt to non-blocking receive a message from the `._rx_chan` and
+        unwrap it's payload delivering the pair to the caller.
+
+        '''
         __tracebackhide__: bool = hide_tb
 
         msg: MsgType = (
@@ -189,31 +175,78 @@ class PldRx(Struct):
             # sync-rx msg from underlying IPC feeder (mem-)chan
             ipc._rx_chan.receive_nowait()
         )
-        if (
-            type(msg) is Return
-        ):
-            log.info(
-                f'Rxed final result msg\n'
-                f'{msg}\n'
-            )
-        return self.decode_pld(
+        pld: PayloadT = self.decode_pld(
             msg,
             ipc=ipc,
             expect_msg=expect_msg,
             hide_tb=hide_tb,
             **dec_pld_kwargs,
         )
+        return (
+            msg,
+            pld,
+        )
+
+    async def recv_msg(
+        self,
+        ipc: Context|MsgStream,
+        expect_msg: MsgType,
+
+        # NOTE: ONLY for handling `Stop`-msgs that arrive during
+        # a call to `drain_to_final_msg()` above!
+        passthrough_non_pld_msgs: bool = True,
+        hide_tb: bool = True,
+
+        **decode_pld_kwargs,
+
+    ) -> tuple[MsgType, PayloadT]:
+        '''
+        Retrieve the next avail IPC msg, decode its payload, and
+        return the (msg, pld) pair.
+
+        '''
+        __tracebackhide__: bool = hide_tb
+        msg: MsgType = await ipc._rx_chan.receive()
+        match msg:
+            case Return()|Error():
+                log.runtime(
+                    f'Rxed final outcome msg\n'
+                    f'{msg}\n'
+                )
+            case Stop():
+                log.runtime(
+                    f'Rxed stream stopped msg\n'
+                    f'{msg}\n'
+                )
+                if passthrough_non_pld_msgs:
+                    return msg, None
+
+        # TODO: is there some way we can inject the decoded
+        # payload into an existing output buffer for the original
+        # msg instance?
+        pld: PayloadT = self.decode_pld(
+            msg,
+            ipc=ipc,
+            expect_msg=expect_msg,
+            hide_tb=hide_tb,
+
+            **decode_pld_kwargs,
+        )
+        return (
+            msg,
+            pld,
+        )
 
     async def recv_pld(
         self,
         ipc: Context|MsgStream,
-        ipc_msg: MsgType|None = None,
+        ipc_msg: MsgType[PayloadT]|None = None,
         expect_msg: Type[MsgType]|None = None,
         hide_tb: bool = True,
 
         **dec_pld_kwargs,
 
-    ) -> Any|Raw:
+    ) -> PayloadT:
         '''
         Receive a `MsgType`, then decode and return its `.pld` field.
 
@@ -420,54 +453,6 @@ class PldRx(Struct):
             __tracebackhide__: bool = False
             raise
 
-    async def recv_msg_w_pld(
-        self,
-        ipc: Context|MsgStream,
-        expect_msg: MsgType,
-
-        # NOTE: generally speaking only for handling `Stop`-msgs that
-        # arrive during a call to `drain_to_final_msg()` above!
-        passthrough_non_pld_msgs: bool = True,
-        hide_tb: bool = True,
-        **kwargs,
-
-    ) -> tuple[MsgType, PayloadT]:
-        '''
-        Retrieve the next avail IPC msg, decode it's payload, and
-        return the pair of refs.
-
-        '''
-        __tracebackhide__: bool = hide_tb
-        msg: MsgType = await ipc._rx_chan.receive()
-        if (
-            type(msg) is Return
-        ):
-            log.info(
-                f'Rxed final result msg\n'
-                f'{msg}\n'
-            )
-
-        if passthrough_non_pld_msgs:
-            match msg:
-                case Stop():
-                    return msg, None
-
-        # TODO: is there some way we can inject the decoded
-        # payload into an existing output buffer for the original
-        # msg instance?
-        pld: PayloadT = self.decode_pld(
-            msg,
-            ipc=ipc,
-            expect_msg=expect_msg,
-            hide_tb=hide_tb,
-            **kwargs,
-        )
-        # log.runtime(
-        #     f'Delivering payload msg\n'
-        #     f'{msg}\n'
-        # )
-        return msg, pld
-
 
 @cm
 def limit_plds(
@@ -607,7 +592,7 @@ async def drain_to_final_msg(
             # receive all msgs, scanning for either a final result
             # or error; the underlying call should never raise any
             # remote error directly!
-            msg, pld = await ctx._pld_rx.recv_msg_w_pld(
+            msg, pld = await ctx._pld_rx.recv_msg(
                 ipc=ctx,
                 expect_msg=Return,
                 raise_error=False,