From 34fb497eb4f720fd213682a79289b2e6a79d69eb Mon Sep 17 00:00:00 2001 From: Guillermo Rodriguez Date: Tue, 23 Aug 2022 22:21:27 -0300 Subject: [PATCH] Add aiter api to NoBsWs and rework cryptofeed relay to not be OOPy --- piker/brokers/deribit/api.py | 521 ++++++++++++++++++---------------- piker/brokers/deribit/feed.py | 6 +- piker/data/_web_bs.py | 6 + 3 files changed, 279 insertions(+), 254 deletions(-) diff --git a/piker/brokers/deribit/api.py b/piker/brokers/deribit/api.py index 32db7c01..0ab5e2be 100644 --- a/piker/brokers/deribit/api.py +++ b/piker/brokers/deribit/api.py @@ -26,7 +26,7 @@ from contextlib import asynccontextmanager as acm, AsyncExitStack from itertools import count from functools import partial from datetime import datetime -from typing import Any, List, Dict, Optional, Iterable +from typing import Any, List, Dict, Optional, Iterable, Callable import pendulum import asks @@ -43,7 +43,7 @@ from .._util import resproc from piker import config from piker.log import get_logger -from tractor.trionics import broadcast_receiver, BroadcastReceiver +from tractor.trionics import broadcast_receiver, BroadcastReceiver, maybe_open_context from tractor import to_asyncio from cryptofeed import FeedHandler @@ -192,6 +192,7 @@ def get_config() -> dict[str, Any]: section = conf.get('deribit') + # TODO: document why we send this, basically because logging params for cryptofeed conf['log'] = {} conf['log']['disabled'] = True @@ -203,7 +204,7 @@ def get_config() -> dict[str, Any]: class Client: - def __init__(self, n: Nursery, ws: NoBsWs) -> None: + def __init__(self, json_rpc: Callable) -> None: self._pairs: dict[str, Any] = None config = get_config().get('deribit', {}) @@ -216,137 +217,12 @@ class Client: self._key_id = None self._key_secret = None - self._ws = ws - self._n = n - - self._rpc_id: Iterable = count(0) - self._rpc_results: Dict[int, Dict] = {} - - self._expiry_time: int = float('inf') - self._access_token: Optional[str] = None - self._refresh_token: Optional[str] = None - - self.feeds = CryptoFeedRelay() + self.json_rpc = json_rpc @property def currencies(self): return ['btc', 'eth', 'sol', 'usd'] - def _next_json_body(self, method: str, params: Dict): - """get the typical json rpc 2.0 msg body and increment the req id - """ - return { - 'jsonrpc': '2.0', - 'id': next(self._rpc_id), - 'method': method, - 'params': params - } - - async def start_rpc(self): - """launch message receiver - """ - self._n.start_soon(self._recv_task) - - # if we have client creds launch auth loop - if self._key_id is not None: - await self._n.start(self._auth_loop) - - async def _recv_task(self): - """receives every ws message and stores it in its corresponding result - field, then sets the event to wakeup original sender tasks. - """ - while True: - msg = JSONRPCResult(**(await self._ws.recv_msg())) - - if msg.id not in self._rpc_results: - # in case this message wasn't beign accounted for store it - self._rpc_results[msg.id] = { - 'result': None, - 'event': trio.Event() - } - - self._rpc_results[msg.id]['result'] = msg - self._rpc_results[msg.id]['event'].set() - - async def json_rpc(self, method: str, params: Dict) -> Dict: - """perform a json rpc call and wait for the result, raise exception in - case of error field present on response - """ - msg = self._next_json_body(method, params) - _id = msg['id'] - - self._rpc_results[_id] = { - 'result': None, - 'event': trio.Event() - } - - await self._ws.send_msg(msg) - - await self._rpc_results[_id]['event'].wait() - - ret = self._rpc_results[_id]['result'] - - del self._rpc_results[_id] - - if ret.error is not None: - raise Exception(json.dumps(ret.error, indent=4)) - - return ret - - async def _auth_loop( - self, - task_status: TaskStatus = trio.TASK_STATUS_IGNORED - ): - """Background task that adquires a first access token and then will - refresh the access token while the nursery isn't cancelled. - - https://docs.deribit.com/?python#authentication-2 - """ - renew_time = 10 - access_scope = 'trade:read_write' - self._expiry_time = time.time() - got_access = False - - while True: - if time.time() - self._expiry_time < renew_time: - # if we are close to token expiry time - - if self._refresh_token != None: - # if we have a refresh token already dont need to send - # secret - params = { - 'grant_type': 'refresh_token', - 'refresh_token': self._refresh_token, - 'scope': access_scope - } - - else: - # we don't have refresh token, send secret to initialize - params = { - 'grant_type': 'client_credentials', - 'client_id': self._key_id, - 'client_secret': self._key_secret, - 'scope': access_scope - } - - resp = await self.json_rpc('public/auth', params) - result = resp.result - - self._expiry_time = time.time() + result['expires_in'] - self._refresh_token = result['refresh_token'] - - if 'access_token' in result: - self._access_token = result['access_token'] - - if not got_access: - # first time this loop runs we must indicate task is - # started, we have auth - got_access = True - task_status.started() - - else: - await trio.sleep(renew_time / 2) - async def get_balances(self, kind: str = 'option') -> dict[str, float]: """Return the set of positions for this account by symbol. @@ -539,149 +415,292 @@ async def get_client() -> Client: trio.open_nursery() as n, open_autorecon_ws(_testnet_ws_url) as ws ): - client = Client(n, ws) - await client.start_rpc() + + _rpc_id: Iterable = count(0) + _rpc_results: Dict[int, Dict] = {} + + _expiry_time: int = float('inf') + _access_token: Optional[str] = None + _refresh_token: Optional[str] = None + + def _next_json_body(method: str, params: Dict): + """get the typical json rpc 2.0 msg body and increment the req id + """ + return { + 'jsonrpc': '2.0', + 'id': next(_rpc_id), + 'method': method, + 'params': params + } + + async def json_rpc(method: str, params: Dict) -> Dict: + """perform a json rpc call and wait for the result, raise exception in + case of error field present on response + """ + msg = _next_json_body(method, params) + _id = msg['id'] + + _rpc_results[_id] = { + 'result': None, + 'event': trio.Event() + } + + await ws.send_msg(msg) + + await _rpc_results[_id]['event'].wait() + + ret = _rpc_results[_id]['result'] + + del _rpc_results[_id] + + if ret.error is not None: + raise Exception(json.dumps(ret.error, indent=4)) + + return ret + + async def _recv_task(): + """receives every ws message and stores it in its corresponding result + field, then sets the event to wakeup original sender tasks. + """ + async for msg in ws: + msg = JSONRPCResult(**msg) + + if msg.id not in _rpc_results: + # in case this message wasn't beign accounted for store it + _rpc_results[msg.id] = { + 'result': None, + 'event': trio.Event() + } + + _rpc_results[msg.id]['result'] = msg + _rpc_results[msg.id]['event'].set() + + client = Client(json_rpc) + + async def _auth_loop( + task_status: TaskStatus = trio.TASK_STATUS_IGNORED + ): + """Background task that adquires a first access token and then will + refresh the access token while the nursery isn't cancelled. + + https://docs.deribit.com/?python#authentication-2 + """ + renew_time = 10 + access_scope = 'trade:read_write' + _expiry_time = time.time() + got_access = False + nonlocal _refresh_token + nonlocal _access_token + + while True: + if time.time() - _expiry_time < renew_time: + # if we are close to token expiry time + + if _refresh_token != None: + # if we have a refresh token already dont need to send + # secret + params = { + 'grant_type': 'refresh_token', + 'refresh_token': _refresh_token, + 'scope': access_scope + } + + else: + # we don't have refresh token, send secret to initialize + params = { + 'grant_type': 'client_credentials', + 'client_id': client._key_id, + 'client_secret': client._key_secret, + 'scope': access_scope + } + + resp = await json_rpc('public/auth', params) + result = resp.result + + _expiry_time = time.time() + result['expires_in'] + _refresh_token = result['refresh_token'] + + if 'access_token' in result: + _access_token = result['access_token'] + + if not got_access: + # first time this loop runs we must indicate task is + # started, we have auth + got_access = True + task_status.started() + + else: + await trio.sleep(renew_time / 2) + + n.start_soon(_recv_task) + # if we have client creds launch auth loop + if client._key_id is not None: + await n.start(_auth_loop) + await client.cache_symbols() yield client - await client.feeds.stop() -class CryptoFeedRelay: +@acm +async def open_feed_handler(): + fh = FeedHandler(config=get_config()) + yield fh + await to_asyncio.run_task(fh.stop_async) - def __init__(self): - self._fh = FeedHandler(config=get_config()) - self._price_streams: dict[str, BroadcastReceiver] = {} - self._order_stream: Optional[BroadcastReceiver] = None +@acm +async def maybe_open_feed_handler() -> trio.abc.ReceiveStream: + async with maybe_open_context( + acm_func=open_feed_handler, + key='feedhandler', + ) as (cache_hit, fh): + yield fh - self._loop = None - async def stop(self): - await to_asyncio.run_task( - partial(self._fh.stop_async, loop=self._loop)) +@acm +async def open_price_feed( + instrument: str +) -> trio.abc.ReceiveStream: - @acm - async def open_price_feed( - self, - instruments: List[str] - ) -> trio.abc.ReceiveStream: - inst_str = ','.join(instruments) - instruments = [piker_sym_to_cb_sym(i) for i in instruments] + # XXX: hangs when going into this ctx mngr + async with maybe_open_feed_handler() as fh: - if inst_str in self._price_streams: - # TODO: a good value for maxlen? - yield broadcast_receiver(self._price_streams[inst_str], 10) + async def relay( + from_trio: asyncio.Queue, + to_trio: trio.abc.SendChannel, + ) -> None: + async def _trade(data: dict, receipt_timestamp): + to_trio.send_nowait(('trade', { + 'symbol': cb_sym_to_deribit_inst( + str_to_cb_sym(data.symbol)).lower(), + 'last': data, + 'broker_ts': time.time(), + 'data': data.to_dict(), + 'receipt': receipt_timestamp + })) + async def _l1(data: dict, receipt_timestamp): + to_trio.send_nowait(('l1', { + 'symbol': cb_sym_to_deribit_inst( + str_to_cb_sym(data.symbol)).lower(), + 'ticks': [ + {'type': 'bid', + 'price': float(data.bid_price), 'size': float(data.bid_size)}, + {'type': 'bsize', + 'price': float(data.bid_price), 'size': float(data.bid_size)}, + {'type': 'ask', + 'price': float(data.ask_price), 'size': float(data.ask_size)}, + {'type': 'asize', + 'price': float(data.ask_price), 'size': float(data.ask_size)} + ] + })) + + fh.add_feed( + DERIBIT, + channels=[TRADES, L1_BOOK], + symbols=[instrument], + callbacks={ + TRADES: _trade, + L1_BOOK: _l1 + }) + + if not fh.running: + fh.run( + start_loop=False, + install_signal_handlers=False) + + # sync with trio + to_trio.send_nowait(None) + + try: + await asyncio.sleep(float('inf')) + + except asyncio.exceptions.CancelledError: + ... + + async with to_asyncio.open_channel_from( + relay + ) as (first, chan): + yield chan + + +@acm +async def maybe_open_price_feed( + instrument: str +) -> trio.abc.ReceiveStream: + + # TODO: add a predicate to maybe_open_context + async with maybe_open_context( + acm_func=open_price_feed, + kwargs={ + 'instrument': instrument + }, + key=f'{instrument}-price', + ) as (cache_hit, feed): + if cache_hit: + yield broadcast_receiver(feed, 10) else: - async def relay( - from_trio: asyncio.Queue, - to_trio: trio.abc.SendChannel, - ) -> None: - async def _trade(data: dict, receipt_timestamp): - to_trio.send_nowait(('trade', { - 'symbol': cb_sym_to_deribit_inst( - str_to_cb_sym(data.symbol)).lower(), - 'last': data, - 'broker_ts': time.time(), - 'data': data.to_dict(), - 'receipt': receipt_timestamp - })) + yield feed - async def _l1(data: dict, receipt_timestamp): - to_trio.send_nowait(('l1', { - 'symbol': cb_sym_to_deribit_inst( - str_to_cb_sym(data.symbol)).lower(), - 'ticks': [ - {'type': 'bid', - 'price': float(data.bid_price), 'size': float(data.bid_size)}, - {'type': 'bsize', - 'price': float(data.bid_price), 'size': float(data.bid_size)}, - {'type': 'ask', - 'price': float(data.ask_price), 'size': float(data.ask_size)}, - {'type': 'asize', - 'price': float(data.ask_price), 'size': float(data.ask_size)} - ] - })) +@acm +async def open_order_feed( + instrument: List[str] +) -> trio.abc.ReceiveStream: - self._fh.add_feed( - DERIBIT, - channels=[TRADES, L1_BOOK], - symbols=instruments, - callbacks={ - TRADES: _trade, - L1_BOOK: _l1 - }) + async with maybe_open_feed_handler() as fh: - if not self._fh.running: - self._fh.run( - start_loop=False, - install_signal_handlers=False) - self._loop = asyncio.get_event_loop() + async def relay( + from_trio: asyncio.Queue, + to_trio: trio.abc.SendChannel, + ) -> None: + async def _fill(data: dict, receipt_timestamp): + breakpoint() - # sync with trio - to_trio.send_nowait(None) + async def _order_info(data: dict, receipt_timestamp): + breakpoint() - try: - await asyncio.sleep(float('inf')) + fh.add_feed( + DERIBIT, + channels=[FILLS, ORDER_INFO], + symbols=[instrument], + callbacks={ + FILLS: _fill, + ORDER_INFO: _order_info, + }) - except asyncio.exceptions.CancelledError: - ... + if not fh.running: + fh.run( + start_loop=False, + install_signal_handlers=False) - async with to_asyncio.open_channel_from( - relay - ) as (first, chan): - self._price_streams[inst_str] = chan - yield self._price_streams[inst_str] + # sync with trio + to_trio.send_nowait(None) - @acm - async def open_order_feed( - self, - instruments: List[str] - ) -> trio.abc.ReceiveStream: + try: + await asyncio.sleep(float('inf')) - inst_str = ','.join(instruments) - instruments = [piker_sym_to_cb_sym(i) for i in instruments] + except asyncio.exceptions.CancelledError: + ... - if self._order_stream: - yield broadcast_receiver(self._order_streams[inst_str], 10) + async with to_asyncio.open_channel_from( + relay + ) as (first, chan): + yield chan +@acm +async def maybe_open_order_feed( + instrument: str +) -> trio.abc.ReceiveStream: + + # TODO: add a predicate to maybe_open_context + async with maybe_open_context( + acm_func=open_order_feed, + kwargs={ + 'instrument': instrument + }, + key=f'{instrument}-order', + ) as (cache_hit, feed): + if cache_hit: + yield broadcast_receiver(feed, 10) else: - async def relay( - from_trio: asyncio.Queue, - to_trio: trio.abc.SendChannel, - ) -> None: - async def _fill(data: dict, receipt_timestamp): - breakpoint() - - async def _order_info(data: dict, receipt_timestamp): - breakpoint() - - self._fh.add_feed( - DERIBIT, - channels=[FILLS, ORDER_INFO], - symbols=instruments, - callbacks={ - FILLS: _fill, - ORDER_INFO: _order_info, - }) - - if not self._fh.running: - self._fh.run( - start_loop=False, - install_signal_handlers=False) - self._loop = asyncio.get_event_loop() - - # sync with trio - to_trio.send_nowait(None) - - try: - await asyncio.sleep(float('inf')) - - except asyncio.exceptions.CancelledError: - ... - - async with to_asyncio.open_channel_from( - relay - ) as (first, chan): - self._order_stream = chan - yield self._order_stream + yield feed diff --git a/piker/brokers/deribit/feed.py b/piker/brokers/deribit/feed.py index b3daed7d..a45861f7 100644 --- a/piker/brokers/deribit/feed.py +++ b/piker/brokers/deribit/feed.py @@ -48,7 +48,8 @@ from cryptofeed.symbols import Symbol from .api import ( Client, Trade, get_config, - str_to_cb_sym, piker_sym_to_cb_sym, cb_sym_to_deribit_inst + str_to_cb_sym, piker_sym_to_cb_sym, cb_sym_to_deribit_inst, + maybe_open_price_feed ) _spawn_kwargs = { @@ -144,8 +145,7 @@ async def stream_quotes( nsym = piker_sym_to_cb_sym(sym) - async with client.feeds.open_price_feed( - symbols) as stream: + async with maybe_open_price_feed(sym) as stream: cache = await client.cache_symbols() diff --git a/piker/data/_web_bs.py b/piker/data/_web_bs.py index 64d447df..78e82dfd 100644 --- a/piker/data/_web_bs.py +++ b/piker/data/_web_bs.py @@ -123,6 +123,12 @@ class NoBsWs: except self.recon_errors: await self._connect() + def __aiter__(self): + return self + + async def __anext__(self): + return await self.recv_msg() + @asynccontextmanager async def open_autorecon_ws(