519 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			519 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Python
		
	
	
| # piker: trading gear for hackers
 | |
| # Copyright (C) Tyler Goodlet (in stewardship for pikers)
 | |
| 
 | |
| # This program is free software: you can redistribute it and/or modify
 | |
| # it under the terms of the GNU Affero General Public License as published by
 | |
| # the Free Software Foundation, either version 3 of the License, or
 | |
| # (at your option) any later version.
 | |
| 
 | |
| # This program is distributed in the hope that it will be useful,
 | |
| # but WITHOUT ANY WARRANTY; without even the implied warranty of
 | |
| # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 | |
| # GNU Affero General Public License for more details.
 | |
| 
 | |
| # You should have received a copy of the GNU Affero General Public License
 | |
| # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 | |
| 
 | |
| """
 | |
| ToOlS fOr CoPInG wITh "tHE wEB" protocols.
 | |
| 
 | |
| """
 | |
| from __future__ import annotations
 | |
| from contextlib import (
 | |
|     asynccontextmanager as acm,
 | |
| )
 | |
| from itertools import count
 | |
| from functools import partial
 | |
| from types import ModuleType
 | |
| from typing import (
 | |
|     Any,
 | |
|     Optional,
 | |
|     Callable,
 | |
|     AsyncContextManager,
 | |
|     AsyncGenerator,
 | |
|     Iterable,
 | |
| )
 | |
| import json
 | |
| 
 | |
| import trio
 | |
| from trio_typing import TaskStatus
 | |
| from trio_websocket import (
 | |
|     WebSocketConnection,
 | |
|     open_websocket_url,
 | |
| )
 | |
| from wsproto.utilities import LocalProtocolError
 | |
| from trio_websocket._impl import (
 | |
|     ConnectionClosed,
 | |
|     DisconnectionTimeout,
 | |
|     ConnectionRejected,
 | |
|     HandshakeError,
 | |
|     ConnectionTimeout,
 | |
| )
 | |
| 
 | |
| from piker.types import Struct
 | |
| from ._util import log
 | |
| 
 | |
| 
 | |
| class NoBsWs:
 | |
|     '''
 | |
|     Make ``trio_websocket`` sockets stay up no matter the bs.
 | |
| 
 | |
|     A shim interface that allows client code to stream from some
 | |
|     ``WebSocketConnection`` but where any connectivy bs is handled
 | |
|     automatcially and entirely in the background.
 | |
| 
 | |
|     NOTE: this type should never be created directly but instead is
 | |
|     provided via the ``open_autorecon_ws()`` factor below.
 | |
| 
 | |
|     '''
 | |
|     # apparently we can QoS for all sorts of reasons..so catch em.
 | |
|     recon_errors = (
 | |
|         ConnectionClosed,
 | |
|         DisconnectionTimeout,
 | |
|         ConnectionRejected,
 | |
|         HandshakeError,
 | |
|         ConnectionTimeout,
 | |
|         LocalProtocolError,
 | |
|     )
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         url: str,
 | |
|         rxchan: trio.MemoryReceiveChannel,
 | |
|         msg_recv_timeout: float,
 | |
| 
 | |
|         serializer: ModuleType = json
 | |
|     ):
 | |
|         self.url = url
 | |
|         self._rx = rxchan
 | |
|         self._timeout = msg_recv_timeout
 | |
| 
 | |
|         # signaling between caller and relay task which determines when
 | |
|         # socket is connected (and subscribed).
 | |
|         self._connected: trio.Event = trio.Event()
 | |
| 
 | |
|         # dynamically reset by the bg relay task
 | |
|         self._ws: WebSocketConnection | None = None
 | |
|         self._cs: trio.CancelScope | None = None
 | |
| 
 | |
|         # interchange codec methods
 | |
|         # TODO: obviously the method API here may be different
 | |
|         # for another interchange format..
 | |
|         self._dumps: Callable = serializer.dumps
 | |
|         self._loads: Callable = serializer.loads
 | |
| 
 | |
|     def connected(self) -> bool:
 | |
|         return self._connected.is_set()
 | |
| 
 | |
|     async def reset(self) -> None:
 | |
|         '''
 | |
|         Reset the underlying ws connection by cancelling
 | |
|         the bg relay task and waiting for it to signal
 | |
|         a new connection.
 | |
| 
 | |
|         '''
 | |
|         self._connected = trio.Event()
 | |
|         self._cs.cancel()
 | |
|         await self._connected.wait()
 | |
| 
 | |
|     async def send_msg(
 | |
|         self,
 | |
|         data: Any,
 | |
|     ) -> None:
 | |
|         while True:
 | |
|             try:
 | |
|                 msg: Any = self._dumps(data)
 | |
|                 return await self._ws.send_message(msg)
 | |
|             except self.recon_errors:
 | |
|                 await self.reset()
 | |
| 
 | |
|     async def recv_msg(self) -> Any:
 | |
|         msg: Any = await self._rx.receive()
 | |
|         data = self._loads(msg)
 | |
|         return data
 | |
| 
 | |
|     def __aiter__(self):
 | |
|         return self
 | |
| 
 | |
|     async def __anext__(self):
 | |
|         return await self.recv_msg()
 | |
| 
 | |
|     def set_recv_timeout(
 | |
|         self,
 | |
|         timeout: float,
 | |
|     ) -> None:
 | |
|         self._timeout = timeout
 | |
| 
 | |
| 
 | |
| async def _reconnect_forever(
 | |
|     url: str,
 | |
|     snd: trio.MemorySendChannel,
 | |
|     nobsws: NoBsWs,
 | |
|     reset_after: int,  # msg recv timeout before reset attempt
 | |
| 
 | |
|     fixture: AsyncContextManager | None = None,
 | |
|     task_status: TaskStatus = trio.TASK_STATUS_IGNORED,
 | |
| 
 | |
| ) -> None:
 | |
| 
 | |
|     # TODO: can we just report "where" in the call stack
 | |
|     # the client code is using the ws stream?
 | |
|     # Maybe we can just drop this since it's already in the log msg
 | |
|     # orefix?
 | |
|     if fixture is not None:
 | |
|         src_mod: str = fixture.__module__
 | |
|     else:
 | |
|         src_mod: str = 'unknown'
 | |
| 
 | |
|     async def proxy_msgs(
 | |
|         ws: WebSocketConnection,
 | |
|         pcs: trio.CancelScope,  # parent cancel scope
 | |
|     ):
 | |
|         '''
 | |
|         Receive (under `timeout` deadline) all msgs from from underlying
 | |
|         websocket and relay them to (calling) parent task via ``trio``
 | |
|         mem chan.
 | |
| 
 | |
|         '''
 | |
|         # after so many msg recv timeouts, reset the connection
 | |
|         timeouts: int = 0
 | |
| 
 | |
|         while True:
 | |
|             with trio.move_on_after(
 | |
|                 # can be dynamically changed by user code
 | |
|                 nobsws._timeout,
 | |
|             ) as cs:
 | |
|                 try:
 | |
|                     msg: Any = await ws.get_message()
 | |
|                     await snd.send(msg)
 | |
|                 except nobsws.recon_errors:
 | |
|                     log.exception(
 | |
|                         f'{src_mod}\n'
 | |
|                         f'{url} connection bail with:'
 | |
|                     )
 | |
|                     await trio.sleep(0.5)
 | |
|                     pcs.cancel()
 | |
| 
 | |
|                     # go back to reonnect loop in parent task
 | |
|                     return
 | |
| 
 | |
|             if cs.cancelled_caught:
 | |
|                 timeouts += 1
 | |
|                 if timeouts > reset_after:
 | |
|                     log.error(
 | |
|                         f'{src_mod}\n'
 | |
|                         'WS feed seems down and slow af.. reconnecting\n'
 | |
|                     )
 | |
|                     pcs.cancel()
 | |
| 
 | |
|                     # go back to reonnect loop in parent task
 | |
|                     return
 | |
| 
 | |
|     async def open_fixture(
 | |
|         fixture: AsyncContextManager,
 | |
|         nobsws: NoBsWs,
 | |
|         task_status: TaskStatus = trio.TASK_STATUS_IGNORED,
 | |
|     ):
 | |
|         '''
 | |
|         Open user provided `@acm` and sleep until any connection
 | |
|         reset occurs.
 | |
| 
 | |
|         '''
 | |
|         async with fixture(nobsws) as ret:
 | |
|             assert ret is None
 | |
|             task_status.started()
 | |
|             await trio.sleep_forever()
 | |
| 
 | |
|     # last_err = None
 | |
|     nobsws._connected = trio.Event()
 | |
|     task_status.started()
 | |
| 
 | |
|     while not snd._closed:
 | |
|         log.info(
 | |
|             f'{src_mod}\n'
 | |
|             f'{url} trying (RE)CONNECT'
 | |
|         )
 | |
| 
 | |
|         ws: WebSocketConnection
 | |
|         try:
 | |
|             async with (
 | |
|                 trio.open_nursery() as n,
 | |
|                 open_websocket_url(url) as ws,
 | |
|             ):
 | |
|                 cs = nobsws._cs = n.cancel_scope
 | |
|                 nobsws._ws = ws
 | |
|                 log.info(
 | |
|                     f'{src_mod}\n'
 | |
|                     f'Connection success: {url}'
 | |
|                 )
 | |
| 
 | |
|                 # begin relay loop to forward msgs
 | |
|                 n.start_soon(
 | |
|                     proxy_msgs,
 | |
|                     ws,
 | |
|                     cs,
 | |
|                 )
 | |
| 
 | |
|                 if fixture is not None:
 | |
|                     log.info(
 | |
|                         f'{src_mod}\n'
 | |
|                         f'Entering fixture: {fixture}'
 | |
|                     )
 | |
| 
 | |
|                     # TODO: should we return an explicit sub-cs
 | |
|                     # from this fixture task?
 | |
|                     await n.start(
 | |
|                         open_fixture,
 | |
|                         fixture,
 | |
|                         nobsws,
 | |
|                     )
 | |
| 
 | |
|                 # indicate to wrapper / opener that we are up and block
 | |
|                 # to let tasks run **inside** the ws open block above.
 | |
|                 nobsws._connected.set()
 | |
|                 await trio.sleep_forever()
 | |
|         except HandshakeError:
 | |
|             log.exception('Retrying connection')
 | |
| 
 | |
|         # ws & nursery block ends
 | |
| 
 | |
|         nobsws._connected = trio.Event()
 | |
|         if cs.cancelled_caught:
 | |
|             log.cancel(
 | |
|                 f'{url} connection cancelled!'
 | |
|             )
 | |
|             # if wrapper cancelled us, we expect it to also
 | |
|             # have re-assigned a new event
 | |
|             assert (
 | |
|                 nobsws._connected
 | |
|                 and not nobsws._connected.is_set()
 | |
|             )
 | |
| 
 | |
|         # -> from here, move to next reconnect attempt iteration
 | |
|         # in the while loop above Bp
 | |
| 
 | |
|     else:
 | |
|         log.exception(
 | |
|             f'{src_mod}\n'
 | |
|             'ws connection closed by client...'
 | |
|         )
 | |
| 
 | |
| 
 | |
| @acm
 | |
| async def open_autorecon_ws(
 | |
|     url: str,
 | |
| 
 | |
|     fixture: AsyncContextManager | None = None,
 | |
| 
 | |
|     # time in sec between msgs received before
 | |
|     # we presume connection might need a reset.
 | |
|     msg_recv_timeout: float = 16,
 | |
| 
 | |
|     # count of the number of above timeouts before connection reset
 | |
|     reset_after: int = 3,
 | |
| 
 | |
| ) -> AsyncGenerator[tuple[...],  NoBsWs]:
 | |
|     '''
 | |
|     An auto-reconnect websocket (wrapper API) around
 | |
|     ``trio_websocket.open_websocket_url()`` providing automatic
 | |
|     re-connection on network errors, msg latency and thus roaming.
 | |
| 
 | |
|     Here we implement a re-connect websocket interface where a bg
 | |
|     nursery runs ``WebSocketConnection.receive_message()``s in a loop
 | |
|     and restarts the full http(s) handshake on catches of certain
 | |
|     connetivity errors, or some user defined recv timeout.
 | |
| 
 | |
|     You can provide a ``fixture`` async-context-manager which will be
 | |
|     entered/exitted around each connection reset; eg. for (re)requesting
 | |
|     subscriptions without requiring streaming setup code to rerun.
 | |
| 
 | |
|     '''
 | |
|     snd: trio.MemorySendChannel
 | |
|     rcv: trio.MemoryReceiveChannel
 | |
|     snd, rcv = trio.open_memory_channel(616)
 | |
| 
 | |
|     async with trio.open_nursery() as n:
 | |
|         nobsws = NoBsWs(
 | |
|             url,
 | |
|             rcv,
 | |
|             msg_recv_timeout=msg_recv_timeout,
 | |
|         )
 | |
|         await n.start(
 | |
|             partial(
 | |
|                 _reconnect_forever,
 | |
|                 url,
 | |
|                 snd,
 | |
|                 nobsws,
 | |
|                 fixture=fixture,
 | |
|                 reset_after=reset_after,
 | |
|             )
 | |
|         )
 | |
|         await nobsws._connected.wait()
 | |
|         assert nobsws._cs
 | |
|         assert nobsws.connected()
 | |
| 
 | |
|         try:
 | |
|             yield nobsws
 | |
|         finally:
 | |
|             n.cancel_scope.cancel()
 | |
| 
 | |
| 
 | |
| '''
 | |
| JSONRPC response-request style machinery for transparent multiplexing
 | |
| of msgs over a `NoBsWs`.
 | |
| 
 | |
| '''
 | |
| 
 | |
| 
 | |
| class JSONRPCResult(Struct):
 | |
|     id: int
 | |
|     jsonrpc: str = '2.0'
 | |
|     result: Optional[dict] = None
 | |
|     error: Optional[dict] = None
 | |
| 
 | |
| 
 | |
| @acm
 | |
| async def open_jsonrpc_session(
 | |
|     url: str,
 | |
|     start_id: int = 0,
 | |
|     response_type: type = JSONRPCResult,
 | |
|     msg_recv_timeout: float = float('inf'),
 | |
|     # ^NOTE, since only `deribit` is using this jsonrpc stuff atm
 | |
|     # and options mkts are generally "slow moving"..
 | |
|     #
 | |
|     # FURTHER if we break the underlying ws connection then since we
 | |
|     # don't pass a `fixture` to the task that manages `NoBsWs`, i.e.
 | |
|     # `_reconnect_forever()`, the jsonrpc "transport pipe" get's
 | |
|     # broken and never restored with wtv init sequence is required to
 | |
|     # re-establish a working req-resp session.
 | |
| 
 | |
|     # request_type: Optional[type] = None,
 | |
|     # request_hook: Optional[Callable] = None,
 | |
|     # error_hook: Optional[Callable] = None,
 | |
| ) -> Callable[[str, dict], dict]:
 | |
| 
 | |
|     # NOTE, store all request msgs so we can raise errors on the
 | |
|     # caller side!
 | |
|     req_msgs: dict[int, dict] = {}
 | |
| 
 | |
|     async with (
 | |
|         trio.open_nursery() as n,
 | |
|         open_autorecon_ws(
 | |
|             url=url,
 | |
|             msg_recv_timeout=msg_recv_timeout,
 | |
|         ) as ws
 | |
|     ):
 | |
|         rpc_id: Iterable[int] = count(start_id)
 | |
|         rpc_results: dict[int, dict] = {}
 | |
| 
 | |
|         async def json_rpc(
 | |
|             method: str,
 | |
|             params: dict,
 | |
|         ) -> dict:
 | |
|             '''
 | |
|             perform a json rpc call and wait for the result, raise exception in
 | |
|             case of error field present on response
 | |
|             '''
 | |
|             nonlocal req_msgs
 | |
| 
 | |
|             req_id: int = next(rpc_id)
 | |
|             msg = {
 | |
|                 'jsonrpc': '2.0',
 | |
|                 'id': req_id,
 | |
|                 'method': method,
 | |
|                 'params': params
 | |
|             }
 | |
|             _id = msg['id']
 | |
| 
 | |
|             result = rpc_results[_id] = {
 | |
|                 'result': None,
 | |
|                 'error': None,
 | |
|                 'event': trio.Event(),  # signal caller resp arrived
 | |
|             }
 | |
|             req_msgs[_id] = msg
 | |
| 
 | |
|             await ws.send_msg(msg)
 | |
| 
 | |
|             # wait for reponse before unblocking requester code
 | |
|             await rpc_results[_id]['event'].wait()
 | |
| 
 | |
|             if (maybe_result := result['result']):
 | |
|                 ret = maybe_result
 | |
|                 del rpc_results[_id]
 | |
| 
 | |
|             else:
 | |
|                 err = result['error']
 | |
|                 raise Exception(
 | |
|                     f'JSONRPC request failed\n'
 | |
|                     f'req: {msg}\n'
 | |
|                     f'resp: {err}\n'
 | |
|                 )
 | |
| 
 | |
|             if ret.error is not None:
 | |
|                 raise Exception(json.dumps(ret.error, indent=4))
 | |
| 
 | |
|             return ret
 | |
| 
 | |
|         async def recv_task():
 | |
|             '''
 | |
|             receives every ws message and stores it in its corresponding
 | |
|             result field, then sets the event to wakeup original sender
 | |
|             tasks. also recieves responses to requests originated from
 | |
|             the server side.
 | |
| 
 | |
|             '''
 | |
|             nonlocal req_msgs
 | |
|             async for msg in ws:
 | |
|                 match msg:
 | |
|                     case {
 | |
|                         'result': _,
 | |
|                         'id': mid,
 | |
|                     } if res_entry := rpc_results.get(mid):
 | |
| 
 | |
|                         res_entry['result'] = response_type(**msg)
 | |
|                         res_entry['event'].set()
 | |
| 
 | |
|                     case {
 | |
|                         'result': _,
 | |
|                         'id': mid,
 | |
|                     } if not rpc_results.get(mid):
 | |
|                         log.warning(
 | |
|                             f'Unexpected ws msg: {json.dumps(msg, indent=4)}'
 | |
|                         )
 | |
| 
 | |
|                     case {
 | |
|                         'method': _,
 | |
|                         'params': _,
 | |
|                     }:
 | |
|                         log.debug(f'Recieved\n{msg}')
 | |
|                         # if request_hook:
 | |
|                         #     await request_hook(request_type(**msg))
 | |
| 
 | |
|                     case {
 | |
|                         'error': error
 | |
|                     }:
 | |
|                         # if error_hook:
 | |
|                         #     await error_hook(response_type(**msg))
 | |
| 
 | |
|                         # retreive orig request msg, set error
 | |
|                         # response in original "result" msg,
 | |
|                         # THEN FINALLY set the event to signal caller
 | |
|                         # to raise the error in the parent task.
 | |
|                         req_id: int = error['id']
 | |
|                         req_msg: dict = req_msgs[req_id]
 | |
|                         result: dict = rpc_results[req_id]
 | |
|                         result['error'] = error
 | |
|                         result['event'].set()
 | |
|                         log.error(
 | |
|                             f'JSONRPC request failed\n'
 | |
|                             f'req: {req_msg}\n'
 | |
|                             f'resp: {error}\n'
 | |
|                         )
 | |
| 
 | |
|                     case _:
 | |
|                         log.warning(f'Unhandled JSON-RPC msg!?\n{msg}')
 | |
| 
 | |
|         n.start_soon(recv_task)
 | |
|         yield json_rpc
 | |
|         n.cancel_scope.cancel()
 |