Factor `MsgpackTCPStream` msg-type checks
Add both the `.send()` and `.recv()` handling blocks to a common `_raise_msg_type_err()` which includes detailed error msg formatting: - the `.recv()` side case does introspection of the `Msg` fields and attempting to report the exact (field type related) issue - `.send()` side does some boxed-error style tb formatting like `RemoteActorError`. - add a `strict_types: bool` to `.send()` to allow for just warning on bad inputs versus raising, but always raise from any `Encoder` type error.runtime_to_msgspec
parent
97bfbdbc1c
commit
4cfe4979ff
174
tractor/_ipc.py
174
tractor/_ipc.py
|
@ -54,7 +54,8 @@ from tractor.msg import (
|
|||
_ctxvar_MsgCodec,
|
||||
_codec,
|
||||
MsgCodec,
|
||||
types,
|
||||
types as msgtypes,
|
||||
pretty_struct,
|
||||
)
|
||||
|
||||
log = get_logger(__name__)
|
||||
|
@ -72,6 +73,7 @@ def get_stream_addrs(stream: trio.SocketStream) -> tuple:
|
|||
)
|
||||
|
||||
|
||||
# TODO: this should be our `Union[*msgtypes.__spec__]` now right?
|
||||
MsgType = TypeVar("MsgType")
|
||||
|
||||
# TODO: consider using a generic def and indexing with our eventual
|
||||
|
@ -116,6 +118,73 @@ class MsgTransport(Protocol[MsgType]):
|
|||
...
|
||||
|
||||
|
||||
def _raise_msg_type_err(
|
||||
msg: Any|bytes,
|
||||
codec: MsgCodec,
|
||||
validation_err: msgspec.ValidationError|None = None,
|
||||
verb_header: str = '',
|
||||
|
||||
) -> None:
|
||||
|
||||
# if side == 'send':
|
||||
if validation_err is None: # send-side
|
||||
|
||||
import traceback
|
||||
from tractor._exceptions import pformat_boxed_tb
|
||||
|
||||
fmt_spec: str = '\n'.join(
|
||||
map(str, codec.msg_spec.__args__)
|
||||
)
|
||||
fmt_stack: str = (
|
||||
'\n'.join(traceback.format_stack(limit=3))
|
||||
)
|
||||
tb_fmt: str = pformat_boxed_tb(
|
||||
tb_str=fmt_stack,
|
||||
# fields_str=header,
|
||||
field_prefix=' ',
|
||||
indent='',
|
||||
)
|
||||
raise MsgTypeError(
|
||||
f'invalid msg -> {msg}: {type(msg)}\n\n'
|
||||
f'{tb_fmt}\n'
|
||||
f'Valid IPC msgs are:\n\n'
|
||||
# f' ------ - ------\n'
|
||||
f'{fmt_spec}\n'
|
||||
)
|
||||
|
||||
else:
|
||||
# decode the msg-bytes using the std msgpack
|
||||
# interchange-prot (i.e. without any
|
||||
# `msgspec.Struct` handling) so that we can
|
||||
# determine what `.msg.types.Msg` is the culprit
|
||||
# by reporting the received value.
|
||||
msg_dict: dict = msgspec.msgpack.decode(msg)
|
||||
msg_type_name: str = msg_dict['msg_type']
|
||||
msg_type = getattr(msgtypes, msg_type_name)
|
||||
errmsg: str = (
|
||||
f'invalid `{msg_type_name}` IPC msg\n\n'
|
||||
)
|
||||
if verb_header:
|
||||
errmsg = f'{verb_header} ' + errmsg
|
||||
|
||||
# XXX see if we can determine the exact invalid field
|
||||
# such that we can comprehensively report the
|
||||
# specific field's type problem
|
||||
msgspec_msg: str = validation_err.args[0].rstrip('`')
|
||||
msg, _, maybe_field = msgspec_msg.rpartition('$.')
|
||||
if field_val := msg_dict.get(maybe_field):
|
||||
field_type: Union[Type] = msg_type.__signature__.parameters[
|
||||
maybe_field
|
||||
].annotation
|
||||
errmsg += (
|
||||
f'{msg.rstrip("`")}\n\n'
|
||||
f'{msg_type}\n'
|
||||
f' |_.{maybe_field}: {field_type} = {field_val!r}\n'
|
||||
)
|
||||
|
||||
raise MsgTypeError(errmsg) from validation_err
|
||||
|
||||
|
||||
# TODO: not sure why we have to inherit here, but it seems to be an
|
||||
# issue with ``get_msg_transport()`` returning a ``Type[Protocol]``;
|
||||
# probably should make a `mypy` issue?
|
||||
|
@ -175,9 +244,10 @@ class MsgpackTCPStream(MsgTransport):
|
|||
or
|
||||
_codec._ctxvar_MsgCodec.get()
|
||||
)
|
||||
log.critical(
|
||||
'!?!: USING STD `tractor` CODEC !?!?\n'
|
||||
f'{self._codec}\n'
|
||||
# TODO: mask out before release?
|
||||
log.runtime(
|
||||
f'New {self} created with codec\n'
|
||||
f'codec: {self._codec}\n'
|
||||
)
|
||||
|
||||
async def _iter_packets(self) -> AsyncGenerator[dict, None]:
|
||||
|
@ -221,16 +291,18 @@ class MsgpackTCPStream(MsgTransport):
|
|||
# NOTE: lookup the `trio.Task.context`'s var for
|
||||
# the current `MsgCodec`.
|
||||
codec: MsgCodec = _ctxvar_MsgCodec.get()
|
||||
|
||||
# TODO: mask out before release?
|
||||
if self._codec.pld_spec != codec.pld_spec:
|
||||
# assert (
|
||||
# task := trio.lowlevel.current_task()
|
||||
# ) is not self._task
|
||||
# self._task = task
|
||||
self._codec = codec
|
||||
log.critical(
|
||||
'.recv() USING NEW CODEC !?!?\n'
|
||||
f'{self._codec}\n\n'
|
||||
f'msg_bytes -> {msg_bytes}\n'
|
||||
log.runtime(
|
||||
'Using new codec in {self}.recv()\n'
|
||||
f'codec: {self._codec}\n\n'
|
||||
f'msg_bytes: {msg_bytes}\n'
|
||||
)
|
||||
yield codec.decode(msg_bytes)
|
||||
|
||||
|
@ -252,36 +324,13 @@ class MsgpackTCPStream(MsgTransport):
|
|||
# and always raise such that spec violations
|
||||
# are never allowed to be caught silently!
|
||||
except msgspec.ValidationError as verr:
|
||||
|
||||
# decode the msg-bytes using the std msgpack
|
||||
# interchange-prot (i.e. without any
|
||||
# `msgspec.Struct` handling) so that we can
|
||||
# determine what `.msg.types.Msg` is the culprit
|
||||
# by reporting the received value.
|
||||
msg_dict: dict = msgspec.msgpack.decode(msg_bytes)
|
||||
msg_type_name: str = msg_dict['msg_type']
|
||||
msg_type = getattr(types, msg_type_name)
|
||||
errmsg: str = (
|
||||
f'Received invalid IPC `{msg_type_name}` msg\n\n'
|
||||
# re-raise as type error
|
||||
_raise_msg_type_err(
|
||||
msg=msg_bytes,
|
||||
codec=codec,
|
||||
validation_err=verr,
|
||||
)
|
||||
|
||||
# XXX see if we can determine the exact invalid field
|
||||
# such that we can comprehensively report the
|
||||
# specific field's type problem
|
||||
msgspec_msg: str = verr.args[0].rstrip('`')
|
||||
msg, _, maybe_field = msgspec_msg.rpartition('$.')
|
||||
if field_val := msg_dict.get(maybe_field):
|
||||
field_type: Union[Type] = msg_type.__signature__.parameters[
|
||||
maybe_field
|
||||
].annotation
|
||||
errmsg += (
|
||||
f'{msg.rstrip("`")}\n\n'
|
||||
f'{msg_type}\n'
|
||||
f' |_.{maybe_field}: {field_type} = {field_val}\n'
|
||||
)
|
||||
|
||||
raise MsgTypeError(errmsg) from verr
|
||||
|
||||
except (
|
||||
msgspec.DecodeError,
|
||||
UnicodeDecodeError,
|
||||
|
@ -307,12 +356,16 @@ class MsgpackTCPStream(MsgTransport):
|
|||
|
||||
async def send(
|
||||
self,
|
||||
msg: Any,
|
||||
msg: msgtypes.Msg,
|
||||
|
||||
strict_types: bool = True,
|
||||
# hide_tb: bool = False,
|
||||
) -> None:
|
||||
'''
|
||||
Send a msgpack coded blob-as-msg over TCP.
|
||||
Send a msgpack encoded py-object-blob-as-msg over TCP.
|
||||
|
||||
If `strict_types == True` then a `MsgTypeError` will be raised on any
|
||||
invalid msg type
|
||||
|
||||
'''
|
||||
# __tracebackhide__: bool = hide_tb
|
||||
|
@ -321,25 +374,40 @@ class MsgpackTCPStream(MsgTransport):
|
|||
# NOTE: lookup the `trio.Task.context`'s var for
|
||||
# the current `MsgCodec`.
|
||||
codec: MsgCodec = _ctxvar_MsgCodec.get()
|
||||
# if self._codec != codec:
|
||||
|
||||
# TODO: mask out before release?
|
||||
if self._codec.pld_spec != codec.pld_spec:
|
||||
self._codec = codec
|
||||
log.critical(
|
||||
'.send() using NEW CODEC !?!?\n'
|
||||
f'{self._codec}\n\n'
|
||||
f'OBJ -> {msg}\n'
|
||||
log.runtime(
|
||||
'Using new codec in {self}.send()\n'
|
||||
f'codec: {self._codec}\n\n'
|
||||
f'msg: {msg}\n'
|
||||
)
|
||||
if type(msg) not in types.__spec__:
|
||||
log.warning(
|
||||
'Sending non-`Msg`-spec msg?\n\n'
|
||||
f'{msg}\n'
|
||||
)
|
||||
bytes_data: bytes = codec.encode(msg)
|
||||
|
||||
if type(msg) not in msgtypes.__msg_types__:
|
||||
if strict_types:
|
||||
_raise_msg_type_err(
|
||||
msg,
|
||||
codec=codec,
|
||||
)
|
||||
else:
|
||||
log.warning(
|
||||
'Sending non-`Msg`-spec msg?\n\n'
|
||||
f'{msg}\n'
|
||||
)
|
||||
|
||||
try:
|
||||
bytes_data: bytes = codec.encode(msg)
|
||||
except TypeError as typerr:
|
||||
raise MsgTypeError(
|
||||
'A msg field violates the current spec\n'
|
||||
f'{codec.pld_spec}\n\n'
|
||||
f'{pretty_struct.Struct.pformat(msg)}'
|
||||
) from typerr
|
||||
|
||||
# supposedly the fastest says,
|
||||
# https://stackoverflow.com/a/54027962
|
||||
size: bytes = struct.pack("<I", len(bytes_data))
|
||||
|
||||
return await self.stream.send_all(size + bytes_data)
|
||||
|
||||
@property
|
||||
|
@ -567,7 +635,6 @@ class Channel:
|
|||
f'{pformat(payload)}\n'
|
||||
) # type: ignore
|
||||
assert self._transport
|
||||
|
||||
await self._transport.send(
|
||||
payload,
|
||||
# hide_tb=hide_tb,
|
||||
|
@ -577,6 +644,11 @@ class Channel:
|
|||
assert self._transport
|
||||
return await self._transport.recv()
|
||||
|
||||
# TODO: auto-reconnect features like 0mq/nanomsg?
|
||||
# -[ ] implement it manually with nods to SC prot
|
||||
# possibly on multiple transport backends?
|
||||
# -> seems like that might be re-inventing scalability
|
||||
# prots tho no?
|
||||
# try:
|
||||
# return await self._transport.recv()
|
||||
# except trio.BrokenResourceError:
|
||||
|
|
Loading…
Reference in New Issue