diff --git a/tractor/_streaming.py b/tractor/_streaming.py index b112956..06f2629 100644 --- a/tractor/_streaming.py +++ b/tractor/_streaming.py @@ -699,6 +699,18 @@ class Context: await self.chan.send({'started': value, 'cid': self.cid}) self._started_called = True + + # TODO: msg capability context api1 + # @acm + # async def enable_msg_caps( + # self, + # msg_subtypes: Union[ + # list[list[Struct]], + # Protocol, # hypothetical type that wraps a msg set + # ], + # ) -> tuple[Callable, Callable]: # payload enc, dec pair + # ... + # TODO: do we need a restart api? # async def restart(self) -> None: # pass diff --git a/tractor/msg.py b/tractor/msg.py index 2a03b8f..6a4eb32 100644 --- a/tractor/msg.py +++ b/tractor/msg.py @@ -45,10 +45,10 @@ Built-in messaging patterns, types, APIs and helpers. from __future__ import annotations from contextlib import contextmanager as cm from pkgutil import resolve_name -from typing import Union, Any +from typing import Union, Any, Optional -from msgspec import Struct +from msgspec import Struct, Raw from msgspec.msgpack import ( Encoder, Decoder, @@ -125,8 +125,6 @@ def configure_native_msgs( in all IPC transports and pop the codec on exit. ''' - global _lifo_codecs - # See "tagged unions" docs: # https://jcristharif.com/msgspec/structs.html#tagged-unions @@ -149,3 +147,91 @@ def configure_native_msgs( finally: print("NONONONONON") _lifo_codecs.pop() + + +class Header(Struct, tag=True): + ''' + A msg header which defines payload properties + + ''' + uid: str + msgtype: Optional[str] = None + + +class Msg(Struct, tag=True): + ''' + The "god" msg type, a box for task level msg types. + + ''' + header: Header + payload: Raw + + +_root_dec = Decoder(Msg) +_root_enc = Encoder() + +# sub-decoders for retreiving embedded +# payload data and decoding to a sender +# side defined (struct) type. +_subdecs: dict[ + Optional[str], + Decoder] = { + None: Decoder(Any), +} + + +@cm +def enable_context( + msg_subtypes: list[list[Struct]] +) -> Decoder: + + for types in msg_subtypes: + first = types[0] + + # register using the default tag_field of "type" + # which seems to map to the class "name". + tags = [first.__name__] + + # create a tagged union decoder for this type set + type_union = Union[first] + for typ in types[1:]: + type_union |= typ + tags.append(typ.__name__) + + dec = Decoder(type_union) + + # register all tags for this union sub-decoder + for tag in tags: + _subdecs[tag] = dec + try: + yield dec + finally: + for tag in tags: + _subdecs.pop(tag) + + +def decmsg(msg: Msg) -> Any: + msg = _root_dec.decode(msg) + tag_field = msg.header.msgtype + dec = _subdecs[tag_field] + return dec.decode(msg.payload) + + +def encmsg( + dialog_id: str | int, + payload: Any, +) -> Msg: + + tag_field = None + + plbytes = _root_enc.encode(payload) + if b'type' in plbytes: + assert isinstance(payload, Struct) + tag_field = type(payload).__name__ + payload = Raw(plbytes) + + msg = Msg( + Header(dialog_id, tag_field), + payload, + ) + return _root_enc.encode(msg)