From b5bdd20eb566c8aca0e37fd4316c5fcf857f7afa Mon Sep 17 00:00:00 2001
From: Tyler Goodlet <jgbt@protonmail.com>
Date: Tue, 2 Apr 2024 11:14:43 -0400
Subject: [PATCH] Get `test_codec_hooks_mod` working with `Msg`s

Though the runtime hasn't been changed over in this patch (it was in the
local index at the time however), the test does now demonstrate that
using a `Started` the correctly typed `.pld` will codec correctly when
passed manually to `MsgCodec.encode/decode()`.

Despite not having the runtime ported to the new shuttle msg set
(meaning the mentioned test will fail without the runtime port patch),
I was able to get this first original test working that limits payload
packets as a `Msg.pld: NamespacePath`this as long as we spec
`enc/dec_hook()`s then the `Msg.pld` will be processed correctly as per:
https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types
in both the `Any` and `NamespacePath|None` spec cases.
^- turns out in this case -^ that the codec hooks only get invoked on
the unknown-fields NOT the entire `Struct`-msg.

A further gotcha was merging a `|None` into the `pld_spec` since this
test spawns a subactor and opens a context via `send_back_nsp()` and
that func has no explicit `return` - so of course it delivers
a `Return(pld=None)` which will fail if we only spec `NamespacePath`.
---
 tests/test_caps_based_msging.py | 305 ++++++++++++++++++++++++--------
 1 file changed, 236 insertions(+), 69 deletions(-)

diff --git a/tests/test_caps_based_msging.py b/tests/test_caps_based_msging.py
index abdda0a5..b42d9e35 100644
--- a/tests/test_caps_based_msging.py
+++ b/tests/test_caps_based_msging.py
@@ -7,7 +7,6 @@ B~)
 '''
 from typing import (
     Any,
-    _GenericAlias,
     Type,
     Union,
 )
@@ -26,20 +25,23 @@ from msgspec import (
 import pytest
 import tractor
 from tractor.msg import (
-    _def_msgspec_codec,
+    _codec,
     _ctxvar_MsgCodec,
 
     NamespacePath,
     MsgCodec,
     mk_codec,
     apply_codec,
-    current_msgspec_codec,
+    current_codec,
 )
-from tractor.msg import types
+from tractor.msg import (
+    types,
+)
+from tractor import _state
 from tractor.msg.types import (
     # PayloadT,
     Msg,
-    # Started,
+    Started,
     mk_msg_spec,
 )
 import trio
@@ -60,56 +62,110 @@ def test_msg_spec_xor_pld_spec():
         )
 
 
-# TODO: wrap these into `._codec` such that user can just pass
-# a type table of some sort?
-def enc_hook(obj: Any) -> Any:
-    if isinstance(obj, NamespacePath):
-        return str(obj)
-    else:
-        raise NotImplementedError(
-            f'Objects of type {type(obj)} are not supported'
-        )
-
-
-def dec_hook(type: Type, obj: Any) -> Any:
-    print(f'type is: {type}')
-    if type is NamespacePath:
-        return NamespacePath(obj)
-    else:
-        raise NotImplementedError(
-            f'Objects of type {type(obj)} are not supported'
-        )
-
-
 def ex_func(*args):
     print(f'ex_func({args})')
 
 
 def mk_custom_codec(
-    ipc_msg_spec: Type[Any] = Any,
-) -> MsgCodec:
-    # apply custom hooks and set a `Decoder` which only
-    # loads `NamespacePath` types.
-    nsp_codec: MsgCodec = mk_codec(
-        ipc_msg_spec=ipc_msg_spec,
-        enc_hook=enc_hook,
-        dec_hook=dec_hook,
-    )
+    pld_spec: Union[Type]|Any,
 
-    # TODO: validate `MsgCodec` interface/semantics?
-    # -[ ] simple field tests to ensure caching + reset is workin?
-    # -[ ] custom / changing `.decoder()` calls?
-    #
-    # dec = nsp_codec.decoder(
-    #     types=NamespacePath,
-    # )
-    # assert nsp_codec.dec is dec
+) -> MsgCodec:
+    '''
+    Create custom `msgpack` enc/dec-hooks and set a `Decoder`
+    which only loads `NamespacePath` types.
+
+    '''
+    uid: tuple[str, str] = tractor.current_actor().uid
+
+    # XXX NOTE XXX: despite defining `NamespacePath` as a type
+    # field on our `Msg.pld`, we still need a enc/dec_hook() pair
+    # to cast to/from that type on the wire. See the docs:
+    # https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types
+
+    def enc_nsp(obj: Any) -> Any:
+        match obj:
+            case NamespacePath():
+                print(
+                    f'{uid}: `NamespacePath`-Only ENCODE?\n'
+                    f'type: {type(obj)}\n'
+                    f'obj: {obj}\n'
+                )
+
+                return str(obj)
+
+        logmsg: str = (
+            f'{uid}: Encoding `{obj}: <{type(obj)}>` not supported'
+            f'type: {type(obj)}\n'
+            f'obj: {obj}\n'
+        )
+        print(logmsg)
+        raise NotImplementedError(logmsg)
+
+    def dec_nsp(
+        type: Type,
+        obj: Any,
+
+    ) -> Any:
+        print(
+            f'{uid}: CUSTOM DECODE\n'
+            f'input type: {type}\n'
+            f'obj: {obj}\n'
+            f'type(obj): `{type(obj).__class__}`\n'
+        )
+        nsp = None
+
+        # This never seems to hit?
+        if isinstance(obj, Msg):
+            print(f'Msg type: {obj}')
+
+        if (
+            type is NamespacePath
+            and isinstance(obj, str)
+            and ':' in obj
+        ):
+            nsp = NamespacePath(obj)
+
+        if nsp:
+            print(f'Returning NSP instance: {nsp}')
+            return nsp
+
+        logmsg: str = (
+            f'{uid}: Decoding `{obj}: <{type(obj)}>` not supported'
+            f'input type: {type(obj)}\n'
+            f'obj: {obj}\n'
+            f'type(obj): `{type(obj).__class__}`\n'
+        )
+        print(logmsg)
+        raise NotImplementedError(logmsg)
+
+
+    nsp_codec: MsgCodec = mk_codec(
+        ipc_pld_spec=pld_spec,
+
+        # NOTE XXX: the encode hook MUST be used no matter what since
+        # our `NamespacePath` is not any of a `Any` native type nor
+        # a `msgspec.Struct` subtype - so `msgspec` has no way to know
+        # how to encode it unless we provide the custom hook.
+        #
+        # AGAIN that is, regardless of whether we spec an
+        # `Any`-decoded-pld the enc has no knowledge (by default)
+        # how to enc `NamespacePath` (nsp), so we add a custom
+        # hook to do that ALWAYS.
+        enc_hook=enc_nsp,
+
+        # XXX NOTE: pretty sure this is mutex with the `type=` to
+        # `Decoder`? so it won't work in tandem with the
+        # `ipc_pld_spec` passed above?
+        dec_hook=dec_nsp,
+    )
     return nsp_codec
 
 
 @tractor.context
 async def send_back_nsp(
-    ctx: tractor.Context,
+    ctx: Context,
+    expect_debug: bool,
+    use_any_spec: bool,
 
 ) -> None:
     '''
@@ -117,28 +173,65 @@ async def send_back_nsp(
     and ensure we can round trip a func ref with our parent.
 
     '''
-    task: trio.Task = trio.lowlevel.current_task()
-    task_ctx: Context = task.context
-    assert _ctxvar_MsgCodec not in task_ctx
+    # debug mode sanity check
+    assert expect_debug == _state.debug_mode()
 
-    nsp_codec: MsgCodec = mk_custom_codec()
+    # task: trio.Task = trio.lowlevel.current_task()
+
+    # TreeVar
+    # curr_codec = _ctxvar_MsgCodec.get_in(task)
+
+    # ContextVar
+    # task_ctx: Context = task.context
+    # assert _ctxvar_MsgCodec not in task_ctx
+
+    curr_codec = _ctxvar_MsgCodec.get()
+    assert curr_codec is _codec._def_tractor_codec
+
+    if use_any_spec:
+        pld_spec = Any
+    else:
+        # NOTE: don't need the |None here since
+        # the parent side will never send `None` like
+        # we do here in the implicit return at the end of this
+        # `@context` body.
+        pld_spec = NamespacePath  # |None
+
+    nsp_codec: MsgCodec = mk_custom_codec(
+        pld_spec=pld_spec,
+    )
     with apply_codec(nsp_codec) as codec:
         chk_codec_applied(
             custom_codec=nsp_codec,
             enter_value=codec,
         )
 
+        # ensure roundtripping works locally
         nsp = NamespacePath.from_ref(ex_func)
-        await ctx.started(nsp)
+        wire_bytes: bytes = nsp_codec.encode(
+            Started(
+                cid=ctx.cid,
+                pld=nsp
+            )
+        )
+        msg: Started = nsp_codec.decode(wire_bytes)
+        pld = msg.pld
+        assert pld == nsp
 
+        await ctx.started(nsp)
         async with ctx.open_stream() as ipc:
             async for msg in ipc:
 
-                assert msg == f'{__name__}:ex_func'
+                if use_any_spec:
+                    assert msg == f'{__name__}:ex_func'
 
-                # TODO: as per below
-                # assert isinstance(msg, NamespacePath)
-                assert isinstance(msg, str)
+                    # TODO: as per below
+                    # assert isinstance(msg, NamespacePath)
+                    assert isinstance(msg, str)
+                else:
+                    assert isinstance(msg, NamespacePath)
+
+                await ipc.send(msg)
 
 
 def chk_codec_applied(
@@ -146,11 +239,20 @@ def chk_codec_applied(
     enter_value: MsgCodec,
 ) -> MsgCodec:
 
-    task: trio.Task = trio.lowlevel.current_task()
-    task_ctx: Context = task.context
+    # task: trio.Task = trio.lowlevel.current_task()
 
-    assert _ctxvar_MsgCodec in task_ctx
-    curr_codec: MsgCodec = task.context[_ctxvar_MsgCodec]
+    # TreeVar
+    # curr_codec = _ctxvar_MsgCodec.get_in(task)
+
+    # ContextVar
+    # task_ctx: Context = task.context
+    # assert _ctxvar_MsgCodec in task_ctx
+    # curr_codec: MsgCodec = task.context[_ctxvar_MsgCodec]
+
+    # RunVar
+    curr_codec: MsgCodec = _ctxvar_MsgCodec.get()
+    last_read_codec = _ctxvar_MsgCodec.get()
+    assert curr_codec is last_read_codec
 
     assert (
         # returned from `mk_codec()`
@@ -163,14 +265,31 @@ def chk_codec_applied(
         curr_codec is
 
         # public API for all of the above
-        current_msgspec_codec()
+        current_codec()
 
         # the default `msgspec` settings
-        is not _def_msgspec_codec
+        is not _codec._def_msgspec_codec
+        is not _codec._def_tractor_codec
     )
 
 
-def test_codec_hooks_mod():
+@pytest.mark.parametrize(
+    'ipc_pld_spec',
+    [
+        # _codec._def_msgspec_codec,
+        Any,
+        # _codec._def_tractor_codec,
+        NamespacePath|None,
+    ],
+    ids=[
+        'any_type',
+        'nsp_type',
+    ]
+)
+def test_codec_hooks_mod(
+    debug_mode: bool,
+    ipc_pld_spec: Union[Type]|Any,
+):
     '''
     Audit the `.msg.MsgCodec` override apis details given our impl
     uses `contextvars` to accomplish per `trio` task codec
@@ -178,11 +297,21 @@ def test_codec_hooks_mod():
 
     '''
     async def main():
-        task: trio.Task = trio.lowlevel.current_task()
-        task_ctx: Context = task.context
-        assert _ctxvar_MsgCodec not in task_ctx
 
-        async with tractor.open_nursery() as an:
+        # task: trio.Task = trio.lowlevel.current_task()
+
+        # ContextVar
+        # task_ctx: Context = task.context
+        # assert _ctxvar_MsgCodec not in task_ctx
+
+        # TreeVar
+        # def_codec: MsgCodec = _ctxvar_MsgCodec.get_in(task)
+        def_codec = _ctxvar_MsgCodec.get()
+        assert def_codec is _codec._def_tractor_codec
+
+        async with tractor.open_nursery(
+            debug_mode=debug_mode,
+        ) as an:
             p: tractor.Portal = await an.start_actor(
                 'sub',
                 enable_modules=[__name__],
@@ -192,7 +321,9 @@ def test_codec_hooks_mod():
             # - codec not modified -> decode nsp as `str`
             # - codec modified with hooks -> decode nsp as
             #   `NamespacePath`
-            nsp_codec: MsgCodec = mk_custom_codec()
+            nsp_codec: MsgCodec = mk_custom_codec(
+                pld_spec=ipc_pld_spec,
+            )
             with apply_codec(nsp_codec) as codec:
                 chk_codec_applied(
                     custom_codec=nsp_codec,
@@ -202,9 +333,22 @@ def test_codec_hooks_mod():
                 async with (
                     p.open_context(
                         send_back_nsp,
+                        # TODO: send the original nsp here and
+                        # test with `limit_msg_spec()` above?
+                        expect_debug=debug_mode,
+                        use_any_spec=(ipc_pld_spec==Any),
+
                     ) as (ctx, first),
                     ctx.open_stream() as ipc,
                 ):
+                    if ipc_pld_spec is NamespacePath:
+                        assert isinstance(first, NamespacePath)
+
+                    print(
+                        'root: ENTERING CONTEXT BLOCK\n'
+                        f'type(first): {type(first)}\n'
+                        f'first: {first}\n'
+                    )
                     # ensure codec is still applied across
                     # `tractor.Context` + its embedded nursery.
                     chk_codec_applied(
@@ -212,23 +356,46 @@ def test_codec_hooks_mod():
                         enter_value=codec,
                     )
 
-                    assert first == f'{__name__}:ex_func'
+                    first_nsp = NamespacePath(first)
+
+                    # ensure roundtripping works
+                    wire_bytes: bytes = nsp_codec.encode(
+                        Started(
+                            cid=ctx.cid,
+                            pld=first_nsp
+                        )
+                    )
+                    msg: Started = nsp_codec.decode(wire_bytes)
+                    pld = msg.pld
+                    assert  pld == first_nsp
+
+                    # try a manual decode of the started msg+pld
+
                     # TODO: actually get the decoder loading
                     # to native once we spec our SCIPP msgspec
                     # (structurred-conc-inter-proc-protocol)
                     # implemented as per,
                     # https://github.com/goodboy/tractor/issues/36
                     #
-                    # assert isinstance(first, NamespacePath)
-                    assert isinstance(first, str)
+                    if ipc_pld_spec is NamespacePath:
+                        assert isinstance(first, NamespacePath)
+
+                    # `Any`-payload-spec case
+                    else:
+                        assert isinstance(first, str)
+                        assert first == f'{__name__}:ex_func'
+
                     await ipc.send(first)
 
-                    with trio.move_on_after(1):
+                    with trio.move_on_after(.6):
                         async for msg in ipc:
+                            print(msg)
 
                             # TODO: as per above
                             # assert isinstance(msg, NamespacePath)
                             assert isinstance(msg, str)
+                            await ipc.send(msg)
+                            await trio.sleep(0.1)
 
             await p.cancel_actor()