Add `NoBsWs.connected()` predicate

misc_brokerd_backend_repairs
Tyler Goodlet 2023-01-05 17:28:10 -05:00
parent a146ad9e69
commit cf6e44cb9c
1 changed files with 40 additions and 13 deletions

View File

@ -18,16 +18,24 @@
ToOlS fOr CoPInG wITh "tHE wEB" protocols. ToOlS fOr CoPInG wITh "tHE wEB" protocols.
""" """
from contextlib import asynccontextmanager, AsyncExitStack from contextlib import (
asynccontextmanager,
AsyncExitStack,
)
from itertools import count from itertools import count
from types import ModuleType from types import ModuleType
from typing import Any, Optional, Callable, AsyncGenerator from typing import (
Any,
Optional,
Callable,
AsyncGenerator,
Iterable,
)
import json import json
import sys
import trio import trio
import trio_websocket import trio_websocket
from wsproto.utilities import LocalProtocolError from wsproto.utilities import LocalProtocolError
from trio_websocket._impl import ( from trio_websocket._impl import (
ConnectionClosed, ConnectionClosed,
DisconnectionTimeout, DisconnectionTimeout,
@ -44,9 +52,12 @@ log = get_logger(__name__)
class NoBsWs: class NoBsWs:
"""Make ``trio_websocket`` sockets stay up no matter the bs. '''
Make ``trio_websocket`` sockets stay up no matter the bs.
""" You can provide a ``fixture`` async-context-manager which will be
enter/exitted around each reconnect operation.
'''
recon_errors = ( recon_errors = (
ConnectionClosed, ConnectionClosed,
DisconnectionTimeout, DisconnectionTimeout,
@ -68,10 +79,16 @@ class NoBsWs:
self._stack = stack self._stack = stack
self._ws: 'WebSocketConnection' = None # noqa self._ws: 'WebSocketConnection' = None # noqa
# TODO: is there some method we can call
# on the underlying `._ws` to get this?
self._connected: bool = False
async def _connect( async def _connect(
self, self,
tries: int = 1000, tries: int = 1000,
) -> None: ) -> None:
self._connected = False
while True: while True:
try: try:
await self._stack.aclose() await self._stack.aclose()
@ -96,6 +113,8 @@ class NoBsWs:
assert ret is None assert ret is None
log.info(f'Connection success: {self.url}') log.info(f'Connection success: {self.url}')
self._connected = True
return self._ws return self._ws
except self.recon_errors as err: except self.recon_errors as err:
@ -105,11 +124,15 @@ class NoBsWs:
f'{type(err)}...retry attempt {i}' f'{type(err)}...retry attempt {i}'
) )
await trio.sleep(0.5) await trio.sleep(0.5)
self._connected = False
continue continue
else: else:
log.exception('ws connection fail...') log.exception('ws connection fail...')
raise last_err raise last_err
def connected(self) -> bool:
return self._connected
async def send_msg( async def send_msg(
self, self,
data: Any, data: Any,
@ -161,6 +184,7 @@ async def open_autorecon_ws(
''' '''
JSONRPC response-request style machinery for transparent multiplexing of msgs JSONRPC response-request style machinery for transparent multiplexing of msgs
over a NoBsWs. over a NoBsWs.
''' '''
@ -170,6 +194,7 @@ class JSONRPCResult(Struct):
result: Optional[dict] = None result: Optional[dict] = None
error: Optional[dict] = None error: Optional[dict] = None
@asynccontextmanager @asynccontextmanager
async def open_jsonrpc_session( async def open_jsonrpc_session(
url: str, url: str,
@ -220,15 +245,16 @@ async def open_jsonrpc_session(
async def recv_task(): async def recv_task():
''' '''
receives every ws message and stores it in its corresponding result receives every ws message and stores it in its corresponding
field, then sets the event to wakeup original sender tasks. result field, then sets the event to wakeup original sender
also recieves responses to requests originated from the server side. tasks. also recieves responses to requests originated from
''' the server side.
'''
async for msg in ws: async for msg in ws:
match msg: match msg:
case { case {
'result': result, 'result': _,
'id': mid, 'id': mid,
} if res_entry := rpc_results.get(mid): } if res_entry := rpc_results.get(mid):
@ -239,7 +265,9 @@ async def open_jsonrpc_session(
'result': _, 'result': _,
'id': mid, 'id': mid,
} if not rpc_results.get(mid): } if not rpc_results.get(mid):
log.warning(f'Wasn\'t expecting ws msg: {json.dumps(msg, indent=4)}') log.warning(
f'Unexpected ws msg: {json.dumps(msg, indent=4)}'
)
case { case {
'method': _, 'method': _,
@ -259,7 +287,6 @@ async def open_jsonrpc_session(
case _: case _:
log.warning(f'Unhandled JSON-RPC msg!?\n{msg}') log.warning(f'Unhandled JSON-RPC msg!?\n{msg}')
n.start_soon(recv_task) n.start_soon(recv_task)
yield json_rpc yield json_rpc
n.cancel_scope.cancel() n.cancel_scope.cancel()