Allocate a `PldRx` per `Context`, new pld-spec API

Since the state mgmt becomes quite messy with multiple sub-tasks inside
an IPC ctx, AND bc generally speaking the payload-type-spec should map
1-to-1 with the `Context`, it doesn't make a lot of sense to be using
`ContextVar`s to modify the `Context.pld_rx: PldRx` instance.

Instead, always allocate a full instance inside `mk_context()` with the
default `.pld_rx: PldRx` set to use the `msg._ops._def_any_pldec: MsgDec`

In support, simplify the `.msg._ops` impl and APIs:
- drop `_ctxvar_PldRx`, `_def_pld_rx` and `current_pldrx()`.
- rename `PldRx._pldec` -> `._pld_dec`.
- rename the unused `PldRx.apply_to_ipc()` -> `.wraps_ipc()`.
- add a required `PldRx._ctx: Context` attr since it is needed
  internally in some meths and each pld-rx now maps to a specific ctx.
- modify all recv methods to accept a `ipc: Context|MsgStream` (instead
  of a `ctx` arg) since both have a ref to the same `._rx_chan` and there
  are only a couple spots (in `.dec_msg()`) where we need the `ctx`
  explicitly (which can now be easily accessed via a new `MsgStream.ctx`
  property, see below).
- always show the `.dec_msg()` frame in tbs if there's a reference error
  when calling `_raise_from_unexpected_msg()` in the fallthrough case.
- implement `limit_plds()` as light wrapper around getting the
  `current_ipc_ctx()` and mutating its `MsgDec` via
  `Context.pld_rx.limit_plds()`.
- add a `maybe_limit_plds()` which just provides an `@acm` equivalent of
  `limit_plds()` handy for composing in a `async with ():` style block
  (avoiding additional indent levels in the body of async funcs).

Obvi extend the `Context` and `MsgStream` interfaces as needed
to match the above:
- add a `Context.pld_rx` pub prop.
- new private refs to `Context._started_msg: Started` and
  a `._started_pld` (mostly for internal debugging / testing / logging)
  and set inside `.open_context()` immediately after the syncing phase.
- a `Context.has_outcome() -> bool:` predicate which can be used to more
  easily determine if the ctx errored or has a final result.
- pub props for `MsgStream.ctx: Context` and `.chan: Channel` providing
  full `ipc`-arg compat with the `PldRx` method signatures.
runtime_to_msgspec
Tyler Goodlet 2024-05-20 14:34:50 -04:00
parent d93135acd8
commit 262a0e36c6
3 changed files with 212 additions and 163 deletions

View File

@ -41,6 +41,7 @@ from typing import (
Callable,
Mapping,
Type,
TypeAlias,
TYPE_CHECKING,
Union,
)
@ -155,6 +156,41 @@ class Context:
# payload receiver
_pld_rx: msgops.PldRx
@property
def pld_rx(self) -> msgops.PldRx:
'''
The current `tractor.Context`'s msg-payload-receiver.
A payload receiver is the IPC-msg processing sub-sys which
filters inter-actor-task communicated payload data, i.e. the
`PayloadMsg.pld: PayloadT` field value, AFTER its container
shuttlle msg (eg. `Started`/`Yield`/`Return) has been
delivered up from `tractor`'s transport layer but BEFORE the
data is yielded to `tractor` application code.
The "IPC-primitive API" is normally one of a `Context` (this)` or a `MsgStream`
or some higher level API using one of them.
For ex. `pld_data: PayloadT = MsgStream.receive()` implicitly
calls into the stream's parent `Context.pld_rx.recv_pld().` to
receive the latest `PayloadMsg.pld` value.
Modification of the current payload spec via `limit_plds()`
allows a `tractor` application to contextually filter IPC
payload content with a type specification as supported by the
interchange backend.
- for `msgspec` see <PUTLINKHERE>.
Note that the `PldRx` itself is a per-`Context` instance that
normally only changes when some (sub-)task, on a given "side"
of the IPC ctx (either a "child"-side RPC or inside
a "parent"-side `Portal.open_context()` block), modifies it
using the `.msg._ops.limit_plds()` API.
'''
return self._pld_rx
# full "namespace-path" to target RPC function
_nsf: NamespacePath
@ -231,6 +267,8 @@ class Context:
# init and streaming state
_started_called: bool = False
_started_msg: MsgType|None = None
_started_pld: Any = None
_stream_opened: bool = False
_stream: MsgStream|None = None
@ -623,7 +661,7 @@ class Context:
log.runtime(
'Setting remote error for ctx\n\n'
f'<= {self.peer_side!r}: {self.chan.uid}\n'
f'=> {self.side!r}\n\n'
f'=> {self.side!r}: {self._actor.uid}\n\n'
f'{error}'
)
self._remote_error: BaseException = error
@ -678,7 +716,7 @@ class Context:
log.error(
f'Remote context error:\n\n'
# f'{pformat(self)}\n'
f'{error}\n'
f'{error}'
)
if self._canceller is None:
@ -724,8 +762,10 @@ class Context:
)
else:
message: str = 'NOT cancelling `Context._scope` !\n\n'
# from .devx import mk_pdb
# mk_pdb().set_trace()
fmt_str: str = 'No `self._scope: CancelScope` was set/used ?'
fmt_str: str = 'No `self._scope: CancelScope` was set/used ?\n'
if (
cs
and
@ -805,6 +845,7 @@ class Context:
# f'{ci.api_nsp}()\n'
# )
# TODO: use `.dev._frame_stack` scanning to find caller!
return 'Portal.open_context()'
async def cancel(
@ -1304,17 +1345,6 @@ class Context:
ctx=self,
hide_tb=hide_tb,
)
for msg in drained_msgs:
# TODO: mask this by default..
if isinstance(msg, Return):
# from .devx import pause
# await pause()
# raise InternalError(
log.warning(
'Final `return` msg should never be drained !?!?\n\n'
f'{msg}\n'
)
drained_status: str = (
'Ctx drained to final outcome msg\n\n'
@ -1435,6 +1465,10 @@ class Context:
self._result
)
@property
def has_outcome(self) -> bool:
return bool(self.maybe_error) or self._final_result_is_set()
# @property
def repr_outcome(
self,
@ -1637,8 +1671,6 @@ class Context:
)
if rt_started != started_msg:
# TODO: break these methods out from the struct subtype?
# TODO: make that one a mod func too..
diff = pretty_struct.Struct.__sub__(
rt_started,
@ -1674,6 +1706,8 @@ class Context:
) from verr
self._started_called = True
self._started_msg = started_msg
self._started_pld = value
async def _drain_overflows(
self,
@ -1961,6 +1995,7 @@ async def open_context_from_portal(
portal: Portal,
func: Callable,
pld_spec: TypeAlias|None = None,
allow_overruns: bool = False,
# TODO: if we set this the wrapping `@acm` body will
@ -2026,7 +2061,7 @@ async def open_context_from_portal(
# XXX NOTE XXX: currenly we do NOT allow opening a contex
# with "self" since the local feeder mem-chan processing
# is not built for it.
if portal.channel.uid == portal.actor.uid:
if (uid := portal.channel.uid) == portal.actor.uid:
raise RuntimeError(
'** !! Invalid Operation !! **\n'
'Can not open an IPC ctx with the local actor!\n'
@ -2054,32 +2089,45 @@ async def open_context_from_portal(
assert ctx._caller_info
_ctxvar_Context.set(ctx)
# XXX NOTE since `._scope` is NOT set BEFORE we retreive the
# `Started`-msg any cancellation triggered
# in `._maybe_cancel_and_set_remote_error()` will
# NOT actually cancel the below line!
# -> it's expected that if there is an error in this phase of
# the dialog, the `Error` msg should be raised from the `msg`
# handling block below.
first: Any = await ctx._pld_rx.recv_pld(
ctx=ctx,
expect_msg=Started,
)
ctx._started_called: bool = True
uid: tuple = portal.channel.uid
cid: str = ctx.cid
# placeholder for any exception raised in the runtime
# or by user tasks which cause this context's closure.
scope_err: BaseException|None = None
ctxc_from_callee: ContextCancelled|None = None
try:
async with trio.open_nursery() as nurse:
async with (
trio.open_nursery() as tn,
msgops.maybe_limit_plds(
ctx=ctx,
spec=pld_spec,
) as maybe_msgdec,
):
if maybe_msgdec:
assert maybe_msgdec.pld_spec == pld_spec
# NOTE: used to start overrun queuing tasks
ctx._scope_nursery: trio.Nursery = nurse
ctx._scope: trio.CancelScope = nurse.cancel_scope
# XXX NOTE since `._scope` is NOT set BEFORE we retreive the
# `Started`-msg any cancellation triggered
# in `._maybe_cancel_and_set_remote_error()` will
# NOT actually cancel the below line!
# -> it's expected that if there is an error in this phase of
# the dialog, the `Error` msg should be raised from the `msg`
# handling block below.
started_msg, first = await ctx._pld_rx.recv_msg_w_pld(
ipc=ctx,
expect_msg=Started,
passthrough_non_pld_msgs=False,
)
# from .devx import pause
# await pause()
ctx._started_called: bool = True
ctx._started_msg: bool = started_msg
ctx._started_pld: bool = first
# NOTE: this in an implicit runtime nursery used to,
# - start overrun queuing tasks when as well as
# for cancellation of the scope opened by the user.
ctx._scope_nursery: trio.Nursery = tn
ctx._scope: trio.CancelScope = tn.cancel_scope
# deliver context instance and .started() msg value
# in enter tuple.
@ -2126,13 +2174,13 @@ async def open_context_from_portal(
# when in allow_overruns mode there may be
# lingering overflow sender tasks remaining?
if nurse.child_tasks:
if tn.child_tasks:
# XXX: ensure we are in overrun state
# with ``._allow_overruns=True`` bc otherwise
# there should be no tasks in this nursery!
if (
not ctx._allow_overruns
or len(nurse.child_tasks) > 1
or len(tn.child_tasks) > 1
):
raise InternalError(
'Context has sub-tasks but is '
@ -2304,8 +2352,8 @@ async def open_context_from_portal(
):
log.warning(
'IPC connection for context is broken?\n'
f'task:{cid}\n'
f'actor:{uid}'
f'task: {ctx.cid}\n'
f'actor: {uid}'
)
raise # duh
@ -2455,9 +2503,8 @@ async def open_context_from_portal(
and ctx.cancel_acked
):
log.cancel(
'Context cancelled by {ctx.side!r}-side task\n'
f'Context cancelled by {ctx.side!r}-side task\n'
f'|_{ctx._task}\n\n'
f'{repr(scope_err)}\n'
)
@ -2485,7 +2532,7 @@ async def open_context_from_portal(
f'cid: {ctx.cid}\n'
)
portal.actor._contexts.pop(
(uid, cid),
(uid, ctx.cid),
None,
)
@ -2516,8 +2563,9 @@ def mk_context(
from .devx._frame_stack import find_caller_info
caller_info: CallerInfo|None = find_caller_info()
# TODO: when/how do we apply `.limit_plds()` from here?
pld_rx: msgops.PldRx = msgops.current_pldrx()
pld_rx = msgops.PldRx(
_pld_dec=msgops._def_any_pldec,
)
ctx = Context(
chan=chan,
@ -2531,13 +2579,16 @@ def mk_context(
_caller_info=caller_info,
**kwargs,
)
pld_rx._ctx = ctx
ctx._result = Unresolved
return ctx
# TODO: use the new type-parameters to annotate this in 3.13?
# -[ ] https://peps.python.org/pep-0718/#unknown-types
def context(func: Callable) -> Callable:
def context(
func: Callable,
) -> Callable:
'''
Mark an (async) function as an SC-supervised, inter-`Actor`,
child-`trio.Task`, IPC endpoint otherwise known more

View File

@ -52,6 +52,7 @@ from tractor.msg import (
if TYPE_CHECKING:
from ._context import Context
from ._ipc import Channel
log = get_logger(__name__)
@ -65,10 +66,10 @@ log = get_logger(__name__)
class MsgStream(trio.abc.Channel):
'''
A bidirectional message stream for receiving logically sequenced
values over an inter-actor IPC ``Channel``.
values over an inter-actor IPC `Channel`.
This is the type returned to a local task which entered either
``Portal.open_stream_from()`` or ``Context.open_stream()``.
`Portal.open_stream_from()` or `Context.open_stream()`.
Termination rules:
@ -95,6 +96,22 @@ class MsgStream(trio.abc.Channel):
self._eoc: bool|trio.EndOfChannel = False
self._closed: bool|trio.ClosedResourceError = False
@property
def ctx(self) -> Context:
'''
This stream's IPC `Context` ref.
'''
return self._ctx
@property
def chan(self) -> Channel:
'''
Ref to the containing `Context`'s transport `Channel`.
'''
return self._ctx.chan
# TODO: could we make this a direct method bind to `PldRx`?
# -> receive_nowait = PldRx.recv_pld
# |_ means latter would have to accept `MsgStream`-as-`self`?
@ -109,7 +126,7 @@ class MsgStream(trio.abc.Channel):
):
ctx: Context = self._ctx
return ctx._pld_rx.recv_pld_nowait(
ctx=ctx,
ipc=self,
expect_msg=expect_msg,
)
@ -148,7 +165,7 @@ class MsgStream(trio.abc.Channel):
try:
ctx: Context = self._ctx
return await ctx._pld_rx.recv_pld(ctx=ctx)
return await ctx._pld_rx.recv_pld(ipc=self)
# XXX: the stream terminates on either of:
# - via `self._rx_chan.receive()` raising after manual closure

View File

@ -22,10 +22,9 @@ operational helpers for processing transaction flows.
'''
from __future__ import annotations
from contextlib import (
# asynccontextmanager as acm,
asynccontextmanager as acm,
contextmanager as cm,
)
from contextvars import ContextVar
from typing import (
Any,
Type,
@ -50,6 +49,7 @@ from tractor._exceptions import (
_mk_msg_type_err,
pack_from_raise,
)
from tractor._state import current_ipc_ctx
from ._codec import (
mk_dec,
MsgDec,
@ -75,7 +75,7 @@ if TYPE_CHECKING:
log = get_logger(__name__)
_def_any_pldec: MsgDec = mk_dec()
_def_any_pldec: MsgDec[Any] = mk_dec()
class PldRx(Struct):
@ -104,15 +104,19 @@ class PldRx(Struct):
'''
# TODO: better to bind it here?
# _rx_mc: trio.MemoryReceiveChannel
_pldec: MsgDec
_pld_dec: MsgDec
_ctx: Context|None = None
_ipc: Context|MsgStream|None = None
@property
def pld_dec(self) -> MsgDec:
return self._pldec
return self._pld_dec
# TODO: a better name?
# -[ ] when would this be used as it avoids needingn to pass the
# ipc prim to every method
@cm
def apply_to_ipc(
def wraps_ipc(
self,
ipc_prim: Context|MsgStream,
@ -140,49 +144,50 @@ class PldRx(Struct):
exit.
'''
orig_dec: MsgDec = self._pldec
orig_dec: MsgDec = self._pld_dec
limit_dec: MsgDec = mk_dec(spec=spec)
try:
self._pldec = limit_dec
self._pld_dec = limit_dec
yield limit_dec
finally:
self._pldec = orig_dec
self._pld_dec = orig_dec
@property
def dec(self) -> msgpack.Decoder:
return self._pldec.dec
return self._pld_dec.dec
def recv_pld_nowait(
self,
# TODO: make this `MsgStream` compat as well, see above^
# ipc_prim: Context|MsgStream,
ctx: Context,
ipc: Context|MsgStream,
ipc_msg: MsgType|None = None,
expect_msg: Type[MsgType]|None = None,
hide_tb: bool = False,
**dec_msg_kwargs,
) -> Any|Raw:
__tracebackhide__: bool = True
__tracebackhide__: bool = hide_tb
msg: MsgType = (
ipc_msg
or
# sync-rx msg from underlying IPC feeder (mem-)chan
ctx._rx_chan.receive_nowait()
ipc._rx_chan.receive_nowait()
)
return self.dec_msg(
msg,
ctx=ctx,
ipc=ipc,
expect_msg=expect_msg,
hide_tb=hide_tb,
**dec_msg_kwargs,
)
async def recv_pld(
self,
ctx: Context,
ipc: Context|MsgStream,
ipc_msg: MsgType|None = None,
expect_msg: Type[MsgType]|None = None,
hide_tb: bool = True,
@ -200,11 +205,11 @@ class PldRx(Struct):
or
# async-rx msg from underlying IPC feeder (mem-)chan
await ctx._rx_chan.receive()
await ipc._rx_chan.receive()
)
return self.dec_msg(
msg=msg,
ctx=ctx,
ipc=ipc,
expect_msg=expect_msg,
**dec_msg_kwargs,
)
@ -212,7 +217,7 @@ class PldRx(Struct):
def dec_msg(
self,
msg: MsgType,
ctx: Context,
ipc: Context|MsgStream,
expect_msg: Type[MsgType]|None,
raise_error: bool = True,
@ -225,6 +230,9 @@ class PldRx(Struct):
'''
__tracebackhide__: bool = hide_tb
_src_err = None
src_err: BaseException|None = None
match msg:
# payload-data shuttle msg; deliver the `.pld` value
# directly to IPC (primitive) client-consumer code.
@ -234,7 +242,7 @@ class PldRx(Struct):
|Return(pld=pld) # termination phase
):
try:
pld: PayloadT = self._pldec.decode(pld)
pld: PayloadT = self._pld_dec.decode(pld)
log.runtime(
'Decoded msg payload\n\n'
f'{msg}\n\n'
@ -243,25 +251,30 @@ class PldRx(Struct):
)
return pld
# XXX pld-type failure
except ValidationError as src_err:
# XXX pld-value type failure
except ValidationError as valerr:
# pack mgterr into error-msg for
# reraise below; ensure remote-actor-err
# info is displayed nicely?
msgterr: MsgTypeError = _mk_msg_type_err(
msg=msg,
codec=self.pld_dec,
src_validation_error=src_err,
src_validation_error=valerr,
is_invalid_payload=True,
)
msg: Error = pack_from_raise(
local_err=msgterr,
cid=msg.cid,
src_uid=ctx.chan.uid,
src_uid=ipc.chan.uid,
)
src_err = valerr
# XXX some other decoder specific failure?
# except TypeError as src_error:
# from .devx import mk_pdb
# mk_pdb().set_trace()
# raise src_error
# ^-TODO-^ can remove?
# a runtime-internal RPC endpoint response.
# always passthrough since (internal) runtime
@ -299,6 +312,7 @@ class PldRx(Struct):
return src_err
case Stop(cid=cid):
ctx: Context = getattr(ipc, 'ctx', ipc)
message: str = (
f'{ctx.side!r}-side of ctx received stream-`Stop` from '
f'{ctx.peer_side!r} peer ?\n'
@ -341,14 +355,21 @@ class PldRx(Struct):
# |_https://docs.python.org/3.11/library/exceptions.html#BaseException.add_note
#
# fallthrough and raise from `src_err`
_raise_from_unexpected_msg(
ctx=ctx,
msg=msg,
src_err=src_err,
log=log,
expect_msg=expect_msg,
hide_tb=hide_tb,
)
try:
_raise_from_unexpected_msg(
ctx=getattr(ipc, 'ctx', ipc),
msg=msg,
src_err=src_err,
log=log,
expect_msg=expect_msg,
hide_tb=hide_tb,
)
except UnboundLocalError:
# XXX if there's an internal lookup error in the above
# code (prolly on `src_err`) we want to show this frame
# in the tb!
__tracebackhide__: bool = False
raise
async def recv_msg_w_pld(
self,
@ -378,52 +399,13 @@ class PldRx(Struct):
# msg instance?
pld: PayloadT = self.dec_msg(
msg,
ctx=ipc,
ipc=ipc,
expect_msg=expect_msg,
**kwargs,
)
return msg, pld
# Always maintain a task-context-global `PldRx`
_def_pld_rx: PldRx = PldRx(
_pldec=_def_any_pldec,
)
_ctxvar_PldRx: ContextVar[PldRx] = ContextVar(
'pld_rx',
default=_def_pld_rx,
)
def current_pldrx() -> PldRx:
'''
Return the current `trio.Task.context`'s msg-payload-receiver.
A payload receiver is the IPC-msg processing sub-sys which
filters inter-actor-task communicated payload data, i.e. the
`PayloadMsg.pld: PayloadT` field value, AFTER it's container
shuttlle msg (eg. `Started`/`Yield`/`Return) has been delivered
up from `tractor`'s transport layer but BEFORE the data is
yielded to application code, normally via an IPC primitive API
like, for ex., `pld_data: PayloadT = MsgStream.receive()`.
Modification of the current payload spec via `limit_plds()`
allows a `tractor` application to contextually filter IPC
payload content with a type specification as supported by
the interchange backend.
- for `msgspec` see <PUTLINKHERE>.
NOTE that the `PldRx` itself is a per-`Context` global sub-system
that normally does not change other then the applied pld-spec
for the current `trio.Task`.
'''
# ctx: context = current_ipc_ctx()
# return ctx._pld_rx
return _ctxvar_PldRx.get()
@cm
def limit_plds(
spec: Union[Type[Struct]],
@ -439,29 +421,55 @@ def limit_plds(
'''
__tracebackhide__: bool = True
try:
# sanity on orig settings
orig_pldrx: PldRx = current_pldrx()
orig_pldec: MsgDec = orig_pldrx.pld_dec
curr_ctx: Context = current_ipc_ctx()
rx: PldRx = curr_ctx._pld_rx
orig_pldec: MsgDec = rx.pld_dec
with orig_pldrx.limit_plds(
with rx.limit_plds(
spec=spec,
**kwargs,
) as pldec:
log.info(
log.runtime(
'Applying payload-decoder\n\n'
f'{pldec}\n'
)
yield pldec
finally:
log.info(
log.runtime(
'Reverted to previous payload-decoder\n\n'
f'{orig_pldec}\n'
)
assert (
(pldrx := current_pldrx()) is orig_pldrx
and
pldrx.pld_dec is orig_pldec
)
# sanity on orig settings
assert rx.pld_dec is orig_pldec
@acm
async def maybe_limit_plds(
ctx: Context,
spec: Union[Type[Struct]]|None = None,
**kwargs,
) -> MsgDec|None:
'''
Async compat maybe-payload type limiter.
Mostly for use inside other internal `@acm`s such that a separate
indent block isn't needed when an async one is already being
used.
'''
if spec is None:
yield None
return
# sanity on scoping
curr_ctx: Context = current_ipc_ctx()
assert ctx is curr_ctx
with ctx._pld_rx.limit_plds(spec=spec) as msgdec:
yield msgdec
curr_ctx: Context = current_ipc_ctx()
assert ctx is curr_ctx
async def drain_to_final_msg(
@ -543,21 +551,12 @@ async def drain_to_final_msg(
match msg:
# final result arrived!
case Return(
# cid=cid,
# pld=res,
):
# ctx._result: Any = res
ctx._result: Any = pld
case Return():
log.runtime(
'Context delivered final draining msg:\n'
f'{pretty_struct.pformat(msg)}'
)
# XXX: only close the rx mem chan AFTER
# a final result is retreived.
# if ctx._rx_chan:
# await ctx._rx_chan.aclose()
# TODO: ^ we don't need it right?
ctx._result: Any = pld
result_msg = msg
break
@ -664,24 +663,6 @@ async def drain_to_final_msg(
result_msg = msg
break # OOOOOF, yeah obvi we need this..
# XXX we should never really get here
# right! since `._deliver_msg()` should
# always have detected an {'error': ..}
# msg and already called this right!?!
# elif error := unpack_error(
# msg=msg,
# chan=ctx._portal.channel,
# hide_tb=False,
# ):
# log.critical('SHOULD NEVER GET HERE!?')
# assert msg is ctx._cancel_msg
# assert error.msgdata == ctx._remote_error.msgdata
# assert error.ipc_msg == ctx._remote_error.ipc_msg
# from .devx._debug import pause
# await pause()
# ctx._maybe_cancel_and_set_remote_error(error)
# ctx._maybe_raise_remote_err(error)
else:
# bubble the original src key error
raise