Factor `.started()` validation into `.msg._ops`

Filling out the helper `validate_payload_msg()` staged in a prior commit
and adjusting all imports to match.

Also add a `raise_mte: bool` flag for potential usage where the caller
wants to handle the MTE instance themselves.
runtime_to_msgspec
Tyler Goodlet 2024-05-28 11:08:27 -04:00
parent f7fd8278af
commit 6c2efc96dc
2 changed files with 67 additions and 52 deletions

View File

@ -58,9 +58,6 @@ from typing import (
import warnings import warnings
# ------ - ------ # ------ - ------
import trio import trio
from msgspec import (
ValidationError,
)
# ------ - ------ # ------ - ------
from ._exceptions import ( from ._exceptions import (
ContextCancelled, ContextCancelled,
@ -78,19 +75,16 @@ from .log import (
from .msg import ( from .msg import (
Error, Error,
MsgType, MsgType,
MsgCodec,
NamespacePath, NamespacePath,
PayloadT, PayloadT,
Started, Started,
Stop, Stop,
Yield, Yield,
current_codec,
pretty_struct, pretty_struct,
_ops as msgops, _ops as msgops,
) )
from ._ipc import ( from ._ipc import (
Channel, Channel,
_mk_msg_type_err,
) )
from ._streaming import MsgStream from ._streaming import MsgStream
from ._state import ( from ._state import (
@ -1657,54 +1651,21 @@ class Context:
# #
__tracebackhide__: bool = hide_tb __tracebackhide__: bool = hide_tb
if validate_pld_spec: if validate_pld_spec:
# __tracebackhide__: bool = False msgops.validate_payload_msg(
codec: MsgCodec = current_codec() pld_msg=started_msg,
msg_bytes: bytes = codec.encode(started_msg) pld_value=value,
try:
roundtripped: Started = codec.decode(msg_bytes)
# pld: PayloadT = await self.pld_rx.recv_pld(
pld: PayloadT = self.pld_rx.dec_msg(
msg=roundtripped,
ipc=self, ipc=self,
expect_msg=Started, strict_pld_parity=strict_pld_parity,
hide_tb=hide_tb, hide_tb=hide_tb,
is_started_send_side=True,
) )
if (
strict_pld_parity
and
pld != value
):
# TODO: make that one a mod func too..
diff = pretty_struct.Struct.__sub__(
roundtripped,
started_msg,
)
complaint: str = (
'Started value does not match after roundtrip?\n\n'
f'{diff}'
)
raise ValidationError(complaint)
# raise any msg type error NO MATTER WHAT!
except ValidationError as verr:
# always show this src frame in the tb
# __tracebackhide__: bool = False
raise _mk_msg_type_err(
msg=roundtripped,
codec=codec,
src_validation_error=verr,
verb_header='Trying to send ',
is_invalid_payload=True,
) from verr
# TODO: maybe a flag to by-pass encode op if already done # TODO: maybe a flag to by-pass encode op if already done
# here in caller? # here in caller?
await self.chan.send(started_msg) await self.chan.send(started_msg)
# set msg-related internal runtime-state # set msg-related internal runtime-state
self._started_called = True self._started_called: bool = True
self._started_msg = started_msg self._started_msg: Started = started_msg
self._started_pld = value self._started_pld = value
async def _drain_overflows( async def _drain_overflows(

View File

@ -53,6 +53,8 @@ from tractor._state import current_ipc_ctx
from ._codec import ( from ._codec import (
mk_dec, mk_dec,
MsgDec, MsgDec,
MsgCodec,
current_codec,
) )
from .types import ( from .types import (
CancelAck, CancelAck,
@ -737,9 +739,61 @@ async def drain_to_final_msg(
) )
# TODO: factor logic from `.Context.started()` for send-side
# validate raising!
def validate_payload_msg( def validate_payload_msg(
msg: Started|Yield|Return, pld_msg: Started|Yield|Return,
pld_value: PayloadT,
ipc: Context|MsgStream,
raise_mte: bool = True,
strict_pld_parity: bool = False,
hide_tb: bool = True,
) -> MsgTypeError|None: ) -> MsgTypeError|None:
... '''
Validate a `PayloadMsg.pld` value with the current
IPC ctx's `PldRx` and raise an appropriate `MsgTypeError`
on failure.
'''
__tracebackhide__: bool = hide_tb
codec: MsgCodec = current_codec()
msg_bytes: bytes = codec.encode(pld_msg)
try:
roundtripped: Started = codec.decode(msg_bytes)
ctx: Context = getattr(ipc, 'ctx', ipc)
pld: PayloadT = ctx.pld_rx.dec_msg(
msg=roundtripped,
ipc=ipc,
expect_msg=Started,
hide_tb=hide_tb,
is_started_send_side=True,
)
if (
strict_pld_parity
and
pld != pld_value
):
# TODO: make that one a mod func too..
diff = pretty_struct.Struct.__sub__(
roundtripped,
pld_msg,
)
complaint: str = (
'Started value does not match after roundtrip?\n\n'
f'{diff}'
)
raise ValidationError(complaint)
# raise any msg type error NO MATTER WHAT!
except ValidationError as verr:
mte: MsgTypeError = _mk_msg_type_err(
msg=roundtripped,
codec=codec,
src_validation_error=verr,
verb_header='Trying to send ',
is_invalid_payload=True,
)
if not raise_mte:
return mte
raise mte from verr