diff --git a/tractor/_discovery.py b/tractor/_discovery.py index 1c3cbff0..f6f4b9d9 100644 --- a/tractor/_discovery.py +++ b/tractor/_discovery.py @@ -75,7 +75,7 @@ async def get_registry( # TODO: try to look pre-existing connection from # `Actor._peers` and use it instead? async with ( - _connect_chan(host, port) as chan, + _connect_chan((host, port)) as chan, open_portal(chan) as regstr_ptl, ): yield regstr_ptl @@ -93,7 +93,7 @@ async def get_root( assert host is not None async with ( - _connect_chan(host, port) as chan, + _connect_chan((host, port)) as chan, open_portal(chan, **kwargs) as portal, ): yield portal @@ -187,7 +187,7 @@ async def maybe_open_portal( pass if sockaddr: - async with _connect_chan(*sockaddr) as chan: + async with _connect_chan(sockaddr) as chan: async with open_portal(chan) as portal: yield portal else: @@ -310,6 +310,6 @@ async def wait_for_actor( # TODO: offer multi-portal yields in multi-homed case? sockaddr: tuple[str, int] = sockaddrs[-1] - async with _connect_chan(*sockaddr) as chan: + async with _connect_chan(sockaddr) as chan: async with open_portal(chan) as portal: yield portal diff --git a/tractor/_root.py b/tractor/_root.py index 35639c15..40682a7a 100644 --- a/tractor/_root.py +++ b/tractor/_root.py @@ -271,7 +271,7 @@ async def open_root_actor( # be better to eventually have a "discovery" protocol # with basic handshake instead? with trio.move_on_after(timeout): - async with _connect_chan(*addr): + async with _connect_chan(addr): ponged_addrs.append(addr) except OSError: diff --git a/tractor/_runtime.py b/tractor/_runtime.py index 2c8dbbd9..eaab31b6 100644 --- a/tractor/_runtime.py +++ b/tractor/_runtime.py @@ -1040,10 +1040,7 @@ class Actor: # Connect back to the parent actor and conduct initial # handshake. From this point on if we error, we # attempt to ship the exception back to the parent. - chan = Channel( - destaddr=parent_addr, - ) - await chan.connect() + chan = await Channel.from_destaddr(parent_addr) # TODO: move this into a `Channel.handshake()`? # Initial handshake: swap names. diff --git a/tractor/ipc/__init__.py b/tractor/ipc/__init__.py index 4f0cd2b4..0c8e09ca 100644 --- a/tractor/ipc/__init__.py +++ b/tractor/ipc/__init__.py @@ -13,20 +13,26 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . - - import platform -from ._transport import MsgTransport as MsgTransport +from ._transport import ( + AddressType as AddressType, + MsgType as MsgType, + MsgTransport as MsgTransport, + MsgpackTransport as MsgpackTransport +) -from ._tcp import ( - get_stream_addrs as get_stream_addrs, - MsgpackTCPStream as MsgpackTCPStream +from ._tcp import MsgpackTCPStream as MsgpackTCPStream +from ._uds import MsgpackUDSStream as MsgpackUDSStream + +from ._types import ( + transport_from_destaddr as transport_from_destaddr, + transport_from_stream as transport_from_stream, + AddressTypes as AddressTypes ) from ._chan import ( _connect_chan as _connect_chan, - get_msg_transport as get_msg_transport, Channel as Channel ) diff --git a/tractor/ipc/_chan.py b/tractor/ipc/_chan.py index 1b6ba29f..ee259371 100644 --- a/tractor/ipc/_chan.py +++ b/tractor/ipc/_chan.py @@ -29,15 +29,15 @@ from pprint import pformat import typing from typing import ( Any, - Type ) import trio from tractor.ipc._transport import MsgTransport -from tractor.ipc._tcp import ( - MsgpackTCPStream, - get_stream_addrs +from tractor.ipc._types import ( + transport_from_destaddr, + transport_from_stream, + AddressTypes ) from tractor.log import get_logger from tractor._exceptions import ( @@ -52,17 +52,6 @@ log = get_logger(__name__) _is_windows = platform.system() == 'Windows' -def get_msg_transport( - - key: tuple[str, str], - -) -> Type[MsgTransport]: - - return { - ('msgpack', 'tcp'): MsgpackTCPStream, - }[key] - - class Channel: ''' An inter-process channel for communication between (remote) actors. @@ -77,10 +66,8 @@ class Channel: def __init__( self, - destaddr: tuple[str, int]|None, - - msg_transport_type_key: tuple[str, str] = ('msgpack', 'tcp'), - + destaddr: AddressTypes|None = None, + transport: MsgTransport|None = None, # TODO: optional reconnection support? # auto_reconnect: bool = False, # on_reconnect: typing.Callable[..., typing.Awaitable] = None, @@ -90,13 +77,11 @@ class Channel: # self._recon_seq = on_reconnect # self._autorecon = auto_reconnect - self._destaddr = destaddr - self._transport_key = msg_transport_type_key - # Either created in ``.connect()`` or passed in by # user in ``.from_stream()``. - self._stream: trio.SocketStream|None = None - self._transport: MsgTransport|None = None + self._transport: MsgTransport|None = transport + + self._destaddr = destaddr if destaddr else self._transport.raddr # set after handshake - always uid of far end self.uid: tuple[str, str]|None = None @@ -110,6 +95,10 @@ class Channel: # runtime. self._cancel_called: bool = False + @property + def stream(self) -> trio.abc.Stream | None: + return self._transport.stream if self._transport else None + @property def msgstream(self) -> MsgTransport: log.info( @@ -124,52 +113,31 @@ class Channel: @classmethod def from_stream( cls, - stream: trio.SocketStream, - **kwargs, - + stream: trio.abc.Stream, ) -> Channel: - - src, dst = get_stream_addrs(stream) - chan = Channel( - destaddr=dst, - **kwargs, + transport_cls = transport_from_stream(stream) + return Channel( + transport=transport_cls(stream) ) - # set immediately here from provided instance - chan._stream: trio.SocketStream = stream - chan.set_msg_transport(stream) - return chan + @classmethod + async def from_destaddr( + cls, + destaddr: AddressTypes, + **kwargs + ) -> Channel: + transport_cls = transport_from_destaddr(destaddr) + transport = await transport_cls.connect_to(destaddr, **kwargs) - def set_msg_transport( - self, - stream: trio.SocketStream, - type_key: tuple[str, str]|None = None, - - # XXX optionally provided codec pair for `msgspec`: - # https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types - codec: MsgCodec|None = None, - - ) -> MsgTransport: - type_key = ( - type_key - or - self._transport_key + log.transport( + f'Opened channel[{type(transport)}]: {transport.laddr} -> {transport.raddr}' ) - # get transport type, then - self._transport = get_msg_transport( - type_key - # instantiate an instance of the msg-transport - )( - stream, - codec=codec, - ) - return self._transport + return Channel(transport=transport) @cm def apply_codec( self, codec: MsgCodec, - ) -> None: ''' Temporarily override the underlying IPC msg codec for @@ -189,7 +157,7 @@ class Channel: return '' return repr( - self._transport.stream.socket._sock + self._transport ).replace( # type: ignore "socket.socket", "Channel", @@ -203,30 +171,6 @@ class Channel: def raddr(self) -> tuple[str, int]|None: return self._transport.raddr if self._transport else None - async def connect( - self, - destaddr: tuple[Any, ...] | None = None, - **kwargs - - ) -> MsgTransport: - - if self.connected(): - raise RuntimeError("channel is already connected?") - - destaddr = destaddr or self._destaddr - assert isinstance(destaddr, tuple) - - stream = await trio.open_tcp_stream( - *destaddr, - **kwargs - ) - transport = self.set_msg_transport(stream) - - log.transport( - f'Opened channel[{type(transport)}]: {self.laddr} -> {self.raddr}' - ) - return transport - # TODO: something like, # `pdbp.hideframe_on(errors=[MsgTypeError])` # instead of the `try/except` hack we have rn.. @@ -388,17 +332,14 @@ class Channel: @acm async def _connect_chan( - host: str, - port: int - + destaddr: AddressTypes ) -> typing.AsyncGenerator[Channel, None]: ''' Create and connect a channel with disconnect on context manager teardown. ''' - chan = Channel((host, port)) - await chan.connect() + chan = await Channel.from_destaddr(destaddr) yield chan with trio.CancelScope(shield=True): await chan.aclose() diff --git a/tractor/ipc/_tcp.py b/tractor/ipc/_tcp.py index 3ce0b4ea..71265f38 100644 --- a/tractor/ipc/_tcp.py +++ b/tractor/ipc/_tcp.py @@ -18,388 +18,75 @@ TCP implementation of tractor.ipc._transport.MsgTransport protocol ''' from __future__ import annotations -from collections.abc import ( - AsyncGenerator, - AsyncIterator, -) -import struct -from typing import ( - Any, - Callable, -) -import msgspec -from tricycle import BufferedReceiveStream import trio +from tractor.msg import MsgCodec from tractor.log import get_logger -from tractor._exceptions import ( - MsgTypeError, - TransportClosed, - _mk_send_mte, - _mk_recv_mte, -) -from tractor.msg import ( - _ctxvar_MsgCodec, - # _codec, XXX see `self._codec` sanity/debug checks - MsgCodec, - types as msgtypes, - pretty_struct, -) -from tractor.ipc import MsgTransport +from tractor.ipc._transport import MsgpackTransport log = get_logger(__name__) -def get_stream_addrs( - stream: trio.SocketStream -) -> tuple[ - tuple[str, int], # local - tuple[str, int], # remote -]: - ''' - Return the `trio` streaming transport prot's socket-addrs for - both the local and remote sides as a pair. - - ''' - # rn, should both be IP sockets - lsockname = stream.socket.getsockname() - rsockname = stream.socket.getpeername() - return ( - tuple(lsockname[:2]), - tuple(rsockname[:2]), - ) - - # TODO: typing oddity.. not sure why we have to inherit here, but it # seems to be an issue with `get_msg_transport()` returning # a `Type[Protocol]`; probably should make a `mypy` issue? -class MsgpackTCPStream(MsgTransport): +class MsgpackTCPStream(MsgpackTransport): ''' A ``trio.SocketStream`` delivering ``msgpack`` formatted data using the ``msgspec`` codec lib. ''' + address_type = tuple[str, int] layer_key: int = 4 name_key: str = 'tcp' - # TODO: better naming for this? - # -[ ] check how libp2p does naming for such things? - codec_key: str = 'msgpack' + # def __init__( + # self, + # stream: trio.SocketStream, + # prefix_size: int = 4, + # codec: CodecType = None, - def __init__( - self, - stream: trio.SocketStream, - prefix_size: int = 4, - - # XXX optionally provided codec pair for `msgspec`: - # https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types - # - # TODO: define this as a `Codec` struct which can be - # overriden dynamically by the application/runtime? - codec: tuple[ - Callable[[Any], Any]|None, # coder - Callable[[type, Any], Any]|None, # decoder - ]|None = None, - - ) -> None: - - self.stream = stream - assert self.stream.socket - - # should both be IP sockets - self._laddr, self._raddr = get_stream_addrs(stream) - - # create read loop instance - self._aiter_pkts = self._iter_packets() - self._send_lock = trio.StrictFIFOLock() - - # public i guess? - self.drained: list[dict] = [] - - self.recv_stream = BufferedReceiveStream( - transport_stream=stream - ) - self.prefix_size = prefix_size - - # allow for custom IPC msg interchange format - # dynamic override Bo - self._task = trio.lowlevel.current_task() - - # XXX for ctxvar debug only! - # self._codec: MsgCodec = ( - # codec - # or - # _codec._ctxvar_MsgCodec.get() - # ) - - async def _iter_packets(self) -> AsyncGenerator[dict, None]: - ''' - Yield `bytes`-blob decoded packets from the underlying TCP - stream using the current task's `MsgCodec`. - - This is a streaming routine implemented as an async generator - func (which was the original design, but could be changed?) - and is allocated by a `.__call__()` inside `.__init__()` where - it is assigned to the `._aiter_pkts` attr. - - ''' - decodes_failed: int = 0 - - while True: - try: - header: bytes = await self.recv_stream.receive_exactly(4) - except ( - ValueError, - ConnectionResetError, - - # not sure entirely why we need this but without it we - # seem to be getting racy failures here on - # arbiter/registry name subs.. - trio.BrokenResourceError, - - ) as trans_err: - - loglevel = 'transport' - match trans_err: - # case ( - # ConnectionResetError() - # ): - # loglevel = 'transport' - - # peer actor (graceful??) TCP EOF but `tricycle` - # seems to raise a 0-bytes-read? - case ValueError() if ( - 'unclean EOF' in trans_err.args[0] - ): - pass - - # peer actor (task) prolly shutdown quickly due - # to cancellation - case trio.BrokenResourceError() if ( - 'Connection reset by peer' in trans_err.args[0] - ): - pass - - # unless the disconnect condition falls under "a - # normal operation breakage" we usualy console warn - # about it. - case _: - loglevel: str = 'warning' - - - raise TransportClosed( - message=( - f'IPC transport already closed by peer\n' - f'x)> {type(trans_err)}\n' - f' |_{self}\n' - ), - loglevel=loglevel, - ) from trans_err - - # XXX definitely can happen if transport is closed - # manually by another `trio.lowlevel.Task` in the - # same actor; we use this in some simulated fault - # testing for ex, but generally should never happen - # under normal operation! - # - # NOTE: as such we always re-raise this error from the - # RPC msg loop! - except trio.ClosedResourceError as closure_err: - raise TransportClosed( - message=( - f'IPC transport already manually closed locally?\n' - f'x)> {type(closure_err)} \n' - f' |_{self}\n' - ), - loglevel='error', - raise_on_report=( - closure_err.args[0] == 'another task closed this fd' - or - closure_err.args[0] in ['another task closed this fd'] - ), - ) from closure_err - - # graceful TCP EOF disconnect - if header == b'': - raise TransportClosed( - message=( - f'IPC transport already gracefully closed\n' - f')>\n' - f'|_{self}\n' - ), - loglevel='transport', - # cause=??? # handy or no? - ) - - size: int - size, = struct.unpack(" None: - ''' - Send a msgpack encoded py-object-blob-as-msg over TCP. - - If `strict_types == True` then a `MsgTypeError` will be raised on any - invalid msg type - - ''' - __tracebackhide__: bool = hide_tb - - # XXX see `trio._sync.AsyncContextManagerMixin` for details - # on the `.acquire()`/`.release()` sequencing.. - async with self._send_lock: - - # NOTE: lookup the `trio.Task.context`'s var for - # the current `MsgCodec`. - codec: MsgCodec = _ctxvar_MsgCodec.get() - - # XXX for ctxvar debug only! - # if self._codec.pld_spec != codec.pld_spec: - # self._codec = codec - # log.runtime( - # f'Using new codec in {self}.send()\n' - # f'codec: {self._codec}\n\n' - # f'msg: {msg}\n' - # ) - - if type(msg) not in msgtypes.__msg_types__: - if strict_types: - raise _mk_send_mte( - msg, - codec=codec, - ) - else: - log.warning( - 'Sending non-`Msg`-spec msg?\n\n' - f'{msg}\n' - ) - - try: - bytes_data: bytes = codec.encode(msg) - except TypeError as _err: - typerr = _err - msgtyperr: MsgTypeError = _mk_send_mte( - msg, - codec=codec, - message=( - f'IPC-msg-spec violation in\n\n' - f'{pretty_struct.Struct.pformat(msg)}' - ), - src_type_error=typerr, - ) - raise msgtyperr from typerr - - # supposedly the fastest says, - # https://stackoverflow.com/a/54027962 - size: bytes = struct.pack(" - # except BaseException as _err: - # err = _err - # if not isinstance(err, MsgTypeError): - # __tracebackhide__: bool = False - # raise - - @property - def laddr(self) -> tuple[str, int]: - return self._laddr - - @property - def raddr(self) -> tuple[str, int]: - return self._raddr - - async def recv(self) -> Any: - return await self._aiter_pkts.asend(None) - - async def drain(self) -> AsyncIterator[dict]: - ''' - Drain the stream's remaining messages sent from - the far end until the connection is closed by - the peer. - - ''' - try: - async for msg in self._iter_packets(): - self.drained.append(msg) - except TransportClosed: - for msg in self.drained: - yield msg - - def __aiter__(self): - return self._aiter_pkts + # ) -> None: + # super().__init__( + # stream, + # prefix_size=prefix_size, + # codec=codec + # ) def connected(self) -> bool: return self.stream.socket.fileno() != -1 + + @classmethod + async def connect_to( + cls, + destaddr: tuple[str, int], + prefix_size: int = 4, + codec: MsgCodec|None = None, + **kwargs + ) -> MsgpackTCPStream: + stream = await trio.open_tcp_stream( + *destaddr, + **kwargs + ) + return MsgpackTCPStream( + stream, + prefix_size=prefix_size, + codec=codec + ) + + @classmethod + def get_stream_addrs( + cls, + stream: trio.SocketStream + ) -> tuple[ + tuple[str, int], + tuple[str, int] + ]: + lsockname = stream.socket.getsockname() + rsockname = stream.socket.getpeername() + return ( + tuple(lsockname[:2]), + tuple(rsockname[:2]), + ) diff --git a/tractor/ipc/_transport.py b/tractor/ipc/_transport.py index 64453c89..70ba2088 100644 --- a/tractor/ipc/_transport.py +++ b/tractor/ipc/_transport.py @@ -18,24 +18,56 @@ typing.Protocol based generic msg API, implement this class to add backends for tractor.ipc.Channel ''' -import trio +from __future__ import annotations from typing import ( runtime_checkable, + Type, Protocol, TypeVar, + ClassVar ) -from collections.abc import AsyncIterator +from collections.abc import ( + AsyncGenerator, + AsyncIterator, +) +import struct +from typing import ( + Any, + Callable, +) + +import trio +import msgspec +from tricycle import BufferedReceiveStream + +from tractor.log import get_logger +from tractor._exceptions import ( + MsgTypeError, + TransportClosed, + _mk_send_mte, + _mk_recv_mte, +) +from tractor.msg import ( + _ctxvar_MsgCodec, + # _codec, XXX see `self._codec` sanity/debug checks + MsgCodec, + types as msgtypes, + pretty_struct, +) + +log = get_logger(__name__) # from tractor.msg.types import MsgType # ?TODO? this should be our `Union[*msgtypes.__spec__]` alias now right..? # => BLEH, except can't bc prots must inherit typevar or param-spec # vars.. +AddressType = TypeVar('AddressType') MsgType = TypeVar('MsgType') @runtime_checkable -class MsgTransport(Protocol[MsgType]): +class MsgTransport(Protocol[AddressType, MsgType]): # # ^-TODO-^ consider using a generic def and indexing with our # eventual msg definition/types? @@ -43,9 +75,7 @@ class MsgTransport(Protocol[MsgType]): stream: trio.abc.Stream drained: list[MsgType] - - def __init__(self, stream: trio.abc.Stream) -> None: - ... + address_type: ClassVar[Type[AddressType]] # XXX: should this instead be called `.sendall()`? async def send(self, msg: MsgType) -> None: @@ -66,9 +96,345 @@ class MsgTransport(Protocol[MsgType]): ... @property - def laddr(self) -> tuple[str, int]: + def laddr(self) -> AddressType: ... @property - def raddr(self) -> tuple[str, int]: + def raddr(self) -> AddressType: ... + + @classmethod + async def connect_to( + cls, + destaddr: AddressType, + **kwargs + ) -> MsgTransport: + ... + + @classmethod + def get_stream_addrs( + cls, + stream: trio.abc.Stream + ) -> tuple[ + AddressType, # local + AddressType # remote + ]: + ''' + Return the `trio` streaming transport prot's addrs for both + the local and remote sides as a pair. + + ''' + ... + + +class MsgpackTransport(MsgTransport): + + # TODO: better naming for this? + # -[ ] check how libp2p does naming for such things? + codec_key: str = 'msgpack' + + def __init__( + self, + stream: trio.abc.Stream, + prefix_size: int = 4, + + # XXX optionally provided codec pair for `msgspec`: + # https://jcristharif.com/msgspec/extending.html#mapping-to-from-native-types + # + # TODO: define this as a `Codec` struct which can be + # overriden dynamically by the application/runtime? + codec: MsgCodec = None, + + ) -> None: + self.stream = stream + self._laddr, self._raddr = self.get_stream_addrs(stream) + + # create read loop instance + self._aiter_pkts = self._iter_packets() + self._send_lock = trio.StrictFIFOLock() + + # public i guess? + self.drained: list[dict] = [] + + self.recv_stream = BufferedReceiveStream( + transport_stream=stream + ) + self.prefix_size = prefix_size + + # allow for custom IPC msg interchange format + # dynamic override Bo + self._task = trio.lowlevel.current_task() + + # XXX for ctxvar debug only! + # self._codec: MsgCodec = ( + # codec + # or + # _codec._ctxvar_MsgCodec.get() + # ) + + async def _iter_packets(self) -> AsyncGenerator[dict, None]: + ''' + Yield `bytes`-blob decoded packets from the underlying TCP + stream using the current task's `MsgCodec`. + + This is a streaming routine implemented as an async generator + func (which was the original design, but could be changed?) + and is allocated by a `.__call__()` inside `.__init__()` where + it is assigned to the `._aiter_pkts` attr. + + ''' + decodes_failed: int = 0 + + while True: + try: + header: bytes = await self.recv_stream.receive_exactly(4) + except ( + ValueError, + ConnectionResetError, + + # not sure entirely why we need this but without it we + # seem to be getting racy failures here on + # arbiter/registry name subs.. + trio.BrokenResourceError, + + ) as trans_err: + + loglevel = 'transport' + match trans_err: + # case ( + # ConnectionResetError() + # ): + # loglevel = 'transport' + + # peer actor (graceful??) TCP EOF but `tricycle` + # seems to raise a 0-bytes-read? + case ValueError() if ( + 'unclean EOF' in trans_err.args[0] + ): + pass + + # peer actor (task) prolly shutdown quickly due + # to cancellation + case trio.BrokenResourceError() if ( + 'Connection reset by peer' in trans_err.args[0] + ): + pass + + # unless the disconnect condition falls under "a + # normal operation breakage" we usualy console warn + # about it. + case _: + loglevel: str = 'warning' + + + raise TransportClosed( + message=( + f'IPC transport already closed by peer\n' + f'x)> {type(trans_err)}\n' + f' |_{self}\n' + ), + loglevel=loglevel, + ) from trans_err + + # XXX definitely can happen if transport is closed + # manually by another `trio.lowlevel.Task` in the + # same actor; we use this in some simulated fault + # testing for ex, but generally should never happen + # under normal operation! + # + # NOTE: as such we always re-raise this error from the + # RPC msg loop! + except trio.ClosedResourceError as closure_err: + raise TransportClosed( + message=( + f'IPC transport already manually closed locally?\n' + f'x)> {type(closure_err)} \n' + f' |_{self}\n' + ), + loglevel='error', + raise_on_report=( + closure_err.args[0] == 'another task closed this fd' + or + closure_err.args[0] in ['another task closed this fd'] + ), + ) from closure_err + + # graceful TCP EOF disconnect + if header == b'': + raise TransportClosed( + message=( + f'IPC transport already gracefully closed\n' + f')>\n' + f'|_{self}\n' + ), + loglevel='transport', + # cause=??? # handy or no? + ) + + size: int + size, = struct.unpack(" None: + ''' + Send a msgpack encoded py-object-blob-as-msg over TCP. + + If `strict_types == True` then a `MsgTypeError` will be raised on any + invalid msg type + + ''' + __tracebackhide__: bool = hide_tb + + # XXX see `trio._sync.AsyncContextManagerMixin` for details + # on the `.acquire()`/`.release()` sequencing.. + async with self._send_lock: + + # NOTE: lookup the `trio.Task.context`'s var for + # the current `MsgCodec`. + codec: MsgCodec = _ctxvar_MsgCodec.get() + + # XXX for ctxvar debug only! + # if self._codec.pld_spec != codec.pld_spec: + # self._codec = codec + # log.runtime( + # f'Using new codec in {self}.send()\n' + # f'codec: {self._codec}\n\n' + # f'msg: {msg}\n' + # ) + + if type(msg) not in msgtypes.__msg_types__: + if strict_types: + raise _mk_send_mte( + msg, + codec=codec, + ) + else: + log.warning( + 'Sending non-`Msg`-spec msg?\n\n' + f'{msg}\n' + ) + + try: + bytes_data: bytes = codec.encode(msg) + except TypeError as _err: + typerr = _err + msgtyperr: MsgTypeError = _mk_send_mte( + msg, + codec=codec, + message=( + f'IPC-msg-spec violation in\n\n' + f'{pretty_struct.Struct.pformat(msg)}' + ), + src_type_error=typerr, + ) + raise msgtyperr from typerr + + # supposedly the fastest says, + # https://stackoverflow.com/a/54027962 + size: bytes = struct.pack(" + # except BaseException as _err: + # err = _err + # if not isinstance(err, MsgTypeError): + # __tracebackhide__: bool = False + # raise + + async def recv(self) -> Any: + return await self._aiter_pkts.asend(None) + + async def drain(self) -> AsyncIterator[dict]: + ''' + Drain the stream's remaining messages sent from + the far end until the connection is closed by + the peer. + + ''' + try: + async for msg in self._iter_packets(): + self.drained.append(msg) + except TransportClosed: + for msg in self.drained: + yield msg + + def __aiter__(self): + return self._aiter_pkts + + @property + def laddr(self) -> AddressType: + return self._laddr + + @property + def raddr(self) -> AddressType: + return self._raddr diff --git a/tractor/ipc/_types.py b/tractor/ipc/_types.py new file mode 100644 index 00000000..93c5e3c9 --- /dev/null +++ b/tractor/ipc/_types.py @@ -0,0 +1,101 @@ +# tractor: structured concurrent "actors". +# Copyright 2018-eternity Tyler Goodlet. + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from typing import Type, Union + +import trio +import socket + +from ._transport import MsgTransport +from ._tcp import MsgpackTCPStream +from ._uds import MsgpackUDSStream + + +# manually updated list of all supported codec+transport types +_msg_transports = { + ('msgpack', 'tcp'): MsgpackTCPStream, + ('msgpack', 'uds'): MsgpackUDSStream +} + + +# all different address py types we use +AddressTypes = Union[ + tuple([ + cls.address_type + for key, cls in _msg_transports.items() + ]) +] + + +def transport_from_destaddr( + destaddr: AddressTypes, + codec_key: str = 'msgpack', +) -> Type[MsgTransport]: + ''' + Given a destination address and a desired codec, find the + corresponding `MsgTransport` type. + + ''' + match destaddr: + case str(): + return MsgpackUDSStream + + case tuple(): + if ( + len(destaddr) == 2 + and + isinstance(destaddr[0], str) + and + isinstance(destaddr[1], int) + ): + return MsgpackTCPStream + + raise NotImplementedError( + f'No known transport for address {destaddr}' + ) + + +def transport_from_stream( + stream: trio.abc.Stream, + codec_key: str = 'msgpack' +) -> Type[MsgTransport]: + ''' + Given an arbitrary `trio.abc.Stream` and a desired codec, + find the corresponding `MsgTransport` type. + + ''' + transport = None + if isinstance(stream, trio.SocketStream): + sock = stream.socket + match sock.family: + case socket.AF_INET | socket.AF_INET6: + transport = 'tcp' + + case socket.AF_UNIX: + transport = 'uds' + + case _: + raise NotImplementedError( + f'Unsupported socket family: {sock.family}' + ) + + if not transport: + raise NotImplementedError( + f'Could not figure out transport type for stream type {type(stream)}' + ) + + key = (codec_key, transport) + + return _msg_transports[key] diff --git a/tractor/ipc/_uds.py b/tractor/ipc/_uds.py new file mode 100644 index 00000000..3b848898 --- /dev/null +++ b/tractor/ipc/_uds.py @@ -0,0 +1,84 @@ +# tractor: structured concurrent "actors". +# Copyright 2018-eternity Tyler Goodlet. + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +''' +Unix Domain Socket implementation of tractor.ipc._transport.MsgTransport protocol + +''' +from __future__ import annotations + +import trio + +from tractor.msg import MsgCodec +from tractor.log import get_logger +from tractor.ipc._transport import MsgpackTransport + + +log = get_logger(__name__) + + +class MsgpackUDSStream(MsgpackTransport): + ''' + A ``trio.SocketStream`` delivering ``msgpack`` formatted data + using the ``msgspec`` codec lib. + + ''' + address_type = str + layer_key: int = 7 + name_key: str = 'uds' + + # def __init__( + # self, + # stream: trio.SocketStream, + # prefix_size: int = 4, + # codec: CodecType = None, + + # ) -> None: + # super().__init__( + # stream, + # prefix_size=prefix_size, + # codec=codec + # ) + + def connected(self) -> bool: + return self.stream.socket.fileno() != -1 + + @classmethod + async def connect_to( + cls, + filename: str, + prefix_size: int = 4, + codec: MsgCodec|None = None, + **kwargs + ) -> MsgpackUDSStream: + stream = await trio.open_unix_socket( + filename, + **kwargs + ) + return MsgpackUDSStream( + stream, + prefix_size=prefix_size, + codec=codec + ) + + @classmethod + def get_stream_addrs( + cls, + stream: trio.SocketStream + ) -> tuple[str, str]: + return ( + stream.socket.getsockname(), + stream.socket.getpeername(), + )