Add `NoBsWs.connected()` predicate
parent
a146ad9e69
commit
cf6e44cb9c
|
@ -18,16 +18,24 @@
|
||||||
ToOlS fOr CoPInG wITh "tHE wEB" protocols.
|
ToOlS fOr CoPInG wITh "tHE wEB" protocols.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from contextlib import asynccontextmanager, AsyncExitStack
|
from contextlib import (
|
||||||
|
asynccontextmanager,
|
||||||
|
AsyncExitStack,
|
||||||
|
)
|
||||||
from itertools import count
|
from itertools import count
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import Any, Optional, Callable, AsyncGenerator
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Optional,
|
||||||
|
Callable,
|
||||||
|
AsyncGenerator,
|
||||||
|
Iterable,
|
||||||
|
)
|
||||||
import json
|
import json
|
||||||
import sys
|
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
import trio_websocket
|
import trio_websocket
|
||||||
from wsproto.utilities import LocalProtocolError
|
from wsproto.utilities import LocalProtocolError
|
||||||
from trio_websocket._impl import (
|
from trio_websocket._impl import (
|
||||||
ConnectionClosed,
|
ConnectionClosed,
|
||||||
DisconnectionTimeout,
|
DisconnectionTimeout,
|
||||||
|
@ -44,9 +52,12 @@ log = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class NoBsWs:
|
class NoBsWs:
|
||||||
"""Make ``trio_websocket`` sockets stay up no matter the bs.
|
'''
|
||||||
|
Make ``trio_websocket`` sockets stay up no matter the bs.
|
||||||
|
|
||||||
"""
|
You can provide a ``fixture`` async-context-manager which will be
|
||||||
|
enter/exitted around each reconnect operation.
|
||||||
|
'''
|
||||||
recon_errors = (
|
recon_errors = (
|
||||||
ConnectionClosed,
|
ConnectionClosed,
|
||||||
DisconnectionTimeout,
|
DisconnectionTimeout,
|
||||||
|
@ -68,10 +79,16 @@ class NoBsWs:
|
||||||
self._stack = stack
|
self._stack = stack
|
||||||
self._ws: 'WebSocketConnection' = None # noqa
|
self._ws: 'WebSocketConnection' = None # noqa
|
||||||
|
|
||||||
|
# TODO: is there some method we can call
|
||||||
|
# on the underlying `._ws` to get this?
|
||||||
|
self._connected: bool = False
|
||||||
|
|
||||||
async def _connect(
|
async def _connect(
|
||||||
self,
|
self,
|
||||||
tries: int = 1000,
|
tries: int = 1000,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
|
self._connected = False
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
await self._stack.aclose()
|
await self._stack.aclose()
|
||||||
|
@ -96,6 +113,8 @@ class NoBsWs:
|
||||||
assert ret is None
|
assert ret is None
|
||||||
|
|
||||||
log.info(f'Connection success: {self.url}')
|
log.info(f'Connection success: {self.url}')
|
||||||
|
|
||||||
|
self._connected = True
|
||||||
return self._ws
|
return self._ws
|
||||||
|
|
||||||
except self.recon_errors as err:
|
except self.recon_errors as err:
|
||||||
|
@ -105,11 +124,15 @@ class NoBsWs:
|
||||||
f'{type(err)}...retry attempt {i}'
|
f'{type(err)}...retry attempt {i}'
|
||||||
)
|
)
|
||||||
await trio.sleep(0.5)
|
await trio.sleep(0.5)
|
||||||
|
self._connected = False
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
log.exception('ws connection fail...')
|
log.exception('ws connection fail...')
|
||||||
raise last_err
|
raise last_err
|
||||||
|
|
||||||
|
def connected(self) -> bool:
|
||||||
|
return self._connected
|
||||||
|
|
||||||
async def send_msg(
|
async def send_msg(
|
||||||
self,
|
self,
|
||||||
data: Any,
|
data: Any,
|
||||||
|
@ -161,6 +184,7 @@ async def open_autorecon_ws(
|
||||||
'''
|
'''
|
||||||
JSONRPC response-request style machinery for transparent multiplexing of msgs
|
JSONRPC response-request style machinery for transparent multiplexing of msgs
|
||||||
over a NoBsWs.
|
over a NoBsWs.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
@ -170,6 +194,7 @@ class JSONRPCResult(Struct):
|
||||||
result: Optional[dict] = None
|
result: Optional[dict] = None
|
||||||
error: Optional[dict] = None
|
error: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def open_jsonrpc_session(
|
async def open_jsonrpc_session(
|
||||||
url: str,
|
url: str,
|
||||||
|
@ -220,15 +245,16 @@ async def open_jsonrpc_session(
|
||||||
|
|
||||||
async def recv_task():
|
async def recv_task():
|
||||||
'''
|
'''
|
||||||
receives every ws message and stores it in its corresponding result
|
receives every ws message and stores it in its corresponding
|
||||||
field, then sets the event to wakeup original sender tasks.
|
result field, then sets the event to wakeup original sender
|
||||||
also recieves responses to requests originated from the server side.
|
tasks. also recieves responses to requests originated from
|
||||||
'''
|
the server side.
|
||||||
|
|
||||||
|
'''
|
||||||
async for msg in ws:
|
async for msg in ws:
|
||||||
match msg:
|
match msg:
|
||||||
case {
|
case {
|
||||||
'result': result,
|
'result': _,
|
||||||
'id': mid,
|
'id': mid,
|
||||||
} if res_entry := rpc_results.get(mid):
|
} if res_entry := rpc_results.get(mid):
|
||||||
|
|
||||||
|
@ -239,7 +265,9 @@ async def open_jsonrpc_session(
|
||||||
'result': _,
|
'result': _,
|
||||||
'id': mid,
|
'id': mid,
|
||||||
} if not rpc_results.get(mid):
|
} if not rpc_results.get(mid):
|
||||||
log.warning(f'Wasn\'t expecting ws msg: {json.dumps(msg, indent=4)}')
|
log.warning(
|
||||||
|
f'Unexpected ws msg: {json.dumps(msg, indent=4)}'
|
||||||
|
)
|
||||||
|
|
||||||
case {
|
case {
|
||||||
'method': _,
|
'method': _,
|
||||||
|
@ -259,7 +287,6 @@ async def open_jsonrpc_session(
|
||||||
case _:
|
case _:
|
||||||
log.warning(f'Unhandled JSON-RPC msg!?\n{msg}')
|
log.warning(f'Unhandled JSON-RPC msg!?\n{msg}')
|
||||||
|
|
||||||
|
|
||||||
n.start_soon(recv_task)
|
n.start_soon(recv_task)
|
||||||
yield json_rpc
|
yield json_rpc
|
||||||
n.cancel_scope.cancel()
|
n.cancel_scope.cancel()
|
||||||
|
|
Loading…
Reference in New Issue