From 5449bd567309b943a0f3317aabf4a9cb95cb0fd3 Mon Sep 17 00:00:00 2001 From: Tyler Goodlet Date: Sun, 16 Jun 2024 22:50:43 -0400 Subject: [PATCH] Offer a `@context(pld_spec=)` API Instead of the WIP/prototyped `Portal.open_context()` offering a `pld_spec` input arg, this changes to a proper decorator API for specifying the "payload spec" on `@context` endpoints. The impl change details actually cover 2-birds: - monkey patch decorated functions with a new `._tractor_context_meta: dict[str, Any]` and insert any provided input `@context` kwargs: `_pld_spec`, `enc_hook`, `enc_hook`. - use `inspect.get_annotations()` to scan for a `func` arg type-annotated with `tractor.Context` and use the name of that arg as the RPC task-side injected `Context`, thus injecting the needed arg by type instead of by name (a longstanding TODO); raise a type-error when not found. - pull the `pld_spec` from the `._tractor_context_meta` attr both in the `.open_context()` parent-side and child-side `._invoke()`-cation of the RPC task and use the `msg._ops.maybe_limit_plds()` API to apply it internally in the runtime for each case. --- tractor/_context.py | 89 +++++++++++++++++++++++++++++++++++---------- tractor/_rpc.py | 25 +++++++++++-- 2 files changed, 92 insertions(+), 22 deletions(-) diff --git a/tractor/_context.py b/tractor/_context.py index dd14361..f5d9d69 100644 --- a/tractor/_context.py +++ b/tractor/_context.py @@ -1792,7 +1792,6 @@ async def open_context_from_portal( portal: Portal, func: Callable, - pld_spec: TypeAlias|None = None, allow_overruns: bool = False, hide_tb: bool = True, @@ -1838,12 +1837,20 @@ async def open_context_from_portal( # NOTE: 2 bc of the wrapping `@acm` __runtimeframe__: int = 2 # noqa - # conduct target func method structural checks - if not inspect.iscoroutinefunction(func) and ( - getattr(func, '_tractor_contex_function', False) + # if NOT an async func but decorated with `@context`, error. + if ( + not inspect.iscoroutinefunction(func) + and getattr(func, '_tractor_context_meta', False) ): raise TypeError( - f'{func} must be an async generator function!') + f'{func!r} must be an async function!' + ) + + ctx_meta: dict[str, Any]|None = getattr( + func, + '_tractor_context_meta', + None, + ) # TODO: i think from here onward should probably # just be factored into an `@acm` inside a new @@ -1890,12 +1897,9 @@ async def open_context_from_portal( trio.open_nursery() as tn, msgops.maybe_limit_plds( ctx=ctx, - spec=pld_spec, - ) as maybe_msgdec, + spec=ctx_meta.get('pld_spec'), + ), ): - if maybe_msgdec: - assert maybe_msgdec.pld_spec == pld_spec - # NOTE: this in an implicit runtime nursery used to, # - start overrun queuing tasks when as well as # for cancellation of the scope opened by the user. @@ -2398,7 +2402,15 @@ def mk_context( # a `contextlib.ContextDecorator`? # def context( - func: Callable, + func: Callable|None = None, + + *, + + # must be named! + pld_spec: Union[Type]|TypeAlias = Any, + dec_hook: Callable|None = None, + enc_hook: Callable|None = None, + ) -> Callable: ''' Mark an async function as an SC-supervised, inter-`Actor`, RPC @@ -2409,15 +2421,54 @@ def context( `tractor`. ''' + # XXX for the `@context(pld_spec=MyMsg|None)` case + if func is None: + return partial( + context, + pld_spec=pld_spec, + dec_hook=dec_hook, + enc_hook=enc_hook, + ) + + # TODO: from this, enforcing a `Start.sig` type + # check when invoking RPC tasks by ensuring the input + # args validate against the endpoint def. + sig: inspect.Signature = inspect.signature(func) + # params: inspect.Parameters = sig.parameters + + # https://docs.python.org/3/library/inspect.html#inspect.get_annotations + annots: dict[str, Type] = inspect.get_annotations( + func, + eval_str=True, + ) + name: str + param: Type + for name, param in annots.items(): + if param is Context: + ctx_var_name: str = name + break + else: + raise TypeError( + 'At least one (normally the first) argument to the `@context` function ' + f'{func.__name__!r} must be typed as `tractor.Context`, for ex,\n\n' + f'`ctx: tractor.Context`\n' + ) + # TODO: apply whatever solution ``mypy`` ends up picking for this: # https://github.com/python/mypy/issues/2087#issuecomment-769266912 - func._tractor_context_function = True # type: ignore + # func._tractor_context_function = True # type: ignore + func._tractor_context_meta: dict[str, Any] = { + 'ctx_var_name': ctx_var_name, + # `msgspec` related settings + 'pld_spec': pld_spec, + 'enc_hook': enc_hook, + 'dec_hook': dec_hook, - sig: inspect.Signature = inspect.signature(func) - params: Mapping = sig.parameters - if 'ctx' not in params: - raise TypeError( - "The first argument to the context function " - f"{func.__name__} must be `ctx: tractor.Context`" - ) + # TODO: eventually we need to "signature-check" with these + # vs. the `Start` msg fields! + # => this would allow for TPC endpoint argument-type-spec + # limiting and we could then error on + # invalid inputs passed to `.open_context(rpc_ep, arg0='blah')` + 'sig': sig, + } return func diff --git a/tractor/_rpc.py b/tractor/_rpc.py index c9eb845..166ee96 100644 --- a/tractor/_rpc.py +++ b/tractor/_rpc.py @@ -69,6 +69,7 @@ from .msg import ( PayloadT, NamespacePath, pretty_struct, + _ops as msgops, ) from tractor.msg.types import ( CancelAck, @@ -500,8 +501,19 @@ async def _invoke( # handle decorated ``@tractor.context`` async function - elif getattr(func, '_tractor_context_function', False): - kwargs['ctx'] = ctx + # - pull out any typed-pld-spec info and apply (below) + # - (TODO) store func-ref meta data for API-frame-info logging + elif ( + ctx_meta := getattr( + func, + '_tractor_context_meta', + False, + ) + ): + # kwargs['ctx'] = ctx + # set the required `tractor.Context` typed input argument to + # the allocated RPC task context. + kwargs[ctx_meta['ctx_var_name']] = ctx context_ep_func = True # errors raised inside this block are propgated back to caller @@ -595,7 +607,14 @@ async def _invoke( # `@context` marked RPC function. # - `._portal` is never set. try: - async with trio.open_nursery() as tn: + async with ( + trio.open_nursery() as tn, + msgops.maybe_limit_plds( + ctx=ctx, + spec=ctx_meta.get('pld_spec'), + dec_hook=ctx_meta.get('dec_hook'), + ), + ): ctx._scope_nursery = tn ctx._scope = tn.cancel_scope task_status.started(ctx)