(Re)type annot some tests

- For the (still not finished) `test_caps_based_msging`, switch to
  using the new `PayloadMsg`.
- add `testdir` fixture type.
multihost_exs
Tyler Goodlet 2024-06-28 19:24:03 -04:00
parent edac717613
commit 18d440c207
2 changed files with 17 additions and 30 deletions

View File

@ -11,9 +11,6 @@ from typing import (
Type, Type,
Union, Union,
) )
from contextvars import (
Context,
)
from msgspec import ( from msgspec import (
structs, structs,
@ -27,6 +24,7 @@ import tractor
from tractor import ( from tractor import (
_state, _state,
MsgTypeError, MsgTypeError,
Context,
) )
from tractor.msg import ( from tractor.msg import (
_codec, _codec,
@ -41,7 +39,7 @@ from tractor.msg import (
from tractor.msg.types import ( from tractor.msg.types import (
_payload_msgs, _payload_msgs,
log, log,
Msg, PayloadMsg,
Started, Started,
mk_msg_spec, mk_msg_spec,
) )
@ -61,7 +59,7 @@ def mk_custom_codec(
uid: tuple[str, str] = tractor.current_actor().uid uid: tuple[str, str] = tractor.current_actor().uid
# XXX NOTE XXX: despite defining `NamespacePath` as a type # XXX NOTE XXX: despite defining `NamespacePath` as a type
# field on our `Msg.pld`, we still need a enc/dec_hook() pair # field on our `PayloadMsg.pld`, we still need a enc/dec_hook() pair
# to cast to/from that type on the wire. See the docs: # to cast to/from that type on the wire. See the docs:
# https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types # https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types
@ -321,12 +319,12 @@ def dec_type_union(
import importlib import importlib
types: list[Type] = [] types: list[Type] = []
for type_name in type_names: for type_name in type_names:
for ns in [ for mod in [
typing, typing,
importlib.import_module(__name__), importlib.import_module(__name__),
]: ]:
if type_ref := getattr( if type_ref := getattr(
ns, mod,
type_name, type_name,
False, False,
): ):
@ -744,7 +742,7 @@ def chk_pld_type(
# 'Error', .pld: ErrorData # 'Error', .pld: ErrorData
codec: MsgCodec = mk_codec( codec: MsgCodec = mk_codec(
# NOTE: this ONLY accepts `Msg.pld` fields of a specified # NOTE: this ONLY accepts `PayloadMsg.pld` fields of a specified
# type union. # type union.
ipc_pld_spec=payload_spec, ipc_pld_spec=payload_spec,
) )
@ -752,7 +750,7 @@ def chk_pld_type(
# make a one-off dec to compare with our `MsgCodec` instance # make a one-off dec to compare with our `MsgCodec` instance
# which does the below `mk_msg_spec()` call internally # which does the below `mk_msg_spec()` call internally
ipc_msg_spec: Union[Type[Struct]] ipc_msg_spec: Union[Type[Struct]]
msg_types: list[Msg[payload_spec]] msg_types: list[PayloadMsg[payload_spec]]
( (
ipc_msg_spec, ipc_msg_spec,
msg_types, msg_types,
@ -761,7 +759,7 @@ def chk_pld_type(
) )
_enc = msgpack.Encoder() _enc = msgpack.Encoder()
_dec = msgpack.Decoder( _dec = msgpack.Decoder(
type=ipc_msg_spec or Any, # like `Msg[Any]` type=ipc_msg_spec or Any, # like `PayloadMsg[Any]`
) )
assert ( assert (
@ -806,7 +804,7 @@ def chk_pld_type(
'cid': '666', 'cid': '666',
'pld': pld, 'pld': pld,
} }
enc_msg: Msg = typedef(**kwargs) enc_msg: PayloadMsg = typedef(**kwargs)
_wire_bytes: bytes = _enc.encode(enc_msg) _wire_bytes: bytes = _enc.encode(enc_msg)
wire_bytes: bytes = codec.enc.encode(enc_msg) wire_bytes: bytes = codec.enc.encode(enc_msg)
@ -883,25 +881,16 @@ def test_limit_msgspec():
debug_mode=True debug_mode=True
): ):
# ensure we can round-trip a boxing `Msg` # ensure we can round-trip a boxing `PayloadMsg`
assert chk_pld_type( assert chk_pld_type(
# Msg, payload_spec=Any,
Any, pld=None,
None,
expect_roundtrip=True, expect_roundtrip=True,
) )
# TODO: don't need this any more right since
# `msgspec>=0.15` has the nice generics stuff yah??
#
# manually override the type annot of the payload
# field and ensure it propagates to all msg-subtypes.
# Msg.__annotations__['pld'] = Any
# verify that a mis-typed payload value won't decode # verify that a mis-typed payload value won't decode
assert not chk_pld_type( assert not chk_pld_type(
# Msg, payload_spec=int,
int,
pld='doggy', pld='doggy',
) )
@ -913,18 +902,16 @@ def test_limit_msgspec():
value: Any value: Any
assert not chk_pld_type( assert not chk_pld_type(
# Msg, payload_spec=CustomPayload,
CustomPayload,
pld='doggy', pld='doggy',
) )
assert chk_pld_type( assert chk_pld_type(
# Msg, payload_spec=CustomPayload,
CustomPayload,
pld=CustomPayload(name='doggy', value='urmom') pld=CustomPayload(name='doggy', value='urmom')
) )
# uhh bc we can `.pause_from_sync()` now! :surfer: # yah, we can `.pause_from_sync()` now!
# breakpoint() # breakpoint()
trio.run(main) trio.run(main)

View File

@ -19,7 +19,7 @@ from tractor._testing import (
@pytest.fixture @pytest.fixture
def run_example_in_subproc( def run_example_in_subproc(
loglevel: str, loglevel: str,
testdir, testdir: pytest.Testdir,
reg_addr: tuple[str, int], reg_addr: tuple[str, int],
): ):