Factor `MsgpackTCPStream` msg-type checks
Add both the `.send()` and `.recv()` handling blocks to a common `_raise_msg_type_err()` which includes detailed error msg formatting: - the `.recv()` side case does introspection of the `Msg` fields and attempting to report the exact (field type related) issue - `.send()` side does some boxed-error style tb formatting like `RemoteActorError`. - add a `strict_types: bool` to `.send()` to allow for just warning on bad inputs versus raising, but always raise from any `Encoder` type error.rae_message_packing
							parent
							
								
									97bfbdbc1c
								
							
						
					
					
						commit
						4cfe4979ff
					
				
							
								
								
									
										164
									
								
								tractor/_ipc.py
								
								
								
								
							
							
						
						
									
										164
									
								
								tractor/_ipc.py
								
								
								
								
							| 
						 | 
					@ -54,7 +54,8 @@ from tractor.msg import (
 | 
				
			||||||
    _ctxvar_MsgCodec,
 | 
					    _ctxvar_MsgCodec,
 | 
				
			||||||
    _codec,
 | 
					    _codec,
 | 
				
			||||||
    MsgCodec,
 | 
					    MsgCodec,
 | 
				
			||||||
    types,
 | 
					    types as msgtypes,
 | 
				
			||||||
 | 
					    pretty_struct,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
log = get_logger(__name__)
 | 
					log = get_logger(__name__)
 | 
				
			||||||
| 
						 | 
					@ -72,6 +73,7 @@ def get_stream_addrs(stream: trio.SocketStream) -> tuple:
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# TODO: this should be our `Union[*msgtypes.__spec__]` now right?
 | 
				
			||||||
MsgType = TypeVar("MsgType")
 | 
					MsgType = TypeVar("MsgType")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# TODO: consider using a generic def and indexing with our eventual
 | 
					# TODO: consider using a generic def and indexing with our eventual
 | 
				
			||||||
| 
						 | 
					@ -116,6 +118,73 @@ class MsgTransport(Protocol[MsgType]):
 | 
				
			||||||
        ...
 | 
					        ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _raise_msg_type_err(
 | 
				
			||||||
 | 
					    msg: Any|bytes,
 | 
				
			||||||
 | 
					    codec: MsgCodec,
 | 
				
			||||||
 | 
					    validation_err: msgspec.ValidationError|None = None,
 | 
				
			||||||
 | 
					    verb_header: str = '',
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					) -> None:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # if side == 'send':
 | 
				
			||||||
 | 
					    if validation_err is None:  # send-side
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        import traceback
 | 
				
			||||||
 | 
					        from tractor._exceptions import pformat_boxed_tb
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        fmt_spec: str = '\n'.join(
 | 
				
			||||||
 | 
					            map(str, codec.msg_spec.__args__)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        fmt_stack: str = (
 | 
				
			||||||
 | 
					            '\n'.join(traceback.format_stack(limit=3))
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        tb_fmt: str = pformat_boxed_tb(
 | 
				
			||||||
 | 
					            tb_str=fmt_stack,
 | 
				
			||||||
 | 
					            # fields_str=header,
 | 
				
			||||||
 | 
					            field_prefix='  ',
 | 
				
			||||||
 | 
					            indent='',
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        raise MsgTypeError(
 | 
				
			||||||
 | 
					            f'invalid msg -> {msg}: {type(msg)}\n\n'
 | 
				
			||||||
 | 
					            f'{tb_fmt}\n'
 | 
				
			||||||
 | 
					            f'Valid IPC msgs are:\n\n'
 | 
				
			||||||
 | 
					            # f'  ------ - ------\n'
 | 
				
			||||||
 | 
					            f'{fmt_spec}\n'
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        # decode the msg-bytes using the std msgpack
 | 
				
			||||||
 | 
					        # interchange-prot (i.e. without any
 | 
				
			||||||
 | 
					        # `msgspec.Struct` handling) so that we can
 | 
				
			||||||
 | 
					        # determine what `.msg.types.Msg` is the culprit
 | 
				
			||||||
 | 
					        # by reporting the received value.
 | 
				
			||||||
 | 
					        msg_dict: dict = msgspec.msgpack.decode(msg)
 | 
				
			||||||
 | 
					        msg_type_name: str = msg_dict['msg_type']
 | 
				
			||||||
 | 
					        msg_type = getattr(msgtypes, msg_type_name)
 | 
				
			||||||
 | 
					        errmsg: str = (
 | 
				
			||||||
 | 
					            f'invalid `{msg_type_name}` IPC msg\n\n'
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        if verb_header:
 | 
				
			||||||
 | 
					            errmsg = f'{verb_header} ' + errmsg
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # XXX see if we can determine the exact invalid field
 | 
				
			||||||
 | 
					        # such that we can comprehensively report the
 | 
				
			||||||
 | 
					        # specific field's type problem
 | 
				
			||||||
 | 
					        msgspec_msg: str = validation_err.args[0].rstrip('`')
 | 
				
			||||||
 | 
					        msg, _, maybe_field = msgspec_msg.rpartition('$.')
 | 
				
			||||||
 | 
					        if field_val := msg_dict.get(maybe_field):
 | 
				
			||||||
 | 
					            field_type: Union[Type] = msg_type.__signature__.parameters[
 | 
				
			||||||
 | 
					                maybe_field
 | 
				
			||||||
 | 
					            ].annotation
 | 
				
			||||||
 | 
					            errmsg += (
 | 
				
			||||||
 | 
					                f'{msg.rstrip("`")}\n\n'
 | 
				
			||||||
 | 
					                f'{msg_type}\n'
 | 
				
			||||||
 | 
					                f' |_.{maybe_field}: {field_type} = {field_val!r}\n'
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        raise MsgTypeError(errmsg) from validation_err
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# TODO: not sure why we have to inherit here, but it seems to be an
 | 
					# TODO: not sure why we have to inherit here, but it seems to be an
 | 
				
			||||||
# issue with ``get_msg_transport()`` returning a ``Type[Protocol]``;
 | 
					# issue with ``get_msg_transport()`` returning a ``Type[Protocol]``;
 | 
				
			||||||
# probably should make a `mypy` issue?
 | 
					# probably should make a `mypy` issue?
 | 
				
			||||||
| 
						 | 
					@ -175,9 +244,10 @@ class MsgpackTCPStream(MsgTransport):
 | 
				
			||||||
            or
 | 
					            or
 | 
				
			||||||
            _codec._ctxvar_MsgCodec.get()
 | 
					            _codec._ctxvar_MsgCodec.get()
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        log.critical(
 | 
					        # TODO: mask out before release?
 | 
				
			||||||
            '!?!: USING STD `tractor` CODEC !?!?\n'
 | 
					        log.runtime(
 | 
				
			||||||
            f'{self._codec}\n'
 | 
					            f'New {self} created with codec\n'
 | 
				
			||||||
 | 
					            f'codec: {self._codec}\n'
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def _iter_packets(self) -> AsyncGenerator[dict, None]:
 | 
					    async def _iter_packets(self) -> AsyncGenerator[dict, None]:
 | 
				
			||||||
| 
						 | 
					@ -221,16 +291,18 @@ class MsgpackTCPStream(MsgTransport):
 | 
				
			||||||
                # NOTE: lookup the `trio.Task.context`'s var for
 | 
					                # NOTE: lookup the `trio.Task.context`'s var for
 | 
				
			||||||
                # the current `MsgCodec`.
 | 
					                # the current `MsgCodec`.
 | 
				
			||||||
                codec: MsgCodec = _ctxvar_MsgCodec.get()
 | 
					                codec: MsgCodec = _ctxvar_MsgCodec.get()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # TODO: mask out before release?
 | 
				
			||||||
                if self._codec.pld_spec != codec.pld_spec:
 | 
					                if self._codec.pld_spec != codec.pld_spec:
 | 
				
			||||||
                    # assert (
 | 
					                    # assert (
 | 
				
			||||||
                    #     task := trio.lowlevel.current_task()
 | 
					                    #     task := trio.lowlevel.current_task()
 | 
				
			||||||
                    # ) is not self._task
 | 
					                    # ) is not self._task
 | 
				
			||||||
                    # self._task = task
 | 
					                    # self._task = task
 | 
				
			||||||
                    self._codec = codec
 | 
					                    self._codec = codec
 | 
				
			||||||
                    log.critical(
 | 
					                    log.runtime(
 | 
				
			||||||
                        '.recv() USING NEW CODEC !?!?\n'
 | 
					                        'Using new codec in {self}.recv()\n'
 | 
				
			||||||
                        f'{self._codec}\n\n'
 | 
					                        f'codec: {self._codec}\n\n'
 | 
				
			||||||
                        f'msg_bytes -> {msg_bytes}\n'
 | 
					                        f'msg_bytes: {msg_bytes}\n'
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
                yield codec.decode(msg_bytes)
 | 
					                yield codec.decode(msg_bytes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -252,36 +324,13 @@ class MsgpackTCPStream(MsgTransport):
 | 
				
			||||||
            # and always raise such that spec violations
 | 
					            # and always raise such that spec violations
 | 
				
			||||||
            # are never allowed to be caught silently!
 | 
					            # are never allowed to be caught silently!
 | 
				
			||||||
            except msgspec.ValidationError as verr:
 | 
					            except msgspec.ValidationError as verr:
 | 
				
			||||||
 | 
					                # re-raise as type error
 | 
				
			||||||
                # decode the msg-bytes using the std msgpack
 | 
					                _raise_msg_type_err(
 | 
				
			||||||
                # interchange-prot (i.e. without any
 | 
					                    msg=msg_bytes,
 | 
				
			||||||
                # `msgspec.Struct` handling) so that we can
 | 
					                    codec=codec,
 | 
				
			||||||
                # determine what `.msg.types.Msg` is the culprit
 | 
					                    validation_err=verr,
 | 
				
			||||||
                # by reporting the received value.
 | 
					 | 
				
			||||||
                msg_dict: dict = msgspec.msgpack.decode(msg_bytes)
 | 
					 | 
				
			||||||
                msg_type_name: str = msg_dict['msg_type']
 | 
					 | 
				
			||||||
                msg_type = getattr(types, msg_type_name)
 | 
					 | 
				
			||||||
                errmsg: str = (
 | 
					 | 
				
			||||||
                    f'Received invalid IPC `{msg_type_name}` msg\n\n'
 | 
					 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                # XXX see if we can determine the exact invalid field
 | 
					 | 
				
			||||||
                # such that we can comprehensively report the
 | 
					 | 
				
			||||||
                # specific field's type problem
 | 
					 | 
				
			||||||
                msgspec_msg: str = verr.args[0].rstrip('`')
 | 
					 | 
				
			||||||
                msg, _, maybe_field = msgspec_msg.rpartition('$.')
 | 
					 | 
				
			||||||
                if field_val := msg_dict.get(maybe_field):
 | 
					 | 
				
			||||||
                    field_type: Union[Type] = msg_type.__signature__.parameters[
 | 
					 | 
				
			||||||
                        maybe_field
 | 
					 | 
				
			||||||
                    ].annotation
 | 
					 | 
				
			||||||
                    errmsg += (
 | 
					 | 
				
			||||||
                        f'{msg.rstrip("`")}\n\n'
 | 
					 | 
				
			||||||
                        f'{msg_type}\n'
 | 
					 | 
				
			||||||
                        f' |_.{maybe_field}: {field_type} = {field_val}\n'
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                raise MsgTypeError(errmsg) from verr
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            except (
 | 
					            except (
 | 
				
			||||||
                msgspec.DecodeError,
 | 
					                msgspec.DecodeError,
 | 
				
			||||||
                UnicodeDecodeError,
 | 
					                UnicodeDecodeError,
 | 
				
			||||||
| 
						 | 
					@ -307,12 +356,16 @@ class MsgpackTCPStream(MsgTransport):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def send(
 | 
					    async def send(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        msg: Any,
 | 
					        msg: msgtypes.Msg,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        strict_types: bool = True,
 | 
				
			||||||
        # hide_tb: bool = False,
 | 
					        # hide_tb: bool = False,
 | 
				
			||||||
    ) -> None:
 | 
					    ) -> None:
 | 
				
			||||||
        '''
 | 
					        '''
 | 
				
			||||||
        Send a msgpack coded blob-as-msg over TCP.
 | 
					        Send a msgpack encoded py-object-blob-as-msg over TCP.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        If `strict_types == True` then a `MsgTypeError` will be raised on any
 | 
				
			||||||
 | 
					        invalid msg type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        '''
 | 
					        '''
 | 
				
			||||||
        # __tracebackhide__: bool = hide_tb
 | 
					        # __tracebackhide__: bool = hide_tb
 | 
				
			||||||
| 
						 | 
					@ -321,25 +374,40 @@ class MsgpackTCPStream(MsgTransport):
 | 
				
			||||||
            # NOTE: lookup the `trio.Task.context`'s var for
 | 
					            # NOTE: lookup the `trio.Task.context`'s var for
 | 
				
			||||||
            # the current `MsgCodec`.
 | 
					            # the current `MsgCodec`.
 | 
				
			||||||
            codec: MsgCodec = _ctxvar_MsgCodec.get()
 | 
					            codec: MsgCodec = _ctxvar_MsgCodec.get()
 | 
				
			||||||
            # if self._codec != codec:
 | 
					
 | 
				
			||||||
 | 
					            # TODO: mask out before release?
 | 
				
			||||||
            if self._codec.pld_spec != codec.pld_spec:
 | 
					            if self._codec.pld_spec != codec.pld_spec:
 | 
				
			||||||
                self._codec = codec
 | 
					                self._codec = codec
 | 
				
			||||||
                log.critical(
 | 
					                log.runtime(
 | 
				
			||||||
                    '.send() using NEW CODEC !?!?\n'
 | 
					                    'Using new codec in {self}.send()\n'
 | 
				
			||||||
                    f'{self._codec}\n\n'
 | 
					                    f'codec: {self._codec}\n\n'
 | 
				
			||||||
                    f'OBJ -> {msg}\n'
 | 
					                    f'msg: {msg}\n'
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
            if type(msg) not in types.__spec__:
 | 
					
 | 
				
			||||||
 | 
					            if type(msg) not in msgtypes.__msg_types__:
 | 
				
			||||||
 | 
					                if strict_types:
 | 
				
			||||||
 | 
					                    _raise_msg_type_err(
 | 
				
			||||||
 | 
					                        msg,
 | 
				
			||||||
 | 
					                        codec=codec,
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
                    log.warning(
 | 
					                    log.warning(
 | 
				
			||||||
                        'Sending non-`Msg`-spec msg?\n\n'
 | 
					                        'Sending non-`Msg`-spec msg?\n\n'
 | 
				
			||||||
                        f'{msg}\n'
 | 
					                        f'{msg}\n'
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
                bytes_data: bytes = codec.encode(msg)
 | 
					                bytes_data: bytes = codec.encode(msg)
 | 
				
			||||||
 | 
					            except TypeError as typerr:
 | 
				
			||||||
 | 
					                raise MsgTypeError(
 | 
				
			||||||
 | 
					                    'A msg field violates the current spec\n'
 | 
				
			||||||
 | 
					                    f'{codec.pld_spec}\n\n'
 | 
				
			||||||
 | 
					                    f'{pretty_struct.Struct.pformat(msg)}'
 | 
				
			||||||
 | 
					                ) from typerr
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # supposedly the fastest says,
 | 
					            # supposedly the fastest says,
 | 
				
			||||||
            # https://stackoverflow.com/a/54027962
 | 
					            # https://stackoverflow.com/a/54027962
 | 
				
			||||||
            size: bytes = struct.pack("<I", len(bytes_data))
 | 
					            size: bytes = struct.pack("<I", len(bytes_data))
 | 
				
			||||||
 | 
					 | 
				
			||||||
            return await self.stream.send_all(size + bytes_data)
 | 
					            return await self.stream.send_all(size + bytes_data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
| 
						 | 
					@ -567,7 +635,6 @@ class Channel:
 | 
				
			||||||
            f'{pformat(payload)}\n'
 | 
					            f'{pformat(payload)}\n'
 | 
				
			||||||
        )  # type: ignore
 | 
					        )  # type: ignore
 | 
				
			||||||
        assert self._transport
 | 
					        assert self._transport
 | 
				
			||||||
 | 
					 | 
				
			||||||
        await self._transport.send(
 | 
					        await self._transport.send(
 | 
				
			||||||
            payload,
 | 
					            payload,
 | 
				
			||||||
            # hide_tb=hide_tb,
 | 
					            # hide_tb=hide_tb,
 | 
				
			||||||
| 
						 | 
					@ -577,6 +644,11 @@ class Channel:
 | 
				
			||||||
        assert self._transport
 | 
					        assert self._transport
 | 
				
			||||||
        return await self._transport.recv()
 | 
					        return await self._transport.recv()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # TODO: auto-reconnect features like 0mq/nanomsg?
 | 
				
			||||||
 | 
					        # -[ ] implement it manually with nods to SC prot
 | 
				
			||||||
 | 
					        #      possibly on multiple transport backends?
 | 
				
			||||||
 | 
					        #  -> seems like that might be re-inventing scalability
 | 
				
			||||||
 | 
					        #     prots tho no?
 | 
				
			||||||
        # try:
 | 
					        # try:
 | 
				
			||||||
        #     return await self._transport.recv()
 | 
					        #     return await self._transport.recv()
 | 
				
			||||||
        # except trio.BrokenResourceError:
 | 
					        # except trio.BrokenResourceError:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue