forked from goodboy/tractor
1
0
Fork 0

IPC ctx refinements around `MsgTypeError` awareness

Add a bit of special handling for msg-type-errors with a dedicated
log-msg detailing which `.side: str` is the sender/causer and avoiding
a `._scope.cancel()` call in such cases since the local task might be
written to handle and tolerate the badly (typed) IPC msg.

As part of ^, change the ctx task-pair "side" semantics from "caller" ->
"callee" to be "parent" -> "child" which better matches the
cross-process SC-linked-task supervision hierarchy, and
`trio.Nursery.parent_task`; in `trio` the task that opens a nursery is
also named the "parent".

Impl deats / fixes around the `.side` semantics:
- ensure that `._portal: Portal` is set ASAP after
  `Actor.start_remote_task()` such that if the `Started` transaction
  fails, the parent-vs.-child sides are still denoted correctly (since
  `._portal` being set is the predicate for that).
- add a helper func `Context.peer_side(side: str) -> str:` which inverts
  from "child" to "parent" and vice versa, useful for logging info.

Other tweaks:
- make `_drain_to_final_msg()` return a tuple of a maybe-`Return` and
  the list of other `pre_result_drained: list[MsgType]` such that we
  don't ever have to warn about the return msg getting captured as
  a pre-"result" msg.
- Add some strictness flags to `.started()` which allow for toggling
  whether to error or warn log about mismatching roundtripped `Started`
  msgs prior to IPC transit.
runtime_to_msgspec
Tyler Goodlet 2024-04-13 15:19:08 -04:00
parent 3fb3608879
commit df548257ad
1 changed files with 132 additions and 47 deletions

View File

@ -47,6 +47,7 @@ import trio
from ._exceptions import ( from ._exceptions import (
ContextCancelled, ContextCancelled,
InternalError, InternalError,
MsgTypeError,
RemoteActorError, RemoteActorError,
StreamOverrun, StreamOverrun,
pack_from_raise, pack_from_raise,
@ -59,12 +60,14 @@ from .msg import (
MsgType, MsgType,
MsgCodec, MsgCodec,
NamespacePath, NamespacePath,
PayloadT,
Return, Return,
Started, Started,
Stop, Stop,
Yield, Yield,
current_codec, current_codec,
pretty_struct, pretty_struct,
types as msgtypes,
) )
from ._ipc import Channel from ._ipc import Channel
from ._streaming import MsgStream from ._streaming import MsgStream
@ -88,7 +91,10 @@ async def _drain_to_final_msg(
hide_tb: bool = True, hide_tb: bool = True,
msg_limit: int = 6, msg_limit: int = 6,
) -> list[dict]: ) -> tuple[
Return|None,
list[MsgType]
]:
''' '''
Drain IPC msgs delivered to the underlying rx-mem-chan Drain IPC msgs delivered to the underlying rx-mem-chan
`Context._recv_chan` from the runtime in search for a final `Context._recv_chan` from the runtime in search for a final
@ -109,6 +115,7 @@ async def _drain_to_final_msg(
# basically ignoring) any bi-dir-stream msgs still in transit # basically ignoring) any bi-dir-stream msgs still in transit
# from the far end. # from the far end.
pre_result_drained: list[MsgType] = [] pre_result_drained: list[MsgType] = []
return_msg: Return|None = None
while not ( while not (
ctx.maybe_error ctx.maybe_error
and not ctx._final_result_is_set() and not ctx._final_result_is_set()
@ -169,8 +176,6 @@ async def _drain_to_final_msg(
# pray to the `trio` gawds that we're corrent with this # pray to the `trio` gawds that we're corrent with this
# msg: dict = await ctx._recv_chan.receive() # msg: dict = await ctx._recv_chan.receive()
msg: MsgType = await ctx._recv_chan.receive() msg: MsgType = await ctx._recv_chan.receive()
# always capture unexpected/non-result msgs
pre_result_drained.append(msg)
# NOTE: we get here if the far end was # NOTE: we get here if the far end was
# `ContextCancelled` in 2 cases: # `ContextCancelled` in 2 cases:
@ -207,11 +212,13 @@ async def _drain_to_final_msg(
# if ctx._recv_chan: # if ctx._recv_chan:
# await ctx._recv_chan.aclose() # await ctx._recv_chan.aclose()
# TODO: ^ we don't need it right? # TODO: ^ we don't need it right?
return_msg = msg
break break
# far end task is still streaming to us so discard # far end task is still streaming to us so discard
# and report depending on local ctx state. # and report depending on local ctx state.
case Yield(): case Yield():
pre_result_drained.append(msg)
if ( if (
(ctx._stream.closed (ctx._stream.closed
and (reason := 'stream was already closed') and (reason := 'stream was already closed')
@ -236,7 +243,10 @@ async def _drain_to_final_msg(
f'{pformat(msg)}\n' f'{pformat(msg)}\n'
) )
return pre_result_drained return (
return_msg,
pre_result_drained,
)
# drain up to the `msg_limit` hoping to get # drain up to the `msg_limit` hoping to get
# a final result or error/ctxc. # a final result or error/ctxc.
@ -260,6 +270,7 @@ async def _drain_to_final_msg(
# -[ ] should be a runtime error if a stream is open right? # -[ ] should be a runtime error if a stream is open right?
# Stop() # Stop()
case Stop(): case Stop():
pre_result_drained.append(msg)
log.cancel( log.cancel(
'Remote stream terminated due to "stop" msg:\n\n' 'Remote stream terminated due to "stop" msg:\n\n'
f'{pformat(msg)}\n' f'{pformat(msg)}\n'
@ -269,7 +280,6 @@ async def _drain_to_final_msg(
# remote error msg, likely already handled inside # remote error msg, likely already handled inside
# `Context._deliver_msg()` # `Context._deliver_msg()`
case Error(): case Error():
# TODO: can we replace this with `ctx.maybe_raise()`? # TODO: can we replace this with `ctx.maybe_raise()`?
# -[ ] would this be handier for this case maybe? # -[ ] would this be handier for this case maybe?
# async with maybe_raise_on_exit() as raises: # async with maybe_raise_on_exit() as raises:
@ -336,6 +346,7 @@ async def _drain_to_final_msg(
# XXX should pretty much never get here unless someone # XXX should pretty much never get here unless someone
# overrides the default `MsgType` spec. # overrides the default `MsgType` spec.
case _: case _:
pre_result_drained.append(msg)
# It's definitely an internal error if any other # It's definitely an internal error if any other
# msg type without a`'cid'` field arrives here! # msg type without a`'cid'` field arrives here!
if not msg.cid: if not msg.cid:
@ -352,7 +363,10 @@ async def _drain_to_final_msg(
f'{ctx.outcome}\n' f'{ctx.outcome}\n'
) )
return pre_result_drained return (
return_msg,
pre_result_drained,
)
class Unresolved: class Unresolved:
@ -719,21 +733,36 @@ class Context:
Return string indicating which task this instance is wrapping. Return string indicating which task this instance is wrapping.
''' '''
return 'caller' if self._portal else 'callee' return 'parent' if self._portal else 'child'
@staticmethod
def peer_side(side: str) -> str:
match side:
case 'child':
return 'parent'
case 'parent':
return 'child'
# TODO: remove stat!
# -[ ] re-implement the `.experiemental._pubsub` stuff
# with `MsgStream` and that should be last usage?
# -[ ] remove from `tests/legacy_one_way_streaming.py`!
async def send_yield( async def send_yield(
self, self,
data: Any, data: Any,
) -> None: ) -> None:
'''
Deprecated method for what now is implemented in `MsgStream`.
We need to rework / remove some stuff tho, see above.
'''
warnings.warn( warnings.warn(
"`Context.send_yield()` is now deprecated. " "`Context.send_yield()` is now deprecated. "
"Use ``MessageStream.send()``. ", "Use ``MessageStream.send()``. ",
DeprecationWarning, DeprecationWarning,
stacklevel=2, stacklevel=2,
) )
# await self.chan.send({'yield': data, 'cid': self.cid})
await self.chan.send( await self.chan.send(
Yield( Yield(
cid=self.cid, cid=self.cid,
@ -742,12 +771,11 @@ class Context:
) )
async def send_stop(self) -> None: async def send_stop(self) -> None:
# await pause() '''
# await self.chan.send({ Terminate a `MsgStream` dialog-phase by sending the IPC
# # Stop( equiv of a `StopIteration`.
# 'stop': True,
# 'cid': self.cid '''
# })
await self.chan.send( await self.chan.send(
Stop(cid=self.cid) Stop(cid=self.cid)
) )
@ -843,6 +871,7 @@ class Context:
# self-cancel (ack) or, # self-cancel (ack) or,
# peer propagated remote cancellation. # peer propagated remote cancellation.
msgtyperr: bool = False
if isinstance(error, ContextCancelled): if isinstance(error, ContextCancelled):
whom: str = ( whom: str = (
@ -854,6 +883,16 @@ class Context:
f'{error}' f'{error}'
) )
elif isinstance(error, MsgTypeError):
msgtyperr = True
peer_side: str = self.peer_side(self.side)
log.error(
f'IPC dialog error due to msg-type caused by {peer_side!r} side\n\n'
f'{error}\n'
f'{pformat(self)}\n'
)
else: else:
log.error( log.error(
f'Remote context error:\n\n' f'Remote context error:\n\n'
@ -894,9 +933,9 @@ class Context:
# if `._cancel_called` then `.cancel_acked and .cancel_called` # if `._cancel_called` then `.cancel_acked and .cancel_called`
# always should be set. # always should be set.
and not self._is_self_cancelled() and not self._is_self_cancelled()
and not cs.cancel_called and not cs.cancel_called
and not cs.cancelled_caught and not cs.cancelled_caught
and not msgtyperr
): ):
# TODO: it'd sure be handy to inject our own # TODO: it'd sure be handy to inject our own
# `trio.Cancelled` subtype here ;) # `trio.Cancelled` subtype here ;)
@ -1001,7 +1040,7 @@ class Context:
# when the runtime finally receives it during teardown # when the runtime finally receives it during teardown
# (normally in `.result()` called from # (normally in `.result()` called from
# `Portal.open_context().__aexit__()`) # `Portal.open_context().__aexit__()`)
if side == 'caller': if side == 'parent':
if not self._portal: if not self._portal:
raise InternalError( raise InternalError(
'No portal found!?\n' 'No portal found!?\n'
@ -1423,7 +1462,10 @@ class Context:
# wait for a final context result/error by "draining" # wait for a final context result/error by "draining"
# (by more or less ignoring) any bi-dir-stream "yield" # (by more or less ignoring) any bi-dir-stream "yield"
# msgs still in transit from the far end. # msgs still in transit from the far end.
drained_msgs: list[dict] = await _drain_to_final_msg( (
return_msg,
drained_msgs,
) = await _drain_to_final_msg(
ctx=self, ctx=self,
hide_tb=hide_tb, hide_tb=hide_tb,
) )
@ -1441,7 +1483,10 @@ class Context:
log.cancel( log.cancel(
'Ctx drained pre-result msgs:\n' 'Ctx drained pre-result msgs:\n'
f'{pformat(drained_msgs)}' f'{pformat(drained_msgs)}\n\n'
f'Final return msg:\n'
f'{return_msg}\n'
) )
self.maybe_raise( self.maybe_raise(
@ -1608,7 +1653,13 @@ class Context:
async def started( async def started(
self, self,
value: Any | None = None
# TODO: how to type this so that it's the
# same as the payload type? Is this enough?
value: PayloadT|None = None,
strict_parity: bool = False,
complain_no_parity: bool = True,
) -> None: ) -> None:
''' '''
@ -1629,7 +1680,7 @@ class Context:
f'called `.started()` twice on context with {self.chan.uid}' f'called `.started()` twice on context with {self.chan.uid}'
) )
started = Started( started_msg = Started(
cid=self.cid, cid=self.cid,
pld=value, pld=value,
) )
@ -1650,28 +1701,54 @@ class Context:
# https://zguide.zeromq.org/docs/chapter7/#The-Cheap-or-Nasty-Pattern # https://zguide.zeromq.org/docs/chapter7/#The-Cheap-or-Nasty-Pattern
# #
codec: MsgCodec = current_codec() codec: MsgCodec = current_codec()
msg_bytes: bytes = codec.encode(started) msg_bytes: bytes = codec.encode(started_msg)
try: try:
# be a "cheap" dialog (see above!) # be a "cheap" dialog (see above!)
rt_started = codec.decode(msg_bytes) if (
if rt_started != started: strict_parity
or
complain_no_parity
):
rt_started: Started = codec.decode(msg_bytes)
# TODO: break these methods out from the struct subtype? # XXX something is prolly totes cucked with the
diff = pretty_struct.Struct.__sub__(rt_started, started) # codec state!
if isinstance(rt_started, dict):
rt_started = msgtypes.from_dict_msg(
dict_msg=rt_started,
)
raise RuntimeError(
'Failed to roundtrip `Started` msg?\n'
f'{pformat(rt_started)}\n'
)
complaint: str = ( if rt_started != started_msg:
'Started value does not match after codec rountrip?\n\n' # TODO: break these methods out from the struct subtype?
f'{diff}'
)
# TODO: rn this will pretty much always fail with
# any other sequence type embeded in the
# payload...
if self._strict_started:
raise ValueError(complaint)
else:
log.warning(complaint)
await self.chan.send(rt_started) diff = pretty_struct.Struct.__sub__(
rt_started,
started_msg,
)
complaint: str = (
'Started value does not match after codec rountrip?\n\n'
f'{diff}'
)
# TODO: rn this will pretty much always fail with
# any other sequence type embeded in the
# payload...
if (
self._strict_started
or
strict_parity
):
raise ValueError(complaint)
else:
log.warning(complaint)
# started_msg = rt_started
await self.chan.send(started_msg)
# raise any msg type error NO MATTER WHAT! # raise any msg type error NO MATTER WHAT!
except msgspec.ValidationError as verr: except msgspec.ValidationError as verr:
@ -1682,7 +1759,7 @@ class Context:
src_validation_error=verr, src_validation_error=verr,
verb_header='Trying to send payload' verb_header='Trying to send payload'
# > 'invalid `Started IPC msgs\n' # > 'invalid `Started IPC msgs\n'
) ) from verr
self._started_called = True self._started_called = True
@ -1783,13 +1860,17 @@ class Context:
else: else:
log_meth = log.runtime log_meth = log.runtime
log_meth( side: str = self.side
f'Delivering error-msg to caller\n\n'
f'<= peer: {from_uid}\n' peer_side: str = self.peer_side(side)
log_meth(
f'Delivering IPC ctx error from {peer_side!r} to {side!r} task\n\n'
f'<= peer {peer_side!r}: {from_uid}\n'
f' |_ {nsf}()\n\n' f' |_ {nsf}()\n\n'
f'=> cid: {cid}\n' f'=> {side!r} cid: {cid}\n'
f' |_{self._task}\n\n' f' |_{self._task}\n\n'
f'{pformat(re)}\n' f'{pformat(re)}\n'
@ -1804,6 +1885,7 @@ class Context:
self._maybe_cancel_and_set_remote_error(re) self._maybe_cancel_and_set_remote_error(re)
# XXX only case where returning early is fine! # XXX only case where returning early is fine!
structfmt = pretty_struct.Struct.pformat
if self._in_overrun: if self._in_overrun:
log.warning( log.warning(
f'Queueing OVERRUN msg on caller task:\n' f'Queueing OVERRUN msg on caller task:\n'
@ -1813,7 +1895,7 @@ class Context:
f'=> cid: {cid}\n' f'=> cid: {cid}\n'
f' |_{self._task}\n\n' f' |_{self._task}\n\n'
f'{pformat(msg)}\n' f'{structfmt(msg)}\n'
) )
self._overflow_q.append(msg) self._overflow_q.append(msg)
return False return False
@ -1827,7 +1909,7 @@ class Context:
f'=> {self._task}\n' f'=> {self._task}\n'
f' |_cid={self.cid}\n\n' f' |_cid={self.cid}\n\n'
f'{pformat(msg)}\n' f'{structfmt(msg)}\n'
) )
# NOTE: if an error is deteced we should always still # NOTE: if an error is deteced we should always still
@ -2047,6 +2129,9 @@ async def open_context_from_portal(
# place.. # place..
allow_overruns=allow_overruns, allow_overruns=allow_overruns,
) )
# ASAP, so that `Context.side: str` can be determined for
# logging / tracing / debug!
ctx._portal: Portal = portal
assert ctx._remote_func_type == 'context' assert ctx._remote_func_type == 'context'
msg: Started = await ctx._recv_chan.receive() msg: Started = await ctx._recv_chan.receive()
@ -2065,10 +2150,10 @@ async def open_context_from_portal(
msg=msg, msg=msg,
src_err=src_error, src_err=src_error,
log=log, log=log,
expect_key='started', expect_msg=Started,
# expect_key='started',
) )
ctx._portal: Portal = portal
uid: tuple = portal.channel.uid uid: tuple = portal.channel.uid
cid: str = ctx.cid cid: str = ctx.cid