Factor `.started()` validation into `.msg._ops`
Filling out the helper `validate_payload_msg()` staged in a prior commit and adjusting all imports to match. Also add a `raise_mte: bool` flag for potential usage where the caller wants to handle the MTE instance themselves.runtime_to_msgspec
parent
f7fd8278af
commit
6c2efc96dc
|
@ -58,9 +58,6 @@ from typing import (
|
||||||
import warnings
|
import warnings
|
||||||
# ------ - ------
|
# ------ - ------
|
||||||
import trio
|
import trio
|
||||||
from msgspec import (
|
|
||||||
ValidationError,
|
|
||||||
)
|
|
||||||
# ------ - ------
|
# ------ - ------
|
||||||
from ._exceptions import (
|
from ._exceptions import (
|
||||||
ContextCancelled,
|
ContextCancelled,
|
||||||
|
@ -78,19 +75,16 @@ from .log import (
|
||||||
from .msg import (
|
from .msg import (
|
||||||
Error,
|
Error,
|
||||||
MsgType,
|
MsgType,
|
||||||
MsgCodec,
|
|
||||||
NamespacePath,
|
NamespacePath,
|
||||||
PayloadT,
|
PayloadT,
|
||||||
Started,
|
Started,
|
||||||
Stop,
|
Stop,
|
||||||
Yield,
|
Yield,
|
||||||
current_codec,
|
|
||||||
pretty_struct,
|
pretty_struct,
|
||||||
_ops as msgops,
|
_ops as msgops,
|
||||||
)
|
)
|
||||||
from ._ipc import (
|
from ._ipc import (
|
||||||
Channel,
|
Channel,
|
||||||
_mk_msg_type_err,
|
|
||||||
)
|
)
|
||||||
from ._streaming import MsgStream
|
from ._streaming import MsgStream
|
||||||
from ._state import (
|
from ._state import (
|
||||||
|
@ -1657,54 +1651,21 @@ class Context:
|
||||||
#
|
#
|
||||||
__tracebackhide__: bool = hide_tb
|
__tracebackhide__: bool = hide_tb
|
||||||
if validate_pld_spec:
|
if validate_pld_spec:
|
||||||
# __tracebackhide__: bool = False
|
msgops.validate_payload_msg(
|
||||||
codec: MsgCodec = current_codec()
|
pld_msg=started_msg,
|
||||||
msg_bytes: bytes = codec.encode(started_msg)
|
pld_value=value,
|
||||||
try:
|
ipc=self,
|
||||||
roundtripped: Started = codec.decode(msg_bytes)
|
strict_pld_parity=strict_pld_parity,
|
||||||
# pld: PayloadT = await self.pld_rx.recv_pld(
|
hide_tb=hide_tb,
|
||||||
pld: PayloadT = self.pld_rx.dec_msg(
|
)
|
||||||
msg=roundtripped,
|
|
||||||
ipc=self,
|
|
||||||
expect_msg=Started,
|
|
||||||
hide_tb=hide_tb,
|
|
||||||
is_started_send_side=True,
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
strict_pld_parity
|
|
||||||
and
|
|
||||||
pld != value
|
|
||||||
):
|
|
||||||
# TODO: make that one a mod func too..
|
|
||||||
diff = pretty_struct.Struct.__sub__(
|
|
||||||
roundtripped,
|
|
||||||
started_msg,
|
|
||||||
)
|
|
||||||
complaint: str = (
|
|
||||||
'Started value does not match after roundtrip?\n\n'
|
|
||||||
f'{diff}'
|
|
||||||
)
|
|
||||||
raise ValidationError(complaint)
|
|
||||||
|
|
||||||
# raise any msg type error NO MATTER WHAT!
|
|
||||||
except ValidationError as verr:
|
|
||||||
# always show this src frame in the tb
|
|
||||||
# __tracebackhide__: bool = False
|
|
||||||
raise _mk_msg_type_err(
|
|
||||||
msg=roundtripped,
|
|
||||||
codec=codec,
|
|
||||||
src_validation_error=verr,
|
|
||||||
verb_header='Trying to send ',
|
|
||||||
is_invalid_payload=True,
|
|
||||||
) from verr
|
|
||||||
|
|
||||||
# TODO: maybe a flag to by-pass encode op if already done
|
# TODO: maybe a flag to by-pass encode op if already done
|
||||||
# here in caller?
|
# here in caller?
|
||||||
await self.chan.send(started_msg)
|
await self.chan.send(started_msg)
|
||||||
|
|
||||||
# set msg-related internal runtime-state
|
# set msg-related internal runtime-state
|
||||||
self._started_called = True
|
self._started_called: bool = True
|
||||||
self._started_msg = started_msg
|
self._started_msg: Started = started_msg
|
||||||
self._started_pld = value
|
self._started_pld = value
|
||||||
|
|
||||||
async def _drain_overflows(
|
async def _drain_overflows(
|
||||||
|
|
|
@ -53,6 +53,8 @@ from tractor._state import current_ipc_ctx
|
||||||
from ._codec import (
|
from ._codec import (
|
||||||
mk_dec,
|
mk_dec,
|
||||||
MsgDec,
|
MsgDec,
|
||||||
|
MsgCodec,
|
||||||
|
current_codec,
|
||||||
)
|
)
|
||||||
from .types import (
|
from .types import (
|
||||||
CancelAck,
|
CancelAck,
|
||||||
|
@ -737,9 +739,61 @@ async def drain_to_final_msg(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: factor logic from `.Context.started()` for send-side
|
|
||||||
# validate raising!
|
|
||||||
def validate_payload_msg(
|
def validate_payload_msg(
|
||||||
msg: Started|Yield|Return,
|
pld_msg: Started|Yield|Return,
|
||||||
|
pld_value: PayloadT,
|
||||||
|
ipc: Context|MsgStream,
|
||||||
|
|
||||||
|
raise_mte: bool = True,
|
||||||
|
strict_pld_parity: bool = False,
|
||||||
|
hide_tb: bool = True,
|
||||||
|
|
||||||
) -> MsgTypeError|None:
|
) -> MsgTypeError|None:
|
||||||
...
|
'''
|
||||||
|
Validate a `PayloadMsg.pld` value with the current
|
||||||
|
IPC ctx's `PldRx` and raise an appropriate `MsgTypeError`
|
||||||
|
on failure.
|
||||||
|
|
||||||
|
'''
|
||||||
|
__tracebackhide__: bool = hide_tb
|
||||||
|
codec: MsgCodec = current_codec()
|
||||||
|
msg_bytes: bytes = codec.encode(pld_msg)
|
||||||
|
try:
|
||||||
|
roundtripped: Started = codec.decode(msg_bytes)
|
||||||
|
ctx: Context = getattr(ipc, 'ctx', ipc)
|
||||||
|
pld: PayloadT = ctx.pld_rx.dec_msg(
|
||||||
|
msg=roundtripped,
|
||||||
|
ipc=ipc,
|
||||||
|
expect_msg=Started,
|
||||||
|
hide_tb=hide_tb,
|
||||||
|
is_started_send_side=True,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
strict_pld_parity
|
||||||
|
and
|
||||||
|
pld != pld_value
|
||||||
|
):
|
||||||
|
# TODO: make that one a mod func too..
|
||||||
|
diff = pretty_struct.Struct.__sub__(
|
||||||
|
roundtripped,
|
||||||
|
pld_msg,
|
||||||
|
)
|
||||||
|
complaint: str = (
|
||||||
|
'Started value does not match after roundtrip?\n\n'
|
||||||
|
f'{diff}'
|
||||||
|
)
|
||||||
|
raise ValidationError(complaint)
|
||||||
|
|
||||||
|
# raise any msg type error NO MATTER WHAT!
|
||||||
|
except ValidationError as verr:
|
||||||
|
mte: MsgTypeError = _mk_msg_type_err(
|
||||||
|
msg=roundtripped,
|
||||||
|
codec=codec,
|
||||||
|
src_validation_error=verr,
|
||||||
|
verb_header='Trying to send ',
|
||||||
|
is_invalid_payload=True,
|
||||||
|
)
|
||||||
|
if not raise_mte:
|
||||||
|
return mte
|
||||||
|
|
||||||
|
raise mte from verr
|
||||||
|
|
Loading…
Reference in New Issue