From df548257adc2a7f7252341a1e031fd1f28e1fdc1 Mon Sep 17 00:00:00 2001 From: Tyler Goodlet Date: Sat, 13 Apr 2024 15:19:08 -0400 Subject: [PATCH] IPC ctx refinements around `MsgTypeError` awareness Add a bit of special handling for msg-type-errors with a dedicated log-msg detailing which `.side: str` is the sender/causer and avoiding a `._scope.cancel()` call in such cases since the local task might be written to handle and tolerate the badly (typed) IPC msg. As part of ^, change the ctx task-pair "side" semantics from "caller" -> "callee" to be "parent" -> "child" which better matches the cross-process SC-linked-task supervision hierarchy, and `trio.Nursery.parent_task`; in `trio` the task that opens a nursery is also named the "parent". Impl deats / fixes around the `.side` semantics: - ensure that `._portal: Portal` is set ASAP after `Actor.start_remote_task()` such that if the `Started` transaction fails, the parent-vs.-child sides are still denoted correctly (since `._portal` being set is the predicate for that). - add a helper func `Context.peer_side(side: str) -> str:` which inverts from "child" to "parent" and vice versa, useful for logging info. Other tweaks: - make `_drain_to_final_msg()` return a tuple of a maybe-`Return` and the list of other `pre_result_drained: list[MsgType]` such that we don't ever have to warn about the return msg getting captured as a pre-"result" msg. - Add some strictness flags to `.started()` which allow for toggling whether to error or warn log about mismatching roundtripped `Started` msgs prior to IPC transit. --- tractor/_context.py | 179 ++++++++++++++++++++++++++++++++------------ 1 file changed, 132 insertions(+), 47 deletions(-) diff --git a/tractor/_context.py b/tractor/_context.py index 69f28ac..fc16289 100644 --- a/tractor/_context.py +++ b/tractor/_context.py @@ -47,6 +47,7 @@ import trio from ._exceptions import ( ContextCancelled, InternalError, + MsgTypeError, RemoteActorError, StreamOverrun, pack_from_raise, @@ -59,12 +60,14 @@ from .msg import ( MsgType, MsgCodec, NamespacePath, + PayloadT, Return, Started, Stop, Yield, current_codec, pretty_struct, + types as msgtypes, ) from ._ipc import Channel from ._streaming import MsgStream @@ -88,7 +91,10 @@ async def _drain_to_final_msg( hide_tb: bool = True, msg_limit: int = 6, -) -> list[dict]: +) -> tuple[ + Return|None, + list[MsgType] +]: ''' Drain IPC msgs delivered to the underlying rx-mem-chan `Context._recv_chan` from the runtime in search for a final @@ -109,6 +115,7 @@ async def _drain_to_final_msg( # basically ignoring) any bi-dir-stream msgs still in transit # from the far end. pre_result_drained: list[MsgType] = [] + return_msg: Return|None = None while not ( ctx.maybe_error and not ctx._final_result_is_set() @@ -169,8 +176,6 @@ async def _drain_to_final_msg( # pray to the `trio` gawds that we're corrent with this # msg: dict = await ctx._recv_chan.receive() msg: MsgType = await ctx._recv_chan.receive() - # always capture unexpected/non-result msgs - pre_result_drained.append(msg) # NOTE: we get here if the far end was # `ContextCancelled` in 2 cases: @@ -207,11 +212,13 @@ async def _drain_to_final_msg( # if ctx._recv_chan: # await ctx._recv_chan.aclose() # TODO: ^ we don't need it right? + return_msg = msg break # far end task is still streaming to us so discard # and report depending on local ctx state. case Yield(): + pre_result_drained.append(msg) if ( (ctx._stream.closed and (reason := 'stream was already closed') @@ -236,7 +243,10 @@ async def _drain_to_final_msg( f'{pformat(msg)}\n' ) - return pre_result_drained + return ( + return_msg, + pre_result_drained, + ) # drain up to the `msg_limit` hoping to get # a final result or error/ctxc. @@ -260,6 +270,7 @@ async def _drain_to_final_msg( # -[ ] should be a runtime error if a stream is open right? # Stop() case Stop(): + pre_result_drained.append(msg) log.cancel( 'Remote stream terminated due to "stop" msg:\n\n' f'{pformat(msg)}\n' @@ -269,7 +280,6 @@ async def _drain_to_final_msg( # remote error msg, likely already handled inside # `Context._deliver_msg()` case Error(): - # TODO: can we replace this with `ctx.maybe_raise()`? # -[ ] would this be handier for this case maybe? # async with maybe_raise_on_exit() as raises: @@ -336,6 +346,7 @@ async def _drain_to_final_msg( # XXX should pretty much never get here unless someone # overrides the default `MsgType` spec. case _: + pre_result_drained.append(msg) # It's definitely an internal error if any other # msg type without a`'cid'` field arrives here! if not msg.cid: @@ -352,7 +363,10 @@ async def _drain_to_final_msg( f'{ctx.outcome}\n' ) - return pre_result_drained + return ( + return_msg, + pre_result_drained, + ) class Unresolved: @@ -719,21 +733,36 @@ class Context: Return string indicating which task this instance is wrapping. ''' - return 'caller' if self._portal else 'callee' + return 'parent' if self._portal else 'child' + @staticmethod + def peer_side(side: str) -> str: + match side: + case 'child': + return 'parent' + case 'parent': + return 'child' + + # TODO: remove stat! + # -[ ] re-implement the `.experiemental._pubsub` stuff + # with `MsgStream` and that should be last usage? + # -[ ] remove from `tests/legacy_one_way_streaming.py`! async def send_yield( self, data: Any, - ) -> None: + ''' + Deprecated method for what now is implemented in `MsgStream`. + We need to rework / remove some stuff tho, see above. + + ''' warnings.warn( "`Context.send_yield()` is now deprecated. " "Use ``MessageStream.send()``. ", DeprecationWarning, stacklevel=2, ) - # await self.chan.send({'yield': data, 'cid': self.cid}) await self.chan.send( Yield( cid=self.cid, @@ -742,12 +771,11 @@ class Context: ) async def send_stop(self) -> None: - # await pause() - # await self.chan.send({ - # # Stop( - # 'stop': True, - # 'cid': self.cid - # }) + ''' + Terminate a `MsgStream` dialog-phase by sending the IPC + equiv of a `StopIteration`. + + ''' await self.chan.send( Stop(cid=self.cid) ) @@ -843,6 +871,7 @@ class Context: # self-cancel (ack) or, # peer propagated remote cancellation. + msgtyperr: bool = False if isinstance(error, ContextCancelled): whom: str = ( @@ -854,6 +883,16 @@ class Context: f'{error}' ) + elif isinstance(error, MsgTypeError): + msgtyperr = True + peer_side: str = self.peer_side(self.side) + log.error( + f'IPC dialog error due to msg-type caused by {peer_side!r} side\n\n' + + f'{error}\n' + f'{pformat(self)}\n' + ) + else: log.error( f'Remote context error:\n\n' @@ -894,9 +933,9 @@ class Context: # if `._cancel_called` then `.cancel_acked and .cancel_called` # always should be set. and not self._is_self_cancelled() - and not cs.cancel_called and not cs.cancelled_caught + and not msgtyperr ): # TODO: it'd sure be handy to inject our own # `trio.Cancelled` subtype here ;) @@ -1001,7 +1040,7 @@ class Context: # when the runtime finally receives it during teardown # (normally in `.result()` called from # `Portal.open_context().__aexit__()`) - if side == 'caller': + if side == 'parent': if not self._portal: raise InternalError( 'No portal found!?\n' @@ -1423,7 +1462,10 @@ class Context: # wait for a final context result/error by "draining" # (by more or less ignoring) any bi-dir-stream "yield" # msgs still in transit from the far end. - drained_msgs: list[dict] = await _drain_to_final_msg( + ( + return_msg, + drained_msgs, + ) = await _drain_to_final_msg( ctx=self, hide_tb=hide_tb, ) @@ -1441,7 +1483,10 @@ class Context: log.cancel( 'Ctx drained pre-result msgs:\n' - f'{pformat(drained_msgs)}' + f'{pformat(drained_msgs)}\n\n' + + f'Final return msg:\n' + f'{return_msg}\n' ) self.maybe_raise( @@ -1608,7 +1653,13 @@ class Context: async def started( self, - value: Any | None = None + + # TODO: how to type this so that it's the + # same as the payload type? Is this enough? + value: PayloadT|None = None, + + strict_parity: bool = False, + complain_no_parity: bool = True, ) -> None: ''' @@ -1629,7 +1680,7 @@ class Context: f'called `.started()` twice on context with {self.chan.uid}' ) - started = Started( + started_msg = Started( cid=self.cid, pld=value, ) @@ -1650,28 +1701,54 @@ class Context: # https://zguide.zeromq.org/docs/chapter7/#The-Cheap-or-Nasty-Pattern # codec: MsgCodec = current_codec() - msg_bytes: bytes = codec.encode(started) + msg_bytes: bytes = codec.encode(started_msg) try: # be a "cheap" dialog (see above!) - rt_started = codec.decode(msg_bytes) - if rt_started != started: + if ( + strict_parity + or + complain_no_parity + ): + rt_started: Started = codec.decode(msg_bytes) - # TODO: break these methods out from the struct subtype? - diff = pretty_struct.Struct.__sub__(rt_started, started) + # XXX something is prolly totes cucked with the + # codec state! + if isinstance(rt_started, dict): + rt_started = msgtypes.from_dict_msg( + dict_msg=rt_started, + ) + raise RuntimeError( + 'Failed to roundtrip `Started` msg?\n' + f'{pformat(rt_started)}\n' + ) - complaint: str = ( - 'Started value does not match after codec rountrip?\n\n' - f'{diff}' - ) - # TODO: rn this will pretty much always fail with - # any other sequence type embeded in the - # payload... - if self._strict_started: - raise ValueError(complaint) - else: - log.warning(complaint) + if rt_started != started_msg: + # TODO: break these methods out from the struct subtype? - await self.chan.send(rt_started) + diff = pretty_struct.Struct.__sub__( + rt_started, + started_msg, + ) + complaint: str = ( + 'Started value does not match after codec rountrip?\n\n' + f'{diff}' + ) + + # TODO: rn this will pretty much always fail with + # any other sequence type embeded in the + # payload... + if ( + self._strict_started + or + strict_parity + ): + raise ValueError(complaint) + else: + log.warning(complaint) + + # started_msg = rt_started + + await self.chan.send(started_msg) # raise any msg type error NO MATTER WHAT! except msgspec.ValidationError as verr: @@ -1682,7 +1759,7 @@ class Context: src_validation_error=verr, verb_header='Trying to send payload' # > 'invalid `Started IPC msgs\n' - ) + ) from verr self._started_called = True @@ -1783,13 +1860,17 @@ class Context: else: log_meth = log.runtime - log_meth( - f'Delivering error-msg to caller\n\n' + side: str = self.side - f'<= peer: {from_uid}\n' + peer_side: str = self.peer_side(side) + + log_meth( + f'Delivering IPC ctx error from {peer_side!r} to {side!r} task\n\n' + + f'<= peer {peer_side!r}: {from_uid}\n' f' |_ {nsf}()\n\n' - f'=> cid: {cid}\n' + f'=> {side!r} cid: {cid}\n' f' |_{self._task}\n\n' f'{pformat(re)}\n' @@ -1804,6 +1885,7 @@ class Context: self._maybe_cancel_and_set_remote_error(re) # XXX only case where returning early is fine! + structfmt = pretty_struct.Struct.pformat if self._in_overrun: log.warning( f'Queueing OVERRUN msg on caller task:\n' @@ -1813,7 +1895,7 @@ class Context: f'=> cid: {cid}\n' f' |_{self._task}\n\n' - f'{pformat(msg)}\n' + f'{structfmt(msg)}\n' ) self._overflow_q.append(msg) return False @@ -1827,7 +1909,7 @@ class Context: f'=> {self._task}\n' f' |_cid={self.cid}\n\n' - f'{pformat(msg)}\n' + f'{structfmt(msg)}\n' ) # NOTE: if an error is deteced we should always still @@ -2047,6 +2129,9 @@ async def open_context_from_portal( # place.. allow_overruns=allow_overruns, ) + # ASAP, so that `Context.side: str` can be determined for + # logging / tracing / debug! + ctx._portal: Portal = portal assert ctx._remote_func_type == 'context' msg: Started = await ctx._recv_chan.receive() @@ -2065,10 +2150,10 @@ async def open_context_from_portal( msg=msg, src_err=src_error, log=log, - expect_key='started', + expect_msg=Started, + # expect_key='started', ) - ctx._portal: Portal = portal uid: tuple = portal.channel.uid cid: str = ctx.cid