forked from goodboy/tractor
First draft payload-spec limit API
Add new task-scope oriented `PldRx.pld_spec` management API similar to `.msg._codec.limit_msg_spec()`, but obvi built to process and filter `MsgType.pld` values. New API related changes include: - new per-task singleton getter `msg._ops.current_pldrx()` which delivers the current (global) payload receiver via a new `_ctxvar_PldRx: ContextVar` configured with a default `_def_any_pldec: MsgDec[Any]` decoder. - a `PldRx.limit_plds()` which sets the decoder (`.type` underneath) for the specific payload rx instance. - `.msg._ops.limit_plds()` which obtains the current task-scoped `PldRx` and applies the pld spec via a new `PldRx.limit_plds()`. - rename `PldRx._msgdec` -> `._pldec`. - add `.pld_dec` as pub attr for -^ Unrelated adjustments: - use `.msg.pretty_struct.pformat()` where handy. - always pass `expect_msg: MsgType`. - add a `case Stop()` to `PldRx.dec_msg()` which will `log.warning()` when a stop is received by no stream was open on this receiving side since we rarely want that to raise since it's prolly just a runtime race or mistake in user code. Other:runtime_to_msgspec
parent
d285a3479a
commit
a3429268ea
tractor/msg
|
@ -25,12 +25,12 @@ from contextlib import (
|
|||
# asynccontextmanager as acm,
|
||||
contextmanager as cm,
|
||||
)
|
||||
from pprint import pformat
|
||||
from contextvars import ContextVar
|
||||
from typing import (
|
||||
Any,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
# Union,
|
||||
Union,
|
||||
)
|
||||
# ------ - ------
|
||||
from msgspec import (
|
||||
|
@ -63,7 +63,7 @@ from .types import (
|
|||
Started,
|
||||
Stop,
|
||||
Yield,
|
||||
# pretty_struct,
|
||||
pretty_struct,
|
||||
)
|
||||
|
||||
|
||||
|
@ -75,6 +75,9 @@ if TYPE_CHECKING:
|
|||
log = get_logger(__name__)
|
||||
|
||||
|
||||
_def_any_pldec: MsgDec = mk_dec()
|
||||
|
||||
|
||||
class PldRx(Struct):
|
||||
'''
|
||||
A "msg payload receiver".
|
||||
|
@ -101,10 +104,13 @@ class PldRx(Struct):
|
|||
'''
|
||||
# TODO: better to bind it here?
|
||||
# _rx_mc: trio.MemoryReceiveChannel
|
||||
_msgdec: MsgDec = mk_dec(spec=Any)
|
||||
|
||||
_pldec: MsgDec
|
||||
_ipc: Context|MsgStream|None = None
|
||||
|
||||
@property
|
||||
def pld_dec(self) -> MsgDec:
|
||||
return self._pldec
|
||||
|
||||
@cm
|
||||
def apply_to_ipc(
|
||||
self,
|
||||
|
@ -122,9 +128,29 @@ class PldRx(Struct):
|
|||
finally:
|
||||
self._ipc = None
|
||||
|
||||
@cm
|
||||
def limit_plds(
|
||||
self,
|
||||
spec: Union[Type[Struct]],
|
||||
|
||||
) -> MsgDec:
|
||||
'''
|
||||
Type-limit the loadable msg payloads via an applied
|
||||
`MsgDec` given an input spec, revert to prior decoder on
|
||||
exit.
|
||||
|
||||
'''
|
||||
orig_dec: MsgDec = self._pldec
|
||||
limit_dec: MsgDec = mk_dec(spec=spec)
|
||||
try:
|
||||
self._pldec = limit_dec
|
||||
yield limit_dec
|
||||
finally:
|
||||
self._pldec = orig_dec
|
||||
|
||||
@property
|
||||
def dec(self) -> msgpack.Decoder:
|
||||
return self._msgdec.dec
|
||||
return self._pldec.dec
|
||||
|
||||
def recv_pld_nowait(
|
||||
self,
|
||||
|
@ -182,7 +208,7 @@ class PldRx(Struct):
|
|||
self,
|
||||
msg: MsgType,
|
||||
ctx: Context,
|
||||
expect_msg: Type[MsgType]|None = None,
|
||||
expect_msg: Type[MsgType]|None,
|
||||
|
||||
) -> PayloadT|Raw:
|
||||
'''
|
||||
|
@ -199,11 +225,11 @@ class PldRx(Struct):
|
|||
|Return(pld=pld) # termination phase
|
||||
):
|
||||
try:
|
||||
pld: PayloadT = self._msgdec.decode(pld)
|
||||
pld: PayloadT = self._pldec.decode(pld)
|
||||
log.runtime(
|
||||
'Decode msg payload\n\n'
|
||||
f'{msg}\n\n'
|
||||
f'{pld}\n'
|
||||
'Decoded msg payload\n\n'
|
||||
f'{msg}\n'
|
||||
f'|_pld={pld!r}'
|
||||
)
|
||||
return pld
|
||||
|
||||
|
@ -237,9 +263,42 @@ class PldRx(Struct):
|
|||
|
||||
case Error():
|
||||
src_err = MessagingError(
|
||||
'IPC dialog termination by msg'
|
||||
'IPC ctx dialog terminated without `Return`-ing a result'
|
||||
)
|
||||
|
||||
case Stop(cid=cid):
|
||||
message: str = (
|
||||
f'{ctx.side!r}-side of ctx received stream-`Stop` from '
|
||||
f'{ctx.peer_side!r} peer ?\n'
|
||||
f'|_cid: {cid}\n\n'
|
||||
|
||||
f'{pretty_struct.pformat(msg)}\n'
|
||||
)
|
||||
if ctx._stream is None:
|
||||
explain: str = (
|
||||
f'BUT, no `MsgStream` (was) open(ed) on this '
|
||||
f'{ctx.side!r}-side of the IPC ctx?\n'
|
||||
f'Maybe check your code for streaming phase race conditions?\n'
|
||||
)
|
||||
log.warning(
|
||||
message
|
||||
+
|
||||
explain
|
||||
)
|
||||
# let caller decide what to do when only one
|
||||
# side opened a stream, don't raise.
|
||||
return msg
|
||||
|
||||
else:
|
||||
explain: str = (
|
||||
'Received a `Stop` when it should NEVER be possible!?!?\n'
|
||||
)
|
||||
# TODO: this is constructed inside
|
||||
# `_raise_from_unexpected_msg()` but maybe we
|
||||
# should pass it in?
|
||||
# src_err = trio.EndOfChannel(explain)
|
||||
src_err = None
|
||||
|
||||
case _:
|
||||
src_err = InternalError(
|
||||
'Unknown IPC msg ??\n\n'
|
||||
|
@ -259,6 +318,7 @@ class PldRx(Struct):
|
|||
async def recv_msg_w_pld(
|
||||
self,
|
||||
ipc: Context|MsgStream,
|
||||
expect_msg: MsgType,
|
||||
|
||||
) -> tuple[MsgType, PayloadT]:
|
||||
'''
|
||||
|
@ -274,10 +334,75 @@ class PldRx(Struct):
|
|||
pld: PayloadT = self.dec_msg(
|
||||
msg,
|
||||
ctx=ipc,
|
||||
expect_msg=expect_msg,
|
||||
)
|
||||
return msg, pld
|
||||
|
||||
|
||||
# Always maintain a task-context-global `PldRx`
|
||||
_def_pld_rx: PldRx = PldRx(
|
||||
_pldec=_def_any_pldec,
|
||||
)
|
||||
_ctxvar_PldRx: ContextVar[PldRx] = ContextVar(
|
||||
'pld_rx',
|
||||
default=_def_pld_rx,
|
||||
)
|
||||
|
||||
|
||||
def current_pldrx() -> PldRx:
|
||||
'''
|
||||
Return the current `trio.Task.context`'s msg-payload
|
||||
receiver, the post IPC but pre-app code `MsgType.pld`
|
||||
filter.
|
||||
|
||||
Modification of the current payload spec via `limit_plds()`
|
||||
allows an application to contextually filter typed IPC msg
|
||||
content delivered via wire transport.
|
||||
|
||||
'''
|
||||
return _ctxvar_PldRx.get()
|
||||
|
||||
|
||||
@cm
|
||||
def limit_plds(
|
||||
spec: Union[Type[Struct]],
|
||||
**kwargs,
|
||||
|
||||
) -> MsgDec:
|
||||
'''
|
||||
Apply a `MsgCodec` that will natively decode the SC-msg set's
|
||||
`Msg.pld: Union[Type[Struct]]` payload fields using
|
||||
tagged-unions of `msgspec.Struct`s from the `payload_types`
|
||||
for all IPC contexts in use by the current `trio.Task`.
|
||||
|
||||
'''
|
||||
__tracebackhide__: bool = True
|
||||
try:
|
||||
# sanity on orig settings
|
||||
orig_pldrx: PldRx = current_pldrx()
|
||||
orig_pldec: MsgDec = orig_pldrx.pld_dec
|
||||
|
||||
with orig_pldrx.limit_plds(
|
||||
spec=spec,
|
||||
**kwargs,
|
||||
) as pldec:
|
||||
log.info(
|
||||
'Applying payload-decoder\n\n'
|
||||
f'{pldec}\n'
|
||||
)
|
||||
yield pldec
|
||||
finally:
|
||||
log.info(
|
||||
'Reverted to previous payload-decoder\n\n'
|
||||
f'{orig_pldec}\n'
|
||||
)
|
||||
assert (
|
||||
(pldrx := current_pldrx()) is orig_pldrx
|
||||
and
|
||||
pldrx.pld_dec is orig_pldec
|
||||
)
|
||||
|
||||
|
||||
async def drain_to_final_msg(
|
||||
ctx: Context,
|
||||
|
||||
|
@ -368,7 +493,10 @@ async def drain_to_final_msg(
|
|||
|
||||
# pray to the `trio` gawds that we're corrent with this
|
||||
# msg: dict = await ctx._rx_chan.receive()
|
||||
msg, pld = await ctx._pld_rx.recv_msg_w_pld(ipc=ctx)
|
||||
msg, pld = await ctx._pld_rx.recv_msg_w_pld(
|
||||
ipc=ctx,
|
||||
expect_msg=Return,
|
||||
)
|
||||
|
||||
# NOTE: we get here if the far end was
|
||||
# `ContextCancelled` in 2 cases:
|
||||
|
@ -399,7 +527,7 @@ async def drain_to_final_msg(
|
|||
ctx._result: Any = pld
|
||||
log.runtime(
|
||||
'Context delivered final draining msg:\n'
|
||||
f'{pformat(msg)}'
|
||||
f'{pretty_struct.pformat(msg)}'
|
||||
)
|
||||
# XXX: only close the rx mem chan AFTER
|
||||
# a final result is retreived.
|
||||
|
@ -435,7 +563,7 @@ async def drain_to_final_msg(
|
|||
f'=> {ctx._task}\n'
|
||||
f' |_{ctx._stream}\n\n'
|
||||
|
||||
f'{pformat(msg)}\n'
|
||||
f'{pretty_struct.pformat(msg)}\n'
|
||||
)
|
||||
return (
|
||||
return_msg,
|
||||
|
@ -452,7 +580,7 @@ async def drain_to_final_msg(
|
|||
f'=> {ctx._task}\n'
|
||||
f' |_{ctx._stream}\n\n'
|
||||
|
||||
f'{pformat(msg)}\n'
|
||||
f'{pretty_struct.pformat(msg)}\n'
|
||||
)
|
||||
continue
|
||||
|
||||
|
@ -467,7 +595,7 @@ async def drain_to_final_msg(
|
|||
pre_result_drained.append(msg)
|
||||
log.cancel(
|
||||
'Remote stream terminated due to "stop" msg:\n\n'
|
||||
f'{pformat(msg)}\n'
|
||||
f'{pretty_struct.pformat(msg)}\n'
|
||||
)
|
||||
continue
|
||||
|
||||
|
|
Loading…
Reference in New Issue