From 78476c9c28d1e3ac8da4867352499cf31a65d875 Mon Sep 17 00:00:00 2001
From: Tyler Goodlet <jgbt@protonmail.com>
Date: Mon, 27 May 2024 22:36:05 -0400
Subject: [PATCH] Add `MsgTypeError` "bad msg" capture

Such that if caught by user code and/or the runtime we can introspect
the original msg which caused the type error. Previously this was kinda
half-baked with a `.msg_dict` which was delivered from an `Any`-decode
of the shuttle msg in `_mk_msg_type_err()` but now this more explicitly
refines the API and supports both `PayloadMsg`-instance or the msg-dict
style injection:
- allow passing either of `bad_msg: PayloadMsg|None` or
  `bad_msg_as_dict: dict|None` to `MsgTypeError.from_decode()`.
- expose public props for both ^ whilst dropping prior `.msgdict`.
- rework `.from_decode()` to explicitly accept `**extra_msgdata: dict`
  |_ only overriding it from any `bad_msg_as_dict` if the keys are found in
    `_ipcmsg_keys`, **except** for `_bad_msg` when `bad_msg` is passed.
  |_ drop `.ipc_msg` passthrough.
  |_ drop `msgdict` input.
- adjust `.cid` to only pull from the `.bad_msg` if set.

Related fixes/adjustments:
- `pack_from_raise()` should pull `boxed_type_str` from
  `boxed_type.__name__`, not the `type()` of it.. also add a
  `hide_tb: bool` flag.
- don't include `_msg_dict` and `_bad_msg` in the `_body_fields` set.
- allow more granular boxed traceback-str controls:
  |_ allow passing a `tb_str: str` explicitly in which case we use it
    verbatim and presume caller knows what they're doing.
  |_ when not provided, use the more explicit
    `traceback.format_exception(exc)` since the error instance is
    a required input (we still fail back to the old `.format_exc()` call
    if for some reason the caller passes `None`; but that should be
    a bug right?).
  |_ if a `tb: TracebackType` and a `tb_str` is passed, concat them.
- in `RemoteActorError.pformat()` don't indent the `._message` part used
  for the `body` when `with_type_header == False`.
- update `_mk_msg_type_err()` to use `bad_msg`/`bad_msg_as_dict`
  appropriately and drop passing `ipc_msg`.
---
 tractor/_exceptions.py | 220 +++++++++++++++++++++++++++--------------
 1 file changed, 148 insertions(+), 72 deletions(-)

diff --git a/tractor/_exceptions.py b/tractor/_exceptions.py
index 9a94bbdb..85957356 100644
--- a/tractor/_exceptions.py
+++ b/tractor/_exceptions.py
@@ -22,6 +22,9 @@ from __future__ import annotations
 import builtins
 import importlib
 from pprint import pformat
+from types import (
+    TracebackType,
+)
 from typing import (
     Any,
     Callable,
@@ -92,26 +95,30 @@ _ipcmsg_keys: list[str] = [
     fi.name
     for fi, k, v
     in iter_fields(Error)
-
 ]
 
 _body_fields: list[str] = list(
     set(_ipcmsg_keys)
 
-    # NOTE: don't show fields that either don't provide
-    # any extra useful info or that are already shown
-    # as part of `.__repr__()` output.
+    # XXX NOTE: DON'T-SHOW-FIELDS
+    # - don't provide any extra useful info or,
+    # - are already shown as part of `.__repr__()` or,
+    # - are sub-type specific.
     - {
         'src_type_str',
         'boxed_type_str',
         'tb_str',
         'relay_path',
-        '_msg_dict',
         'cid',
 
-        # since only ctxc should show it but `Error` does
+        # only ctxc should show it but `Error` does
         # have it as an optional field.
         'canceller',
+
+        # only for MTEs and generally only used
+        # when devving/testing/debugging.
+        '_msg_dict',
+        '_bad_msg',
     }
 )
 
@@ -146,6 +153,7 @@ def pack_from_raise(
         |MsgTypeError
     ),
     cid: str,
+    hide_tb: bool = True,
 
     **rae_fields,
 
@@ -156,7 +164,7 @@ def pack_from_raise(
     `Error`-msg using `pack_error()` to extract the tb info.
 
     '''
-    __tracebackhide__: bool = True
+    __tracebackhide__: bool = hide_tb
     try:
         raise local_err
     except type(local_err) as local_err:
@@ -231,7 +239,8 @@ class RemoteActorError(Exception):
 
         if (
             extra_msgdata
-            and ipc_msg
+            and
+            ipc_msg
         ):
             # XXX mutate the orig msg directly from
             # manually provided input params.
@@ -261,17 +270,16 @@ class RemoteActorError(Exception):
         # either by customizing `ContextCancelled.__init__()` or
         # through a special factor func?
         elif boxed_type:
-            boxed_type_str: str = type(boxed_type).__name__
+            boxed_type_str: str = boxed_type.__name__
             if (
                 ipc_msg
-                and not self._ipc_msg.boxed_type_str
+                and
+                self._ipc_msg.boxed_type_str != boxed_type_str
             ):
                 self._ipc_msg.boxed_type_str = boxed_type_str
                 assert self.boxed_type_str == self._ipc_msg.boxed_type_str
 
-            else:
-                self._extra_msgdata['boxed_type_str'] = boxed_type_str
-
+            # ensure any roundtripping evals to the input value
             assert self.boxed_type is boxed_type
 
     @property
@@ -309,7 +317,9 @@ class RemoteActorError(Exception):
             if self._ipc_msg
             else {}
         )
-        return self._extra_msgdata | msgdata
+        return {
+            k: v for k, v in self._extra_msgdata.items()
+        } | msgdata
 
     @property
     def src_type_str(self) -> str:
@@ -502,6 +512,8 @@ class RemoteActorError(Exception):
 
         '''
         header: str = ''
+        body: str = ''
+
         if with_type_header:
             header: str = f'<{type(self).__name__}(\n'
 
@@ -525,24 +537,22 @@ class RemoteActorError(Exception):
             )
             if not with_type_header:
                 body = '\n' + body
-        else:
-            first: str = ''
-            message: str = self._message
 
+        elif message := self._message:
             # split off the first line so it isn't indented
             # the same like the "boxed content".
             if not with_type_header:
                 lines: list[str] = message.splitlines()
-                first = lines[0]
-                message = ''.join(lines[1:])
+                first: str = lines[0]
+                message: str = message.removeprefix(first)
+
+            else:
+                first: str = ''
 
             body: str = (
                 first
                 +
-                textwrap.indent(
-                    message,
-                    prefix='  ',
-                )
+                message
                 +
                 '\n'
             )
@@ -708,52 +718,72 @@ class MsgTypeError(
     ]
 
     @property
-    def msg_dict(self) -> dict[str, Any]:
+    def bad_msg(self) -> PayloadMsg|None:
         '''
-        If the underlying IPC `MsgType` was received from a remote
-        actor but was unable to be decoded to a native
-        `Yield`|`Started`|`Return` struct, the interchange backend
-        native format decoder can be used to stash a `dict`
-        version for introspection by the invalidating RPC task.
+        Ref to the the original invalid IPC shuttle msg which failed
+        to decode thus providing for the reason for this error.
 
         '''
-        return self.msgdata.get('_msg_dict')
+        if (
+            (_bad_msg := self.msgdata.get('_bad_msg'))
+            and
+            isinstance(_bad_msg, PayloadMsg)
+        ):
+            return _bad_msg
 
-    @property
-    def expected_msg(self) -> MsgType|None:
-        '''
-        Attempt to construct what would have been the original
-        `MsgType`-with-payload subtype (i.e. an instance from the set
-        of msgs in `.msg.types._payload_msgs`) which failed
-        validation.
-
-        '''
-        if msg_dict := self.msg_dict.copy():
+        elif bad_msg_dict := self.bad_msg_as_dict:
             return msgtypes.from_dict_msg(
-                dict_msg=msg_dict,
+                dict_msg=bad_msg_dict.copy(),
                 # use_pretty=True,
                 # ^-TODO-^ would luv to use this BUT then the
                 # `field_prefix` in `pformat_boxed_tb()` cucks it
                 # all up.. XD
             )
+
         return None
 
+    @property
+    def bad_msg_as_dict(self) -> dict[str, Any]:
+        '''
+        If the underlying IPC `MsgType` was received from a remote
+        actor but was unable to be decoded to a native `PayloadMsg`
+        (`Yield`|`Started`|`Return`) struct, the interchange backend
+        native format decoder can be used to stash a `dict` version
+        for introspection by the invalidating RPC task.
+
+        Optionally when this error is constructed from
+        `.from_decode()` the caller can attempt to construct what
+        would have been the original `MsgType`-with-payload subtype
+        (i.e. an instance from the set of msgs in
+        `.msg.types._payload_msgs`) which failed validation.
+
+        '''
+        return self.msgdata.get('_bad_msg_as_dict')
+
     @property
     def expected_msg_type(self) -> Type[MsgType]|None:
-        return type(self.expected_msg)
+        return type(self.bad_msg)
 
     @property
     def cid(self) -> str:
-        # pre-packed using `.from_decode()` constructor
-        return self.msgdata.get('cid')
+        # pull from required `.bad_msg` ref (or src dict)
+        if bad_msg := self.bad_msg:
+            return bad_msg.cid
+
+        return self.msgdata['cid']
 
     @classmethod
     def from_decode(
         cls,
         message: str,
 
-        ipc_msg: PayloadMsg|None = None,
-        msgdict: dict|None = None,
+        bad_msg: PayloadMsg|None = None,
+        bad_msg_as_dict: dict|None = None,
+
+        # if provided, expand and pack all RAE compat fields into the
+        # `._extra_msgdata` auxillary data `dict` internal to
+        # `RemoteActorError`.
+        **extra_msgdata,
 
     ) -> MsgTypeError:
         '''
@@ -763,25 +793,44 @@ class MsgTypeError(
         (which is normally the caller of this).
 
         '''
-        # if provided, expand and pack all RAE compat fields into the
-        # `._extra_msgdata` auxillary data `dict` internal to
-        # `RemoteActorError`.
-        extra_msgdata: dict = {}
-        if msgdict:
-            extra_msgdata: dict = {
-                k: v
-                for k, v in msgdict.items()
-                if k in _ipcmsg_keys
-            }
+        if bad_msg_as_dict:
             # NOTE: original "vanilla decode" of the msg-bytes
             # is placed inside a value readable from
             # `.msgdata['_msg_dict']`
-            extra_msgdata['_msg_dict'] = msgdict
+            extra_msgdata['_bad_msg_as_dict'] = bad_msg_as_dict
+
+            # scrape out any underlying fields from the
+            # msg that failed validation.
+            for k, v in bad_msg_as_dict.items():
+                if (
+                    # always skip a duplicate entry
+                    # if already provided as an arg
+                    k == '_bad_msg' and bad_msg
+                    or
+                    # skip anything not in the default msg-field set.
+                    k not in _ipcmsg_keys
+                    # k not in _body_fields
+                ):
+                    continue
+
+                extra_msgdata[k] = v
+
+
+        elif bad_msg:
+            if not isinstance(bad_msg, PayloadMsg):
+                raise TypeError(
+                    'The provided `bad_msg` is not a `PayloadMsg` type?\n\n'
+                    f'{bad_msg}'
+                )
+            extra_msgdata['_bad_msg'] = bad_msg
+            extra_msgdata['cid'] = bad_msg.cid
+
+        if 'cid' not in extra_msgdata:
+            import pdbp; pdbp.set_trace()
 
         return cls(
             message=message,
             boxed_type=cls,
-            ipc_msg=ipc_msg,
             **extra_msgdata,
         )
 
@@ -836,9 +885,10 @@ class MessagingError(Exception):
 def pack_error(
     exc: BaseException|RemoteActorError,
 
-    tb: str|None = None,
     cid: str|None = None,
     src_uid: tuple[str, str]|None = None,
+    tb: TracebackType|None = None,
+    tb_str: str = '',
 
 ) -> Error:
     '''
@@ -848,10 +898,28 @@ def pack_error(
     the receiver side using `unpack_error()` below.
 
     '''
-    if tb:
-        tb_str = ''.join(traceback.format_tb(tb))
+    if not tb_str:
+        tb_str: str = (
+            ''.join(traceback.format_exception(exc))
+
+            # TODO: can we remove this is `exc` is required?
+            or
+            # NOTE: this is just a shorthand for the "last error" as
+            # provided by `sys.exeception()`, see:
+            # - https://docs.python.org/3/library/traceback.html#traceback.print_exc
+            # - https://docs.python.org/3/library/traceback.html#traceback.format_exc
+            traceback.format_exc()
+        )
     else:
-        tb_str = traceback.format_exc()
+        if tb_str[-2:] != '\n':
+            tb_str += '\n'
+
+    # when caller provides a tb instance (say pulled from some other
+    # src error's `.__traceback__`) we use that as the "boxed"
+    # tb-string instead.
+    if tb:
+        # https://docs.python.org/3/library/traceback.html#traceback.format_list
+        tb_str: str = ''.join(traceback.format_tb(tb)) + tb_str
 
     error_msg: dict[  # for IPC
         str,
@@ -1115,7 +1183,7 @@ def _mk_msg_type_err(
     src_validation_error: ValidationError|None = None,
     src_type_error: TypeError|None = None,
     is_invalid_payload: bool = False,
-    src_err_msg: Error|None = None,
+    # src_err_msg: Error|None = None,
 
     **mte_kwargs,
 
@@ -1164,10 +1232,10 @@ def _mk_msg_type_err(
                     '|_ https://jcristharif.com/msgspec/extending.html#defining-a-custom-extension-messagepack-only\n'
                 )
 
-
         msgtyperr = MsgTypeError(
             message=message,
             ipc_msg=msg,
+            bad_msg=msg,
         )
         # ya, might be `None`
         msgtyperr.__cause__ = src_type_error
@@ -1175,6 +1243,9 @@ def _mk_msg_type_err(
 
     # `Channel.recv()` case
     else:
+        msg_dict: dict|None = None
+        bad_msg: PayloadMsg|None = None
+
         if is_invalid_payload:
             msg_type: str = type(msg)
             any_pld: Any = msgpack.decode(msg.pld)
@@ -1186,19 +1257,20 @@ def _mk_msg_type_err(
                 # f' |_pld: {codec.pld_spec_str}\n'# != {any_pld!r}\n'
                 # f')>\n\n'
             )
+            # src_err_msg = msg
+            bad_msg = msg
             # TODO: should we just decode the msg to a dict despite
             # only the payload being wrong?
             # -[ ] maybe the better design is to break this construct
             #   logic into a separate explicit helper raiser-func?
-            msg_dict = None
 
         else:
-            msg: bytes
             # decode the msg-bytes using the std msgpack
-            # interchange-prot (i.e. without any
-            # `msgspec.Struct` handling) so that we can
-            # determine what `.msg.types.Msg` is the culprit
-            # by reporting the received value.
+            # interchange-prot (i.e. without any `msgspec.Struct`
+            # handling) so that we can determine what
+            # `.msg.types.PayloadMsg` is the culprit by reporting the
+            # received value.
+            msg: bytes
             msg_dict: dict = msgpack.decode(msg)
             msg_type_name: str = msg_dict['msg_type']
             msg_type = getattr(msgtypes, msg_type_name)
@@ -1235,9 +1307,13 @@ def _mk_msg_type_err(
         if verb_header:
             message = f'{verb_header} ' + message
 
+        # if not isinstance(bad_msg, PayloadMsg):
+        #     import pdbp; pdbp.set_trace()
+
         msgtyperr = MsgTypeError.from_decode(
             message=message,
-            msgdict=msg_dict,
+            bad_msg=bad_msg,
+            bad_msg_as_dict=msg_dict,
 
             # NOTE: for the send-side `.started()` pld-validate
             # case we actually set the `._ipc_msg` AFTER we return
@@ -1245,7 +1321,7 @@ def _mk_msg_type_err(
             # want to emulate the `Error` from the mte we build here
             # Bo
             # so by default in that case this is set to `None`
-            ipc_msg=src_err_msg,
+            # ipc_msg=src_err_msg,
         )
         msgtyperr.__cause__ = src_validation_error
         return msgtyperr