forked from goodboy/tractor
				
			Factor `.started()` validation into `.msg._ops`
Filling out the helper `validate_payload_msg()` staged in a prior commit and adjusting all imports to match. Also add a `raise_mte: bool` flag for potential usage where the caller wants to handle the MTE instance themselves.runtime_to_msgspec
							parent
							
								
									f7fd8278af
								
							
						
					
					
						commit
						6c2efc96dc
					
				| 
						 | 
					@ -58,9 +58,6 @@ from typing import (
 | 
				
			||||||
import warnings
 | 
					import warnings
 | 
				
			||||||
# ------ - ------
 | 
					# ------ - ------
 | 
				
			||||||
import trio
 | 
					import trio
 | 
				
			||||||
from msgspec import (
 | 
					 | 
				
			||||||
    ValidationError,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
# ------ - ------
 | 
					# ------ - ------
 | 
				
			||||||
from ._exceptions import (
 | 
					from ._exceptions import (
 | 
				
			||||||
    ContextCancelled,
 | 
					    ContextCancelled,
 | 
				
			||||||
| 
						 | 
					@ -78,19 +75,16 @@ from .log import (
 | 
				
			||||||
from .msg import (
 | 
					from .msg import (
 | 
				
			||||||
    Error,
 | 
					    Error,
 | 
				
			||||||
    MsgType,
 | 
					    MsgType,
 | 
				
			||||||
    MsgCodec,
 | 
					 | 
				
			||||||
    NamespacePath,
 | 
					    NamespacePath,
 | 
				
			||||||
    PayloadT,
 | 
					    PayloadT,
 | 
				
			||||||
    Started,
 | 
					    Started,
 | 
				
			||||||
    Stop,
 | 
					    Stop,
 | 
				
			||||||
    Yield,
 | 
					    Yield,
 | 
				
			||||||
    current_codec,
 | 
					 | 
				
			||||||
    pretty_struct,
 | 
					    pretty_struct,
 | 
				
			||||||
    _ops as msgops,
 | 
					    _ops as msgops,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from ._ipc import (
 | 
					from ._ipc import (
 | 
				
			||||||
    Channel,
 | 
					    Channel,
 | 
				
			||||||
    _mk_msg_type_err,
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from ._streaming import MsgStream
 | 
					from ._streaming import MsgStream
 | 
				
			||||||
from ._state import (
 | 
					from ._state import (
 | 
				
			||||||
| 
						 | 
					@ -1657,54 +1651,21 @@ class Context:
 | 
				
			||||||
        #
 | 
					        #
 | 
				
			||||||
        __tracebackhide__: bool = hide_tb
 | 
					        __tracebackhide__: bool = hide_tb
 | 
				
			||||||
        if validate_pld_spec:
 | 
					        if validate_pld_spec:
 | 
				
			||||||
            # __tracebackhide__: bool = False
 | 
					            msgops.validate_payload_msg(
 | 
				
			||||||
            codec: MsgCodec = current_codec()
 | 
					                pld_msg=started_msg,
 | 
				
			||||||
            msg_bytes: bytes = codec.encode(started_msg)
 | 
					                pld_value=value,
 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                roundtripped: Started = codec.decode(msg_bytes)
 | 
					 | 
				
			||||||
                # pld: PayloadT = await self.pld_rx.recv_pld(
 | 
					 | 
				
			||||||
                pld: PayloadT = self.pld_rx.dec_msg(
 | 
					 | 
				
			||||||
                    msg=roundtripped,
 | 
					 | 
				
			||||||
                ipc=self,
 | 
					                ipc=self,
 | 
				
			||||||
                    expect_msg=Started,
 | 
					                strict_pld_parity=strict_pld_parity,
 | 
				
			||||||
                hide_tb=hide_tb,
 | 
					                hide_tb=hide_tb,
 | 
				
			||||||
                    is_started_send_side=True,
 | 
					 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
                if (
 | 
					 | 
				
			||||||
                    strict_pld_parity
 | 
					 | 
				
			||||||
                    and
 | 
					 | 
				
			||||||
                    pld != value
 | 
					 | 
				
			||||||
                ):
 | 
					 | 
				
			||||||
                    # TODO: make that one a mod func too..
 | 
					 | 
				
			||||||
                    diff = pretty_struct.Struct.__sub__(
 | 
					 | 
				
			||||||
                        roundtripped,
 | 
					 | 
				
			||||||
                        started_msg,
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
                    complaint: str = (
 | 
					 | 
				
			||||||
                        'Started value does not match after roundtrip?\n\n'
 | 
					 | 
				
			||||||
                        f'{diff}'
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
                    raise ValidationError(complaint)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # raise any msg type error NO MATTER WHAT!
 | 
					 | 
				
			||||||
            except ValidationError as verr:
 | 
					 | 
				
			||||||
                # always show this src frame in the tb
 | 
					 | 
				
			||||||
                # __tracebackhide__: bool = False
 | 
					 | 
				
			||||||
                raise _mk_msg_type_err(
 | 
					 | 
				
			||||||
                    msg=roundtripped,
 | 
					 | 
				
			||||||
                    codec=codec,
 | 
					 | 
				
			||||||
                    src_validation_error=verr,
 | 
					 | 
				
			||||||
                    verb_header='Trying to send ',
 | 
					 | 
				
			||||||
                    is_invalid_payload=True,
 | 
					 | 
				
			||||||
                ) from verr
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # TODO: maybe a flag to by-pass encode op if already done
 | 
					        # TODO: maybe a flag to by-pass encode op if already done
 | 
				
			||||||
        # here in caller?
 | 
					        # here in caller?
 | 
				
			||||||
        await self.chan.send(started_msg)
 | 
					        await self.chan.send(started_msg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # set msg-related internal runtime-state
 | 
					        # set msg-related internal runtime-state
 | 
				
			||||||
        self._started_called = True
 | 
					        self._started_called: bool = True
 | 
				
			||||||
        self._started_msg = started_msg
 | 
					        self._started_msg: Started = started_msg
 | 
				
			||||||
        self._started_pld = value
 | 
					        self._started_pld = value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def _drain_overflows(
 | 
					    async def _drain_overflows(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -53,6 +53,8 @@ from tractor._state import current_ipc_ctx
 | 
				
			||||||
from ._codec import (
 | 
					from ._codec import (
 | 
				
			||||||
    mk_dec,
 | 
					    mk_dec,
 | 
				
			||||||
    MsgDec,
 | 
					    MsgDec,
 | 
				
			||||||
 | 
					    MsgCodec,
 | 
				
			||||||
 | 
					    current_codec,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from .types import (
 | 
					from .types import (
 | 
				
			||||||
    CancelAck,
 | 
					    CancelAck,
 | 
				
			||||||
| 
						 | 
					@ -737,9 +739,61 @@ async def drain_to_final_msg(
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# TODO: factor logic from `.Context.started()` for send-side
 | 
					 | 
				
			||||||
# validate raising!
 | 
					 | 
				
			||||||
def validate_payload_msg(
 | 
					def validate_payload_msg(
 | 
				
			||||||
    msg: Started|Yield|Return,
 | 
					    pld_msg: Started|Yield|Return,
 | 
				
			||||||
 | 
					    pld_value: PayloadT,
 | 
				
			||||||
 | 
					    ipc: Context|MsgStream,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    raise_mte: bool = True,
 | 
				
			||||||
 | 
					    strict_pld_parity: bool = False,
 | 
				
			||||||
 | 
					    hide_tb: bool = True,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
) -> MsgTypeError|None:
 | 
					) -> MsgTypeError|None:
 | 
				
			||||||
    ...
 | 
					    '''
 | 
				
			||||||
 | 
					    Validate a `PayloadMsg.pld` value with the current
 | 
				
			||||||
 | 
					    IPC ctx's `PldRx` and raise an appropriate `MsgTypeError`
 | 
				
			||||||
 | 
					    on failure.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    '''
 | 
				
			||||||
 | 
					    __tracebackhide__: bool = hide_tb
 | 
				
			||||||
 | 
					    codec: MsgCodec = current_codec()
 | 
				
			||||||
 | 
					    msg_bytes: bytes = codec.encode(pld_msg)
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        roundtripped: Started = codec.decode(msg_bytes)
 | 
				
			||||||
 | 
					        ctx: Context = getattr(ipc, 'ctx', ipc)
 | 
				
			||||||
 | 
					        pld: PayloadT = ctx.pld_rx.dec_msg(
 | 
				
			||||||
 | 
					            msg=roundtripped,
 | 
				
			||||||
 | 
					            ipc=ipc,
 | 
				
			||||||
 | 
					            expect_msg=Started,
 | 
				
			||||||
 | 
					            hide_tb=hide_tb,
 | 
				
			||||||
 | 
					            is_started_send_side=True,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        if (
 | 
				
			||||||
 | 
					            strict_pld_parity
 | 
				
			||||||
 | 
					            and
 | 
				
			||||||
 | 
					            pld != pld_value
 | 
				
			||||||
 | 
					        ):
 | 
				
			||||||
 | 
					            # TODO: make that one a mod func too..
 | 
				
			||||||
 | 
					            diff = pretty_struct.Struct.__sub__(
 | 
				
			||||||
 | 
					                roundtripped,
 | 
				
			||||||
 | 
					                pld_msg,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            complaint: str = (
 | 
				
			||||||
 | 
					                'Started value does not match after roundtrip?\n\n'
 | 
				
			||||||
 | 
					                f'{diff}'
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            raise ValidationError(complaint)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # raise any msg type error NO MATTER WHAT!
 | 
				
			||||||
 | 
					    except ValidationError as verr:
 | 
				
			||||||
 | 
					        mte: MsgTypeError = _mk_msg_type_err(
 | 
				
			||||||
 | 
					            msg=roundtripped,
 | 
				
			||||||
 | 
					            codec=codec,
 | 
				
			||||||
 | 
					            src_validation_error=verr,
 | 
				
			||||||
 | 
					            verb_header='Trying to send ',
 | 
				
			||||||
 | 
					            is_invalid_payload=True,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        if not raise_mte:
 | 
				
			||||||
 | 
					            return mte
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        raise mte from verr
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue