First draft payload-spec limit API

Add new task-scope oriented `PldRx.pld_spec` management API similar to
`.msg._codec.limit_msg_spec()`, but obvi built to process and filter
`MsgType.pld` values.

New API related changes include:
- new per-task singleton getter `msg._ops.current_pldrx()` which
  delivers the current (global) payload receiver via a new
  `_ctxvar_PldRx: ContextVar` configured with a default
  `_def_any_pldec: MsgDec[Any]` decoder.
- a `PldRx.limit_plds()` which sets the decoder (`.type` underneath)
  for the specific payload rx instance.
- `.msg._ops.limit_plds()` which obtains the current task-scoped `PldRx`
  and applies the pld spec via a new `PldRx.limit_plds()`.
- rename `PldRx._msgdec` -> `._pldec`.
- add `.pld_dec` as pub attr for -^

Unrelated adjustments:
- use `.msg.pretty_struct.pformat()` where handy.
- always pass `expect_msg: MsgType`.
- add a `case Stop()` to `PldRx.dec_msg()` which will `log.warning()`
  when a stop is received by no stream was open on this receiving side
  since we rarely want that to raise since it's prolly just a runtime
  race or mistake in user code.

Other:
runtime_to_msgspec
Tyler Goodlet 2024-04-26 15:29:50 -04:00
parent d285a3479a
commit a3429268ea
1 changed files with 145 additions and 17 deletions

View File

@ -25,12 +25,12 @@ from contextlib import (
# asynccontextmanager as acm,
contextmanager as cm,
)
from pprint import pformat
from contextvars import ContextVar
from typing import (
Any,
Type,
TYPE_CHECKING,
# Union,
Union,
)
# ------ - ------
from msgspec import (
@ -63,7 +63,7 @@ from .types import (
Started,
Stop,
Yield,
# pretty_struct,
pretty_struct,
)
@ -75,6 +75,9 @@ if TYPE_CHECKING:
log = get_logger(__name__)
_def_any_pldec: MsgDec = mk_dec()
class PldRx(Struct):
'''
A "msg payload receiver".
@ -101,10 +104,13 @@ class PldRx(Struct):
'''
# TODO: better to bind it here?
# _rx_mc: trio.MemoryReceiveChannel
_msgdec: MsgDec = mk_dec(spec=Any)
_pldec: MsgDec
_ipc: Context|MsgStream|None = None
@property
def pld_dec(self) -> MsgDec:
return self._pldec
@cm
def apply_to_ipc(
self,
@ -122,9 +128,29 @@ class PldRx(Struct):
finally:
self._ipc = None
@cm
def limit_plds(
self,
spec: Union[Type[Struct]],
) -> MsgDec:
'''
Type-limit the loadable msg payloads via an applied
`MsgDec` given an input spec, revert to prior decoder on
exit.
'''
orig_dec: MsgDec = self._pldec
limit_dec: MsgDec = mk_dec(spec=spec)
try:
self._pldec = limit_dec
yield limit_dec
finally:
self._pldec = orig_dec
@property
def dec(self) -> msgpack.Decoder:
return self._msgdec.dec
return self._pldec.dec
def recv_pld_nowait(
self,
@ -182,7 +208,7 @@ class PldRx(Struct):
self,
msg: MsgType,
ctx: Context,
expect_msg: Type[MsgType]|None = None,
expect_msg: Type[MsgType]|None,
) -> PayloadT|Raw:
'''
@ -199,11 +225,11 @@ class PldRx(Struct):
|Return(pld=pld) # termination phase
):
try:
pld: PayloadT = self._msgdec.decode(pld)
pld: PayloadT = self._pldec.decode(pld)
log.runtime(
'Decode msg payload\n\n'
f'{msg}\n\n'
f'{pld}\n'
'Decoded msg payload\n\n'
f'{msg}\n'
f'|_pld={pld!r}'
)
return pld
@ -237,9 +263,42 @@ class PldRx(Struct):
case Error():
src_err = MessagingError(
'IPC dialog termination by msg'
'IPC ctx dialog terminated without `Return`-ing a result'
)
case Stop(cid=cid):
message: str = (
f'{ctx.side!r}-side of ctx received stream-`Stop` from '
f'{ctx.peer_side!r} peer ?\n'
f'|_cid: {cid}\n\n'
f'{pretty_struct.pformat(msg)}\n'
)
if ctx._stream is None:
explain: str = (
f'BUT, no `MsgStream` (was) open(ed) on this '
f'{ctx.side!r}-side of the IPC ctx?\n'
f'Maybe check your code for streaming phase race conditions?\n'
)
log.warning(
message
+
explain
)
# let caller decide what to do when only one
# side opened a stream, don't raise.
return msg
else:
explain: str = (
'Received a `Stop` when it should NEVER be possible!?!?\n'
)
# TODO: this is constructed inside
# `_raise_from_unexpected_msg()` but maybe we
# should pass it in?
# src_err = trio.EndOfChannel(explain)
src_err = None
case _:
src_err = InternalError(
'Unknown IPC msg ??\n\n'
@ -259,6 +318,7 @@ class PldRx(Struct):
async def recv_msg_w_pld(
self,
ipc: Context|MsgStream,
expect_msg: MsgType,
) -> tuple[MsgType, PayloadT]:
'''
@ -274,10 +334,75 @@ class PldRx(Struct):
pld: PayloadT = self.dec_msg(
msg,
ctx=ipc,
expect_msg=expect_msg,
)
return msg, pld
# Always maintain a task-context-global `PldRx`
_def_pld_rx: PldRx = PldRx(
_pldec=_def_any_pldec,
)
_ctxvar_PldRx: ContextVar[PldRx] = ContextVar(
'pld_rx',
default=_def_pld_rx,
)
def current_pldrx() -> PldRx:
'''
Return the current `trio.Task.context`'s msg-payload
receiver, the post IPC but pre-app code `MsgType.pld`
filter.
Modification of the current payload spec via `limit_plds()`
allows an application to contextually filter typed IPC msg
content delivered via wire transport.
'''
return _ctxvar_PldRx.get()
@cm
def limit_plds(
spec: Union[Type[Struct]],
**kwargs,
) -> MsgDec:
'''
Apply a `MsgCodec` that will natively decode the SC-msg set's
`Msg.pld: Union[Type[Struct]]` payload fields using
tagged-unions of `msgspec.Struct`s from the `payload_types`
for all IPC contexts in use by the current `trio.Task`.
'''
__tracebackhide__: bool = True
try:
# sanity on orig settings
orig_pldrx: PldRx = current_pldrx()
orig_pldec: MsgDec = orig_pldrx.pld_dec
with orig_pldrx.limit_plds(
spec=spec,
**kwargs,
) as pldec:
log.info(
'Applying payload-decoder\n\n'
f'{pldec}\n'
)
yield pldec
finally:
log.info(
'Reverted to previous payload-decoder\n\n'
f'{orig_pldec}\n'
)
assert (
(pldrx := current_pldrx()) is orig_pldrx
and
pldrx.pld_dec is orig_pldec
)
async def drain_to_final_msg(
ctx: Context,
@ -368,7 +493,10 @@ async def drain_to_final_msg(
# pray to the `trio` gawds that we're corrent with this
# msg: dict = await ctx._rx_chan.receive()
msg, pld = await ctx._pld_rx.recv_msg_w_pld(ipc=ctx)
msg, pld = await ctx._pld_rx.recv_msg_w_pld(
ipc=ctx,
expect_msg=Return,
)
# NOTE: we get here if the far end was
# `ContextCancelled` in 2 cases:
@ -399,7 +527,7 @@ async def drain_to_final_msg(
ctx._result: Any = pld
log.runtime(
'Context delivered final draining msg:\n'
f'{pformat(msg)}'
f'{pretty_struct.pformat(msg)}'
)
# XXX: only close the rx mem chan AFTER
# a final result is retreived.
@ -435,7 +563,7 @@ async def drain_to_final_msg(
f'=> {ctx._task}\n'
f' |_{ctx._stream}\n\n'
f'{pformat(msg)}\n'
f'{pretty_struct.pformat(msg)}\n'
)
return (
return_msg,
@ -452,7 +580,7 @@ async def drain_to_final_msg(
f'=> {ctx._task}\n'
f' |_{ctx._stream}\n\n'
f'{pformat(msg)}\n'
f'{pretty_struct.pformat(msg)}\n'
)
continue
@ -467,7 +595,7 @@ async def drain_to_final_msg(
pre_result_drained.append(msg)
log.cancel(
'Remote stream terminated due to "stop" msg:\n\n'
f'{pformat(msg)}\n'
f'{pretty_struct.pformat(msg)}\n'
)
continue