forked from goodboy/tractor
				
			Factor `.started()` validation into `.msg._ops`
Filling out the helper `validate_payload_msg()` staged in a prior commit and adjusting all imports to match. Also add a `raise_mte: bool` flag for potential usage where the caller wants to handle the MTE instance themselves.runtime_to_msgspec
							parent
							
								
									f7fd8278af
								
							
						
					
					
						commit
						6c2efc96dc
					
				| 
						 | 
				
			
			@ -58,9 +58,6 @@ from typing import (
 | 
			
		|||
import warnings
 | 
			
		||||
# ------ - ------
 | 
			
		||||
import trio
 | 
			
		||||
from msgspec import (
 | 
			
		||||
    ValidationError,
 | 
			
		||||
)
 | 
			
		||||
# ------ - ------
 | 
			
		||||
from ._exceptions import (
 | 
			
		||||
    ContextCancelled,
 | 
			
		||||
| 
						 | 
				
			
			@ -78,19 +75,16 @@ from .log import (
 | 
			
		|||
from .msg import (
 | 
			
		||||
    Error,
 | 
			
		||||
    MsgType,
 | 
			
		||||
    MsgCodec,
 | 
			
		||||
    NamespacePath,
 | 
			
		||||
    PayloadT,
 | 
			
		||||
    Started,
 | 
			
		||||
    Stop,
 | 
			
		||||
    Yield,
 | 
			
		||||
    current_codec,
 | 
			
		||||
    pretty_struct,
 | 
			
		||||
    _ops as msgops,
 | 
			
		||||
)
 | 
			
		||||
from ._ipc import (
 | 
			
		||||
    Channel,
 | 
			
		||||
    _mk_msg_type_err,
 | 
			
		||||
)
 | 
			
		||||
from ._streaming import MsgStream
 | 
			
		||||
from ._state import (
 | 
			
		||||
| 
						 | 
				
			
			@ -1657,54 +1651,21 @@ class Context:
 | 
			
		|||
        #
 | 
			
		||||
        __tracebackhide__: bool = hide_tb
 | 
			
		||||
        if validate_pld_spec:
 | 
			
		||||
            # __tracebackhide__: bool = False
 | 
			
		||||
            codec: MsgCodec = current_codec()
 | 
			
		||||
            msg_bytes: bytes = codec.encode(started_msg)
 | 
			
		||||
            try:
 | 
			
		||||
                roundtripped: Started = codec.decode(msg_bytes)
 | 
			
		||||
                # pld: PayloadT = await self.pld_rx.recv_pld(
 | 
			
		||||
                pld: PayloadT = self.pld_rx.dec_msg(
 | 
			
		||||
                    msg=roundtripped,
 | 
			
		||||
                    ipc=self,
 | 
			
		||||
                    expect_msg=Started,
 | 
			
		||||
                    hide_tb=hide_tb,
 | 
			
		||||
                    is_started_send_side=True,
 | 
			
		||||
                )
 | 
			
		||||
                if (
 | 
			
		||||
                    strict_pld_parity
 | 
			
		||||
                    and
 | 
			
		||||
                    pld != value
 | 
			
		||||
                ):
 | 
			
		||||
                    # TODO: make that one a mod func too..
 | 
			
		||||
                    diff = pretty_struct.Struct.__sub__(
 | 
			
		||||
                        roundtripped,
 | 
			
		||||
                        started_msg,
 | 
			
		||||
                    )
 | 
			
		||||
                    complaint: str = (
 | 
			
		||||
                        'Started value does not match after roundtrip?\n\n'
 | 
			
		||||
                        f'{diff}'
 | 
			
		||||
                    )
 | 
			
		||||
                    raise ValidationError(complaint)
 | 
			
		||||
 | 
			
		||||
            # raise any msg type error NO MATTER WHAT!
 | 
			
		||||
            except ValidationError as verr:
 | 
			
		||||
                # always show this src frame in the tb
 | 
			
		||||
                # __tracebackhide__: bool = False
 | 
			
		||||
                raise _mk_msg_type_err(
 | 
			
		||||
                    msg=roundtripped,
 | 
			
		||||
                    codec=codec,
 | 
			
		||||
                    src_validation_error=verr,
 | 
			
		||||
                    verb_header='Trying to send ',
 | 
			
		||||
                    is_invalid_payload=True,
 | 
			
		||||
                ) from verr
 | 
			
		||||
            msgops.validate_payload_msg(
 | 
			
		||||
                pld_msg=started_msg,
 | 
			
		||||
                pld_value=value,
 | 
			
		||||
                ipc=self,
 | 
			
		||||
                strict_pld_parity=strict_pld_parity,
 | 
			
		||||
                hide_tb=hide_tb,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # TODO: maybe a flag to by-pass encode op if already done
 | 
			
		||||
        # here in caller?
 | 
			
		||||
        await self.chan.send(started_msg)
 | 
			
		||||
 | 
			
		||||
        # set msg-related internal runtime-state
 | 
			
		||||
        self._started_called = True
 | 
			
		||||
        self._started_msg = started_msg
 | 
			
		||||
        self._started_called: bool = True
 | 
			
		||||
        self._started_msg: Started = started_msg
 | 
			
		||||
        self._started_pld = value
 | 
			
		||||
 | 
			
		||||
    async def _drain_overflows(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -53,6 +53,8 @@ from tractor._state import current_ipc_ctx
 | 
			
		|||
from ._codec import (
 | 
			
		||||
    mk_dec,
 | 
			
		||||
    MsgDec,
 | 
			
		||||
    MsgCodec,
 | 
			
		||||
    current_codec,
 | 
			
		||||
)
 | 
			
		||||
from .types import (
 | 
			
		||||
    CancelAck,
 | 
			
		||||
| 
						 | 
				
			
			@ -737,9 +739,61 @@ async def drain_to_final_msg(
 | 
			
		|||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO: factor logic from `.Context.started()` for send-side
 | 
			
		||||
# validate raising!
 | 
			
		||||
def validate_payload_msg(
 | 
			
		||||
    msg: Started|Yield|Return,
 | 
			
		||||
    pld_msg: Started|Yield|Return,
 | 
			
		||||
    pld_value: PayloadT,
 | 
			
		||||
    ipc: Context|MsgStream,
 | 
			
		||||
 | 
			
		||||
    raise_mte: bool = True,
 | 
			
		||||
    strict_pld_parity: bool = False,
 | 
			
		||||
    hide_tb: bool = True,
 | 
			
		||||
 | 
			
		||||
) -> MsgTypeError|None:
 | 
			
		||||
    ...
 | 
			
		||||
    '''
 | 
			
		||||
    Validate a `PayloadMsg.pld` value with the current
 | 
			
		||||
    IPC ctx's `PldRx` and raise an appropriate `MsgTypeError`
 | 
			
		||||
    on failure.
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    __tracebackhide__: bool = hide_tb
 | 
			
		||||
    codec: MsgCodec = current_codec()
 | 
			
		||||
    msg_bytes: bytes = codec.encode(pld_msg)
 | 
			
		||||
    try:
 | 
			
		||||
        roundtripped: Started = codec.decode(msg_bytes)
 | 
			
		||||
        ctx: Context = getattr(ipc, 'ctx', ipc)
 | 
			
		||||
        pld: PayloadT = ctx.pld_rx.dec_msg(
 | 
			
		||||
            msg=roundtripped,
 | 
			
		||||
            ipc=ipc,
 | 
			
		||||
            expect_msg=Started,
 | 
			
		||||
            hide_tb=hide_tb,
 | 
			
		||||
            is_started_send_side=True,
 | 
			
		||||
        )
 | 
			
		||||
        if (
 | 
			
		||||
            strict_pld_parity
 | 
			
		||||
            and
 | 
			
		||||
            pld != pld_value
 | 
			
		||||
        ):
 | 
			
		||||
            # TODO: make that one a mod func too..
 | 
			
		||||
            diff = pretty_struct.Struct.__sub__(
 | 
			
		||||
                roundtripped,
 | 
			
		||||
                pld_msg,
 | 
			
		||||
            )
 | 
			
		||||
            complaint: str = (
 | 
			
		||||
                'Started value does not match after roundtrip?\n\n'
 | 
			
		||||
                f'{diff}'
 | 
			
		||||
            )
 | 
			
		||||
            raise ValidationError(complaint)
 | 
			
		||||
 | 
			
		||||
    # raise any msg type error NO MATTER WHAT!
 | 
			
		||||
    except ValidationError as verr:
 | 
			
		||||
        mte: MsgTypeError = _mk_msg_type_err(
 | 
			
		||||
            msg=roundtripped,
 | 
			
		||||
            codec=codec,
 | 
			
		||||
            src_validation_error=verr,
 | 
			
		||||
            verb_header='Trying to send ',
 | 
			
		||||
            is_invalid_payload=True,
 | 
			
		||||
        )
 | 
			
		||||
        if not raise_mte:
 | 
			
		||||
            return mte
 | 
			
		||||
 | 
			
		||||
        raise mte from verr
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue