diff --git a/piker/brokers/core.py b/piker/brokers/core.py index 511c216c..8b823808 100644 --- a/piker/brokers/core.py +++ b/piker/brokers/core.py @@ -6,7 +6,7 @@ import inspect from functools import partial import socket from types import ModuleType -from typing import Coroutine +from typing import Coroutine, Callable import msgpack import trio @@ -48,14 +48,14 @@ async def quote(brokermod: ModuleType, tickers: [str]) -> dict: return results -async def wait_for_network(get_quotes, sleep=1): +async def wait_for_network(net_func: Callable, sleep: int = 1) -> dict: """Wait until the network comes back up. """ down = False while True: try: with trio.move_on_after(1) as cancel_scope: - quotes = await get_quotes() + quotes = await net_func() if down: log.warn("Network is back up") return quotes @@ -69,35 +69,22 @@ async def wait_for_network(get_quotes, sleep=1): await trio.sleep(sleep) -class Disconnect(trio.Cancelled): - "Stream was closed" - - class StreamQueue: - """Stream wrapped as a queue that delivers json serialized "packets" - delimited by ``delim``. + """Stream wrapped as a queue that delivers ``msgpack`` serialized objects. """ - def __init__(self, stream, delim=b'\n'): + def __init__(self, stream): self.stream = stream - self._delim = delim self.peer = stream.socket.getpeername() self._agen = self._iter_packets() async def _iter_packets(self): """Yield packets from the underlying stream. """ - delim = self._delim - buff = b'' unpacker = msgpack.Unpacker(raw=False) while True: - packets = [] - try: - data = await self.stream.receive_some(2**10) - except trio.BrokenStreamError as err: - log.debug("Stream connection was broken") - return - + data = await self.stream.receive_some(2**10) log.trace(f"Data is {data}") + if data == b'': log.debug("Stream connection was closed") return @@ -117,7 +104,88 @@ class StreamQueue: return self._agen -async def poll_tickers( +class Client: + """The most basic client. + + Use this to talk to any micro-service daemon or other client(s) over a + TCP socket managed by ``trio``. + """ + def __init__( + self, sockaddr: tuple, + startup_seq: Coroutine, + auto_reconnect: bool = True, + ): + self._sockaddr = sockaddr + self._startup_seq = startup_seq + self._autorecon = auto_reconnect + self.stream = None + self.squeue = None + + async def connect(self, sockaddr: tuple = None, **kwargs): + sockaddr = sockaddr or self._sockaddr + stream = await trio.open_tcp_stream(*sockaddr, **kwargs) + self.squeue = StreamQueue(stream) + await self._startup_seq(self) + return stream + + async def send(self, item): + await self.squeue.put(item) + + async def recv(self): + try: + return await self.squeue.get() + except trio.BrokenStreamError as err: + if self._autorecon: + await self._reconnect() + return await self.recv() + + async def __aenter__(self): + await self.connect(self._sockaddr) + return self + + async def __aexit__(self, *args): + await self.squeue.stream.__aexit__() + self.stream = None + + async def _reconnect(self): + """Handle connection failures by polling until a reconnect can be + established. + """ + down = False + while True: + try: + with trio.move_on_after(3) as cancel_scope: + await self.connect() + cancelled = cancel_scope.cancelled_caught + if cancelled: + log.warn("Reconnect timed out after 3 seconds, retrying...") + continue + else: + log.warn("Stream connection re-established!") + break + except OSError: + if not down: + down = True + log.warn( + "Connection went down, waiting for re-establishment") + await trio.sleep(1) + + async def aiter_recv(self): + """Async iterate items from underlying stream. + """ + try: + async for item in self.squeue: + yield item + except trio.BrokenStreamError as err: + if not self._autorecon: + raise + if self._autorecon: # attempt reconnect + await self._reconnect() + async for item in self.aiter_recv(): + yield item + + +async def stream_quotes( brokermod: ModuleType, get_quotes: Coroutine, tickers2qs: {str: StreamQueue}, @@ -192,7 +260,7 @@ async def poll_tickers( await trio.sleep(delay) -async def start_quoter(stream): +async def start_quoter(stream: trio.SocketStream) -> None: """Handle per-broker quote stream subscriptions. Spawns new quoter tasks for each broker backend on-demand. @@ -255,7 +323,7 @@ async def start_quoter(stream): # task should begin on the next checkpoint/iteration log.info(f"Spawning quoter task for {brokermod.name}") nursery.start_soon( - poll_tickers, brokermod, get_quotes, tickers2qs) + stream_quotes, brokermod, get_quotes, tickers2qs) else: log.info(f"{queue.peer} was disconnected") nursery.cancel_scope.cancel() @@ -265,7 +333,7 @@ async def start_quoter(stream): await client.__aexit__() -async def _daemon_main(brokermod): +async def _daemon_main(brokermod: ModuleType) -> None: """Entry point for the broker daemon. """ async with trio.open_nursery() as nursery: