Factor `MsgpackTCPStream` msg-type checks

Add both the `.send()` and `.recv()` handling blocks to a common
`_raise_msg_type_err()` which includes detailed error msg formatting:

- the `.recv()` side case does introspection of the `Msg` fields and
  attempting to report the exact (field type related) issue
- `.send()` side does some boxed-error style tb formatting like
  `RemoteActorError`.
- add a `strict_types: bool` to `.send()` to allow for just
  warning on bad inputs versus raising, but always raise from any
  `Encoder` type error.
runtime_to_msgspec
Tyler Goodlet 2024-04-05 16:34:07 -04:00
parent 97bfbdbc1c
commit 4cfe4979ff
1 changed files with 123 additions and 51 deletions

View File

@ -54,7 +54,8 @@ from tractor.msg import (
_ctxvar_MsgCodec,
_codec,
MsgCodec,
types,
types as msgtypes,
pretty_struct,
)
log = get_logger(__name__)
@ -72,6 +73,7 @@ def get_stream_addrs(stream: trio.SocketStream) -> tuple:
)
# TODO: this should be our `Union[*msgtypes.__spec__]` now right?
MsgType = TypeVar("MsgType")
# TODO: consider using a generic def and indexing with our eventual
@ -116,6 +118,73 @@ class MsgTransport(Protocol[MsgType]):
...
def _raise_msg_type_err(
msg: Any|bytes,
codec: MsgCodec,
validation_err: msgspec.ValidationError|None = None,
verb_header: str = '',
) -> None:
# if side == 'send':
if validation_err is None: # send-side
import traceback
from tractor._exceptions import pformat_boxed_tb
fmt_spec: str = '\n'.join(
map(str, codec.msg_spec.__args__)
)
fmt_stack: str = (
'\n'.join(traceback.format_stack(limit=3))
)
tb_fmt: str = pformat_boxed_tb(
tb_str=fmt_stack,
# fields_str=header,
field_prefix=' ',
indent='',
)
raise MsgTypeError(
f'invalid msg -> {msg}: {type(msg)}\n\n'
f'{tb_fmt}\n'
f'Valid IPC msgs are:\n\n'
# f' ------ - ------\n'
f'{fmt_spec}\n'
)
else:
# decode the msg-bytes using the std msgpack
# interchange-prot (i.e. without any
# `msgspec.Struct` handling) so that we can
# determine what `.msg.types.Msg` is the culprit
# by reporting the received value.
msg_dict: dict = msgspec.msgpack.decode(msg)
msg_type_name: str = msg_dict['msg_type']
msg_type = getattr(msgtypes, msg_type_name)
errmsg: str = (
f'invalid `{msg_type_name}` IPC msg\n\n'
)
if verb_header:
errmsg = f'{verb_header} ' + errmsg
# XXX see if we can determine the exact invalid field
# such that we can comprehensively report the
# specific field's type problem
msgspec_msg: str = validation_err.args[0].rstrip('`')
msg, _, maybe_field = msgspec_msg.rpartition('$.')
if field_val := msg_dict.get(maybe_field):
field_type: Union[Type] = msg_type.__signature__.parameters[
maybe_field
].annotation
errmsg += (
f'{msg.rstrip("`")}\n\n'
f'{msg_type}\n'
f' |_.{maybe_field}: {field_type} = {field_val!r}\n'
)
raise MsgTypeError(errmsg) from validation_err
# TODO: not sure why we have to inherit here, but it seems to be an
# issue with ``get_msg_transport()`` returning a ``Type[Protocol]``;
# probably should make a `mypy` issue?
@ -175,9 +244,10 @@ class MsgpackTCPStream(MsgTransport):
or
_codec._ctxvar_MsgCodec.get()
)
log.critical(
'!?!: USING STD `tractor` CODEC !?!?\n'
f'{self._codec}\n'
# TODO: mask out before release?
log.runtime(
f'New {self} created with codec\n'
f'codec: {self._codec}\n'
)
async def _iter_packets(self) -> AsyncGenerator[dict, None]:
@ -221,16 +291,18 @@ class MsgpackTCPStream(MsgTransport):
# NOTE: lookup the `trio.Task.context`'s var for
# the current `MsgCodec`.
codec: MsgCodec = _ctxvar_MsgCodec.get()
# TODO: mask out before release?
if self._codec.pld_spec != codec.pld_spec:
# assert (
# task := trio.lowlevel.current_task()
# ) is not self._task
# self._task = task
self._codec = codec
log.critical(
'.recv() USING NEW CODEC !?!?\n'
f'{self._codec}\n\n'
f'msg_bytes -> {msg_bytes}\n'
log.runtime(
'Using new codec in {self}.recv()\n'
f'codec: {self._codec}\n\n'
f'msg_bytes: {msg_bytes}\n'
)
yield codec.decode(msg_bytes)
@ -252,36 +324,13 @@ class MsgpackTCPStream(MsgTransport):
# and always raise such that spec violations
# are never allowed to be caught silently!
except msgspec.ValidationError as verr:
# decode the msg-bytes using the std msgpack
# interchange-prot (i.e. without any
# `msgspec.Struct` handling) so that we can
# determine what `.msg.types.Msg` is the culprit
# by reporting the received value.
msg_dict: dict = msgspec.msgpack.decode(msg_bytes)
msg_type_name: str = msg_dict['msg_type']
msg_type = getattr(types, msg_type_name)
errmsg: str = (
f'Received invalid IPC `{msg_type_name}` msg\n\n'
# re-raise as type error
_raise_msg_type_err(
msg=msg_bytes,
codec=codec,
validation_err=verr,
)
# XXX see if we can determine the exact invalid field
# such that we can comprehensively report the
# specific field's type problem
msgspec_msg: str = verr.args[0].rstrip('`')
msg, _, maybe_field = msgspec_msg.rpartition('$.')
if field_val := msg_dict.get(maybe_field):
field_type: Union[Type] = msg_type.__signature__.parameters[
maybe_field
].annotation
errmsg += (
f'{msg.rstrip("`")}\n\n'
f'{msg_type}\n'
f' |_.{maybe_field}: {field_type} = {field_val}\n'
)
raise MsgTypeError(errmsg) from verr
except (
msgspec.DecodeError,
UnicodeDecodeError,
@ -307,12 +356,16 @@ class MsgpackTCPStream(MsgTransport):
async def send(
self,
msg: Any,
msg: msgtypes.Msg,
strict_types: bool = True,
# hide_tb: bool = False,
) -> None:
'''
Send a msgpack coded blob-as-msg over TCP.
Send a msgpack encoded py-object-blob-as-msg over TCP.
If `strict_types == True` then a `MsgTypeError` will be raised on any
invalid msg type
'''
# __tracebackhide__: bool = hide_tb
@ -321,25 +374,40 @@ class MsgpackTCPStream(MsgTransport):
# NOTE: lookup the `trio.Task.context`'s var for
# the current `MsgCodec`.
codec: MsgCodec = _ctxvar_MsgCodec.get()
# if self._codec != codec:
# TODO: mask out before release?
if self._codec.pld_spec != codec.pld_spec:
self._codec = codec
log.critical(
'.send() using NEW CODEC !?!?\n'
f'{self._codec}\n\n'
f'OBJ -> {msg}\n'
log.runtime(
'Using new codec in {self}.send()\n'
f'codec: {self._codec}\n\n'
f'msg: {msg}\n'
)
if type(msg) not in types.__spec__:
log.warning(
'Sending non-`Msg`-spec msg?\n\n'
f'{msg}\n'
)
bytes_data: bytes = codec.encode(msg)
if type(msg) not in msgtypes.__msg_types__:
if strict_types:
_raise_msg_type_err(
msg,
codec=codec,
)
else:
log.warning(
'Sending non-`Msg`-spec msg?\n\n'
f'{msg}\n'
)
try:
bytes_data: bytes = codec.encode(msg)
except TypeError as typerr:
raise MsgTypeError(
'A msg field violates the current spec\n'
f'{codec.pld_spec}\n\n'
f'{pretty_struct.Struct.pformat(msg)}'
) from typerr
# supposedly the fastest says,
# https://stackoverflow.com/a/54027962
size: bytes = struct.pack("<I", len(bytes_data))
return await self.stream.send_all(size + bytes_data)
@property
@ -567,7 +635,6 @@ class Channel:
f'{pformat(payload)}\n'
) # type: ignore
assert self._transport
await self._transport.send(
payload,
# hide_tb=hide_tb,
@ -577,6 +644,11 @@ class Channel:
assert self._transport
return await self._transport.recv()
# TODO: auto-reconnect features like 0mq/nanomsg?
# -[ ] implement it manually with nods to SC prot
# possibly on multiple transport backends?
# -> seems like that might be re-inventing scalability
# prots tho no?
# try:
# return await self._transport.recv()
# except trio.BrokenResourceError: