forked from goodboy/tractor
1
0
Fork 0

Allocate a `PldRx` per `Context`, new pld-spec API

Since the state mgmt becomes quite messy with multiple sub-tasks inside
an IPC ctx, AND bc generally speaking the payload-type-spec should map
1-to-1 with the `Context`, it doesn't make a lot of sense to be using
`ContextVar`s to modify the `Context.pld_rx: PldRx` instance.

Instead, always allocate a full instance inside `mk_context()` with the
default `.pld_rx: PldRx` set to use the `msg._ops._def_any_pldec: MsgDec`

In support, simplify the `.msg._ops` impl and APIs:
- drop `_ctxvar_PldRx`, `_def_pld_rx` and `current_pldrx()`.
- rename `PldRx._pldec` -> `._pld_dec`.
- rename the unused `PldRx.apply_to_ipc()` -> `.wraps_ipc()`.
- add a required `PldRx._ctx: Context` attr since it is needed
  internally in some meths and each pld-rx now maps to a specific ctx.
- modify all recv methods to accept a `ipc: Context|MsgStream` (instead
  of a `ctx` arg) since both have a ref to the same `._rx_chan` and there
  are only a couple spots (in `.dec_msg()`) where we need the `ctx`
  explicitly (which can now be easily accessed via a new `MsgStream.ctx`
  property, see below).
- always show the `.dec_msg()` frame in tbs if there's a reference error
  when calling `_raise_from_unexpected_msg()` in the fallthrough case.
- implement `limit_plds()` as light wrapper around getting the
  `current_ipc_ctx()` and mutating its `MsgDec` via
  `Context.pld_rx.limit_plds()`.
- add a `maybe_limit_plds()` which just provides an `@acm` equivalent of
  `limit_plds()` handy for composing in a `async with ():` style block
  (avoiding additional indent levels in the body of async funcs).

Obvi extend the `Context` and `MsgStream` interfaces as needed
to match the above:
- add a `Context.pld_rx` pub prop.
- new private refs to `Context._started_msg: Started` and
  a `._started_pld` (mostly for internal debugging / testing / logging)
  and set inside `.open_context()` immediately after the syncing phase.
- a `Context.has_outcome() -> bool:` predicate which can be used to more
  easily determine if the ctx errored or has a final result.
- pub props for `MsgStream.ctx: Context` and `.chan: Channel` providing
  full `ipc`-arg compat with the `PldRx` method signatures.
runtime_to_msgspec
Tyler Goodlet 2024-05-20 14:34:50 -04:00
parent d93135acd8
commit 262a0e36c6
3 changed files with 212 additions and 163 deletions

View File

@ -41,6 +41,7 @@ from typing import (
Callable, Callable,
Mapping, Mapping,
Type, Type,
TypeAlias,
TYPE_CHECKING, TYPE_CHECKING,
Union, Union,
) )
@ -155,6 +156,41 @@ class Context:
# payload receiver # payload receiver
_pld_rx: msgops.PldRx _pld_rx: msgops.PldRx
@property
def pld_rx(self) -> msgops.PldRx:
'''
The current `tractor.Context`'s msg-payload-receiver.
A payload receiver is the IPC-msg processing sub-sys which
filters inter-actor-task communicated payload data, i.e. the
`PayloadMsg.pld: PayloadT` field value, AFTER its container
shuttlle msg (eg. `Started`/`Yield`/`Return) has been
delivered up from `tractor`'s transport layer but BEFORE the
data is yielded to `tractor` application code.
The "IPC-primitive API" is normally one of a `Context` (this)` or a `MsgStream`
or some higher level API using one of them.
For ex. `pld_data: PayloadT = MsgStream.receive()` implicitly
calls into the stream's parent `Context.pld_rx.recv_pld().` to
receive the latest `PayloadMsg.pld` value.
Modification of the current payload spec via `limit_plds()`
allows a `tractor` application to contextually filter IPC
payload content with a type specification as supported by the
interchange backend.
- for `msgspec` see <PUTLINKHERE>.
Note that the `PldRx` itself is a per-`Context` instance that
normally only changes when some (sub-)task, on a given "side"
of the IPC ctx (either a "child"-side RPC or inside
a "parent"-side `Portal.open_context()` block), modifies it
using the `.msg._ops.limit_plds()` API.
'''
return self._pld_rx
# full "namespace-path" to target RPC function # full "namespace-path" to target RPC function
_nsf: NamespacePath _nsf: NamespacePath
@ -231,6 +267,8 @@ class Context:
# init and streaming state # init and streaming state
_started_called: bool = False _started_called: bool = False
_started_msg: MsgType|None = None
_started_pld: Any = None
_stream_opened: bool = False _stream_opened: bool = False
_stream: MsgStream|None = None _stream: MsgStream|None = None
@ -623,7 +661,7 @@ class Context:
log.runtime( log.runtime(
'Setting remote error for ctx\n\n' 'Setting remote error for ctx\n\n'
f'<= {self.peer_side!r}: {self.chan.uid}\n' f'<= {self.peer_side!r}: {self.chan.uid}\n'
f'=> {self.side!r}\n\n' f'=> {self.side!r}: {self._actor.uid}\n\n'
f'{error}' f'{error}'
) )
self._remote_error: BaseException = error self._remote_error: BaseException = error
@ -678,7 +716,7 @@ class Context:
log.error( log.error(
f'Remote context error:\n\n' f'Remote context error:\n\n'
# f'{pformat(self)}\n' # f'{pformat(self)}\n'
f'{error}\n' f'{error}'
) )
if self._canceller is None: if self._canceller is None:
@ -724,8 +762,10 @@ class Context:
) )
else: else:
message: str = 'NOT cancelling `Context._scope` !\n\n' message: str = 'NOT cancelling `Context._scope` !\n\n'
# from .devx import mk_pdb
# mk_pdb().set_trace()
fmt_str: str = 'No `self._scope: CancelScope` was set/used ?' fmt_str: str = 'No `self._scope: CancelScope` was set/used ?\n'
if ( if (
cs cs
and and
@ -805,6 +845,7 @@ class Context:
# f'{ci.api_nsp}()\n' # f'{ci.api_nsp}()\n'
# ) # )
# TODO: use `.dev._frame_stack` scanning to find caller!
return 'Portal.open_context()' return 'Portal.open_context()'
async def cancel( async def cancel(
@ -1304,17 +1345,6 @@ class Context:
ctx=self, ctx=self,
hide_tb=hide_tb, hide_tb=hide_tb,
) )
for msg in drained_msgs:
# TODO: mask this by default..
if isinstance(msg, Return):
# from .devx import pause
# await pause()
# raise InternalError(
log.warning(
'Final `return` msg should never be drained !?!?\n\n'
f'{msg}\n'
)
drained_status: str = ( drained_status: str = (
'Ctx drained to final outcome msg\n\n' 'Ctx drained to final outcome msg\n\n'
@ -1435,6 +1465,10 @@ class Context:
self._result self._result
) )
@property
def has_outcome(self) -> bool:
return bool(self.maybe_error) or self._final_result_is_set()
# @property # @property
def repr_outcome( def repr_outcome(
self, self,
@ -1637,8 +1671,6 @@ class Context:
) )
if rt_started != started_msg: if rt_started != started_msg:
# TODO: break these methods out from the struct subtype?
# TODO: make that one a mod func too.. # TODO: make that one a mod func too..
diff = pretty_struct.Struct.__sub__( diff = pretty_struct.Struct.__sub__(
rt_started, rt_started,
@ -1674,6 +1706,8 @@ class Context:
) from verr ) from verr
self._started_called = True self._started_called = True
self._started_msg = started_msg
self._started_pld = value
async def _drain_overflows( async def _drain_overflows(
self, self,
@ -1961,6 +1995,7 @@ async def open_context_from_portal(
portal: Portal, portal: Portal,
func: Callable, func: Callable,
pld_spec: TypeAlias|None = None,
allow_overruns: bool = False, allow_overruns: bool = False,
# TODO: if we set this the wrapping `@acm` body will # TODO: if we set this the wrapping `@acm` body will
@ -2026,7 +2061,7 @@ async def open_context_from_portal(
# XXX NOTE XXX: currenly we do NOT allow opening a contex # XXX NOTE XXX: currenly we do NOT allow opening a contex
# with "self" since the local feeder mem-chan processing # with "self" since the local feeder mem-chan processing
# is not built for it. # is not built for it.
if portal.channel.uid == portal.actor.uid: if (uid := portal.channel.uid) == portal.actor.uid:
raise RuntimeError( raise RuntimeError(
'** !! Invalid Operation !! **\n' '** !! Invalid Operation !! **\n'
'Can not open an IPC ctx with the local actor!\n' 'Can not open an IPC ctx with the local actor!\n'
@ -2054,32 +2089,45 @@ async def open_context_from_portal(
assert ctx._caller_info assert ctx._caller_info
_ctxvar_Context.set(ctx) _ctxvar_Context.set(ctx)
# XXX NOTE since `._scope` is NOT set BEFORE we retreive the
# `Started`-msg any cancellation triggered
# in `._maybe_cancel_and_set_remote_error()` will
# NOT actually cancel the below line!
# -> it's expected that if there is an error in this phase of
# the dialog, the `Error` msg should be raised from the `msg`
# handling block below.
first: Any = await ctx._pld_rx.recv_pld(
ctx=ctx,
expect_msg=Started,
)
ctx._started_called: bool = True
uid: tuple = portal.channel.uid
cid: str = ctx.cid
# placeholder for any exception raised in the runtime # placeholder for any exception raised in the runtime
# or by user tasks which cause this context's closure. # or by user tasks which cause this context's closure.
scope_err: BaseException|None = None scope_err: BaseException|None = None
ctxc_from_callee: ContextCancelled|None = None ctxc_from_callee: ContextCancelled|None = None
try: try:
async with trio.open_nursery() as nurse: async with (
trio.open_nursery() as tn,
msgops.maybe_limit_plds(
ctx=ctx,
spec=pld_spec,
) as maybe_msgdec,
):
if maybe_msgdec:
assert maybe_msgdec.pld_spec == pld_spec
# NOTE: used to start overrun queuing tasks # XXX NOTE since `._scope` is NOT set BEFORE we retreive the
ctx._scope_nursery: trio.Nursery = nurse # `Started`-msg any cancellation triggered
ctx._scope: trio.CancelScope = nurse.cancel_scope # in `._maybe_cancel_and_set_remote_error()` will
# NOT actually cancel the below line!
# -> it's expected that if there is an error in this phase of
# the dialog, the `Error` msg should be raised from the `msg`
# handling block below.
started_msg, first = await ctx._pld_rx.recv_msg_w_pld(
ipc=ctx,
expect_msg=Started,
passthrough_non_pld_msgs=False,
)
# from .devx import pause
# await pause()
ctx._started_called: bool = True
ctx._started_msg: bool = started_msg
ctx._started_pld: bool = first
# NOTE: this in an implicit runtime nursery used to,
# - start overrun queuing tasks when as well as
# for cancellation of the scope opened by the user.
ctx._scope_nursery: trio.Nursery = tn
ctx._scope: trio.CancelScope = tn.cancel_scope
# deliver context instance and .started() msg value # deliver context instance and .started() msg value
# in enter tuple. # in enter tuple.
@ -2126,13 +2174,13 @@ async def open_context_from_portal(
# when in allow_overruns mode there may be # when in allow_overruns mode there may be
# lingering overflow sender tasks remaining? # lingering overflow sender tasks remaining?
if nurse.child_tasks: if tn.child_tasks:
# XXX: ensure we are in overrun state # XXX: ensure we are in overrun state
# with ``._allow_overruns=True`` bc otherwise # with ``._allow_overruns=True`` bc otherwise
# there should be no tasks in this nursery! # there should be no tasks in this nursery!
if ( if (
not ctx._allow_overruns not ctx._allow_overruns
or len(nurse.child_tasks) > 1 or len(tn.child_tasks) > 1
): ):
raise InternalError( raise InternalError(
'Context has sub-tasks but is ' 'Context has sub-tasks but is '
@ -2304,8 +2352,8 @@ async def open_context_from_portal(
): ):
log.warning( log.warning(
'IPC connection for context is broken?\n' 'IPC connection for context is broken?\n'
f'task:{cid}\n' f'task: {ctx.cid}\n'
f'actor:{uid}' f'actor: {uid}'
) )
raise # duh raise # duh
@ -2455,9 +2503,8 @@ async def open_context_from_portal(
and ctx.cancel_acked and ctx.cancel_acked
): ):
log.cancel( log.cancel(
'Context cancelled by {ctx.side!r}-side task\n' f'Context cancelled by {ctx.side!r}-side task\n'
f'|_{ctx._task}\n\n' f'|_{ctx._task}\n\n'
f'{repr(scope_err)}\n' f'{repr(scope_err)}\n'
) )
@ -2485,7 +2532,7 @@ async def open_context_from_portal(
f'cid: {ctx.cid}\n' f'cid: {ctx.cid}\n'
) )
portal.actor._contexts.pop( portal.actor._contexts.pop(
(uid, cid), (uid, ctx.cid),
None, None,
) )
@ -2516,8 +2563,9 @@ def mk_context(
from .devx._frame_stack import find_caller_info from .devx._frame_stack import find_caller_info
caller_info: CallerInfo|None = find_caller_info() caller_info: CallerInfo|None = find_caller_info()
# TODO: when/how do we apply `.limit_plds()` from here? pld_rx = msgops.PldRx(
pld_rx: msgops.PldRx = msgops.current_pldrx() _pld_dec=msgops._def_any_pldec,
)
ctx = Context( ctx = Context(
chan=chan, chan=chan,
@ -2531,13 +2579,16 @@ def mk_context(
_caller_info=caller_info, _caller_info=caller_info,
**kwargs, **kwargs,
) )
pld_rx._ctx = ctx
ctx._result = Unresolved ctx._result = Unresolved
return ctx return ctx
# TODO: use the new type-parameters to annotate this in 3.13? # TODO: use the new type-parameters to annotate this in 3.13?
# -[ ] https://peps.python.org/pep-0718/#unknown-types # -[ ] https://peps.python.org/pep-0718/#unknown-types
def context(func: Callable) -> Callable: def context(
func: Callable,
) -> Callable:
''' '''
Mark an (async) function as an SC-supervised, inter-`Actor`, Mark an (async) function as an SC-supervised, inter-`Actor`,
child-`trio.Task`, IPC endpoint otherwise known more child-`trio.Task`, IPC endpoint otherwise known more

View File

@ -52,6 +52,7 @@ from tractor.msg import (
if TYPE_CHECKING: if TYPE_CHECKING:
from ._context import Context from ._context import Context
from ._ipc import Channel
log = get_logger(__name__) log = get_logger(__name__)
@ -65,10 +66,10 @@ log = get_logger(__name__)
class MsgStream(trio.abc.Channel): class MsgStream(trio.abc.Channel):
''' '''
A bidirectional message stream for receiving logically sequenced A bidirectional message stream for receiving logically sequenced
values over an inter-actor IPC ``Channel``. values over an inter-actor IPC `Channel`.
This is the type returned to a local task which entered either This is the type returned to a local task which entered either
``Portal.open_stream_from()`` or ``Context.open_stream()``. `Portal.open_stream_from()` or `Context.open_stream()`.
Termination rules: Termination rules:
@ -95,6 +96,22 @@ class MsgStream(trio.abc.Channel):
self._eoc: bool|trio.EndOfChannel = False self._eoc: bool|trio.EndOfChannel = False
self._closed: bool|trio.ClosedResourceError = False self._closed: bool|trio.ClosedResourceError = False
@property
def ctx(self) -> Context:
'''
This stream's IPC `Context` ref.
'''
return self._ctx
@property
def chan(self) -> Channel:
'''
Ref to the containing `Context`'s transport `Channel`.
'''
return self._ctx.chan
# TODO: could we make this a direct method bind to `PldRx`? # TODO: could we make this a direct method bind to `PldRx`?
# -> receive_nowait = PldRx.recv_pld # -> receive_nowait = PldRx.recv_pld
# |_ means latter would have to accept `MsgStream`-as-`self`? # |_ means latter would have to accept `MsgStream`-as-`self`?
@ -109,7 +126,7 @@ class MsgStream(trio.abc.Channel):
): ):
ctx: Context = self._ctx ctx: Context = self._ctx
return ctx._pld_rx.recv_pld_nowait( return ctx._pld_rx.recv_pld_nowait(
ctx=ctx, ipc=self,
expect_msg=expect_msg, expect_msg=expect_msg,
) )
@ -148,7 +165,7 @@ class MsgStream(trio.abc.Channel):
try: try:
ctx: Context = self._ctx ctx: Context = self._ctx
return await ctx._pld_rx.recv_pld(ctx=ctx) return await ctx._pld_rx.recv_pld(ipc=self)
# XXX: the stream terminates on either of: # XXX: the stream terminates on either of:
# - via `self._rx_chan.receive()` raising after manual closure # - via `self._rx_chan.receive()` raising after manual closure

View File

@ -22,10 +22,9 @@ operational helpers for processing transaction flows.
''' '''
from __future__ import annotations from __future__ import annotations
from contextlib import ( from contextlib import (
# asynccontextmanager as acm, asynccontextmanager as acm,
contextmanager as cm, contextmanager as cm,
) )
from contextvars import ContextVar
from typing import ( from typing import (
Any, Any,
Type, Type,
@ -50,6 +49,7 @@ from tractor._exceptions import (
_mk_msg_type_err, _mk_msg_type_err,
pack_from_raise, pack_from_raise,
) )
from tractor._state import current_ipc_ctx
from ._codec import ( from ._codec import (
mk_dec, mk_dec,
MsgDec, MsgDec,
@ -75,7 +75,7 @@ if TYPE_CHECKING:
log = get_logger(__name__) log = get_logger(__name__)
_def_any_pldec: MsgDec = mk_dec() _def_any_pldec: MsgDec[Any] = mk_dec()
class PldRx(Struct): class PldRx(Struct):
@ -104,15 +104,19 @@ class PldRx(Struct):
''' '''
# TODO: better to bind it here? # TODO: better to bind it here?
# _rx_mc: trio.MemoryReceiveChannel # _rx_mc: trio.MemoryReceiveChannel
_pldec: MsgDec _pld_dec: MsgDec
_ctx: Context|None = None
_ipc: Context|MsgStream|None = None _ipc: Context|MsgStream|None = None
@property @property
def pld_dec(self) -> MsgDec: def pld_dec(self) -> MsgDec:
return self._pldec return self._pld_dec
# TODO: a better name?
# -[ ] when would this be used as it avoids needingn to pass the
# ipc prim to every method
@cm @cm
def apply_to_ipc( def wraps_ipc(
self, self,
ipc_prim: Context|MsgStream, ipc_prim: Context|MsgStream,
@ -140,49 +144,50 @@ class PldRx(Struct):
exit. exit.
''' '''
orig_dec: MsgDec = self._pldec orig_dec: MsgDec = self._pld_dec
limit_dec: MsgDec = mk_dec(spec=spec) limit_dec: MsgDec = mk_dec(spec=spec)
try: try:
self._pldec = limit_dec self._pld_dec = limit_dec
yield limit_dec yield limit_dec
finally: finally:
self._pldec = orig_dec self._pld_dec = orig_dec
@property @property
def dec(self) -> msgpack.Decoder: def dec(self) -> msgpack.Decoder:
return self._pldec.dec return self._pld_dec.dec
def recv_pld_nowait( def recv_pld_nowait(
self, self,
# TODO: make this `MsgStream` compat as well, see above^ # TODO: make this `MsgStream` compat as well, see above^
# ipc_prim: Context|MsgStream, # ipc_prim: Context|MsgStream,
ctx: Context, ipc: Context|MsgStream,
ipc_msg: MsgType|None = None, ipc_msg: MsgType|None = None,
expect_msg: Type[MsgType]|None = None, expect_msg: Type[MsgType]|None = None,
hide_tb: bool = False,
**dec_msg_kwargs, **dec_msg_kwargs,
) -> Any|Raw: ) -> Any|Raw:
__tracebackhide__: bool = True __tracebackhide__: bool = hide_tb
msg: MsgType = ( msg: MsgType = (
ipc_msg ipc_msg
or or
# sync-rx msg from underlying IPC feeder (mem-)chan # sync-rx msg from underlying IPC feeder (mem-)chan
ctx._rx_chan.receive_nowait() ipc._rx_chan.receive_nowait()
) )
return self.dec_msg( return self.dec_msg(
msg, msg,
ctx=ctx, ipc=ipc,
expect_msg=expect_msg, expect_msg=expect_msg,
hide_tb=hide_tb,
**dec_msg_kwargs, **dec_msg_kwargs,
) )
async def recv_pld( async def recv_pld(
self, self,
ctx: Context, ipc: Context|MsgStream,
ipc_msg: MsgType|None = None, ipc_msg: MsgType|None = None,
expect_msg: Type[MsgType]|None = None, expect_msg: Type[MsgType]|None = None,
hide_tb: bool = True, hide_tb: bool = True,
@ -200,11 +205,11 @@ class PldRx(Struct):
or or
# async-rx msg from underlying IPC feeder (mem-)chan # async-rx msg from underlying IPC feeder (mem-)chan
await ctx._rx_chan.receive() await ipc._rx_chan.receive()
) )
return self.dec_msg( return self.dec_msg(
msg=msg, msg=msg,
ctx=ctx, ipc=ipc,
expect_msg=expect_msg, expect_msg=expect_msg,
**dec_msg_kwargs, **dec_msg_kwargs,
) )
@ -212,7 +217,7 @@ class PldRx(Struct):
def dec_msg( def dec_msg(
self, self,
msg: MsgType, msg: MsgType,
ctx: Context, ipc: Context|MsgStream,
expect_msg: Type[MsgType]|None, expect_msg: Type[MsgType]|None,
raise_error: bool = True, raise_error: bool = True,
@ -225,6 +230,9 @@ class PldRx(Struct):
''' '''
__tracebackhide__: bool = hide_tb __tracebackhide__: bool = hide_tb
_src_err = None
src_err: BaseException|None = None
match msg: match msg:
# payload-data shuttle msg; deliver the `.pld` value # payload-data shuttle msg; deliver the `.pld` value
# directly to IPC (primitive) client-consumer code. # directly to IPC (primitive) client-consumer code.
@ -234,7 +242,7 @@ class PldRx(Struct):
|Return(pld=pld) # termination phase |Return(pld=pld) # termination phase
): ):
try: try:
pld: PayloadT = self._pldec.decode(pld) pld: PayloadT = self._pld_dec.decode(pld)
log.runtime( log.runtime(
'Decoded msg payload\n\n' 'Decoded msg payload\n\n'
f'{msg}\n\n' f'{msg}\n\n'
@ -243,25 +251,30 @@ class PldRx(Struct):
) )
return pld return pld
# XXX pld-type failure # XXX pld-value type failure
except ValidationError as src_err: except ValidationError as valerr:
# pack mgterr into error-msg for
# reraise below; ensure remote-actor-err
# info is displayed nicely?
msgterr: MsgTypeError = _mk_msg_type_err( msgterr: MsgTypeError = _mk_msg_type_err(
msg=msg, msg=msg,
codec=self.pld_dec, codec=self.pld_dec,
src_validation_error=src_err, src_validation_error=valerr,
is_invalid_payload=True, is_invalid_payload=True,
) )
msg: Error = pack_from_raise( msg: Error = pack_from_raise(
local_err=msgterr, local_err=msgterr,
cid=msg.cid, cid=msg.cid,
src_uid=ctx.chan.uid, src_uid=ipc.chan.uid,
) )
src_err = valerr
# XXX some other decoder specific failure? # XXX some other decoder specific failure?
# except TypeError as src_error: # except TypeError as src_error:
# from .devx import mk_pdb # from .devx import mk_pdb
# mk_pdb().set_trace() # mk_pdb().set_trace()
# raise src_error # raise src_error
# ^-TODO-^ can remove?
# a runtime-internal RPC endpoint response. # a runtime-internal RPC endpoint response.
# always passthrough since (internal) runtime # always passthrough since (internal) runtime
@ -299,6 +312,7 @@ class PldRx(Struct):
return src_err return src_err
case Stop(cid=cid): case Stop(cid=cid):
ctx: Context = getattr(ipc, 'ctx', ipc)
message: str = ( message: str = (
f'{ctx.side!r}-side of ctx received stream-`Stop` from ' f'{ctx.side!r}-side of ctx received stream-`Stop` from '
f'{ctx.peer_side!r} peer ?\n' f'{ctx.peer_side!r} peer ?\n'
@ -341,14 +355,21 @@ class PldRx(Struct):
# |_https://docs.python.org/3.11/library/exceptions.html#BaseException.add_note # |_https://docs.python.org/3.11/library/exceptions.html#BaseException.add_note
# #
# fallthrough and raise from `src_err` # fallthrough and raise from `src_err`
_raise_from_unexpected_msg( try:
ctx=ctx, _raise_from_unexpected_msg(
msg=msg, ctx=getattr(ipc, 'ctx', ipc),
src_err=src_err, msg=msg,
log=log, src_err=src_err,
expect_msg=expect_msg, log=log,
hide_tb=hide_tb, expect_msg=expect_msg,
) hide_tb=hide_tb,
)
except UnboundLocalError:
# XXX if there's an internal lookup error in the above
# code (prolly on `src_err`) we want to show this frame
# in the tb!
__tracebackhide__: bool = False
raise
async def recv_msg_w_pld( async def recv_msg_w_pld(
self, self,
@ -378,52 +399,13 @@ class PldRx(Struct):
# msg instance? # msg instance?
pld: PayloadT = self.dec_msg( pld: PayloadT = self.dec_msg(
msg, msg,
ctx=ipc, ipc=ipc,
expect_msg=expect_msg, expect_msg=expect_msg,
**kwargs, **kwargs,
) )
return msg, pld return msg, pld
# Always maintain a task-context-global `PldRx`
_def_pld_rx: PldRx = PldRx(
_pldec=_def_any_pldec,
)
_ctxvar_PldRx: ContextVar[PldRx] = ContextVar(
'pld_rx',
default=_def_pld_rx,
)
def current_pldrx() -> PldRx:
'''
Return the current `trio.Task.context`'s msg-payload-receiver.
A payload receiver is the IPC-msg processing sub-sys which
filters inter-actor-task communicated payload data, i.e. the
`PayloadMsg.pld: PayloadT` field value, AFTER it's container
shuttlle msg (eg. `Started`/`Yield`/`Return) has been delivered
up from `tractor`'s transport layer but BEFORE the data is
yielded to application code, normally via an IPC primitive API
like, for ex., `pld_data: PayloadT = MsgStream.receive()`.
Modification of the current payload spec via `limit_plds()`
allows a `tractor` application to contextually filter IPC
payload content with a type specification as supported by
the interchange backend.
- for `msgspec` see <PUTLINKHERE>.
NOTE that the `PldRx` itself is a per-`Context` global sub-system
that normally does not change other then the applied pld-spec
for the current `trio.Task`.
'''
# ctx: context = current_ipc_ctx()
# return ctx._pld_rx
return _ctxvar_PldRx.get()
@cm @cm
def limit_plds( def limit_plds(
spec: Union[Type[Struct]], spec: Union[Type[Struct]],
@ -439,29 +421,55 @@ def limit_plds(
''' '''
__tracebackhide__: bool = True __tracebackhide__: bool = True
try: try:
# sanity on orig settings curr_ctx: Context = current_ipc_ctx()
orig_pldrx: PldRx = current_pldrx() rx: PldRx = curr_ctx._pld_rx
orig_pldec: MsgDec = orig_pldrx.pld_dec orig_pldec: MsgDec = rx.pld_dec
with orig_pldrx.limit_plds( with rx.limit_plds(
spec=spec, spec=spec,
**kwargs, **kwargs,
) as pldec: ) as pldec:
log.info( log.runtime(
'Applying payload-decoder\n\n' 'Applying payload-decoder\n\n'
f'{pldec}\n' f'{pldec}\n'
) )
yield pldec yield pldec
finally: finally:
log.info( log.runtime(
'Reverted to previous payload-decoder\n\n' 'Reverted to previous payload-decoder\n\n'
f'{orig_pldec}\n' f'{orig_pldec}\n'
) )
assert ( # sanity on orig settings
(pldrx := current_pldrx()) is orig_pldrx assert rx.pld_dec is orig_pldec
and
pldrx.pld_dec is orig_pldec
) @acm
async def maybe_limit_plds(
ctx: Context,
spec: Union[Type[Struct]]|None = None,
**kwargs,
) -> MsgDec|None:
'''
Async compat maybe-payload type limiter.
Mostly for use inside other internal `@acm`s such that a separate
indent block isn't needed when an async one is already being
used.
'''
if spec is None:
yield None
return
# sanity on scoping
curr_ctx: Context = current_ipc_ctx()
assert ctx is curr_ctx
with ctx._pld_rx.limit_plds(spec=spec) as msgdec:
yield msgdec
curr_ctx: Context = current_ipc_ctx()
assert ctx is curr_ctx
async def drain_to_final_msg( async def drain_to_final_msg(
@ -543,21 +551,12 @@ async def drain_to_final_msg(
match msg: match msg:
# final result arrived! # final result arrived!
case Return( case Return():
# cid=cid,
# pld=res,
):
# ctx._result: Any = res
ctx._result: Any = pld
log.runtime( log.runtime(
'Context delivered final draining msg:\n' 'Context delivered final draining msg:\n'
f'{pretty_struct.pformat(msg)}' f'{pretty_struct.pformat(msg)}'
) )
# XXX: only close the rx mem chan AFTER ctx._result: Any = pld
# a final result is retreived.
# if ctx._rx_chan:
# await ctx._rx_chan.aclose()
# TODO: ^ we don't need it right?
result_msg = msg result_msg = msg
break break
@ -664,24 +663,6 @@ async def drain_to_final_msg(
result_msg = msg result_msg = msg
break # OOOOOF, yeah obvi we need this.. break # OOOOOF, yeah obvi we need this..
# XXX we should never really get here
# right! since `._deliver_msg()` should
# always have detected an {'error': ..}
# msg and already called this right!?!
# elif error := unpack_error(
# msg=msg,
# chan=ctx._portal.channel,
# hide_tb=False,
# ):
# log.critical('SHOULD NEVER GET HERE!?')
# assert msg is ctx._cancel_msg
# assert error.msgdata == ctx._remote_error.msgdata
# assert error.ipc_msg == ctx._remote_error.ipc_msg
# from .devx._debug import pause
# await pause()
# ctx._maybe_cancel_and_set_remote_error(error)
# ctx._maybe_raise_remote_err(error)
else: else:
# bubble the original src key error # bubble the original src key error
raise raise