Add `NoBsWs.connected()` predicate

epoch_index_backup
Tyler Goodlet 2023-01-05 17:28:10 -05:00
parent d2fec7016a
commit 61c4147b73
1 changed files with 40 additions and 13 deletions

View File

@ -18,12 +18,20 @@
ToOlS fOr CoPInG wITh "tHE wEB" protocols.
"""
from contextlib import asynccontextmanager, AsyncExitStack
from contextlib import (
asynccontextmanager,
AsyncExitStack,
)
from itertools import count
from types import ModuleType
from typing import Any, Optional, Callable, AsyncGenerator
from typing import (
Any,
Optional,
Callable,
AsyncGenerator,
Iterable,
)
import json
import sys
import trio
import trio_websocket
@ -44,9 +52,12 @@ log = get_logger(__name__)
class NoBsWs:
"""Make ``trio_websocket`` sockets stay up no matter the bs.
'''
Make ``trio_websocket`` sockets stay up no matter the bs.
"""
You can provide a ``fixture`` async-context-manager which will be
enter/exitted around each reconnect operation.
'''
recon_errors = (
ConnectionClosed,
DisconnectionTimeout,
@ -68,10 +79,16 @@ class NoBsWs:
self._stack = stack
self._ws: 'WebSocketConnection' = None # noqa
# TODO: is there some method we can call
# on the underlying `._ws` to get this?
self._connected: bool = False
async def _connect(
self,
tries: int = 1000,
) -> None:
self._connected = False
while True:
try:
await self._stack.aclose()
@ -96,6 +113,8 @@ class NoBsWs:
assert ret is None
log.info(f'Connection success: {self.url}')
self._connected = True
return self._ws
except self.recon_errors as err:
@ -105,11 +124,15 @@ class NoBsWs:
f'{type(err)}...retry attempt {i}'
)
await trio.sleep(0.5)
self._connected = False
continue
else:
log.exception('ws connection fail...')
raise last_err
def connected(self) -> bool:
return self._connected
async def send_msg(
self,
data: Any,
@ -161,6 +184,7 @@ async def open_autorecon_ws(
'''
JSONRPC response-request style machinery for transparent multiplexing of msgs
over a NoBsWs.
'''
@ -170,6 +194,7 @@ class JSONRPCResult(Struct):
result: Optional[dict] = None
error: Optional[dict] = None
@asynccontextmanager
async def open_jsonrpc_session(
url: str,
@ -220,15 +245,16 @@ async def open_jsonrpc_session(
async def recv_task():
'''
receives every ws message and stores it in its corresponding result
field, then sets the event to wakeup original sender tasks.
also recieves responses to requests originated from the server side.
'''
receives every ws message and stores it in its corresponding
result field, then sets the event to wakeup original sender
tasks. also recieves responses to requests originated from
the server side.
'''
async for msg in ws:
match msg:
case {
'result': result,
'result': _,
'id': mid,
} if res_entry := rpc_results.get(mid):
@ -239,7 +265,9 @@ async def open_jsonrpc_session(
'result': _,
'id': mid,
} if not rpc_results.get(mid):
log.warning(f'Wasn\'t expecting ws msg: {json.dumps(msg, indent=4)}')
log.warning(
f'Unexpected ws msg: {json.dumps(msg, indent=4)}'
)
case {
'method': _,
@ -259,7 +287,6 @@ async def open_jsonrpc_session(
case _:
log.warning(f'Unhandled JSON-RPC msg!?\n{msg}')
n.start_soon(recv_task)
yield json_rpc
n.cancel_scope.cancel()