diff --git a/tractor/msg/_ops.py b/tractor/msg/_ops.py index e78b79a..5a9ab46 100644 --- a/tractor/msg/_ops.py +++ b/tractor/msg/_ops.py @@ -25,12 +25,12 @@ from contextlib import ( # asynccontextmanager as acm, contextmanager as cm, ) -from pprint import pformat +from contextvars import ContextVar from typing import ( Any, Type, TYPE_CHECKING, - # Union, + Union, ) # ------ - ------ from msgspec import ( @@ -63,7 +63,7 @@ from .types import ( Started, Stop, Yield, - # pretty_struct, + pretty_struct, ) @@ -75,6 +75,9 @@ if TYPE_CHECKING: log = get_logger(__name__) +_def_any_pldec: MsgDec = mk_dec() + + class PldRx(Struct): ''' A "msg payload receiver". @@ -101,10 +104,13 @@ class PldRx(Struct): ''' # TODO: better to bind it here? # _rx_mc: trio.MemoryReceiveChannel - _msgdec: MsgDec = mk_dec(spec=Any) - + _pldec: MsgDec _ipc: Context|MsgStream|None = None + @property + def pld_dec(self) -> MsgDec: + return self._pldec + @cm def apply_to_ipc( self, @@ -122,9 +128,29 @@ class PldRx(Struct): finally: self._ipc = None + @cm + def limit_plds( + self, + spec: Union[Type[Struct]], + + ) -> MsgDec: + ''' + Type-limit the loadable msg payloads via an applied + `MsgDec` given an input spec, revert to prior decoder on + exit. + + ''' + orig_dec: MsgDec = self._pldec + limit_dec: MsgDec = mk_dec(spec=spec) + try: + self._pldec = limit_dec + yield limit_dec + finally: + self._pldec = orig_dec + @property def dec(self) -> msgpack.Decoder: - return self._msgdec.dec + return self._pldec.dec def recv_pld_nowait( self, @@ -182,7 +208,7 @@ class PldRx(Struct): self, msg: MsgType, ctx: Context, - expect_msg: Type[MsgType]|None = None, + expect_msg: Type[MsgType]|None, ) -> PayloadT|Raw: ''' @@ -199,11 +225,11 @@ class PldRx(Struct): |Return(pld=pld) # termination phase ): try: - pld: PayloadT = self._msgdec.decode(pld) + pld: PayloadT = self._pldec.decode(pld) log.runtime( - 'Decode msg payload\n\n' - f'{msg}\n\n' - f'{pld}\n' + 'Decoded msg payload\n\n' + f'{msg}\n' + f'|_pld={pld!r}' ) return pld @@ -237,9 +263,42 @@ class PldRx(Struct): case Error(): src_err = MessagingError( - 'IPC dialog termination by msg' + 'IPC ctx dialog terminated without `Return`-ing a result' ) + case Stop(cid=cid): + message: str = ( + f'{ctx.side!r}-side of ctx received stream-`Stop` from ' + f'{ctx.peer_side!r} peer ?\n' + f'|_cid: {cid}\n\n' + + f'{pretty_struct.pformat(msg)}\n' + ) + if ctx._stream is None: + explain: str = ( + f'BUT, no `MsgStream` (was) open(ed) on this ' + f'{ctx.side!r}-side of the IPC ctx?\n' + f'Maybe check your code for streaming phase race conditions?\n' + ) + log.warning( + message + + + explain + ) + # let caller decide what to do when only one + # side opened a stream, don't raise. + return msg + + else: + explain: str = ( + 'Received a `Stop` when it should NEVER be possible!?!?\n' + ) + # TODO: this is constructed inside + # `_raise_from_unexpected_msg()` but maybe we + # should pass it in? + # src_err = trio.EndOfChannel(explain) + src_err = None + case _: src_err = InternalError( 'Unknown IPC msg ??\n\n' @@ -259,6 +318,7 @@ class PldRx(Struct): async def recv_msg_w_pld( self, ipc: Context|MsgStream, + expect_msg: MsgType, ) -> tuple[MsgType, PayloadT]: ''' @@ -274,10 +334,75 @@ class PldRx(Struct): pld: PayloadT = self.dec_msg( msg, ctx=ipc, + expect_msg=expect_msg, ) return msg, pld +# Always maintain a task-context-global `PldRx` +_def_pld_rx: PldRx = PldRx( + _pldec=_def_any_pldec, +) +_ctxvar_PldRx: ContextVar[PldRx] = ContextVar( + 'pld_rx', + default=_def_pld_rx, +) + + +def current_pldrx() -> PldRx: + ''' + Return the current `trio.Task.context`'s msg-payload + receiver, the post IPC but pre-app code `MsgType.pld` + filter. + + Modification of the current payload spec via `limit_plds()` + allows an application to contextually filter typed IPC msg + content delivered via wire transport. + + ''' + return _ctxvar_PldRx.get() + + +@cm +def limit_plds( + spec: Union[Type[Struct]], + **kwargs, + +) -> MsgDec: + ''' + Apply a `MsgCodec` that will natively decode the SC-msg set's + `Msg.pld: Union[Type[Struct]]` payload fields using + tagged-unions of `msgspec.Struct`s from the `payload_types` + for all IPC contexts in use by the current `trio.Task`. + + ''' + __tracebackhide__: bool = True + try: + # sanity on orig settings + orig_pldrx: PldRx = current_pldrx() + orig_pldec: MsgDec = orig_pldrx.pld_dec + + with orig_pldrx.limit_plds( + spec=spec, + **kwargs, + ) as pldec: + log.info( + 'Applying payload-decoder\n\n' + f'{pldec}\n' + ) + yield pldec + finally: + log.info( + 'Reverted to previous payload-decoder\n\n' + f'{orig_pldec}\n' + ) + assert ( + (pldrx := current_pldrx()) is orig_pldrx + and + pldrx.pld_dec is orig_pldec + ) + + async def drain_to_final_msg( ctx: Context, @@ -368,7 +493,10 @@ async def drain_to_final_msg( # pray to the `trio` gawds that we're corrent with this # msg: dict = await ctx._rx_chan.receive() - msg, pld = await ctx._pld_rx.recv_msg_w_pld(ipc=ctx) + msg, pld = await ctx._pld_rx.recv_msg_w_pld( + ipc=ctx, + expect_msg=Return, + ) # NOTE: we get here if the far end was # `ContextCancelled` in 2 cases: @@ -399,7 +527,7 @@ async def drain_to_final_msg( ctx._result: Any = pld log.runtime( 'Context delivered final draining msg:\n' - f'{pformat(msg)}' + f'{pretty_struct.pformat(msg)}' ) # XXX: only close the rx mem chan AFTER # a final result is retreived. @@ -435,7 +563,7 @@ async def drain_to_final_msg( f'=> {ctx._task}\n' f' |_{ctx._stream}\n\n' - f'{pformat(msg)}\n' + f'{pretty_struct.pformat(msg)}\n' ) return ( return_msg, @@ -452,7 +580,7 @@ async def drain_to_final_msg( f'=> {ctx._task}\n' f' |_{ctx._stream}\n\n' - f'{pformat(msg)}\n' + f'{pretty_struct.pformat(msg)}\n' ) continue @@ -467,7 +595,7 @@ async def drain_to_final_msg( pre_result_drained.append(msg) log.cancel( 'Remote stream terminated due to "stop" msg:\n\n' - f'{pformat(msg)}\n' + f'{pretty_struct.pformat(msg)}\n' ) continue