Factor `MsgpackTCPStream` msg-type checks
Add both the `.send()` and `.recv()` handling blocks to a common `_raise_msg_type_err()` which includes detailed error msg formatting: - the `.recv()` side case does introspection of the `Msg` fields and attempting to report the exact (field type related) issue - `.send()` side does some boxed-error style tb formatting like `RemoteActorError`. - add a `strict_types: bool` to `.send()` to allow for just warning on bad inputs versus raising, but always raise from any `Encoder` type error.rae_message_packing
							parent
							
								
									97bfbdbc1c
								
							
						
					
					
						commit
						4cfe4979ff
					
				
							
								
								
									
										164
									
								
								tractor/_ipc.py
								
								
								
								
							
							
						
						
									
										164
									
								
								tractor/_ipc.py
								
								
								
								
							| 
						 | 
				
			
			@ -54,7 +54,8 @@ from tractor.msg import (
 | 
			
		|||
    _ctxvar_MsgCodec,
 | 
			
		||||
    _codec,
 | 
			
		||||
    MsgCodec,
 | 
			
		||||
    types,
 | 
			
		||||
    types as msgtypes,
 | 
			
		||||
    pretty_struct,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
log = get_logger(__name__)
 | 
			
		||||
| 
						 | 
				
			
			@ -72,6 +73,7 @@ def get_stream_addrs(stream: trio.SocketStream) -> tuple:
 | 
			
		|||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO: this should be our `Union[*msgtypes.__spec__]` now right?
 | 
			
		||||
MsgType = TypeVar("MsgType")
 | 
			
		||||
 | 
			
		||||
# TODO: consider using a generic def and indexing with our eventual
 | 
			
		||||
| 
						 | 
				
			
			@ -116,6 +118,73 @@ class MsgTransport(Protocol[MsgType]):
 | 
			
		|||
        ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _raise_msg_type_err(
 | 
			
		||||
    msg: Any|bytes,
 | 
			
		||||
    codec: MsgCodec,
 | 
			
		||||
    validation_err: msgspec.ValidationError|None = None,
 | 
			
		||||
    verb_header: str = '',
 | 
			
		||||
 | 
			
		||||
) -> None:
 | 
			
		||||
 | 
			
		||||
    # if side == 'send':
 | 
			
		||||
    if validation_err is None:  # send-side
 | 
			
		||||
 | 
			
		||||
        import traceback
 | 
			
		||||
        from tractor._exceptions import pformat_boxed_tb
 | 
			
		||||
 | 
			
		||||
        fmt_spec: str = '\n'.join(
 | 
			
		||||
            map(str, codec.msg_spec.__args__)
 | 
			
		||||
        )
 | 
			
		||||
        fmt_stack: str = (
 | 
			
		||||
            '\n'.join(traceback.format_stack(limit=3))
 | 
			
		||||
        )
 | 
			
		||||
        tb_fmt: str = pformat_boxed_tb(
 | 
			
		||||
            tb_str=fmt_stack,
 | 
			
		||||
            # fields_str=header,
 | 
			
		||||
            field_prefix='  ',
 | 
			
		||||
            indent='',
 | 
			
		||||
        )
 | 
			
		||||
        raise MsgTypeError(
 | 
			
		||||
            f'invalid msg -> {msg}: {type(msg)}\n\n'
 | 
			
		||||
            f'{tb_fmt}\n'
 | 
			
		||||
            f'Valid IPC msgs are:\n\n'
 | 
			
		||||
            # f'  ------ - ------\n'
 | 
			
		||||
            f'{fmt_spec}\n'
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
        # decode the msg-bytes using the std msgpack
 | 
			
		||||
        # interchange-prot (i.e. without any
 | 
			
		||||
        # `msgspec.Struct` handling) so that we can
 | 
			
		||||
        # determine what `.msg.types.Msg` is the culprit
 | 
			
		||||
        # by reporting the received value.
 | 
			
		||||
        msg_dict: dict = msgspec.msgpack.decode(msg)
 | 
			
		||||
        msg_type_name: str = msg_dict['msg_type']
 | 
			
		||||
        msg_type = getattr(msgtypes, msg_type_name)
 | 
			
		||||
        errmsg: str = (
 | 
			
		||||
            f'invalid `{msg_type_name}` IPC msg\n\n'
 | 
			
		||||
        )
 | 
			
		||||
        if verb_header:
 | 
			
		||||
            errmsg = f'{verb_header} ' + errmsg
 | 
			
		||||
 | 
			
		||||
        # XXX see if we can determine the exact invalid field
 | 
			
		||||
        # such that we can comprehensively report the
 | 
			
		||||
        # specific field's type problem
 | 
			
		||||
        msgspec_msg: str = validation_err.args[0].rstrip('`')
 | 
			
		||||
        msg, _, maybe_field = msgspec_msg.rpartition('$.')
 | 
			
		||||
        if field_val := msg_dict.get(maybe_field):
 | 
			
		||||
            field_type: Union[Type] = msg_type.__signature__.parameters[
 | 
			
		||||
                maybe_field
 | 
			
		||||
            ].annotation
 | 
			
		||||
            errmsg += (
 | 
			
		||||
                f'{msg.rstrip("`")}\n\n'
 | 
			
		||||
                f'{msg_type}\n'
 | 
			
		||||
                f' |_.{maybe_field}: {field_type} = {field_val!r}\n'
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        raise MsgTypeError(errmsg) from validation_err
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO: not sure why we have to inherit here, but it seems to be an
 | 
			
		||||
# issue with ``get_msg_transport()`` returning a ``Type[Protocol]``;
 | 
			
		||||
# probably should make a `mypy` issue?
 | 
			
		||||
| 
						 | 
				
			
			@ -175,9 +244,10 @@ class MsgpackTCPStream(MsgTransport):
 | 
			
		|||
            or
 | 
			
		||||
            _codec._ctxvar_MsgCodec.get()
 | 
			
		||||
        )
 | 
			
		||||
        log.critical(
 | 
			
		||||
            '!?!: USING STD `tractor` CODEC !?!?\n'
 | 
			
		||||
            f'{self._codec}\n'
 | 
			
		||||
        # TODO: mask out before release?
 | 
			
		||||
        log.runtime(
 | 
			
		||||
            f'New {self} created with codec\n'
 | 
			
		||||
            f'codec: {self._codec}\n'
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    async def _iter_packets(self) -> AsyncGenerator[dict, None]:
 | 
			
		||||
| 
						 | 
				
			
			@ -221,16 +291,18 @@ class MsgpackTCPStream(MsgTransport):
 | 
			
		|||
                # NOTE: lookup the `trio.Task.context`'s var for
 | 
			
		||||
                # the current `MsgCodec`.
 | 
			
		||||
                codec: MsgCodec = _ctxvar_MsgCodec.get()
 | 
			
		||||
 | 
			
		||||
                # TODO: mask out before release?
 | 
			
		||||
                if self._codec.pld_spec != codec.pld_spec:
 | 
			
		||||
                    # assert (
 | 
			
		||||
                    #     task := trio.lowlevel.current_task()
 | 
			
		||||
                    # ) is not self._task
 | 
			
		||||
                    # self._task = task
 | 
			
		||||
                    self._codec = codec
 | 
			
		||||
                    log.critical(
 | 
			
		||||
                        '.recv() USING NEW CODEC !?!?\n'
 | 
			
		||||
                        f'{self._codec}\n\n'
 | 
			
		||||
                        f'msg_bytes -> {msg_bytes}\n'
 | 
			
		||||
                    log.runtime(
 | 
			
		||||
                        'Using new codec in {self}.recv()\n'
 | 
			
		||||
                        f'codec: {self._codec}\n\n'
 | 
			
		||||
                        f'msg_bytes: {msg_bytes}\n'
 | 
			
		||||
                    )
 | 
			
		||||
                yield codec.decode(msg_bytes)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -252,36 +324,13 @@ class MsgpackTCPStream(MsgTransport):
 | 
			
		|||
            # and always raise such that spec violations
 | 
			
		||||
            # are never allowed to be caught silently!
 | 
			
		||||
            except msgspec.ValidationError as verr:
 | 
			
		||||
 | 
			
		||||
                # decode the msg-bytes using the std msgpack
 | 
			
		||||
                # interchange-prot (i.e. without any
 | 
			
		||||
                # `msgspec.Struct` handling) so that we can
 | 
			
		||||
                # determine what `.msg.types.Msg` is the culprit
 | 
			
		||||
                # by reporting the received value.
 | 
			
		||||
                msg_dict: dict = msgspec.msgpack.decode(msg_bytes)
 | 
			
		||||
                msg_type_name: str = msg_dict['msg_type']
 | 
			
		||||
                msg_type = getattr(types, msg_type_name)
 | 
			
		||||
                errmsg: str = (
 | 
			
		||||
                    f'Received invalid IPC `{msg_type_name}` msg\n\n'
 | 
			
		||||
                # re-raise as type error
 | 
			
		||||
                _raise_msg_type_err(
 | 
			
		||||
                    msg=msg_bytes,
 | 
			
		||||
                    codec=codec,
 | 
			
		||||
                    validation_err=verr,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                # XXX see if we can determine the exact invalid field
 | 
			
		||||
                # such that we can comprehensively report the
 | 
			
		||||
                # specific field's type problem
 | 
			
		||||
                msgspec_msg: str = verr.args[0].rstrip('`')
 | 
			
		||||
                msg, _, maybe_field = msgspec_msg.rpartition('$.')
 | 
			
		||||
                if field_val := msg_dict.get(maybe_field):
 | 
			
		||||
                    field_type: Union[Type] = msg_type.__signature__.parameters[
 | 
			
		||||
                        maybe_field
 | 
			
		||||
                    ].annotation
 | 
			
		||||
                    errmsg += (
 | 
			
		||||
                        f'{msg.rstrip("`")}\n\n'
 | 
			
		||||
                        f'{msg_type}\n'
 | 
			
		||||
                        f' |_.{maybe_field}: {field_type} = {field_val}\n'
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                raise MsgTypeError(errmsg) from verr
 | 
			
		||||
 | 
			
		||||
            except (
 | 
			
		||||
                msgspec.DecodeError,
 | 
			
		||||
                UnicodeDecodeError,
 | 
			
		||||
| 
						 | 
				
			
			@ -307,12 +356,16 @@ class MsgpackTCPStream(MsgTransport):
 | 
			
		|||
 | 
			
		||||
    async def send(
 | 
			
		||||
        self,
 | 
			
		||||
        msg: Any,
 | 
			
		||||
        msg: msgtypes.Msg,
 | 
			
		||||
 | 
			
		||||
        strict_types: bool = True,
 | 
			
		||||
        # hide_tb: bool = False,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        '''
 | 
			
		||||
        Send a msgpack coded blob-as-msg over TCP.
 | 
			
		||||
        Send a msgpack encoded py-object-blob-as-msg over TCP.
 | 
			
		||||
 | 
			
		||||
        If `strict_types == True` then a `MsgTypeError` will be raised on any
 | 
			
		||||
        invalid msg type
 | 
			
		||||
 | 
			
		||||
        '''
 | 
			
		||||
        # __tracebackhide__: bool = hide_tb
 | 
			
		||||
| 
						 | 
				
			
			@ -321,25 +374,40 @@ class MsgpackTCPStream(MsgTransport):
 | 
			
		|||
            # NOTE: lookup the `trio.Task.context`'s var for
 | 
			
		||||
            # the current `MsgCodec`.
 | 
			
		||||
            codec: MsgCodec = _ctxvar_MsgCodec.get()
 | 
			
		||||
            # if self._codec != codec:
 | 
			
		||||
 | 
			
		||||
            # TODO: mask out before release?
 | 
			
		||||
            if self._codec.pld_spec != codec.pld_spec:
 | 
			
		||||
                self._codec = codec
 | 
			
		||||
                log.critical(
 | 
			
		||||
                    '.send() using NEW CODEC !?!?\n'
 | 
			
		||||
                    f'{self._codec}\n\n'
 | 
			
		||||
                    f'OBJ -> {msg}\n'
 | 
			
		||||
                log.runtime(
 | 
			
		||||
                    'Using new codec in {self}.send()\n'
 | 
			
		||||
                    f'codec: {self._codec}\n\n'
 | 
			
		||||
                    f'msg: {msg}\n'
 | 
			
		||||
                )
 | 
			
		||||
            if type(msg) not in types.__spec__:
 | 
			
		||||
 | 
			
		||||
            if type(msg) not in msgtypes.__msg_types__:
 | 
			
		||||
                if strict_types:
 | 
			
		||||
                    _raise_msg_type_err(
 | 
			
		||||
                        msg,
 | 
			
		||||
                        codec=codec,
 | 
			
		||||
                    )
 | 
			
		||||
                else:
 | 
			
		||||
                    log.warning(
 | 
			
		||||
                        'Sending non-`Msg`-spec msg?\n\n'
 | 
			
		||||
                        f'{msg}\n'
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                bytes_data: bytes = codec.encode(msg)
 | 
			
		||||
            except TypeError as typerr:
 | 
			
		||||
                raise MsgTypeError(
 | 
			
		||||
                    'A msg field violates the current spec\n'
 | 
			
		||||
                    f'{codec.pld_spec}\n\n'
 | 
			
		||||
                    f'{pretty_struct.Struct.pformat(msg)}'
 | 
			
		||||
                ) from typerr
 | 
			
		||||
 | 
			
		||||
            # supposedly the fastest says,
 | 
			
		||||
            # https://stackoverflow.com/a/54027962
 | 
			
		||||
            size: bytes = struct.pack("<I", len(bytes_data))
 | 
			
		||||
 | 
			
		||||
            return await self.stream.send_all(size + bytes_data)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
| 
						 | 
				
			
			@ -567,7 +635,6 @@ class Channel:
 | 
			
		|||
            f'{pformat(payload)}\n'
 | 
			
		||||
        )  # type: ignore
 | 
			
		||||
        assert self._transport
 | 
			
		||||
 | 
			
		||||
        await self._transport.send(
 | 
			
		||||
            payload,
 | 
			
		||||
            # hide_tb=hide_tb,
 | 
			
		||||
| 
						 | 
				
			
			@ -577,6 +644,11 @@ class Channel:
 | 
			
		|||
        assert self._transport
 | 
			
		||||
        return await self._transport.recv()
 | 
			
		||||
 | 
			
		||||
        # TODO: auto-reconnect features like 0mq/nanomsg?
 | 
			
		||||
        # -[ ] implement it manually with nods to SC prot
 | 
			
		||||
        #      possibly on multiple transport backends?
 | 
			
		||||
        #  -> seems like that might be re-inventing scalability
 | 
			
		||||
        #     prots tho no?
 | 
			
		||||
        # try:
 | 
			
		||||
        #     return await self._transport.recv()
 | 
			
		||||
        # except trio.BrokenResourceError:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue