Re-add boxed struct type system on _codec & create enc/dec hook auto factory
parent
112ed27cda
commit
51746a71ac
|
@ -11,9 +11,4 @@ pkgs.mkShell {
|
||||||
|
|
||||||
LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath nativeBuildInputs;
|
LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath nativeBuildInputs;
|
||||||
TMPDIR = "/tmp";
|
TMPDIR = "/tmp";
|
||||||
|
|
||||||
shellHook = ''
|
|
||||||
set -e
|
|
||||||
uv venv .venv --python=3.12
|
|
||||||
'';
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ Low-level functional audits for our
|
||||||
B~)
|
B~)
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
from __future__ import annotations
|
||||||
from contextlib import (
|
from contextlib import (
|
||||||
contextmanager as cm,
|
contextmanager as cm,
|
||||||
# nullcontext,
|
# nullcontext,
|
||||||
|
@ -20,7 +21,7 @@ from msgspec import (
|
||||||
# structs,
|
# structs,
|
||||||
# msgpack,
|
# msgpack,
|
||||||
Raw,
|
Raw,
|
||||||
# Struct,
|
Struct,
|
||||||
ValidationError,
|
ValidationError,
|
||||||
)
|
)
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -46,6 +47,11 @@ from tractor.msg import (
|
||||||
apply_codec,
|
apply_codec,
|
||||||
current_codec,
|
current_codec,
|
||||||
)
|
)
|
||||||
|
from tractor.msg._codec import (
|
||||||
|
default_builtins,
|
||||||
|
mk_dec_hook,
|
||||||
|
mk_codec_from_spec,
|
||||||
|
)
|
||||||
from tractor.msg.types import (
|
from tractor.msg.types import (
|
||||||
log,
|
log,
|
||||||
Started,
|
Started,
|
||||||
|
@ -743,6 +749,143 @@ def test_ext_types_over_ipc(
|
||||||
assert exc.boxed_type is TypeError
|
assert exc.boxed_type is TypeError
|
||||||
|
|
||||||
|
|
||||||
|
'''
|
||||||
|
Test the auto enc & dec hooks
|
||||||
|
|
||||||
|
Create a codec which will work for:
|
||||||
|
- builtins
|
||||||
|
- custom types
|
||||||
|
- lists of custom types
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
class TestBytesClass(Struct, tag=True):
|
||||||
|
raw: bytes
|
||||||
|
|
||||||
|
def encode(self) -> bytes:
|
||||||
|
return self.raw
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_bytes(self, raw: bytes) -> TestBytesClass:
|
||||||
|
return TestBytesClass(raw=raw)
|
||||||
|
|
||||||
|
|
||||||
|
class TestStrClass(Struct, tag=True):
|
||||||
|
s: str
|
||||||
|
|
||||||
|
def encode(self) -> str:
|
||||||
|
return self.s
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_str(self, s: str) -> TestStrClass:
|
||||||
|
return TestStrClass(s=s)
|
||||||
|
|
||||||
|
|
||||||
|
class TestIntClass(Struct, tag=True):
|
||||||
|
num: int
|
||||||
|
|
||||||
|
def encode(self) -> int:
|
||||||
|
return self.num
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_int(self, num: int) -> TestIntClass:
|
||||||
|
return TestIntClass(num=num)
|
||||||
|
|
||||||
|
|
||||||
|
builtins = tuple((
|
||||||
|
builtin
|
||||||
|
for builtin in default_builtins
|
||||||
|
if builtin is not list
|
||||||
|
))
|
||||||
|
|
||||||
|
TestClasses = (TestBytesClass, TestStrClass, TestIntClass)
|
||||||
|
|
||||||
|
|
||||||
|
TestSpec = (
|
||||||
|
*TestClasses, list[Union[*TestClasses]]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
test_codec = mk_codec_from_spec(
|
||||||
|
spec=TestSpec
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@tractor.context
|
||||||
|
async def child_custom_codec(
|
||||||
|
ctx: tractor.Context,
|
||||||
|
msgs: list[Union[*TestSpec]],
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
Apply codec and send all msgs passed through stream
|
||||||
|
|
||||||
|
'''
|
||||||
|
with (
|
||||||
|
apply_codec(test_codec),
|
||||||
|
limit_plds(
|
||||||
|
test_codec.pld_spec,
|
||||||
|
dec_hook=mk_dec_hook(TestSpec),
|
||||||
|
ext_types=TestSpec + builtins
|
||||||
|
),
|
||||||
|
):
|
||||||
|
await ctx.started(None)
|
||||||
|
async with ctx.open_stream() as stream:
|
||||||
|
for msg in msgs:
|
||||||
|
await stream.send(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_custom_codec():
|
||||||
|
'''
|
||||||
|
Open subactor setup codec and pld_rx and wait to receive & assert from
|
||||||
|
stream
|
||||||
|
|
||||||
|
'''
|
||||||
|
msgs = [
|
||||||
|
None,
|
||||||
|
True, False,
|
||||||
|
0xdeadbeef,
|
||||||
|
.42069,
|
||||||
|
b'deadbeef',
|
||||||
|
TestBytesClass(raw=b'deadbeef'),
|
||||||
|
TestStrClass(s='deadbeef'),
|
||||||
|
TestIntClass(num=0xdeadbeef),
|
||||||
|
[
|
||||||
|
TestBytesClass(raw=b'deadbeef'),
|
||||||
|
TestStrClass(s='deadbeef'),
|
||||||
|
TestIntClass(num=0xdeadbeef),
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
async with tractor.open_nursery() as an:
|
||||||
|
p: tractor.Portal = await an.start_actor(
|
||||||
|
'child',
|
||||||
|
enable_modules=[__name__],
|
||||||
|
)
|
||||||
|
async with (
|
||||||
|
p.open_context(
|
||||||
|
child_custom_codec,
|
||||||
|
msgs=msgs,
|
||||||
|
) as (ctx, _),
|
||||||
|
ctx.open_stream() as ipc
|
||||||
|
):
|
||||||
|
with (
|
||||||
|
apply_codec(test_codec),
|
||||||
|
limit_plds(
|
||||||
|
test_codec.pld_spec,
|
||||||
|
dec_hook=mk_dec_hook(TestSpec),
|
||||||
|
ext_types=TestSpec + builtins
|
||||||
|
)
|
||||||
|
):
|
||||||
|
msg_iter = iter(msgs)
|
||||||
|
async for recv_msg in ipc:
|
||||||
|
assert recv_msg == next(msg_iter)
|
||||||
|
|
||||||
|
await p.cancel_actor()
|
||||||
|
|
||||||
|
trio.run(main)
|
||||||
|
|
||||||
|
|
||||||
# def chk_pld_type(
|
# def chk_pld_type(
|
||||||
# payload_spec: Type[Struct]|Any,
|
# payload_spec: Type[Struct]|Any,
|
||||||
# pld: Any,
|
# pld: Any,
|
||||||
|
|
|
@ -39,13 +39,11 @@ from contextvars import (
|
||||||
)
|
)
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Protocol,
|
|
||||||
Type,
|
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
TypeVar,
|
Any,
|
||||||
|
Type,
|
||||||
Union,
|
Union,
|
||||||
|
Callable,
|
||||||
)
|
)
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
|
|
||||||
|
@ -54,6 +52,13 @@ from msgspec import (
|
||||||
msgpack,
|
msgpack,
|
||||||
Raw,
|
Raw,
|
||||||
)
|
)
|
||||||
|
from msgspec.inspect import (
|
||||||
|
CustomType,
|
||||||
|
UnionType,
|
||||||
|
SetType,
|
||||||
|
ListType,
|
||||||
|
TupleType
|
||||||
|
)
|
||||||
# TODO: see notes below from @mikenerone..
|
# TODO: see notes below from @mikenerone..
|
||||||
# from tricycle import TreeVar
|
# from tricycle import TreeVar
|
||||||
|
|
||||||
|
@ -81,7 +86,7 @@ class MsgDec(Struct):
|
||||||
|
|
||||||
'''
|
'''
|
||||||
_dec: msgpack.Decoder
|
_dec: msgpack.Decoder
|
||||||
# _ext_types_box: Struct|None = None
|
_ext_types_boxes: dict[Type, Struct] = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dec(self) -> msgpack.Decoder:
|
def dec(self) -> msgpack.Decoder:
|
||||||
|
@ -226,6 +231,8 @@ def mk_dec(
|
||||||
f'ext_types = {ext_types!r}\n'
|
f'ext_types = {ext_types!r}\n'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_boxed_structs: dict[Type, Struct] = {}
|
||||||
|
|
||||||
if dec_hook:
|
if dec_hook:
|
||||||
if ext_types is None:
|
if ext_types is None:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
|
@ -237,17 +244,15 @@ def mk_dec(
|
||||||
f'ext_types = {ext_types!r}\n'
|
f'ext_types = {ext_types!r}\n'
|
||||||
)
|
)
|
||||||
|
|
||||||
# XXX, i *thought* we would require a boxing struct as per docs,
|
if len(ext_types) > 1:
|
||||||
# https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types
|
_boxed_structs = mk_boxed_ext_structs(ext_types)
|
||||||
# |_ see comment,
|
ext_types = [
|
||||||
# > Note that typed deserialization is required for
|
etype
|
||||||
# > successful roundtripping here, so we pass `MyMessage` to
|
for etype in ext_types
|
||||||
# > `Decoder`.
|
if etype not in _boxed_structs
|
||||||
#
|
]
|
||||||
# BUT, turns out as long as you spec a union with `Raw` it
|
ext_types += list(_boxed_structs.values())
|
||||||
# will work? kk B)
|
|
||||||
#
|
|
||||||
# maybe_box_struct = mk_boxed_ext_struct(ext_types)
|
|
||||||
spec = Raw | Union[*ext_types]
|
spec = Raw | Union[*ext_types]
|
||||||
|
|
||||||
return MsgDec(
|
return MsgDec(
|
||||||
|
@ -255,29 +260,26 @@ def mk_dec(
|
||||||
type=spec, # like `MsgType[Any]`
|
type=spec, # like `MsgType[Any]`
|
||||||
dec_hook=dec_hook,
|
dec_hook=dec_hook,
|
||||||
),
|
),
|
||||||
|
_ext_types_boxes=_boxed_structs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO? remove since didn't end up needing this?
|
def mk_boxed_ext_structs(
|
||||||
def mk_boxed_ext_struct(
|
|
||||||
ext_types: list[Type],
|
ext_types: list[Type],
|
||||||
) -> Struct:
|
) -> Struct:
|
||||||
# NOTE, originally was to wrap non-msgpack-supported "extension
|
box_types: dict[Type, Struct] = {}
|
||||||
# types" in a field-typed boxing struct, see notes around the
|
for ext_type in ext_types:
|
||||||
# `dec_hook()` branch in `mk_dec()`.
|
info = msgspec.inspect.type_info(ext_type)
|
||||||
ext_types_union = Union[*ext_types]
|
if isinstance(info, CustomType):
|
||||||
repr_ext_types_union: str = (
|
box_types[ext_type] = msgspec.defstruct(
|
||||||
str(ext_types_union)
|
f'Box{ext_type.__name__}',
|
||||||
or
|
tag=True,
|
||||||
"|".join(ext_types)
|
|
||||||
)
|
|
||||||
BoxedExtType = msgspec.defstruct(
|
|
||||||
f'BoxedExts[{repr_ext_types_union}]',
|
|
||||||
fields=[
|
fields=[
|
||||||
('boxed', ext_types_union),
|
('inner', ext_type),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
return BoxedExtType
|
|
||||||
|
return box_types
|
||||||
|
|
||||||
|
|
||||||
def unpack_spec_types(
|
def unpack_spec_types(
|
||||||
|
@ -378,7 +380,7 @@ class MsgCodec(Struct):
|
||||||
_dec: msgpack.Decoder
|
_dec: msgpack.Decoder
|
||||||
_pld_spec: Type[Struct]|Raw|Any
|
_pld_spec: Type[Struct]|Raw|Any
|
||||||
|
|
||||||
# _ext_types_box: Struct|None = None
|
_ext_types_boxes: dict[Type, Struct] = {}
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
speclines: str = textwrap.indent(
|
speclines: str = textwrap.indent(
|
||||||
|
@ -465,45 +467,29 @@ class MsgCodec(Struct):
|
||||||
|
|
||||||
'''
|
'''
|
||||||
__tracebackhide__: bool = hide_tb
|
__tracebackhide__: bool = hide_tb
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
box: Struct|None = self._ext_types_boxes.get(type(py_obj), None)
|
||||||
|
if (
|
||||||
|
as_ext_type
|
||||||
|
or
|
||||||
|
box
|
||||||
|
):
|
||||||
|
py_obj = box(inner=py_obj)
|
||||||
|
|
||||||
if use_buf:
|
if use_buf:
|
||||||
self._enc.encode_into(py_obj, self._buf)
|
self._enc.encode_into(py_obj, self._buf)
|
||||||
return self._buf
|
return self._buf
|
||||||
|
|
||||||
return self._enc.encode(py_obj)
|
return self._enc.encode(py_obj)
|
||||||
# try:
|
|
||||||
# return self._enc.encode(py_obj)
|
|
||||||
# except TypeError as typerr:
|
|
||||||
# typerr.add_note(
|
|
||||||
# '|_src error from `msgspec`'
|
|
||||||
# # f'|_{self._enc.encode!r}'
|
|
||||||
# )
|
|
||||||
# raise typerr
|
|
||||||
|
|
||||||
# TODO! REMOVE once i'm confident we won't ever need it!
|
except TypeError as typerr:
|
||||||
#
|
typerr.add_note(
|
||||||
# box: Struct = self._ext_types_box
|
'|_src error from `msgspec`'
|
||||||
# if (
|
# f'|_{self._enc.encode!r}'
|
||||||
# as_ext_type
|
)
|
||||||
# or
|
raise typerr
|
||||||
# (
|
|
||||||
# # XXX NOTE, auto-detect if the input type
|
|
||||||
# box
|
|
||||||
# and
|
|
||||||
# (ext_types := unpack_spec_types(
|
|
||||||
# spec=box.__annotations__['boxed'])
|
|
||||||
# )
|
|
||||||
# )
|
|
||||||
# ):
|
|
||||||
# match py_obj:
|
|
||||||
# # case PayloadMsg(pld=pld) if (
|
|
||||||
# # type(pld) in ext_types
|
|
||||||
# # ):
|
|
||||||
# # py_obj.pld = box(boxed=py_obj)
|
|
||||||
# # breakpoint()
|
|
||||||
# case _ if (
|
|
||||||
# type(py_obj) in ext_types
|
|
||||||
# ):
|
|
||||||
# py_obj = box(boxed=py_obj)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dec(self) -> msgpack.Decoder:
|
def dec(self) -> msgpack.Decoder:
|
||||||
|
@ -565,11 +551,6 @@ def mk_codec(
|
||||||
enc_hook: Callable|None = None,
|
enc_hook: Callable|None = None,
|
||||||
ext_types: list[Type]|None = None,
|
ext_types: list[Type]|None = None,
|
||||||
|
|
||||||
# optionally provided msg-decoder from which we pull its,
|
|
||||||
# |_.dec_hook()
|
|
||||||
# |_.type
|
|
||||||
ext_dec: MsgDec|None = None
|
|
||||||
#
|
|
||||||
# ?TODO? other params we might want to support
|
# ?TODO? other params we might want to support
|
||||||
# Encoder:
|
# Encoder:
|
||||||
# write_buffer_size=write_buffer_size,
|
# write_buffer_size=write_buffer_size,
|
||||||
|
@ -597,12 +578,6 @@ def mk_codec(
|
||||||
)
|
)
|
||||||
|
|
||||||
dec_hook: Callable|None = None
|
dec_hook: Callable|None = None
|
||||||
if ext_dec:
|
|
||||||
dec: msgspec.Decoder = ext_dec.dec
|
|
||||||
dec_hook = dec.dec_hook
|
|
||||||
pld_spec |= dec.type
|
|
||||||
if ext_types:
|
|
||||||
pld_spec |= Union[*ext_types]
|
|
||||||
|
|
||||||
# (manually) generate a msg-spec (how appropes) for all relevant
|
# (manually) generate a msg-spec (how appropes) for all relevant
|
||||||
# payload-boxing-struct-msg-types, parameterizing the
|
# payload-boxing-struct-msg-types, parameterizing the
|
||||||
|
@ -630,10 +605,16 @@ def mk_codec(
|
||||||
enc = msgpack.Encoder(
|
enc = msgpack.Encoder(
|
||||||
enc_hook=enc_hook,
|
enc_hook=enc_hook,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
boxes = {}
|
||||||
|
if ext_types and len(ext_types) > 1:
|
||||||
|
boxes = mk_boxed_ext_structs(ext_types)
|
||||||
|
|
||||||
codec = MsgCodec(
|
codec = MsgCodec(
|
||||||
_enc=enc,
|
_enc=enc,
|
||||||
_dec=dec,
|
_dec=dec,
|
||||||
_pld_spec=pld_spec,
|
_pld_spec=pld_spec,
|
||||||
|
_ext_types_boxes=boxes
|
||||||
)
|
)
|
||||||
# sanity on expected backend support
|
# sanity on expected backend support
|
||||||
assert codec.lib.__name__ == libname
|
assert codec.lib.__name__ == libname
|
||||||
|
@ -809,78 +790,282 @@ def limit_msg_spec(
|
||||||
assert curr_codec is current_codec()
|
assert curr_codec is current_codec()
|
||||||
|
|
||||||
|
|
||||||
# XXX: msgspec won't allow this with non-struct custom types
|
'''
|
||||||
# like `NamespacePath`!@!
|
Encoder / Decoder generic hook factory
|
||||||
# @cm
|
|
||||||
# def extend_msg_spec(
|
|
||||||
# payload_spec: Union[Type[Struct]],
|
|
||||||
|
|
||||||
# ) -> MsgCodec:
|
'''
|
||||||
# '''
|
|
||||||
# Extend the current `MsgCodec.pld_spec` (type set) by extending
|
|
||||||
# the payload spec to **include** the types specified by
|
|
||||||
# `payload_spec`.
|
|
||||||
|
|
||||||
# '''
|
|
||||||
# codec: MsgCodec = current_codec()
|
|
||||||
# pld_spec: Union[Type] = codec.pld_spec
|
|
||||||
# extended_spec: Union[Type] = pld_spec|payload_spec
|
|
||||||
|
|
||||||
# with limit_msg_spec(payload_types=extended_spec) as ext_codec:
|
|
||||||
# # import pdbp; pdbp.set_trace()
|
|
||||||
# assert ext_codec.pld_spec == extended_spec
|
|
||||||
# yield ext_codec
|
|
||||||
#
|
|
||||||
# ^-TODO-^ is it impossible to make something like this orr!?
|
|
||||||
|
|
||||||
# TODO: make an auto-custom hook generator from a set of input custom
|
|
||||||
# types?
|
|
||||||
# -[ ] below is a proto design using a `TypeCodec` idea?
|
|
||||||
#
|
|
||||||
# type var for the expected interchange-lib's
|
|
||||||
# IPC-transport type when not available as a built-in
|
|
||||||
# serialization output.
|
|
||||||
WireT = TypeVar('WireT')
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: some kinda (decorator) API for built-in subtypes
|
# builtins we can have in same pld_spec as custom types
|
||||||
# that builds this implicitly by inspecting the `mro()`?
|
default_builtins = (
|
||||||
class TypeCodec(Protocol):
|
None,
|
||||||
|
bool,
|
||||||
|
int,
|
||||||
|
float,
|
||||||
|
bytes,
|
||||||
|
list
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# spec definition type
|
||||||
|
TypeSpec = (
|
||||||
|
Type |
|
||||||
|
Union[Type] |
|
||||||
|
list[Type] |
|
||||||
|
tuple[Type] |
|
||||||
|
set[Type]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TypeCodec:
|
||||||
'''
|
'''
|
||||||
A per-custom-type wire-transport serialization translator
|
This class describes a way of encoding to or decoding from a "wire type",
|
||||||
description type.
|
objects that have `encode_fn` and `decode_fn` can be used with
|
||||||
|
`.encode/.decode`.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
src_type: Type
|
|
||||||
wire_type: WireT
|
|
||||||
|
|
||||||
def encode(obj: Type) -> WireT:
|
def __init__(
|
||||||
...
|
self,
|
||||||
|
wire_type: Type,
|
||||||
|
decode_fn: str,
|
||||||
|
encode_fn: str = 'encode',
|
||||||
|
):
|
||||||
|
self._encode_fn: str = encode_fn
|
||||||
|
self._decode_fn: str = decode_fn
|
||||||
|
self._wire_type: Type = wire_type
|
||||||
|
|
||||||
def decode(
|
@property
|
||||||
obj_type: Type[WireT],
|
def encode_fn(self) -> str:
|
||||||
obj: WireT,
|
return self._encode_fn
|
||||||
) -> Type:
|
|
||||||
...
|
@property
|
||||||
|
def decode_fn(self) -> str:
|
||||||
|
return self._decode_fn
|
||||||
|
|
||||||
|
@property
|
||||||
|
def wire_type(self) -> str:
|
||||||
|
return self._wire_type
|
||||||
|
|
||||||
|
def is_type_compat(self, obj: any) -> bool:
|
||||||
|
return (
|
||||||
|
hasattr(obj, self._encode_fn)
|
||||||
|
and
|
||||||
|
hasattr(obj, self._decode_fn)
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode(self, obj: any) -> any:
|
||||||
|
return getattr(obj, self._encode_fn)(obj)
|
||||||
|
|
||||||
|
def decode(self, cls: Type, raw: any) -> any:
|
||||||
|
return getattr(cls, self._decode_fn)(raw)
|
||||||
|
|
||||||
|
|
||||||
class MsgpackTypeCodec(TypeCodec):
|
'''
|
||||||
...
|
Default codec descriptions for wire types:
|
||||||
|
|
||||||
|
- bytes
|
||||||
|
- str
|
||||||
|
- int
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
def mk_codec_hooks(
|
BytesCodec = TypeCodec(
|
||||||
type_codecs: list[TypeCodec],
|
decode_fn='from_bytes',
|
||||||
|
wire_type=bytes
|
||||||
|
)
|
||||||
|
|
||||||
) -> tuple[Callable, Callable]:
|
|
||||||
|
StrCodec = TypeCodec(
|
||||||
|
decode_fn='from_str',
|
||||||
|
wire_type=str
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
IntCodec = TypeCodec(
|
||||||
|
decode_fn='from_int',
|
||||||
|
wire_type=int
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
default_codecs: dict[Type, TypeCodec] = {
|
||||||
|
bytes: BytesCodec,
|
||||||
|
str: StrCodec,
|
||||||
|
int: IntCodec
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def mk_spec_set(
|
||||||
|
spec: TypeSpec
|
||||||
|
) -> set[Type]:
|
||||||
'''
|
'''
|
||||||
Deliver a `enc_hook()`/`dec_hook()` pair which handle
|
Given any of the different spec definitions, always return a `set[Type]`
|
||||||
manual convertion from an input `Type` set such that whenever
|
with each spec type as an item.
|
||||||
the `TypeCodec.filter()` predicate matches the
|
|
||||||
`TypeCodec.decode()` is called on the input native object by
|
- When passed list|tuple|set do nothing
|
||||||
the `dec_hook()` and whenever the
|
- When passed a single type we wrap it in tuple
|
||||||
`isiinstance(obj, TypeCodec.type)` matches against an
|
- When passed a Union we wrap its inner types in tuple
|
||||||
`enc_hook(obj=obj)` the return value is taken from a
|
|
||||||
`TypeCodec.encode(obj)` callback.
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
...
|
if not (
|
||||||
|
isinstance(spec, set)
|
||||||
|
or
|
||||||
|
isinstance(spec, list)
|
||||||
|
or
|
||||||
|
isinstance(spec, tuple)
|
||||||
|
):
|
||||||
|
spec_info = msgspec.inspect.type_info(spec)
|
||||||
|
match spec_info:
|
||||||
|
case UnionType():
|
||||||
|
return set((
|
||||||
|
t.cls
|
||||||
|
for t in spec_info.types
|
||||||
|
))
|
||||||
|
|
||||||
|
case (
|
||||||
|
SetType() |
|
||||||
|
ListType() |
|
||||||
|
TupleType()
|
||||||
|
):
|
||||||
|
return set((spec_info.item_type, ))
|
||||||
|
|
||||||
|
case _:
|
||||||
|
return set((spec, ))
|
||||||
|
|
||||||
|
return set(spec)
|
||||||
|
|
||||||
|
|
||||||
|
def mk_codec_map_from_spec(
|
||||||
|
spec: TypeSpec,
|
||||||
|
codecs: dict[Type, TypeCodec] = default_codecs
|
||||||
|
) -> dict[Type, TypeCodec]:
|
||||||
|
'''
|
||||||
|
Generate a map of spec type -> supported codec
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
spec: set[Type] = mk_spec_set(spec)
|
||||||
|
|
||||||
|
spec_codecs: dict[Type, TypeCodec] = {}
|
||||||
|
for t in spec:
|
||||||
|
for codec in codecs.values():
|
||||||
|
if codec.is_type_compat(t):
|
||||||
|
spec_codecs[t] = codec
|
||||||
|
|
||||||
|
return spec_codecs
|
||||||
|
|
||||||
|
|
||||||
|
def mk_enc_hook(
|
||||||
|
spec: TypeSpec,
|
||||||
|
with_builtins: bool = True,
|
||||||
|
builtins: set[Type] = default_builtins,
|
||||||
|
codecs: dict[Type, TypeCodec] = default_codecs
|
||||||
|
) -> Callable:
|
||||||
|
'''
|
||||||
|
Given a type specification return a msgspec enc_hook fn
|
||||||
|
|
||||||
|
'''
|
||||||
|
spec_codecs = mk_codec_map_from_spec(spec)
|
||||||
|
|
||||||
|
def enc_hook(obj: any) -> any:
|
||||||
|
t = type(obj)
|
||||||
|
maybe_codec = spec_codecs.get(t, None)
|
||||||
|
if maybe_codec:
|
||||||
|
return maybe_codec.encode(obj)
|
||||||
|
|
||||||
|
# passthrough built ins
|
||||||
|
if builtins and t in builtins:
|
||||||
|
return obj
|
||||||
|
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Objects of type {type(obj)} are not supported:\n{obj}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return enc_hook
|
||||||
|
|
||||||
|
|
||||||
|
def mk_dec_hook(
|
||||||
|
spec: TypeSpec,
|
||||||
|
with_builtins: bool = True,
|
||||||
|
builtins: set[Type] = default_builtins,
|
||||||
|
codecs: dict[Type, TypeCodec] = default_codecs
|
||||||
|
) -> Callable:
|
||||||
|
'''
|
||||||
|
Given a type specification return a msgspec dec_hook fn
|
||||||
|
|
||||||
|
'''
|
||||||
|
spec_codecs = mk_codec_map_from_spec(spec)
|
||||||
|
|
||||||
|
def dec_hook(t: Type, obj: any) -> any:
|
||||||
|
maybe_codec = spec_codecs.get(t, None)
|
||||||
|
if maybe_codec:
|
||||||
|
return maybe_codec.decode(t, obj)
|
||||||
|
|
||||||
|
# passthrough builtins
|
||||||
|
if builtins and type in builtins:
|
||||||
|
return obj
|
||||||
|
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Objects of type {type} are not supported from {obj}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return dec_hook
|
||||||
|
|
||||||
|
|
||||||
|
def mk_codec_hooks(*args, **kwargs) -> tuple[Callable, Callable]:
|
||||||
|
'''
|
||||||
|
Given a type specification return a msgspec enc & dec hook fn pair
|
||||||
|
|
||||||
|
'''
|
||||||
|
return (
|
||||||
|
mk_enc_hook(*args, **kwargs),
|
||||||
|
|
||||||
|
mk_dec_hook(*args, **kwargs)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def mk_codec_from_spec(
|
||||||
|
spec: TypeSpec,
|
||||||
|
with_builtins: bool = True,
|
||||||
|
builtins: set[Type] = default_builtins,
|
||||||
|
codecs: dict[Type, TypeCodec] = default_codecs
|
||||||
|
) -> MsgCodec:
|
||||||
|
'''
|
||||||
|
Given a type specification return a MsgCodec
|
||||||
|
|
||||||
|
'''
|
||||||
|
spec: set[Type] = mk_spec_set(spec)
|
||||||
|
|
||||||
|
return mk_codec(
|
||||||
|
enc_hook=mk_enc_hook(
|
||||||
|
spec,
|
||||||
|
with_builtins=with_builtins,
|
||||||
|
builtins=builtins,
|
||||||
|
codecs=codecs
|
||||||
|
),
|
||||||
|
ext_types=spec
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def mk_msgpack_codec(
|
||||||
|
spec: TypeSpec,
|
||||||
|
with_builtins: bool = True,
|
||||||
|
builtins: set[Type] = default_builtins,
|
||||||
|
codecs: dict[Type, TypeCodec] = default_codecs
|
||||||
|
) -> tuple[msgpack.Encoder, msgpack.Decoder]:
|
||||||
|
'''
|
||||||
|
Get a msgpack Encoder, Decoder pair for a given type spec
|
||||||
|
|
||||||
|
'''
|
||||||
|
spec: set[Type] = mk_spec_set(spec)
|
||||||
|
|
||||||
|
enc_hook, dec_hook = mk_codec_hooks(
|
||||||
|
spec,
|
||||||
|
with_builtins=with_builtins,
|
||||||
|
builtins=builtins,
|
||||||
|
codecs=codecs
|
||||||
|
)
|
||||||
|
encoder = msgpack.Encoder(enc_hook=enc_hook)
|
||||||
|
decoder = msgpack.Decoder(spec, dec_hook=dec_hook)
|
||||||
|
return encoder, decoder
|
||||||
|
|
Loading…
Reference in New Issue