Factor `.started()` validation into `.msg._ops`
Filling out the helper `validate_payload_msg()` staged in a prior commit and adjusting all imports to match. Also add a `raise_mte: bool` flag for potential usage where the caller wants to handle the MTE instance themselves.runtime_to_msgspec
parent
f7fd8278af
commit
6c2efc96dc
|
@ -58,9 +58,6 @@ from typing import (
|
|||
import warnings
|
||||
# ------ - ------
|
||||
import trio
|
||||
from msgspec import (
|
||||
ValidationError,
|
||||
)
|
||||
# ------ - ------
|
||||
from ._exceptions import (
|
||||
ContextCancelled,
|
||||
|
@ -78,19 +75,16 @@ from .log import (
|
|||
from .msg import (
|
||||
Error,
|
||||
MsgType,
|
||||
MsgCodec,
|
||||
NamespacePath,
|
||||
PayloadT,
|
||||
Started,
|
||||
Stop,
|
||||
Yield,
|
||||
current_codec,
|
||||
pretty_struct,
|
||||
_ops as msgops,
|
||||
)
|
||||
from ._ipc import (
|
||||
Channel,
|
||||
_mk_msg_type_err,
|
||||
)
|
||||
from ._streaming import MsgStream
|
||||
from ._state import (
|
||||
|
@ -1657,54 +1651,21 @@ class Context:
|
|||
#
|
||||
__tracebackhide__: bool = hide_tb
|
||||
if validate_pld_spec:
|
||||
# __tracebackhide__: bool = False
|
||||
codec: MsgCodec = current_codec()
|
||||
msg_bytes: bytes = codec.encode(started_msg)
|
||||
try:
|
||||
roundtripped: Started = codec.decode(msg_bytes)
|
||||
# pld: PayloadT = await self.pld_rx.recv_pld(
|
||||
pld: PayloadT = self.pld_rx.dec_msg(
|
||||
msg=roundtripped,
|
||||
ipc=self,
|
||||
expect_msg=Started,
|
||||
hide_tb=hide_tb,
|
||||
is_started_send_side=True,
|
||||
)
|
||||
if (
|
||||
strict_pld_parity
|
||||
and
|
||||
pld != value
|
||||
):
|
||||
# TODO: make that one a mod func too..
|
||||
diff = pretty_struct.Struct.__sub__(
|
||||
roundtripped,
|
||||
started_msg,
|
||||
)
|
||||
complaint: str = (
|
||||
'Started value does not match after roundtrip?\n\n'
|
||||
f'{diff}'
|
||||
)
|
||||
raise ValidationError(complaint)
|
||||
|
||||
# raise any msg type error NO MATTER WHAT!
|
||||
except ValidationError as verr:
|
||||
# always show this src frame in the tb
|
||||
# __tracebackhide__: bool = False
|
||||
raise _mk_msg_type_err(
|
||||
msg=roundtripped,
|
||||
codec=codec,
|
||||
src_validation_error=verr,
|
||||
verb_header='Trying to send ',
|
||||
is_invalid_payload=True,
|
||||
) from verr
|
||||
msgops.validate_payload_msg(
|
||||
pld_msg=started_msg,
|
||||
pld_value=value,
|
||||
ipc=self,
|
||||
strict_pld_parity=strict_pld_parity,
|
||||
hide_tb=hide_tb,
|
||||
)
|
||||
|
||||
# TODO: maybe a flag to by-pass encode op if already done
|
||||
# here in caller?
|
||||
await self.chan.send(started_msg)
|
||||
|
||||
# set msg-related internal runtime-state
|
||||
self._started_called = True
|
||||
self._started_msg = started_msg
|
||||
self._started_called: bool = True
|
||||
self._started_msg: Started = started_msg
|
||||
self._started_pld = value
|
||||
|
||||
async def _drain_overflows(
|
||||
|
|
|
@ -53,6 +53,8 @@ from tractor._state import current_ipc_ctx
|
|||
from ._codec import (
|
||||
mk_dec,
|
||||
MsgDec,
|
||||
MsgCodec,
|
||||
current_codec,
|
||||
)
|
||||
from .types import (
|
||||
CancelAck,
|
||||
|
@ -737,9 +739,61 @@ async def drain_to_final_msg(
|
|||
)
|
||||
|
||||
|
||||
# TODO: factor logic from `.Context.started()` for send-side
|
||||
# validate raising!
|
||||
def validate_payload_msg(
|
||||
msg: Started|Yield|Return,
|
||||
pld_msg: Started|Yield|Return,
|
||||
pld_value: PayloadT,
|
||||
ipc: Context|MsgStream,
|
||||
|
||||
raise_mte: bool = True,
|
||||
strict_pld_parity: bool = False,
|
||||
hide_tb: bool = True,
|
||||
|
||||
) -> MsgTypeError|None:
|
||||
...
|
||||
'''
|
||||
Validate a `PayloadMsg.pld` value with the current
|
||||
IPC ctx's `PldRx` and raise an appropriate `MsgTypeError`
|
||||
on failure.
|
||||
|
||||
'''
|
||||
__tracebackhide__: bool = hide_tb
|
||||
codec: MsgCodec = current_codec()
|
||||
msg_bytes: bytes = codec.encode(pld_msg)
|
||||
try:
|
||||
roundtripped: Started = codec.decode(msg_bytes)
|
||||
ctx: Context = getattr(ipc, 'ctx', ipc)
|
||||
pld: PayloadT = ctx.pld_rx.dec_msg(
|
||||
msg=roundtripped,
|
||||
ipc=ipc,
|
||||
expect_msg=Started,
|
||||
hide_tb=hide_tb,
|
||||
is_started_send_side=True,
|
||||
)
|
||||
if (
|
||||
strict_pld_parity
|
||||
and
|
||||
pld != pld_value
|
||||
):
|
||||
# TODO: make that one a mod func too..
|
||||
diff = pretty_struct.Struct.__sub__(
|
||||
roundtripped,
|
||||
pld_msg,
|
||||
)
|
||||
complaint: str = (
|
||||
'Started value does not match after roundtrip?\n\n'
|
||||
f'{diff}'
|
||||
)
|
||||
raise ValidationError(complaint)
|
||||
|
||||
# raise any msg type error NO MATTER WHAT!
|
||||
except ValidationError as verr:
|
||||
mte: MsgTypeError = _mk_msg_type_err(
|
||||
msg=roundtripped,
|
||||
codec=codec,
|
||||
src_validation_error=verr,
|
||||
verb_header='Trying to send ',
|
||||
is_invalid_payload=True,
|
||||
)
|
||||
if not raise_mte:
|
||||
return mte
|
||||
|
||||
raise mte from verr
|
||||
|
|
Loading…
Reference in New Issue