From 4cfe4979ff555bcff2d0257603aa44de38c1de96 Mon Sep 17 00:00:00 2001 From: Tyler Goodlet Date: Fri, 5 Apr 2024 16:34:07 -0400 Subject: [PATCH] Factor `MsgpackTCPStream` msg-type checks Add both the `.send()` and `.recv()` handling blocks to a common `_raise_msg_type_err()` which includes detailed error msg formatting: - the `.recv()` side case does introspection of the `Msg` fields and attempting to report the exact (field type related) issue - `.send()` side does some boxed-error style tb formatting like `RemoteActorError`. - add a `strict_types: bool` to `.send()` to allow for just warning on bad inputs versus raising, but always raise from any `Encoder` type error. --- tractor/_ipc.py | 174 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 123 insertions(+), 51 deletions(-) diff --git a/tractor/_ipc.py b/tractor/_ipc.py index 6168c77..9af28e5 100644 --- a/tractor/_ipc.py +++ b/tractor/_ipc.py @@ -54,7 +54,8 @@ from tractor.msg import ( _ctxvar_MsgCodec, _codec, MsgCodec, - types, + types as msgtypes, + pretty_struct, ) log = get_logger(__name__) @@ -72,6 +73,7 @@ def get_stream_addrs(stream: trio.SocketStream) -> tuple: ) +# TODO: this should be our `Union[*msgtypes.__spec__]` now right? MsgType = TypeVar("MsgType") # TODO: consider using a generic def and indexing with our eventual @@ -116,6 +118,73 @@ class MsgTransport(Protocol[MsgType]): ... +def _raise_msg_type_err( + msg: Any|bytes, + codec: MsgCodec, + validation_err: msgspec.ValidationError|None = None, + verb_header: str = '', + +) -> None: + + # if side == 'send': + if validation_err is None: # send-side + + import traceback + from tractor._exceptions import pformat_boxed_tb + + fmt_spec: str = '\n'.join( + map(str, codec.msg_spec.__args__) + ) + fmt_stack: str = ( + '\n'.join(traceback.format_stack(limit=3)) + ) + tb_fmt: str = pformat_boxed_tb( + tb_str=fmt_stack, + # fields_str=header, + field_prefix=' ', + indent='', + ) + raise MsgTypeError( + f'invalid msg -> {msg}: {type(msg)}\n\n' + f'{tb_fmt}\n' + f'Valid IPC msgs are:\n\n' + # f' ------ - ------\n' + f'{fmt_spec}\n' + ) + + else: + # decode the msg-bytes using the std msgpack + # interchange-prot (i.e. without any + # `msgspec.Struct` handling) so that we can + # determine what `.msg.types.Msg` is the culprit + # by reporting the received value. + msg_dict: dict = msgspec.msgpack.decode(msg) + msg_type_name: str = msg_dict['msg_type'] + msg_type = getattr(msgtypes, msg_type_name) + errmsg: str = ( + f'invalid `{msg_type_name}` IPC msg\n\n' + ) + if verb_header: + errmsg = f'{verb_header} ' + errmsg + + # XXX see if we can determine the exact invalid field + # such that we can comprehensively report the + # specific field's type problem + msgspec_msg: str = validation_err.args[0].rstrip('`') + msg, _, maybe_field = msgspec_msg.rpartition('$.') + if field_val := msg_dict.get(maybe_field): + field_type: Union[Type] = msg_type.__signature__.parameters[ + maybe_field + ].annotation + errmsg += ( + f'{msg.rstrip("`")}\n\n' + f'{msg_type}\n' + f' |_.{maybe_field}: {field_type} = {field_val!r}\n' + ) + + raise MsgTypeError(errmsg) from validation_err + + # TODO: not sure why we have to inherit here, but it seems to be an # issue with ``get_msg_transport()`` returning a ``Type[Protocol]``; # probably should make a `mypy` issue? @@ -175,9 +244,10 @@ class MsgpackTCPStream(MsgTransport): or _codec._ctxvar_MsgCodec.get() ) - log.critical( - '!?!: USING STD `tractor` CODEC !?!?\n' - f'{self._codec}\n' + # TODO: mask out before release? + log.runtime( + f'New {self} created with codec\n' + f'codec: {self._codec}\n' ) async def _iter_packets(self) -> AsyncGenerator[dict, None]: @@ -221,16 +291,18 @@ class MsgpackTCPStream(MsgTransport): # NOTE: lookup the `trio.Task.context`'s var for # the current `MsgCodec`. codec: MsgCodec = _ctxvar_MsgCodec.get() + + # TODO: mask out before release? if self._codec.pld_spec != codec.pld_spec: # assert ( # task := trio.lowlevel.current_task() # ) is not self._task # self._task = task self._codec = codec - log.critical( - '.recv() USING NEW CODEC !?!?\n' - f'{self._codec}\n\n' - f'msg_bytes -> {msg_bytes}\n' + log.runtime( + 'Using new codec in {self}.recv()\n' + f'codec: {self._codec}\n\n' + f'msg_bytes: {msg_bytes}\n' ) yield codec.decode(msg_bytes) @@ -252,36 +324,13 @@ class MsgpackTCPStream(MsgTransport): # and always raise such that spec violations # are never allowed to be caught silently! except msgspec.ValidationError as verr: - - # decode the msg-bytes using the std msgpack - # interchange-prot (i.e. without any - # `msgspec.Struct` handling) so that we can - # determine what `.msg.types.Msg` is the culprit - # by reporting the received value. - msg_dict: dict = msgspec.msgpack.decode(msg_bytes) - msg_type_name: str = msg_dict['msg_type'] - msg_type = getattr(types, msg_type_name) - errmsg: str = ( - f'Received invalid IPC `{msg_type_name}` msg\n\n' + # re-raise as type error + _raise_msg_type_err( + msg=msg_bytes, + codec=codec, + validation_err=verr, ) - # XXX see if we can determine the exact invalid field - # such that we can comprehensively report the - # specific field's type problem - msgspec_msg: str = verr.args[0].rstrip('`') - msg, _, maybe_field = msgspec_msg.rpartition('$.') - if field_val := msg_dict.get(maybe_field): - field_type: Union[Type] = msg_type.__signature__.parameters[ - maybe_field - ].annotation - errmsg += ( - f'{msg.rstrip("`")}\n\n' - f'{msg_type}\n' - f' |_.{maybe_field}: {field_type} = {field_val}\n' - ) - - raise MsgTypeError(errmsg) from verr - except ( msgspec.DecodeError, UnicodeDecodeError, @@ -307,12 +356,16 @@ class MsgpackTCPStream(MsgTransport): async def send( self, - msg: Any, + msg: msgtypes.Msg, + strict_types: bool = True, # hide_tb: bool = False, ) -> None: ''' - Send a msgpack coded blob-as-msg over TCP. + Send a msgpack encoded py-object-blob-as-msg over TCP. + + If `strict_types == True` then a `MsgTypeError` will be raised on any + invalid msg type ''' # __tracebackhide__: bool = hide_tb @@ -321,25 +374,40 @@ class MsgpackTCPStream(MsgTransport): # NOTE: lookup the `trio.Task.context`'s var for # the current `MsgCodec`. codec: MsgCodec = _ctxvar_MsgCodec.get() - # if self._codec != codec: + + # TODO: mask out before release? if self._codec.pld_spec != codec.pld_spec: self._codec = codec - log.critical( - '.send() using NEW CODEC !?!?\n' - f'{self._codec}\n\n' - f'OBJ -> {msg}\n' + log.runtime( + 'Using new codec in {self}.send()\n' + f'codec: {self._codec}\n\n' + f'msg: {msg}\n' ) - if type(msg) not in types.__spec__: - log.warning( - 'Sending non-`Msg`-spec msg?\n\n' - f'{msg}\n' - ) - bytes_data: bytes = codec.encode(msg) + + if type(msg) not in msgtypes.__msg_types__: + if strict_types: + _raise_msg_type_err( + msg, + codec=codec, + ) + else: + log.warning( + 'Sending non-`Msg`-spec msg?\n\n' + f'{msg}\n' + ) + + try: + bytes_data: bytes = codec.encode(msg) + except TypeError as typerr: + raise MsgTypeError( + 'A msg field violates the current spec\n' + f'{codec.pld_spec}\n\n' + f'{pretty_struct.Struct.pformat(msg)}' + ) from typerr # supposedly the fastest says, # https://stackoverflow.com/a/54027962 size: bytes = struct.pack(" seems like that might be re-inventing scalability + # prots tho no? # try: # return await self._transport.recv() # except trio.BrokenResourceError: