diff --git a/piker/ipc.py b/piker/ipc.py index 1deb2b75..ff2d2649 100644 --- a/piker/ipc.py +++ b/piker/ipc.py @@ -1,13 +1,13 @@ """ Inter-process comms abstractions """ -from typing import Coroutine +from typing import Coroutine, Tuple import msgpack import trio from .log import get_logger -log = get_logger('broker.core') +log = get_logger('ipc') class StreamQueue: @@ -15,32 +15,43 @@ class StreamQueue: """ def __init__(self, stream): self.stream = stream - self.peer = stream.socket.getpeername() self._agen = self._iter_packets() + self._laddr = self.stream.socket.getsockname()[:2] + self._raddr = self.stream.socket.getpeername()[:2] + self._send_lock = trio.Lock() async def _iter_packets(self): """Yield packets from the underlying stream. """ - unpacker = msgpack.Unpacker(raw=False) + unpacker = msgpack.Unpacker(raw=False, use_list=False) while True: try: data = await self.stream.receive_some(2**10) log.trace(f"Data is {data}") except trio.BrokenStreamError: - log.error(f"Stream connection {self.peer} broke") + log.error(f"Stream connection {self.raddr} broke") return if data == b'': - log.debug("Stream connection was closed") + log.debug(f"Stream connection {self.raddr} was closed") return unpacker.feed(data) for packet in unpacker: yield packet + @property + def laddr(self): + return self._laddr + + @property + def raddr(self): + return self._raddr + async def put(self, data): - return await self.stream.send_all( - msgpack.dumps(data, use_bin_type=True)) + async with self._send_lock: + return await self.stream.send_all( + msgpack.dumps(data, use_bin_type=True)) async def get(self): return await self._agen.asend(None) @@ -49,25 +60,48 @@ class StreamQueue: return self._agen -class Client: - """The most basic client. +class Channel: + """A channel to actors in other processes. Use this to talk to any micro-service daemon or other client(s) over a - TCP socket managed by ``trio``. + a transport managed by ``trio``. """ def __init__( - self, sockaddr: tuple, - on_reconnect: Coroutine, - auto_reconnect: bool = True, - ): - self.sockaddr = sockaddr + self, + destaddr: tuple = None, + on_reconnect: Coroutine = None, + auto_reconnect: bool = False, + stream: trio.SocketStream = None, # expected to be active + ) -> None: self._recon_seq = on_reconnect self._autorecon = auto_reconnect - self.squeue = None + self.squeue = StreamQueue(stream) if stream else None + if self.squeue and destaddr: + raise ValueError( + f"A stream was provided with local addr {self.laddr}" + ) + self._destaddr = destaddr or self.squeue.raddr - async def connect(self, sockaddr: tuple = None, **kwargs): - sockaddr = sockaddr or self.sockaddr - stream = await trio.open_tcp_stream(*sockaddr, **kwargs) + def __repr__(self): + if self.squeue: + return repr( + self.squeue.stream.socket._sock).replace( + "socket.socket", "Channel") + return object.__repr__(self) + + @property + def laddr(self): + return self.squeue.laddr if self.squeue else (None, None) + + @property + def raddr(self): + return self.squeue.raddr if self.squeue else (None, None) + + async def connect(self, destaddr: Tuple[str, int] = None, **kwargs): + if self.squeue is not None: + raise RuntimeError("channel is already connected?") + destaddr = destaddr or self._destaddr + stream = await trio.open_tcp_stream(*destaddr, **kwargs) self.squeue = StreamQueue(stream) return stream @@ -77,21 +111,25 @@ class Client: async def recv(self): try: return await self.squeue.get() - except trio.BrokenStreamError as err: + except trio.BrokenStreamError: if self._autorecon: await self._reconnect() return await self.recv() async def aclose(self, *args): await self.squeue.stream.aclose() + self.squeue = None async def __aenter__(self): - await self.connect(self.sockaddr) + await self.connect() return self async def __aexit__(self, *args): await self.aclose(*args) + async def __aiter__(self): + return self.aiter_recv() + async def _reconnect(self): """Handle connection failures by polling until a reconnect can be established. @@ -109,13 +147,15 @@ class Client: else: log.warn("Stream connection re-established!") # run any reconnection sequence - await self._recon_seq(self) + on_recon = self._recon_seq + if on_recon: + await on_recon(self) break except (OSError, ConnectionRefusedError): if not down: down = True log.warn( - f"Connection to {self.sockaddr} went down, waiting" + f"Connection to {self.raddr} went down, waiting" " for re-establishment") await trio.sleep(1) @@ -126,9 +166,15 @@ class Client: try: async for item in self.squeue: yield item - except trio.BrokenStreamError as err: + # sent = yield item + # if sent is not None: + # # optimization, passing None through all the + # # time is pointless + # await self.squeue.put(sent) + except trio.BrokenStreamError: if not self._autorecon: raise + self.squeue = None if self._autorecon: # attempt reconnect await self._reconnect() continue