317 lines
7.9 KiB
Python
317 lines
7.9 KiB
Python
|
'''
|
||
|
Audit sub-sys APIs from `.msg._ops`
|
||
|
mostly for ensuring correct `contextvars`
|
||
|
related settings around IPC contexts.
|
||
|
|
||
|
'''
|
||
|
from contextlib import (
|
||
|
asynccontextmanager as acm,
|
||
|
contextmanager as cm,
|
||
|
)
|
||
|
# import typing
|
||
|
from typing import (
|
||
|
# Any,
|
||
|
TypeAlias,
|
||
|
# Union,
|
||
|
)
|
||
|
from contextvars import (
|
||
|
Context,
|
||
|
)
|
||
|
|
||
|
from msgspec import (
|
||
|
# structs,
|
||
|
# msgpack,
|
||
|
Struct,
|
||
|
# ValidationError,
|
||
|
)
|
||
|
import pytest
|
||
|
import trio
|
||
|
|
||
|
import tractor
|
||
|
from tractor import (
|
||
|
# _state,
|
||
|
MsgTypeError,
|
||
|
current_ipc_ctx,
|
||
|
Portal,
|
||
|
)
|
||
|
from tractor.msg import (
|
||
|
_ops as msgops,
|
||
|
Return,
|
||
|
)
|
||
|
from tractor.msg import (
|
||
|
_codec,
|
||
|
# _ctxvar_MsgCodec,
|
||
|
|
||
|
# NamespacePath,
|
||
|
# MsgCodec,
|
||
|
# mk_codec,
|
||
|
# apply_codec,
|
||
|
# current_codec,
|
||
|
)
|
||
|
from tractor.msg.types import (
|
||
|
log,
|
||
|
# _payload_msgs,
|
||
|
# PayloadMsg,
|
||
|
# Started,
|
||
|
# mk_msg_spec,
|
||
|
)
|
||
|
|
||
|
|
||
|
class PldMsg(Struct):
|
||
|
field: str
|
||
|
|
||
|
|
||
|
maybe_msg_spec = PldMsg|None
|
||
|
|
||
|
|
||
|
@cm
|
||
|
def custom_spec(
|
||
|
ctx: Context,
|
||
|
spec: TypeAlias,
|
||
|
) -> _codec.MsgCodec:
|
||
|
'''
|
||
|
Apply a custom payload spec, remove on exit.
|
||
|
|
||
|
'''
|
||
|
rx: msgops.PldRx = ctx._pld_rx
|
||
|
|
||
|
|
||
|
@acm
|
||
|
async def maybe_expect_raises(
|
||
|
raises: BaseException|None = None,
|
||
|
ensure_in_message: list[str]|None = None,
|
||
|
|
||
|
reraise: bool = False,
|
||
|
timeout: int = 3,
|
||
|
) -> None:
|
||
|
'''
|
||
|
Async wrapper for ensuring errors propagate from the inner scope.
|
||
|
|
||
|
'''
|
||
|
with trio.fail_after(timeout):
|
||
|
try:
|
||
|
yield
|
||
|
except BaseException as _inner_err:
|
||
|
inner_err = _inner_err
|
||
|
# wasn't-expected to error..
|
||
|
if raises is None:
|
||
|
raise
|
||
|
|
||
|
else:
|
||
|
assert type(inner_err) is raises
|
||
|
|
||
|
# maybe check for error txt content
|
||
|
if ensure_in_message:
|
||
|
part: str
|
||
|
for part in ensure_in_message:
|
||
|
for i, arg in enumerate(inner_err.args):
|
||
|
if part in arg:
|
||
|
break
|
||
|
# if part never matches an arg, then we're
|
||
|
# missing a match.
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
'Failed to find error message content?\n\n'
|
||
|
f'expected: {ensure_in_message!r}\n'
|
||
|
f'part: {part!r}\n\n'
|
||
|
f'{inner_err.args}'
|
||
|
)
|
||
|
|
||
|
if reraise:
|
||
|
raise inner_err
|
||
|
|
||
|
else:
|
||
|
if raises:
|
||
|
raise RuntimeError(
|
||
|
f'Expected a {raises.__name__!r} to be raised?'
|
||
|
)
|
||
|
|
||
|
|
||
|
@tractor.context
|
||
|
async def child(
|
||
|
ctx: Context,
|
||
|
started_value: int|PldMsg|None,
|
||
|
return_value: str|None,
|
||
|
validate_pld_spec: bool,
|
||
|
raise_on_started_mte: bool = True,
|
||
|
|
||
|
) -> None:
|
||
|
'''
|
||
|
Call ``Context.started()`` more then once (an error).
|
||
|
|
||
|
'''
|
||
|
expect_started_mte: bool = started_value == 10
|
||
|
|
||
|
# sanaity check that child RPC context is the current one
|
||
|
curr_ctx: Context = current_ipc_ctx()
|
||
|
assert ctx is curr_ctx
|
||
|
|
||
|
rx: msgops.PldRx = ctx._pld_rx
|
||
|
orig_pldec: _codec.MsgDec = rx.pld_dec
|
||
|
# senity that default pld-spec should be set
|
||
|
assert (
|
||
|
rx.pld_dec
|
||
|
is
|
||
|
msgops._def_any_pldec
|
||
|
)
|
||
|
|
||
|
try:
|
||
|
with msgops.limit_plds(
|
||
|
spec=maybe_msg_spec,
|
||
|
) as pldec:
|
||
|
# sanity on `MsgDec` state
|
||
|
assert rx.pld_dec is pldec
|
||
|
assert pldec.spec is maybe_msg_spec
|
||
|
|
||
|
# 2 cases: hdndle send-side and recv-only validation
|
||
|
# - when `raise_on_started_mte == True`, send validate
|
||
|
# - else, parent-recv-side only validation
|
||
|
try:
|
||
|
await ctx.started(
|
||
|
value=started_value,
|
||
|
validate_pld_spec=validate_pld_spec,
|
||
|
)
|
||
|
|
||
|
except MsgTypeError:
|
||
|
log.exception('started()` raised an MTE!\n')
|
||
|
if not expect_started_mte:
|
||
|
raise RuntimeError(
|
||
|
'Child-ctx-task SHOULD NOT HAVE raised an MTE for\n\n'
|
||
|
f'{started_value!r}\n'
|
||
|
)
|
||
|
|
||
|
# propagate to parent?
|
||
|
if raise_on_started_mte:
|
||
|
raise
|
||
|
else:
|
||
|
if expect_started_mte:
|
||
|
raise RuntimeError(
|
||
|
'Child-ctx-task SHOULD HAVE raised an MTE for\n\n'
|
||
|
f'{started_value!r}\n'
|
||
|
)
|
||
|
|
||
|
# XXX should always fail on recv side since we can't
|
||
|
# really do much else beside terminate and relay the
|
||
|
# msg-type-error from this RPC task ;)
|
||
|
return return_value
|
||
|
|
||
|
finally:
|
||
|
# sanity on `limit_plds()` reversion
|
||
|
assert (
|
||
|
rx.pld_dec
|
||
|
is
|
||
|
msgops._def_any_pldec
|
||
|
)
|
||
|
log.runtime(
|
||
|
'Reverted to previous pld-spec\n\n'
|
||
|
f'{orig_pldec}\n'
|
||
|
)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
'return_value',
|
||
|
[
|
||
|
None,
|
||
|
'yo',
|
||
|
],
|
||
|
ids=[
|
||
|
'return[invalid-"yo"]',
|
||
|
'return[valid-None]',
|
||
|
],
|
||
|
)
|
||
|
@pytest.mark.parametrize(
|
||
|
'started_value',
|
||
|
[
|
||
|
10,
|
||
|
PldMsg(field='yo'),
|
||
|
],
|
||
|
ids=[
|
||
|
'Started[invalid-10]',
|
||
|
'Started[valid-PldMsg]',
|
||
|
],
|
||
|
)
|
||
|
@pytest.mark.parametrize(
|
||
|
'pld_check_started_value',
|
||
|
[
|
||
|
True,
|
||
|
False,
|
||
|
],
|
||
|
ids=[
|
||
|
'check-started-pld',
|
||
|
'no-started-pld-validate',
|
||
|
],
|
||
|
)
|
||
|
def test_basic_payload_spec(
|
||
|
debug_mode: bool,
|
||
|
loglevel: str,
|
||
|
return_value: str|None,
|
||
|
started_value: int|PldMsg,
|
||
|
pld_check_started_value: bool,
|
||
|
):
|
||
|
'''
|
||
|
Validate the most basic `PldRx` msg-type-spec semantics around
|
||
|
a IPC `Context` endpoint start, started-sync, and final return
|
||
|
value depending on set payload types and the currently applied
|
||
|
pld-spec.
|
||
|
|
||
|
'''
|
||
|
invalid_return: bool = return_value == 'yo'
|
||
|
invalid_started: bool = started_value == 10
|
||
|
|
||
|
async def main():
|
||
|
async with tractor.open_nursery(
|
||
|
debug_mode=debug_mode,
|
||
|
loglevel=loglevel,
|
||
|
) as an:
|
||
|
p: Portal = await an.start_actor(
|
||
|
'child',
|
||
|
enable_modules=[__name__],
|
||
|
)
|
||
|
|
||
|
# since not opened yet.
|
||
|
assert current_ipc_ctx() is None
|
||
|
|
||
|
async with (
|
||
|
maybe_expect_raises(
|
||
|
raises=MsgTypeError if (
|
||
|
invalid_return
|
||
|
or
|
||
|
invalid_started
|
||
|
) else None,
|
||
|
ensure_in_message=[
|
||
|
"invalid `Return` payload",
|
||
|
"value: `'yo'` does not match type-spec: `Return.pld: PldMsg|NoneType`",
|
||
|
],
|
||
|
),
|
||
|
p.open_context(
|
||
|
child,
|
||
|
return_value=return_value,
|
||
|
started_value=started_value,
|
||
|
pld_spec=maybe_msg_spec,
|
||
|
validate_pld_spec=pld_check_started_value,
|
||
|
) as (ctx, first),
|
||
|
):
|
||
|
# now opened with 'child' sub
|
||
|
assert current_ipc_ctx() is ctx
|
||
|
|
||
|
assert type(first) is PldMsg
|
||
|
assert first.field == 'yo'
|
||
|
|
||
|
try:
|
||
|
assert (await ctx.result()) is None
|
||
|
except MsgTypeError as mte:
|
||
|
if not invalid_return:
|
||
|
raise
|
||
|
|
||
|
else: # expected this invalid `Return.pld`
|
||
|
assert mte.cid == ctx.cid
|
||
|
|
||
|
# verify expected remote mte deats
|
||
|
await tractor.pause()
|
||
|
assert ctx._remote_error is mte
|
||
|
assert mte.expected_msg_type is Return
|
||
|
|
||
|
await p.cancel_actor()
|
||
|
|
||
|
trio.run(main)
|