Add `MsgTypeError` "bad msg" capture

Such that if caught by user code and/or the runtime we can introspect
the original msg which caused the type error. Previously this was kinda
half-baked with a `.msg_dict` which was delivered from an `Any`-decode
of the shuttle msg in `_mk_msg_type_err()` but now this more explicitly
refines the API and supports both `PayloadMsg`-instance or the msg-dict
style injection:
- allow passing either of `bad_msg: PayloadMsg|None` or
  `bad_msg_as_dict: dict|None` to `MsgTypeError.from_decode()`.
- expose public props for both ^ whilst dropping prior `.msgdict`.
- rework `.from_decode()` to explicitly accept `**extra_msgdata: dict`
  |_ only overriding it from any `bad_msg_as_dict` if the keys are found in
    `_ipcmsg_keys`, **except** for `_bad_msg` when `bad_msg` is passed.
  |_ drop `.ipc_msg` passthrough.
  |_ drop `msgdict` input.
- adjust `.cid` to only pull from the `.bad_msg` if set.

Related fixes/adjustments:
- `pack_from_raise()` should pull `boxed_type_str` from
  `boxed_type.__name__`, not the `type()` of it.. also add a
  `hide_tb: bool` flag.
- don't include `_msg_dict` and `_bad_msg` in the `_body_fields` set.
- allow more granular boxed traceback-str controls:
  |_ allow passing a `tb_str: str` explicitly in which case we use it
    verbatim and presume caller knows what they're doing.
  |_ when not provided, use the more explicit
    `traceback.format_exception(exc)` since the error instance is
    a required input (we still fail back to the old `.format_exc()` call
    if for some reason the caller passes `None`; but that should be
    a bug right?).
  |_ if a `tb: TracebackType` and a `tb_str` is passed, concat them.
- in `RemoteActorError.pformat()` don't indent the `._message` part used
  for the `body` when `with_type_header == False`.
- update `_mk_msg_type_err()` to use `bad_msg`/`bad_msg_as_dict`
  appropriately and drop passing `ipc_msg`.
runtime_to_msgspec
Tyler Goodlet 2024-05-27 22:36:05 -04:00
parent 42ba855d1b
commit eee4c61b51
1 changed files with 148 additions and 72 deletions

View File

@ -22,6 +22,9 @@ from __future__ import annotations
import builtins import builtins
import importlib import importlib
from pprint import pformat from pprint import pformat
from types import (
TracebackType,
)
from typing import ( from typing import (
Any, Any,
Callable, Callable,
@ -92,26 +95,30 @@ _ipcmsg_keys: list[str] = [
fi.name fi.name
for fi, k, v for fi, k, v
in iter_fields(Error) in iter_fields(Error)
] ]
_body_fields: list[str] = list( _body_fields: list[str] = list(
set(_ipcmsg_keys) set(_ipcmsg_keys)
# NOTE: don't show fields that either don't provide # XXX NOTE: DON'T-SHOW-FIELDS
# any extra useful info or that are already shown # - don't provide any extra useful info or,
# as part of `.__repr__()` output. # - are already shown as part of `.__repr__()` or,
# - are sub-type specific.
- { - {
'src_type_str', 'src_type_str',
'boxed_type_str', 'boxed_type_str',
'tb_str', 'tb_str',
'relay_path', 'relay_path',
'_msg_dict',
'cid', 'cid',
# since only ctxc should show it but `Error` does # only ctxc should show it but `Error` does
# have it as an optional field. # have it as an optional field.
'canceller', 'canceller',
# only for MTEs and generally only used
# when devving/testing/debugging.
'_msg_dict',
'_bad_msg',
} }
) )
@ -146,6 +153,7 @@ def pack_from_raise(
|MsgTypeError |MsgTypeError
), ),
cid: str, cid: str,
hide_tb: bool = True,
**rae_fields, **rae_fields,
@ -156,7 +164,7 @@ def pack_from_raise(
`Error`-msg using `pack_error()` to extract the tb info. `Error`-msg using `pack_error()` to extract the tb info.
''' '''
__tracebackhide__: bool = True __tracebackhide__: bool = hide_tb
try: try:
raise local_err raise local_err
except type(local_err) as local_err: except type(local_err) as local_err:
@ -231,7 +239,8 @@ class RemoteActorError(Exception):
if ( if (
extra_msgdata extra_msgdata
and ipc_msg and
ipc_msg
): ):
# XXX mutate the orig msg directly from # XXX mutate the orig msg directly from
# manually provided input params. # manually provided input params.
@ -261,17 +270,16 @@ class RemoteActorError(Exception):
# either by customizing `ContextCancelled.__init__()` or # either by customizing `ContextCancelled.__init__()` or
# through a special factor func? # through a special factor func?
elif boxed_type: elif boxed_type:
boxed_type_str: str = type(boxed_type).__name__ boxed_type_str: str = boxed_type.__name__
if ( if (
ipc_msg ipc_msg
and not self._ipc_msg.boxed_type_str and
self._ipc_msg.boxed_type_str != boxed_type_str
): ):
self._ipc_msg.boxed_type_str = boxed_type_str self._ipc_msg.boxed_type_str = boxed_type_str
assert self.boxed_type_str == self._ipc_msg.boxed_type_str assert self.boxed_type_str == self._ipc_msg.boxed_type_str
else: # ensure any roundtripping evals to the input value
self._extra_msgdata['boxed_type_str'] = boxed_type_str
assert self.boxed_type is boxed_type assert self.boxed_type is boxed_type
@property @property
@ -309,7 +317,9 @@ class RemoteActorError(Exception):
if self._ipc_msg if self._ipc_msg
else {} else {}
) )
return self._extra_msgdata | msgdata return {
k: v for k, v in self._extra_msgdata.items()
} | msgdata
@property @property
def src_type_str(self) -> str: def src_type_str(self) -> str:
@ -502,6 +512,8 @@ class RemoteActorError(Exception):
''' '''
header: str = '' header: str = ''
body: str = ''
if with_type_header: if with_type_header:
header: str = f'<{type(self).__name__}(\n' header: str = f'<{type(self).__name__}(\n'
@ -525,24 +537,22 @@ class RemoteActorError(Exception):
) )
if not with_type_header: if not with_type_header:
body = '\n' + body body = '\n' + body
else:
first: str = ''
message: str = self._message
elif message := self._message:
# split off the first line so it isn't indented # split off the first line so it isn't indented
# the same like the "boxed content". # the same like the "boxed content".
if not with_type_header: if not with_type_header:
lines: list[str] = message.splitlines() lines: list[str] = message.splitlines()
first = lines[0] first: str = lines[0]
message = ''.join(lines[1:]) message: str = message.removeprefix(first)
else:
first: str = ''
body: str = ( body: str = (
first first
+ +
textwrap.indent( message
message,
prefix=' ',
)
+ +
'\n' '\n'
) )
@ -708,52 +718,72 @@ class MsgTypeError(
] ]
@property @property
def msg_dict(self) -> dict[str, Any]: def bad_msg(self) -> PayloadMsg|None:
''' '''
If the underlying IPC `MsgType` was received from a remote Ref to the the original invalid IPC shuttle msg which failed
actor but was unable to be decoded to a native to decode thus providing for the reason for this error.
`Yield`|`Started`|`Return` struct, the interchange backend
native format decoder can be used to stash a `dict`
version for introspection by the invalidating RPC task.
''' '''
return self.msgdata.get('_msg_dict') if (
(_bad_msg := self.msgdata.get('_bad_msg'))
and
isinstance(_bad_msg, PayloadMsg)
):
return _bad_msg
@property elif bad_msg_dict := self.bad_msg_as_dict:
def expected_msg(self) -> MsgType|None:
'''
Attempt to construct what would have been the original
`MsgType`-with-payload subtype (i.e. an instance from the set
of msgs in `.msg.types._payload_msgs`) which failed
validation.
'''
if msg_dict := self.msg_dict.copy():
return msgtypes.from_dict_msg( return msgtypes.from_dict_msg(
dict_msg=msg_dict, dict_msg=bad_msg_dict.copy(),
# use_pretty=True, # use_pretty=True,
# ^-TODO-^ would luv to use this BUT then the # ^-TODO-^ would luv to use this BUT then the
# `field_prefix` in `pformat_boxed_tb()` cucks it # `field_prefix` in `pformat_boxed_tb()` cucks it
# all up.. XD # all up.. XD
) )
return None return None
@property
def bad_msg_as_dict(self) -> dict[str, Any]:
'''
If the underlying IPC `MsgType` was received from a remote
actor but was unable to be decoded to a native `PayloadMsg`
(`Yield`|`Started`|`Return`) struct, the interchange backend
native format decoder can be used to stash a `dict` version
for introspection by the invalidating RPC task.
Optionally when this error is constructed from
`.from_decode()` the caller can attempt to construct what
would have been the original `MsgType`-with-payload subtype
(i.e. an instance from the set of msgs in
`.msg.types._payload_msgs`) which failed validation.
'''
return self.msgdata.get('_bad_msg_as_dict')
@property @property
def expected_msg_type(self) -> Type[MsgType]|None: def expected_msg_type(self) -> Type[MsgType]|None:
return type(self.expected_msg) return type(self.bad_msg)
@property @property
def cid(self) -> str: def cid(self) -> str:
# pre-packed using `.from_decode()` constructor # pull from required `.bad_msg` ref (or src dict)
return self.msgdata.get('cid') if bad_msg := self.bad_msg:
return bad_msg.cid
return self.msgdata['cid']
@classmethod @classmethod
def from_decode( def from_decode(
cls, cls,
message: str, message: str,
ipc_msg: PayloadMsg|None = None, bad_msg: PayloadMsg|None = None,
msgdict: dict|None = None, bad_msg_as_dict: dict|None = None,
# if provided, expand and pack all RAE compat fields into the
# `._extra_msgdata` auxillary data `dict` internal to
# `RemoteActorError`.
**extra_msgdata,
) -> MsgTypeError: ) -> MsgTypeError:
''' '''
@ -763,25 +793,44 @@ class MsgTypeError(
(which is normally the caller of this). (which is normally the caller of this).
''' '''
# if provided, expand and pack all RAE compat fields into the if bad_msg_as_dict:
# `._extra_msgdata` auxillary data `dict` internal to
# `RemoteActorError`.
extra_msgdata: dict = {}
if msgdict:
extra_msgdata: dict = {
k: v
for k, v in msgdict.items()
if k in _ipcmsg_keys
}
# NOTE: original "vanilla decode" of the msg-bytes # NOTE: original "vanilla decode" of the msg-bytes
# is placed inside a value readable from # is placed inside a value readable from
# `.msgdata['_msg_dict']` # `.msgdata['_msg_dict']`
extra_msgdata['_msg_dict'] = msgdict extra_msgdata['_bad_msg_as_dict'] = bad_msg_as_dict
# scrape out any underlying fields from the
# msg that failed validation.
for k, v in bad_msg_as_dict.items():
if (
# always skip a duplicate entry
# if already provided as an arg
k == '_bad_msg' and bad_msg
or
# skip anything not in the default msg-field set.
k not in _ipcmsg_keys
# k not in _body_fields
):
continue
extra_msgdata[k] = v
elif bad_msg:
if not isinstance(bad_msg, PayloadMsg):
raise TypeError(
'The provided `bad_msg` is not a `PayloadMsg` type?\n\n'
f'{bad_msg}'
)
extra_msgdata['_bad_msg'] = bad_msg
extra_msgdata['cid'] = bad_msg.cid
if 'cid' not in extra_msgdata:
import pdbp; pdbp.set_trace()
return cls( return cls(
message=message, message=message,
boxed_type=cls, boxed_type=cls,
ipc_msg=ipc_msg,
**extra_msgdata, **extra_msgdata,
) )
@ -836,9 +885,10 @@ class MessagingError(Exception):
def pack_error( def pack_error(
exc: BaseException|RemoteActorError, exc: BaseException|RemoteActorError,
tb: str|None = None,
cid: str|None = None, cid: str|None = None,
src_uid: tuple[str, str]|None = None, src_uid: tuple[str, str]|None = None,
tb: TracebackType|None = None,
tb_str: str = '',
) -> Error: ) -> Error:
''' '''
@ -848,10 +898,28 @@ def pack_error(
the receiver side using `unpack_error()` below. the receiver side using `unpack_error()` below.
''' '''
if tb: if not tb_str:
tb_str = ''.join(traceback.format_tb(tb)) tb_str: str = (
''.join(traceback.format_exception(exc))
# TODO: can we remove this is `exc` is required?
or
# NOTE: this is just a shorthand for the "last error" as
# provided by `sys.exeception()`, see:
# - https://docs.python.org/3/library/traceback.html#traceback.print_exc
# - https://docs.python.org/3/library/traceback.html#traceback.format_exc
traceback.format_exc()
)
else: else:
tb_str = traceback.format_exc() if tb_str[-2:] != '\n':
tb_str += '\n'
# when caller provides a tb instance (say pulled from some other
# src error's `.__traceback__`) we use that as the "boxed"
# tb-string instead.
if tb:
# https://docs.python.org/3/library/traceback.html#traceback.format_list
tb_str: str = ''.join(traceback.format_tb(tb)) + tb_str
error_msg: dict[ # for IPC error_msg: dict[ # for IPC
str, str,
@ -1115,7 +1183,7 @@ def _mk_msg_type_err(
src_validation_error: ValidationError|None = None, src_validation_error: ValidationError|None = None,
src_type_error: TypeError|None = None, src_type_error: TypeError|None = None,
is_invalid_payload: bool = False, is_invalid_payload: bool = False,
src_err_msg: Error|None = None, # src_err_msg: Error|None = None,
**mte_kwargs, **mte_kwargs,
@ -1164,10 +1232,10 @@ def _mk_msg_type_err(
'|_ https://jcristharif.com/msgspec/extending.html#defining-a-custom-extension-messagepack-only\n' '|_ https://jcristharif.com/msgspec/extending.html#defining-a-custom-extension-messagepack-only\n'
) )
msgtyperr = MsgTypeError( msgtyperr = MsgTypeError(
message=message, message=message,
ipc_msg=msg, ipc_msg=msg,
bad_msg=msg,
) )
# ya, might be `None` # ya, might be `None`
msgtyperr.__cause__ = src_type_error msgtyperr.__cause__ = src_type_error
@ -1175,6 +1243,9 @@ def _mk_msg_type_err(
# `Channel.recv()` case # `Channel.recv()` case
else: else:
msg_dict: dict|None = None
bad_msg: PayloadMsg|None = None
if is_invalid_payload: if is_invalid_payload:
msg_type: str = type(msg) msg_type: str = type(msg)
any_pld: Any = msgpack.decode(msg.pld) any_pld: Any = msgpack.decode(msg.pld)
@ -1186,19 +1257,20 @@ def _mk_msg_type_err(
# f' |_pld: {codec.pld_spec_str}\n'# != {any_pld!r}\n' # f' |_pld: {codec.pld_spec_str}\n'# != {any_pld!r}\n'
# f')>\n\n' # f')>\n\n'
) )
# src_err_msg = msg
bad_msg = msg
# TODO: should we just decode the msg to a dict despite # TODO: should we just decode the msg to a dict despite
# only the payload being wrong? # only the payload being wrong?
# -[ ] maybe the better design is to break this construct # -[ ] maybe the better design is to break this construct
# logic into a separate explicit helper raiser-func? # logic into a separate explicit helper raiser-func?
msg_dict = None
else: else:
msg: bytes
# decode the msg-bytes using the std msgpack # decode the msg-bytes using the std msgpack
# interchange-prot (i.e. without any # interchange-prot (i.e. without any `msgspec.Struct`
# `msgspec.Struct` handling) so that we can # handling) so that we can determine what
# determine what `.msg.types.Msg` is the culprit # `.msg.types.PayloadMsg` is the culprit by reporting the
# by reporting the received value. # received value.
msg: bytes
msg_dict: dict = msgpack.decode(msg) msg_dict: dict = msgpack.decode(msg)
msg_type_name: str = msg_dict['msg_type'] msg_type_name: str = msg_dict['msg_type']
msg_type = getattr(msgtypes, msg_type_name) msg_type = getattr(msgtypes, msg_type_name)
@ -1235,9 +1307,13 @@ def _mk_msg_type_err(
if verb_header: if verb_header:
message = f'{verb_header} ' + message message = f'{verb_header} ' + message
# if not isinstance(bad_msg, PayloadMsg):
# import pdbp; pdbp.set_trace()
msgtyperr = MsgTypeError.from_decode( msgtyperr = MsgTypeError.from_decode(
message=message, message=message,
msgdict=msg_dict, bad_msg=bad_msg,
bad_msg_as_dict=msg_dict,
# NOTE: for the send-side `.started()` pld-validate # NOTE: for the send-side `.started()` pld-validate
# case we actually set the `._ipc_msg` AFTER we return # case we actually set the `._ipc_msg` AFTER we return
@ -1245,7 +1321,7 @@ def _mk_msg_type_err(
# want to emulate the `Error` from the mte we build here # want to emulate the `Error` from the mte we build here
# Bo # Bo
# so by default in that case this is set to `None` # so by default in that case this is set to `None`
ipc_msg=src_err_msg, # ipc_msg=src_err_msg,
) )
msgtyperr.__cause__ = src_validation_error msgtyperr.__cause__ = src_validation_error
return msgtyperr return msgtyperr