diff --git a/tractor/_context.py b/tractor/_context.py index fe5d654..ed720a2 100644 --- a/tractor/_context.py +++ b/tractor/_context.py @@ -41,6 +41,7 @@ from typing import ( Callable, Mapping, Type, + TypeAlias, TYPE_CHECKING, Union, ) @@ -155,6 +156,41 @@ class Context: # payload receiver _pld_rx: msgops.PldRx + @property + def pld_rx(self) -> msgops.PldRx: + ''' + The current `tractor.Context`'s msg-payload-receiver. + + A payload receiver is the IPC-msg processing sub-sys which + filters inter-actor-task communicated payload data, i.e. the + `PayloadMsg.pld: PayloadT` field value, AFTER its container + shuttlle msg (eg. `Started`/`Yield`/`Return) has been + delivered up from `tractor`'s transport layer but BEFORE the + data is yielded to `tractor` application code. + + The "IPC-primitive API" is normally one of a `Context` (this)` or a `MsgStream` + or some higher level API using one of them. + + For ex. `pld_data: PayloadT = MsgStream.receive()` implicitly + calls into the stream's parent `Context.pld_rx.recv_pld().` to + receive the latest `PayloadMsg.pld` value. + + Modification of the current payload spec via `limit_plds()` + allows a `tractor` application to contextually filter IPC + payload content with a type specification as supported by the + interchange backend. + + - for `msgspec` see . + + Note that the `PldRx` itself is a per-`Context` instance that + normally only changes when some (sub-)task, on a given "side" + of the IPC ctx (either a "child"-side RPC or inside + a "parent"-side `Portal.open_context()` block), modifies it + using the `.msg._ops.limit_plds()` API. + + ''' + return self._pld_rx + # full "namespace-path" to target RPC function _nsf: NamespacePath @@ -231,6 +267,8 @@ class Context: # init and streaming state _started_called: bool = False + _started_msg: MsgType|None = None + _started_pld: Any = None _stream_opened: bool = False _stream: MsgStream|None = None @@ -623,7 +661,7 @@ class Context: log.runtime( 'Setting remote error for ctx\n\n' f'<= {self.peer_side!r}: {self.chan.uid}\n' - f'=> {self.side!r}\n\n' + f'=> {self.side!r}: {self._actor.uid}\n\n' f'{error}' ) self._remote_error: BaseException = error @@ -678,7 +716,7 @@ class Context: log.error( f'Remote context error:\n\n' # f'{pformat(self)}\n' - f'{error}\n' + f'{error}' ) if self._canceller is None: @@ -724,8 +762,10 @@ class Context: ) else: message: str = 'NOT cancelling `Context._scope` !\n\n' + # from .devx import mk_pdb + # mk_pdb().set_trace() - fmt_str: str = 'No `self._scope: CancelScope` was set/used ?' + fmt_str: str = 'No `self._scope: CancelScope` was set/used ?\n' if ( cs and @@ -805,6 +845,7 @@ class Context: # f'{ci.api_nsp}()\n' # ) + # TODO: use `.dev._frame_stack` scanning to find caller! return 'Portal.open_context()' async def cancel( @@ -1304,17 +1345,6 @@ class Context: ctx=self, hide_tb=hide_tb, ) - for msg in drained_msgs: - - # TODO: mask this by default.. - if isinstance(msg, Return): - # from .devx import pause - # await pause() - # raise InternalError( - log.warning( - 'Final `return` msg should never be drained !?!?\n\n' - f'{msg}\n' - ) drained_status: str = ( 'Ctx drained to final outcome msg\n\n' @@ -1435,6 +1465,10 @@ class Context: self._result ) + @property + def has_outcome(self) -> bool: + return bool(self.maybe_error) or self._final_result_is_set() + # @property def repr_outcome( self, @@ -1637,8 +1671,6 @@ class Context: ) if rt_started != started_msg: - # TODO: break these methods out from the struct subtype? - # TODO: make that one a mod func too.. diff = pretty_struct.Struct.__sub__( rt_started, @@ -1674,6 +1706,8 @@ class Context: ) from verr self._started_called = True + self._started_msg = started_msg + self._started_pld = value async def _drain_overflows( self, @@ -1961,6 +1995,7 @@ async def open_context_from_portal( portal: Portal, func: Callable, + pld_spec: TypeAlias|None = None, allow_overruns: bool = False, # TODO: if we set this the wrapping `@acm` body will @@ -2026,7 +2061,7 @@ async def open_context_from_portal( # XXX NOTE XXX: currenly we do NOT allow opening a contex # with "self" since the local feeder mem-chan processing # is not built for it. - if portal.channel.uid == portal.actor.uid: + if (uid := portal.channel.uid) == portal.actor.uid: raise RuntimeError( '** !! Invalid Operation !! **\n' 'Can not open an IPC ctx with the local actor!\n' @@ -2054,32 +2089,45 @@ async def open_context_from_portal( assert ctx._caller_info _ctxvar_Context.set(ctx) - # XXX NOTE since `._scope` is NOT set BEFORE we retreive the - # `Started`-msg any cancellation triggered - # in `._maybe_cancel_and_set_remote_error()` will - # NOT actually cancel the below line! - # -> it's expected that if there is an error in this phase of - # the dialog, the `Error` msg should be raised from the `msg` - # handling block below. - first: Any = await ctx._pld_rx.recv_pld( - ctx=ctx, - expect_msg=Started, - ) - ctx._started_called: bool = True - - uid: tuple = portal.channel.uid - cid: str = ctx.cid - # placeholder for any exception raised in the runtime # or by user tasks which cause this context's closure. scope_err: BaseException|None = None ctxc_from_callee: ContextCancelled|None = None try: - async with trio.open_nursery() as nurse: + async with ( + trio.open_nursery() as tn, + msgops.maybe_limit_plds( + ctx=ctx, + spec=pld_spec, + ) as maybe_msgdec, + ): + if maybe_msgdec: + assert maybe_msgdec.pld_spec == pld_spec - # NOTE: used to start overrun queuing tasks - ctx._scope_nursery: trio.Nursery = nurse - ctx._scope: trio.CancelScope = nurse.cancel_scope + # XXX NOTE since `._scope` is NOT set BEFORE we retreive the + # `Started`-msg any cancellation triggered + # in `._maybe_cancel_and_set_remote_error()` will + # NOT actually cancel the below line! + # -> it's expected that if there is an error in this phase of + # the dialog, the `Error` msg should be raised from the `msg` + # handling block below. + started_msg, first = await ctx._pld_rx.recv_msg_w_pld( + ipc=ctx, + expect_msg=Started, + passthrough_non_pld_msgs=False, + ) + + # from .devx import pause + # await pause() + ctx._started_called: bool = True + ctx._started_msg: bool = started_msg + ctx._started_pld: bool = first + + # NOTE: this in an implicit runtime nursery used to, + # - start overrun queuing tasks when as well as + # for cancellation of the scope opened by the user. + ctx._scope_nursery: trio.Nursery = tn + ctx._scope: trio.CancelScope = tn.cancel_scope # deliver context instance and .started() msg value # in enter tuple. @@ -2126,13 +2174,13 @@ async def open_context_from_portal( # when in allow_overruns mode there may be # lingering overflow sender tasks remaining? - if nurse.child_tasks: + if tn.child_tasks: # XXX: ensure we are in overrun state # with ``._allow_overruns=True`` bc otherwise # there should be no tasks in this nursery! if ( not ctx._allow_overruns - or len(nurse.child_tasks) > 1 + or len(tn.child_tasks) > 1 ): raise InternalError( 'Context has sub-tasks but is ' @@ -2304,8 +2352,8 @@ async def open_context_from_portal( ): log.warning( 'IPC connection for context is broken?\n' - f'task:{cid}\n' - f'actor:{uid}' + f'task: {ctx.cid}\n' + f'actor: {uid}' ) raise # duh @@ -2455,9 +2503,8 @@ async def open_context_from_portal( and ctx.cancel_acked ): log.cancel( - 'Context cancelled by {ctx.side!r}-side task\n' + f'Context cancelled by {ctx.side!r}-side task\n' f'|_{ctx._task}\n\n' - f'{repr(scope_err)}\n' ) @@ -2485,7 +2532,7 @@ async def open_context_from_portal( f'cid: {ctx.cid}\n' ) portal.actor._contexts.pop( - (uid, cid), + (uid, ctx.cid), None, ) @@ -2516,8 +2563,9 @@ def mk_context( from .devx._frame_stack import find_caller_info caller_info: CallerInfo|None = find_caller_info() - # TODO: when/how do we apply `.limit_plds()` from here? - pld_rx: msgops.PldRx = msgops.current_pldrx() + pld_rx = msgops.PldRx( + _pld_dec=msgops._def_any_pldec, + ) ctx = Context( chan=chan, @@ -2531,13 +2579,16 @@ def mk_context( _caller_info=caller_info, **kwargs, ) + pld_rx._ctx = ctx ctx._result = Unresolved return ctx # TODO: use the new type-parameters to annotate this in 3.13? # -[ ] https://peps.python.org/pep-0718/#unknown-types -def context(func: Callable) -> Callable: +def context( + func: Callable, +) -> Callable: ''' Mark an (async) function as an SC-supervised, inter-`Actor`, child-`trio.Task`, IPC endpoint otherwise known more diff --git a/tractor/_streaming.py b/tractor/_streaming.py index dd4cd0e..a008eaf 100644 --- a/tractor/_streaming.py +++ b/tractor/_streaming.py @@ -52,6 +52,7 @@ from tractor.msg import ( if TYPE_CHECKING: from ._context import Context + from ._ipc import Channel log = get_logger(__name__) @@ -65,10 +66,10 @@ log = get_logger(__name__) class MsgStream(trio.abc.Channel): ''' A bidirectional message stream for receiving logically sequenced - values over an inter-actor IPC ``Channel``. + values over an inter-actor IPC `Channel`. This is the type returned to a local task which entered either - ``Portal.open_stream_from()`` or ``Context.open_stream()``. + `Portal.open_stream_from()` or `Context.open_stream()`. Termination rules: @@ -95,6 +96,22 @@ class MsgStream(trio.abc.Channel): self._eoc: bool|trio.EndOfChannel = False self._closed: bool|trio.ClosedResourceError = False + @property + def ctx(self) -> Context: + ''' + This stream's IPC `Context` ref. + + ''' + return self._ctx + + @property + def chan(self) -> Channel: + ''' + Ref to the containing `Context`'s transport `Channel`. + + ''' + return self._ctx.chan + # TODO: could we make this a direct method bind to `PldRx`? # -> receive_nowait = PldRx.recv_pld # |_ means latter would have to accept `MsgStream`-as-`self`? @@ -109,7 +126,7 @@ class MsgStream(trio.abc.Channel): ): ctx: Context = self._ctx return ctx._pld_rx.recv_pld_nowait( - ctx=ctx, + ipc=self, expect_msg=expect_msg, ) @@ -148,7 +165,7 @@ class MsgStream(trio.abc.Channel): try: ctx: Context = self._ctx - return await ctx._pld_rx.recv_pld(ctx=ctx) + return await ctx._pld_rx.recv_pld(ipc=self) # XXX: the stream terminates on either of: # - via `self._rx_chan.receive()` raising after manual closure diff --git a/tractor/msg/_ops.py b/tractor/msg/_ops.py index 3b0b833..3014c15 100644 --- a/tractor/msg/_ops.py +++ b/tractor/msg/_ops.py @@ -22,10 +22,9 @@ operational helpers for processing transaction flows. ''' from __future__ import annotations from contextlib import ( - # asynccontextmanager as acm, + asynccontextmanager as acm, contextmanager as cm, ) -from contextvars import ContextVar from typing import ( Any, Type, @@ -50,6 +49,7 @@ from tractor._exceptions import ( _mk_msg_type_err, pack_from_raise, ) +from tractor._state import current_ipc_ctx from ._codec import ( mk_dec, MsgDec, @@ -75,7 +75,7 @@ if TYPE_CHECKING: log = get_logger(__name__) -_def_any_pldec: MsgDec = mk_dec() +_def_any_pldec: MsgDec[Any] = mk_dec() class PldRx(Struct): @@ -104,15 +104,19 @@ class PldRx(Struct): ''' # TODO: better to bind it here? # _rx_mc: trio.MemoryReceiveChannel - _pldec: MsgDec + _pld_dec: MsgDec + _ctx: Context|None = None _ipc: Context|MsgStream|None = None @property def pld_dec(self) -> MsgDec: - return self._pldec + return self._pld_dec + # TODO: a better name? + # -[ ] when would this be used as it avoids needingn to pass the + # ipc prim to every method @cm - def apply_to_ipc( + def wraps_ipc( self, ipc_prim: Context|MsgStream, @@ -140,49 +144,50 @@ class PldRx(Struct): exit. ''' - orig_dec: MsgDec = self._pldec + orig_dec: MsgDec = self._pld_dec limit_dec: MsgDec = mk_dec(spec=spec) try: - self._pldec = limit_dec + self._pld_dec = limit_dec yield limit_dec finally: - self._pldec = orig_dec + self._pld_dec = orig_dec @property def dec(self) -> msgpack.Decoder: - return self._pldec.dec + return self._pld_dec.dec def recv_pld_nowait( self, # TODO: make this `MsgStream` compat as well, see above^ # ipc_prim: Context|MsgStream, - ctx: Context, + ipc: Context|MsgStream, ipc_msg: MsgType|None = None, expect_msg: Type[MsgType]|None = None, - + hide_tb: bool = False, **dec_msg_kwargs, ) -> Any|Raw: - __tracebackhide__: bool = True + __tracebackhide__: bool = hide_tb msg: MsgType = ( ipc_msg or # sync-rx msg from underlying IPC feeder (mem-)chan - ctx._rx_chan.receive_nowait() + ipc._rx_chan.receive_nowait() ) return self.dec_msg( msg, - ctx=ctx, + ipc=ipc, expect_msg=expect_msg, + hide_tb=hide_tb, **dec_msg_kwargs, ) async def recv_pld( self, - ctx: Context, + ipc: Context|MsgStream, ipc_msg: MsgType|None = None, expect_msg: Type[MsgType]|None = None, hide_tb: bool = True, @@ -200,11 +205,11 @@ class PldRx(Struct): or # async-rx msg from underlying IPC feeder (mem-)chan - await ctx._rx_chan.receive() + await ipc._rx_chan.receive() ) return self.dec_msg( msg=msg, - ctx=ctx, + ipc=ipc, expect_msg=expect_msg, **dec_msg_kwargs, ) @@ -212,7 +217,7 @@ class PldRx(Struct): def dec_msg( self, msg: MsgType, - ctx: Context, + ipc: Context|MsgStream, expect_msg: Type[MsgType]|None, raise_error: bool = True, @@ -225,6 +230,9 @@ class PldRx(Struct): ''' __tracebackhide__: bool = hide_tb + + _src_err = None + src_err: BaseException|None = None match msg: # payload-data shuttle msg; deliver the `.pld` value # directly to IPC (primitive) client-consumer code. @@ -234,7 +242,7 @@ class PldRx(Struct): |Return(pld=pld) # termination phase ): try: - pld: PayloadT = self._pldec.decode(pld) + pld: PayloadT = self._pld_dec.decode(pld) log.runtime( 'Decoded msg payload\n\n' f'{msg}\n\n' @@ -243,25 +251,30 @@ class PldRx(Struct): ) return pld - # XXX pld-type failure - except ValidationError as src_err: + # XXX pld-value type failure + except ValidationError as valerr: + # pack mgterr into error-msg for + # reraise below; ensure remote-actor-err + # info is displayed nicely? msgterr: MsgTypeError = _mk_msg_type_err( msg=msg, codec=self.pld_dec, - src_validation_error=src_err, + src_validation_error=valerr, is_invalid_payload=True, ) msg: Error = pack_from_raise( local_err=msgterr, cid=msg.cid, - src_uid=ctx.chan.uid, + src_uid=ipc.chan.uid, ) + src_err = valerr # XXX some other decoder specific failure? # except TypeError as src_error: # from .devx import mk_pdb # mk_pdb().set_trace() # raise src_error + # ^-TODO-^ can remove? # a runtime-internal RPC endpoint response. # always passthrough since (internal) runtime @@ -299,6 +312,7 @@ class PldRx(Struct): return src_err case Stop(cid=cid): + ctx: Context = getattr(ipc, 'ctx', ipc) message: str = ( f'{ctx.side!r}-side of ctx received stream-`Stop` from ' f'{ctx.peer_side!r} peer ?\n' @@ -341,14 +355,21 @@ class PldRx(Struct): # |_https://docs.python.org/3.11/library/exceptions.html#BaseException.add_note # # fallthrough and raise from `src_err` - _raise_from_unexpected_msg( - ctx=ctx, - msg=msg, - src_err=src_err, - log=log, - expect_msg=expect_msg, - hide_tb=hide_tb, - ) + try: + _raise_from_unexpected_msg( + ctx=getattr(ipc, 'ctx', ipc), + msg=msg, + src_err=src_err, + log=log, + expect_msg=expect_msg, + hide_tb=hide_tb, + ) + except UnboundLocalError: + # XXX if there's an internal lookup error in the above + # code (prolly on `src_err`) we want to show this frame + # in the tb! + __tracebackhide__: bool = False + raise async def recv_msg_w_pld( self, @@ -378,52 +399,13 @@ class PldRx(Struct): # msg instance? pld: PayloadT = self.dec_msg( msg, - ctx=ipc, + ipc=ipc, expect_msg=expect_msg, **kwargs, ) return msg, pld -# Always maintain a task-context-global `PldRx` -_def_pld_rx: PldRx = PldRx( - _pldec=_def_any_pldec, -) -_ctxvar_PldRx: ContextVar[PldRx] = ContextVar( - 'pld_rx', - default=_def_pld_rx, -) - - -def current_pldrx() -> PldRx: - ''' - Return the current `trio.Task.context`'s msg-payload-receiver. - - A payload receiver is the IPC-msg processing sub-sys which - filters inter-actor-task communicated payload data, i.e. the - `PayloadMsg.pld: PayloadT` field value, AFTER it's container - shuttlle msg (eg. `Started`/`Yield`/`Return) has been delivered - up from `tractor`'s transport layer but BEFORE the data is - yielded to application code, normally via an IPC primitive API - like, for ex., `pld_data: PayloadT = MsgStream.receive()`. - - Modification of the current payload spec via `limit_plds()` - allows a `tractor` application to contextually filter IPC - payload content with a type specification as supported by - the interchange backend. - - - for `msgspec` see . - - NOTE that the `PldRx` itself is a per-`Context` global sub-system - that normally does not change other then the applied pld-spec - for the current `trio.Task`. - - ''' - # ctx: context = current_ipc_ctx() - # return ctx._pld_rx - return _ctxvar_PldRx.get() - - @cm def limit_plds( spec: Union[Type[Struct]], @@ -439,29 +421,55 @@ def limit_plds( ''' __tracebackhide__: bool = True try: - # sanity on orig settings - orig_pldrx: PldRx = current_pldrx() - orig_pldec: MsgDec = orig_pldrx.pld_dec + curr_ctx: Context = current_ipc_ctx() + rx: PldRx = curr_ctx._pld_rx + orig_pldec: MsgDec = rx.pld_dec - with orig_pldrx.limit_plds( + with rx.limit_plds( spec=spec, **kwargs, ) as pldec: - log.info( + log.runtime( 'Applying payload-decoder\n\n' f'{pldec}\n' ) yield pldec finally: - log.info( + log.runtime( 'Reverted to previous payload-decoder\n\n' f'{orig_pldec}\n' ) - assert ( - (pldrx := current_pldrx()) is orig_pldrx - and - pldrx.pld_dec is orig_pldec - ) + # sanity on orig settings + assert rx.pld_dec is orig_pldec + + +@acm +async def maybe_limit_plds( + ctx: Context, + spec: Union[Type[Struct]]|None = None, + **kwargs, +) -> MsgDec|None: + ''' + Async compat maybe-payload type limiter. + + Mostly for use inside other internal `@acm`s such that a separate + indent block isn't needed when an async one is already being + used. + + ''' + if spec is None: + yield None + return + + # sanity on scoping + curr_ctx: Context = current_ipc_ctx() + assert ctx is curr_ctx + + with ctx._pld_rx.limit_plds(spec=spec) as msgdec: + yield msgdec + + curr_ctx: Context = current_ipc_ctx() + assert ctx is curr_ctx async def drain_to_final_msg( @@ -543,21 +551,12 @@ async def drain_to_final_msg( match msg: # final result arrived! - case Return( - # cid=cid, - # pld=res, - ): - # ctx._result: Any = res - ctx._result: Any = pld + case Return(): log.runtime( 'Context delivered final draining msg:\n' f'{pretty_struct.pformat(msg)}' ) - # XXX: only close the rx mem chan AFTER - # a final result is retreived. - # if ctx._rx_chan: - # await ctx._rx_chan.aclose() - # TODO: ^ we don't need it right? + ctx._result: Any = pld result_msg = msg break @@ -664,24 +663,6 @@ async def drain_to_final_msg( result_msg = msg break # OOOOOF, yeah obvi we need this.. - # XXX we should never really get here - # right! since `._deliver_msg()` should - # always have detected an {'error': ..} - # msg and already called this right!?! - # elif error := unpack_error( - # msg=msg, - # chan=ctx._portal.channel, - # hide_tb=False, - # ): - # log.critical('SHOULD NEVER GET HERE!?') - # assert msg is ctx._cancel_msg - # assert error.msgdata == ctx._remote_error.msgdata - # assert error.ipc_msg == ctx._remote_error.ipc_msg - # from .devx._debug import pause - # await pause() - # ctx._maybe_cancel_and_set_remote_error(error) - # ctx._maybe_raise_remote_err(error) - else: # bubble the original src key error raise