Add `NoBsWs.connected()` predicate
parent
d2fec7016a
commit
61c4147b73
|
@ -18,12 +18,20 @@
|
|||
ToOlS fOr CoPInG wITh "tHE wEB" protocols.
|
||||
|
||||
"""
|
||||
from contextlib import asynccontextmanager, AsyncExitStack
|
||||
from contextlib import (
|
||||
asynccontextmanager,
|
||||
AsyncExitStack,
|
||||
)
|
||||
from itertools import count
|
||||
from types import ModuleType
|
||||
from typing import Any, Optional, Callable, AsyncGenerator
|
||||
from typing import (
|
||||
Any,
|
||||
Optional,
|
||||
Callable,
|
||||
AsyncGenerator,
|
||||
Iterable,
|
||||
)
|
||||
import json
|
||||
import sys
|
||||
|
||||
import trio
|
||||
import trio_websocket
|
||||
|
@ -44,9 +52,12 @@ log = get_logger(__name__)
|
|||
|
||||
|
||||
class NoBsWs:
|
||||
"""Make ``trio_websocket`` sockets stay up no matter the bs.
|
||||
'''
|
||||
Make ``trio_websocket`` sockets stay up no matter the bs.
|
||||
|
||||
"""
|
||||
You can provide a ``fixture`` async-context-manager which will be
|
||||
enter/exitted around each reconnect operation.
|
||||
'''
|
||||
recon_errors = (
|
||||
ConnectionClosed,
|
||||
DisconnectionTimeout,
|
||||
|
@ -68,10 +79,16 @@ class NoBsWs:
|
|||
self._stack = stack
|
||||
self._ws: 'WebSocketConnection' = None # noqa
|
||||
|
||||
# TODO: is there some method we can call
|
||||
# on the underlying `._ws` to get this?
|
||||
self._connected: bool = False
|
||||
|
||||
async def _connect(
|
||||
self,
|
||||
tries: int = 1000,
|
||||
) -> None:
|
||||
|
||||
self._connected = False
|
||||
while True:
|
||||
try:
|
||||
await self._stack.aclose()
|
||||
|
@ -96,6 +113,8 @@ class NoBsWs:
|
|||
assert ret is None
|
||||
|
||||
log.info(f'Connection success: {self.url}')
|
||||
|
||||
self._connected = True
|
||||
return self._ws
|
||||
|
||||
except self.recon_errors as err:
|
||||
|
@ -105,11 +124,15 @@ class NoBsWs:
|
|||
f'{type(err)}...retry attempt {i}'
|
||||
)
|
||||
await trio.sleep(0.5)
|
||||
self._connected = False
|
||||
continue
|
||||
else:
|
||||
log.exception('ws connection fail...')
|
||||
raise last_err
|
||||
|
||||
def connected(self) -> bool:
|
||||
return self._connected
|
||||
|
||||
async def send_msg(
|
||||
self,
|
||||
data: Any,
|
||||
|
@ -161,6 +184,7 @@ async def open_autorecon_ws(
|
|||
'''
|
||||
JSONRPC response-request style machinery for transparent multiplexing of msgs
|
||||
over a NoBsWs.
|
||||
|
||||
'''
|
||||
|
||||
|
||||
|
@ -170,6 +194,7 @@ class JSONRPCResult(Struct):
|
|||
result: Optional[dict] = None
|
||||
error: Optional[dict] = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def open_jsonrpc_session(
|
||||
url: str,
|
||||
|
@ -220,15 +245,16 @@ async def open_jsonrpc_session(
|
|||
|
||||
async def recv_task():
|
||||
'''
|
||||
receives every ws message and stores it in its corresponding result
|
||||
field, then sets the event to wakeup original sender tasks.
|
||||
also recieves responses to requests originated from the server side.
|
||||
'''
|
||||
receives every ws message and stores it in its corresponding
|
||||
result field, then sets the event to wakeup original sender
|
||||
tasks. also recieves responses to requests originated from
|
||||
the server side.
|
||||
|
||||
'''
|
||||
async for msg in ws:
|
||||
match msg:
|
||||
case {
|
||||
'result': result,
|
||||
'result': _,
|
||||
'id': mid,
|
||||
} if res_entry := rpc_results.get(mid):
|
||||
|
||||
|
@ -239,7 +265,9 @@ async def open_jsonrpc_session(
|
|||
'result': _,
|
||||
'id': mid,
|
||||
} if not rpc_results.get(mid):
|
||||
log.warning(f'Wasn\'t expecting ws msg: {json.dumps(msg, indent=4)}')
|
||||
log.warning(
|
||||
f'Unexpected ws msg: {json.dumps(msg, indent=4)}'
|
||||
)
|
||||
|
||||
case {
|
||||
'method': _,
|
||||
|
@ -259,7 +287,6 @@ async def open_jsonrpc_session(
|
|||
case _:
|
||||
log.warning(f'Unhandled JSON-RPC msg!?\n{msg}')
|
||||
|
||||
|
||||
n.start_soon(recv_task)
|
||||
yield json_rpc
|
||||
n.cancel_scope.cancel()
|
||||
|
|
Loading…
Reference in New Issue