Add aiter api to NoBsWs and rework cryptofeed relay to not be OOPy

size_in_shm_token
Guillermo Rodriguez 2022-08-23 22:21:27 -03:00
parent 6669ba6590
commit 34fb497eb4
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
3 changed files with 279 additions and 254 deletions

View File

@ -26,7 +26,7 @@ from contextlib import asynccontextmanager as acm, AsyncExitStack
from itertools import count from itertools import count
from functools import partial from functools import partial
from datetime import datetime from datetime import datetime
from typing import Any, List, Dict, Optional, Iterable from typing import Any, List, Dict, Optional, Iterable, Callable
import pendulum import pendulum
import asks import asks
@ -43,7 +43,7 @@ from .._util import resproc
from piker import config from piker import config
from piker.log import get_logger from piker.log import get_logger
from tractor.trionics import broadcast_receiver, BroadcastReceiver from tractor.trionics import broadcast_receiver, BroadcastReceiver, maybe_open_context
from tractor import to_asyncio from tractor import to_asyncio
from cryptofeed import FeedHandler from cryptofeed import FeedHandler
@ -192,6 +192,7 @@ def get_config() -> dict[str, Any]:
section = conf.get('deribit') section = conf.get('deribit')
# TODO: document why we send this, basically because logging params for cryptofeed
conf['log'] = {} conf['log'] = {}
conf['log']['disabled'] = True conf['log']['disabled'] = True
@ -203,7 +204,7 @@ def get_config() -> dict[str, Any]:
class Client: class Client:
def __init__(self, n: Nursery, ws: NoBsWs) -> None: def __init__(self, json_rpc: Callable) -> None:
self._pairs: dict[str, Any] = None self._pairs: dict[str, Any] = None
config = get_config().get('deribit', {}) config = get_config().get('deribit', {})
@ -216,137 +217,12 @@ class Client:
self._key_id = None self._key_id = None
self._key_secret = None self._key_secret = None
self._ws = ws self.json_rpc = json_rpc
self._n = n
self._rpc_id: Iterable = count(0)
self._rpc_results: Dict[int, Dict] = {}
self._expiry_time: int = float('inf')
self._access_token: Optional[str] = None
self._refresh_token: Optional[str] = None
self.feeds = CryptoFeedRelay()
@property @property
def currencies(self): def currencies(self):
return ['btc', 'eth', 'sol', 'usd'] return ['btc', 'eth', 'sol', 'usd']
def _next_json_body(self, method: str, params: Dict):
"""get the typical json rpc 2.0 msg body and increment the req id
"""
return {
'jsonrpc': '2.0',
'id': next(self._rpc_id),
'method': method,
'params': params
}
async def start_rpc(self):
"""launch message receiver
"""
self._n.start_soon(self._recv_task)
# if we have client creds launch auth loop
if self._key_id is not None:
await self._n.start(self._auth_loop)
async def _recv_task(self):
"""receives every ws message and stores it in its corresponding result
field, then sets the event to wakeup original sender tasks.
"""
while True:
msg = JSONRPCResult(**(await self._ws.recv_msg()))
if msg.id not in self._rpc_results:
# in case this message wasn't beign accounted for store it
self._rpc_results[msg.id] = {
'result': None,
'event': trio.Event()
}
self._rpc_results[msg.id]['result'] = msg
self._rpc_results[msg.id]['event'].set()
async def json_rpc(self, method: str, params: Dict) -> Dict:
"""perform a json rpc call and wait for the result, raise exception in
case of error field present on response
"""
msg = self._next_json_body(method, params)
_id = msg['id']
self._rpc_results[_id] = {
'result': None,
'event': trio.Event()
}
await self._ws.send_msg(msg)
await self._rpc_results[_id]['event'].wait()
ret = self._rpc_results[_id]['result']
del self._rpc_results[_id]
if ret.error is not None:
raise Exception(json.dumps(ret.error, indent=4))
return ret
async def _auth_loop(
self,
task_status: TaskStatus = trio.TASK_STATUS_IGNORED
):
"""Background task that adquires a first access token and then will
refresh the access token while the nursery isn't cancelled.
https://docs.deribit.com/?python#authentication-2
"""
renew_time = 10
access_scope = 'trade:read_write'
self._expiry_time = time.time()
got_access = False
while True:
if time.time() - self._expiry_time < renew_time:
# if we are close to token expiry time
if self._refresh_token != None:
# if we have a refresh token already dont need to send
# secret
params = {
'grant_type': 'refresh_token',
'refresh_token': self._refresh_token,
'scope': access_scope
}
else:
# we don't have refresh token, send secret to initialize
params = {
'grant_type': 'client_credentials',
'client_id': self._key_id,
'client_secret': self._key_secret,
'scope': access_scope
}
resp = await self.json_rpc('public/auth', params)
result = resp.result
self._expiry_time = time.time() + result['expires_in']
self._refresh_token = result['refresh_token']
if 'access_token' in result:
self._access_token = result['access_token']
if not got_access:
# first time this loop runs we must indicate task is
# started, we have auth
got_access = True
task_status.started()
else:
await trio.sleep(renew_time / 2)
async def get_balances(self, kind: str = 'option') -> dict[str, float]: async def get_balances(self, kind: str = 'option') -> dict[str, float]:
"""Return the set of positions for this account """Return the set of positions for this account
by symbol. by symbol.
@ -539,40 +415,156 @@ async def get_client() -> Client:
trio.open_nursery() as n, trio.open_nursery() as n,
open_autorecon_ws(_testnet_ws_url) as ws open_autorecon_ws(_testnet_ws_url) as ws
): ):
client = Client(n, ws)
await client.start_rpc() _rpc_id: Iterable = count(0)
_rpc_results: Dict[int, Dict] = {}
_expiry_time: int = float('inf')
_access_token: Optional[str] = None
_refresh_token: Optional[str] = None
def _next_json_body(method: str, params: Dict):
"""get the typical json rpc 2.0 msg body and increment the req id
"""
return {
'jsonrpc': '2.0',
'id': next(_rpc_id),
'method': method,
'params': params
}
async def json_rpc(method: str, params: Dict) -> Dict:
"""perform a json rpc call and wait for the result, raise exception in
case of error field present on response
"""
msg = _next_json_body(method, params)
_id = msg['id']
_rpc_results[_id] = {
'result': None,
'event': trio.Event()
}
await ws.send_msg(msg)
await _rpc_results[_id]['event'].wait()
ret = _rpc_results[_id]['result']
del _rpc_results[_id]
if ret.error is not None:
raise Exception(json.dumps(ret.error, indent=4))
return ret
async def _recv_task():
"""receives every ws message and stores it in its corresponding result
field, then sets the event to wakeup original sender tasks.
"""
async for msg in ws:
msg = JSONRPCResult(**msg)
if msg.id not in _rpc_results:
# in case this message wasn't beign accounted for store it
_rpc_results[msg.id] = {
'result': None,
'event': trio.Event()
}
_rpc_results[msg.id]['result'] = msg
_rpc_results[msg.id]['event'].set()
client = Client(json_rpc)
async def _auth_loop(
task_status: TaskStatus = trio.TASK_STATUS_IGNORED
):
"""Background task that adquires a first access token and then will
refresh the access token while the nursery isn't cancelled.
https://docs.deribit.com/?python#authentication-2
"""
renew_time = 10
access_scope = 'trade:read_write'
_expiry_time = time.time()
got_access = False
nonlocal _refresh_token
nonlocal _access_token
while True:
if time.time() - _expiry_time < renew_time:
# if we are close to token expiry time
if _refresh_token != None:
# if we have a refresh token already dont need to send
# secret
params = {
'grant_type': 'refresh_token',
'refresh_token': _refresh_token,
'scope': access_scope
}
else:
# we don't have refresh token, send secret to initialize
params = {
'grant_type': 'client_credentials',
'client_id': client._key_id,
'client_secret': client._key_secret,
'scope': access_scope
}
resp = await json_rpc('public/auth', params)
result = resp.result
_expiry_time = time.time() + result['expires_in']
_refresh_token = result['refresh_token']
if 'access_token' in result:
_access_token = result['access_token']
if not got_access:
# first time this loop runs we must indicate task is
# started, we have auth
got_access = True
task_status.started()
else:
await trio.sleep(renew_time / 2)
n.start_soon(_recv_task)
# if we have client creds launch auth loop
if client._key_id is not None:
await n.start(_auth_loop)
await client.cache_symbols() await client.cache_symbols()
yield client yield client
await client.feeds.stop()
class CryptoFeedRelay: @acm
async def open_feed_handler():
fh = FeedHandler(config=get_config())
yield fh
await to_asyncio.run_task(fh.stop_async)
def __init__(self):
self._fh = FeedHandler(config=get_config())
self._price_streams: dict[str, BroadcastReceiver] = {} @acm
self._order_stream: Optional[BroadcastReceiver] = None async def maybe_open_feed_handler() -> trio.abc.ReceiveStream:
async with maybe_open_context(
acm_func=open_feed_handler,
key='feedhandler',
) as (cache_hit, fh):
yield fh
self._loop = None
async def stop(self):
await to_asyncio.run_task(
partial(self._fh.stop_async, loop=self._loop))
@acm @acm
async def open_price_feed( async def open_price_feed(
self, instrument: str
instruments: List[str]
) -> trio.abc.ReceiveStream: ) -> trio.abc.ReceiveStream:
inst_str = ','.join(instruments)
instruments = [piker_sym_to_cb_sym(i) for i in instruments]
if inst_str in self._price_streams: # XXX: hangs when going into this ctx mngr
# TODO: a good value for maxlen? async with maybe_open_feed_handler() as fh:
yield broadcast_receiver(self._price_streams[inst_str], 10)
else:
async def relay( async def relay(
from_trio: asyncio.Queue, from_trio: asyncio.Queue,
to_trio: trio.abc.SendChannel, to_trio: trio.abc.SendChannel,
@ -603,20 +595,19 @@ class CryptoFeedRelay:
] ]
})) }))
self._fh.add_feed( fh.add_feed(
DERIBIT, DERIBIT,
channels=[TRADES, L1_BOOK], channels=[TRADES, L1_BOOK],
symbols=instruments, symbols=[instrument],
callbacks={ callbacks={
TRADES: _trade, TRADES: _trade,
L1_BOOK: _l1 L1_BOOK: _l1
}) })
if not self._fh.running: if not fh.running:
self._fh.run( fh.run(
start_loop=False, start_loop=False,
install_signal_handlers=False) install_signal_handlers=False)
self._loop = asyncio.get_event_loop()
# sync with trio # sync with trio
to_trio.send_nowait(None) to_trio.send_nowait(None)
@ -630,22 +621,34 @@ class CryptoFeedRelay:
async with to_asyncio.open_channel_from( async with to_asyncio.open_channel_from(
relay relay
) as (first, chan): ) as (first, chan):
self._price_streams[inst_str] = chan yield chan
yield self._price_streams[inst_str]
@acm
async def maybe_open_price_feed(
instrument: str
) -> trio.abc.ReceiveStream:
# TODO: add a predicate to maybe_open_context
async with maybe_open_context(
acm_func=open_price_feed,
kwargs={
'instrument': instrument
},
key=f'{instrument}-price',
) as (cache_hit, feed):
if cache_hit:
yield broadcast_receiver(feed, 10)
else:
yield feed
@acm @acm
async def open_order_feed( async def open_order_feed(
self, instrument: List[str]
instruments: List[str]
) -> trio.abc.ReceiveStream: ) -> trio.abc.ReceiveStream:
inst_str = ','.join(instruments) async with maybe_open_feed_handler() as fh:
instruments = [piker_sym_to_cb_sym(i) for i in instruments]
if self._order_stream:
yield broadcast_receiver(self._order_streams[inst_str], 10)
else:
async def relay( async def relay(
from_trio: asyncio.Queue, from_trio: asyncio.Queue,
to_trio: trio.abc.SendChannel, to_trio: trio.abc.SendChannel,
@ -656,20 +659,19 @@ class CryptoFeedRelay:
async def _order_info(data: dict, receipt_timestamp): async def _order_info(data: dict, receipt_timestamp):
breakpoint() breakpoint()
self._fh.add_feed( fh.add_feed(
DERIBIT, DERIBIT,
channels=[FILLS, ORDER_INFO], channels=[FILLS, ORDER_INFO],
symbols=instruments, symbols=[instrument],
callbacks={ callbacks={
FILLS: _fill, FILLS: _fill,
ORDER_INFO: _order_info, ORDER_INFO: _order_info,
}) })
if not self._fh.running: if not fh.running:
self._fh.run( fh.run(
start_loop=False, start_loop=False,
install_signal_handlers=False) install_signal_handlers=False)
self._loop = asyncio.get_event_loop()
# sync with trio # sync with trio
to_trio.send_nowait(None) to_trio.send_nowait(None)
@ -683,5 +685,22 @@ class CryptoFeedRelay:
async with to_asyncio.open_channel_from( async with to_asyncio.open_channel_from(
relay relay
) as (first, chan): ) as (first, chan):
self._order_stream = chan yield chan
yield self._order_stream
@acm
async def maybe_open_order_feed(
instrument: str
) -> trio.abc.ReceiveStream:
# TODO: add a predicate to maybe_open_context
async with maybe_open_context(
acm_func=open_order_feed,
kwargs={
'instrument': instrument
},
key=f'{instrument}-order',
) as (cache_hit, feed):
if cache_hit:
yield broadcast_receiver(feed, 10)
else:
yield feed

View File

@ -48,7 +48,8 @@ from cryptofeed.symbols import Symbol
from .api import ( from .api import (
Client, Trade, Client, Trade,
get_config, get_config,
str_to_cb_sym, piker_sym_to_cb_sym, cb_sym_to_deribit_inst str_to_cb_sym, piker_sym_to_cb_sym, cb_sym_to_deribit_inst,
maybe_open_price_feed
) )
_spawn_kwargs = { _spawn_kwargs = {
@ -144,8 +145,7 @@ async def stream_quotes(
nsym = piker_sym_to_cb_sym(sym) nsym = piker_sym_to_cb_sym(sym)
async with client.feeds.open_price_feed( async with maybe_open_price_feed(sym) as stream:
symbols) as stream:
cache = await client.cache_symbols() cache = await client.cache_symbols()

View File

@ -123,6 +123,12 @@ class NoBsWs:
except self.recon_errors: except self.recon_errors:
await self._connect() await self._connect()
def __aiter__(self):
return self
async def __anext__(self):
return await self.recv_msg()
@asynccontextmanager @asynccontextmanager
async def open_autorecon_ws( async def open_autorecon_ws(