Compare commits
	
		
			32 Commits 
		
	
	
		
			e94fd65e91
			...
			5cee222353
		
	
	| Author | SHA1 | Date | 
|---|---|---|
|  | 5cee222353 | |
|  | 8ebb1f09de | |
|  | 2683a7f33a | |
|  | 255209f881 | |
|  | 9a0d529b18 | |
|  | 1c441b0986 | |
|  | afbdb50a30 | |
|  | e46033cbe7 | |
|  | c932bb5911 | |
|  | 33482d8f41 | |
|  | 7ae194baed | |
|  | ef7ca49e9b | |
|  | fde681fa19 | |
|  | efcf81bcad | |
|  | 3988ea69f5 | |
|  | 8bd4490cad | |
|  | 622f840dfd | |
|  | 8ba315e60c | |
|  | 80f20b35b1 | |
|  | 9ec37dd13f | |
|  | 9be76b1dda | |
|  | 31f88b59f4 | |
|  | 155d581fa2 | |
|  | a810f6c8f6 | |
|  | 83b9dc3c62 | |
|  | f152a20025 | |
|  | 1ea8254ae3 | |
|  | 8ed890f892 | |
|  | d4e6f2b8dc | |
|  | 64fe767647 | |
|  | aca015f1c2 | |
|  | 818cd8535f | 
|  | @ -1,917 +0,0 @@ | |||
| ''' | ||||
| Low-level functional audits for our | ||||
| "capability based messaging"-spec feats. | ||||
| 
 | ||||
| B~) | ||||
| 
 | ||||
| ''' | ||||
| import typing | ||||
| from typing import ( | ||||
|     Any, | ||||
|     Type, | ||||
|     Union, | ||||
| ) | ||||
| 
 | ||||
| from msgspec import ( | ||||
|     structs, | ||||
|     msgpack, | ||||
|     Struct, | ||||
|     ValidationError, | ||||
| ) | ||||
| import pytest | ||||
| 
 | ||||
| import tractor | ||||
| from tractor import ( | ||||
|     _state, | ||||
|     MsgTypeError, | ||||
|     Context, | ||||
| ) | ||||
| from tractor.msg import ( | ||||
|     _codec, | ||||
|     _ctxvar_MsgCodec, | ||||
| 
 | ||||
|     NamespacePath, | ||||
|     MsgCodec, | ||||
|     mk_codec, | ||||
|     apply_codec, | ||||
|     current_codec, | ||||
| ) | ||||
| from tractor.msg.types import ( | ||||
|     _payload_msgs, | ||||
|     log, | ||||
|     PayloadMsg, | ||||
|     Started, | ||||
|     mk_msg_spec, | ||||
| ) | ||||
| import trio | ||||
| 
 | ||||
| 
 | ||||
| def mk_custom_codec( | ||||
|     pld_spec: Union[Type]|Any, | ||||
|     add_hooks: bool, | ||||
| 
 | ||||
| ) -> MsgCodec: | ||||
|     ''' | ||||
|     Create custom `msgpack` enc/dec-hooks and set a `Decoder` | ||||
|     which only loads `pld_spec` (like `NamespacePath`) types. | ||||
| 
 | ||||
|     ''' | ||||
|     uid: tuple[str, str] = tractor.current_actor().uid | ||||
| 
 | ||||
|     # XXX NOTE XXX: despite defining `NamespacePath` as a type | ||||
|     # field on our `PayloadMsg.pld`, we still need a enc/dec_hook() pair | ||||
|     # to cast to/from that type on the wire. See the docs: | ||||
|     # https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types | ||||
| 
 | ||||
|     def enc_nsp(obj: Any) -> Any: | ||||
|         print(f'{uid} ENC HOOK') | ||||
|         match obj: | ||||
|             case NamespacePath(): | ||||
|                 print( | ||||
|                     f'{uid}: `NamespacePath`-Only ENCODE?\n' | ||||
|                     f'obj-> `{obj}`: {type(obj)}\n' | ||||
|                 ) | ||||
|                 # if type(obj) != NamespacePath: | ||||
|                 #     breakpoint() | ||||
|                 return str(obj) | ||||
| 
 | ||||
|         print( | ||||
|             f'{uid}\n' | ||||
|             'CUSTOM ENCODE\n' | ||||
|             f'obj-arg-> `{obj}`: {type(obj)}\n' | ||||
|         ) | ||||
|         logmsg: str = ( | ||||
|             f'{uid}\n' | ||||
|             'FAILED ENCODE\n' | ||||
|             f'obj-> `{obj}: {type(obj)}`\n' | ||||
|         ) | ||||
|         raise NotImplementedError(logmsg) | ||||
| 
 | ||||
|     def dec_nsp( | ||||
|         obj_type: Type, | ||||
|         obj: Any, | ||||
| 
 | ||||
|     ) -> Any: | ||||
|         print( | ||||
|             f'{uid}\n' | ||||
|             'CUSTOM DECODE\n' | ||||
|             f'type-arg-> {obj_type}\n' | ||||
|             f'obj-arg-> `{obj}`: {type(obj)}\n' | ||||
|         ) | ||||
|         nsp = None | ||||
| 
 | ||||
|         if ( | ||||
|             obj_type is NamespacePath | ||||
|             and isinstance(obj, str) | ||||
|             and ':' in obj | ||||
|         ): | ||||
|             nsp = NamespacePath(obj) | ||||
|             # TODO: we could built a generic handler using | ||||
|             # JUST matching the obj_type part? | ||||
|             # nsp = obj_type(obj) | ||||
| 
 | ||||
|         if nsp: | ||||
|             print(f'Returning NSP instance: {nsp}') | ||||
|             return nsp | ||||
| 
 | ||||
|         logmsg: str = ( | ||||
|             f'{uid}\n' | ||||
|             'FAILED DECODE\n' | ||||
|             f'type-> {obj_type}\n' | ||||
|             f'obj-arg-> `{obj}`: {type(obj)}\n\n' | ||||
|             f'current codec:\n' | ||||
|             f'{current_codec()}\n' | ||||
|         ) | ||||
|         # TODO: figure out the ignore subsys for this! | ||||
|         # -[ ] option whether to defense-relay backc the msg | ||||
|         #   inside an `Invalid`/`Ignore` | ||||
|         # -[ ] how to make this handling pluggable such that a | ||||
|         #   `Channel`/`MsgTransport` can intercept and process | ||||
|         #   back msgs either via exception handling or some other | ||||
|         #   signal? | ||||
|         log.warning(logmsg) | ||||
|         # NOTE: this delivers the invalid | ||||
|         # value up to `msgspec`'s decoding | ||||
|         # machinery for error raising. | ||||
|         return obj | ||||
|         # raise NotImplementedError(logmsg) | ||||
| 
 | ||||
|     nsp_codec: MsgCodec = mk_codec( | ||||
|         ipc_pld_spec=pld_spec, | ||||
| 
 | ||||
|         # NOTE XXX: the encode hook MUST be used no matter what since | ||||
|         # our `NamespacePath` is not any of a `Any` native type nor | ||||
|         # a `msgspec.Struct` subtype - so `msgspec` has no way to know | ||||
|         # how to encode it unless we provide the custom hook. | ||||
|         # | ||||
|         # AGAIN that is, regardless of whether we spec an | ||||
|         # `Any`-decoded-pld the enc has no knowledge (by default) | ||||
|         # how to enc `NamespacePath` (nsp), so we add a custom | ||||
|         # hook to do that ALWAYS. | ||||
|         enc_hook=enc_nsp if add_hooks else None, | ||||
| 
 | ||||
|         # XXX NOTE: pretty sure this is mutex with the `type=` to | ||||
|         # `Decoder`? so it won't work in tandem with the | ||||
|         # `ipc_pld_spec` passed above? | ||||
|         dec_hook=dec_nsp if add_hooks else None, | ||||
|     ) | ||||
|     return nsp_codec | ||||
| 
 | ||||
| 
 | ||||
| def chk_codec_applied( | ||||
|     expect_codec: MsgCodec, | ||||
|     enter_value: MsgCodec|None = None, | ||||
| 
 | ||||
| ) -> MsgCodec: | ||||
|     ''' | ||||
|     buncha sanity checks ensuring that the IPC channel's | ||||
|     context-vars are set to the expected codec and that are | ||||
|     ctx-var wrapper APIs match the same. | ||||
| 
 | ||||
|     ''' | ||||
|     # TODO: play with tricyle again, bc this is supposed to work | ||||
|     # the way we want? | ||||
|     # | ||||
|     # TreeVar | ||||
|     # task: trio.Task = trio.lowlevel.current_task() | ||||
|     # curr_codec = _ctxvar_MsgCodec.get_in(task) | ||||
| 
 | ||||
|     # ContextVar | ||||
|     # task_ctx: Context = task.context | ||||
|     # assert _ctxvar_MsgCodec in task_ctx | ||||
|     # curr_codec: MsgCodec = task.context[_ctxvar_MsgCodec] | ||||
| 
 | ||||
|     # NOTE: currently we use this! | ||||
|     # RunVar | ||||
|     curr_codec: MsgCodec = current_codec() | ||||
|     last_read_codec = _ctxvar_MsgCodec.get() | ||||
|     # assert curr_codec is last_read_codec | ||||
| 
 | ||||
|     assert ( | ||||
|         (same_codec := expect_codec) is | ||||
|         # returned from `mk_codec()` | ||||
| 
 | ||||
|         # yielded value from `apply_codec()` | ||||
| 
 | ||||
|         # read from current task's `contextvars.Context` | ||||
|         curr_codec is | ||||
|         last_read_codec | ||||
| 
 | ||||
|         # the default `msgspec` settings | ||||
|         is not _codec._def_msgspec_codec | ||||
|         is not _codec._def_tractor_codec | ||||
|     ) | ||||
| 
 | ||||
|     if enter_value: | ||||
|         enter_value is same_codec | ||||
| 
 | ||||
| 
 | ||||
| def iter_maybe_sends( | ||||
|     send_items: dict[Union[Type], Any] | list[tuple], | ||||
|     ipc_pld_spec: Union[Type] | Any, | ||||
|     add_codec_hooks: bool, | ||||
| 
 | ||||
|     codec: MsgCodec|None = None, | ||||
| 
 | ||||
| ) -> tuple[Any, bool]: | ||||
| 
 | ||||
|     if isinstance(send_items, dict): | ||||
|         send_items = send_items.items() | ||||
| 
 | ||||
|     for ( | ||||
|         send_type_spec, | ||||
|         send_value, | ||||
|     ) in send_items: | ||||
| 
 | ||||
|         expect_roundtrip: bool = False | ||||
| 
 | ||||
|         # values-to-typespec santiy | ||||
|         send_type = type(send_value) | ||||
|         assert send_type == send_type_spec or ( | ||||
|             (subtypes := getattr(send_type_spec, '__args__', None)) | ||||
|             and send_type in subtypes | ||||
|         ) | ||||
| 
 | ||||
|         spec_subtypes: set[Union[Type]] = ( | ||||
|              getattr( | ||||
|                  ipc_pld_spec, | ||||
|                  '__args__', | ||||
|                  {ipc_pld_spec,}, | ||||
|              ) | ||||
|         ) | ||||
|         send_in_spec: bool = ( | ||||
|             send_type == ipc_pld_spec | ||||
|             or ( | ||||
|                 ipc_pld_spec != Any | ||||
|                 and  # presume `Union` of types | ||||
|                 send_type in spec_subtypes | ||||
|             ) | ||||
|             or ( | ||||
|                 ipc_pld_spec == Any | ||||
|                 and | ||||
|                 send_type != NamespacePath | ||||
|             ) | ||||
|         ) | ||||
|         expect_roundtrip = ( | ||||
|             send_in_spec | ||||
|             # any spec should support all other | ||||
|             # builtin py values that we send | ||||
|             # except our custom nsp type which | ||||
|             # we should be able to send as long | ||||
|             # as we provide the custom codec hooks. | ||||
|             or ( | ||||
|                 ipc_pld_spec == Any | ||||
|                 and | ||||
|                 send_type == NamespacePath | ||||
|                 and | ||||
|                 add_codec_hooks | ||||
|             ) | ||||
|         ) | ||||
| 
 | ||||
|         if codec is not None: | ||||
|             # XXX FIRST XXX ensure roundtripping works | ||||
|             # before touching any IPC primitives/APIs. | ||||
|             wire_bytes: bytes = codec.encode( | ||||
|                 Started( | ||||
|                     cid='blahblah', | ||||
|                     pld=send_value, | ||||
|                 ) | ||||
|             ) | ||||
|             # NOTE: demonstrates the decoder loading | ||||
|             # to via our native SCIPP msg-spec | ||||
|             # (structurred-conc-inter-proc-protocol) | ||||
|             # implemented as per, | ||||
|             try: | ||||
|                 msg: Started = codec.decode(wire_bytes) | ||||
|                 if not expect_roundtrip: | ||||
|                     pytest.fail( | ||||
|                         f'NOT-EXPECTED able to roundtrip value given spec:\n' | ||||
|                         f'ipc_pld_spec -> {ipc_pld_spec}\n' | ||||
|                         f'value -> {send_value}: {send_type}\n' | ||||
|                     ) | ||||
| 
 | ||||
|                 pld = msg.pld | ||||
|                 assert pld == send_value | ||||
| 
 | ||||
|             except ValidationError: | ||||
|                 if expect_roundtrip: | ||||
|                     pytest.fail( | ||||
|                         f'EXPECTED to roundtrip value given spec:\n' | ||||
|                         f'ipc_pld_spec -> {ipc_pld_spec}\n' | ||||
|                         f'value -> {send_value}: {send_type}\n' | ||||
|                     ) | ||||
| 
 | ||||
|         yield ( | ||||
|             str(send_type), | ||||
|             send_value, | ||||
|             expect_roundtrip, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| def dec_type_union( | ||||
|     type_names: list[str], | ||||
| ) -> Type: | ||||
|     ''' | ||||
|     Look up types by name, compile into a list and then create and | ||||
|     return a `typing.Union` from the full set. | ||||
| 
 | ||||
|     ''' | ||||
|     import importlib | ||||
|     types: list[Type] = [] | ||||
|     for type_name in type_names: | ||||
|         for mod in [ | ||||
|             typing, | ||||
|             importlib.import_module(__name__), | ||||
|         ]: | ||||
|             if type_ref := getattr( | ||||
|                 mod, | ||||
|                 type_name, | ||||
|                 False, | ||||
|             ): | ||||
|                 types.append(type_ref) | ||||
| 
 | ||||
|     # special case handling only.. | ||||
|     # ipc_pld_spec: Union[Type] = eval( | ||||
|     #     pld_spec_str, | ||||
|     #     {},  # globals | ||||
|     #     {'typing': typing},  # locals | ||||
|     # ) | ||||
| 
 | ||||
|     return Union[*types] | ||||
| 
 | ||||
| 
 | ||||
| def enc_type_union( | ||||
|     union_or_type: Union[Type]|Type, | ||||
| ) -> list[str]: | ||||
|     ''' | ||||
|     Encode a type-union or single type to a list of type-name-strings | ||||
|     ready for IPC interchange. | ||||
| 
 | ||||
|     ''' | ||||
|     type_strs: list[str] = [] | ||||
|     for typ in getattr( | ||||
|         union_or_type, | ||||
|         '__args__', | ||||
|         {union_or_type,}, | ||||
|     ): | ||||
|         type_strs.append(typ.__qualname__) | ||||
| 
 | ||||
|     return type_strs | ||||
| 
 | ||||
| 
 | ||||
| @tractor.context | ||||
| async def send_back_values( | ||||
|     ctx: Context, | ||||
|     expect_debug: bool, | ||||
|     pld_spec_type_strs: list[str], | ||||
|     add_hooks: bool, | ||||
|     started_msg_bytes: bytes, | ||||
|     expect_ipc_send: dict[str, tuple[Any, bool]], | ||||
| 
 | ||||
| ) -> None: | ||||
|     ''' | ||||
|     Setup up a custom codec to load instances of `NamespacePath` | ||||
|     and ensure we can round trip a func ref with our parent. | ||||
| 
 | ||||
|     ''' | ||||
|     uid: tuple = tractor.current_actor().uid | ||||
| 
 | ||||
|     # debug mode sanity check (prolly superfluous but, meh) | ||||
|     assert expect_debug == _state.debug_mode() | ||||
| 
 | ||||
|     # init state in sub-actor should be default | ||||
|     chk_codec_applied( | ||||
|         expect_codec=_codec._def_tractor_codec, | ||||
|     ) | ||||
| 
 | ||||
|     # load pld spec from input str | ||||
|     ipc_pld_spec = dec_type_union( | ||||
|         pld_spec_type_strs, | ||||
|     ) | ||||
|     pld_spec_str = str(ipc_pld_spec) | ||||
| 
 | ||||
|     # same as on parent side config. | ||||
|     nsp_codec: MsgCodec = mk_custom_codec( | ||||
|         pld_spec=ipc_pld_spec, | ||||
|         add_hooks=add_hooks, | ||||
|     ) | ||||
|     with ( | ||||
|         apply_codec(nsp_codec) as codec, | ||||
|     ): | ||||
|         chk_codec_applied( | ||||
|             expect_codec=nsp_codec, | ||||
|             enter_value=codec, | ||||
|         ) | ||||
| 
 | ||||
|         print( | ||||
|             f'{uid}: attempting `Started`-bytes DECODE..\n' | ||||
|         ) | ||||
|         try: | ||||
|             msg: Started = nsp_codec.decode(started_msg_bytes) | ||||
|             expected_pld_spec_str: str = msg.pld | ||||
|             assert pld_spec_str == expected_pld_spec_str | ||||
| 
 | ||||
|         # TODO: maybe we should add our own wrapper error so as to | ||||
|         # be interchange-lib agnostic? | ||||
|         # -[ ] the error type is wtv is raised from the hook so we | ||||
|         #   could also require a type-class of errors for | ||||
|         #   indicating whether the hook-failure can be handled by | ||||
|         #   a nasty-dialog-unprot sub-sys? | ||||
|         except ValidationError: | ||||
| 
 | ||||
|             # NOTE: only in the `Any` spec case do we expect this to | ||||
|             # work since otherwise no spec covers a plain-ol' | ||||
|             # `.pld: str` | ||||
|             if pld_spec_str == 'Any': | ||||
|                 raise | ||||
|             else: | ||||
|                 print( | ||||
|                     f'{uid}: (correctly) unable to DECODE `Started`-bytes\n' | ||||
|                     f'{started_msg_bytes}\n' | ||||
|                 ) | ||||
| 
 | ||||
|         iter_send_val_items = iter(expect_ipc_send.values()) | ||||
|         sent: list[Any] = [] | ||||
|         for send_value, expect_send in iter_send_val_items: | ||||
|             try: | ||||
|                 print( | ||||
|                     f'{uid}: attempting to `.started({send_value})`\n' | ||||
|                     f'=> expect_send: {expect_send}\n' | ||||
|                     f'SINCE, ipc_pld_spec: {ipc_pld_spec}\n' | ||||
|                     f'AND, codec: {codec}\n' | ||||
|                 ) | ||||
|                 await ctx.started(send_value) | ||||
|                 sent.append(send_value) | ||||
|                 if not expect_send: | ||||
| 
 | ||||
|                     # XXX NOTE XXX THIS WON'T WORK WITHOUT SPECIAL | ||||
|                     # `str` handling! or special debug mode IPC | ||||
|                     # msgs! | ||||
|                     await tractor.pause() | ||||
| 
 | ||||
|                     raise RuntimeError( | ||||
|                         f'NOT-EXPECTED able to roundtrip value given spec:\n' | ||||
|                         f'ipc_pld_spec -> {ipc_pld_spec}\n' | ||||
|                         f'value -> {send_value}: {type(send_value)}\n' | ||||
|                     ) | ||||
| 
 | ||||
|                 break  # move on to streaming block.. | ||||
| 
 | ||||
|             except tractor.MsgTypeError: | ||||
|                 await tractor.pause() | ||||
| 
 | ||||
|                 if expect_send: | ||||
|                     raise RuntimeError( | ||||
|                         f'EXPECTED to `.started()` value given spec:\n' | ||||
|                         f'ipc_pld_spec -> {ipc_pld_spec}\n' | ||||
|                         f'value -> {send_value}: {type(send_value)}\n' | ||||
|                     ) | ||||
| 
 | ||||
|         async with ctx.open_stream() as ipc: | ||||
|             print( | ||||
|                 f'{uid}: Entering streaming block to send remaining values..' | ||||
|             ) | ||||
| 
 | ||||
|             for send_value, expect_send in iter_send_val_items: | ||||
|                 send_type: Type = type(send_value) | ||||
|                 print( | ||||
|                     '------ - ------\n' | ||||
|                     f'{uid}: SENDING NEXT VALUE\n' | ||||
|                     f'ipc_pld_spec: {ipc_pld_spec}\n' | ||||
|                     f'expect_send: {expect_send}\n' | ||||
|                     f'val: {send_value}\n' | ||||
|                     '------ - ------\n' | ||||
|                 ) | ||||
|                 try: | ||||
|                     await ipc.send(send_value) | ||||
|                     print(f'***\n{uid}-CHILD sent {send_value!r}\n***\n') | ||||
|                     sent.append(send_value) | ||||
| 
 | ||||
|                     # NOTE: should only raise above on | ||||
|                     # `.started()` or a `Return` | ||||
|                     # if not expect_send: | ||||
|                     #     raise RuntimeError( | ||||
|                     #         f'NOT-EXPECTED able to roundtrip value given spec:\n' | ||||
|                     #         f'ipc_pld_spec -> {ipc_pld_spec}\n' | ||||
|                     #         f'value -> {send_value}: {send_type}\n' | ||||
|                     #     ) | ||||
| 
 | ||||
|                 except ValidationError: | ||||
|                     print(f'{uid} FAILED TO SEND {send_value}!') | ||||
| 
 | ||||
|                     # await tractor.pause() | ||||
|                     if expect_send: | ||||
|                         raise RuntimeError( | ||||
|                             f'EXPECTED to roundtrip value given spec:\n' | ||||
|                             f'ipc_pld_spec -> {ipc_pld_spec}\n' | ||||
|                             f'value -> {send_value}: {send_type}\n' | ||||
|                         ) | ||||
|                     # continue | ||||
| 
 | ||||
|             else: | ||||
|                 print( | ||||
|                     f'{uid}: finished sending all values\n' | ||||
|                     'Should be exiting stream block!\n' | ||||
|                 ) | ||||
| 
 | ||||
|         print(f'{uid}: exited streaming block!') | ||||
| 
 | ||||
|         # TODO: this won't be true bc in streaming phase we DO NOT | ||||
|         # msgspec check outbound msgs! | ||||
|         # -[ ] once we implement the receiver side `InvalidMsg` | ||||
|         #   then we can expect it here? | ||||
|         # assert ( | ||||
|         #     len(sent) | ||||
|         #     == | ||||
|         #     len([val | ||||
|         #          for val, expect in | ||||
|         #          expect_ipc_send.values() | ||||
|         #          if expect is True]) | ||||
|         # ) | ||||
| 
 | ||||
| 
 | ||||
| def ex_func(*args): | ||||
|     print(f'ex_func({args})') | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|     'ipc_pld_spec', | ||||
|     [ | ||||
|         Any, | ||||
|         NamespacePath, | ||||
|         NamespacePath|None,  # the "maybe" spec Bo | ||||
|     ], | ||||
|     ids=[ | ||||
|         'any_type', | ||||
|         'nsp_type', | ||||
|         'maybe_nsp_type', | ||||
|     ] | ||||
| ) | ||||
| @pytest.mark.parametrize( | ||||
|     'add_codec_hooks', | ||||
|     [ | ||||
|         True, | ||||
|         False, | ||||
|     ], | ||||
|     ids=['use_codec_hooks', 'no_codec_hooks'], | ||||
| ) | ||||
| def test_codec_hooks_mod( | ||||
|     debug_mode: bool, | ||||
|     ipc_pld_spec: Union[Type]|Any, | ||||
|     # send_value: None|str|NamespacePath, | ||||
|     add_codec_hooks: bool, | ||||
| ): | ||||
|     ''' | ||||
|     Audit the `.msg.MsgCodec` override apis details given our impl | ||||
|     uses `contextvars` to accomplish per `trio` task codec | ||||
|     application around an inter-proc-task-comms context. | ||||
| 
 | ||||
|     ''' | ||||
|     async def main(): | ||||
|         nsp = NamespacePath.from_ref(ex_func) | ||||
|         send_items: dict[Union, Any] = { | ||||
|             Union[None]: None, | ||||
|             Union[NamespacePath]: nsp, | ||||
|             Union[str]: str(nsp), | ||||
|         } | ||||
| 
 | ||||
|         # init default state for actor | ||||
|         chk_codec_applied( | ||||
|             expect_codec=_codec._def_tractor_codec, | ||||
|         ) | ||||
| 
 | ||||
|         async with tractor.open_nursery( | ||||
|             debug_mode=debug_mode, | ||||
|         ) as an: | ||||
|             p: tractor.Portal = await an.start_actor( | ||||
|                 'sub', | ||||
|                 enable_modules=[__name__], | ||||
|             ) | ||||
| 
 | ||||
|             # TODO: 2 cases: | ||||
|             # - codec not modified -> decode nsp as `str` | ||||
|             # - codec modified with hooks -> decode nsp as | ||||
|             #   `NamespacePath` | ||||
|             nsp_codec: MsgCodec = mk_custom_codec( | ||||
|                 pld_spec=ipc_pld_spec, | ||||
|                 add_hooks=add_codec_hooks, | ||||
|             ) | ||||
|             with apply_codec(nsp_codec) as codec: | ||||
|                 chk_codec_applied( | ||||
|                     expect_codec=nsp_codec, | ||||
|                     enter_value=codec, | ||||
|                 ) | ||||
| 
 | ||||
|                 expect_ipc_send: dict[str, tuple[Any, bool]] = {} | ||||
| 
 | ||||
|                 report: str = ( | ||||
|                     'Parent report on send values with\n' | ||||
|                     f'ipc_pld_spec: {ipc_pld_spec}\n' | ||||
|                     '       ------ - ------\n' | ||||
|                 ) | ||||
|                 for val_type_str, val, expect_send in iter_maybe_sends( | ||||
|                     send_items, | ||||
|                     ipc_pld_spec, | ||||
|                     add_codec_hooks=add_codec_hooks, | ||||
|                 ): | ||||
|                     report += ( | ||||
|                         f'send_value: {val}: {type(val)} ' | ||||
|                         f'=> expect_send: {expect_send}\n' | ||||
|                     ) | ||||
|                     expect_ipc_send[val_type_str] = (val, expect_send) | ||||
| 
 | ||||
|                 print( | ||||
|                     report + | ||||
|                     '       ------ - ------\n' | ||||
|                 ) | ||||
|                 assert len(expect_ipc_send) == len(send_items) | ||||
|                 # now try over real IPC with a the subactor | ||||
|                 # expect_ipc_rountrip: bool = True | ||||
|                 expected_started = Started( | ||||
|                     cid='cid', | ||||
|                     pld=str(ipc_pld_spec), | ||||
|                 ) | ||||
|                 # build list of values we expect to receive from | ||||
|                 # the subactor. | ||||
|                 expect_to_send: list[Any] = [ | ||||
|                     val | ||||
|                     for val, expect_send in expect_ipc_send.values() | ||||
|                     if expect_send | ||||
|                 ] | ||||
| 
 | ||||
|                 pld_spec_type_strs: list[str] = enc_type_union(ipc_pld_spec) | ||||
| 
 | ||||
|                 # XXX should raise an mte (`MsgTypeError`) | ||||
|                 # when `add_codec_hooks == False` bc the input | ||||
|                 # `expect_ipc_send` kwarg has a nsp which can't be | ||||
|                 # serialized! | ||||
|                 # | ||||
|                 # TODO:can we ensure this happens from the | ||||
|                 # `Return`-side (aka the sub) as well? | ||||
|                 if not add_codec_hooks: | ||||
|                     try: | ||||
|                         async with p.open_context( | ||||
|                             send_back_values, | ||||
|                             expect_debug=debug_mode, | ||||
|                             pld_spec_type_strs=pld_spec_type_strs, | ||||
|                             add_hooks=add_codec_hooks, | ||||
|                             started_msg_bytes=nsp_codec.encode(expected_started), | ||||
| 
 | ||||
|                             # XXX NOTE bc we send a `NamespacePath` in this kwarg | ||||
|                             expect_ipc_send=expect_ipc_send, | ||||
| 
 | ||||
|                         ) as (ctx, first): | ||||
|                             pytest.fail('ctx should fail to open without custom enc_hook!?') | ||||
| 
 | ||||
|                     # this test passes bc we can go no further! | ||||
|                     except MsgTypeError: | ||||
|                         # teardown nursery | ||||
|                         await p.cancel_actor() | ||||
|                         return | ||||
| 
 | ||||
|                 # TODO: send the original nsp here and | ||||
|                 # test with `limit_msg_spec()` above? | ||||
|                 # await tractor.pause() | ||||
|                 print('PARENT opening IPC ctx!\n') | ||||
|                 async with ( | ||||
| 
 | ||||
|                     # XXX should raise an mte (`MsgTypeError`) | ||||
|                     # when `add_codec_hooks == False`.. | ||||
|                     p.open_context( | ||||
|                         send_back_values, | ||||
|                         expect_debug=debug_mode, | ||||
|                         pld_spec_type_strs=pld_spec_type_strs, | ||||
|                         add_hooks=add_codec_hooks, | ||||
|                         started_msg_bytes=nsp_codec.encode(expected_started), | ||||
|                         expect_ipc_send=expect_ipc_send, | ||||
|                     ) as (ctx, first), | ||||
| 
 | ||||
|                     ctx.open_stream() as ipc, | ||||
|                 ): | ||||
|                     # ensure codec is still applied across | ||||
|                     # `tractor.Context` + its embedded nursery. | ||||
|                     chk_codec_applied( | ||||
|                         expect_codec=nsp_codec, | ||||
|                         enter_value=codec, | ||||
|                     ) | ||||
|                     print( | ||||
|                         'root: ENTERING CONTEXT BLOCK\n' | ||||
|                         f'type(first): {type(first)}\n' | ||||
|                         f'first: {first}\n' | ||||
|                     ) | ||||
|                     expect_to_send.remove(first) | ||||
| 
 | ||||
|                     # TODO: explicit values we expect depending on | ||||
|                     # codec config! | ||||
|                     # assert first == first_val | ||||
|                     # assert first == f'{__name__}:ex_func' | ||||
| 
 | ||||
|                     async for next_sent in ipc: | ||||
|                         print( | ||||
|                             'Parent: child sent next value\n' | ||||
|                             f'{next_sent}: {type(next_sent)}\n' | ||||
|                         ) | ||||
|                         if expect_to_send: | ||||
|                             expect_to_send.remove(next_sent) | ||||
|                         else: | ||||
|                             print('PARENT should terminate stream loop + block!') | ||||
| 
 | ||||
|                     # all sent values should have arrived! | ||||
|                     assert not expect_to_send | ||||
| 
 | ||||
|             await p.cancel_actor() | ||||
| 
 | ||||
|     trio.run(main) | ||||
| 
 | ||||
| 
 | ||||
| def chk_pld_type( | ||||
|     payload_spec: Type[Struct]|Any, | ||||
|     pld: Any, | ||||
| 
 | ||||
|     expect_roundtrip: bool|None = None, | ||||
| 
 | ||||
| ) -> bool: | ||||
| 
 | ||||
|     pld_val_type: Type = type(pld) | ||||
| 
 | ||||
|     # TODO: verify that the overridden subtypes | ||||
|     # DO NOT have modified type-annots from original! | ||||
|     # 'Start',  .pld: FuncSpec | ||||
|     # 'StartAck',  .pld: IpcCtxSpec | ||||
|     # 'Stop',  .pld: UNSEt | ||||
|     # 'Error',  .pld: ErrorData | ||||
| 
 | ||||
|     codec: MsgCodec = mk_codec( | ||||
|         # NOTE: this ONLY accepts `PayloadMsg.pld` fields of a specified | ||||
|         # type union. | ||||
|         ipc_pld_spec=payload_spec, | ||||
|     ) | ||||
| 
 | ||||
|     # make a one-off dec to compare with our `MsgCodec` instance | ||||
|     # which does the below `mk_msg_spec()` call internally | ||||
|     ipc_msg_spec: Union[Type[Struct]] | ||||
|     msg_types: list[PayloadMsg[payload_spec]] | ||||
|     ( | ||||
|         ipc_msg_spec, | ||||
|         msg_types, | ||||
|     ) = mk_msg_spec( | ||||
|         payload_type_union=payload_spec, | ||||
|     ) | ||||
|     _enc = msgpack.Encoder() | ||||
|     _dec = msgpack.Decoder( | ||||
|         type=ipc_msg_spec or Any,  # like `PayloadMsg[Any]` | ||||
|     ) | ||||
| 
 | ||||
|     assert ( | ||||
|         payload_spec | ||||
|         == | ||||
|         codec.pld_spec | ||||
|     ) | ||||
| 
 | ||||
|     # assert codec.dec == dec | ||||
|     # | ||||
|     # ^-XXX-^ not sure why these aren't "equal" but when cast | ||||
|     # to `str` they seem to match ?? .. kk | ||||
| 
 | ||||
|     assert ( | ||||
|         str(ipc_msg_spec) | ||||
|         == | ||||
|         str(codec.msg_spec) | ||||
|         == | ||||
|         str(_dec.type) | ||||
|         == | ||||
|         str(codec.dec.type) | ||||
|     ) | ||||
| 
 | ||||
|     # verify the boxed-type for all variable payload-type msgs. | ||||
|     if not msg_types: | ||||
|         breakpoint() | ||||
| 
 | ||||
|     roundtrip: bool|None = None | ||||
|     pld_spec_msg_names: list[str] = [ | ||||
|         td.__name__ for td in _payload_msgs | ||||
|     ] | ||||
|     for typedef in msg_types: | ||||
| 
 | ||||
|         skip_runtime_msg: bool = typedef.__name__ not in pld_spec_msg_names | ||||
|         if skip_runtime_msg: | ||||
|             continue | ||||
| 
 | ||||
|         pld_field = structs.fields(typedef)[1] | ||||
|         assert pld_field.type is payload_spec # TODO-^ does this need to work to get all subtypes to adhere? | ||||
| 
 | ||||
|         kwargs: dict[str, Any] = { | ||||
|             'cid': '666', | ||||
|             'pld': pld, | ||||
|         } | ||||
|         enc_msg: PayloadMsg = typedef(**kwargs) | ||||
| 
 | ||||
|         _wire_bytes: bytes = _enc.encode(enc_msg) | ||||
|         wire_bytes: bytes = codec.enc.encode(enc_msg) | ||||
|         assert _wire_bytes == wire_bytes | ||||
| 
 | ||||
|         ve: ValidationError|None = None | ||||
|         try: | ||||
|             dec_msg = codec.dec.decode(wire_bytes) | ||||
|             _dec_msg = _dec.decode(wire_bytes) | ||||
| 
 | ||||
|             # decoded msg and thus payload should be exactly same! | ||||
|             assert (roundtrip := ( | ||||
|                 _dec_msg | ||||
|                 == | ||||
|                 dec_msg | ||||
|                 == | ||||
|                 enc_msg | ||||
|             )) | ||||
| 
 | ||||
|             if ( | ||||
|                 expect_roundtrip is not None | ||||
|                 and expect_roundtrip != roundtrip | ||||
|             ): | ||||
|                 breakpoint() | ||||
| 
 | ||||
|             assert ( | ||||
|                 pld | ||||
|                 == | ||||
|                 dec_msg.pld | ||||
|                 == | ||||
|                 enc_msg.pld | ||||
|             ) | ||||
|             # assert (roundtrip := (_dec_msg == enc_msg)) | ||||
| 
 | ||||
|         except ValidationError as _ve: | ||||
|             ve = _ve | ||||
|             roundtrip: bool = False | ||||
|             if pld_val_type is payload_spec: | ||||
|                 raise ValueError( | ||||
|                    'Got `ValidationError` despite type-var match!?\n' | ||||
|                     f'pld_val_type: {pld_val_type}\n' | ||||
|                     f'payload_type: {payload_spec}\n' | ||||
|                 ) from ve | ||||
| 
 | ||||
|             else: | ||||
|                 # ow we good cuz the pld spec mismatched. | ||||
|                 print( | ||||
|                     'Got expected `ValidationError` since,\n' | ||||
|                     f'{pld_val_type} is not {payload_spec}\n' | ||||
|                 ) | ||||
|         else: | ||||
|             if ( | ||||
|                 payload_spec is not Any | ||||
|                 and | ||||
|                 pld_val_type is not payload_spec | ||||
|             ): | ||||
|                 raise ValueError( | ||||
|                    'DID NOT `ValidationError` despite expected type match!?\n' | ||||
|                     f'pld_val_type: {pld_val_type}\n' | ||||
|                     f'payload_type: {payload_spec}\n' | ||||
|                 ) | ||||
| 
 | ||||
|     # full code decode should always be attempted! | ||||
|     if roundtrip is None: | ||||
|         breakpoint() | ||||
| 
 | ||||
|     return roundtrip | ||||
| 
 | ||||
| 
 | ||||
| def test_limit_msgspec( | ||||
|     debug_mode: bool, | ||||
| ): | ||||
|     async def main(): | ||||
|         async with tractor.open_root_actor( | ||||
|             debug_mode=debug_mode, | ||||
|         ): | ||||
|             # ensure we can round-trip a boxing `PayloadMsg` | ||||
|             assert chk_pld_type( | ||||
|                 payload_spec=Any, | ||||
|                 pld=None, | ||||
|                 expect_roundtrip=True, | ||||
|             ) | ||||
| 
 | ||||
|             # verify that a mis-typed payload value won't decode | ||||
|             assert not chk_pld_type( | ||||
|                 payload_spec=int, | ||||
|                 pld='doggy', | ||||
|             ) | ||||
| 
 | ||||
|             # parametrize the boxed `.pld` type as a custom-struct | ||||
|             # and ensure that parametrization propagates | ||||
|             # to all payload-msg-spec-able subtypes! | ||||
|             class CustomPayload(Struct): | ||||
|                 name: str | ||||
|                 value: Any | ||||
| 
 | ||||
|             assert not chk_pld_type( | ||||
|                 payload_spec=CustomPayload, | ||||
|                 pld='doggy', | ||||
|             ) | ||||
| 
 | ||||
|             assert chk_pld_type( | ||||
|                 payload_spec=CustomPayload, | ||||
|                 pld=CustomPayload(name='doggy', value='urmom') | ||||
|             ) | ||||
| 
 | ||||
|             # yah, we can `.pause_from_sync()` now! | ||||
|             # breakpoint() | ||||
| 
 | ||||
|     trio.run(main) | ||||
|  | @ -38,9 +38,9 @@ from tractor._testing import ( | |||
| # - standard setup/teardown: | ||||
| #   ``Portal.open_context()`` starts a new | ||||
| #   remote task context in another actor. The target actor's task must | ||||
| #   call ``Context.started()`` to unblock this entry on the caller side. | ||||
| #   the callee task executes until complete and returns a final value | ||||
| #   which is delivered to the caller side and retreived via | ||||
| #   call ``Context.started()`` to unblock this entry on the parent side. | ||||
| #   the child task executes until complete and returns a final value | ||||
| #   which is delivered to the parent side and retreived via | ||||
| #   ``Context.result()``. | ||||
| 
 | ||||
| # - cancel termination: | ||||
|  | @ -170,9 +170,9 @@ async def assert_state(value: bool): | |||
|     [False, ValueError, KeyboardInterrupt], | ||||
| ) | ||||
| @pytest.mark.parametrize( | ||||
|     'callee_blocks_forever', | ||||
|     'child_blocks_forever', | ||||
|     [False, True], | ||||
|     ids=lambda item: f'callee_blocks_forever={item}' | ||||
|     ids=lambda item: f'child_blocks_forever={item}' | ||||
| ) | ||||
| @pytest.mark.parametrize( | ||||
|     'pointlessly_open_stream', | ||||
|  | @ -181,7 +181,7 @@ async def assert_state(value: bool): | |||
| ) | ||||
| def test_simple_context( | ||||
|     error_parent, | ||||
|     callee_blocks_forever, | ||||
|     child_blocks_forever, | ||||
|     pointlessly_open_stream, | ||||
|     debug_mode: bool, | ||||
| ): | ||||
|  | @ -204,13 +204,13 @@ def test_simple_context( | |||
|                         portal.open_context( | ||||
|                             simple_setup_teardown, | ||||
|                             data=10, | ||||
|                             block_forever=callee_blocks_forever, | ||||
|                             block_forever=child_blocks_forever, | ||||
|                         ) as (ctx, sent), | ||||
|                     ): | ||||
|                         assert current_ipc_ctx() is ctx | ||||
|                         assert sent == 11 | ||||
| 
 | ||||
|                         if callee_blocks_forever: | ||||
|                         if child_blocks_forever: | ||||
|                             await portal.run(assert_state, value=True) | ||||
|                         else: | ||||
|                             assert await ctx.result() == 'yo' | ||||
|  | @ -220,7 +220,7 @@ def test_simple_context( | |||
|                                 if error_parent: | ||||
|                                     raise error_parent | ||||
| 
 | ||||
|                                 if callee_blocks_forever: | ||||
|                                 if child_blocks_forever: | ||||
|                                     await ctx.cancel() | ||||
|                                 else: | ||||
|                                     # in this case the stream will send a | ||||
|  | @ -259,9 +259,9 @@ def test_simple_context( | |||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|     'callee_returns_early', | ||||
|     'child_returns_early', | ||||
|     [True, False], | ||||
|     ids=lambda item: f'callee_returns_early={item}' | ||||
|     ids=lambda item: f'child_returns_early={item}' | ||||
| ) | ||||
| @pytest.mark.parametrize( | ||||
|     'cancel_method', | ||||
|  | @ -273,14 +273,14 @@ def test_simple_context( | |||
|     [True, False], | ||||
|     ids=lambda item: f'chk_ctx_result_before_exit={item}' | ||||
| ) | ||||
| def test_caller_cancels( | ||||
| def test_parent_cancels( | ||||
|     cancel_method: str, | ||||
|     chk_ctx_result_before_exit: bool, | ||||
|     callee_returns_early: bool, | ||||
|     child_returns_early: bool, | ||||
|     debug_mode: bool, | ||||
| ): | ||||
|     ''' | ||||
|     Verify that when the opening side of a context (aka the caller) | ||||
|     Verify that when the opening side of a context (aka the parent) | ||||
|     cancels that context, the ctx does not raise a cancelled when | ||||
|     either calling `.result()` or on context exit. | ||||
| 
 | ||||
|  | @ -294,7 +294,7 @@ def test_caller_cancels( | |||
| 
 | ||||
|         if ( | ||||
|             cancel_method == 'portal' | ||||
|             and not callee_returns_early | ||||
|             and not child_returns_early | ||||
|         ): | ||||
|             try: | ||||
|                 res = await ctx.result() | ||||
|  | @ -318,7 +318,7 @@ def test_caller_cancels( | |||
|                 pytest.fail(f'should not have raised ctxc\n{ctxc}') | ||||
| 
 | ||||
|         # we actually get a result | ||||
|         if callee_returns_early: | ||||
|         if child_returns_early: | ||||
|             assert res == 'yo' | ||||
|             assert ctx.outcome is res | ||||
|             assert ctx.maybe_error is None | ||||
|  | @ -362,14 +362,14 @@ def test_caller_cancels( | |||
|             ) | ||||
|             timeout: float = ( | ||||
|                 0.5 | ||||
|                 if not callee_returns_early | ||||
|                 if not child_returns_early | ||||
|                 else 2 | ||||
|             ) | ||||
|             with trio.fail_after(timeout): | ||||
|                 async with ( | ||||
|                     expect_ctxc( | ||||
|                         yay=( | ||||
|                             not callee_returns_early | ||||
|                             not child_returns_early | ||||
|                             and cancel_method == 'portal' | ||||
|                         ) | ||||
|                     ), | ||||
|  | @ -377,13 +377,13 @@ def test_caller_cancels( | |||
|                     portal.open_context( | ||||
|                         simple_setup_teardown, | ||||
|                         data=10, | ||||
|                         block_forever=not callee_returns_early, | ||||
|                         block_forever=not child_returns_early, | ||||
|                     ) as (ctx, sent), | ||||
|                 ): | ||||
| 
 | ||||
|                     if callee_returns_early: | ||||
|                     if child_returns_early: | ||||
|                         # ensure we block long enough before sending | ||||
|                         # a cancel such that the callee has already | ||||
|                         # a cancel such that the child has already | ||||
|                         # returned it's result. | ||||
|                         await trio.sleep(0.5) | ||||
| 
 | ||||
|  | @ -421,7 +421,7 @@ def test_caller_cancels( | |||
|             #   which should in turn cause `ctx._scope` to | ||||
|             # catch any cancellation? | ||||
|             if ( | ||||
|                 not callee_returns_early | ||||
|                 not child_returns_early | ||||
|                 and cancel_method != 'portal' | ||||
|             ): | ||||
|                 assert not ctx._scope.cancelled_caught | ||||
|  | @ -430,11 +430,11 @@ def test_caller_cancels( | |||
| 
 | ||||
| 
 | ||||
| # basic stream terminations: | ||||
| # - callee context closes without using stream | ||||
| # - caller context closes without using stream | ||||
| # - caller context calls `Context.cancel()` while streaming | ||||
| #   is ongoing resulting in callee being cancelled | ||||
| # - callee calls `Context.cancel()` while streaming and caller | ||||
| # - child context closes without using stream | ||||
| # - parent context closes without using stream | ||||
| # - parent context calls `Context.cancel()` while streaming | ||||
| #   is ongoing resulting in child being cancelled | ||||
| # - child calls `Context.cancel()` while streaming and parent | ||||
| #   sees stream terminated in `RemoteActorError` | ||||
| 
 | ||||
| # TODO: future possible features | ||||
|  | @ -443,7 +443,6 @@ def test_caller_cancels( | |||
| 
 | ||||
| @tractor.context | ||||
| async def close_ctx_immediately( | ||||
| 
 | ||||
|     ctx: Context, | ||||
| 
 | ||||
| ) -> None: | ||||
|  | @ -454,13 +453,24 @@ async def close_ctx_immediately( | |||
|     async with ctx.open_stream(): | ||||
|         pass | ||||
| 
 | ||||
|     print('child returning!') | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|     'parent_send_before_receive', | ||||
|     [ | ||||
|         False, | ||||
|         True, | ||||
|     ], | ||||
|     ids=lambda item: f'child_send_before_receive={item}' | ||||
| ) | ||||
| @tractor_test | ||||
| async def test_callee_closes_ctx_after_stream_open( | ||||
| async def test_child_exits_ctx_after_stream_open( | ||||
|     debug_mode: bool, | ||||
|     parent_send_before_receive: bool, | ||||
| ): | ||||
|     ''' | ||||
|     callee context closes without using stream. | ||||
|     child context closes without using stream. | ||||
| 
 | ||||
|     This should result in a msg sequence | ||||
|     |_<root>_ | ||||
|  | @ -474,6 +484,9 @@ async def test_callee_closes_ctx_after_stream_open( | |||
|     => {'stop': True, 'cid': <str>} | ||||
| 
 | ||||
|     ''' | ||||
|     timeout: float = ( | ||||
|         0.5 if not debug_mode else 999 | ||||
|     ) | ||||
|     async with tractor.open_nursery( | ||||
|         debug_mode=debug_mode, | ||||
|     ) as an: | ||||
|  | @ -482,7 +495,7 @@ async def test_callee_closes_ctx_after_stream_open( | |||
|             enable_modules=[__name__], | ||||
|         ) | ||||
| 
 | ||||
|         with trio.fail_after(0.5): | ||||
|         with trio.fail_after(timeout): | ||||
|             async with portal.open_context( | ||||
|                 close_ctx_immediately, | ||||
| 
 | ||||
|  | @ -494,41 +507,56 @@ async def test_callee_closes_ctx_after_stream_open( | |||
| 
 | ||||
|                 with trio.fail_after(0.4): | ||||
|                     async with ctx.open_stream() as stream: | ||||
|                         if parent_send_before_receive: | ||||
|                             print('sending first msg from parent!') | ||||
|                             await stream.send('yo') | ||||
| 
 | ||||
|                         # should fall through since ``StopAsyncIteration`` | ||||
|                         # should be raised through translation of | ||||
|                         # a ``trio.EndOfChannel`` by | ||||
|                         # ``trio.abc.ReceiveChannel.__anext__()`` | ||||
|                         async for _ in stream: | ||||
|                         msg = 10 | ||||
|                         async for msg in stream: | ||||
|                             # trigger failure if we DO NOT | ||||
|                             # get an EOC! | ||||
|                             assert 0 | ||||
|                         else: | ||||
|                             # never should get anythinig new from | ||||
|                             # the underlying stream | ||||
|                             assert msg == 10 | ||||
| 
 | ||||
|                             # verify stream is now closed | ||||
|                             try: | ||||
|                                 with trio.fail_after(0.3): | ||||
|                                     print('parent trying to `.receive()` on EoC stream!') | ||||
|                                     await stream.receive() | ||||
|                                     assert 0, 'should have raised eoc!?' | ||||
|                             except trio.EndOfChannel: | ||||
|                                 print('parent got EoC as expected!') | ||||
|                                 pass | ||||
|                                 # raise | ||||
| 
 | ||||
|                 # TODO: should be just raise the closed resource err | ||||
|                 # directly here to enforce not allowing a re-open | ||||
|                 # of a stream to the context (at least until a time of | ||||
|                 # if/when we decide that's a good idea?) | ||||
|                 try: | ||||
|                     with trio.fail_after(0.5): | ||||
|                     with trio.fail_after(timeout): | ||||
|                         async with ctx.open_stream() as stream: | ||||
|                             pass | ||||
|                 except trio.ClosedResourceError: | ||||
|                     pass | ||||
| 
 | ||||
|                 # if ctx._rx_chan._state.data: | ||||
|                 #     await tractor.pause() | ||||
| 
 | ||||
|         await portal.cancel_actor() | ||||
| 
 | ||||
| 
 | ||||
| @tractor.context | ||||
| async def expect_cancelled( | ||||
|     ctx: Context, | ||||
|     send_before_receive: bool = False, | ||||
| 
 | ||||
| ) -> None: | ||||
|     global _state | ||||
|  | @ -538,6 +566,10 @@ async def expect_cancelled( | |||
| 
 | ||||
|     try: | ||||
|         async with ctx.open_stream() as stream: | ||||
| 
 | ||||
|             if send_before_receive: | ||||
|                 await stream.send('yo') | ||||
| 
 | ||||
|             async for msg in stream: | ||||
|                 await stream.send(msg)  # echo server | ||||
| 
 | ||||
|  | @ -564,26 +596,49 @@ async def expect_cancelled( | |||
|         raise | ||||
| 
 | ||||
|     else: | ||||
|         assert 0, "callee wasn't cancelled !?" | ||||
|         assert 0, "child wasn't cancelled !?" | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|     'child_send_before_receive', | ||||
|     [ | ||||
|         False, | ||||
|         True, | ||||
|     ], | ||||
|     ids=lambda item: f'child_send_before_receive={item}' | ||||
| ) | ||||
| @pytest.mark.parametrize( | ||||
|     'rent_wait_for_msg', | ||||
|     [ | ||||
|         False, | ||||
|         True, | ||||
|     ], | ||||
|     ids=lambda item: f'rent_wait_for_msg={item}' | ||||
| ) | ||||
| @pytest.mark.parametrize( | ||||
|     'use_ctx_cancel_method', | ||||
|     [False, True], | ||||
|     [ | ||||
|         False, | ||||
|         'pre_stream', | ||||
|         'post_stream_open', | ||||
|         'post_stream_close', | ||||
|     ], | ||||
|     ids=lambda item: f'use_ctx_cancel_method={item}' | ||||
| ) | ||||
| @tractor_test | ||||
| async def test_caller_closes_ctx_after_callee_opens_stream( | ||||
|     use_ctx_cancel_method: bool, | ||||
| async def test_parent_exits_ctx_after_child_enters_stream( | ||||
|     use_ctx_cancel_method: bool|str, | ||||
|     debug_mode: bool, | ||||
|     rent_wait_for_msg: bool, | ||||
|     child_send_before_receive: bool, | ||||
| ): | ||||
|     ''' | ||||
|     caller context closes without using/opening stream | ||||
|     Parent-side of IPC context closes without sending on `MsgStream`. | ||||
| 
 | ||||
|     ''' | ||||
|     async with tractor.open_nursery( | ||||
|         debug_mode=debug_mode, | ||||
|     ) as an: | ||||
| 
 | ||||
|         root: Actor = current_actor() | ||||
|         portal = await an.start_actor( | ||||
|             'ctx_cancelled', | ||||
|  | @ -592,41 +647,52 @@ async def test_caller_closes_ctx_after_callee_opens_stream( | |||
| 
 | ||||
|         async with portal.open_context( | ||||
|             expect_cancelled, | ||||
|             send_before_receive=child_send_before_receive, | ||||
|         ) as (ctx, sent): | ||||
|             assert sent is None | ||||
| 
 | ||||
|             await portal.run(assert_state, value=True) | ||||
| 
 | ||||
|             # call `ctx.cancel()` explicitly | ||||
|             if use_ctx_cancel_method: | ||||
|             if use_ctx_cancel_method == 'pre_stream': | ||||
|                 await ctx.cancel() | ||||
| 
 | ||||
|                 # NOTE: means the local side `ctx._scope` will | ||||
|                 # have been cancelled by an ctxc ack and thus | ||||
|                 # `._scope.cancelled_caught` should be set. | ||||
|                 try: | ||||
|                 async with ( | ||||
|                     expect_ctxc( | ||||
|                         # XXX: the cause is US since we call | ||||
|                         # `Context.cancel()` just above! | ||||
|                         yay=True, | ||||
| 
 | ||||
|                         # XXX: must be propagated to __aexit__ | ||||
|                         # and should be silently absorbed there | ||||
|                         # since we called `.cancel()` just above ;) | ||||
|                         reraise=True, | ||||
|                     ) as maybe_ctxc, | ||||
|                 ): | ||||
|                     async with ctx.open_stream() as stream: | ||||
|                         async for msg in stream: | ||||
|                             pass | ||||
| 
 | ||||
|                 except tractor.ContextCancelled as ctxc: | ||||
|                     # XXX: the cause is US since we call | ||||
|                     # `Context.cancel()` just above! | ||||
|                     assert ( | ||||
|                         ctxc.canceller | ||||
|                         == | ||||
|                         current_actor().uid | ||||
|                         == | ||||
|                         root.uid | ||||
|                     ) | ||||
|                         if rent_wait_for_msg: | ||||
|                             async for msg in stream: | ||||
|                                 print(f'PARENT rx: {msg!r}\n') | ||||
|                                 break | ||||
| 
 | ||||
|                     # XXX: must be propagated to __aexit__ | ||||
|                     # and should be silently absorbed there | ||||
|                     # since we called `.cancel()` just above ;) | ||||
|                     raise | ||||
|                         if use_ctx_cancel_method == 'post_stream_open': | ||||
|                             await ctx.cancel() | ||||
| 
 | ||||
|                 else: | ||||
|                     assert 0, "Should have context cancelled?" | ||||
|                     if use_ctx_cancel_method == 'post_stream_close': | ||||
|                         await ctx.cancel() | ||||
| 
 | ||||
|                 ctxc: tractor.ContextCancelled = maybe_ctxc.value | ||||
|                 assert ( | ||||
|                     ctxc.canceller | ||||
|                     == | ||||
|                     current_actor().uid | ||||
|                     == | ||||
|                     root.uid | ||||
|                 ) | ||||
| 
 | ||||
|                 # channel should still be up | ||||
|                 assert portal.channel.connected() | ||||
|  | @ -637,13 +703,20 @@ async def test_caller_closes_ctx_after_callee_opens_stream( | |||
|                     value=False, | ||||
|                 ) | ||||
| 
 | ||||
|             # XXX CHILD-BLOCKS case, we SHOULD NOT exit from the | ||||
|             # `.open_context()` before the child has returned, | ||||
|             # errored or been cancelled! | ||||
|             else: | ||||
|                 try: | ||||
|                     with trio.fail_after(0.2): | ||||
|                         await ctx.result() | ||||
|                     with trio.fail_after( | ||||
|                         0.5  # if not debug_mode else 999 | ||||
|                     ): | ||||
|                         res = await ctx.wait_for_result() | ||||
|                         assert res is not tractor._context.Unresolved | ||||
|                         assert 0, "Callee should have blocked!?" | ||||
|                 except trio.TooSlowError: | ||||
|                     # NO-OP -> since already called above | ||||
|                     # NO-OP -> since already triggered by | ||||
|                     # `trio.fail_after()` above! | ||||
|                     await ctx.cancel() | ||||
| 
 | ||||
|         # NOTE: local scope should have absorbed the cancellation since | ||||
|  | @ -683,7 +756,7 @@ async def test_caller_closes_ctx_after_callee_opens_stream( | |||
| 
 | ||||
| 
 | ||||
| @tractor_test | ||||
| async def test_multitask_caller_cancels_from_nonroot_task( | ||||
| async def test_multitask_parent_cancels_from_nonroot_task( | ||||
|     debug_mode: bool, | ||||
| ): | ||||
|     async with tractor.open_nursery( | ||||
|  | @ -735,7 +808,6 @@ async def test_multitask_caller_cancels_from_nonroot_task( | |||
| 
 | ||||
| @tractor.context | ||||
| async def cancel_self( | ||||
| 
 | ||||
|     ctx: Context, | ||||
| 
 | ||||
| ) -> None: | ||||
|  | @ -775,11 +847,11 @@ async def cancel_self( | |||
| 
 | ||||
| 
 | ||||
| @tractor_test | ||||
| async def test_callee_cancels_before_started( | ||||
| async def test_child_cancels_before_started( | ||||
|     debug_mode: bool, | ||||
| ): | ||||
|     ''' | ||||
|     Callee calls `Context.cancel()` while streaming and caller | ||||
|     Callee calls `Context.cancel()` while streaming and parent | ||||
|     sees stream terminated in `ContextCancelled`. | ||||
| 
 | ||||
|     ''' | ||||
|  | @ -826,14 +898,13 @@ async def never_open_stream( | |||
| 
 | ||||
| 
 | ||||
| @tractor.context | ||||
| async def keep_sending_from_callee( | ||||
| 
 | ||||
| async def keep_sending_from_child( | ||||
|     ctx:  Context, | ||||
|     msg_buffer_size: int|None = None, | ||||
| 
 | ||||
| ) -> None: | ||||
|     ''' | ||||
|     Send endlessly on the calleee stream. | ||||
|     Send endlessly on the child stream. | ||||
| 
 | ||||
|     ''' | ||||
|     await ctx.started() | ||||
|  | @ -841,7 +912,7 @@ async def keep_sending_from_callee( | |||
|         msg_buffer_size=msg_buffer_size, | ||||
|     ) as stream: | ||||
|         for msg in count(): | ||||
|             print(f'callee sending {msg}') | ||||
|             print(f'child sending {msg}') | ||||
|             await stream.send(msg) | ||||
|             await trio.sleep(0.01) | ||||
| 
 | ||||
|  | @ -849,12 +920,12 @@ async def keep_sending_from_callee( | |||
| @pytest.mark.parametrize( | ||||
|     'overrun_by', | ||||
|     [ | ||||
|         ('caller', 1, never_open_stream), | ||||
|         ('callee', 0, keep_sending_from_callee), | ||||
|         ('parent', 1, never_open_stream), | ||||
|         ('child', 0, keep_sending_from_child), | ||||
|     ], | ||||
|     ids=[ | ||||
|          ('caller_1buf_never_open_stream'), | ||||
|          ('callee_0buf_keep_sending_from_callee'), | ||||
|          ('parent_1buf_never_open_stream'), | ||||
|          ('child_0buf_keep_sending_from_child'), | ||||
|     ] | ||||
| ) | ||||
| def test_one_end_stream_not_opened( | ||||
|  | @ -885,8 +956,7 @@ def test_one_end_stream_not_opened( | |||
|                 ) as (ctx, sent): | ||||
|                     assert sent is None | ||||
| 
 | ||||
|                     if 'caller' in overrunner: | ||||
| 
 | ||||
|                     if 'parent' in overrunner: | ||||
|                         async with ctx.open_stream() as stream: | ||||
| 
 | ||||
|                             # itersend +1 msg more then the buffer size | ||||
|  | @ -901,7 +971,7 @@ def test_one_end_stream_not_opened( | |||
|                                 await trio.sleep_forever() | ||||
| 
 | ||||
|                     else: | ||||
|                         # callee overruns caller case so we do nothing here | ||||
|                         # child overruns parent case so we do nothing here | ||||
|                         await trio.sleep_forever() | ||||
| 
 | ||||
|             await portal.cancel_actor() | ||||
|  | @ -909,19 +979,19 @@ def test_one_end_stream_not_opened( | |||
|     # 2 overrun cases and the no overrun case (which pushes right up to | ||||
|     # the msg limit) | ||||
|     if ( | ||||
|         overrunner == 'caller' | ||||
|         overrunner == 'parent' | ||||
|     ): | ||||
|         with pytest.raises(tractor.RemoteActorError) as excinfo: | ||||
|             trio.run(main) | ||||
| 
 | ||||
|         assert excinfo.value.boxed_type == StreamOverrun | ||||
| 
 | ||||
|     elif overrunner == 'callee': | ||||
|     elif overrunner == 'child': | ||||
|         with pytest.raises(tractor.RemoteActorError) as excinfo: | ||||
|             trio.run(main) | ||||
| 
 | ||||
|         # TODO: embedded remote errors so that we can verify the source | ||||
|         # error? the callee delivers an error which is an overrun | ||||
|         # error? the child delivers an error which is an overrun | ||||
|         # wrapped in a remote actor error. | ||||
|         assert excinfo.value.boxed_type == tractor.RemoteActorError | ||||
| 
 | ||||
|  | @ -931,8 +1001,7 @@ def test_one_end_stream_not_opened( | |||
| 
 | ||||
| @tractor.context | ||||
| async def echo_back_sequence( | ||||
| 
 | ||||
|     ctx:  Context, | ||||
|     ctx: Context, | ||||
|     seq: list[int], | ||||
|     wait_for_cancel: bool, | ||||
|     allow_overruns_side: str, | ||||
|  | @ -941,12 +1010,12 @@ async def echo_back_sequence( | |||
| 
 | ||||
| ) -> None: | ||||
|     ''' | ||||
|     Send endlessly on the calleee stream using a small buffer size | ||||
|     Send endlessly on the child stream using a small buffer size | ||||
|     setting on the contex to simulate backlogging that would normally | ||||
|     cause overruns. | ||||
| 
 | ||||
|     ''' | ||||
|     # NOTE: ensure that if the caller is expecting to cancel this task | ||||
|     # NOTE: ensure that if the parent is expecting to cancel this task | ||||
|     # that we stay echoing much longer then they are so we don't | ||||
|     # return early instead of receive the cancel msg. | ||||
|     total_batches: int = ( | ||||
|  | @ -996,18 +1065,18 @@ async def echo_back_sequence( | |||
|                 if be_slow: | ||||
|                     await trio.sleep(0.05) | ||||
| 
 | ||||
|                 print('callee waiting on next') | ||||
|                 print('child waiting on next') | ||||
| 
 | ||||
|             print(f'callee echoing back latest batch\n{batch}') | ||||
|             print(f'child echoing back latest batch\n{batch}') | ||||
|             for msg in batch: | ||||
|                 print(f'callee sending msg\n{msg}') | ||||
|                 print(f'child sending msg\n{msg}') | ||||
|                 await stream.send(msg) | ||||
| 
 | ||||
|     try: | ||||
|         return 'yo' | ||||
|     finally: | ||||
|         print( | ||||
|             'exiting callee with context:\n' | ||||
|             'exiting child with context:\n' | ||||
|             f'{pformat(ctx)}\n' | ||||
|         ) | ||||
| 
 | ||||
|  | @ -1061,7 +1130,7 @@ def test_maybe_allow_overruns_stream( | |||
|             debug_mode=debug_mode, | ||||
|         ) as an: | ||||
|             portal = await an.start_actor( | ||||
|                 'callee_sends_forever', | ||||
|                 'child_sends_forever', | ||||
|                 enable_modules=[__name__], | ||||
|                 loglevel=loglevel, | ||||
|                 debug_mode=debug_mode, | ||||
|  |  | |||
|  | @ -0,0 +1,946 @@ | |||
| ''' | ||||
| Low-level functional audits for our | ||||
| "capability based messaging"-spec feats. | ||||
| 
 | ||||
| B~) | ||||
| 
 | ||||
| ''' | ||||
| from contextlib import ( | ||||
|     contextmanager as cm, | ||||
|     # nullcontext, | ||||
| ) | ||||
| import importlib | ||||
| from typing import ( | ||||
|     Any, | ||||
|     Type, | ||||
|     Union, | ||||
| ) | ||||
| 
 | ||||
| from msgspec import ( | ||||
|     # structs, | ||||
|     # msgpack, | ||||
|     Raw, | ||||
|     # Struct, | ||||
|     ValidationError, | ||||
| ) | ||||
| import pytest | ||||
| import trio | ||||
| 
 | ||||
| import tractor | ||||
| from tractor import ( | ||||
|     Actor, | ||||
|     # _state, | ||||
|     MsgTypeError, | ||||
|     Context, | ||||
| ) | ||||
| from tractor.msg import ( | ||||
|     _codec, | ||||
|     _ctxvar_MsgCodec, | ||||
|     _exts, | ||||
| 
 | ||||
|     NamespacePath, | ||||
|     MsgCodec, | ||||
|     MsgDec, | ||||
|     mk_codec, | ||||
|     mk_dec, | ||||
|     apply_codec, | ||||
|     current_codec, | ||||
| ) | ||||
| from tractor.msg.types import ( | ||||
|     log, | ||||
|     Started, | ||||
|     # _payload_msgs, | ||||
|     # PayloadMsg, | ||||
|     # mk_msg_spec, | ||||
| ) | ||||
| from tractor.msg._ops import ( | ||||
|     limit_plds, | ||||
| ) | ||||
| 
 | ||||
| def enc_nsp(obj: Any) -> Any: | ||||
|     actor: Actor = tractor.current_actor( | ||||
|         err_on_no_runtime=False, | ||||
|     ) | ||||
|     uid: tuple[str, str]|None = None if not actor else actor.uid | ||||
|     print(f'{uid} ENC HOOK') | ||||
| 
 | ||||
|     match obj: | ||||
|         # case NamespacePath()|str(): | ||||
|         case NamespacePath(): | ||||
|             encoded: str = str(obj) | ||||
|             print( | ||||
|                 f'----- ENCODING `NamespacePath` as `str` ------\n' | ||||
|                 f'|_obj:{type(obj)!r} = {obj!r}\n' | ||||
|                 f'|_encoded: str = {encoded!r}\n' | ||||
|             ) | ||||
|             # if type(obj) != NamespacePath: | ||||
|             #     breakpoint() | ||||
|             return encoded | ||||
|         case _: | ||||
|             logmsg: str = ( | ||||
|                 f'{uid}\n' | ||||
|                 'FAILED ENCODE\n' | ||||
|                 f'obj-> `{obj}: {type(obj)}`\n' | ||||
|             ) | ||||
|             raise NotImplementedError(logmsg) | ||||
| 
 | ||||
| 
 | ||||
| def dec_nsp( | ||||
|     obj_type: Type, | ||||
|     obj: Any, | ||||
| 
 | ||||
| ) -> Any: | ||||
|     # breakpoint() | ||||
|     actor: Actor = tractor.current_actor( | ||||
|         err_on_no_runtime=False, | ||||
|     ) | ||||
|     uid: tuple[str, str]|None = None if not actor else actor.uid | ||||
|     print( | ||||
|         f'{uid}\n' | ||||
|         'CUSTOM DECODE\n' | ||||
|         f'type-arg-> {obj_type}\n' | ||||
|         f'obj-arg-> `{obj}`: {type(obj)}\n' | ||||
|     ) | ||||
|     nsp = None | ||||
|     # XXX, never happens right? | ||||
|     if obj_type is Raw: | ||||
|         breakpoint() | ||||
| 
 | ||||
|     if ( | ||||
|         obj_type is NamespacePath | ||||
|         and isinstance(obj, str) | ||||
|         and ':' in obj | ||||
|     ): | ||||
|         nsp = NamespacePath(obj) | ||||
|         # TODO: we could built a generic handler using | ||||
|         # JUST matching the obj_type part? | ||||
|         # nsp = obj_type(obj) | ||||
| 
 | ||||
|     if nsp: | ||||
|         print(f'Returning NSP instance: {nsp}') | ||||
|         return nsp | ||||
| 
 | ||||
|     logmsg: str = ( | ||||
|         f'{uid}\n' | ||||
|         'FAILED DECODE\n' | ||||
|         f'type-> {obj_type}\n' | ||||
|         f'obj-arg-> `{obj}`: {type(obj)}\n\n' | ||||
|         f'current codec:\n' | ||||
|         f'{current_codec()}\n' | ||||
|     ) | ||||
|     # TODO: figure out the ignore subsys for this! | ||||
|     # -[ ] option whether to defense-relay backc the msg | ||||
|     #   inside an `Invalid`/`Ignore` | ||||
|     # -[ ] how to make this handling pluggable such that a | ||||
|     #   `Channel`/`MsgTransport` can intercept and process | ||||
|     #   back msgs either via exception handling or some other | ||||
|     #   signal? | ||||
|     log.warning(logmsg) | ||||
|     # NOTE: this delivers the invalid | ||||
|     # value up to `msgspec`'s decoding | ||||
|     # machinery for error raising. | ||||
|     return obj | ||||
|     # raise NotImplementedError(logmsg) | ||||
| 
 | ||||
| 
 | ||||
| def ex_func(*args): | ||||
|     ''' | ||||
|     A mod level func we can ref and load via our `NamespacePath` | ||||
|     python-object pointer `str` subtype. | ||||
| 
 | ||||
|     ''' | ||||
|     print(f'ex_func({args})') | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|     'add_codec_hooks', | ||||
|     [ | ||||
|         True, | ||||
|         False, | ||||
|     ], | ||||
|     ids=['use_codec_hooks', 'no_codec_hooks'], | ||||
| ) | ||||
| def test_custom_extension_types( | ||||
|     debug_mode: bool, | ||||
|     add_codec_hooks: bool | ||||
| ): | ||||
|     ''' | ||||
|     Verify that a `MsgCodec` (used for encoding all outbound IPC msgs | ||||
|     and decoding all inbound `PayloadMsg`s) and a paired `MsgDec` | ||||
|     (used for decoding the `PayloadMsg.pld: Raw` received within a given | ||||
|     task's ipc `Context` scope) can both send and receive "extension types" | ||||
|     as supported via custom converter hooks passed to `msgspec`. | ||||
| 
 | ||||
|     ''' | ||||
|     nsp_pld_dec: MsgDec = mk_dec( | ||||
|         spec=None,  # ONLY support the ext type | ||||
|         dec_hook=dec_nsp if add_codec_hooks else None, | ||||
|         ext_types=[NamespacePath], | ||||
|     ) | ||||
|     nsp_codec: MsgCodec = mk_codec( | ||||
|         # ipc_pld_spec=Raw,  # default! | ||||
| 
 | ||||
|         # NOTE XXX: the encode hook MUST be used no matter what since | ||||
|         # our `NamespacePath` is not any of a `Any` native type nor | ||||
|         # a `msgspec.Struct` subtype - so `msgspec` has no way to know | ||||
|         # how to encode it unless we provide the custom hook. | ||||
|         # | ||||
|         # AGAIN that is, regardless of whether we spec an | ||||
|         # `Any`-decoded-pld the enc has no knowledge (by default) | ||||
|         # how to enc `NamespacePath` (nsp), so we add a custom | ||||
|         # hook to do that ALWAYS. | ||||
|         enc_hook=enc_nsp if add_codec_hooks else None, | ||||
| 
 | ||||
|         # XXX NOTE: pretty sure this is mutex with the `type=` to | ||||
|         # `Decoder`? so it won't work in tandem with the | ||||
|         # `ipc_pld_spec` passed above? | ||||
|         ext_types=[NamespacePath], | ||||
| 
 | ||||
|         # TODO? is it useful to have the `.pld` decoded *prior* to | ||||
|         # the `PldRx`?? like perf or mem related? | ||||
|         # ext_dec=nsp_pld_dec, | ||||
|     ) | ||||
|     if add_codec_hooks: | ||||
|         assert nsp_codec.dec.dec_hook is None | ||||
| 
 | ||||
|         # TODO? if we pass `ext_dec` above? | ||||
|         # assert nsp_codec.dec.dec_hook is dec_nsp | ||||
| 
 | ||||
|         assert nsp_codec.enc.enc_hook is enc_nsp | ||||
| 
 | ||||
|     nsp = NamespacePath.from_ref(ex_func) | ||||
| 
 | ||||
|     try: | ||||
|         nsp_bytes: bytes = nsp_codec.encode(nsp) | ||||
|         nsp_rt_sin_msg = nsp_pld_dec.decode(nsp_bytes) | ||||
|         nsp_rt_sin_msg.load_ref() is ex_func | ||||
|     except TypeError: | ||||
|         if not add_codec_hooks: | ||||
|             pass | ||||
| 
 | ||||
|     try: | ||||
|         msg_bytes: bytes = nsp_codec.encode( | ||||
|             Started( | ||||
|                 cid='cid', | ||||
|                 pld=nsp, | ||||
|             ) | ||||
|         ) | ||||
|         # since the ext-type obj should also be set as the msg.pld | ||||
|         assert nsp_bytes in msg_bytes | ||||
|         started_rt: Started = nsp_codec.decode(msg_bytes) | ||||
|         pld: Raw = started_rt.pld | ||||
|         assert isinstance(pld, Raw) | ||||
|         nsp_rt: NamespacePath = nsp_pld_dec.decode(pld) | ||||
|         assert isinstance(nsp_rt, NamespacePath) | ||||
|         # in obj comparison terms they should be the same | ||||
|         assert nsp_rt == nsp | ||||
|         # ensure we've decoded to ext type! | ||||
|         assert nsp_rt.load_ref() is ex_func | ||||
| 
 | ||||
|     except TypeError: | ||||
|         if not add_codec_hooks: | ||||
|             pass | ||||
| 
 | ||||
| @tractor.context | ||||
| async def sleep_forever_in_sub( | ||||
|     ctx: Context, | ||||
| ) -> None: | ||||
|     await trio.sleep_forever() | ||||
| 
 | ||||
| 
 | ||||
| def mk_custom_codec( | ||||
|     add_hooks: bool, | ||||
| 
 | ||||
| ) -> tuple[ | ||||
|     MsgCodec,  # encode to send | ||||
|     MsgDec,  # pld receive-n-decode | ||||
| ]: | ||||
|     ''' | ||||
|     Create custom `msgpack` enc/dec-hooks and set a `Decoder` | ||||
|     which only loads `pld_spec` (like `NamespacePath`) types. | ||||
| 
 | ||||
|     ''' | ||||
| 
 | ||||
|     # XXX NOTE XXX: despite defining `NamespacePath` as a type | ||||
|     # field on our `PayloadMsg.pld`, we still need a enc/dec_hook() pair | ||||
|     # to cast to/from that type on the wire. See the docs: | ||||
|     # https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types | ||||
| 
 | ||||
|     # if pld_spec is Any: | ||||
|     #     pld_spec = Raw | ||||
| 
 | ||||
|     nsp_codec: MsgCodec = mk_codec( | ||||
|         # ipc_pld_spec=Raw,  # default! | ||||
| 
 | ||||
|         # NOTE XXX: the encode hook MUST be used no matter what since | ||||
|         # our `NamespacePath` is not any of a `Any` native type nor | ||||
|         # a `msgspec.Struct` subtype - so `msgspec` has no way to know | ||||
|         # how to encode it unless we provide the custom hook. | ||||
|         # | ||||
|         # AGAIN that is, regardless of whether we spec an | ||||
|         # `Any`-decoded-pld the enc has no knowledge (by default) | ||||
|         # how to enc `NamespacePath` (nsp), so we add a custom | ||||
|         # hook to do that ALWAYS. | ||||
|         enc_hook=enc_nsp if add_hooks else None, | ||||
| 
 | ||||
|         # XXX NOTE: pretty sure this is mutex with the `type=` to | ||||
|         # `Decoder`? so it won't work in tandem with the | ||||
|         # `ipc_pld_spec` passed above? | ||||
|         ext_types=[NamespacePath], | ||||
|     ) | ||||
|     # dec_hook=dec_nsp if add_hooks else None, | ||||
|     return nsp_codec | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|     'limit_plds_args', | ||||
|     [ | ||||
|         ( | ||||
|             {'dec_hook': None, 'ext_types': None}, | ||||
|             None, | ||||
|         ), | ||||
|         ( | ||||
|             {'dec_hook': dec_nsp, 'ext_types': None}, | ||||
|             TypeError, | ||||
|         ), | ||||
|         ( | ||||
|             {'dec_hook': dec_nsp, 'ext_types': [NamespacePath]}, | ||||
|             None, | ||||
|         ), | ||||
|         ( | ||||
|             {'dec_hook': dec_nsp, 'ext_types': [NamespacePath|None]}, | ||||
|             None, | ||||
|         ), | ||||
|     ], | ||||
|     ids=[ | ||||
|         'no_hook_no_ext_types', | ||||
|         'only_hook', | ||||
|         'hook_and_ext_types', | ||||
|         'hook_and_ext_types_w_null', | ||||
|     ] | ||||
| ) | ||||
| def test_pld_limiting_usage( | ||||
|     limit_plds_args: tuple[dict, Exception|None], | ||||
| ): | ||||
|     ''' | ||||
|     Verify `dec_hook()` and `ext_types` need to either both be | ||||
|     provided or we raise a explanator type-error. | ||||
| 
 | ||||
|     ''' | ||||
|     kwargs, maybe_err = limit_plds_args | ||||
|     async def main(): | ||||
|         async with tractor.open_nursery() as an:  # just to open runtime | ||||
| 
 | ||||
|             # XXX SHOULD NEVER WORK outside an ipc ctx scope! | ||||
|             try: | ||||
|                 with limit_plds(**kwargs): | ||||
|                     pass | ||||
|             except RuntimeError: | ||||
|                 pass | ||||
| 
 | ||||
|             p: tractor.Portal = await an.start_actor( | ||||
|                 'sub', | ||||
|                 enable_modules=[__name__], | ||||
|             ) | ||||
|             async with ( | ||||
|                 p.open_context( | ||||
|                     sleep_forever_in_sub | ||||
|                 ) as (ctx, first), | ||||
|             ): | ||||
|                 try: | ||||
|                     with limit_plds(**kwargs): | ||||
|                         pass | ||||
|                 except maybe_err as exc: | ||||
|                     assert type(exc) is maybe_err | ||||
|                     pass | ||||
| 
 | ||||
| 
 | ||||
| def chk_codec_applied( | ||||
|     expect_codec: MsgCodec|None, | ||||
|     enter_value: MsgCodec|None = None, | ||||
| 
 | ||||
| ) -> MsgCodec: | ||||
|     ''' | ||||
|     buncha sanity checks ensuring that the IPC channel's | ||||
|     context-vars are set to the expected codec and that are | ||||
|     ctx-var wrapper APIs match the same. | ||||
| 
 | ||||
|     ''' | ||||
|     # TODO: play with tricyle again, bc this is supposed to work | ||||
|     # the way we want? | ||||
|     # | ||||
|     # TreeVar | ||||
|     # task: trio.Task = trio.lowlevel.current_task() | ||||
|     # curr_codec = _ctxvar_MsgCodec.get_in(task) | ||||
| 
 | ||||
|     # ContextVar | ||||
|     # task_ctx: Context = task.context | ||||
|     # assert _ctxvar_MsgCodec in task_ctx | ||||
|     # curr_codec: MsgCodec = task.context[_ctxvar_MsgCodec] | ||||
|     if expect_codec is None: | ||||
|         assert enter_value is None | ||||
|         return | ||||
| 
 | ||||
|     # NOTE: currently we use this! | ||||
|     # RunVar | ||||
|     curr_codec: MsgCodec = current_codec() | ||||
|     last_read_codec = _ctxvar_MsgCodec.get() | ||||
|     # assert curr_codec is last_read_codec | ||||
| 
 | ||||
|     assert ( | ||||
|         (same_codec := expect_codec) is | ||||
|         # returned from `mk_codec()` | ||||
| 
 | ||||
|         # yielded value from `apply_codec()` | ||||
| 
 | ||||
|         # read from current task's `contextvars.Context` | ||||
|         curr_codec is | ||||
|         last_read_codec | ||||
| 
 | ||||
|         # the default `msgspec` settings | ||||
|         is not _codec._def_msgspec_codec | ||||
|         is not _codec._def_tractor_codec | ||||
|     ) | ||||
| 
 | ||||
|     if enter_value: | ||||
|         assert enter_value is same_codec | ||||
| 
 | ||||
| 
 | ||||
| @tractor.context | ||||
| async def send_back_values( | ||||
|     ctx: Context, | ||||
|     rent_pld_spec_type_strs: list[str], | ||||
|     add_hooks: bool, | ||||
| 
 | ||||
| ) -> None: | ||||
|     ''' | ||||
|     Setup up a custom codec to load instances of `NamespacePath` | ||||
|     and ensure we can round trip a func ref with our parent. | ||||
| 
 | ||||
|     ''' | ||||
|     uid: tuple = tractor.current_actor().uid | ||||
| 
 | ||||
|     # init state in sub-actor should be default | ||||
|     chk_codec_applied( | ||||
|         expect_codec=_codec._def_tractor_codec, | ||||
|     ) | ||||
| 
 | ||||
|     # load pld spec from input str | ||||
|     rent_pld_spec = _exts.dec_type_union( | ||||
|         rent_pld_spec_type_strs, | ||||
|         mods=[ | ||||
|             importlib.import_module(__name__), | ||||
|         ], | ||||
|     ) | ||||
|     rent_pld_spec_types: set[Type] = _codec.unpack_spec_types( | ||||
|         rent_pld_spec, | ||||
|     ) | ||||
| 
 | ||||
|     # ONLY add ext-hooks if the rent specified a non-std type! | ||||
|     add_hooks: bool = ( | ||||
|         NamespacePath in rent_pld_spec_types | ||||
|         and | ||||
|         add_hooks | ||||
|     ) | ||||
| 
 | ||||
|     # same as on parent side config. | ||||
|     nsp_codec: MsgCodec|None = None | ||||
|     if add_hooks: | ||||
|         nsp_codec = mk_codec( | ||||
|             enc_hook=enc_nsp, | ||||
|             ext_types=[NamespacePath], | ||||
|         ) | ||||
| 
 | ||||
|     with ( | ||||
|         maybe_apply_codec(nsp_codec) as codec, | ||||
|         limit_plds( | ||||
|             rent_pld_spec, | ||||
|             dec_hook=dec_nsp if add_hooks else None, | ||||
|             ext_types=[NamespacePath]  if add_hooks else None, | ||||
|         ) as pld_dec, | ||||
|     ): | ||||
|         # ?XXX? SHOULD WE NOT be swapping the global codec since it | ||||
|         # breaks `Context.started()` roundtripping checks?? | ||||
|         chk_codec_applied( | ||||
|             expect_codec=nsp_codec, | ||||
|             enter_value=codec, | ||||
|         ) | ||||
| 
 | ||||
|         # ?TODO, mismatch case(s)? | ||||
|         # | ||||
|         # ensure pld spec matches on both sides | ||||
|         ctx_pld_dec: MsgDec = ctx._pld_rx._pld_dec | ||||
|         assert pld_dec is ctx_pld_dec | ||||
|         child_pld_spec: Type = pld_dec.spec | ||||
|         child_pld_spec_types: set[Type] = _codec.unpack_spec_types( | ||||
|             child_pld_spec, | ||||
|         ) | ||||
|         assert ( | ||||
|             child_pld_spec_types.issuperset( | ||||
|                 rent_pld_spec_types | ||||
|             ) | ||||
|         ) | ||||
| 
 | ||||
|         # ?TODO, try loop for each of the types in pld-superset? | ||||
|         # | ||||
|         # for send_value in [ | ||||
|         #     nsp, | ||||
|         #     str(nsp), | ||||
|         #     None, | ||||
|         # ]: | ||||
|         nsp = NamespacePath.from_ref(ex_func) | ||||
|         try: | ||||
|             print( | ||||
|                 f'{uid}: attempting to `.started({nsp})`\n' | ||||
|                 f'\n' | ||||
|                 f'rent_pld_spec: {rent_pld_spec}\n' | ||||
|                 f'child_pld_spec: {child_pld_spec}\n' | ||||
|                 f'codec: {codec}\n' | ||||
|             ) | ||||
|             # await tractor.pause() | ||||
|             await ctx.started(nsp) | ||||
| 
 | ||||
|         except tractor.MsgTypeError as _mte: | ||||
|             mte = _mte | ||||
| 
 | ||||
|             # false -ve case | ||||
|             if add_hooks: | ||||
|                 raise RuntimeError( | ||||
|                     f'EXPECTED to `.started()` value given spec ??\n\n' | ||||
|                     f'child_pld_spec -> {child_pld_spec}\n' | ||||
|                     f'value = {nsp}: {type(nsp)}\n' | ||||
|                 ) | ||||
| 
 | ||||
|             # true -ve case | ||||
|             raise mte | ||||
| 
 | ||||
|         # TODO: maybe we should add our own wrapper error so as to | ||||
|         # be interchange-lib agnostic? | ||||
|         # -[ ] the error type is wtv is raised from the hook so we | ||||
|         #   could also require a type-class of errors for | ||||
|         #   indicating whether the hook-failure can be handled by | ||||
|         #   a nasty-dialog-unprot sub-sys? | ||||
|         except TypeError as typerr: | ||||
|             # false -ve | ||||
|             if add_hooks: | ||||
|                 raise RuntimeError('Should have been able to send `nsp`??') | ||||
| 
 | ||||
|             # true -ve | ||||
|             print('Failed to send `nsp` due to no ext hooks set!') | ||||
|             raise typerr | ||||
| 
 | ||||
|         # now try sending a set of valid and invalid plds to ensure | ||||
|         # the pld spec is respected. | ||||
|         sent: list[Any] = [] | ||||
|         async with ctx.open_stream() as ipc: | ||||
|             print( | ||||
|                 f'{uid}: streaming all pld types to rent..' | ||||
|             ) | ||||
| 
 | ||||
|             # for send_value, expect_send in iter_send_val_items: | ||||
|             for send_value in [ | ||||
|                 nsp, | ||||
|                 str(nsp), | ||||
|                 None, | ||||
|             ]: | ||||
|                 send_type: Type = type(send_value) | ||||
|                 print( | ||||
|                     f'{uid}: SENDING NEXT pld\n' | ||||
|                     f'send_type: {send_type}\n' | ||||
|                     f'send_value: {send_value}\n' | ||||
|                 ) | ||||
|                 try: | ||||
|                     await ipc.send(send_value) | ||||
|                     sent.append(send_value) | ||||
| 
 | ||||
|                 except ValidationError as valerr: | ||||
|                     print(f'{uid} FAILED TO SEND {send_value}!') | ||||
| 
 | ||||
|                     # false -ve | ||||
|                     if add_hooks: | ||||
|                         raise RuntimeError( | ||||
|                             f'EXPECTED to roundtrip value given spec:\n' | ||||
|                             f'rent_pld_spec -> {rent_pld_spec}\n' | ||||
|                             f'child_pld_spec -> {child_pld_spec}\n' | ||||
|                             f'value = {send_value}: {send_type}\n' | ||||
|                         ) | ||||
| 
 | ||||
|                     # true -ve | ||||
|                     raise valerr | ||||
|                     # continue | ||||
| 
 | ||||
|             else: | ||||
|                 print( | ||||
|                     f'{uid}: finished sending all values\n' | ||||
|                     'Should be exiting stream block!\n' | ||||
|                 ) | ||||
| 
 | ||||
|         print(f'{uid}: exited streaming block!') | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| @cm | ||||
| def maybe_apply_codec(codec: MsgCodec|None) -> MsgCodec|None: | ||||
|     if codec is None: | ||||
|         yield None | ||||
|         return | ||||
| 
 | ||||
|     with apply_codec(codec) as codec: | ||||
|         yield codec | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|     'pld_spec', | ||||
|     [ | ||||
|         Any, | ||||
|         NamespacePath, | ||||
|         NamespacePath|None,  # the "maybe" spec Bo | ||||
|     ], | ||||
|     ids=[ | ||||
|         'any_type', | ||||
|         'only_nsp_ext', | ||||
|         'maybe_nsp_ext', | ||||
|     ] | ||||
| ) | ||||
| @pytest.mark.parametrize( | ||||
|     'add_hooks', | ||||
|     [ | ||||
|         True, | ||||
|         False, | ||||
|     ], | ||||
|     ids=[ | ||||
|         'use_codec_hooks', | ||||
|         'no_codec_hooks', | ||||
|     ], | ||||
| ) | ||||
| def test_ext_types_over_ipc( | ||||
|     debug_mode: bool, | ||||
|     pld_spec: Union[Type], | ||||
|     add_hooks: bool, | ||||
| ): | ||||
|     ''' | ||||
|     Ensure we can support extension types coverted using | ||||
|     `enc/dec_hook()`s passed to the `.msg.limit_plds()` API | ||||
|     and that sane errors happen when we try do the same without | ||||
|     the codec hooks. | ||||
| 
 | ||||
|     ''' | ||||
|     pld_types: set[Type] = _codec.unpack_spec_types(pld_spec) | ||||
| 
 | ||||
|     async def main(): | ||||
| 
 | ||||
|         # sanity check the default pld-spec beforehand | ||||
|         chk_codec_applied( | ||||
|             expect_codec=_codec._def_tractor_codec, | ||||
|         ) | ||||
| 
 | ||||
|         # extension type we want to send as msg payload | ||||
|         nsp = NamespacePath.from_ref(ex_func) | ||||
| 
 | ||||
|         # ^NOTE, 2 cases: | ||||
|         # - codec hooks noto added -> decode nsp as `str` | ||||
|         # - codec with hooks -> decode nsp as `NamespacePath` | ||||
|         nsp_codec: MsgCodec|None = None | ||||
|         if ( | ||||
|             NamespacePath in pld_types | ||||
|             and | ||||
|             add_hooks | ||||
|         ): | ||||
|             nsp_codec = mk_codec( | ||||
|                 enc_hook=enc_nsp, | ||||
|                 ext_types=[NamespacePath], | ||||
|             ) | ||||
| 
 | ||||
|         async with tractor.open_nursery( | ||||
|             debug_mode=debug_mode, | ||||
|         ) as an: | ||||
|             p: tractor.Portal = await an.start_actor( | ||||
|                 'sub', | ||||
|                 enable_modules=[__name__], | ||||
|             ) | ||||
|             with ( | ||||
|                 maybe_apply_codec(nsp_codec) as codec, | ||||
|             ): | ||||
|                 chk_codec_applied( | ||||
|                     expect_codec=nsp_codec, | ||||
|                     enter_value=codec, | ||||
|                 ) | ||||
|                 rent_pld_spec_type_strs: list[str] = _exts.enc_type_union(pld_spec) | ||||
| 
 | ||||
|                 # XXX should raise an mte (`MsgTypeError`) | ||||
|                 # when `add_hooks == False` bc the input | ||||
|                 # `expect_ipc_send` kwarg has a nsp which can't be | ||||
|                 # serialized! | ||||
|                 # | ||||
|                 # TODO:can we ensure this happens from the | ||||
|                 # `Return`-side (aka the sub) as well? | ||||
|                 try: | ||||
|                     ctx: tractor.Context | ||||
|                     ipc: tractor.MsgStream | ||||
|                     async with ( | ||||
| 
 | ||||
|                         # XXX should raise an mte (`MsgTypeError`) | ||||
|                         # when `add_hooks == False`.. | ||||
|                         p.open_context( | ||||
|                             send_back_values, | ||||
|                             # expect_debug=debug_mode, | ||||
|                             rent_pld_spec_type_strs=rent_pld_spec_type_strs, | ||||
|                             add_hooks=add_hooks, | ||||
|                             # expect_ipc_send=expect_ipc_send, | ||||
|                         ) as (ctx, first), | ||||
| 
 | ||||
|                         ctx.open_stream() as ipc, | ||||
|                     ): | ||||
|                         with ( | ||||
|                             limit_plds( | ||||
|                                 pld_spec, | ||||
|                                 dec_hook=dec_nsp if add_hooks else None, | ||||
|                                 ext_types=[NamespacePath]  if add_hooks else None, | ||||
|                             ) as pld_dec, | ||||
|                         ): | ||||
|                             ctx_pld_dec: MsgDec = ctx._pld_rx._pld_dec | ||||
|                             assert pld_dec is ctx_pld_dec | ||||
| 
 | ||||
|                             # if ( | ||||
|                             #     not add_hooks | ||||
|                             #     and | ||||
|                             #     NamespacePath in  | ||||
|                             # ): | ||||
|                             #     pytest.fail('ctx should fail to open without custom enc_hook!?') | ||||
| 
 | ||||
|                             await ipc.send(nsp) | ||||
|                             nsp_rt = await ipc.receive() | ||||
| 
 | ||||
|                             assert nsp_rt == nsp | ||||
|                             assert nsp_rt.load_ref() is ex_func | ||||
| 
 | ||||
|                 # this test passes bc we can go no further! | ||||
|                 except MsgTypeError as mte: | ||||
|                     # if not add_hooks: | ||||
|                     #     # teardown nursery | ||||
|                     #     await p.cancel_actor() | ||||
|                         # return | ||||
| 
 | ||||
|                     raise mte | ||||
| 
 | ||||
|             await p.cancel_actor() | ||||
| 
 | ||||
|     if ( | ||||
|         NamespacePath in pld_types | ||||
|         and | ||||
|         add_hooks | ||||
|     ): | ||||
|         trio.run(main) | ||||
| 
 | ||||
|     else: | ||||
|         with pytest.raises( | ||||
|             expected_exception=tractor.RemoteActorError, | ||||
|         ) as excinfo: | ||||
|             trio.run(main) | ||||
| 
 | ||||
|         exc = excinfo.value | ||||
|         # bc `.started(nsp: NamespacePath)` will raise | ||||
|         assert exc.boxed_type is TypeError | ||||
| 
 | ||||
| 
 | ||||
| # def chk_pld_type( | ||||
| #     payload_spec: Type[Struct]|Any, | ||||
| #     pld: Any, | ||||
| 
 | ||||
| #     expect_roundtrip: bool|None = None, | ||||
| 
 | ||||
| # ) -> bool: | ||||
| 
 | ||||
| #     pld_val_type: Type = type(pld) | ||||
| 
 | ||||
| #     # TODO: verify that the overridden subtypes | ||||
| #     # DO NOT have modified type-annots from original! | ||||
| #     # 'Start',  .pld: FuncSpec | ||||
| #     # 'StartAck',  .pld: IpcCtxSpec | ||||
| #     # 'Stop',  .pld: UNSEt | ||||
| #     # 'Error',  .pld: ErrorData | ||||
| 
 | ||||
| #     codec: MsgCodec = mk_codec( | ||||
| #         # NOTE: this ONLY accepts `PayloadMsg.pld` fields of a specified | ||||
| #         # type union. | ||||
| #         ipc_pld_spec=payload_spec, | ||||
| #     ) | ||||
| 
 | ||||
| #     # make a one-off dec to compare with our `MsgCodec` instance | ||||
| #     # which does the below `mk_msg_spec()` call internally | ||||
| #     ipc_msg_spec: Union[Type[Struct]] | ||||
| #     msg_types: list[PayloadMsg[payload_spec]] | ||||
| #     ( | ||||
| #         ipc_msg_spec, | ||||
| #         msg_types, | ||||
| #     ) = mk_msg_spec( | ||||
| #         payload_type_union=payload_spec, | ||||
| #     ) | ||||
| #     _enc = msgpack.Encoder() | ||||
| #     _dec = msgpack.Decoder( | ||||
| #         type=ipc_msg_spec or Any,  # like `PayloadMsg[Any]` | ||||
| #     ) | ||||
| 
 | ||||
| #     assert ( | ||||
| #         payload_spec | ||||
| #         == | ||||
| #         codec.pld_spec | ||||
| #     ) | ||||
| 
 | ||||
| #     # assert codec.dec == dec | ||||
| #     # | ||||
| #     # ^-XXX-^ not sure why these aren't "equal" but when cast | ||||
| #     # to `str` they seem to match ?? .. kk | ||||
| 
 | ||||
| #     assert ( | ||||
| #         str(ipc_msg_spec) | ||||
| #         == | ||||
| #         str(codec.msg_spec) | ||||
| #         == | ||||
| #         str(_dec.type) | ||||
| #         == | ||||
| #         str(codec.dec.type) | ||||
| #     ) | ||||
| 
 | ||||
| #     # verify the boxed-type for all variable payload-type msgs. | ||||
| #     if not msg_types: | ||||
| #         breakpoint() | ||||
| 
 | ||||
| #     roundtrip: bool|None = None | ||||
| #     pld_spec_msg_names: list[str] = [ | ||||
| #         td.__name__ for td in _payload_msgs | ||||
| #     ] | ||||
| #     for typedef in msg_types: | ||||
| 
 | ||||
| #         skip_runtime_msg: bool = typedef.__name__ not in pld_spec_msg_names | ||||
| #         if skip_runtime_msg: | ||||
| #             continue | ||||
| 
 | ||||
| #         pld_field = structs.fields(typedef)[1] | ||||
| #         assert pld_field.type is payload_spec # TODO-^ does this need to work to get all subtypes to adhere? | ||||
| 
 | ||||
| #         kwargs: dict[str, Any] = { | ||||
| #             'cid': '666', | ||||
| #             'pld': pld, | ||||
| #         } | ||||
| #         enc_msg: PayloadMsg = typedef(**kwargs) | ||||
| 
 | ||||
| #         _wire_bytes: bytes = _enc.encode(enc_msg) | ||||
| #         wire_bytes: bytes = codec.enc.encode(enc_msg) | ||||
| #         assert _wire_bytes == wire_bytes | ||||
| 
 | ||||
| #         ve: ValidationError|None = None | ||||
| #         try: | ||||
| #             dec_msg = codec.dec.decode(wire_bytes) | ||||
| #             _dec_msg = _dec.decode(wire_bytes) | ||||
| 
 | ||||
| #             # decoded msg and thus payload should be exactly same! | ||||
| #             assert (roundtrip := ( | ||||
| #                 _dec_msg | ||||
| #                 == | ||||
| #                 dec_msg | ||||
| #                 == | ||||
| #                 enc_msg | ||||
| #             )) | ||||
| 
 | ||||
| #             if ( | ||||
| #                 expect_roundtrip is not None | ||||
| #                 and expect_roundtrip != roundtrip | ||||
| #             ): | ||||
| #                 breakpoint() | ||||
| 
 | ||||
| #             assert ( | ||||
| #                 pld | ||||
| #                 == | ||||
| #                 dec_msg.pld | ||||
| #                 == | ||||
| #                 enc_msg.pld | ||||
| #             ) | ||||
| #             # assert (roundtrip := (_dec_msg == enc_msg)) | ||||
| 
 | ||||
| #         except ValidationError as _ve: | ||||
| #             ve = _ve | ||||
| #             roundtrip: bool = False | ||||
| #             if pld_val_type is payload_spec: | ||||
| #                 raise ValueError( | ||||
| #                    'Got `ValidationError` despite type-var match!?\n' | ||||
| #                     f'pld_val_type: {pld_val_type}\n' | ||||
| #                     f'payload_type: {payload_spec}\n' | ||||
| #                 ) from ve | ||||
| 
 | ||||
| #             else: | ||||
| #                 # ow we good cuz the pld spec mismatched. | ||||
| #                 print( | ||||
| #                     'Got expected `ValidationError` since,\n' | ||||
| #                     f'{pld_val_type} is not {payload_spec}\n' | ||||
| #                 ) | ||||
| #         else: | ||||
| #             if ( | ||||
| #                 payload_spec is not Any | ||||
| #                 and | ||||
| #                 pld_val_type is not payload_spec | ||||
| #             ): | ||||
| #                 raise ValueError( | ||||
| #                    'DID NOT `ValidationError` despite expected type match!?\n' | ||||
| #                     f'pld_val_type: {pld_val_type}\n' | ||||
| #                     f'payload_type: {payload_spec}\n' | ||||
| #                 ) | ||||
| 
 | ||||
| #     # full code decode should always be attempted! | ||||
| #     if roundtrip is None: | ||||
| #         breakpoint() | ||||
| 
 | ||||
| #     return roundtrip | ||||
| 
 | ||||
| 
 | ||||
| # ?TODO? maybe remove since covered in the newer `test_pldrx_limiting` | ||||
| # via end-2-end testing of all this? | ||||
| # -[ ] IOW do we really NEED this lowlevel unit testing? | ||||
| # | ||||
| # def test_limit_msgspec( | ||||
| #     debug_mode: bool, | ||||
| # ): | ||||
| #     ''' | ||||
| #     Internals unit testing to verify that type-limiting an IPC ctx's | ||||
| #     msg spec with `Pldrx.limit_plds()` results in various | ||||
| #     encapsulated `msgspec` object settings and state. | ||||
| 
 | ||||
| #     ''' | ||||
| #     async def main(): | ||||
| #         async with tractor.open_root_actor( | ||||
| #             debug_mode=debug_mode, | ||||
| #         ): | ||||
| #             # ensure we can round-trip a boxing `PayloadMsg` | ||||
| #             assert chk_pld_type( | ||||
| #                 payload_spec=Any, | ||||
| #                 pld=None, | ||||
| #                 expect_roundtrip=True, | ||||
| #             ) | ||||
| 
 | ||||
| #             # verify that a mis-typed payload value won't decode | ||||
| #             assert not chk_pld_type( | ||||
| #                 payload_spec=int, | ||||
| #                 pld='doggy', | ||||
| #             ) | ||||
| 
 | ||||
| #             # parametrize the boxed `.pld` type as a custom-struct | ||||
| #             # and ensure that parametrization propagates | ||||
| #             # to all payload-msg-spec-able subtypes! | ||||
| #             class CustomPayload(Struct): | ||||
| #                 name: str | ||||
| #                 value: Any | ||||
| 
 | ||||
| #             assert not chk_pld_type( | ||||
| #                 payload_spec=CustomPayload, | ||||
| #                 pld='doggy', | ||||
| #             ) | ||||
| 
 | ||||
| #             assert chk_pld_type( | ||||
| #                 payload_spec=CustomPayload, | ||||
| #                 pld=CustomPayload(name='doggy', value='urmom') | ||||
| #             ) | ||||
| 
 | ||||
| #             # yah, we can `.pause_from_sync()` now! | ||||
| #             # breakpoint() | ||||
| 
 | ||||
| #     trio.run(main) | ||||
|  | @ -0,0 +1,167 @@ | |||
| """ | ||||
| Shared mem primitives and APIs. | ||||
| 
 | ||||
| """ | ||||
| import uuid | ||||
| 
 | ||||
| # import numpy | ||||
| import pytest | ||||
| import trio | ||||
| import tractor | ||||
| from tractor._shm import ( | ||||
|     open_shm_list, | ||||
|     attach_shm_list, | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| @tractor.context | ||||
| async def child_attach_shml_alot( | ||||
|     ctx: tractor.Context, | ||||
|     shm_key: str, | ||||
| ) -> None: | ||||
| 
 | ||||
|     await ctx.started(shm_key) | ||||
| 
 | ||||
|     # now try to attach a boatload of times in a loop.. | ||||
|     for _ in range(1000): | ||||
|         shml = attach_shm_list( | ||||
|             key=shm_key, | ||||
|             readonly=False, | ||||
|         ) | ||||
|         assert shml.shm.name == shm_key | ||||
|         await trio.sleep(0.001) | ||||
| 
 | ||||
| 
 | ||||
| def test_child_attaches_alot(): | ||||
|     async def main(): | ||||
|         async with tractor.open_nursery() as an: | ||||
| 
 | ||||
|             # allocate writeable list in parent | ||||
|             key = f'shml_{uuid.uuid4()}' | ||||
|             shml = open_shm_list( | ||||
|                 key=key, | ||||
|             ) | ||||
| 
 | ||||
|             portal = await an.start_actor( | ||||
|                 'shm_attacher', | ||||
|                 enable_modules=[__name__], | ||||
|             ) | ||||
| 
 | ||||
|             async with ( | ||||
|                 portal.open_context( | ||||
|                     child_attach_shml_alot, | ||||
|                     shm_key=shml.key, | ||||
|                 ) as (ctx, start_val), | ||||
|             ): | ||||
|                 assert start_val == key | ||||
|                 await ctx.result() | ||||
| 
 | ||||
|             await portal.cancel_actor() | ||||
| 
 | ||||
|     trio.run(main) | ||||
| 
 | ||||
| 
 | ||||
| @tractor.context | ||||
| async def child_read_shm_list( | ||||
|     ctx: tractor.Context, | ||||
|     shm_key: str, | ||||
|     use_str: bool, | ||||
|     frame_size: int, | ||||
| ) -> None: | ||||
| 
 | ||||
|     # attach in child | ||||
|     shml = attach_shm_list( | ||||
|         key=shm_key, | ||||
|         # dtype=str if use_str else float, | ||||
|     ) | ||||
|     await ctx.started(shml.key) | ||||
| 
 | ||||
|     async with ctx.open_stream() as stream: | ||||
|         async for i in stream: | ||||
|             print(f'(child): reading shm list index: {i}') | ||||
| 
 | ||||
|             if use_str: | ||||
|                 expect = str(float(i)) | ||||
|             else: | ||||
|                 expect = float(i) | ||||
| 
 | ||||
|             if frame_size == 1: | ||||
|                 val = shml[i] | ||||
|                 assert expect == val | ||||
|                 print(f'(child): reading value: {val}') | ||||
|             else: | ||||
|                 frame = shml[i - frame_size:i] | ||||
|                 print(f'(child): reading frame: {frame}') | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|     'use_str', | ||||
|     [False, True], | ||||
|     ids=lambda i: f'use_str_values={i}', | ||||
| ) | ||||
| @pytest.mark.parametrize( | ||||
|     'frame_size', | ||||
|     [1, 2**6, 2**10], | ||||
|     ids=lambda i: f'frame_size={i}', | ||||
| ) | ||||
| def test_parent_writer_child_reader( | ||||
|     use_str: bool, | ||||
|     frame_size: int, | ||||
| ): | ||||
| 
 | ||||
|     async def main(): | ||||
|         async with tractor.open_nursery( | ||||
|             # debug_mode=True, | ||||
|         ) as an: | ||||
| 
 | ||||
|             portal = await an.start_actor( | ||||
|                 'shm_reader', | ||||
|                 enable_modules=[__name__], | ||||
|                 debug_mode=True, | ||||
|             ) | ||||
| 
 | ||||
|             # allocate writeable list in parent | ||||
|             key = 'shm_list' | ||||
|             seq_size = int(2 * 2 ** 10) | ||||
|             shml = open_shm_list( | ||||
|                 key=key, | ||||
|                 size=seq_size, | ||||
|                 dtype=str if use_str else float, | ||||
|                 readonly=False, | ||||
|             ) | ||||
| 
 | ||||
|             async with ( | ||||
|                 portal.open_context( | ||||
|                     child_read_shm_list, | ||||
|                     shm_key=key, | ||||
|                     use_str=use_str, | ||||
|                     frame_size=frame_size, | ||||
|                 ) as (ctx, sent), | ||||
| 
 | ||||
|                 ctx.open_stream() as stream, | ||||
|             ): | ||||
| 
 | ||||
|                 assert sent == key | ||||
| 
 | ||||
|                 for i in range(seq_size): | ||||
| 
 | ||||
|                     val = float(i) | ||||
|                     if use_str: | ||||
|                         val = str(val) | ||||
| 
 | ||||
|                     # print(f'(parent): writing {val}') | ||||
|                     shml[i] = val | ||||
| 
 | ||||
|                     # only on frame fills do we | ||||
|                     # signal to the child that a frame's | ||||
|                     # worth is ready. | ||||
|                     if (i % frame_size) == 0: | ||||
|                         print(f'(parent): signalling frame full on {val}') | ||||
|                         await stream.send(i) | ||||
|                 else: | ||||
|                     print(f'(parent): signalling final frame on {val}') | ||||
|                     await stream.send(i) | ||||
| 
 | ||||
|             await portal.cancel_actor() | ||||
| 
 | ||||
|     trio.run(main) | ||||
|  | @ -67,4 +67,4 @@ from ._root import ( | |||
| from ._ipc import Channel as Channel | ||||
| from ._portal import Portal as Portal | ||||
| from ._runtime import Actor as Actor | ||||
| from . import hilevel as hilevel | ||||
| # from . import hilevel as hilevel | ||||
|  |  | |||
|  | @ -47,6 +47,9 @@ from functools import partial | |||
| import inspect | ||||
| from pprint import pformat | ||||
| import textwrap | ||||
| from types import ( | ||||
|     UnionType, | ||||
| ) | ||||
| from typing import ( | ||||
|     Any, | ||||
|     AsyncGenerator, | ||||
|  | @ -79,6 +82,7 @@ from .msg import ( | |||
|     MsgType, | ||||
|     NamespacePath, | ||||
|     PayloadT, | ||||
|     Return, | ||||
|     Started, | ||||
|     Stop, | ||||
|     Yield, | ||||
|  | @ -242,11 +246,13 @@ class Context: | |||
|     # a drain loop? | ||||
|     # _res_scope: trio.CancelScope|None = None | ||||
| 
 | ||||
|     _outcome_msg: Return|Error|ContextCancelled = Unresolved | ||||
| 
 | ||||
|     # on a clean exit there should be a final value | ||||
|     # delivered from the far end "callee" task, so | ||||
|     # this value is only set on one side. | ||||
|     # _result: Any | int = None | ||||
|     _result: Any|Unresolved = Unresolved | ||||
|     _result: PayloadT|Unresolved = Unresolved | ||||
| 
 | ||||
|     # if the local "caller"  task errors this value is always set | ||||
|     # to the error that was captured in the | ||||
|  | @ -1196,9 +1202,11 @@ class Context: | |||
| 
 | ||||
|         ''' | ||||
|         __tracebackhide__: bool = hide_tb | ||||
|         assert self._portal, ( | ||||
|             '`Context.wait_for_result()` can not be called from callee side!' | ||||
|         ) | ||||
|         if not self._portal: | ||||
|             raise RuntimeError( | ||||
|                 'Invalid usage of `Context.wait_for_result()`!\n' | ||||
|                 'Not valid on child-side IPC ctx!\n' | ||||
|             ) | ||||
|         if self._final_result_is_set(): | ||||
|             return self._result | ||||
| 
 | ||||
|  | @ -1219,6 +1227,8 @@ class Context: | |||
|             # since every message should be delivered via the normal | ||||
|             # `._deliver_msg()` route which will appropriately set | ||||
|             # any `.maybe_error`. | ||||
|             outcome_msg: Return|Error|ContextCancelled | ||||
|             drained_msgs: list[MsgType] | ||||
|             ( | ||||
|                 outcome_msg, | ||||
|                 drained_msgs, | ||||
|  | @ -1226,11 +1236,19 @@ class Context: | |||
|                 ctx=self, | ||||
|                 hide_tb=hide_tb, | ||||
|             ) | ||||
| 
 | ||||
|             drained_status: str = ( | ||||
|                 'Ctx drained to final outcome msg\n\n' | ||||
|                 f'{outcome_msg}\n' | ||||
|             ) | ||||
| 
 | ||||
|             # ?XXX, should already be set in `._deliver_msg()` right? | ||||
|             if self._outcome_msg is not Unresolved: | ||||
|                 # from .devx import _debug | ||||
|                 # await _debug.pause() | ||||
|                 assert self._outcome_msg is outcome_msg | ||||
|             else: | ||||
|                 self._outcome_msg = outcome_msg | ||||
| 
 | ||||
|             if drained_msgs: | ||||
|                 drained_status += ( | ||||
|                     '\n' | ||||
|  | @ -1738,7 +1756,6 @@ class Context: | |||
| 
 | ||||
|                 f'{structfmt(msg)}\n' | ||||
|             ) | ||||
| 
 | ||||
|             # NOTE: if an error is deteced we should always still | ||||
|             # send it through the feeder-mem-chan and expect | ||||
|             # it to be raised by any context (stream) consumer | ||||
|  | @ -1750,6 +1767,21 @@ class Context: | |||
|             # normally the task that should get cancelled/error | ||||
|             # from some remote fault! | ||||
|             send_chan.send_nowait(msg) | ||||
|             match msg: | ||||
|                 case Stop(): | ||||
|                     if (stream := self._stream): | ||||
|                         stream._stop_msg = msg | ||||
| 
 | ||||
|                 case Return(): | ||||
|                     if not self._outcome_msg: | ||||
|                         log.warning( | ||||
|                             f'Setting final outcome msg AFTER ' | ||||
|                             f'`._rx_chan.send()`??\n' | ||||
|                             f'\n' | ||||
|                             f'{msg}' | ||||
|                         ) | ||||
|                         self._outcome_msg = msg | ||||
| 
 | ||||
|             return True | ||||
| 
 | ||||
|         except trio.BrokenResourceError: | ||||
|  | @ -2006,7 +2038,7 @@ async def open_context_from_portal( | |||
|             # the dialog, the `Error` msg should be raised from the `msg` | ||||
|             # handling block below. | ||||
|             try: | ||||
|                 started_msg, first = await ctx._pld_rx.recv_msg_w_pld( | ||||
|                 started_msg, first = await ctx._pld_rx.recv_msg( | ||||
|                     ipc=ctx, | ||||
|                     expect_msg=Started, | ||||
|                     passthrough_non_pld_msgs=False, | ||||
|  | @ -2371,7 +2403,8 @@ async def open_context_from_portal( | |||
|             # displaying `ContextCancelled` traces where the | ||||
|             # cause of crash/exit IS due to something in | ||||
|             # user/app code on either end of the context. | ||||
|             and not rxchan._closed | ||||
|             and | ||||
|             not rxchan._closed | ||||
|         ): | ||||
|             # XXX NOTE XXX: and again as per above, we mask any | ||||
|             # `trio.Cancelled` raised here so as to NOT mask | ||||
|  | @ -2430,6 +2463,7 @@ async def open_context_from_portal( | |||
|         # FINALLY, remove the context from runtime tracking and | ||||
|         # exit! | ||||
|         log.runtime( | ||||
|         # log.cancel( | ||||
|             f'De-allocating IPC ctx opened with {ctx.side!r} peer \n' | ||||
|             f'uid: {uid}\n' | ||||
|             f'cid: {ctx.cid}\n' | ||||
|  | @ -2485,7 +2519,6 @@ def mk_context( | |||
|         _caller_info=caller_info, | ||||
|         **kwargs, | ||||
|     ) | ||||
|     pld_rx._ctx = ctx | ||||
|     ctx._result = Unresolved | ||||
|     return ctx | ||||
| 
 | ||||
|  | @ -2548,7 +2581,14 @@ def context( | |||
|     name: str | ||||
|     param: Type | ||||
|     for name, param in annots.items(): | ||||
|         if param is Context: | ||||
|         if ( | ||||
|             param is Context | ||||
|             or ( | ||||
|                 isinstance(param, UnionType) | ||||
|                 and | ||||
|                 Context in param.__args__ | ||||
|             ) | ||||
|         ): | ||||
|             ctx_var_name: str = name | ||||
|             break | ||||
|     else: | ||||
|  |  | |||
|  | @ -432,9 +432,13 @@ class RemoteActorError(Exception): | |||
|         Error type boxed by last actor IPC hop. | ||||
| 
 | ||||
|         ''' | ||||
|         if self._boxed_type is None: | ||||
|         if ( | ||||
|             self._boxed_type is None | ||||
|             and | ||||
|             (ipc_msg := self._ipc_msg) | ||||
|         ): | ||||
|             self._boxed_type = get_err_type( | ||||
|                 self._ipc_msg.boxed_type_str | ||||
|                 ipc_msg.boxed_type_str | ||||
|             ) | ||||
| 
 | ||||
|         return self._boxed_type | ||||
|  | @ -1143,6 +1147,8 @@ def unpack_error( | |||
|     which is the responsibilitiy of the caller. | ||||
| 
 | ||||
|     ''' | ||||
|     # XXX, apparently we pass all sorts of msgs here? | ||||
|     # kinda odd but seems like maybe they shouldn't be? | ||||
|     if not isinstance(msg, Error): | ||||
|         return None | ||||
| 
 | ||||
|  |  | |||
|  | @ -184,7 +184,7 @@ class Portal: | |||
|                 ( | ||||
|                     self._final_result_msg, | ||||
|                     self._final_result_pld, | ||||
|                 ) = await self._expect_result_ctx._pld_rx.recv_msg_w_pld( | ||||
|                 ) = await self._expect_result_ctx._pld_rx.recv_msg( | ||||
|                     ipc=self._expect_result_ctx, | ||||
|                     expect_msg=Return, | ||||
|                 ) | ||||
|  |  | |||
|  | @ -649,6 +649,10 @@ async def _invoke( | |||
|                 ) | ||||
|                 # set and shuttle final result to "parent"-side task. | ||||
|                 ctx._result = res | ||||
|                 log.runtime( | ||||
|                     f'Sending result msg and exiting {ctx.side!r}\n' | ||||
|                     f'{return_msg}\n' | ||||
|                 ) | ||||
|                 await chan.send(return_msg) | ||||
| 
 | ||||
|             # NOTE: this happens IFF `ctx._scope.cancel()` is | ||||
|  |  | |||
|  | @ -836,8 +836,10 @@ class Actor: | |||
|             )] | ||||
|         except KeyError: | ||||
|             report: str = ( | ||||
|                 'Ignoring invalid IPC ctx msg!\n\n' | ||||
|                 f'<=? {uid}\n\n' | ||||
|                 'Ignoring invalid IPC msg!?\n' | ||||
|                 f'Ctx seems to not/no-longer exist??\n' | ||||
|                 f'\n' | ||||
|                 f'<=? {uid}\n' | ||||
|                 f'  |_{pretty_struct.pformat(msg)}\n' | ||||
|             ) | ||||
|             match msg: | ||||
|  |  | |||
|  | @ -0,0 +1,833 @@ | |||
| # tractor: structured concurrent "actors". | ||||
| # Copyright 2018-eternity Tyler Goodlet. | ||||
| 
 | ||||
| # This program is free software: you can redistribute it and/or modify | ||||
| # it under the terms of the GNU Affero General Public License as published by | ||||
| # the Free Software Foundation, either version 3 of the License, or | ||||
| # (at your option) any later version. | ||||
| 
 | ||||
| # This program is distributed in the hope that it will be useful, | ||||
| # but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||
| # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||
| # GNU Affero General Public License for more details. | ||||
| 
 | ||||
| # You should have received a copy of the GNU Affero General Public License | ||||
| # along with this program.  If not, see <https://www.gnu.org/licenses/>. | ||||
| 
 | ||||
| """ | ||||
| SC friendly shared memory management geared at real-time | ||||
| processing. | ||||
| 
 | ||||
| Support for ``numpy`` compatible array-buffers is provided but is | ||||
| considered optional within the context of this runtime-library. | ||||
| 
 | ||||
| """ | ||||
| from __future__ import annotations | ||||
| from sys import byteorder | ||||
| import time | ||||
| from typing import Optional | ||||
| from multiprocessing import shared_memory as shm | ||||
| from multiprocessing.shared_memory import ( | ||||
|     SharedMemory, | ||||
|     ShareableList, | ||||
| ) | ||||
| 
 | ||||
| from msgspec import Struct | ||||
| import tractor | ||||
| 
 | ||||
| from .log import get_logger | ||||
| 
 | ||||
| 
 | ||||
| _USE_POSIX = getattr(shm, '_USE_POSIX', False) | ||||
| if _USE_POSIX: | ||||
|     from _posixshmem import shm_unlink | ||||
| 
 | ||||
| 
 | ||||
| try: | ||||
|     import numpy as np | ||||
|     from numpy.lib import recfunctions as rfn | ||||
|     import nptyping | ||||
| except ImportError: | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| log = get_logger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| def disable_mantracker(): | ||||
|     ''' | ||||
|     Disable all ``multiprocessing``` "resource tracking" machinery since | ||||
|     it's an absolute multi-threaded mess of non-SC madness. | ||||
| 
 | ||||
|     ''' | ||||
|     from multiprocessing import resource_tracker as mantracker | ||||
| 
 | ||||
|     # Tell the "resource tracker" thing to fuck off. | ||||
|     class ManTracker(mantracker.ResourceTracker): | ||||
|         def register(self, name, rtype): | ||||
|             pass | ||||
| 
 | ||||
|         def unregister(self, name, rtype): | ||||
|             pass | ||||
| 
 | ||||
|         def ensure_running(self): | ||||
|             pass | ||||
| 
 | ||||
|     # "know your land and know your prey" | ||||
|     # https://www.dailymotion.com/video/x6ozzco | ||||
|     mantracker._resource_tracker = ManTracker() | ||||
|     mantracker.register = mantracker._resource_tracker.register | ||||
|     mantracker.ensure_running = mantracker._resource_tracker.ensure_running | ||||
|     mantracker.unregister = mantracker._resource_tracker.unregister | ||||
|     mantracker.getfd = mantracker._resource_tracker.getfd | ||||
| 
 | ||||
| 
 | ||||
| disable_mantracker() | ||||
| 
 | ||||
| 
 | ||||
| class SharedInt: | ||||
|     ''' | ||||
|     Wrapper around a single entry shared memory array which | ||||
|     holds an ``int`` value used as an index counter. | ||||
| 
 | ||||
|     ''' | ||||
|     def __init__( | ||||
|         self, | ||||
|         shm: SharedMemory, | ||||
|     ) -> None: | ||||
|         self._shm = shm | ||||
| 
 | ||||
|     @property | ||||
|     def value(self) -> int: | ||||
|         return int.from_bytes(self._shm.buf, byteorder) | ||||
| 
 | ||||
|     @value.setter | ||||
|     def value(self, value) -> None: | ||||
|         self._shm.buf[:] = value.to_bytes(self._shm.size, byteorder) | ||||
| 
 | ||||
|     def destroy(self) -> None: | ||||
|         if _USE_POSIX: | ||||
|             # We manually unlink to bypass all the "resource tracker" | ||||
|             # nonsense meant for non-SC systems. | ||||
|             name = self._shm.name | ||||
|             try: | ||||
|                 shm_unlink(name) | ||||
|             except FileNotFoundError: | ||||
|                 # might be a teardown race here? | ||||
|                 log.warning(f'Shm for {name} already unlinked?') | ||||
| 
 | ||||
| 
 | ||||
| class NDToken(Struct, frozen=True): | ||||
|     ''' | ||||
|     Internal represenation of a shared memory ``numpy`` array "token" | ||||
|     which can be used to key and load a system (OS) wide shm entry | ||||
|     and correctly read the array by type signature. | ||||
| 
 | ||||
|     This type is msg safe. | ||||
| 
 | ||||
|     ''' | ||||
|     shm_name: str  # this servers as a "key" value | ||||
|     shm_first_index_name: str | ||||
|     shm_last_index_name: str | ||||
|     dtype_descr: tuple | ||||
|     size: int  # in struct-array index / row terms | ||||
| 
 | ||||
|     # TODO: use nptyping here on dtypes | ||||
|     @property | ||||
|     def dtype(self) -> list[tuple[str, str, tuple[int, ...]]]: | ||||
|         return np.dtype( | ||||
|             list( | ||||
|                 map(tuple, self.dtype_descr) | ||||
|             ) | ||||
|         ).descr | ||||
| 
 | ||||
|     def as_msg(self): | ||||
|         return self.to_dict() | ||||
| 
 | ||||
|     @classmethod | ||||
|     def from_msg(cls, msg: dict) -> NDToken: | ||||
|         if isinstance(msg, NDToken): | ||||
|             return msg | ||||
| 
 | ||||
|         # TODO: native struct decoding | ||||
|         # return _token_dec.decode(msg) | ||||
| 
 | ||||
|         msg['dtype_descr'] = tuple(map(tuple, msg['dtype_descr'])) | ||||
|         return NDToken(**msg) | ||||
| 
 | ||||
| 
 | ||||
| # _token_dec = msgspec.msgpack.Decoder(NDToken) | ||||
| 
 | ||||
| # TODO: this api? | ||||
| # _known_tokens = tractor.ActorVar('_shm_tokens', {}) | ||||
| # _known_tokens = tractor.ContextStack('_known_tokens', ) | ||||
| # _known_tokens = trio.RunVar('shms', {}) | ||||
| 
 | ||||
| # TODO: this should maybe be provided via | ||||
| # a `.trionics.maybe_open_context()` wrapper factory? | ||||
| # process-local store of keys to tokens | ||||
| _known_tokens: dict[str, NDToken] = {} | ||||
| 
 | ||||
| 
 | ||||
| def get_shm_token(key: str) -> NDToken | None: | ||||
|     ''' | ||||
|     Convenience func to check if a token | ||||
|     for the provided key is known by this process. | ||||
| 
 | ||||
|     Returns either the ``numpy`` token or a string for a shared list. | ||||
| 
 | ||||
|     ''' | ||||
|     return _known_tokens.get(key) | ||||
| 
 | ||||
| 
 | ||||
| def _make_token( | ||||
|     key: str, | ||||
|     size: int, | ||||
|     dtype: np.dtype, | ||||
| 
 | ||||
| ) -> NDToken: | ||||
|     ''' | ||||
|     Create a serializable token that can be used | ||||
|     to access a shared array. | ||||
| 
 | ||||
|     ''' | ||||
|     return NDToken( | ||||
|         shm_name=key, | ||||
|         shm_first_index_name=key + "_first", | ||||
|         shm_last_index_name=key + "_last", | ||||
|         dtype_descr=tuple(np.dtype(dtype).descr), | ||||
|         size=size, | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| class ShmArray: | ||||
|     ''' | ||||
|     A shared memory ``numpy.ndarray`` API. | ||||
| 
 | ||||
|     An underlying shared memory buffer is allocated based on | ||||
|     a user specified ``numpy.ndarray``. This fixed size array | ||||
|     can be read and written to by pushing data both onto the "front" | ||||
|     or "back" of a set index range. The indexes for the "first" and | ||||
|     "last" index are themselves stored in shared memory (accessed via | ||||
|     ``SharedInt`` interfaces) values such that multiple processes can | ||||
|     interact with the same array using a synchronized-index. | ||||
| 
 | ||||
|     ''' | ||||
|     def __init__( | ||||
|         self, | ||||
|         shmarr: np.ndarray, | ||||
|         first: SharedInt, | ||||
|         last: SharedInt, | ||||
|         shm: SharedMemory, | ||||
|         # readonly: bool = True, | ||||
|     ) -> None: | ||||
|         self._array = shmarr | ||||
| 
 | ||||
|         # indexes for first and last indices corresponding | ||||
|         # to fille data | ||||
|         self._first = first | ||||
|         self._last = last | ||||
| 
 | ||||
|         self._len = len(shmarr) | ||||
|         self._shm = shm | ||||
|         self._post_init: bool = False | ||||
| 
 | ||||
|         # pushing data does not write the index (aka primary key) | ||||
|         self._write_fields: list[str] | None = None | ||||
|         dtype = shmarr.dtype | ||||
|         if dtype.fields: | ||||
|             self._write_fields = list(shmarr.dtype.fields.keys())[1:] | ||||
| 
 | ||||
|     # TODO: ringbuf api? | ||||
| 
 | ||||
|     @property | ||||
|     def _token(self) -> NDToken: | ||||
|         return NDToken( | ||||
|             shm_name=self._shm.name, | ||||
|             shm_first_index_name=self._first._shm.name, | ||||
|             shm_last_index_name=self._last._shm.name, | ||||
|             dtype_descr=tuple(self._array.dtype.descr), | ||||
|             size=self._len, | ||||
|         ) | ||||
| 
 | ||||
|     @property | ||||
|     def token(self) -> dict: | ||||
|         """Shared memory token that can be serialized and used by | ||||
|         another process to attach to this array. | ||||
|         """ | ||||
|         return self._token.as_msg() | ||||
| 
 | ||||
|     @property | ||||
|     def index(self) -> int: | ||||
|         return self._last.value % self._len | ||||
| 
 | ||||
|     @property | ||||
|     def array(self) -> np.ndarray: | ||||
|         ''' | ||||
|         Return an up-to-date ``np.ndarray`` view of the | ||||
|         so-far-written data to the underlying shm buffer. | ||||
| 
 | ||||
|         ''' | ||||
|         a = self._array[self._first.value:self._last.value] | ||||
| 
 | ||||
|         # first, last = self._first.value, self._last.value | ||||
|         # a = self._array[first:last] | ||||
| 
 | ||||
|         # TODO: eventually comment this once we've not seen it in the | ||||
|         # wild in a long time.. | ||||
|         # XXX: race where first/last indexes cause a reader | ||||
|         # to load an empty array.. | ||||
|         if len(a) == 0 and self._post_init: | ||||
|             raise RuntimeError('Empty array race condition hit!?') | ||||
|             # breakpoint() | ||||
| 
 | ||||
|         return a | ||||
| 
 | ||||
|     def ustruct( | ||||
|         self, | ||||
|         fields: Optional[list[str]] = None, | ||||
| 
 | ||||
|         # type that all field values will be cast to | ||||
|         # in the returned view. | ||||
|         common_dtype: np.dtype = float, | ||||
| 
 | ||||
|     ) -> np.ndarray: | ||||
| 
 | ||||
|         array = self._array | ||||
| 
 | ||||
|         if fields: | ||||
|             selection = array[fields] | ||||
|             # fcount = len(fields) | ||||
|         else: | ||||
|             selection = array | ||||
|             # fcount = len(array.dtype.fields) | ||||
| 
 | ||||
|         # XXX: manual ``.view()`` attempt that also doesn't work. | ||||
|         # uview = selection.view( | ||||
|         #     dtype='<f16', | ||||
|         # ).reshape(-1, 4, order='A') | ||||
| 
 | ||||
|         # assert len(selection) == len(uview) | ||||
| 
 | ||||
|         u = rfn.structured_to_unstructured( | ||||
|             selection, | ||||
|             # dtype=float, | ||||
|             copy=True, | ||||
|         ) | ||||
| 
 | ||||
|         # unstruct = np.ndarray(u.shape, dtype=a.dtype, buffer=shm.buf) | ||||
|         # array[:] = a[:] | ||||
|         return u | ||||
|         # return ShmArray( | ||||
|         #     shmarr=u, | ||||
|         #     first=self._first, | ||||
|         #     last=self._last, | ||||
|         #     shm=self._shm | ||||
|         # ) | ||||
| 
 | ||||
|     def last( | ||||
|         self, | ||||
|         length: int = 1, | ||||
| 
 | ||||
|     ) -> np.ndarray: | ||||
|         ''' | ||||
|         Return the last ``length``'s worth of ("row") entries from the | ||||
|         array. | ||||
| 
 | ||||
|         ''' | ||||
|         return self.array[-length:] | ||||
| 
 | ||||
|     def push( | ||||
|         self, | ||||
|         data: np.ndarray, | ||||
| 
 | ||||
|         field_map: Optional[dict[str, str]] = None, | ||||
|         prepend: bool = False, | ||||
|         update_first: bool = True, | ||||
|         start: int | None = None, | ||||
| 
 | ||||
|     ) -> int: | ||||
|         ''' | ||||
|         Ring buffer like "push" to append data | ||||
|         into the buffer and return updated "last" index. | ||||
| 
 | ||||
|         NB: no actual ring logic yet to give a "loop around" on overflow | ||||
|         condition, lel. | ||||
| 
 | ||||
|         ''' | ||||
|         length = len(data) | ||||
| 
 | ||||
|         if prepend: | ||||
|             index = (start or self._first.value) - length | ||||
| 
 | ||||
|             if index < 0: | ||||
|                 raise ValueError( | ||||
|                     f'Array size of {self._len} was overrun during prepend.\n' | ||||
|                     f'You have passed {abs(index)} too many datums.' | ||||
|                 ) | ||||
| 
 | ||||
|         else: | ||||
|             index = start if start is not None else self._last.value | ||||
| 
 | ||||
|         end = index + length | ||||
| 
 | ||||
|         if field_map: | ||||
|             src_names, dst_names = zip(*field_map.items()) | ||||
|         else: | ||||
|             dst_names = src_names = self._write_fields | ||||
| 
 | ||||
|         try: | ||||
|             self._array[ | ||||
|                 list(dst_names) | ||||
|             ][index:end] = data[list(src_names)][:] | ||||
| 
 | ||||
|             # NOTE: there was a race here between updating | ||||
|             # the first and last indices and when the next reader | ||||
|             # tries to access ``.array`` (which due to the index | ||||
|             # overlap will be empty). Pretty sure we've fixed it now | ||||
|             # but leaving this here as a reminder. | ||||
|             if ( | ||||
|                 prepend | ||||
|                 and update_first | ||||
|                 and length | ||||
|             ): | ||||
|                 assert index < self._first.value | ||||
| 
 | ||||
|             if ( | ||||
|                 index < self._first.value | ||||
|                 and update_first | ||||
|             ): | ||||
|                 assert prepend, 'prepend=True not passed but index decreased?' | ||||
|                 self._first.value = index | ||||
| 
 | ||||
|             elif not prepend: | ||||
|                 self._last.value = end | ||||
| 
 | ||||
|             self._post_init = True | ||||
|             return end | ||||
| 
 | ||||
|         except ValueError as err: | ||||
|             if field_map: | ||||
|                 raise | ||||
| 
 | ||||
|             # should raise if diff detected | ||||
|             self.diff_err_fields(data) | ||||
|             raise err | ||||
| 
 | ||||
|     def diff_err_fields( | ||||
|         self, | ||||
|         data: np.ndarray, | ||||
|     ) -> None: | ||||
|         # reraise with any field discrepancy | ||||
|         our_fields, their_fields = ( | ||||
|             set(self._array.dtype.fields), | ||||
|             set(data.dtype.fields), | ||||
|         ) | ||||
| 
 | ||||
|         only_in_ours = our_fields - their_fields | ||||
|         only_in_theirs = their_fields - our_fields | ||||
| 
 | ||||
|         if only_in_ours: | ||||
|             raise TypeError( | ||||
|                 f"Input array is missing field(s): {only_in_ours}" | ||||
|             ) | ||||
|         elif only_in_theirs: | ||||
|             raise TypeError( | ||||
|                 f"Input array has unknown field(s): {only_in_theirs}" | ||||
|             ) | ||||
| 
 | ||||
|     # TODO: support "silent" prepends that don't update ._first.value? | ||||
|     def prepend( | ||||
|         self, | ||||
|         data: np.ndarray, | ||||
|     ) -> int: | ||||
|         end = self.push(data, prepend=True) | ||||
|         assert end | ||||
| 
 | ||||
|     def close(self) -> None: | ||||
|         self._first._shm.close() | ||||
|         self._last._shm.close() | ||||
|         self._shm.close() | ||||
| 
 | ||||
|     def destroy(self) -> None: | ||||
|         if _USE_POSIX: | ||||
|             # We manually unlink to bypass all the "resource tracker" | ||||
|             # nonsense meant for non-SC systems. | ||||
|             shm_unlink(self._shm.name) | ||||
| 
 | ||||
|         self._first.destroy() | ||||
|         self._last.destroy() | ||||
| 
 | ||||
|     def flush(self) -> None: | ||||
|         # TODO: flush to storage backend like markestore? | ||||
|         ... | ||||
| 
 | ||||
| 
 | ||||
| def open_shm_ndarray( | ||||
|     size: int, | ||||
|     key: str | None = None, | ||||
|     dtype: np.dtype | None = None, | ||||
|     append_start_index: int | None = None, | ||||
|     readonly: bool = False, | ||||
| 
 | ||||
| ) -> ShmArray: | ||||
|     ''' | ||||
|     Open a memory shared ``numpy`` using the standard library. | ||||
| 
 | ||||
|     This call unlinks (aka permanently destroys) the buffer on teardown | ||||
|     and thus should be used from the parent-most accessor (process). | ||||
| 
 | ||||
|     ''' | ||||
|     # create new shared mem segment for which we | ||||
|     # have write permission | ||||
|     a = np.zeros(size, dtype=dtype) | ||||
|     a['index'] = np.arange(len(a)) | ||||
| 
 | ||||
|     shm = SharedMemory( | ||||
|         name=key, | ||||
|         create=True, | ||||
|         size=a.nbytes | ||||
|     ) | ||||
|     array = np.ndarray( | ||||
|         a.shape, | ||||
|         dtype=a.dtype, | ||||
|         buffer=shm.buf | ||||
|     ) | ||||
|     array[:] = a[:] | ||||
|     array.setflags(write=int(not readonly)) | ||||
| 
 | ||||
|     token = _make_token( | ||||
|         key=key, | ||||
|         size=size, | ||||
|         dtype=dtype, | ||||
|     ) | ||||
| 
 | ||||
|     # create single entry arrays for storing an first and last indices | ||||
|     first = SharedInt( | ||||
|         shm=SharedMemory( | ||||
|             name=token.shm_first_index_name, | ||||
|             create=True, | ||||
|             size=4,  # std int | ||||
|         ) | ||||
|     ) | ||||
| 
 | ||||
|     last = SharedInt( | ||||
|         shm=SharedMemory( | ||||
|             name=token.shm_last_index_name, | ||||
|             create=True, | ||||
|             size=4,  # std int | ||||
|         ) | ||||
|     ) | ||||
| 
 | ||||
|     # Start the "real-time" append-updated (or "pushed-to") section | ||||
|     # after some start index: ``append_start_index``. This allows appending | ||||
|     # from a start point in the array which isn't the 0 index and looks | ||||
|     # something like, | ||||
|     # ------------------------- | ||||
|     # |              |        i | ||||
|     # _________________________ | ||||
|     # <-------------> <-------> | ||||
|     #  history         real-time | ||||
|     # | ||||
|     # Once fully "prepended", the history section will leave the | ||||
|     # ``ShmArray._start.value: int = 0`` and the yet-to-be written | ||||
|     # real-time section will start at ``ShmArray.index: int``. | ||||
| 
 | ||||
|     # this sets the index to nearly 2/3rds into the the length of | ||||
|     # the buffer leaving at least a "days worth of second samples" | ||||
|     # for the real-time section. | ||||
|     if append_start_index is None: | ||||
|         append_start_index = round(size * 0.616) | ||||
| 
 | ||||
|     last.value = first.value = append_start_index | ||||
| 
 | ||||
|     shmarr = ShmArray( | ||||
|         array, | ||||
|         first, | ||||
|         last, | ||||
|         shm, | ||||
|     ) | ||||
| 
 | ||||
|     assert shmarr._token == token | ||||
|     _known_tokens[key] = shmarr.token | ||||
| 
 | ||||
|     # "unlink" created shm on process teardown by | ||||
|     # pushing teardown calls onto actor context stack | ||||
|     stack = tractor.current_actor().lifetime_stack | ||||
|     stack.callback(shmarr.close) | ||||
|     stack.callback(shmarr.destroy) | ||||
| 
 | ||||
|     return shmarr | ||||
| 
 | ||||
| 
 | ||||
| def attach_shm_ndarray( | ||||
|     token: tuple[str, str, tuple[str, str]], | ||||
|     readonly: bool = True, | ||||
| 
 | ||||
| ) -> ShmArray: | ||||
|     ''' | ||||
|     Attach to an existing shared memory array previously | ||||
|     created by another process using ``open_shared_array``. | ||||
| 
 | ||||
|     No new shared mem is allocated but wrapper types for read/write | ||||
|     access are constructed. | ||||
| 
 | ||||
|     ''' | ||||
|     token = NDToken.from_msg(token) | ||||
|     key = token.shm_name | ||||
| 
 | ||||
|     if key in _known_tokens: | ||||
|         assert NDToken.from_msg(_known_tokens[key]) == token, "WTF" | ||||
| 
 | ||||
|     # XXX: ugh, looks like due to the ``shm_open()`` C api we can't | ||||
|     # actually place files in a subdir, see discussion here: | ||||
|     # https://stackoverflow.com/a/11103289 | ||||
| 
 | ||||
|     # attach to array buffer and view as per dtype | ||||
|     _err: Optional[Exception] = None | ||||
|     for _ in range(3): | ||||
|         try: | ||||
|             shm = SharedMemory( | ||||
|                 name=key, | ||||
|                 create=False, | ||||
|             ) | ||||
|             break | ||||
|         except OSError as oserr: | ||||
|             _err = oserr | ||||
|             time.sleep(0.1) | ||||
|     else: | ||||
|         if _err: | ||||
|             raise _err | ||||
| 
 | ||||
|     shmarr = np.ndarray( | ||||
|         (token.size,), | ||||
|         dtype=token.dtype, | ||||
|         buffer=shm.buf | ||||
|     ) | ||||
|     shmarr.setflags(write=int(not readonly)) | ||||
| 
 | ||||
|     first = SharedInt( | ||||
|         shm=SharedMemory( | ||||
|             name=token.shm_first_index_name, | ||||
|             create=False, | ||||
|             size=4,  # std int | ||||
|         ), | ||||
|     ) | ||||
|     last = SharedInt( | ||||
|         shm=SharedMemory( | ||||
|             name=token.shm_last_index_name, | ||||
|             create=False, | ||||
|             size=4,  # std int | ||||
|         ), | ||||
|     ) | ||||
| 
 | ||||
|     # make sure we can read | ||||
|     first.value | ||||
| 
 | ||||
|     sha = ShmArray( | ||||
|         shmarr, | ||||
|         first, | ||||
|         last, | ||||
|         shm, | ||||
|     ) | ||||
|     # read test | ||||
|     sha.array | ||||
| 
 | ||||
|     # Stash key -> token knowledge for future queries | ||||
|     # via `maybe_opepn_shm_array()` but only after we know | ||||
|     # we can attach. | ||||
|     if key not in _known_tokens: | ||||
|         _known_tokens[key] = token | ||||
| 
 | ||||
|     # "close" attached shm on actor teardown | ||||
|     tractor.current_actor().lifetime_stack.callback(sha.close) | ||||
| 
 | ||||
|     return sha | ||||
| 
 | ||||
| 
 | ||||
| def maybe_open_shm_ndarray( | ||||
|     key: str,  # unique identifier for segment | ||||
|     size: int, | ||||
|     dtype: np.dtype | None = None, | ||||
|     append_start_index: int = 0, | ||||
|     readonly: bool = True, | ||||
| 
 | ||||
| ) -> tuple[ShmArray, bool]: | ||||
|     ''' | ||||
|     Attempt to attach to a shared memory block using a "key" lookup | ||||
|     to registered blocks in the users overall "system" registry | ||||
|     (presumes you don't have the block's explicit token). | ||||
| 
 | ||||
|     This function is meant to solve the problem of discovering whether | ||||
|     a shared array token has been allocated or discovered by the actor | ||||
|     running in **this** process. Systems where multiple actors may seek | ||||
|     to access a common block can use this function to attempt to acquire | ||||
|     a token as discovered by the actors who have previously stored | ||||
|     a "key" -> ``NDToken`` map in an actor local (aka python global) | ||||
|     variable. | ||||
| 
 | ||||
|     If you know the explicit ``NDToken`` for your memory segment instead | ||||
|     use ``attach_shm_array``. | ||||
| 
 | ||||
|     ''' | ||||
|     try: | ||||
|         # see if we already know this key | ||||
|         token = _known_tokens[key] | ||||
|         return ( | ||||
|             attach_shm_ndarray( | ||||
|                 token=token, | ||||
|                 readonly=readonly, | ||||
|             ), | ||||
|             False,  # not newly opened | ||||
|         ) | ||||
|     except KeyError: | ||||
|         log.warning(f"Could not find {key} in shms cache") | ||||
|         if dtype: | ||||
|             token = _make_token( | ||||
|                 key, | ||||
|                 size=size, | ||||
|                 dtype=dtype, | ||||
|             ) | ||||
|         else: | ||||
| 
 | ||||
|             try: | ||||
|                 return ( | ||||
|                     attach_shm_ndarray( | ||||
|                         token=token, | ||||
|                         readonly=readonly, | ||||
|                     ), | ||||
|                     False, | ||||
|                 ) | ||||
|             except FileNotFoundError: | ||||
|                 log.warning(f"Could not attach to shm with token {token}") | ||||
| 
 | ||||
|         # This actor does not know about memory | ||||
|         # associated with the provided "key". | ||||
|         # Attempt to open a block and expect | ||||
|         # to fail if a block has been allocated | ||||
|         # on the OS by someone else. | ||||
|         return ( | ||||
|             open_shm_ndarray( | ||||
|                 key=key, | ||||
|                 size=size, | ||||
|                 dtype=dtype, | ||||
|                 append_start_index=append_start_index, | ||||
|                 readonly=readonly, | ||||
|             ), | ||||
|             True, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| class ShmList(ShareableList): | ||||
|     ''' | ||||
|     Carbon copy of ``.shared_memory.ShareableList`` with a few | ||||
|     enhancements: | ||||
| 
 | ||||
|     - readonly mode via instance var flag  `._readonly: bool` | ||||
|     - ``.__getitem__()`` accepts ``slice`` inputs | ||||
|     - exposes the underlying buffer "name" as a ``.key: str`` | ||||
| 
 | ||||
|     ''' | ||||
|     def __init__( | ||||
|         self, | ||||
|         sequence: list | None = None, | ||||
|         *, | ||||
|         name: str | None = None, | ||||
|         readonly: bool = True | ||||
| 
 | ||||
|     ) -> None: | ||||
|         self._readonly = readonly | ||||
|         self._key = name | ||||
|         return super().__init__( | ||||
|             sequence=sequence, | ||||
|             name=name, | ||||
|         ) | ||||
| 
 | ||||
|     @property | ||||
|     def key(self) -> str: | ||||
|         return self._key | ||||
| 
 | ||||
|     @property | ||||
|     def readonly(self) -> bool: | ||||
|         return self._readonly | ||||
| 
 | ||||
|     def __setitem__( | ||||
|         self, | ||||
|         position, | ||||
|         value, | ||||
| 
 | ||||
|     ) -> None: | ||||
| 
 | ||||
|         # mimick ``numpy`` error | ||||
|         if self._readonly: | ||||
|             raise ValueError('assignment destination is read-only') | ||||
| 
 | ||||
|         return super().__setitem__(position, value) | ||||
| 
 | ||||
|     def __getitem__( | ||||
|         self, | ||||
|         indexish, | ||||
|     ) -> list: | ||||
| 
 | ||||
|         # NOTE: this is a non-writeable view (copy?) of the buffer | ||||
|         # in a new list instance. | ||||
|         if isinstance(indexish, slice): | ||||
|             return list(self)[indexish] | ||||
| 
 | ||||
|         return super().__getitem__(indexish) | ||||
| 
 | ||||
|     # TODO: should we offer a `.array` and `.push()` equivalent | ||||
|     # to the `ShmArray`? | ||||
|     # currently we have the following limitations: | ||||
|     # - can't write slices of input using traditional slice-assign | ||||
|     #   syntax due to the ``ShareableList.__setitem__()`` implementation. | ||||
|     # - ``list(shmlist)`` returns a non-mutable copy instead of | ||||
|     #   a writeable view which would be handier numpy-style ops. | ||||
| 
 | ||||
| 
 | ||||
| def open_shm_list( | ||||
|     key: str, | ||||
|     sequence: list | None = None, | ||||
|     size: int = int(2 ** 10), | ||||
|     dtype: float | int | bool | str | bytes | None = float, | ||||
|     readonly: bool = True, | ||||
| 
 | ||||
| ) -> ShmList: | ||||
| 
 | ||||
|     if sequence is None: | ||||
|         default = { | ||||
|             float: 0., | ||||
|             int: 0, | ||||
|             bool: True, | ||||
|             str: 'doggy', | ||||
|             None: None, | ||||
|         }[dtype] | ||||
|         sequence = [default] * size | ||||
| 
 | ||||
|     shml = ShmList( | ||||
|         sequence=sequence, | ||||
|         name=key, | ||||
|         readonly=readonly, | ||||
|     ) | ||||
| 
 | ||||
|     # "close" attached shm on actor teardown | ||||
|     try: | ||||
|         actor = tractor.current_actor() | ||||
|         actor.lifetime_stack.callback(shml.shm.close) | ||||
|         actor.lifetime_stack.callback(shml.shm.unlink) | ||||
|     except RuntimeError: | ||||
|         log.warning('tractor runtime not active, skipping teardown steps') | ||||
| 
 | ||||
|     return shml | ||||
| 
 | ||||
| 
 | ||||
| def attach_shm_list( | ||||
|     key: str, | ||||
|     readonly: bool = False, | ||||
| 
 | ||||
| ) -> ShmList: | ||||
| 
 | ||||
|     return ShmList( | ||||
|         name=key, | ||||
|         readonly=readonly, | ||||
|     ) | ||||
|  | @ -45,9 +45,11 @@ from .trionics import ( | |||
|     BroadcastReceiver, | ||||
| ) | ||||
| from tractor.msg import ( | ||||
|     # Return, | ||||
|     # Stop, | ||||
|     Error, | ||||
|     Return, | ||||
|     Stop, | ||||
|     MsgType, | ||||
|     PayloadT, | ||||
|     Yield, | ||||
| ) | ||||
| 
 | ||||
|  | @ -70,8 +72,7 @@ class MsgStream(trio.abc.Channel): | |||
|     A bidirectional message stream for receiving logically sequenced | ||||
|     values over an inter-actor IPC `Channel`. | ||||
| 
 | ||||
|     This is the type returned to a local task which entered either | ||||
|     `Portal.open_stream_from()` or `Context.open_stream()`. | ||||
| 
 | ||||
| 
 | ||||
|     Termination rules: | ||||
| 
 | ||||
|  | @ -94,6 +95,9 @@ class MsgStream(trio.abc.Channel): | |||
|         self._rx_chan = rx_chan | ||||
|         self._broadcaster = _broadcaster | ||||
| 
 | ||||
|         # any actual IPC msg which is effectively an `EndOfStream` | ||||
|         self._stop_msg: bool|Stop = False | ||||
| 
 | ||||
|         # flag to denote end of stream | ||||
|         self._eoc: bool|trio.EndOfChannel = False | ||||
|         self._closed: bool|trio.ClosedResourceError = False | ||||
|  | @ -125,16 +129,67 @@ class MsgStream(trio.abc.Channel): | |||
|     def receive_nowait( | ||||
|         self, | ||||
|         expect_msg: MsgType = Yield, | ||||
|     ): | ||||
|     ) -> PayloadT: | ||||
|         ctx: Context = self._ctx | ||||
|         return ctx._pld_rx.recv_pld_nowait( | ||||
|         ( | ||||
|             msg, | ||||
|             pld, | ||||
|         ) = ctx._pld_rx.recv_msg_nowait( | ||||
|             ipc=self, | ||||
|             expect_msg=expect_msg, | ||||
|         ) | ||||
| 
 | ||||
|         # ?TODO, maybe factor this into a hyper-common `unwrap_pld()` | ||||
|         # | ||||
|         match msg: | ||||
| 
 | ||||
|             # XXX, these never seems to ever hit? cool? | ||||
|             case Stop(): | ||||
|                 log.cancel( | ||||
|                     f'Msg-stream was ended via stop msg\n' | ||||
|                     f'{msg}' | ||||
|                 ) | ||||
|             case Error(): | ||||
|                 log.error( | ||||
|                     f'Msg-stream was ended via error msg\n' | ||||
|                     f'{msg}' | ||||
|                 ) | ||||
| 
 | ||||
|             # XXX NOTE, always set any final result on the ctx to | ||||
|             # avoid teardown race conditions where previously this msg | ||||
|             # would be consumed silently (by `.aclose()` doing its | ||||
|             # own "msg drain loop" but WITHOUT those `drained: lists[MsgType]` | ||||
|             # being post-close-processed! | ||||
|             # | ||||
|             # !!TODO, see the equiv todo-comment in `.receive()` | ||||
|             # around the `if drained:` where we should prolly | ||||
|             # ACTUALLY be doing this post-close processing?? | ||||
|             # | ||||
|             case Return(pld=pld): | ||||
|                 log.warning( | ||||
|                     f'Msg-stream final result msg for IPC ctx?\n' | ||||
|                     f'{msg}' | ||||
|                 ) | ||||
|                 # XXX TODO, this **should be covered** by higher | ||||
|                 # scoped runtime-side method calls such as | ||||
|                 # `Context._deliver_msg()`, so you should never | ||||
|                 # really see the warning above or else something | ||||
|                 # racy/out-of-order is likely going on between | ||||
|                 # actor-runtime-side push tasks and the user-app-side | ||||
|                 # consume tasks! | ||||
|                 # -[ ] figure out that set of race cases and fix! | ||||
|                 # -[ ] possibly return the `msg` given an input | ||||
|                 #     arg-flag is set so we can process the `Return` | ||||
|                 #     from the `.aclose()` caller? | ||||
|                 # | ||||
|                 # breakpoint()  # to debug this RACE CASE! | ||||
|                 ctx._result = pld | ||||
|                 ctx._outcome_msg = msg | ||||
| 
 | ||||
|         return pld | ||||
| 
 | ||||
|     async def receive( | ||||
|         self, | ||||
| 
 | ||||
|         hide_tb: bool = False, | ||||
|     ): | ||||
|         ''' | ||||
|  | @ -154,7 +209,7 @@ class MsgStream(trio.abc.Channel): | |||
|         #     except trio.EndOfChannel: | ||||
|         #         raise StopAsyncIteration | ||||
|         # | ||||
|         # see ``.aclose()`` for notes on the old behaviour prior to | ||||
|         # see `.aclose()` for notes on the old behaviour prior to | ||||
|         # introducing this | ||||
|         if self._eoc: | ||||
|             raise self._eoc | ||||
|  | @ -165,7 +220,11 @@ class MsgStream(trio.abc.Channel): | |||
|         src_err: Exception|None = None  # orig tb | ||||
|         try: | ||||
|             ctx: Context = self._ctx | ||||
|             return await ctx._pld_rx.recv_pld(ipc=self) | ||||
|             pld = await ctx._pld_rx.recv_pld( | ||||
|                 ipc=self, | ||||
|                 expect_msg=Yield, | ||||
|             ) | ||||
|             return pld | ||||
| 
 | ||||
|         # XXX: the stream terminates on either of: | ||||
|         # - `self._rx_chan.receive()` raising  after manual closure | ||||
|  | @ -174,7 +233,7 @@ class MsgStream(trio.abc.Channel): | |||
|         # - via a `Stop`-msg received from remote peer task. | ||||
|         #   NOTE | ||||
|         #   |_ previously this was triggered by calling | ||||
|         #   ``._rx_chan.aclose()`` on the send side of the channel | ||||
|         #   `._rx_chan.aclose()` on the send side of the channel | ||||
|         #   inside `Actor._deliver_ctx_payload()`, but now the 'stop' | ||||
|         #   message handling gets delegated to `PldRFx.recv_pld()` | ||||
|         #   internals. | ||||
|  | @ -198,11 +257,14 @@ class MsgStream(trio.abc.Channel): | |||
|         # terminated and signal this local iterator to stop | ||||
|         drained: list[Exception|dict] = await self.aclose() | ||||
|         if drained: | ||||
|             # ?TODO? pass these to the `._ctx._drained_msgs: deque` | ||||
|             # and then iterate them as part of any `.wait_for_result()` call? | ||||
|             # | ||||
|             # from .devx import pause | ||||
|             # await pause() | ||||
|         #  ^^^^^^^^TODO? pass these to the `._ctx._drained_msgs: | ||||
|         #  deque` and then iterate them as part of any | ||||
|         #  `.wait_for_result()` call? | ||||
|         # | ||||
|         # -[ ] move the match-case processing from | ||||
|         #     `.receive_nowait()` instead to right here, use it from | ||||
|         #     a for msg in drained:` post-proc loop? | ||||
|         # | ||||
|             log.warning( | ||||
|                 'Drained context msgs during closure\n\n' | ||||
|                 f'{drained}' | ||||
|  | @ -265,9 +327,6 @@ class MsgStream(trio.abc.Channel): | |||
|          - more or less we try to maintain adherance to trio's `.aclose()` semantics: | ||||
|            https://trio.readthedocs.io/en/stable/reference-io.html#trio.abc.AsyncResource.aclose | ||||
|         ''' | ||||
| 
 | ||||
|         # rx_chan = self._rx_chan | ||||
| 
 | ||||
|         # XXX NOTE XXX | ||||
|         # it's SUPER IMPORTANT that we ensure we don't DOUBLE | ||||
|         # DRAIN msgs on closure so avoid getting stuck handing on | ||||
|  | @ -279,15 +338,16 @@ class MsgStream(trio.abc.Channel): | |||
|             # this stream has already been closed so silently succeed as | ||||
|             # per ``trio.AsyncResource`` semantics. | ||||
|             # https://trio.readthedocs.io/en/stable/reference-io.html#trio.abc.AsyncResource.aclose | ||||
|             # import tractor | ||||
|             # await tractor.pause() | ||||
|             return [] | ||||
| 
 | ||||
|         ctx: Context = self._ctx | ||||
|         drained: list[Exception|dict] = [] | ||||
|         while not drained: | ||||
|             try: | ||||
|                 maybe_final_msg = self.receive_nowait( | ||||
|                     # allow_msgs=[Yield, Return], | ||||
|                     expect_msg=Yield, | ||||
|                 maybe_final_msg: Yield|Return = self.receive_nowait( | ||||
|                     expect_msg=Yield|Return, | ||||
|                 ) | ||||
|                 if maybe_final_msg: | ||||
|                     log.debug( | ||||
|  | @ -372,8 +432,10 @@ class MsgStream(trio.abc.Channel): | |||
|         #         await rx_chan.aclose() | ||||
| 
 | ||||
|         if not self._eoc: | ||||
|             this_side: str = self._ctx.side | ||||
|             peer_side: str = self._ctx.peer_side | ||||
|             message: str = ( | ||||
|                 f'Stream self-closed by {self._ctx.side!r}-side before EoC\n' | ||||
|                 f'Stream self-closed by {this_side!r}-side before EoC from {peer_side!r}\n' | ||||
|                 # } bc a stream is a "scope"/msging-phase inside an IPC | ||||
|                 f'x}}>\n' | ||||
|                 f'  |_{self}\n' | ||||
|  | @ -381,9 +443,19 @@ class MsgStream(trio.abc.Channel): | |||
|             log.cancel(message) | ||||
|             self._eoc = trio.EndOfChannel(message) | ||||
| 
 | ||||
|             if ( | ||||
|                 (rx_chan := self._rx_chan) | ||||
|                 and | ||||
|                 (stats := rx_chan.statistics()).tasks_waiting_receive | ||||
|             ): | ||||
|                 log.cancel( | ||||
|                     f'Msg-stream is closing but there is still reader tasks,\n' | ||||
|                     f'{stats}\n' | ||||
|                 ) | ||||
| 
 | ||||
|         # ?XXX WAIT, why do we not close the local mem chan `._rx_chan` XXX? | ||||
|         # => NO, DEFINITELY NOT! <= | ||||
|         # if we're a bi-dir ``MsgStream`` BECAUSE this same | ||||
|         # if we're a bi-dir `MsgStream` BECAUSE this same | ||||
|         # core-msg-loop mem recv-chan is used to deliver the | ||||
|         # potential final result from the surrounding inter-actor | ||||
|         # `Context` so we don't want to close it until that | ||||
|  |  | |||
|  | @ -26,6 +26,9 @@ import os | |||
| import pathlib | ||||
| 
 | ||||
| import tractor | ||||
| from tractor.devx._debug import ( | ||||
|     BoxedMaybeException, | ||||
| ) | ||||
| from .pytest import ( | ||||
|     tractor_test as tractor_test | ||||
| ) | ||||
|  | @ -98,12 +101,13 @@ async def expect_ctxc( | |||
|     ''' | ||||
|     if yay: | ||||
|         try: | ||||
|             yield | ||||
|             yield (maybe_exc := BoxedMaybeException()) | ||||
|             raise RuntimeError('Never raised ctxc?') | ||||
|         except tractor.ContextCancelled: | ||||
|         except tractor.ContextCancelled as ctxc: | ||||
|             maybe_exc.value = ctxc | ||||
|             if reraise: | ||||
|                 raise | ||||
|             else: | ||||
|                 return | ||||
|     else: | ||||
|         yield | ||||
|         yield (maybe_exc := BoxedMaybeException()) | ||||
|  |  | |||
|  | @ -33,6 +33,7 @@ from ._codec import ( | |||
| 
 | ||||
|     apply_codec as apply_codec, | ||||
|     mk_codec as mk_codec, | ||||
|     mk_dec as mk_dec, | ||||
|     MsgCodec as MsgCodec, | ||||
|     MsgDec as MsgDec, | ||||
|     current_codec as current_codec, | ||||
|  |  | |||
|  | @ -61,6 +61,7 @@ from tractor.msg.pretty_struct import Struct | |||
| from tractor.msg.types import ( | ||||
|     mk_msg_spec, | ||||
|     MsgType, | ||||
|     PayloadMsg, | ||||
| ) | ||||
| from tractor.log import get_logger | ||||
| 
 | ||||
|  | @ -80,6 +81,7 @@ class MsgDec(Struct): | |||
| 
 | ||||
|     ''' | ||||
|     _dec: msgpack.Decoder | ||||
|     # _ext_types_box: Struct|None = None | ||||
| 
 | ||||
|     @property | ||||
|     def dec(self) -> msgpack.Decoder: | ||||
|  | @ -179,23 +181,126 @@ class MsgDec(Struct): | |||
| 
 | ||||
| 
 | ||||
| def mk_dec( | ||||
|     spec: Union[Type[Struct]]|Any = Any, | ||||
|     spec: Union[Type[Struct]]|Type|None, | ||||
| 
 | ||||
|     # NOTE, required for ad-hoc type extensions to the underlying | ||||
|     # serialization proto (which is default `msgpack`), | ||||
|     # https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types | ||||
|     dec_hook: Callable|None = None, | ||||
|     ext_types: list[Type]|None = None, | ||||
| 
 | ||||
| ) -> MsgDec: | ||||
|     ''' | ||||
|     Create an IPC msg decoder, normally used as the | ||||
|     `PayloadMsg.pld: PayloadT` field decoder inside a `PldRx`. | ||||
|     Create an IPC msg decoder, a slightly higher level wrapper around | ||||
|     a `msgspec.msgpack.Decoder` which provides, | ||||
| 
 | ||||
|     - easier introspection of the underlying type spec via | ||||
|       the `.spec` and `.spec_str` attrs, | ||||
|     - `.hook` access to the `Decoder.dec_hook()`, | ||||
|     - automatic custom extension-types decode support when | ||||
|       `dec_hook()` is provided such that any `PayloadMsg.pld` tagged | ||||
|       as a type from from `ext_types` (presuming the `MsgCodec.encode()` also used | ||||
|       a `.enc_hook()`) is processed and constructed by a `PldRx` implicitily. | ||||
| 
 | ||||
|     NOTE, as mentioned a `MsgDec` is normally used for `PayloadMsg.pld: PayloadT` field | ||||
|     decoding inside an IPC-ctx-oriented `PldRx`. | ||||
| 
 | ||||
|     ''' | ||||
|     if ( | ||||
|         spec is None | ||||
|         and | ||||
|         ext_types is None | ||||
|     ): | ||||
|         raise TypeError( | ||||
|             f'MIssing type-`spec` for msg decoder!\n' | ||||
|             f'\n' | ||||
|             f'`spec=None` is **only** permitted is if custom extension types ' | ||||
|             f'are provided via `ext_types`, meaning it must be non-`None`.\n' | ||||
|             f'\n' | ||||
|             f'In this case it is presumed that only the `ext_types`, ' | ||||
|             f'which much be handled by a paired `dec_hook()`, ' | ||||
|             f'will be permitted within the payload type-`spec`!\n' | ||||
|             f'\n' | ||||
|             f'spec = {spec!r}\n' | ||||
|             f'dec_hook = {dec_hook!r}\n' | ||||
|             f'ext_types = {ext_types!r}\n' | ||||
|         ) | ||||
| 
 | ||||
|     if dec_hook: | ||||
|         if ext_types is None: | ||||
|             raise TypeError( | ||||
|                 f'If extending the serializable types with a custom decode hook (`dec_hook()`), ' | ||||
|                 f'you must also provide the expected type set that the hook will handle ' | ||||
|                 f'via a `ext_types: Union[Type]|None = None` argument!\n' | ||||
|                 f'\n' | ||||
|                 f'dec_hook = {dec_hook!r}\n' | ||||
|                 f'ext_types = {ext_types!r}\n' | ||||
|             ) | ||||
| 
 | ||||
|         # XXX, i *thought* we would require a boxing struct as per docs, | ||||
|         # https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types | ||||
|         # |_ see comment, | ||||
|         #  > Note that typed deserialization is required for | ||||
|         #  > successful roundtripping here, so we pass `MyMessage` to | ||||
|         #  > `Decoder`. | ||||
|         # | ||||
|         # BUT, turns out as long as you spec a union with `Raw` it | ||||
|         # will work? kk B) | ||||
|         # | ||||
|         # maybe_box_struct = mk_boxed_ext_struct(ext_types) | ||||
|         spec = Raw | Union[*ext_types] | ||||
| 
 | ||||
|     return MsgDec( | ||||
|         _dec=msgpack.Decoder( | ||||
|             type=spec,  # like `MsgType[Any]` | ||||
|             dec_hook=dec_hook, | ||||
|         ) | ||||
|         ), | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| # TODO? remove since didn't end up needing this? | ||||
| def mk_boxed_ext_struct( | ||||
|     ext_types: list[Type], | ||||
| ) -> Struct: | ||||
|     # NOTE, originally was to wrap non-msgpack-supported "extension | ||||
|     # types" in a field-typed boxing struct, see notes around the | ||||
|     # `dec_hook()` branch in `mk_dec()`. | ||||
|     ext_types_union = Union[*ext_types] | ||||
|     repr_ext_types_union: str = ( | ||||
|         str(ext_types_union) | ||||
|         or | ||||
|         "|".join(ext_types) | ||||
|     ) | ||||
|     BoxedExtType = msgspec.defstruct( | ||||
|         f'BoxedExts[{repr_ext_types_union}]', | ||||
|         fields=[ | ||||
|             ('boxed', ext_types_union), | ||||
|         ], | ||||
|     ) | ||||
|     return BoxedExtType | ||||
| 
 | ||||
| 
 | ||||
| def unpack_spec_types( | ||||
|     spec: Union[Type]|Type, | ||||
| ) -> set[Type]: | ||||
|     ''' | ||||
|     Given an input type-`spec`, either a lone type | ||||
|     or a `Union` of types (like `str|int|MyThing`), | ||||
|     return a set of individual types. | ||||
| 
 | ||||
|     When `spec` is not a type-union returns `{spec,}`. | ||||
| 
 | ||||
|     ''' | ||||
|     spec_subtypes: set[Union[Type]] = set( | ||||
|          getattr( | ||||
|              spec, | ||||
|              '__args__', | ||||
|              {spec,}, | ||||
|          ) | ||||
|     ) | ||||
|     return spec_subtypes | ||||
| 
 | ||||
| 
 | ||||
| def mk_msgspec_table( | ||||
|     dec: msgpack.Decoder, | ||||
|     msg: MsgType|None = None, | ||||
|  | @ -273,6 +378,8 @@ class MsgCodec(Struct): | |||
|     _dec: msgpack.Decoder | ||||
|     _pld_spec: Type[Struct]|Raw|Any | ||||
| 
 | ||||
|     # _ext_types_box: Struct|None = None | ||||
| 
 | ||||
|     def __repr__(self) -> str: | ||||
|         speclines: str = textwrap.indent( | ||||
|             pformat_msgspec(codec=self), | ||||
|  | @ -339,12 +446,15 @@ class MsgCodec(Struct): | |||
| 
 | ||||
|     def encode( | ||||
|         self, | ||||
|         py_obj: Any, | ||||
|         py_obj: Any|PayloadMsg, | ||||
| 
 | ||||
|         use_buf: bool = False, | ||||
|         # ^-XXX-^ uhh why am i getting this? | ||||
|         # |_BufferError: Existing exports of data: object cannot be re-sized | ||||
| 
 | ||||
|         as_ext_type: bool = False, | ||||
|         hide_tb: bool = True, | ||||
| 
 | ||||
|     ) -> bytes: | ||||
|         ''' | ||||
|         Encode input python objects to `msgpack` bytes for | ||||
|  | @ -354,11 +464,46 @@ class MsgCodec(Struct): | |||
|         https://jcristharif.com/msgspec/perf-tips.html#reusing-an-output-buffer | ||||
| 
 | ||||
|         ''' | ||||
|         __tracebackhide__: bool = hide_tb | ||||
|         if use_buf: | ||||
|             self._enc.encode_into(py_obj, self._buf) | ||||
|             return self._buf | ||||
|         else: | ||||
|             return self._enc.encode(py_obj) | ||||
| 
 | ||||
|         return self._enc.encode(py_obj) | ||||
|         # try: | ||||
|         #     return self._enc.encode(py_obj) | ||||
|         # except TypeError as typerr: | ||||
|         #     typerr.add_note( | ||||
|         #         '|_src error from `msgspec`' | ||||
|         #         # f'|_{self._enc.encode!r}' | ||||
|         #     ) | ||||
|         #     raise typerr | ||||
| 
 | ||||
|         # TODO! REMOVE once i'm confident we won't ever need it! | ||||
|         # | ||||
|         # box: Struct = self._ext_types_box | ||||
|         # if ( | ||||
|         #     as_ext_type | ||||
|         #     or | ||||
|         #     ( | ||||
|         #         # XXX NOTE, auto-detect if the input type | ||||
|         #         box | ||||
|         #         and | ||||
|         #         (ext_types := unpack_spec_types( | ||||
|         #             spec=box.__annotations__['boxed']) | ||||
|         #         ) | ||||
|         #     ) | ||||
|         # ): | ||||
|         #     match py_obj: | ||||
|         #         # case PayloadMsg(pld=pld) if ( | ||||
|         #         #     type(pld) in ext_types | ||||
|         #         # ): | ||||
|         #         #     py_obj.pld = box(boxed=py_obj) | ||||
|         #         #     breakpoint() | ||||
|         #         case _ if ( | ||||
|         #             type(py_obj) in ext_types | ||||
|         #         ): | ||||
|         #             py_obj = box(boxed=py_obj) | ||||
| 
 | ||||
|     @property | ||||
|     def dec(self) -> msgpack.Decoder: | ||||
|  | @ -378,21 +523,30 @@ class MsgCodec(Struct): | |||
|         return self._dec.decode(msg) | ||||
| 
 | ||||
| 
 | ||||
| # [x] TODO: a sub-decoder system as well? => No! | ||||
| # ?TODO? time to remove this finally? | ||||
| # | ||||
| # -[x] TODO: a sub-decoder system as well? | ||||
| # => No! already re-architected to include a "payload-receiver" | ||||
| #   now found in `._ops`. | ||||
| # | ||||
| # -[x] do we still want to try and support the sub-decoder with | ||||
| # `.Raw` technique in the case that the `Generic` approach gives | ||||
| # future grief? | ||||
| # => NO, since we went with the `PldRx` approach instead B) | ||||
| # => well YES but NO, since we went with the `PldRx` approach | ||||
| #   instead! | ||||
| # | ||||
| # IF however you want to see the code that was staged for this | ||||
| # from wayyy back, see the pure removal commit. | ||||
| 
 | ||||
| 
 | ||||
| def mk_codec( | ||||
|     # struct type unions set for `Decoder` | ||||
|     # https://jcristharif.com/msgspec/structs.html#tagged-unions | ||||
|     ipc_pld_spec: Union[Type[Struct]]|Any = Any, | ||||
|     ipc_pld_spec: Union[Type[Struct]]|Any|Raw = Raw, | ||||
|     # tagged-struct-types-union set for `Decoder`ing of payloads, as | ||||
|     # per https://jcristharif.com/msgspec/structs.html#tagged-unions. | ||||
|     # NOTE that the default `Raw` here **is very intentional** since | ||||
|     # the `PldRx._pld_dec: MsgDec` is responsible for per ipc-ctx-task | ||||
|     # decoding of msg-specs defined by the user as part of **their** | ||||
|     # `tractor` "app's" type-limited IPC msg-spec. | ||||
| 
 | ||||
|     # TODO: offering a per-msg(-field) type-spec such that | ||||
|     # the fields can be dynamically NOT decoded and left as `Raw` | ||||
|  | @ -405,13 +559,18 @@ def mk_codec( | |||
| 
 | ||||
|     libname: str = 'msgspec', | ||||
| 
 | ||||
|     # proxy as `Struct(**kwargs)` for ad-hoc type extensions | ||||
|     # settings for encoding-to-send extension-types, | ||||
|     # https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types | ||||
|     # ------ - ------ | ||||
|     dec_hook: Callable|None = None, | ||||
|     # dec_hook: Callable|None = None, | ||||
|     enc_hook: Callable|None = None, | ||||
|     # ------ - ------ | ||||
|     ext_types: list[Type]|None = None, | ||||
| 
 | ||||
|     # optionally provided msg-decoder from which we pull its, | ||||
|     # |_.dec_hook() | ||||
|     # |_.type | ||||
|     ext_dec: MsgDec|None = None | ||||
|     # | ||||
|     # ?TODO? other params we might want to support | ||||
|     # Encoder: | ||||
|     # write_buffer_size=write_buffer_size, | ||||
|     # | ||||
|  | @ -425,26 +584,44 @@ def mk_codec( | |||
|     `msgspec` ;). | ||||
| 
 | ||||
|     ''' | ||||
|     # (manually) generate a msg-payload-spec for all relevant | ||||
|     # god-boxing-msg subtypes, parameterizing the `PayloadMsg.pld: PayloadT` | ||||
|     # for the decoder such that all sub-type msgs in our SCIPP | ||||
|     # will automatically decode to a type-"limited" payload (`Struct`) | ||||
|     # object (set). | ||||
|     pld_spec = ipc_pld_spec | ||||
|     if enc_hook: | ||||
|         if not ext_types: | ||||
|             raise TypeError( | ||||
|                 f'If extending the serializable types with a custom encode hook (`enc_hook()`), ' | ||||
|                 f'you must also provide the expected type set that the hook will handle ' | ||||
|                 f'via a `ext_types: Union[Type]|None = None` argument!\n' | ||||
|                 f'\n' | ||||
|                 f'enc_hook = {enc_hook!r}\n' | ||||
|                 f'ext_types = {ext_types!r}\n' | ||||
|             ) | ||||
| 
 | ||||
|     dec_hook: Callable|None = None | ||||
|     if ext_dec: | ||||
|         dec: msgspec.Decoder = ext_dec.dec | ||||
|         dec_hook = dec.dec_hook | ||||
|         pld_spec |= dec.type | ||||
|         if ext_types: | ||||
|             pld_spec |= Union[*ext_types] | ||||
| 
 | ||||
|     # (manually) generate a msg-spec (how appropes) for all relevant | ||||
|     # payload-boxing-struct-msg-types, parameterizing the | ||||
|     # `PayloadMsg.pld: PayloadT` for the decoder such that all msgs | ||||
|     # in our SC-RPC-protocol will automatically decode to | ||||
|     # a type-"limited" payload (`Struct`) object (set). | ||||
|     ( | ||||
|         ipc_msg_spec, | ||||
|         msg_types, | ||||
|     ) = mk_msg_spec( | ||||
|         payload_type_union=ipc_pld_spec, | ||||
|         payload_type_union=pld_spec, | ||||
|     ) | ||||
|     assert len(ipc_msg_spec.__args__) == len(msg_types) | ||||
|     assert ipc_msg_spec | ||||
| 
 | ||||
|     # TODO: use this shim instead? | ||||
|     # bc.. unification, err somethin? | ||||
|     # dec: MsgDec = mk_dec( | ||||
|     #     spec=ipc_msg_spec, | ||||
|     #     dec_hook=dec_hook, | ||||
|     # ) | ||||
|     msg_spec_types: set[Type] = unpack_spec_types(ipc_msg_spec) | ||||
|     assert ( | ||||
|         len(ipc_msg_spec.__args__) == len(msg_types) | ||||
|         and | ||||
|         len(msg_spec_types) == len(msg_types) | ||||
|     ) | ||||
| 
 | ||||
|     dec = msgpack.Decoder( | ||||
|         type=ipc_msg_spec, | ||||
|  | @ -453,22 +630,29 @@ def mk_codec( | |||
|     enc = msgpack.Encoder( | ||||
|        enc_hook=enc_hook, | ||||
|     ) | ||||
| 
 | ||||
|     codec = MsgCodec( | ||||
|         _enc=enc, | ||||
|         _dec=dec, | ||||
|         _pld_spec=ipc_pld_spec, | ||||
|         _pld_spec=pld_spec, | ||||
|     ) | ||||
| 
 | ||||
|     # sanity on expected backend support | ||||
|     assert codec.lib.__name__ == libname | ||||
| 
 | ||||
|     return codec | ||||
| 
 | ||||
| 
 | ||||
| # instance of the default `msgspec.msgpack` codec settings, i.e. | ||||
| # no custom structs, hooks or other special types. | ||||
| _def_msgspec_codec: MsgCodec = mk_codec(ipc_pld_spec=Any) | ||||
| # | ||||
| # XXX NOTE XXX, this will break our `Context.start()` call! | ||||
| # | ||||
| # * by default we roundtrip the started pld-`value` and if you apply | ||||
| #   this codec (globally anyway with `apply_codec()`) then the | ||||
| #   `roundtripped` value will include a non-`.pld: Raw` which will | ||||
| #   then type-error on the consequent `._ops.validte_payload_msg()`.. | ||||
| # | ||||
| _def_msgspec_codec: MsgCodec = mk_codec( | ||||
|     ipc_pld_spec=Any, | ||||
| ) | ||||
| 
 | ||||
| # The built-in IPC `Msg` spec. | ||||
| # Our composing "shuttle" protocol which allows `tractor`-app code | ||||
|  | @ -476,13 +660,13 @@ _def_msgspec_codec: MsgCodec = mk_codec(ipc_pld_spec=Any) | |||
| # https://jcristharif.com/msgspec/supported-types.html | ||||
| # | ||||
| _def_tractor_codec: MsgCodec = mk_codec( | ||||
|     # TODO: use this for debug mode locking prot? | ||||
|     # ipc_pld_spec=Any, | ||||
|     ipc_pld_spec=Raw, | ||||
|     ipc_pld_spec=Raw,  # XXX should be default righ!? | ||||
| ) | ||||
| # TODO: IDEALLY provides for per-`trio.Task` specificity of the | ||||
| 
 | ||||
| # -[x] TODO, IDEALLY provides for per-`trio.Task` specificity of the | ||||
| # IPC msging codec used by the transport layer when doing | ||||
| # `Channel.send()/.recv()` of wire data. | ||||
| # => impled as our `PldRx` which is `Context` scoped B) | ||||
| 
 | ||||
| # ContextVar-TODO: DIDN'T WORK, kept resetting in every new task to default!? | ||||
| # _ctxvar_MsgCodec: ContextVar[MsgCodec] = ContextVar( | ||||
|  | @ -559,17 +743,6 @@ def apply_codec( | |||
|     ) | ||||
|     token: Token = var.set(codec) | ||||
| 
 | ||||
|     # ?TODO? for TreeVar approach which copies from the | ||||
|     # cancel-scope of the prior value, NOT the prior task | ||||
|     # See the docs: | ||||
|     # - https://tricycle.readthedocs.io/en/latest/reference.html#tree-variables | ||||
|     # - https://github.com/oremanj/tricycle/blob/master/tricycle/_tests/test_tree_var.py | ||||
|     #   ^- see docs for @cm `.being()` API | ||||
|     # with _ctxvar_MsgCodec.being(codec): | ||||
|     #     new = _ctxvar_MsgCodec.get() | ||||
|     #     assert new is codec | ||||
|     #     yield codec | ||||
| 
 | ||||
|     try: | ||||
|         yield var.get() | ||||
|     finally: | ||||
|  | @ -580,6 +753,19 @@ def apply_codec( | |||
|         ) | ||||
|         assert var.get() is orig | ||||
| 
 | ||||
|     # ?TODO? for TreeVar approach which copies from the | ||||
|     # cancel-scope of the prior value, NOT the prior task | ||||
|     # | ||||
|     # See the docs: | ||||
|     # - https://tricycle.readthedocs.io/en/latest/reference.html#tree-variables | ||||
|     # - https://github.com/oremanj/tricycle/blob/master/tricycle/_tests/test_tree_var.py | ||||
|     #   ^- see docs for @cm `.being()` API | ||||
|     # | ||||
|     # with _ctxvar_MsgCodec.being(codec): | ||||
|     #     new = _ctxvar_MsgCodec.get() | ||||
|     #     assert new is codec | ||||
|     #     yield codec | ||||
| 
 | ||||
| 
 | ||||
| def current_codec() -> MsgCodec: | ||||
|     ''' | ||||
|  | @ -599,6 +785,7 @@ def limit_msg_spec( | |||
|     # -> related to the `MsgCodec._payload_decs` stuff above.. | ||||
|     # tagged_structs: list[Struct]|None = None, | ||||
| 
 | ||||
|     hide_tb: bool = True, | ||||
|     **codec_kwargs, | ||||
| 
 | ||||
| ) -> MsgCodec: | ||||
|  | @ -609,7 +796,7 @@ def limit_msg_spec( | |||
|     for all IPC contexts in use by the current `trio.Task`. | ||||
| 
 | ||||
|     ''' | ||||
|     __tracebackhide__: bool = True | ||||
|     __tracebackhide__: bool = hide_tb | ||||
|     curr_codec: MsgCodec = current_codec() | ||||
|     msgspec_codec: MsgCodec = mk_codec( | ||||
|         ipc_pld_spec=payload_spec, | ||||
|  |  | |||
|  | @ -0,0 +1,94 @@ | |||
| # tractor: structured concurrent "actors". | ||||
| # Copyright 2018-eternity Tyler Goodlet. | ||||
| 
 | ||||
| # This program is free software: you can redistribute it and/or modify | ||||
| # it under the terms of the GNU Affero General Public License as published by | ||||
| # the Free Software Foundation, either version 3 of the License, or | ||||
| # (at your option) any later version. | ||||
| 
 | ||||
| # This program is distributed in the hope that it will be useful, | ||||
| # but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||
| # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||
| # GNU Affero General Public License for more details. | ||||
| 
 | ||||
| # You should have received a copy of the GNU Affero General Public License | ||||
| # along with this program.  If not, see <https://www.gnu.org/licenses/>. | ||||
| 
 | ||||
| ''' | ||||
| Type-extension-utils for codec-ing (python) objects not | ||||
| covered by the `msgspec.msgpack` protocol. | ||||
| 
 | ||||
| See the various API docs from `msgspec`. | ||||
| 
 | ||||
| extending from native types, | ||||
| - https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types | ||||
| 
 | ||||
| converters, | ||||
| - https://jcristharif.com/msgspec/converters.html | ||||
| - https://jcristharif.com/msgspec/api.html#msgspec.convert | ||||
| 
 | ||||
| `Raw` fields, | ||||
| - https://jcristharif.com/msgspec/api.html#raw | ||||
| - support for `.convert()` and `Raw`, | ||||
|   |_ https://jcristharif.com/msgspec/changelog.html | ||||
| 
 | ||||
| ''' | ||||
| from types import ( | ||||
|     ModuleType, | ||||
| ) | ||||
| import typing | ||||
| from typing import ( | ||||
|     Type, | ||||
|     Union, | ||||
| ) | ||||
| 
 | ||||
| def dec_type_union( | ||||
|     type_names: list[str], | ||||
|     mods: list[ModuleType] = [] | ||||
| ) -> Type|Union[Type]: | ||||
|     ''' | ||||
|     Look up types by name, compile into a list and then create and | ||||
|     return a `typing.Union` from the full set. | ||||
| 
 | ||||
|     ''' | ||||
|     # import importlib | ||||
|     types: list[Type] = [] | ||||
|     for type_name in type_names: | ||||
|         for mod in [ | ||||
|             typing, | ||||
|             # importlib.import_module(__name__), | ||||
|         ] + mods: | ||||
|             if type_ref := getattr( | ||||
|                 mod, | ||||
|                 type_name, | ||||
|                 False, | ||||
|             ): | ||||
|                 types.append(type_ref) | ||||
| 
 | ||||
|     # special case handling only.. | ||||
|     # ipc_pld_spec: Union[Type] = eval( | ||||
|     #     pld_spec_str, | ||||
|     #     {},  # globals | ||||
|     #     {'typing': typing},  # locals | ||||
|     # ) | ||||
| 
 | ||||
|     return Union[*types] | ||||
| 
 | ||||
| 
 | ||||
| def enc_type_union( | ||||
|     union_or_type: Union[Type]|Type, | ||||
| ) -> list[str]: | ||||
|     ''' | ||||
|     Encode a type-union or single type to a list of type-name-strings | ||||
|     ready for IPC interchange. | ||||
| 
 | ||||
|     ''' | ||||
|     type_strs: list[str] = [] | ||||
|     for typ in getattr( | ||||
|         union_or_type, | ||||
|         '__args__', | ||||
|         {union_or_type,}, | ||||
|     ): | ||||
|         type_strs.append(typ.__qualname__) | ||||
| 
 | ||||
|     return type_strs | ||||
|  | @ -50,7 +50,9 @@ from tractor._exceptions import ( | |||
|     _mk_recv_mte, | ||||
|     pack_error, | ||||
| ) | ||||
| from tractor._state import current_ipc_ctx | ||||
| from tractor._state import ( | ||||
|     current_ipc_ctx, | ||||
| ) | ||||
| from ._codec import ( | ||||
|     mk_dec, | ||||
|     MsgDec, | ||||
|  | @ -78,7 +80,7 @@ if TYPE_CHECKING: | |||
| log = get_logger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| _def_any_pldec: MsgDec[Any] = mk_dec() | ||||
| _def_any_pldec: MsgDec[Any] = mk_dec(spec=Any) | ||||
| 
 | ||||
| 
 | ||||
| class PldRx(Struct): | ||||
|  | @ -108,33 +110,11 @@ class PldRx(Struct): | |||
|     # TODO: better to bind it here? | ||||
|     # _rx_mc: trio.MemoryReceiveChannel | ||||
|     _pld_dec: MsgDec | ||||
|     _ctx: Context|None = None | ||||
|     _ipc: Context|MsgStream|None = None | ||||
| 
 | ||||
|     @property | ||||
|     def pld_dec(self) -> MsgDec: | ||||
|         return self._pld_dec | ||||
| 
 | ||||
|     # TODO: a better name? | ||||
|     # -[ ] when would this be used as it avoids needingn to pass the | ||||
|     #   ipc prim to every method | ||||
|     @cm | ||||
|     def wraps_ipc( | ||||
|         self, | ||||
|         ipc_prim: Context|MsgStream, | ||||
| 
 | ||||
|     ) -> PldRx: | ||||
|         ''' | ||||
|         Apply this payload receiver to an IPC primitive type, one | ||||
|         of `Context` or `MsgStream`. | ||||
| 
 | ||||
|         ''' | ||||
|         self._ipc = ipc_prim | ||||
|         try: | ||||
|             yield self | ||||
|         finally: | ||||
|             self._ipc = None | ||||
| 
 | ||||
|     @cm | ||||
|     def limit_plds( | ||||
|         self, | ||||
|  | @ -148,6 +128,10 @@ class PldRx(Struct): | |||
|         exit. | ||||
| 
 | ||||
|         ''' | ||||
|         # TODO, ensure we pull the current `MsgCodec`'s custom | ||||
|         # dec/enc_hook settings as well ? | ||||
|         # -[ ] see `._codec.mk_codec()` inputs | ||||
|         # | ||||
|         orig_dec: MsgDec = self._pld_dec | ||||
|         limit_dec: MsgDec = mk_dec( | ||||
|             spec=spec, | ||||
|  | @ -163,7 +147,7 @@ class PldRx(Struct): | |||
|     def dec(self) -> msgpack.Decoder: | ||||
|         return self._pld_dec.dec | ||||
| 
 | ||||
|     def recv_pld_nowait( | ||||
|     def recv_msg_nowait( | ||||
|         self, | ||||
|         # TODO: make this `MsgStream` compat as well, see above^ | ||||
|         # ipc_prim: Context|MsgStream, | ||||
|  | @ -174,34 +158,95 @@ class PldRx(Struct): | |||
|         hide_tb: bool = False, | ||||
|         **dec_pld_kwargs, | ||||
| 
 | ||||
|     ) -> Any|Raw: | ||||
|     ) -> tuple[ | ||||
|         MsgType[PayloadT], | ||||
|         PayloadT, | ||||
|     ]: | ||||
|         ''' | ||||
|         Attempt to non-blocking receive a message from the `._rx_chan` and | ||||
|         unwrap it's payload delivering the pair to the caller. | ||||
| 
 | ||||
|         ''' | ||||
|         __tracebackhide__: bool = hide_tb | ||||
| 
 | ||||
|         msg: MsgType = ( | ||||
|             ipc_msg | ||||
|             or | ||||
| 
 | ||||
|             # sync-rx msg from underlying IPC feeder (mem-)chan | ||||
|             ipc._rx_chan.receive_nowait() | ||||
|         ) | ||||
|         return self.decode_pld( | ||||
|         pld: PayloadT = self.decode_pld( | ||||
|             msg, | ||||
|             ipc=ipc, | ||||
|             expect_msg=expect_msg, | ||||
|             hide_tb=hide_tb, | ||||
|             **dec_pld_kwargs, | ||||
|         ) | ||||
|         return ( | ||||
|             msg, | ||||
|             pld, | ||||
|         ) | ||||
| 
 | ||||
|     async def recv_msg( | ||||
|         self, | ||||
|         ipc: Context|MsgStream, | ||||
|         expect_msg: MsgType, | ||||
| 
 | ||||
|         # NOTE: ONLY for handling `Stop`-msgs that arrive during | ||||
|         # a call to `drain_to_final_msg()` above! | ||||
|         passthrough_non_pld_msgs: bool = True, | ||||
|         hide_tb: bool = True, | ||||
| 
 | ||||
|         **decode_pld_kwargs, | ||||
| 
 | ||||
|     ) -> tuple[MsgType, PayloadT]: | ||||
|         ''' | ||||
|         Retrieve the next avail IPC msg, decode its payload, and | ||||
|         return the (msg, pld) pair. | ||||
| 
 | ||||
|         ''' | ||||
|         __tracebackhide__: bool = hide_tb | ||||
|         msg: MsgType = await ipc._rx_chan.receive() | ||||
|         match msg: | ||||
|             case Return()|Error(): | ||||
|                 log.runtime( | ||||
|                     f'Rxed final outcome msg\n' | ||||
|                     f'{msg}\n' | ||||
|                 ) | ||||
|             case Stop(): | ||||
|                 log.runtime( | ||||
|                     f'Rxed stream stopped msg\n' | ||||
|                     f'{msg}\n' | ||||
|                 ) | ||||
|                 if passthrough_non_pld_msgs: | ||||
|                     return msg, None | ||||
| 
 | ||||
|         # TODO: is there some way we can inject the decoded | ||||
|         # payload into an existing output buffer for the original | ||||
|         # msg instance? | ||||
|         pld: PayloadT = self.decode_pld( | ||||
|             msg, | ||||
|             ipc=ipc, | ||||
|             expect_msg=expect_msg, | ||||
|             hide_tb=hide_tb, | ||||
| 
 | ||||
|             **decode_pld_kwargs, | ||||
|         ) | ||||
|         return ( | ||||
|             msg, | ||||
|             pld, | ||||
|         ) | ||||
| 
 | ||||
|     async def recv_pld( | ||||
|         self, | ||||
|         ipc: Context|MsgStream, | ||||
|         ipc_msg: MsgType|None = None, | ||||
|         ipc_msg: MsgType[PayloadT]|None = None, | ||||
|         expect_msg: Type[MsgType]|None = None, | ||||
|         hide_tb: bool = True, | ||||
| 
 | ||||
|         **dec_pld_kwargs, | ||||
| 
 | ||||
|     ) -> Any|Raw: | ||||
|     ) -> PayloadT: | ||||
|         ''' | ||||
|         Receive a `MsgType`, then decode and return its `.pld` field. | ||||
| 
 | ||||
|  | @ -213,6 +258,13 @@ class PldRx(Struct): | |||
|             # async-rx msg from underlying IPC feeder (mem-)chan | ||||
|             await ipc._rx_chan.receive() | ||||
|         ) | ||||
|         if ( | ||||
|             type(msg) is Return | ||||
|         ): | ||||
|             log.info( | ||||
|                 f'Rxed final result msg\n' | ||||
|                 f'{msg}\n' | ||||
|             ) | ||||
|         return self.decode_pld( | ||||
|             msg=msg, | ||||
|             ipc=ipc, | ||||
|  | @ -401,45 +453,6 @@ class PldRx(Struct): | |||
|             __tracebackhide__: bool = False | ||||
|             raise | ||||
| 
 | ||||
|     dec_msg = decode_pld | ||||
| 
 | ||||
|     async def recv_msg_w_pld( | ||||
|         self, | ||||
|         ipc: Context|MsgStream, | ||||
|         expect_msg: MsgType, | ||||
| 
 | ||||
|         # NOTE: generally speaking only for handling `Stop`-msgs that | ||||
|         # arrive during a call to `drain_to_final_msg()` above! | ||||
|         passthrough_non_pld_msgs: bool = True, | ||||
|         hide_tb: bool = True, | ||||
|         **kwargs, | ||||
| 
 | ||||
|     ) -> tuple[MsgType, PayloadT]: | ||||
|         ''' | ||||
|         Retrieve the next avail IPC msg, decode it's payload, and return | ||||
|         the pair of refs. | ||||
| 
 | ||||
|         ''' | ||||
|         __tracebackhide__: bool = hide_tb | ||||
|         msg: MsgType = await ipc._rx_chan.receive() | ||||
| 
 | ||||
|         if passthrough_non_pld_msgs: | ||||
|             match msg: | ||||
|                 case Stop(): | ||||
|                     return msg, None | ||||
| 
 | ||||
|         # TODO: is there some way we can inject the decoded | ||||
|         # payload into an existing output buffer for the original | ||||
|         # msg instance? | ||||
|         pld: PayloadT = self.decode_pld( | ||||
|             msg, | ||||
|             ipc=ipc, | ||||
|             expect_msg=expect_msg, | ||||
|             hide_tb=hide_tb, | ||||
|             **kwargs, | ||||
|         ) | ||||
|         return msg, pld | ||||
| 
 | ||||
| 
 | ||||
| @cm | ||||
| def limit_plds( | ||||
|  | @ -455,11 +468,16 @@ def limit_plds( | |||
| 
 | ||||
|     ''' | ||||
|     __tracebackhide__: bool = True | ||||
|     curr_ctx: Context|None = current_ipc_ctx() | ||||
|     if curr_ctx is None: | ||||
|         raise RuntimeError( | ||||
|             'No IPC `Context` is active !?\n' | ||||
|             'Did you open `limit_plds()` from outside ' | ||||
|             'a `Portal.open_context()` scope-block?' | ||||
|         ) | ||||
|     try: | ||||
|         curr_ctx: Context = current_ipc_ctx() | ||||
|         rx: PldRx = curr_ctx._pld_rx | ||||
|         orig_pldec: MsgDec = rx.pld_dec | ||||
| 
 | ||||
|         with rx.limit_plds( | ||||
|             spec=spec, | ||||
|             **dec_kwargs, | ||||
|  | @ -469,6 +487,11 @@ def limit_plds( | |||
|                 f'{pldec}\n' | ||||
|             ) | ||||
|             yield pldec | ||||
| 
 | ||||
|     except BaseException: | ||||
|         __tracebackhide__: bool = False | ||||
|         raise | ||||
| 
 | ||||
|     finally: | ||||
|         log.runtime( | ||||
|             'Reverted to previous payload-decoder\n\n' | ||||
|  | @ -522,8 +545,8 @@ async def maybe_limit_plds( | |||
| async def drain_to_final_msg( | ||||
|     ctx: Context, | ||||
| 
 | ||||
|     hide_tb: bool = True, | ||||
|     msg_limit: int = 6, | ||||
|     hide_tb: bool = True, | ||||
| 
 | ||||
| ) -> tuple[ | ||||
|     Return|None, | ||||
|  | @ -552,8 +575,8 @@ async def drain_to_final_msg( | |||
|     even after ctx closure and the `.open_context()` block exit. | ||||
| 
 | ||||
|     ''' | ||||
|     __tracebackhide__: bool = hide_tb | ||||
|     raise_overrun: bool = not ctx._allow_overruns | ||||
|     parent_never_opened_stream: bool = ctx._stream is None | ||||
| 
 | ||||
|     # wait for a final context result by collecting (but | ||||
|     # basically ignoring) any bi-dir-stream msgs still in transit | ||||
|  | @ -562,13 +585,14 @@ async def drain_to_final_msg( | |||
|     result_msg: Return|Error|None = None | ||||
|     while not ( | ||||
|         ctx.maybe_error | ||||
|         and not ctx._final_result_is_set() | ||||
|         and | ||||
|         not ctx._final_result_is_set() | ||||
|     ): | ||||
|         try: | ||||
|             # receive all msgs, scanning for either a final result | ||||
|             # or error; the underlying call should never raise any | ||||
|             # remote error directly! | ||||
|             msg, pld = await ctx._pld_rx.recv_msg_w_pld( | ||||
|             msg, pld = await ctx._pld_rx.recv_msg( | ||||
|                 ipc=ctx, | ||||
|                 expect_msg=Return, | ||||
|                 raise_error=False, | ||||
|  | @ -615,6 +639,11 @@ async def drain_to_final_msg( | |||
|                     ) | ||||
|                     __tracebackhide__: bool = False | ||||
| 
 | ||||
|             else: | ||||
|                 log.cancel( | ||||
|                     f'IPC ctx cancelled externally during result drain ?\n' | ||||
|                     f'{ctx}' | ||||
|                 ) | ||||
|             # CASE 2: mask the local cancelled-error(s) | ||||
|             # only when we are sure the remote error is | ||||
|             # the source cause of this local task's | ||||
|  | @ -646,17 +675,24 @@ async def drain_to_final_msg( | |||
|             case Yield(): | ||||
|                 pre_result_drained.append(msg) | ||||
|                 if ( | ||||
|                     (ctx._stream.closed | ||||
|                      and (reason := 'stream was already closed') | ||||
|                     ) | ||||
|                     or (ctx.cancel_acked | ||||
|                         and (reason := 'ctx cancelled other side') | ||||
|                     ) | ||||
|                     or (ctx._cancel_called | ||||
|                         and (reason := 'ctx called `.cancel()`') | ||||
|                     ) | ||||
|                     or (len(pre_result_drained) > msg_limit | ||||
|                         and (reason := f'"yield" limit={msg_limit}') | ||||
|                     not parent_never_opened_stream | ||||
|                     and ( | ||||
|                         (ctx._stream.closed | ||||
|                          and | ||||
|                          (reason := 'stream was already closed') | ||||
|                         ) or | ||||
|                         (ctx.cancel_acked | ||||
|                             and | ||||
|                             (reason := 'ctx cancelled other side') | ||||
|                         ) | ||||
|                         or (ctx._cancel_called | ||||
|                             and | ||||
|                             (reason := 'ctx called `.cancel()`') | ||||
|                         ) | ||||
|                         or (len(pre_result_drained) > msg_limit | ||||
|                             and | ||||
|                             (reason := f'"yield" limit={msg_limit}') | ||||
|                         ) | ||||
|                     ) | ||||
|                 ): | ||||
|                     log.cancel( | ||||
|  | @ -674,7 +710,7 @@ async def drain_to_final_msg( | |||
|                 # drain up to the `msg_limit` hoping to get | ||||
|                 # a final result or error/ctxc. | ||||
|                 else: | ||||
|                     log.warning( | ||||
|                     report: str = ( | ||||
|                         'Ignoring "yield" msg during `ctx.result()` drain..\n' | ||||
|                         f'<= {ctx.chan.uid}\n' | ||||
|                         f'  |_{ctx._nsf}()\n\n' | ||||
|  | @ -683,6 +719,14 @@ async def drain_to_final_msg( | |||
| 
 | ||||
|                         f'{pretty_struct.pformat(msg)}\n' | ||||
|                     ) | ||||
|                     if parent_never_opened_stream: | ||||
|                         report = ( | ||||
|                             f'IPC ctx never opened stream on {ctx.side!r}-side!\n' | ||||
|                             f'\n' | ||||
|                             # f'{ctx}\n' | ||||
|                         ) + report | ||||
| 
 | ||||
|                     log.warning(report) | ||||
|                     continue | ||||
| 
 | ||||
|             # stream terminated, but no result yet.. | ||||
|  | @ -774,6 +818,7 @@ async def drain_to_final_msg( | |||
|             f'{ctx.outcome}\n' | ||||
|         ) | ||||
| 
 | ||||
|     __tracebackhide__: bool = hide_tb | ||||
|     return ( | ||||
|         result_msg, | ||||
|         pre_result_drained, | ||||
|  |  | |||
|  | @ -599,15 +599,15 @@ def mk_msg_spec( | |||
|         Msg[payload_type_union], | ||||
|         Generic[PayloadT], | ||||
|     ) | ||||
|     defstruct_bases: tuple = ( | ||||
|         Msg, # [payload_type_union], | ||||
|         # Generic[PayloadT], | ||||
|         # ^-XXX-^: not allowed? lul.. | ||||
|     ) | ||||
|     # defstruct_bases: tuple = ( | ||||
|     #     Msg, # [payload_type_union], | ||||
|     #     # Generic[PayloadT], | ||||
|     #     # ^-XXX-^: not allowed? lul.. | ||||
|     # ) | ||||
|     ipc_msg_types: list[Msg] = [] | ||||
| 
 | ||||
|     idx_msg_types: list[Msg] = [] | ||||
|     defs_msg_types: list[Msg] = [] | ||||
|     # defs_msg_types: list[Msg] = [] | ||||
|     nc_msg_types: list[Msg] = [] | ||||
| 
 | ||||
|     for msgtype in __msg_types__: | ||||
|  | @ -625,7 +625,7 @@ def mk_msg_spec( | |||
|         # TODO: wait why do we need the dynamic version here? | ||||
|         # XXX ANSWER XXX -> BC INHERITANCE.. don't work w generics.. | ||||
|         # | ||||
|         # NOTE previously bc msgtypes WERE NOT inheritting | ||||
|         # NOTE previously bc msgtypes WERE NOT inheriting | ||||
|         # directly the `Generic[PayloadT]` type, the manual method | ||||
|         # of generic-paraming with `.__class_getitem__()` wasn't | ||||
|         # working.. | ||||
|  | @ -662,38 +662,35 @@ def mk_msg_spec( | |||
| 
 | ||||
|         # with `msgspec.structs.defstruct` | ||||
|         # XXX ALSO DOESN'T WORK | ||||
|         defstruct_msgtype = defstruct( | ||||
|             name=msgtype.__name__, | ||||
|             fields=[ | ||||
|                 ('cid', str), | ||||
|         # defstruct_msgtype = defstruct( | ||||
|         #     name=msgtype.__name__, | ||||
|         #     fields=[ | ||||
|         #         ('cid', str), | ||||
| 
 | ||||
|                 # XXX doesn't seem to work.. | ||||
|                 # ('pld', PayloadT), | ||||
| 
 | ||||
|                 ('pld', payload_type_union), | ||||
|             ], | ||||
|             bases=defstruct_bases, | ||||
|         ) | ||||
|         defs_msg_types.append(defstruct_msgtype) | ||||
|         #         # XXX doesn't seem to work.. | ||||
|         #         # ('pld', PayloadT), | ||||
| 
 | ||||
|         #         ('pld', payload_type_union), | ||||
|         #     ], | ||||
|         #     bases=defstruct_bases, | ||||
|         # ) | ||||
|         # defs_msg_types.append(defstruct_msgtype) | ||||
|         # assert index_paramed_msg_type == manual_paramed_msg_subtype | ||||
| 
 | ||||
|         # paramed_msg_type = manual_paramed_msg_subtype | ||||
| 
 | ||||
|         # ipc_payload_msgs_type_union |= index_paramed_msg_type | ||||
| 
 | ||||
|     idx_spec: Union[Type[Msg]] = Union[*idx_msg_types] | ||||
|     def_spec: Union[Type[Msg]] = Union[*defs_msg_types] | ||||
|     # def_spec: Union[Type[Msg]] = Union[*defs_msg_types] | ||||
|     nc_spec: Union[Type[Msg]] = Union[*nc_msg_types] | ||||
| 
 | ||||
|     specs: dict[str, Union[Type[Msg]]] = { | ||||
|         'indexed_generics': idx_spec, | ||||
|         'defstruct': def_spec, | ||||
|         # 'defstruct': def_spec, | ||||
|         'types_new_class': nc_spec, | ||||
|     } | ||||
|     msgtypes_table: dict[str, list[Msg]] = { | ||||
|         'indexed_generics': idx_msg_types, | ||||
|         'defstruct': defs_msg_types, | ||||
|         # 'defstruct': defs_msg_types, | ||||
|         'types_new_class': nc_msg_types, | ||||
|     } | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue