519 lines
15 KiB
Python
519 lines
15 KiB
Python
# piker: trading gear for hackers
|
|
# Copyright (C) Tyler Goodlet (in stewardship for pikers)
|
|
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Affero General Public License as published by
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
"""
|
|
ToOlS fOr CoPInG wITh "tHE wEB" protocols.
|
|
|
|
"""
|
|
from __future__ import annotations
|
|
from contextlib import (
|
|
asynccontextmanager as acm,
|
|
)
|
|
from itertools import count
|
|
from functools import partial
|
|
from types import ModuleType
|
|
from typing import (
|
|
Any,
|
|
Optional,
|
|
Callable,
|
|
AsyncContextManager,
|
|
AsyncGenerator,
|
|
Iterable,
|
|
)
|
|
import json
|
|
|
|
import trio
|
|
from trio_typing import TaskStatus
|
|
from trio_websocket import (
|
|
WebSocketConnection,
|
|
open_websocket_url,
|
|
)
|
|
from wsproto.utilities import LocalProtocolError
|
|
from trio_websocket._impl import (
|
|
ConnectionClosed,
|
|
DisconnectionTimeout,
|
|
ConnectionRejected,
|
|
HandshakeError,
|
|
ConnectionTimeout,
|
|
)
|
|
|
|
from piker.types import Struct
|
|
from ._util import log
|
|
|
|
|
|
class NoBsWs:
|
|
'''
|
|
Make ``trio_websocket`` sockets stay up no matter the bs.
|
|
|
|
A shim interface that allows client code to stream from some
|
|
``WebSocketConnection`` but where any connectivy bs is handled
|
|
automatcially and entirely in the background.
|
|
|
|
NOTE: this type should never be created directly but instead is
|
|
provided via the ``open_autorecon_ws()`` factor below.
|
|
|
|
'''
|
|
# apparently we can QoS for all sorts of reasons..so catch em.
|
|
recon_errors = (
|
|
ConnectionClosed,
|
|
DisconnectionTimeout,
|
|
ConnectionRejected,
|
|
HandshakeError,
|
|
ConnectionTimeout,
|
|
LocalProtocolError,
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
url: str,
|
|
rxchan: trio.MemoryReceiveChannel,
|
|
msg_recv_timeout: float,
|
|
|
|
serializer: ModuleType = json
|
|
):
|
|
self.url = url
|
|
self._rx = rxchan
|
|
self._timeout = msg_recv_timeout
|
|
|
|
# signaling between caller and relay task which determines when
|
|
# socket is connected (and subscribed).
|
|
self._connected: trio.Event = trio.Event()
|
|
|
|
# dynamically reset by the bg relay task
|
|
self._ws: WebSocketConnection | None = None
|
|
self._cs: trio.CancelScope | None = None
|
|
|
|
# interchange codec methods
|
|
# TODO: obviously the method API here may be different
|
|
# for another interchange format..
|
|
self._dumps: Callable = serializer.dumps
|
|
self._loads: Callable = serializer.loads
|
|
|
|
def connected(self) -> bool:
|
|
return self._connected.is_set()
|
|
|
|
async def reset(self) -> None:
|
|
'''
|
|
Reset the underlying ws connection by cancelling
|
|
the bg relay task and waiting for it to signal
|
|
a new connection.
|
|
|
|
'''
|
|
self._connected = trio.Event()
|
|
self._cs.cancel()
|
|
await self._connected.wait()
|
|
|
|
async def send_msg(
|
|
self,
|
|
data: Any,
|
|
) -> None:
|
|
while True:
|
|
try:
|
|
msg: Any = self._dumps(data)
|
|
return await self._ws.send_message(msg)
|
|
except self.recon_errors:
|
|
await self.reset()
|
|
|
|
async def recv_msg(self) -> Any:
|
|
msg: Any = await self._rx.receive()
|
|
data = self._loads(msg)
|
|
return data
|
|
|
|
def __aiter__(self):
|
|
return self
|
|
|
|
async def __anext__(self):
|
|
return await self.recv_msg()
|
|
|
|
def set_recv_timeout(
|
|
self,
|
|
timeout: float,
|
|
) -> None:
|
|
self._timeout = timeout
|
|
|
|
|
|
async def _reconnect_forever(
|
|
url: str,
|
|
snd: trio.MemorySendChannel,
|
|
nobsws: NoBsWs,
|
|
reset_after: int, # msg recv timeout before reset attempt
|
|
|
|
fixture: AsyncContextManager | None = None,
|
|
task_status: TaskStatus = trio.TASK_STATUS_IGNORED,
|
|
|
|
) -> None:
|
|
|
|
# TODO: can we just report "where" in the call stack
|
|
# the client code is using the ws stream?
|
|
# Maybe we can just drop this since it's already in the log msg
|
|
# orefix?
|
|
if fixture is not None:
|
|
src_mod: str = fixture.__module__
|
|
else:
|
|
src_mod: str = 'unknown'
|
|
|
|
async def proxy_msgs(
|
|
ws: WebSocketConnection,
|
|
pcs: trio.CancelScope, # parent cancel scope
|
|
):
|
|
'''
|
|
Receive (under `timeout` deadline) all msgs from from underlying
|
|
websocket and relay them to (calling) parent task via ``trio``
|
|
mem chan.
|
|
|
|
'''
|
|
# after so many msg recv timeouts, reset the connection
|
|
timeouts: int = 0
|
|
|
|
while True:
|
|
with trio.move_on_after(
|
|
# can be dynamically changed by user code
|
|
nobsws._timeout,
|
|
) as cs:
|
|
try:
|
|
msg: Any = await ws.get_message()
|
|
await snd.send(msg)
|
|
except nobsws.recon_errors:
|
|
log.exception(
|
|
f'{src_mod}\n'
|
|
f'{url} connection bail with:'
|
|
)
|
|
await trio.sleep(0.5)
|
|
pcs.cancel()
|
|
|
|
# go back to reonnect loop in parent task
|
|
return
|
|
|
|
if cs.cancelled_caught:
|
|
timeouts += 1
|
|
if timeouts > reset_after:
|
|
log.error(
|
|
f'{src_mod}\n'
|
|
'WS feed seems down and slow af.. reconnecting\n'
|
|
)
|
|
pcs.cancel()
|
|
|
|
# go back to reonnect loop in parent task
|
|
return
|
|
|
|
async def open_fixture(
|
|
fixture: AsyncContextManager,
|
|
nobsws: NoBsWs,
|
|
task_status: TaskStatus = trio.TASK_STATUS_IGNORED,
|
|
):
|
|
'''
|
|
Open user provided `@acm` and sleep until any connection
|
|
reset occurs.
|
|
|
|
'''
|
|
async with fixture(nobsws) as ret:
|
|
assert ret is None
|
|
task_status.started()
|
|
await trio.sleep_forever()
|
|
|
|
# last_err = None
|
|
nobsws._connected = trio.Event()
|
|
task_status.started()
|
|
|
|
while not snd._closed:
|
|
log.info(
|
|
f'{src_mod}\n'
|
|
f'{url} trying (RE)CONNECT'
|
|
)
|
|
|
|
ws: WebSocketConnection
|
|
try:
|
|
async with (
|
|
trio.open_nursery() as n,
|
|
open_websocket_url(url) as ws,
|
|
):
|
|
cs = nobsws._cs = n.cancel_scope
|
|
nobsws._ws = ws
|
|
log.info(
|
|
f'{src_mod}\n'
|
|
f'Connection success: {url}'
|
|
)
|
|
|
|
# begin relay loop to forward msgs
|
|
n.start_soon(
|
|
proxy_msgs,
|
|
ws,
|
|
cs,
|
|
)
|
|
|
|
if fixture is not None:
|
|
log.info(
|
|
f'{src_mod}\n'
|
|
f'Entering fixture: {fixture}'
|
|
)
|
|
|
|
# TODO: should we return an explicit sub-cs
|
|
# from this fixture task?
|
|
await n.start(
|
|
open_fixture,
|
|
fixture,
|
|
nobsws,
|
|
)
|
|
|
|
# indicate to wrapper / opener that we are up and block
|
|
# to let tasks run **inside** the ws open block above.
|
|
nobsws._connected.set()
|
|
await trio.sleep_forever()
|
|
except HandshakeError:
|
|
log.exception('Retrying connection')
|
|
|
|
# ws & nursery block ends
|
|
|
|
nobsws._connected = trio.Event()
|
|
if cs.cancelled_caught:
|
|
log.cancel(
|
|
f'{url} connection cancelled!'
|
|
)
|
|
# if wrapper cancelled us, we expect it to also
|
|
# have re-assigned a new event
|
|
assert (
|
|
nobsws._connected
|
|
and not nobsws._connected.is_set()
|
|
)
|
|
|
|
# -> from here, move to next reconnect attempt iteration
|
|
# in the while loop above Bp
|
|
|
|
else:
|
|
log.exception(
|
|
f'{src_mod}\n'
|
|
'ws connection closed by client...'
|
|
)
|
|
|
|
|
|
@acm
|
|
async def open_autorecon_ws(
|
|
url: str,
|
|
|
|
fixture: AsyncContextManager | None = None,
|
|
|
|
# time in sec between msgs received before
|
|
# we presume connection might need a reset.
|
|
msg_recv_timeout: float = 16,
|
|
|
|
# count of the number of above timeouts before connection reset
|
|
reset_after: int = 3,
|
|
|
|
) -> AsyncGenerator[tuple[...], NoBsWs]:
|
|
'''
|
|
An auto-reconnect websocket (wrapper API) around
|
|
``trio_websocket.open_websocket_url()`` providing automatic
|
|
re-connection on network errors, msg latency and thus roaming.
|
|
|
|
Here we implement a re-connect websocket interface where a bg
|
|
nursery runs ``WebSocketConnection.receive_message()``s in a loop
|
|
and restarts the full http(s) handshake on catches of certain
|
|
connetivity errors, or some user defined recv timeout.
|
|
|
|
You can provide a ``fixture`` async-context-manager which will be
|
|
entered/exitted around each connection reset; eg. for (re)requesting
|
|
subscriptions without requiring streaming setup code to rerun.
|
|
|
|
'''
|
|
snd: trio.MemorySendChannel
|
|
rcv: trio.MemoryReceiveChannel
|
|
snd, rcv = trio.open_memory_channel(616)
|
|
|
|
async with trio.open_nursery() as n:
|
|
nobsws = NoBsWs(
|
|
url,
|
|
rcv,
|
|
msg_recv_timeout=msg_recv_timeout,
|
|
)
|
|
await n.start(
|
|
partial(
|
|
_reconnect_forever,
|
|
url,
|
|
snd,
|
|
nobsws,
|
|
fixture=fixture,
|
|
reset_after=reset_after,
|
|
)
|
|
)
|
|
await nobsws._connected.wait()
|
|
assert nobsws._cs
|
|
assert nobsws.connected()
|
|
|
|
try:
|
|
yield nobsws
|
|
finally:
|
|
n.cancel_scope.cancel()
|
|
|
|
|
|
'''
|
|
JSONRPC response-request style machinery for transparent multiplexing
|
|
of msgs over a `NoBsWs`.
|
|
|
|
'''
|
|
|
|
|
|
class JSONRPCResult(Struct):
|
|
id: int
|
|
jsonrpc: str = '2.0'
|
|
result: Optional[dict] = None
|
|
error: Optional[dict] = None
|
|
|
|
|
|
@acm
|
|
async def open_jsonrpc_session(
|
|
url: str,
|
|
start_id: int = 0,
|
|
response_type: type = JSONRPCResult,
|
|
msg_recv_timeout: float = float('inf'),
|
|
# ^NOTE, since only `deribit` is using this jsonrpc stuff atm
|
|
# and options mkts are generally "slow moving"..
|
|
#
|
|
# FURTHER if we break the underlying ws connection then since we
|
|
# don't pass a `fixture` to the task that manages `NoBsWs`, i.e.
|
|
# `_reconnect_forever()`, the jsonrpc "transport pipe" get's
|
|
# broken and never restored with wtv init sequence is required to
|
|
# re-establish a working req-resp session.
|
|
|
|
# request_type: Optional[type] = None,
|
|
# request_hook: Optional[Callable] = None,
|
|
# error_hook: Optional[Callable] = None,
|
|
) -> Callable[[str, dict], dict]:
|
|
|
|
# NOTE, store all request msgs so we can raise errors on the
|
|
# caller side!
|
|
req_msgs: dict[int, dict] = {}
|
|
|
|
async with (
|
|
trio.open_nursery() as n,
|
|
open_autorecon_ws(
|
|
url=url,
|
|
msg_recv_timeout=msg_recv_timeout,
|
|
) as ws
|
|
):
|
|
rpc_id: Iterable[int] = count(start_id)
|
|
rpc_results: dict[int, dict] = {}
|
|
|
|
async def json_rpc(
|
|
method: str,
|
|
params: dict,
|
|
) -> dict:
|
|
'''
|
|
perform a json rpc call and wait for the result, raise exception in
|
|
case of error field present on response
|
|
'''
|
|
nonlocal req_msgs
|
|
|
|
req_id: int = next(rpc_id)
|
|
msg = {
|
|
'jsonrpc': '2.0',
|
|
'id': req_id,
|
|
'method': method,
|
|
'params': params
|
|
}
|
|
_id = msg['id']
|
|
|
|
result = rpc_results[_id] = {
|
|
'result': None,
|
|
'error': None,
|
|
'event': trio.Event(), # signal caller resp arrived
|
|
}
|
|
req_msgs[_id] = msg
|
|
|
|
await ws.send_msg(msg)
|
|
|
|
# wait for reponse before unblocking requester code
|
|
await rpc_results[_id]['event'].wait()
|
|
|
|
if (maybe_result := result['result']):
|
|
ret = maybe_result
|
|
del rpc_results[_id]
|
|
|
|
else:
|
|
err = result['error']
|
|
raise Exception(
|
|
f'JSONRPC request failed\n'
|
|
f'req: {msg}\n'
|
|
f'resp: {err}\n'
|
|
)
|
|
|
|
if ret.error is not None:
|
|
raise Exception(json.dumps(ret.error, indent=4))
|
|
|
|
return ret
|
|
|
|
async def recv_task():
|
|
'''
|
|
receives every ws message and stores it in its corresponding
|
|
result field, then sets the event to wakeup original sender
|
|
tasks. also recieves responses to requests originated from
|
|
the server side.
|
|
|
|
'''
|
|
nonlocal req_msgs
|
|
async for msg in ws:
|
|
match msg:
|
|
case {
|
|
'result': _,
|
|
'id': mid,
|
|
} if res_entry := rpc_results.get(mid):
|
|
|
|
res_entry['result'] = response_type(**msg)
|
|
res_entry['event'].set()
|
|
|
|
case {
|
|
'result': _,
|
|
'id': mid,
|
|
} if not rpc_results.get(mid):
|
|
log.warning(
|
|
f'Unexpected ws msg: {json.dumps(msg, indent=4)}'
|
|
)
|
|
|
|
case {
|
|
'method': _,
|
|
'params': _,
|
|
}:
|
|
log.debug(f'Recieved\n{msg}')
|
|
# if request_hook:
|
|
# await request_hook(request_type(**msg))
|
|
|
|
case {
|
|
'error': error
|
|
}:
|
|
# if error_hook:
|
|
# await error_hook(response_type(**msg))
|
|
|
|
# retreive orig request msg, set error
|
|
# response in original "result" msg,
|
|
# THEN FINALLY set the event to signal caller
|
|
# to raise the error in the parent task.
|
|
req_id: int = error['id']
|
|
req_msg: dict = req_msgs[req_id]
|
|
result: dict = rpc_results[req_id]
|
|
result['error'] = error
|
|
result['event'].set()
|
|
log.error(
|
|
f'JSONRPC request failed\n'
|
|
f'req: {req_msg}\n'
|
|
f'resp: {error}\n'
|
|
)
|
|
|
|
case _:
|
|
log.warning(f'Unhandled JSON-RPC msg!?\n{msg}')
|
|
|
|
n.start_soon(recv_task)
|
|
yield json_rpc
|
|
n.cancel_scope.cancel()
|