forked from goodboy/tractor
				
			IPC ctx refinements around `MsgTypeError` awareness
Add a bit of special handling for msg-type-errors with a dedicated log-msg detailing which `.side: str` is the sender/causer and avoiding a `._scope.cancel()` call in such cases since the local task might be written to handle and tolerate the badly (typed) IPC msg. As part of ^, change the ctx task-pair "side" semantics from "caller" -> "callee" to be "parent" -> "child" which better matches the cross-process SC-linked-task supervision hierarchy, and `trio.Nursery.parent_task`; in `trio` the task that opens a nursery is also named the "parent". Impl deats / fixes around the `.side` semantics: - ensure that `._portal: Portal` is set ASAP after `Actor.start_remote_task()` such that if the `Started` transaction fails, the parent-vs.-child sides are still denoted correctly (since `._portal` being set is the predicate for that). - add a helper func `Context.peer_side(side: str) -> str:` which inverts from "child" to "parent" and vice versa, useful for logging info. Other tweaks: - make `_drain_to_final_msg()` return a tuple of a maybe-`Return` and the list of other `pre_result_drained: list[MsgType]` such that we don't ever have to warn about the return msg getting captured as a pre-"result" msg. - Add some strictness flags to `.started()` which allow for toggling whether to error or warn log about mismatching roundtripped `Started` msgs prior to IPC transit.remotes/1757153874605917753/main
							parent
							
								
									8690a88e50
								
							
						
					
					
						commit
						26a3ff6b37
					
				|  | @ -47,6 +47,7 @@ import trio | |||
| from ._exceptions import ( | ||||
|     ContextCancelled, | ||||
|     InternalError, | ||||
|     MsgTypeError, | ||||
|     RemoteActorError, | ||||
|     StreamOverrun, | ||||
|     pack_from_raise, | ||||
|  | @ -59,12 +60,14 @@ from .msg import ( | |||
|     MsgType, | ||||
|     MsgCodec, | ||||
|     NamespacePath, | ||||
|     PayloadT, | ||||
|     Return, | ||||
|     Started, | ||||
|     Stop, | ||||
|     Yield, | ||||
|     current_codec, | ||||
|     pretty_struct, | ||||
|     types as msgtypes, | ||||
| ) | ||||
| from ._ipc import Channel | ||||
| from ._streaming import MsgStream | ||||
|  | @ -88,7 +91,10 @@ async def _drain_to_final_msg( | |||
|     hide_tb: bool = True, | ||||
|     msg_limit: int = 6, | ||||
| 
 | ||||
| ) -> list[dict]: | ||||
| ) -> tuple[ | ||||
|     Return|None, | ||||
|     list[MsgType] | ||||
| ]: | ||||
|     ''' | ||||
|     Drain IPC msgs delivered to the underlying rx-mem-chan | ||||
|     `Context._recv_chan` from the runtime in search for a final | ||||
|  | @ -109,6 +115,7 @@ async def _drain_to_final_msg( | |||
|     # basically ignoring) any bi-dir-stream msgs still in transit | ||||
|     # from the far end. | ||||
|     pre_result_drained: list[MsgType] = [] | ||||
|     return_msg: Return|None = None | ||||
|     while not ( | ||||
|         ctx.maybe_error | ||||
|         and not ctx._final_result_is_set() | ||||
|  | @ -169,8 +176,6 @@ async def _drain_to_final_msg( | |||
|             # pray to the `trio` gawds that we're corrent with this | ||||
|             # msg: dict = await ctx._recv_chan.receive() | ||||
|             msg: MsgType = await ctx._recv_chan.receive() | ||||
|             # always capture unexpected/non-result msgs | ||||
|             pre_result_drained.append(msg) | ||||
| 
 | ||||
|         # NOTE: we get here if the far end was | ||||
|         # `ContextCancelled` in 2 cases: | ||||
|  | @ -207,11 +212,13 @@ async def _drain_to_final_msg( | |||
|                 # if ctx._recv_chan: | ||||
|                 #     await ctx._recv_chan.aclose() | ||||
|                 # TODO: ^ we don't need it right? | ||||
|                 return_msg = msg | ||||
|                 break | ||||
| 
 | ||||
|             # far end task is still streaming to us so discard | ||||
|             # and report depending on local ctx state. | ||||
|             case Yield(): | ||||
|                 pre_result_drained.append(msg) | ||||
|                 if ( | ||||
|                     (ctx._stream.closed | ||||
|                      and (reason := 'stream was already closed') | ||||
|  | @ -236,7 +243,10 @@ async def _drain_to_final_msg( | |||
| 
 | ||||
|                         f'{pformat(msg)}\n' | ||||
|                     ) | ||||
|                     return pre_result_drained | ||||
|                     return ( | ||||
|                         return_msg, | ||||
|                         pre_result_drained, | ||||
|                     ) | ||||
| 
 | ||||
|                 # drain up to the `msg_limit` hoping to get | ||||
|                 # a final result or error/ctxc. | ||||
|  | @ -260,6 +270,7 @@ async def _drain_to_final_msg( | |||
|             # -[ ] should be a runtime error if a stream is open right? | ||||
|             # Stop() | ||||
|             case Stop(): | ||||
|                 pre_result_drained.append(msg) | ||||
|                 log.cancel( | ||||
|                     'Remote stream terminated due to "stop" msg:\n\n' | ||||
|                     f'{pformat(msg)}\n' | ||||
|  | @ -269,7 +280,6 @@ async def _drain_to_final_msg( | |||
|             # remote error msg, likely already handled inside | ||||
|             # `Context._deliver_msg()` | ||||
|             case Error(): | ||||
| 
 | ||||
|                 # TODO: can we replace this with `ctx.maybe_raise()`? | ||||
|                 # -[ ]  would this be handier for this case maybe? | ||||
|                 #     async with maybe_raise_on_exit() as raises: | ||||
|  | @ -336,6 +346,7 @@ async def _drain_to_final_msg( | |||
|             # XXX should pretty much never get here unless someone | ||||
|             # overrides the default `MsgType` spec. | ||||
|             case _: | ||||
|                 pre_result_drained.append(msg) | ||||
|                 # It's definitely an internal error if any other | ||||
|                 # msg type without a`'cid'` field arrives here! | ||||
|                 if not msg.cid: | ||||
|  | @ -352,7 +363,10 @@ async def _drain_to_final_msg( | |||
|             f'{ctx.outcome}\n' | ||||
|         ) | ||||
| 
 | ||||
|     return pre_result_drained | ||||
|     return ( | ||||
|         return_msg, | ||||
|         pre_result_drained, | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| class Unresolved: | ||||
|  | @ -719,21 +733,36 @@ class Context: | |||
|         Return string indicating which task this instance is wrapping. | ||||
| 
 | ||||
|         ''' | ||||
|         return 'caller' if self._portal else 'callee' | ||||
|         return 'parent' if self._portal else 'child' | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def peer_side(side: str) -> str: | ||||
|         match side: | ||||
|             case 'child': | ||||
|                 return 'parent' | ||||
|             case 'parent': | ||||
|                 return 'child' | ||||
| 
 | ||||
|     # TODO: remove stat! | ||||
|     # -[ ] re-implement the `.experiemental._pubsub` stuff | ||||
|     #     with `MsgStream` and that should be last usage? | ||||
|     # -[ ] remove from `tests/legacy_one_way_streaming.py`! | ||||
|     async def send_yield( | ||||
|         self, | ||||
|         data: Any, | ||||
| 
 | ||||
|     ) -> None: | ||||
|         ''' | ||||
|         Deprecated method for what now is implemented in `MsgStream`. | ||||
| 
 | ||||
|         We need to rework / remove some stuff tho, see above. | ||||
| 
 | ||||
|         ''' | ||||
|         warnings.warn( | ||||
|             "`Context.send_yield()` is now deprecated. " | ||||
|             "Use ``MessageStream.send()``. ", | ||||
|             DeprecationWarning, | ||||
|             stacklevel=2, | ||||
|         ) | ||||
|         # await self.chan.send({'yield': data, 'cid': self.cid}) | ||||
|         await self.chan.send( | ||||
|             Yield( | ||||
|                 cid=self.cid, | ||||
|  | @ -742,12 +771,11 @@ class Context: | |||
|         ) | ||||
| 
 | ||||
|     async def send_stop(self) -> None: | ||||
|         # await pause() | ||||
|         # await self.chan.send({ | ||||
|         #     # Stop( | ||||
|         #     'stop': True, | ||||
|         #     'cid': self.cid | ||||
|         # }) | ||||
|         ''' | ||||
|         Terminate a `MsgStream` dialog-phase by sending the IPC | ||||
|         equiv of a `StopIteration`. | ||||
| 
 | ||||
|         ''' | ||||
|         await self.chan.send( | ||||
|             Stop(cid=self.cid) | ||||
|         ) | ||||
|  | @ -843,6 +871,7 @@ class Context: | |||
| 
 | ||||
|         # self-cancel (ack) or, | ||||
|         # peer propagated remote cancellation. | ||||
|         msgtyperr: bool = False | ||||
|         if isinstance(error, ContextCancelled): | ||||
| 
 | ||||
|             whom: str = ( | ||||
|  | @ -854,6 +883,16 @@ class Context: | |||
|                 f'{error}' | ||||
|             ) | ||||
| 
 | ||||
|         elif isinstance(error, MsgTypeError): | ||||
|             msgtyperr = True | ||||
|             peer_side: str = self.peer_side(self.side) | ||||
|             log.error( | ||||
|                 f'IPC dialog error due to msg-type caused by {peer_side!r} side\n\n' | ||||
| 
 | ||||
|                 f'{error}\n' | ||||
|                 f'{pformat(self)}\n' | ||||
|             ) | ||||
| 
 | ||||
|         else: | ||||
|             log.error( | ||||
|                 f'Remote context error:\n\n' | ||||
|  | @ -894,9 +933,9 @@ class Context: | |||
|             # if `._cancel_called` then `.cancel_acked and .cancel_called` | ||||
|             # always should be set. | ||||
|             and not self._is_self_cancelled() | ||||
| 
 | ||||
|             and not cs.cancel_called | ||||
|             and not cs.cancelled_caught | ||||
|             and not msgtyperr | ||||
|         ): | ||||
|             # TODO: it'd sure be handy to inject our own | ||||
|             # `trio.Cancelled` subtype here ;) | ||||
|  | @ -1004,7 +1043,7 @@ class Context: | |||
|         # when the runtime finally receives it during teardown | ||||
|         # (normally in `.result()` called from | ||||
|         # `Portal.open_context().__aexit__()`) | ||||
|         if side == 'caller': | ||||
|         if side == 'parent': | ||||
|             if not self._portal: | ||||
|                 raise InternalError( | ||||
|                     'No portal found!?\n' | ||||
|  | @ -1426,7 +1465,10 @@ class Context: | |||
|             # wait for a final context result/error by "draining" | ||||
|             # (by more or less ignoring) any bi-dir-stream "yield" | ||||
|             # msgs still in transit from the far end. | ||||
|             drained_msgs: list[dict] = await _drain_to_final_msg( | ||||
|             ( | ||||
|                 return_msg, | ||||
|                 drained_msgs, | ||||
|             ) = await _drain_to_final_msg( | ||||
|                 ctx=self, | ||||
|                 hide_tb=hide_tb, | ||||
|             ) | ||||
|  | @ -1444,7 +1486,10 @@ class Context: | |||
| 
 | ||||
|             log.cancel( | ||||
|                 'Ctx drained pre-result msgs:\n' | ||||
|                 f'{pformat(drained_msgs)}' | ||||
|                 f'{pformat(drained_msgs)}\n\n' | ||||
| 
 | ||||
|                 f'Final return msg:\n' | ||||
|                 f'{return_msg}\n' | ||||
|             ) | ||||
| 
 | ||||
|         self.maybe_raise( | ||||
|  | @ -1611,7 +1656,13 @@ class Context: | |||
| 
 | ||||
|     async def started( | ||||
|         self, | ||||
|         value: Any | None = None | ||||
| 
 | ||||
|         # TODO: how to type this so that it's the | ||||
|         # same as the payload type? Is this enough? | ||||
|         value: PayloadT|None = None, | ||||
| 
 | ||||
|         strict_parity: bool = False, | ||||
|         complain_no_parity: bool = True, | ||||
| 
 | ||||
|     ) -> None: | ||||
|         ''' | ||||
|  | @ -1632,7 +1683,7 @@ class Context: | |||
|                 f'called `.started()` twice on context with {self.chan.uid}' | ||||
|             ) | ||||
| 
 | ||||
|         started = Started( | ||||
|         started_msg = Started( | ||||
|             cid=self.cid, | ||||
|             pld=value, | ||||
|         ) | ||||
|  | @ -1653,28 +1704,54 @@ class Context: | |||
|         # https://zguide.zeromq.org/docs/chapter7/#The-Cheap-or-Nasty-Pattern | ||||
|         # | ||||
|         codec: MsgCodec = current_codec() | ||||
|         msg_bytes: bytes = codec.encode(started) | ||||
|         msg_bytes: bytes = codec.encode(started_msg) | ||||
|         try: | ||||
|             # be a "cheap" dialog (see above!) | ||||
|             rt_started = codec.decode(msg_bytes) | ||||
|             if rt_started != started: | ||||
|             if ( | ||||
|                 strict_parity | ||||
|                 or | ||||
|                 complain_no_parity | ||||
|             ): | ||||
|                 rt_started: Started = codec.decode(msg_bytes) | ||||
| 
 | ||||
|                 # TODO: break these methods out from the struct subtype? | ||||
|                 diff = pretty_struct.Struct.__sub__(rt_started, started) | ||||
|                 # XXX something is prolly totes cucked with the | ||||
|                 # codec state! | ||||
|                 if isinstance(rt_started, dict): | ||||
|                     rt_started = msgtypes.from_dict_msg( | ||||
|                         dict_msg=rt_started, | ||||
|                     ) | ||||
|                     raise RuntimeError( | ||||
|                         'Failed to roundtrip `Started` msg?\n' | ||||
|                         f'{pformat(rt_started)}\n' | ||||
|                     ) | ||||
| 
 | ||||
|                 complaint: str = ( | ||||
|                     'Started value does not match after codec rountrip?\n\n' | ||||
|                     f'{diff}' | ||||
|                 ) | ||||
|                 # TODO: rn this will pretty much always fail with | ||||
|                 # any other sequence type embeded in the | ||||
|                 # payload... | ||||
|                 if self._strict_started: | ||||
|                     raise ValueError(complaint) | ||||
|                 else: | ||||
|                     log.warning(complaint) | ||||
|                 if rt_started != started_msg: | ||||
|                     # TODO: break these methods out from the struct subtype? | ||||
| 
 | ||||
|             await self.chan.send(rt_started) | ||||
|                     diff = pretty_struct.Struct.__sub__( | ||||
|                         rt_started, | ||||
|                         started_msg, | ||||
|                     ) | ||||
|                     complaint: str = ( | ||||
|                         'Started value does not match after codec rountrip?\n\n' | ||||
|                         f'{diff}' | ||||
|                     ) | ||||
| 
 | ||||
|                     # TODO: rn this will pretty much always fail with | ||||
|                     # any other sequence type embeded in the | ||||
|                     # payload... | ||||
|                     if ( | ||||
|                         self._strict_started | ||||
|                         or | ||||
|                         strict_parity | ||||
|                     ): | ||||
|                         raise ValueError(complaint) | ||||
|                     else: | ||||
|                         log.warning(complaint) | ||||
| 
 | ||||
|                 # started_msg = rt_started | ||||
| 
 | ||||
|             await self.chan.send(started_msg) | ||||
| 
 | ||||
|         # raise any msg type error NO MATTER WHAT! | ||||
|         except msgspec.ValidationError as verr: | ||||
|  | @ -1685,7 +1762,7 @@ class Context: | |||
|                 src_validation_error=verr, | ||||
|                 verb_header='Trying to send payload' | ||||
|                 # > 'invalid `Started IPC msgs\n' | ||||
|             ) | ||||
|             ) from verr | ||||
| 
 | ||||
|         self._started_called = True | ||||
| 
 | ||||
|  | @ -1786,13 +1863,17 @@ class Context: | |||
|             else: | ||||
|                 log_meth = log.runtime | ||||
| 
 | ||||
|             log_meth( | ||||
|                 f'Delivering error-msg to caller\n\n' | ||||
|             side: str = self.side | ||||
| 
 | ||||
|                 f'<= peer: {from_uid}\n' | ||||
|             peer_side: str = self.peer_side(side) | ||||
| 
 | ||||
|             log_meth( | ||||
|                 f'Delivering IPC ctx error from {peer_side!r} to {side!r} task\n\n' | ||||
| 
 | ||||
|                 f'<= peer {peer_side!r}: {from_uid}\n' | ||||
|                 f'  |_ {nsf}()\n\n' | ||||
| 
 | ||||
|                 f'=> cid: {cid}\n' | ||||
|                 f'=> {side!r} cid: {cid}\n' | ||||
|                 f'  |_{self._task}\n\n' | ||||
| 
 | ||||
|                 f'{pformat(re)}\n' | ||||
|  | @ -1807,6 +1888,7 @@ class Context: | |||
|             self._maybe_cancel_and_set_remote_error(re) | ||||
| 
 | ||||
|         # XXX only case where returning early is fine! | ||||
|         structfmt = pretty_struct.Struct.pformat | ||||
|         if self._in_overrun: | ||||
|             log.warning( | ||||
|                 f'Queueing OVERRUN msg on caller task:\n' | ||||
|  | @ -1816,7 +1898,7 @@ class Context: | |||
|                 f'=> cid: {cid}\n' | ||||
|                 f'  |_{self._task}\n\n' | ||||
| 
 | ||||
|                 f'{pformat(msg)}\n' | ||||
|                 f'{structfmt(msg)}\n' | ||||
|             ) | ||||
|             self._overflow_q.append(msg) | ||||
|             return False | ||||
|  | @ -1830,7 +1912,7 @@ class Context: | |||
|                 f'=> {self._task}\n' | ||||
|                 f'  |_cid={self.cid}\n\n' | ||||
| 
 | ||||
|                 f'{pformat(msg)}\n' | ||||
|                 f'{structfmt(msg)}\n' | ||||
|             ) | ||||
| 
 | ||||
|             # NOTE: if an error is deteced we should always still | ||||
|  | @ -2050,6 +2132,9 @@ async def open_context_from_portal( | |||
|         # place.. | ||||
|         allow_overruns=allow_overruns, | ||||
|     ) | ||||
|     # ASAP, so that `Context.side: str` can be determined for | ||||
|     # logging / tracing / debug! | ||||
|     ctx._portal: Portal = portal | ||||
| 
 | ||||
|     assert ctx._remote_func_type == 'context' | ||||
|     msg: Started = await ctx._recv_chan.receive() | ||||
|  | @ -2068,10 +2153,10 @@ async def open_context_from_portal( | |||
|             msg=msg, | ||||
|             src_err=src_error, | ||||
|             log=log, | ||||
|             expect_key='started', | ||||
|             expect_msg=Started, | ||||
|             # expect_key='started', | ||||
|         ) | ||||
| 
 | ||||
|     ctx._portal: Portal = portal | ||||
|     uid: tuple = portal.channel.uid | ||||
|     cid: str = ctx.cid | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue