From 6c2efc96dc102b8348ca6035890db8be2bcaccb7 Mon Sep 17 00:00:00 2001 From: Tyler Goodlet Date: Tue, 28 May 2024 11:08:27 -0400 Subject: [PATCH] Factor `.started()` validation into `.msg._ops` Filling out the helper `validate_payload_msg()` staged in a prior commit and adjusting all imports to match. Also add a `raise_mte: bool` flag for potential usage where the caller wants to handle the MTE instance themselves. --- tractor/_context.py | 57 +++++++---------------------------------- tractor/msg/_ops.py | 62 ++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 67 insertions(+), 52 deletions(-) diff --git a/tractor/_context.py b/tractor/_context.py index 42271b0..e973092 100644 --- a/tractor/_context.py +++ b/tractor/_context.py @@ -58,9 +58,6 @@ from typing import ( import warnings # ------ - ------ import trio -from msgspec import ( - ValidationError, -) # ------ - ------ from ._exceptions import ( ContextCancelled, @@ -78,19 +75,16 @@ from .log import ( from .msg import ( Error, MsgType, - MsgCodec, NamespacePath, PayloadT, Started, Stop, Yield, - current_codec, pretty_struct, _ops as msgops, ) from ._ipc import ( Channel, - _mk_msg_type_err, ) from ._streaming import MsgStream from ._state import ( @@ -1657,54 +1651,21 @@ class Context: # __tracebackhide__: bool = hide_tb if validate_pld_spec: - # __tracebackhide__: bool = False - codec: MsgCodec = current_codec() - msg_bytes: bytes = codec.encode(started_msg) - try: - roundtripped: Started = codec.decode(msg_bytes) - # pld: PayloadT = await self.pld_rx.recv_pld( - pld: PayloadT = self.pld_rx.dec_msg( - msg=roundtripped, - ipc=self, - expect_msg=Started, - hide_tb=hide_tb, - is_started_send_side=True, - ) - if ( - strict_pld_parity - and - pld != value - ): - # TODO: make that one a mod func too.. - diff = pretty_struct.Struct.__sub__( - roundtripped, - started_msg, - ) - complaint: str = ( - 'Started value does not match after roundtrip?\n\n' - f'{diff}' - ) - raise ValidationError(complaint) - - # raise any msg type error NO MATTER WHAT! - except ValidationError as verr: - # always show this src frame in the tb - # __tracebackhide__: bool = False - raise _mk_msg_type_err( - msg=roundtripped, - codec=codec, - src_validation_error=verr, - verb_header='Trying to send ', - is_invalid_payload=True, - ) from verr + msgops.validate_payload_msg( + pld_msg=started_msg, + pld_value=value, + ipc=self, + strict_pld_parity=strict_pld_parity, + hide_tb=hide_tb, + ) # TODO: maybe a flag to by-pass encode op if already done # here in caller? await self.chan.send(started_msg) # set msg-related internal runtime-state - self._started_called = True - self._started_msg = started_msg + self._started_called: bool = True + self._started_msg: Started = started_msg self._started_pld = value async def _drain_overflows( diff --git a/tractor/msg/_ops.py b/tractor/msg/_ops.py index 6faf78e..e22d39f 100644 --- a/tractor/msg/_ops.py +++ b/tractor/msg/_ops.py @@ -53,6 +53,8 @@ from tractor._state import current_ipc_ctx from ._codec import ( mk_dec, MsgDec, + MsgCodec, + current_codec, ) from .types import ( CancelAck, @@ -737,9 +739,61 @@ async def drain_to_final_msg( ) -# TODO: factor logic from `.Context.started()` for send-side -# validate raising! def validate_payload_msg( - msg: Started|Yield|Return, + pld_msg: Started|Yield|Return, + pld_value: PayloadT, + ipc: Context|MsgStream, + + raise_mte: bool = True, + strict_pld_parity: bool = False, + hide_tb: bool = True, + ) -> MsgTypeError|None: - ... + ''' + Validate a `PayloadMsg.pld` value with the current + IPC ctx's `PldRx` and raise an appropriate `MsgTypeError` + on failure. + + ''' + __tracebackhide__: bool = hide_tb + codec: MsgCodec = current_codec() + msg_bytes: bytes = codec.encode(pld_msg) + try: + roundtripped: Started = codec.decode(msg_bytes) + ctx: Context = getattr(ipc, 'ctx', ipc) + pld: PayloadT = ctx.pld_rx.dec_msg( + msg=roundtripped, + ipc=ipc, + expect_msg=Started, + hide_tb=hide_tb, + is_started_send_side=True, + ) + if ( + strict_pld_parity + and + pld != pld_value + ): + # TODO: make that one a mod func too.. + diff = pretty_struct.Struct.__sub__( + roundtripped, + pld_msg, + ) + complaint: str = ( + 'Started value does not match after roundtrip?\n\n' + f'{diff}' + ) + raise ValidationError(complaint) + + # raise any msg type error NO MATTER WHAT! + except ValidationError as verr: + mte: MsgTypeError = _mk_msg_type_err( + msg=roundtripped, + codec=codec, + src_validation_error=verr, + verb_header='Trying to send ', + is_invalid_payload=True, + ) + if not raise_mte: + return mte + + raise mte from verr