diff --git a/default.nix b/default.nix index 08e46d06..9036aecf 100644 --- a/default.nix +++ b/default.nix @@ -11,9 +11,4 @@ pkgs.mkShell { LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath nativeBuildInputs; TMPDIR = "/tmp"; - - shellHook = '' - set -e - uv venv .venv --python=3.12 - ''; } diff --git a/tests/test_ext_types_msgspec.py b/tests/test_ext_types_msgspec.py index b334b64f..d82759d6 100644 --- a/tests/test_ext_types_msgspec.py +++ b/tests/test_ext_types_msgspec.py @@ -5,6 +5,7 @@ Low-level functional audits for our B~) ''' +from __future__ import annotations from contextlib import ( contextmanager as cm, # nullcontext, @@ -20,7 +21,7 @@ from msgspec import ( # structs, # msgpack, Raw, - # Struct, + Struct, ValidationError, ) import pytest @@ -46,6 +47,11 @@ from tractor.msg import ( apply_codec, current_codec, ) +from tractor.msg._codec import ( + default_builtins, + mk_dec_hook, + mk_codec_from_spec, +) from tractor.msg.types import ( log, Started, @@ -743,6 +749,143 @@ def test_ext_types_over_ipc( assert exc.boxed_type is TypeError +''' +Test the auto enc & dec hooks + +Create a codec which will work for: + - builtins + - custom types + - lists of custom types +''' + + +class TestBytesClass(Struct, tag=True): + raw: bytes + + def encode(self) -> bytes: + return self.raw + + @classmethod + def from_bytes(self, raw: bytes) -> TestBytesClass: + return TestBytesClass(raw=raw) + + +class TestStrClass(Struct, tag=True): + s: str + + def encode(self) -> str: + return self.s + + @classmethod + def from_str(self, s: str) -> TestStrClass: + return TestStrClass(s=s) + + +class TestIntClass(Struct, tag=True): + num: int + + def encode(self) -> int: + return self.num + + @classmethod + def from_int(self, num: int) -> TestIntClass: + return TestIntClass(num=num) + + +builtins = tuple(( + builtin + for builtin in default_builtins + if builtin is not list +)) + +TestClasses = (TestBytesClass, TestStrClass, TestIntClass) + + +TestSpec = ( + *TestClasses, list[Union[*TestClasses]] +) + + +test_codec = mk_codec_from_spec( + spec=TestSpec +) + + +@tractor.context +async def child_custom_codec( + ctx: tractor.Context, + msgs: list[Union[*TestSpec]], +): + ''' + Apply codec and send all msgs passed through stream + + ''' + with ( + apply_codec(test_codec), + limit_plds( + test_codec.pld_spec, + dec_hook=mk_dec_hook(TestSpec), + ext_types=TestSpec + builtins + ), + ): + await ctx.started(None) + async with ctx.open_stream() as stream: + for msg in msgs: + await stream.send(msg) + + +def test_multi_custom_codec(): + ''' + Open subactor setup codec and pld_rx and wait to receive & assert from + stream + + ''' + msgs = [ + None, + True, False, + 0xdeadbeef, + .42069, + b'deadbeef', + TestBytesClass(raw=b'deadbeef'), + TestStrClass(s='deadbeef'), + TestIntClass(num=0xdeadbeef), + [ + TestBytesClass(raw=b'deadbeef'), + TestStrClass(s='deadbeef'), + TestIntClass(num=0xdeadbeef), + ] + ] + + async def main(): + async with tractor.open_nursery() as an: + p: tractor.Portal = await an.start_actor( + 'child', + enable_modules=[__name__], + ) + async with ( + p.open_context( + child_custom_codec, + msgs=msgs, + ) as (ctx, _), + ctx.open_stream() as ipc + ): + with ( + apply_codec(test_codec), + limit_plds( + test_codec.pld_spec, + dec_hook=mk_dec_hook(TestSpec), + ext_types=TestSpec + builtins + ) + ): + msg_iter = iter(msgs) + async for recv_msg in ipc: + assert recv_msg == next(msg_iter) + + await p.cancel_actor() + + trio.run(main) + + # def chk_pld_type( # payload_spec: Type[Struct]|Any, # pld: Any, diff --git a/tractor/msg/_codec.py b/tractor/msg/_codec.py index 1e9623af..f8bc3aa6 100644 --- a/tractor/msg/_codec.py +++ b/tractor/msg/_codec.py @@ -39,13 +39,11 @@ from contextvars import ( ) import textwrap from typing import ( - Any, - Callable, - Protocol, - Type, TYPE_CHECKING, - TypeVar, + Any, + Type, Union, + Callable, ) from types import ModuleType @@ -54,6 +52,13 @@ from msgspec import ( msgpack, Raw, ) +from msgspec.inspect import ( + CustomType, + UnionType, + SetType, + ListType, + TupleType +) # TODO: see notes below from @mikenerone.. # from tricycle import TreeVar @@ -81,7 +86,7 @@ class MsgDec(Struct): ''' _dec: msgpack.Decoder - # _ext_types_box: Struct|None = None + _ext_types_boxes: dict[Type, Struct] = {} @property def dec(self) -> msgpack.Decoder: @@ -226,6 +231,8 @@ def mk_dec( f'ext_types = {ext_types!r}\n' ) + _boxed_structs: dict[Type, Struct] = {} + if dec_hook: if ext_types is None: raise TypeError( @@ -237,17 +244,15 @@ def mk_dec( f'ext_types = {ext_types!r}\n' ) - # XXX, i *thought* we would require a boxing struct as per docs, - # https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types - # |_ see comment, - # > Note that typed deserialization is required for - # > successful roundtripping here, so we pass `MyMessage` to - # > `Decoder`. - # - # BUT, turns out as long as you spec a union with `Raw` it - # will work? kk B) - # - # maybe_box_struct = mk_boxed_ext_struct(ext_types) + if len(ext_types) > 1: + _boxed_structs = mk_boxed_ext_structs(ext_types) + ext_types = [ + etype + for etype in ext_types + if etype not in _boxed_structs + ] + ext_types += list(_boxed_structs.values()) + spec = Raw | Union[*ext_types] return MsgDec( @@ -255,29 +260,26 @@ def mk_dec( type=spec, # like `MsgType[Any]` dec_hook=dec_hook, ), + _ext_types_boxes=_boxed_structs ) -# TODO? remove since didn't end up needing this? -def mk_boxed_ext_struct( +def mk_boxed_ext_structs( ext_types: list[Type], ) -> Struct: - # NOTE, originally was to wrap non-msgpack-supported "extension - # types" in a field-typed boxing struct, see notes around the - # `dec_hook()` branch in `mk_dec()`. - ext_types_union = Union[*ext_types] - repr_ext_types_union: str = ( - str(ext_types_union) - or - "|".join(ext_types) - ) - BoxedExtType = msgspec.defstruct( - f'BoxedExts[{repr_ext_types_union}]', - fields=[ - ('boxed', ext_types_union), - ], - ) - return BoxedExtType + box_types: dict[Type, Struct] = {} + for ext_type in ext_types: + info = msgspec.inspect.type_info(ext_type) + if isinstance(info, CustomType): + box_types[ext_type] = msgspec.defstruct( + f'Box{ext_type.__name__}', + tag=True, + fields=[ + ('inner', ext_type), + ], + ) + + return box_types def unpack_spec_types( @@ -378,7 +380,7 @@ class MsgCodec(Struct): _dec: msgpack.Decoder _pld_spec: Type[Struct]|Raw|Any - # _ext_types_box: Struct|None = None + _ext_types_boxes: dict[Type, Struct] = {} def __repr__(self) -> str: speclines: str = textwrap.indent( @@ -465,45 +467,29 @@ class MsgCodec(Struct): ''' __tracebackhide__: bool = hide_tb - if use_buf: - self._enc.encode_into(py_obj, self._buf) - return self._buf - return self._enc.encode(py_obj) - # try: - # return self._enc.encode(py_obj) - # except TypeError as typerr: - # typerr.add_note( - # '|_src error from `msgspec`' - # # f'|_{self._enc.encode!r}' - # ) - # raise typerr + try: - # TODO! REMOVE once i'm confident we won't ever need it! - # - # box: Struct = self._ext_types_box - # if ( - # as_ext_type - # or - # ( - # # XXX NOTE, auto-detect if the input type - # box - # and - # (ext_types := unpack_spec_types( - # spec=box.__annotations__['boxed']) - # ) - # ) - # ): - # match py_obj: - # # case PayloadMsg(pld=pld) if ( - # # type(pld) in ext_types - # # ): - # # py_obj.pld = box(boxed=py_obj) - # # breakpoint() - # case _ if ( - # type(py_obj) in ext_types - # ): - # py_obj = box(boxed=py_obj) + box: Struct|None = self._ext_types_boxes.get(type(py_obj), None) + if ( + as_ext_type + or + box + ): + py_obj = box(inner=py_obj) + + if use_buf: + self._enc.encode_into(py_obj, self._buf) + return self._buf + + return self._enc.encode(py_obj) + + except TypeError as typerr: + typerr.add_note( + '|_src error from `msgspec`' + # f'|_{self._enc.encode!r}' + ) + raise typerr @property def dec(self) -> msgpack.Decoder: @@ -565,11 +551,6 @@ def mk_codec( enc_hook: Callable|None = None, ext_types: list[Type]|None = None, - # optionally provided msg-decoder from which we pull its, - # |_.dec_hook() - # |_.type - ext_dec: MsgDec|None = None - # # ?TODO? other params we might want to support # Encoder: # write_buffer_size=write_buffer_size, @@ -597,12 +578,6 @@ def mk_codec( ) dec_hook: Callable|None = None - if ext_dec: - dec: msgspec.Decoder = ext_dec.dec - dec_hook = dec.dec_hook - pld_spec |= dec.type - if ext_types: - pld_spec |= Union[*ext_types] # (manually) generate a msg-spec (how appropes) for all relevant # payload-boxing-struct-msg-types, parameterizing the @@ -630,10 +605,16 @@ def mk_codec( enc = msgpack.Encoder( enc_hook=enc_hook, ) + + boxes = {} + if ext_types and len(ext_types) > 1: + boxes = mk_boxed_ext_structs(ext_types) + codec = MsgCodec( _enc=enc, _dec=dec, _pld_spec=pld_spec, + _ext_types_boxes=boxes ) # sanity on expected backend support assert codec.lib.__name__ == libname @@ -809,78 +790,282 @@ def limit_msg_spec( assert curr_codec is current_codec() -# XXX: msgspec won't allow this with non-struct custom types -# like `NamespacePath`!@! -# @cm -# def extend_msg_spec( -# payload_spec: Union[Type[Struct]], +''' +Encoder / Decoder generic hook factory -# ) -> MsgCodec: -# ''' -# Extend the current `MsgCodec.pld_spec` (type set) by extending -# the payload spec to **include** the types specified by -# `payload_spec`. - -# ''' -# codec: MsgCodec = current_codec() -# pld_spec: Union[Type] = codec.pld_spec -# extended_spec: Union[Type] = pld_spec|payload_spec - -# with limit_msg_spec(payload_types=extended_spec) as ext_codec: -# # import pdbp; pdbp.set_trace() -# assert ext_codec.pld_spec == extended_spec -# yield ext_codec -# -# ^-TODO-^ is it impossible to make something like this orr!? - -# TODO: make an auto-custom hook generator from a set of input custom -# types? -# -[ ] below is a proto design using a `TypeCodec` idea? -# -# type var for the expected interchange-lib's -# IPC-transport type when not available as a built-in -# serialization output. -WireT = TypeVar('WireT') +''' -# TODO: some kinda (decorator) API for built-in subtypes -# that builds this implicitly by inspecting the `mro()`? -class TypeCodec(Protocol): +# builtins we can have in same pld_spec as custom types +default_builtins = ( + None, + bool, + int, + float, + bytes, + list +) + + +# spec definition type +TypeSpec = ( + Type | + Union[Type] | + list[Type] | + tuple[Type] | + set[Type] +) + + +class TypeCodec: ''' - A per-custom-type wire-transport serialization translator - description type. + This class describes a way of encoding to or decoding from a "wire type", + objects that have `encode_fn` and `decode_fn` can be used with + `.encode/.decode`. ''' - src_type: Type - wire_type: WireT - def encode(obj: Type) -> WireT: - ... + def __init__( + self, + wire_type: Type, + decode_fn: str, + encode_fn: str = 'encode', + ): + self._encode_fn: str = encode_fn + self._decode_fn: str = decode_fn + self._wire_type: Type = wire_type - def decode( - obj_type: Type[WireT], - obj: WireT, - ) -> Type: - ... + @property + def encode_fn(self) -> str: + return self._encode_fn + + @property + def decode_fn(self) -> str: + return self._decode_fn + + @property + def wire_type(self) -> str: + return self._wire_type + + def is_type_compat(self, obj: any) -> bool: + return ( + hasattr(obj, self._encode_fn) + and + hasattr(obj, self._decode_fn) + ) + + def encode(self, obj: any) -> any: + return getattr(obj, self._encode_fn)(obj) + + def decode(self, cls: Type, raw: any) -> any: + return getattr(cls, self._decode_fn)(raw) -class MsgpackTypeCodec(TypeCodec): - ... +''' +Default codec descriptions for wire types: + + - bytes + - str + - int + +''' -def mk_codec_hooks( - type_codecs: list[TypeCodec], +BytesCodec = TypeCodec( + decode_fn='from_bytes', + wire_type=bytes +) -) -> tuple[Callable, Callable]: + +StrCodec = TypeCodec( + decode_fn='from_str', + wire_type=str +) + + +IntCodec = TypeCodec( + decode_fn='from_int', + wire_type=int +) + + +default_codecs: dict[Type, TypeCodec] = { + bytes: BytesCodec, + str: StrCodec, + int: IntCodec +} + + +def mk_spec_set( + spec: TypeSpec +) -> set[Type]: ''' - Deliver a `enc_hook()`/`dec_hook()` pair which handle - manual convertion from an input `Type` set such that whenever - the `TypeCodec.filter()` predicate matches the - `TypeCodec.decode()` is called on the input native object by - the `dec_hook()` and whenever the - `isiinstance(obj, TypeCodec.type)` matches against an - `enc_hook(obj=obj)` the return value is taken from a - `TypeCodec.encode(obj)` callback. + Given any of the different spec definitions, always return a `set[Type]` + with each spec type as an item. + + - When passed list|tuple|set do nothing + - When passed a single type we wrap it in tuple + - When passed a Union we wrap its inner types in tuple ''' - ... + if not ( + isinstance(spec, set) + or + isinstance(spec, list) + or + isinstance(spec, tuple) + ): + spec_info = msgspec.inspect.type_info(spec) + match spec_info: + case UnionType(): + return set(( + t.cls + for t in spec_info.types + )) + + case ( + SetType() | + ListType() | + TupleType() + ): + return set((spec_info.item_type, )) + + case _: + return set((spec, )) + + return set(spec) + + +def mk_codec_map_from_spec( + spec: TypeSpec, + codecs: dict[Type, TypeCodec] = default_codecs +) -> dict[Type, TypeCodec]: + ''' + Generate a map of spec type -> supported codec + + ''' + + spec: set[Type] = mk_spec_set(spec) + + spec_codecs: dict[Type, TypeCodec] = {} + for t in spec: + for codec in codecs.values(): + if codec.is_type_compat(t): + spec_codecs[t] = codec + + return spec_codecs + + +def mk_enc_hook( + spec: TypeSpec, + with_builtins: bool = True, + builtins: set[Type] = default_builtins, + codecs: dict[Type, TypeCodec] = default_codecs +) -> Callable: + ''' + Given a type specification return a msgspec enc_hook fn + + ''' + spec_codecs = mk_codec_map_from_spec(spec) + + def enc_hook(obj: any) -> any: + t = type(obj) + maybe_codec = spec_codecs.get(t, None) + if maybe_codec: + return maybe_codec.encode(obj) + + # passthrough built ins + if builtins and t in builtins: + return obj + + raise NotImplementedError( + f"Objects of type {type(obj)} are not supported:\n{obj}" + ) + + return enc_hook + + +def mk_dec_hook( + spec: TypeSpec, + with_builtins: bool = True, + builtins: set[Type] = default_builtins, + codecs: dict[Type, TypeCodec] = default_codecs +) -> Callable: + ''' + Given a type specification return a msgspec dec_hook fn + + ''' + spec_codecs = mk_codec_map_from_spec(spec) + + def dec_hook(t: Type, obj: any) -> any: + maybe_codec = spec_codecs.get(t, None) + if maybe_codec: + return maybe_codec.decode(t, obj) + + # passthrough builtins + if builtins and type in builtins: + return obj + + raise NotImplementedError( + f"Objects of type {type} are not supported from {obj}" + ) + + return dec_hook + + +def mk_codec_hooks(*args, **kwargs) -> tuple[Callable, Callable]: + ''' + Given a type specification return a msgspec enc & dec hook fn pair + + ''' + return ( + mk_enc_hook(*args, **kwargs), + + mk_dec_hook(*args, **kwargs) + ) + + +def mk_codec_from_spec( + spec: TypeSpec, + with_builtins: bool = True, + builtins: set[Type] = default_builtins, + codecs: dict[Type, TypeCodec] = default_codecs +) -> MsgCodec: + ''' + Given a type specification return a MsgCodec + + ''' + spec: set[Type] = mk_spec_set(spec) + + return mk_codec( + enc_hook=mk_enc_hook( + spec, + with_builtins=with_builtins, + builtins=builtins, + codecs=codecs + ), + ext_types=spec + ) + + +def mk_msgpack_codec( + spec: TypeSpec, + with_builtins: bool = True, + builtins: set[Type] = default_builtins, + codecs: dict[Type, TypeCodec] = default_codecs +) -> tuple[msgpack.Encoder, msgpack.Decoder]: + ''' + Get a msgpack Encoder, Decoder pair for a given type spec + + ''' + spec: set[Type] = mk_spec_set(spec) + + enc_hook, dec_hook = mk_codec_hooks( + spec, + with_builtins=with_builtins, + builtins=builtins, + codecs=codecs + ) + encoder = msgpack.Encoder(enc_hook=enc_hook) + decoder = msgpack.Decoder(spec, dec_hook=dec_hook) + return encoder, decoder