From cf6e44cb9c7fc0f34368e38340ecdf1c55afb124 Mon Sep 17 00:00:00 2001 From: Tyler Goodlet Date: Thu, 5 Jan 2023 17:28:10 -0500 Subject: [PATCH] Add `NoBsWs.connected()` predicate --- piker/data/_web_bs.py | 53 ++++++++++++++++++++++++++++++++----------- 1 file changed, 40 insertions(+), 13 deletions(-) diff --git a/piker/data/_web_bs.py b/piker/data/_web_bs.py index 8af82d61..1577a678 100644 --- a/piker/data/_web_bs.py +++ b/piker/data/_web_bs.py @@ -18,16 +18,24 @@ ToOlS fOr CoPInG wITh "tHE wEB" protocols. """ -from contextlib import asynccontextmanager, AsyncExitStack +from contextlib import ( + asynccontextmanager, + AsyncExitStack, +) from itertools import count from types import ModuleType -from typing import Any, Optional, Callable, AsyncGenerator +from typing import ( + Any, + Optional, + Callable, + AsyncGenerator, + Iterable, +) import json -import sys import trio import trio_websocket -from wsproto.utilities import LocalProtocolError +from wsproto.utilities import LocalProtocolError from trio_websocket._impl import ( ConnectionClosed, DisconnectionTimeout, @@ -44,9 +52,12 @@ log = get_logger(__name__) class NoBsWs: - """Make ``trio_websocket`` sockets stay up no matter the bs. + ''' + Make ``trio_websocket`` sockets stay up no matter the bs. - """ + You can provide a ``fixture`` async-context-manager which will be + enter/exitted around each reconnect operation. + ''' recon_errors = ( ConnectionClosed, DisconnectionTimeout, @@ -68,10 +79,16 @@ class NoBsWs: self._stack = stack self._ws: 'WebSocketConnection' = None # noqa + # TODO: is there some method we can call + # on the underlying `._ws` to get this? + self._connected: bool = False + async def _connect( self, tries: int = 1000, ) -> None: + + self._connected = False while True: try: await self._stack.aclose() @@ -96,6 +113,8 @@ class NoBsWs: assert ret is None log.info(f'Connection success: {self.url}') + + self._connected = True return self._ws except self.recon_errors as err: @@ -105,11 +124,15 @@ class NoBsWs: f'{type(err)}...retry attempt {i}' ) await trio.sleep(0.5) + self._connected = False continue else: log.exception('ws connection fail...') raise last_err + def connected(self) -> bool: + return self._connected + async def send_msg( self, data: Any, @@ -161,6 +184,7 @@ async def open_autorecon_ws( ''' JSONRPC response-request style machinery for transparent multiplexing of msgs over a NoBsWs. + ''' @@ -170,6 +194,7 @@ class JSONRPCResult(Struct): result: Optional[dict] = None error: Optional[dict] = None + @asynccontextmanager async def open_jsonrpc_session( url: str, @@ -220,15 +245,16 @@ async def open_jsonrpc_session( async def recv_task(): ''' - receives every ws message and stores it in its corresponding result - field, then sets the event to wakeup original sender tasks. - also recieves responses to requests originated from the server side. - ''' + receives every ws message and stores it in its corresponding + result field, then sets the event to wakeup original sender + tasks. also recieves responses to requests originated from + the server side. + ''' async for msg in ws: match msg: case { - 'result': result, + 'result': _, 'id': mid, } if res_entry := rpc_results.get(mid): @@ -239,7 +265,9 @@ async def open_jsonrpc_session( 'result': _, 'id': mid, } if not rpc_results.get(mid): - log.warning(f'Wasn\'t expecting ws msg: {json.dumps(msg, indent=4)}') + log.warning( + f'Unexpected ws msg: {json.dumps(msg, indent=4)}' + ) case { 'method': _, @@ -259,7 +287,6 @@ async def open_jsonrpc_session( case _: log.warning(f'Unhandled JSON-RPC msg!?\n{msg}') - n.start_soon(recv_task) yield json_rpc n.cancel_scope.cancel()