diff --git a/tractor/_ipc.py b/tractor/_ipc.py index f57d3bd..2b5df69 100644 --- a/tractor/_ipc.py +++ b/tractor/_ipc.py @@ -30,6 +30,7 @@ import struct import typing from typing import ( Any, + Callable, runtime_checkable, Protocol, Type, @@ -123,6 +124,16 @@ class MsgpackTCPStream(MsgTransport): stream: trio.SocketStream, prefix_size: int = 4, + # XXX optionally provided codec pair for `msgspec`: + # https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types + # + # TODO: define this as a `Codec` struct which can be + # overriden dynamically by the application/runtime. + codec: tuple[ + Callable[[Any], Any]|None, # coder + Callable[[type, Any], Any]|None, # decoder + ]|None = None, + ) -> None: self.stream = stream @@ -138,12 +149,18 @@ class MsgpackTCPStream(MsgTransport): # public i guess? self.drained: list[dict] = [] - self.recv_stream = BufferedReceiveStream(transport_stream=stream) + self.recv_stream = BufferedReceiveStream( + transport_stream=stream + ) self.prefix_size = prefix_size # TODO: struct aware messaging coders - self.encode = msgspec.msgpack.Encoder().encode - self.decode = msgspec.msgpack.Decoder().decode # dict[str, Any]) + self.encode = msgspec.msgpack.Encoder( + enc_hook=codec[0] if codec else None, + ).encode + self.decode = msgspec.msgpack.Decoder( + dec_hook=codec[1] if codec else None, + ).decode async def _iter_packets(self) -> AsyncGenerator[dict, None]: '''Yield packets from the underlying stream. @@ -349,9 +366,25 @@ class Channel: stream: trio.SocketStream, type_key: tuple[str, str]|None = None, + # XXX optionally provided codec pair for `msgspec`: + # https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types + codec: tuple[ + Callable[[Any], Any], # coder + Callable[[type, Any], Any], # decoder + ]|None = None, + ) -> MsgTransport: - type_key = type_key or self._transport_key - self._transport = get_msg_transport(type_key)(stream) + type_key = ( + type_key + or + self._transport_key + ) + self._transport = get_msg_transport( + type_key + )( + stream, + codec=codec, + ) return self._transport def __repr__(self) -> str: