Factor `MsgpackTCPStream` msg-type checks
Add both the `.send()` and `.recv()` handling blocks to a common `_raise_msg_type_err()` which includes detailed error msg formatting: - the `.recv()` side case does introspection of the `Msg` fields and attempting to report the exact (field type related) issue - `.send()` side does some boxed-error style tb formatting like `RemoteActorError`. - add a `strict_types: bool` to `.send()` to allow for just warning on bad inputs versus raising, but always raise from any `Encoder` type error.runtime_to_msgspec
parent
97bfbdbc1c
commit
4cfe4979ff
164
tractor/_ipc.py
164
tractor/_ipc.py
|
@ -54,7 +54,8 @@ from tractor.msg import (
|
||||||
_ctxvar_MsgCodec,
|
_ctxvar_MsgCodec,
|
||||||
_codec,
|
_codec,
|
||||||
MsgCodec,
|
MsgCodec,
|
||||||
types,
|
types as msgtypes,
|
||||||
|
pretty_struct,
|
||||||
)
|
)
|
||||||
|
|
||||||
log = get_logger(__name__)
|
log = get_logger(__name__)
|
||||||
|
@ -72,6 +73,7 @@ def get_stream_addrs(stream: trio.SocketStream) -> tuple:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: this should be our `Union[*msgtypes.__spec__]` now right?
|
||||||
MsgType = TypeVar("MsgType")
|
MsgType = TypeVar("MsgType")
|
||||||
|
|
||||||
# TODO: consider using a generic def and indexing with our eventual
|
# TODO: consider using a generic def and indexing with our eventual
|
||||||
|
@ -116,6 +118,73 @@ class MsgTransport(Protocol[MsgType]):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def _raise_msg_type_err(
|
||||||
|
msg: Any|bytes,
|
||||||
|
codec: MsgCodec,
|
||||||
|
validation_err: msgspec.ValidationError|None = None,
|
||||||
|
verb_header: str = '',
|
||||||
|
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
# if side == 'send':
|
||||||
|
if validation_err is None: # send-side
|
||||||
|
|
||||||
|
import traceback
|
||||||
|
from tractor._exceptions import pformat_boxed_tb
|
||||||
|
|
||||||
|
fmt_spec: str = '\n'.join(
|
||||||
|
map(str, codec.msg_spec.__args__)
|
||||||
|
)
|
||||||
|
fmt_stack: str = (
|
||||||
|
'\n'.join(traceback.format_stack(limit=3))
|
||||||
|
)
|
||||||
|
tb_fmt: str = pformat_boxed_tb(
|
||||||
|
tb_str=fmt_stack,
|
||||||
|
# fields_str=header,
|
||||||
|
field_prefix=' ',
|
||||||
|
indent='',
|
||||||
|
)
|
||||||
|
raise MsgTypeError(
|
||||||
|
f'invalid msg -> {msg}: {type(msg)}\n\n'
|
||||||
|
f'{tb_fmt}\n'
|
||||||
|
f'Valid IPC msgs are:\n\n'
|
||||||
|
# f' ------ - ------\n'
|
||||||
|
f'{fmt_spec}\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# decode the msg-bytes using the std msgpack
|
||||||
|
# interchange-prot (i.e. without any
|
||||||
|
# `msgspec.Struct` handling) so that we can
|
||||||
|
# determine what `.msg.types.Msg` is the culprit
|
||||||
|
# by reporting the received value.
|
||||||
|
msg_dict: dict = msgspec.msgpack.decode(msg)
|
||||||
|
msg_type_name: str = msg_dict['msg_type']
|
||||||
|
msg_type = getattr(msgtypes, msg_type_name)
|
||||||
|
errmsg: str = (
|
||||||
|
f'invalid `{msg_type_name}` IPC msg\n\n'
|
||||||
|
)
|
||||||
|
if verb_header:
|
||||||
|
errmsg = f'{verb_header} ' + errmsg
|
||||||
|
|
||||||
|
# XXX see if we can determine the exact invalid field
|
||||||
|
# such that we can comprehensively report the
|
||||||
|
# specific field's type problem
|
||||||
|
msgspec_msg: str = validation_err.args[0].rstrip('`')
|
||||||
|
msg, _, maybe_field = msgspec_msg.rpartition('$.')
|
||||||
|
if field_val := msg_dict.get(maybe_field):
|
||||||
|
field_type: Union[Type] = msg_type.__signature__.parameters[
|
||||||
|
maybe_field
|
||||||
|
].annotation
|
||||||
|
errmsg += (
|
||||||
|
f'{msg.rstrip("`")}\n\n'
|
||||||
|
f'{msg_type}\n'
|
||||||
|
f' |_.{maybe_field}: {field_type} = {field_val!r}\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
raise MsgTypeError(errmsg) from validation_err
|
||||||
|
|
||||||
|
|
||||||
# TODO: not sure why we have to inherit here, but it seems to be an
|
# TODO: not sure why we have to inherit here, but it seems to be an
|
||||||
# issue with ``get_msg_transport()`` returning a ``Type[Protocol]``;
|
# issue with ``get_msg_transport()`` returning a ``Type[Protocol]``;
|
||||||
# probably should make a `mypy` issue?
|
# probably should make a `mypy` issue?
|
||||||
|
@ -175,9 +244,10 @@ class MsgpackTCPStream(MsgTransport):
|
||||||
or
|
or
|
||||||
_codec._ctxvar_MsgCodec.get()
|
_codec._ctxvar_MsgCodec.get()
|
||||||
)
|
)
|
||||||
log.critical(
|
# TODO: mask out before release?
|
||||||
'!?!: USING STD `tractor` CODEC !?!?\n'
|
log.runtime(
|
||||||
f'{self._codec}\n'
|
f'New {self} created with codec\n'
|
||||||
|
f'codec: {self._codec}\n'
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _iter_packets(self) -> AsyncGenerator[dict, None]:
|
async def _iter_packets(self) -> AsyncGenerator[dict, None]:
|
||||||
|
@ -221,16 +291,18 @@ class MsgpackTCPStream(MsgTransport):
|
||||||
# NOTE: lookup the `trio.Task.context`'s var for
|
# NOTE: lookup the `trio.Task.context`'s var for
|
||||||
# the current `MsgCodec`.
|
# the current `MsgCodec`.
|
||||||
codec: MsgCodec = _ctxvar_MsgCodec.get()
|
codec: MsgCodec = _ctxvar_MsgCodec.get()
|
||||||
|
|
||||||
|
# TODO: mask out before release?
|
||||||
if self._codec.pld_spec != codec.pld_spec:
|
if self._codec.pld_spec != codec.pld_spec:
|
||||||
# assert (
|
# assert (
|
||||||
# task := trio.lowlevel.current_task()
|
# task := trio.lowlevel.current_task()
|
||||||
# ) is not self._task
|
# ) is not self._task
|
||||||
# self._task = task
|
# self._task = task
|
||||||
self._codec = codec
|
self._codec = codec
|
||||||
log.critical(
|
log.runtime(
|
||||||
'.recv() USING NEW CODEC !?!?\n'
|
'Using new codec in {self}.recv()\n'
|
||||||
f'{self._codec}\n\n'
|
f'codec: {self._codec}\n\n'
|
||||||
f'msg_bytes -> {msg_bytes}\n'
|
f'msg_bytes: {msg_bytes}\n'
|
||||||
)
|
)
|
||||||
yield codec.decode(msg_bytes)
|
yield codec.decode(msg_bytes)
|
||||||
|
|
||||||
|
@ -252,36 +324,13 @@ class MsgpackTCPStream(MsgTransport):
|
||||||
# and always raise such that spec violations
|
# and always raise such that spec violations
|
||||||
# are never allowed to be caught silently!
|
# are never allowed to be caught silently!
|
||||||
except msgspec.ValidationError as verr:
|
except msgspec.ValidationError as verr:
|
||||||
|
# re-raise as type error
|
||||||
# decode the msg-bytes using the std msgpack
|
_raise_msg_type_err(
|
||||||
# interchange-prot (i.e. without any
|
msg=msg_bytes,
|
||||||
# `msgspec.Struct` handling) so that we can
|
codec=codec,
|
||||||
# determine what `.msg.types.Msg` is the culprit
|
validation_err=verr,
|
||||||
# by reporting the received value.
|
|
||||||
msg_dict: dict = msgspec.msgpack.decode(msg_bytes)
|
|
||||||
msg_type_name: str = msg_dict['msg_type']
|
|
||||||
msg_type = getattr(types, msg_type_name)
|
|
||||||
errmsg: str = (
|
|
||||||
f'Received invalid IPC `{msg_type_name}` msg\n\n'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# XXX see if we can determine the exact invalid field
|
|
||||||
# such that we can comprehensively report the
|
|
||||||
# specific field's type problem
|
|
||||||
msgspec_msg: str = verr.args[0].rstrip('`')
|
|
||||||
msg, _, maybe_field = msgspec_msg.rpartition('$.')
|
|
||||||
if field_val := msg_dict.get(maybe_field):
|
|
||||||
field_type: Union[Type] = msg_type.__signature__.parameters[
|
|
||||||
maybe_field
|
|
||||||
].annotation
|
|
||||||
errmsg += (
|
|
||||||
f'{msg.rstrip("`")}\n\n'
|
|
||||||
f'{msg_type}\n'
|
|
||||||
f' |_.{maybe_field}: {field_type} = {field_val}\n'
|
|
||||||
)
|
|
||||||
|
|
||||||
raise MsgTypeError(errmsg) from verr
|
|
||||||
|
|
||||||
except (
|
except (
|
||||||
msgspec.DecodeError,
|
msgspec.DecodeError,
|
||||||
UnicodeDecodeError,
|
UnicodeDecodeError,
|
||||||
|
@ -307,12 +356,16 @@ class MsgpackTCPStream(MsgTransport):
|
||||||
|
|
||||||
async def send(
|
async def send(
|
||||||
self,
|
self,
|
||||||
msg: Any,
|
msg: msgtypes.Msg,
|
||||||
|
|
||||||
|
strict_types: bool = True,
|
||||||
# hide_tb: bool = False,
|
# hide_tb: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
'''
|
'''
|
||||||
Send a msgpack coded blob-as-msg over TCP.
|
Send a msgpack encoded py-object-blob-as-msg over TCP.
|
||||||
|
|
||||||
|
If `strict_types == True` then a `MsgTypeError` will be raised on any
|
||||||
|
invalid msg type
|
||||||
|
|
||||||
'''
|
'''
|
||||||
# __tracebackhide__: bool = hide_tb
|
# __tracebackhide__: bool = hide_tb
|
||||||
|
@ -321,25 +374,40 @@ class MsgpackTCPStream(MsgTransport):
|
||||||
# NOTE: lookup the `trio.Task.context`'s var for
|
# NOTE: lookup the `trio.Task.context`'s var for
|
||||||
# the current `MsgCodec`.
|
# the current `MsgCodec`.
|
||||||
codec: MsgCodec = _ctxvar_MsgCodec.get()
|
codec: MsgCodec = _ctxvar_MsgCodec.get()
|
||||||
# if self._codec != codec:
|
|
||||||
|
# TODO: mask out before release?
|
||||||
if self._codec.pld_spec != codec.pld_spec:
|
if self._codec.pld_spec != codec.pld_spec:
|
||||||
self._codec = codec
|
self._codec = codec
|
||||||
log.critical(
|
log.runtime(
|
||||||
'.send() using NEW CODEC !?!?\n'
|
'Using new codec in {self}.send()\n'
|
||||||
f'{self._codec}\n\n'
|
f'codec: {self._codec}\n\n'
|
||||||
f'OBJ -> {msg}\n'
|
f'msg: {msg}\n'
|
||||||
)
|
)
|
||||||
if type(msg) not in types.__spec__:
|
|
||||||
|
if type(msg) not in msgtypes.__msg_types__:
|
||||||
|
if strict_types:
|
||||||
|
_raise_msg_type_err(
|
||||||
|
msg,
|
||||||
|
codec=codec,
|
||||||
|
)
|
||||||
|
else:
|
||||||
log.warning(
|
log.warning(
|
||||||
'Sending non-`Msg`-spec msg?\n\n'
|
'Sending non-`Msg`-spec msg?\n\n'
|
||||||
f'{msg}\n'
|
f'{msg}\n'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
bytes_data: bytes = codec.encode(msg)
|
bytes_data: bytes = codec.encode(msg)
|
||||||
|
except TypeError as typerr:
|
||||||
|
raise MsgTypeError(
|
||||||
|
'A msg field violates the current spec\n'
|
||||||
|
f'{codec.pld_spec}\n\n'
|
||||||
|
f'{pretty_struct.Struct.pformat(msg)}'
|
||||||
|
) from typerr
|
||||||
|
|
||||||
# supposedly the fastest says,
|
# supposedly the fastest says,
|
||||||
# https://stackoverflow.com/a/54027962
|
# https://stackoverflow.com/a/54027962
|
||||||
size: bytes = struct.pack("<I", len(bytes_data))
|
size: bytes = struct.pack("<I", len(bytes_data))
|
||||||
|
|
||||||
return await self.stream.send_all(size + bytes_data)
|
return await self.stream.send_all(size + bytes_data)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -567,7 +635,6 @@ class Channel:
|
||||||
f'{pformat(payload)}\n'
|
f'{pformat(payload)}\n'
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
assert self._transport
|
assert self._transport
|
||||||
|
|
||||||
await self._transport.send(
|
await self._transport.send(
|
||||||
payload,
|
payload,
|
||||||
# hide_tb=hide_tb,
|
# hide_tb=hide_tb,
|
||||||
|
@ -577,6 +644,11 @@ class Channel:
|
||||||
assert self._transport
|
assert self._transport
|
||||||
return await self._transport.recv()
|
return await self._transport.recv()
|
||||||
|
|
||||||
|
# TODO: auto-reconnect features like 0mq/nanomsg?
|
||||||
|
# -[ ] implement it manually with nods to SC prot
|
||||||
|
# possibly on multiple transport backends?
|
||||||
|
# -> seems like that might be re-inventing scalability
|
||||||
|
# prots tho no?
|
||||||
# try:
|
# try:
|
||||||
# return await self._transport.recv()
|
# return await self._transport.recv()
|
||||||
# except trio.BrokenResourceError:
|
# except trio.BrokenResourceError:
|
||||||
|
|
Loading…
Reference in New Issue