IPC ctx refinements around `MsgTypeError` awareness

Add a bit of special handling for msg-type-errors with a dedicated
log-msg detailing which `.side: str` is the sender/causer and avoiding
a `._scope.cancel()` call in such cases since the local task might be
written to handle and tolerate the badly (typed) IPC msg.

As part of ^, change the ctx task-pair "side" semantics from "caller" ->
"callee" to be "parent" -> "child" which better matches the
cross-process SC-linked-task supervision hierarchy, and
`trio.Nursery.parent_task`; in `trio` the task that opens a nursery is
also named the "parent".

Impl deats / fixes around the `.side` semantics:
- ensure that `._portal: Portal` is set ASAP after
  `Actor.start_remote_task()` such that if the `Started` transaction
  fails, the parent-vs.-child sides are still denoted correctly (since
  `._portal` being set is the predicate for that).
- add a helper func `Context.peer_side(side: str) -> str:` which inverts
  from "child" to "parent" and vice versa, useful for logging info.

Other tweaks:
- make `_drain_to_final_msg()` return a tuple of a maybe-`Return` and
  the list of other `pre_result_drained: list[MsgType]` such that we
  don't ever have to warn about the return msg getting captured as
  a pre-"result" msg.
- Add some strictness flags to `.started()` which allow for toggling
  whether to error or warn log about mismatching roundtripped `Started`
  msgs prior to IPC transit.
runtime_to_msgspec
Tyler Goodlet 2024-04-13 15:19:08 -04:00
parent 3fb3608879
commit df548257ad
1 changed files with 132 additions and 47 deletions

View File

@ -47,6 +47,7 @@ import trio
from ._exceptions import (
ContextCancelled,
InternalError,
MsgTypeError,
RemoteActorError,
StreamOverrun,
pack_from_raise,
@ -59,12 +60,14 @@ from .msg import (
MsgType,
MsgCodec,
NamespacePath,
PayloadT,
Return,
Started,
Stop,
Yield,
current_codec,
pretty_struct,
types as msgtypes,
)
from ._ipc import Channel
from ._streaming import MsgStream
@ -88,7 +91,10 @@ async def _drain_to_final_msg(
hide_tb: bool = True,
msg_limit: int = 6,
) -> list[dict]:
) -> tuple[
Return|None,
list[MsgType]
]:
'''
Drain IPC msgs delivered to the underlying rx-mem-chan
`Context._recv_chan` from the runtime in search for a final
@ -109,6 +115,7 @@ async def _drain_to_final_msg(
# basically ignoring) any bi-dir-stream msgs still in transit
# from the far end.
pre_result_drained: list[MsgType] = []
return_msg: Return|None = None
while not (
ctx.maybe_error
and not ctx._final_result_is_set()
@ -169,8 +176,6 @@ async def _drain_to_final_msg(
# pray to the `trio` gawds that we're corrent with this
# msg: dict = await ctx._recv_chan.receive()
msg: MsgType = await ctx._recv_chan.receive()
# always capture unexpected/non-result msgs
pre_result_drained.append(msg)
# NOTE: we get here if the far end was
# `ContextCancelled` in 2 cases:
@ -207,11 +212,13 @@ async def _drain_to_final_msg(
# if ctx._recv_chan:
# await ctx._recv_chan.aclose()
# TODO: ^ we don't need it right?
return_msg = msg
break
# far end task is still streaming to us so discard
# and report depending on local ctx state.
case Yield():
pre_result_drained.append(msg)
if (
(ctx._stream.closed
and (reason := 'stream was already closed')
@ -236,7 +243,10 @@ async def _drain_to_final_msg(
f'{pformat(msg)}\n'
)
return pre_result_drained
return (
return_msg,
pre_result_drained,
)
# drain up to the `msg_limit` hoping to get
# a final result or error/ctxc.
@ -260,6 +270,7 @@ async def _drain_to_final_msg(
# -[ ] should be a runtime error if a stream is open right?
# Stop()
case Stop():
pre_result_drained.append(msg)
log.cancel(
'Remote stream terminated due to "stop" msg:\n\n'
f'{pformat(msg)}\n'
@ -269,7 +280,6 @@ async def _drain_to_final_msg(
# remote error msg, likely already handled inside
# `Context._deliver_msg()`
case Error():
# TODO: can we replace this with `ctx.maybe_raise()`?
# -[ ] would this be handier for this case maybe?
# async with maybe_raise_on_exit() as raises:
@ -336,6 +346,7 @@ async def _drain_to_final_msg(
# XXX should pretty much never get here unless someone
# overrides the default `MsgType` spec.
case _:
pre_result_drained.append(msg)
# It's definitely an internal error if any other
# msg type without a`'cid'` field arrives here!
if not msg.cid:
@ -352,7 +363,10 @@ async def _drain_to_final_msg(
f'{ctx.outcome}\n'
)
return pre_result_drained
return (
return_msg,
pre_result_drained,
)
class Unresolved:
@ -719,21 +733,36 @@ class Context:
Return string indicating which task this instance is wrapping.
'''
return 'caller' if self._portal else 'callee'
return 'parent' if self._portal else 'child'
@staticmethod
def peer_side(side: str) -> str:
match side:
case 'child':
return 'parent'
case 'parent':
return 'child'
# TODO: remove stat!
# -[ ] re-implement the `.experiemental._pubsub` stuff
# with `MsgStream` and that should be last usage?
# -[ ] remove from `tests/legacy_one_way_streaming.py`!
async def send_yield(
self,
data: Any,
) -> None:
'''
Deprecated method for what now is implemented in `MsgStream`.
We need to rework / remove some stuff tho, see above.
'''
warnings.warn(
"`Context.send_yield()` is now deprecated. "
"Use ``MessageStream.send()``. ",
DeprecationWarning,
stacklevel=2,
)
# await self.chan.send({'yield': data, 'cid': self.cid})
await self.chan.send(
Yield(
cid=self.cid,
@ -742,12 +771,11 @@ class Context:
)
async def send_stop(self) -> None:
# await pause()
# await self.chan.send({
# # Stop(
# 'stop': True,
# 'cid': self.cid
# })
'''
Terminate a `MsgStream` dialog-phase by sending the IPC
equiv of a `StopIteration`.
'''
await self.chan.send(
Stop(cid=self.cid)
)
@ -843,6 +871,7 @@ class Context:
# self-cancel (ack) or,
# peer propagated remote cancellation.
msgtyperr: bool = False
if isinstance(error, ContextCancelled):
whom: str = (
@ -854,6 +883,16 @@ class Context:
f'{error}'
)
elif isinstance(error, MsgTypeError):
msgtyperr = True
peer_side: str = self.peer_side(self.side)
log.error(
f'IPC dialog error due to msg-type caused by {peer_side!r} side\n\n'
f'{error}\n'
f'{pformat(self)}\n'
)
else:
log.error(
f'Remote context error:\n\n'
@ -894,9 +933,9 @@ class Context:
# if `._cancel_called` then `.cancel_acked and .cancel_called`
# always should be set.
and not self._is_self_cancelled()
and not cs.cancel_called
and not cs.cancelled_caught
and not msgtyperr
):
# TODO: it'd sure be handy to inject our own
# `trio.Cancelled` subtype here ;)
@ -1001,7 +1040,7 @@ class Context:
# when the runtime finally receives it during teardown
# (normally in `.result()` called from
# `Portal.open_context().__aexit__()`)
if side == 'caller':
if side == 'parent':
if not self._portal:
raise InternalError(
'No portal found!?\n'
@ -1423,7 +1462,10 @@ class Context:
# wait for a final context result/error by "draining"
# (by more or less ignoring) any bi-dir-stream "yield"
# msgs still in transit from the far end.
drained_msgs: list[dict] = await _drain_to_final_msg(
(
return_msg,
drained_msgs,
) = await _drain_to_final_msg(
ctx=self,
hide_tb=hide_tb,
)
@ -1441,7 +1483,10 @@ class Context:
log.cancel(
'Ctx drained pre-result msgs:\n'
f'{pformat(drained_msgs)}'
f'{pformat(drained_msgs)}\n\n'
f'Final return msg:\n'
f'{return_msg}\n'
)
self.maybe_raise(
@ -1608,7 +1653,13 @@ class Context:
async def started(
self,
value: Any | None = None
# TODO: how to type this so that it's the
# same as the payload type? Is this enough?
value: PayloadT|None = None,
strict_parity: bool = False,
complain_no_parity: bool = True,
) -> None:
'''
@ -1629,7 +1680,7 @@ class Context:
f'called `.started()` twice on context with {self.chan.uid}'
)
started = Started(
started_msg = Started(
cid=self.cid,
pld=value,
)
@ -1650,28 +1701,54 @@ class Context:
# https://zguide.zeromq.org/docs/chapter7/#The-Cheap-or-Nasty-Pattern
#
codec: MsgCodec = current_codec()
msg_bytes: bytes = codec.encode(started)
msg_bytes: bytes = codec.encode(started_msg)
try:
# be a "cheap" dialog (see above!)
rt_started = codec.decode(msg_bytes)
if rt_started != started:
if (
strict_parity
or
complain_no_parity
):
rt_started: Started = codec.decode(msg_bytes)
# XXX something is prolly totes cucked with the
# codec state!
if isinstance(rt_started, dict):
rt_started = msgtypes.from_dict_msg(
dict_msg=rt_started,
)
raise RuntimeError(
'Failed to roundtrip `Started` msg?\n'
f'{pformat(rt_started)}\n'
)
if rt_started != started_msg:
# TODO: break these methods out from the struct subtype?
diff = pretty_struct.Struct.__sub__(rt_started, started)
diff = pretty_struct.Struct.__sub__(
rt_started,
started_msg,
)
complaint: str = (
'Started value does not match after codec rountrip?\n\n'
f'{diff}'
)
# TODO: rn this will pretty much always fail with
# any other sequence type embeded in the
# payload...
if self._strict_started:
if (
self._strict_started
or
strict_parity
):
raise ValueError(complaint)
else:
log.warning(complaint)
await self.chan.send(rt_started)
# started_msg = rt_started
await self.chan.send(started_msg)
# raise any msg type error NO MATTER WHAT!
except msgspec.ValidationError as verr:
@ -1682,7 +1759,7 @@ class Context:
src_validation_error=verr,
verb_header='Trying to send payload'
# > 'invalid `Started IPC msgs\n'
)
) from verr
self._started_called = True
@ -1783,13 +1860,17 @@ class Context:
else:
log_meth = log.runtime
log_meth(
f'Delivering error-msg to caller\n\n'
side: str = self.side
f'<= peer: {from_uid}\n'
peer_side: str = self.peer_side(side)
log_meth(
f'Delivering IPC ctx error from {peer_side!r} to {side!r} task\n\n'
f'<= peer {peer_side!r}: {from_uid}\n'
f' |_ {nsf}()\n\n'
f'=> cid: {cid}\n'
f'=> {side!r} cid: {cid}\n'
f' |_{self._task}\n\n'
f'{pformat(re)}\n'
@ -1804,6 +1885,7 @@ class Context:
self._maybe_cancel_and_set_remote_error(re)
# XXX only case where returning early is fine!
structfmt = pretty_struct.Struct.pformat
if self._in_overrun:
log.warning(
f'Queueing OVERRUN msg on caller task:\n'
@ -1813,7 +1895,7 @@ class Context:
f'=> cid: {cid}\n'
f' |_{self._task}\n\n'
f'{pformat(msg)}\n'
f'{structfmt(msg)}\n'
)
self._overflow_q.append(msg)
return False
@ -1827,7 +1909,7 @@ class Context:
f'=> {self._task}\n'
f' |_cid={self.cid}\n\n'
f'{pformat(msg)}\n'
f'{structfmt(msg)}\n'
)
# NOTE: if an error is deteced we should always still
@ -2047,6 +2129,9 @@ async def open_context_from_portal(
# place..
allow_overruns=allow_overruns,
)
# ASAP, so that `Context.side: str` can be determined for
# logging / tracing / debug!
ctx._portal: Portal = portal
assert ctx._remote_func_type == 'context'
msg: Started = await ctx._recv_chan.receive()
@ -2065,10 +2150,10 @@ async def open_context_from_portal(
msg=msg,
src_err=src_error,
log=log,
expect_key='started',
expect_msg=Started,
# expect_key='started',
)
ctx._portal: Portal = portal
uid: tuple = portal.channel.uid
cid: str = ctx.cid