From 2786a0533255f5d83d5e9375f8de1082091c4204 Mon Sep 17 00:00:00 2001
From: Tyler Goodlet <jgbt@protonmail.com>
Date: Mon, 25 Mar 2024 16:31:16 -0400
Subject: [PATCH] Prepare to offer (dynamic) `.msg.Codec` overrides

By simply allowing an input `codec: tuple` of funcs for now to the
`MsgpackTCPStream` transport but, ideally wrapping this in a `Codec`
type with an API for dynamic extension of the interchange lib's msg
processing settings. Right now we're tied to `msgspec.msgpack` for this
transport but with the right design this can likely extend to other libs
in the future.

Relates to starting feature work toward #36, #196, #365.
---
 tractor/_ipc.py | 43 ++++++++++++++++++++++++++++++++++++++-----
 1 file changed, 38 insertions(+), 5 deletions(-)

diff --git a/tractor/_ipc.py b/tractor/_ipc.py
index f57d3bd8..2b5df698 100644
--- a/tractor/_ipc.py
+++ b/tractor/_ipc.py
@@ -30,6 +30,7 @@ import struct
 import typing
 from typing import (
     Any,
+    Callable,
     runtime_checkable,
     Protocol,
     Type,
@@ -123,6 +124,16 @@ class MsgpackTCPStream(MsgTransport):
         stream: trio.SocketStream,
         prefix_size: int = 4,
 
+        # XXX optionally provided codec pair for `msgspec`:
+        # https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types
+        #
+        # TODO: define this as a `Codec` struct which can be
+        # overriden dynamically by the application/runtime.
+        codec: tuple[
+            Callable[[Any], Any]|None,  # coder
+            Callable[[type, Any], Any]|None,  # decoder
+        ]|None = None,
+
     ) -> None:
 
         self.stream = stream
@@ -138,12 +149,18 @@ class MsgpackTCPStream(MsgTransport):
         # public i guess?
         self.drained: list[dict] = []
 
-        self.recv_stream = BufferedReceiveStream(transport_stream=stream)
+        self.recv_stream = BufferedReceiveStream(
+            transport_stream=stream
+        )
         self.prefix_size = prefix_size
 
         # TODO: struct aware messaging coders
-        self.encode = msgspec.msgpack.Encoder().encode
-        self.decode = msgspec.msgpack.Decoder().decode  # dict[str, Any])
+        self.encode = msgspec.msgpack.Encoder(
+            enc_hook=codec[0] if codec else None,
+        ).encode
+        self.decode = msgspec.msgpack.Decoder(
+            dec_hook=codec[1] if codec else None,
+        ).decode
 
     async def _iter_packets(self) -> AsyncGenerator[dict, None]:
         '''Yield packets from the underlying stream.
@@ -349,9 +366,25 @@ class Channel:
         stream: trio.SocketStream,
         type_key: tuple[str, str]|None = None,
 
+        # XXX optionally provided codec pair for `msgspec`:
+        # https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types
+        codec: tuple[
+            Callable[[Any], Any],  # coder
+            Callable[[type, Any], Any],  # decoder
+        ]|None = None,
+
     ) -> MsgTransport:
-        type_key = type_key or self._transport_key
-        self._transport = get_msg_transport(type_key)(stream)
+        type_key = (
+            type_key
+            or
+            self._transport_key
+        )
+        self._transport = get_msg_transport(
+            type_key
+        )(
+            stream,
+            codec=codec,
+        )
         return self._transport
 
     def __repr__(self) -> str: