diff --git a/tractor/msg/_ops.py b/tractor/msg/_ops.py index 5f4b9fe8..fbbbecff 100644 --- a/tractor/msg/_ops.py +++ b/tractor/msg/_ops.py @@ -110,33 +110,11 @@ class PldRx(Struct): # TODO: better to bind it here? # _rx_mc: trio.MemoryReceiveChannel _pld_dec: MsgDec - _ctx: Context|None = None - _ipc: Context|MsgStream|None = None @property def pld_dec(self) -> MsgDec: return self._pld_dec - # TODO: a better name? - # -[ ] when would this be used as it avoids needingn to pass the - # ipc prim to every method - @cm - def wraps_ipc( - self, - ipc_prim: Context|MsgStream, - - ) -> PldRx: - ''' - Apply this payload receiver to an IPC primitive type, one - of `Context` or `MsgStream`. - - ''' - self._ipc = ipc_prim - try: - yield self - finally: - self._ipc = None - @cm def limit_plds( self, @@ -169,7 +147,7 @@ class PldRx(Struct): def dec(self) -> msgpack.Decoder: return self._pld_dec.dec - def recv_pld_nowait( + def recv_msg_nowait( self, # TODO: make this `MsgStream` compat as well, see above^ # ipc_prim: Context|MsgStream, @@ -180,7 +158,15 @@ class PldRx(Struct): hide_tb: bool = False, **dec_pld_kwargs, - ) -> Any|Raw: + ) -> tuple[ + MsgType[PayloadT], + PayloadT, + ]: + ''' + Attempt to non-blocking receive a message from the `._rx_chan` and + unwrap it's payload delivering the pair to the caller. + + ''' __tracebackhide__: bool = hide_tb msg: MsgType = ( @@ -189,31 +175,78 @@ class PldRx(Struct): # sync-rx msg from underlying IPC feeder (mem-)chan ipc._rx_chan.receive_nowait() ) - if ( - type(msg) is Return - ): - log.info( - f'Rxed final result msg\n' - f'{msg}\n' - ) - return self.decode_pld( + pld: PayloadT = self.decode_pld( msg, ipc=ipc, expect_msg=expect_msg, hide_tb=hide_tb, **dec_pld_kwargs, ) + return ( + msg, + pld, + ) + + async def recv_msg( + self, + ipc: Context|MsgStream, + expect_msg: MsgType, + + # NOTE: ONLY for handling `Stop`-msgs that arrive during + # a call to `drain_to_final_msg()` above! + passthrough_non_pld_msgs: bool = True, + hide_tb: bool = True, + + **decode_pld_kwargs, + + ) -> tuple[MsgType, PayloadT]: + ''' + Retrieve the next avail IPC msg, decode its payload, and + return the (msg, pld) pair. + + ''' + __tracebackhide__: bool = hide_tb + msg: MsgType = await ipc._rx_chan.receive() + match msg: + case Return()|Error(): + log.runtime( + f'Rxed final outcome msg\n' + f'{msg}\n' + ) + case Stop(): + log.runtime( + f'Rxed stream stopped msg\n' + f'{msg}\n' + ) + if passthrough_non_pld_msgs: + return msg, None + + # TODO: is there some way we can inject the decoded + # payload into an existing output buffer for the original + # msg instance? + pld: PayloadT = self.decode_pld( + msg, + ipc=ipc, + expect_msg=expect_msg, + hide_tb=hide_tb, + + **decode_pld_kwargs, + ) + return ( + msg, + pld, + ) async def recv_pld( self, ipc: Context|MsgStream, - ipc_msg: MsgType|None = None, + ipc_msg: MsgType[PayloadT]|None = None, expect_msg: Type[MsgType]|None = None, hide_tb: bool = True, **dec_pld_kwargs, - ) -> Any|Raw: + ) -> PayloadT: ''' Receive a `MsgType`, then decode and return its `.pld` field. @@ -420,54 +453,6 @@ class PldRx(Struct): __tracebackhide__: bool = False raise - async def recv_msg_w_pld( - self, - ipc: Context|MsgStream, - expect_msg: MsgType, - - # NOTE: generally speaking only for handling `Stop`-msgs that - # arrive during a call to `drain_to_final_msg()` above! - passthrough_non_pld_msgs: bool = True, - hide_tb: bool = True, - **kwargs, - - ) -> tuple[MsgType, PayloadT]: - ''' - Retrieve the next avail IPC msg, decode it's payload, and - return the pair of refs. - - ''' - __tracebackhide__: bool = hide_tb - msg: MsgType = await ipc._rx_chan.receive() - if ( - type(msg) is Return - ): - log.info( - f'Rxed final result msg\n' - f'{msg}\n' - ) - - if passthrough_non_pld_msgs: - match msg: - case Stop(): - return msg, None - - # TODO: is there some way we can inject the decoded - # payload into an existing output buffer for the original - # msg instance? - pld: PayloadT = self.decode_pld( - msg, - ipc=ipc, - expect_msg=expect_msg, - hide_tb=hide_tb, - **kwargs, - ) - # log.runtime( - # f'Delivering payload msg\n' - # f'{msg}\n' - # ) - return msg, pld - @cm def limit_plds( @@ -607,7 +592,7 @@ async def drain_to_final_msg( # receive all msgs, scanning for either a final result # or error; the underlying call should never raise any # remote error directly! - msg, pld = await ctx._pld_rx.recv_msg_w_pld( + msg, pld = await ctx._pld_rx.recv_msg( ipc=ctx, expect_msg=Return, raise_error=False,