Compare commits

...

17 Commits

Author SHA1 Message Date
Nelson Torres 3dcee16bf6 config refactor
only one get_config method for api class and cryptofeed feed handler
2025-01-29 15:44:33 -03:00
Nelson Torres 1f41b151d7 move constants to venue 2025-01-29 15:44:33 -03:00
Nelson Torres e8c196fd88 refactor redundant code 2025-01-29 15:44:33 -03:00
Nelson Torres 7cefb202fb name formatting fixes 2025-01-29 15:44:33 -03:00
Nelson Torres ddb4c0269f get_mkt_info cleanup 2025-01-29 15:44:33 -03:00
Nelson Torres 139f62f4de cache_symbols refactor 2025-01-29 15:44:33 -03:00
Nelson Torres 338e292002 json_rpc_auth_wrapper 2025-01-29 15:44:33 -03:00
Nelson Torres 13af0a90eb move object classes to venue 2025-01-29 15:44:33 -03:00
Nelson Torres e84781ca1e Added options symbols to get_assets 2025-01-29 15:44:33 -03:00
Nelson Torres 576b15e2c6 get_assets now uses public endpoint
It's better if the data is available through a public endpoint.
2025-01-29 15:44:33 -03:00
Nelson Torres b7e54571ea now using exch_info in search_symbols 2025-01-29 15:44:33 -03:00
Nelson Torres 1a295c0c21 Fix bs_fqme using venue and expiry 2025-01-29 15:44:33 -03:00
Nelson Torres f057a20bfa Added expiry property for OptionPair 2025-01-29 15:44:33 -03:00
Nelson Torres 9dba47902e No longer needed 2025-01-29 15:44:33 -03:00
Nelson Torres e0ecef04bb bs_mktid instead bs_fqme for deribits options 2025-01-29 15:44:33 -03:00
Nelson Torres 266347bcdb Fixed pair instrument name in search_symbols endpoint.
Fixed instrument in bars endpoint, for options in deribits bs_mktid instead bs_fqme.
Fixed the id is in msg.
2025-01-29 15:44:33 -03:00
Tyler Goodlet 051d43b559 data._web_bs: try to raise jsonrpc errors in parent task 2025-01-29 15:44:33 -03:00
4 changed files with 237 additions and 234 deletions

View File

@ -58,13 +58,20 @@ from cryptofeed.symbols import Symbol
# types for managing the cb callbacks. # types for managing the cb callbacks.
# from cryptofeed.types import L1Book # from cryptofeed.types import L1Book
from .venues import ( from .venues import (
_ws_url,
MarketType, MarketType,
PAIRTYPES, PAIRTYPES,
Pair, Pair,
OptionPair, OptionPair,
JSONRPCResult,
JSONRPCChannel,
KLinesResult,
Trade,
LastTradesResult,
) )
from piker.accounting import ( from piker.accounting import (
Asset, Asset,
digits_to_dec,
MktPair, MktPair,
) )
from piker.data import ( from piker.data import (
@ -89,60 +96,6 @@ _spawn_kwargs = {
} }
_url = 'https://www.deribit.com'
_ws_url = 'wss://www.deribit.com/ws/api/v2'
_testnet_ws_url = 'wss://test.deribit.com/ws/api/v2'
class JSONRPCResult(Struct):
id: int
usIn: int
usOut: int
usDiff: int
testnet: bool
jsonrpc: str = '2.0'
result: Optional[list[dict]] = None
error: Optional[dict] = None
class JSONRPCChannel(Struct):
method: str
params: dict
jsonrpc: str = '2.0'
class KLinesResult(Struct):
close: list[float]
cost: list[float]
high: list[float]
low: list[float]
open: list[float]
status: str
ticks: list[int]
volume: list[float]
class Trade(Struct):
trade_seq: int
trade_id: str
timestamp: int
tick_direction: int
price: float
mark_price: float
iv: float
instrument_name: str
index_price: float
direction: str
contracts: float
amount: float
combo_trade_id: Optional[int] = 0,
combo_id: Optional[str] = '',
block_trade_leg_count: Optional[int] = 0,
block_trade_id: Optional[str] = '',
class LastTradesResult(Struct):
trades: list[Trade]
has_more: bool
# convert datetime obj timestamp to unixtime in milliseconds # convert datetime obj timestamp to unixtime in milliseconds
def deribit_timestamp(when): def deribit_timestamp(when):
return int((when.timestamp() * 1000) + (when.microsecond / 1000)) return int((when.timestamp() * 1000) + (when.microsecond / 1000))
@ -233,34 +186,22 @@ def get_config() -> dict[str, Any]:
) )
section: dict = {} section: dict = {}
section = conf.get('deribit') section = conf.get('deribit')
section['log'] = {}
section['log']['filename'] = 'feedhandler.log'
section['log']['level'] = 'DEBUG'
section['log']['disabled'] = True
if section is None: if section is None:
log.warning(f'No config section found for deribit in {path}') log.warning(f'No config section found for deribit in {path}')
return {} return {}
conf_option = section.get('option', {})
section.clear # clear the dict to reuse it
section['deribit'] = {}
section['deribit']['key_id'] = conf_option.get('api_key')
section['deribit']['key_secret'] = conf_option.get('api_secret')
section['log'] = {}
section['log']['filename'] = 'feedhandler.log'
section['log']['level'] = 'DEBUG'
return section return section
def get_fh_config() -> dict[str, Any]:
conf_option = get_config().get('option', {})
conf_log = get_config().get('log', {})
return {
'log': {
'filename': conf_log.get('filename'),
'level': conf_log.get('level'),
'disabled': conf_log.get('disabled')
},
'deribit': {
'key_id': conf_option.get('api_key'),
'key_secret': conf_option.get('api_secret')
}
}
class Client: class Client:
@ -272,16 +213,45 @@ class Client:
) -> None: ) -> None:
self._pairs: ChainMap[str, Pair] = ChainMap() self._pairs: ChainMap[str, Pair] = ChainMap()
config = get_config().get('option', {}) config = get_config().get('deribit', {})
self._key_id = config.get('api_key') self._key_id = config.get('key_id')
self._key_secret = config.get('api_secret') self._key_secret = config.get('key_secret')
self.json_rpc = json_rpc self.json_rpc = json_rpc
@property self._auth_ts = None
def currencies(self): self._auth_renew_ts = 5 # seconds to renew auth
return ['btc', 'eth', 'sol', 'usd']
async def _json_rpc_auth_wrapper(self, *args, **kwargs) -> JSONRPCResult:
"""Background task that adquires a first access token and then will
refresh the access token.
https://docs.deribit.com/?python#authentication-2
"""
access_scope = 'trade:read_write'
current_ts = time.time()
if not self._auth_ts or current_ts - self._auth_ts < self._auth_renew_ts:
# if we are close to token expiry time
params = {
'grant_type': 'client_credentials',
'client_id': self._key_id,
'client_secret': self._key_secret,
'scope': access_scope
}
resp = await self.json_rpc('public/auth', params)
result = resp.result
self._auth_ts = time.time() + result['expires_in']
return await self.json_rpc(*args, **kwargs)
async def get_balances( async def get_balances(
self, self,
@ -293,7 +263,7 @@ class Client:
balances = {} balances = {}
for currency in self.currencies: for currency in self.currencies:
resp = await self.json_rpc( resp = await self._json_rpc_auth_wrapper(
'private/get_positions', params={ 'private/get_positions', params={
'currency': currency.upper(), 'currency': currency.upper(),
'kind': kind}) 'kind': kind})
@ -311,21 +281,28 @@ class Client:
by symbol. by symbol.
""" """
assets = {} assets = {}
resp = await self.json_rpc( resp = await self._json_rpc_auth_wrapper(
'private/get_account_summaries', 'public/get_currencies',
params={ params={}
'extended' : True
}
) )
summaries = resp.result['summaries'] currencies = resp.result
for summary in summaries: for currency in currencies:
currency = summary['currency'] name = currency['currency']
tx_tick = Decimal('1e-08') tx_tick = digits_to_dec(currency['fee_precision'])
atype='crypto_currency' atype='crypto_currency'
assets[currency] = Asset( assets[name] = Asset(
name=currency, name=name,
atype=atype, atype=atype,
tx_tick=tx_tick) tx_tick=tx_tick)
instruments = await self.symbol_info(currency=name)
for instrument in instruments:
pair = instruments[instrument]
assets[pair.symbol] = Asset(
name=pair.symbol,
atype=pair.venue,
tx_tick=pair.size_tick)
return assets return assets
async def get_mkt_pairs(self) -> dict[str, Pair]: async def get_mkt_pairs(self) -> dict[str, Pair]:
@ -351,7 +328,7 @@ class Client:
'type': 'limit', 'type': 'limit',
'price': price, 'price': price,
} }
resp = await self.json_rpc( resp = await self._json_rpc_auth_wrapper(
f'private/{action}', params) f'private/{action}', params)
return resp.result return resp.result
@ -359,7 +336,7 @@ class Client:
async def submit_cancel(self, oid: str): async def submit_cancel(self, oid: str):
"""Send cancel request for order id """Send cancel request for order id
""" """
resp = await self.json_rpc( resp = await self._json_rpc_auth_wrapper(
'private/cancel', {'order_id': oid}) 'private/cancel', {'order_id': oid})
return resp.result return resp.result
@ -367,7 +344,7 @@ class Client:
self, self,
sym: str | None = None, sym: str | None = None,
venue: MarketType | None = None, venue: MarketType = 'option',
expiry: str | None = None, expiry: str | None = None,
) -> dict[str, Pair] | Pair: ) -> dict[str, Pair] | Pair:
@ -381,7 +358,7 @@ class Client:
return cached_pair return cached_pair
if sym: if sym:
return pair_table[sym.lower()] return pair_table[sym]
else: else:
return self._pairs return self._pairs
@ -407,7 +384,7 @@ class Client:
'expired': str(expired).lower() 'expired': str(expired).lower()
} }
resp: JSONRPCResult = await self.json_rpc( resp: JSONRPCResult = await self._json_rpc_auth_wrapper(
'public/get_instruments', 'public/get_instruments',
params, params,
) )
@ -440,10 +417,32 @@ class Client:
async def cache_symbols( async def cache_symbols(
self, self,
) -> dict: venue: MarketType = 'option',
if not self._pairs: ) -> None:
self._pairs = await self.symbol_info() # lookup internal mkt-specific pair table to update
pair_table: dict[str, Pair] = self._pairs
# make API request(s)
mkt_pairs = await self.symbol_info()
if not mkt_pairs:
raise SymbolNotFound(f'No market pairs found!?:\n{resp}')
pairs_view_subtable: dict[str, Pair] = {}
for instrument in mkt_pairs:
pair_type: Type = PAIRTYPES[venue]
pair: Pair = pair_type(**mkt_pairs[instrument].to_dict())
pair_table[pair.symbol.upper()] = pair
# update an additional top-level-cross-venue-table
# `._pairs: ChainMap` for search B0
pairs_view_subtable[pair.bs_fqme] = pair
self._pairs.maps.append(pairs_view_subtable)
return self._pairs return self._pairs
@ -456,20 +455,15 @@ class Client:
Fuzzy search symbology set for pairs matching `pattern`. Fuzzy search symbology set for pairs matching `pattern`.
''' '''
pairs: dict[str, Pair] = await self.symbol_info() pairs: dict[str, Pair] = await self.exch_info()
matches: dict[str, Pair] = match_from_pairs(
return match_from_pairs(
pairs=pairs, pairs=pairs,
query=pattern.upper(), query=pattern.upper(),
score_cutoff=35, score_cutoff=35,
limit=limit limit=limit
) )
# repack in name-keyed table
return {
pair['instrument_name'].lower(): pair
for pair in matches.values()
}
async def bars( async def bars(
self, self,
mkt: MktPair, mkt: MktPair,
@ -481,7 +475,7 @@ class Client:
as_np: bool = True, as_np: bool = True,
) -> list[tuple] | np.ndarray: ) -> list[tuple] | np.ndarray:
instrument: str = mkt.bs_fqme instrument: str = mkt.bs_fqme.split('.')[0]
if end_dt is None: if end_dt is None:
end_dt = now('UTC') end_dt = now('UTC')
@ -494,7 +488,7 @@ class Client:
end_time = deribit_timestamp(end_dt) end_time = deribit_timestamp(end_dt)
# https://docs.deribit.com/#public-get_tradingview_chart_data # https://docs.deribit.com/#public-get_tradingview_chart_data
resp = await self.json_rpc( resp = await self._json_rpc_auth_wrapper(
'public/get_tradingview_chart_data', 'public/get_tradingview_chart_data',
params={ params={
'instrument_name': instrument.upper(), 'instrument_name': instrument.upper(),
@ -531,7 +525,7 @@ class Client:
instrument: str, instrument: str,
count: int = 10 count: int = 10
): ):
resp = await self.json_rpc( resp = await self._json_rpc_auth_wrapper(
'public/get_last_trades_by_instrument', 'public/get_last_trades_by_instrument',
params={ params={
'instrument_name': instrument, 'instrument_name': instrument,
@ -543,7 +537,8 @@ class Client:
@acm @acm
async def get_client( async def get_client(
is_brokercheck: bool = False is_brokercheck: bool = False,
venue: MarketType = 'option',
) -> Client: ) -> Client:
async with ( async with (
@ -553,68 +548,6 @@ async def get_client(
) as json_rpc ) as json_rpc
): ):
client = Client(json_rpc) client = Client(json_rpc)
_refresh_token: Optional[str] = None
_access_token: Optional[str] = None
async def _auth_loop(
task_status: TaskStatus = trio.TASK_STATUS_IGNORED
):
"""Background task that adquires a first access token and then will
refresh the access token while the nursery isn't cancelled.
https://docs.deribit.com/?python#authentication-2
"""
renew_time = 10
access_scope = 'trade:read_write'
_expiry_time = time.time()
got_access = False
nonlocal _refresh_token
nonlocal _access_token
while True:
if time.time() - _expiry_time < renew_time:
# if we are close to token expiry time
if _refresh_token != None:
# if we have a refresh token already dont need to send
# secret
params = {
'grant_type': 'refresh_token',
'refresh_token': _refresh_token,
'scope': access_scope
}
else:
# we don't have refresh token, send secret to initialize
params = {
'grant_type': 'client_credentials',
'client_id': client._key_id,
'client_secret': client._key_secret,
'scope': access_scope
}
resp = await json_rpc('public/auth', params)
result = resp.result
_expiry_time = time.time() + result['expires_in']
_refresh_token = result['refresh_token']
if 'access_token' in result:
_access_token = result['access_token']
if not got_access:
# first time this loop runs we must indicate task is
# started, we have auth
got_access = True
task_status.started()
else:
await trio.sleep(renew_time / 2)
# if we have client creds launch auth loop
if client._key_id is not None:
await n.start(_auth_loop)
await client.cache_symbols() await client.cache_symbols()
yield client yield client
n.cancel_scope.cancel() n.cancel_scope.cancel()
@ -622,7 +555,7 @@ async def get_client(
@acm @acm
async def open_feed_handler(): async def open_feed_handler():
fh = FeedHandler(config=get_fh_config()) fh = FeedHandler(config=get_config())
yield fh yield fh
await to_asyncio.run_task(fh.stop_async) await to_asyncio.run_task(fh.stop_async)

View File

@ -160,38 +160,26 @@ async def get_mkt_info(
assets: dict[str, Asset] = await client.get_assets() assets: dict[str, Asset] = await client.get_assets()
pair_str: str = mkt_ep.lower() pair_str: str = mkt_ep.lower()
# switch venue-mode depending on input pattern parsing
# since we want to use a particular endpoint (set) for
# pair info lookup!
client.mkt_mode = mkt_mode
pair: Pair = await client.exch_info( pair: Pair = await client.exch_info(
sym=pair_str, sym=pair_str,
) )
dst: Asset | None = assets.get(pair.bs_dst_asset) mkt_mode = pair.venue
if ( client.mkt_mode = mkt_mode
not dst
# TODO: a known asset DNE list?
# and pair.baseAsset == 'DEFI'
):
log.warning(
f'UNKNOWN {venue} asset {pair.base_currency} from,\n'
f'{pformat(pair.to_dict())}'
)
# XXX UNKNOWN missing "asset", though no idea why? dst: Asset | None = assets.get(pair.bs_dst_asset)
# maybe it's only avail in the margin venue(s): /dapi/ ? src: Asset | None = assets.get(pair.bs_src_asset)
return None
mkt = MktPair( mkt = MktPair(
dst=dst, dst=dst,
src=assets.get(pair.bs_src_asset), src=src,
price_tick=pair.price_tick, price_tick=pair.price_tick,
size_tick=pair.size_tick, size_tick=pair.size_tick,
bs_mktid=pair.symbol, bs_mktid=pair.symbol,
expiry=expiry, expiry=pair.expiry,
venue=venue, venue=mkt_mode,
broker='deribit', broker='deribit',
_atype=mkt_mode,
_fqme_without_src=True,
) )
return mkt, pair return mkt, pair
@ -210,7 +198,7 @@ async def stream_quotes(
# XXX: required to propagate ``tractor`` loglevel to piker logging # XXX: required to propagate ``tractor`` loglevel to piker logging
get_console_log(loglevel or tractor.current_actor().loglevel) get_console_log(loglevel or tractor.current_actor().loglevel)
sym = symbols[0] sym = symbols[0].split('.')[0]
init_msgs: list[FeedInit] = [] init_msgs: list[FeedInit] = []
@ -225,11 +213,11 @@ async def stream_quotes(
init_msgs.append( init_msgs.append(
FeedInit(mkt_info=mkt) FeedInit(mkt_info=mkt)
) )
nsym = piker_sym_to_cb_sym(sym.split('.')[0]) nsym = piker_sym_to_cb_sym(sym)
async with maybe_open_price_feed(sym) as stream: async with maybe_open_price_feed(sym) as stream:
cache = await client.cache_symbols() cache = client._pairs
last_trades = (await client.last_trades( last_trades = (await client.last_trades(
cb_sym_to_deribit_inst(nsym), count=1)).trades cb_sym_to_deribit_inst(nsym), count=1)).trades
@ -271,7 +259,7 @@ async def open_symbol_search(
async with open_cached_client('deribit') as client: async with open_cached_client('deribit') as client:
# load all symbols locally for fast search # load all symbols locally for fast search
cache = await client.cache_symbols() cache = client._pairs
await ctx.started() await ctx.started()
async with ctx.open_stream() as stream: async with ctx.open_stream() as stream:

View File

@ -19,6 +19,7 @@ Per market data-type definitions and schemas types.
""" """
from __future__ import annotations from __future__ import annotations
import pendulum
from typing import ( from typing import (
Literal, Literal,
) )
@ -66,29 +67,27 @@ class Pair(Struct, frozen=True, kw_only=True):
# dst # dst
base_currency: str # "BTC", base_currency: str # "BTC",
tick_size: float # 0.0001 tick_size: float # 0.0001 # [{'above_price': 0.005, 'tick_size': 0.0005}]
tick_size_steps: list[dict[str, str | int | float]] # [{'above_price': 0.005, 'tick_size': 0.0005}] tick_size_steps: list[dict[str, float]]
@property @property
def price_tick(self) -> Decimal: def price_tick(self) -> Decimal:
step_size: float = self.tick_size_steps[0].get('above_price') return Decimal(str(self.tick_size_steps[0]['above_price']))
return Decimal(step_size)
@property @property
def size_tick(self) -> Decimal: def size_tick(self) -> Decimal:
step_size: float = self.tick_size_steps[0].get('tick_size') return Decimal(str(self.tick_size))
return Decimal(step_size)
@property @property
def bs_fqme(self) -> str: def bs_fqme(self) -> str:
return self.symbol return f'{self.symbol}'
@property @property
def bs_mktid(self) -> str: def bs_mktid(self) -> str:
return f'{self.symbol}.{self.venue}' return f'{self.symbol}.{self.venue}'
class OptionPair(Pair, frozen=True, kw_only=True): class OptionPair(Pair, frozen=True):
taker_commission: float # 0.0003 taker_commission: float # 0.0003
strike: float # 5000.0 strike: float # 5000.0
@ -116,13 +115,18 @@ class OptionPair(Pair, frozen=True, kw_only=True):
# NOTE: see `.data._symcache.SymbologyCache.load()` for why # NOTE: see `.data._symcache.SymbologyCache.load()` for why
ns_path: str = 'piker.brokers.deribit:OptionPair' ns_path: str = 'piker.brokers.deribit:OptionPair'
@property
def expiry(self) -> str:
iso_date = pendulum.from_timestamp(self.expiration_timestamp / 1000).isoformat()
return iso_date
@property @property
def venue(self) -> str: def venue(self) -> str:
return 'OPTION' return 'option'
@property @property
def bs_fqme(self) -> str: def bs_fqme(self) -> str:
return f'{self.symbol}.OPTION' return f'{self.symbol}'
@property @property
def bs_src_asset(self) -> str: def bs_src_asset(self) -> str:
@ -130,13 +134,58 @@ class OptionPair(Pair, frozen=True, kw_only=True):
@property @property
def bs_dst_asset(self) -> str: def bs_dst_asset(self) -> str:
return f'{self.base_currency}' return f'{self.symbol}'
@property
def bs_mktid(self) -> str:
return f'{self.symbol}.{self.venue}'
PAIRTYPES: dict[MarketType, Pair] = { PAIRTYPES: dict[MarketType, Pair] = {
'option': OptionPair, 'option': OptionPair,
} }
class JSONRPCResult(Struct):
id: int
usIn: int
usOut: int
usDiff: int
testnet: bool
jsonrpc: str = '2.0'
error: Optional[dict] = None
result: Optional[list[dict]] = None
class JSONRPCChannel(Struct):
method: str
params: dict
jsonrpc: str = '2.0'
class KLinesResult(Struct):
low: list[float]
cost: list[float]
high: list[float]
open: list[float]
close: list[float]
ticks: list[int]
status: str
volume: list[float]
class Trade(Struct):
iv: float
price: float
amount: float
trade_id: str
contracts: float
direction: str
trade_seq: int
timestamp: int
mark_price: float
index_price: float
tick_direction: int
instrument_name: str
combo_id: Optional[str] = '',
combo_trade_id: Optional[int] = 0,
block_trade_id: Optional[str] = '',
block_trade_leg_count: Optional[int] = 0,
class LastTradesResult(Struct):
trades: list[Trade]
has_more: bool

View File

@ -273,7 +273,7 @@ async def _reconnect_forever(
nobsws._connected.set() nobsws._connected.set()
await trio.sleep_forever() await trio.sleep_forever()
except HandshakeError: except HandshakeError:
log.exception(f'Retrying connection') log.exception('Retrying connection')
# ws & nursery block ends # ws & nursery block ends
@ -359,8 +359,8 @@ async def open_autorecon_ws(
''' '''
JSONRPC response-request style machinery for transparent multiplexing of msgs JSONRPC response-request style machinery for transparent multiplexing
over a NoBsWs. of msgs over a NoBsWs.
''' '''
@ -377,16 +377,20 @@ async def open_jsonrpc_session(
url: str, url: str,
start_id: int = 0, start_id: int = 0,
response_type: type = JSONRPCResult, response_type: type = JSONRPCResult,
request_type: Optional[type] = None, # request_type: Optional[type] = None,
request_hook: Optional[Callable] = None, # request_hook: Optional[Callable] = None,
error_hook: Optional[Callable] = None, # error_hook: Optional[Callable] = None,
) -> Callable[[str, dict], dict]: ) -> Callable[[str, dict], dict]:
# NOTE, store all request msgs so we can raise errors on the
# caller side!
req_msgs: dict[int, dict] = {}
async with ( async with (
trio.open_nursery() as n, trio.open_nursery() as n,
open_autorecon_ws(url) as ws open_autorecon_ws(url) as ws
): ):
rpc_id: Iterable = count(start_id) rpc_id: Iterable[int] = count(start_id)
rpc_results: dict[int, dict] = {} rpc_results: dict[int, dict] = {}
async def json_rpc(method: str, params: dict) -> dict: async def json_rpc(method: str, params: dict) -> dict:
@ -394,26 +398,40 @@ async def open_jsonrpc_session(
perform a json rpc call and wait for the result, raise exception in perform a json rpc call and wait for the result, raise exception in
case of error field present on response case of error field present on response
''' '''
nonlocal req_msgs
req_id: int = next(rpc_id)
msg = { msg = {
'jsonrpc': '2.0', 'jsonrpc': '2.0',
'id': next(rpc_id), 'id': req_id,
'method': method, 'method': method,
'params': params 'params': params
} }
_id = msg['id'] _id = msg['id']
rpc_results[_id] = { result = rpc_results[_id] = {
'result': None, 'result': None,
'event': trio.Event() 'error': None,
'event': trio.Event(), # signal caller resp arrived
} }
req_msgs[_id] = msg
await ws.send_msg(msg) await ws.send_msg(msg)
# wait for reponse before unblocking requester code
await rpc_results[_id]['event'].wait() await rpc_results[_id]['event'].wait()
ret = rpc_results[_id]['result'] if (maybe_result := result['result']):
ret = maybe_result
del rpc_results[_id]
del rpc_results[_id] else:
err = result['error']
raise Exception(
f'JSONRPC request failed\n'
f'req: {msg}\n'
f'resp: {err}\n'
)
if ret.error is not None: if ret.error is not None:
raise Exception(json.dumps(ret.error, indent=4)) raise Exception(json.dumps(ret.error, indent=4))
@ -428,6 +446,7 @@ async def open_jsonrpc_session(
the server side. the server side.
''' '''
nonlocal req_msgs
async for msg in ws: async for msg in ws:
match msg: match msg:
case { case {
@ -451,15 +470,29 @@ async def open_jsonrpc_session(
'params': _, 'params': _,
}: }:
log.debug(f'Recieved\n{msg}') log.debug(f'Recieved\n{msg}')
if request_hook: # if request_hook:
await request_hook(request_type(**msg)) # await request_hook(request_type(**msg))
case { case {
'error': error 'error': error
}: }:
log.warning(f'Recieved\n{error}') # if error_hook:
if error_hook: # await error_hook(response_type(**msg))
await error_hook(response_type(**msg))
# retreive orig request msg, set error
# response in original "result" msg,
# THEN FINALLY set the event to signal caller
# to raise the error in the parent task.
req_id: int = msg['id']
req_msg: dict = req_msgs[req_id]
result: dict = rpc_results[req_id]
result['error'] = error
result['event'].set()
log.error(
f'JSONRPC request failed\n'
f'req: {req_msg}\n'
f'resp: {error}\n'
)
case _: case _:
log.warning(f'Unhandled JSON-RPC msg!?\n{msg}') log.warning(f'Unhandled JSON-RPC msg!?\n{msg}')