diff --git a/piker/brokers/kucoin.py b/piker/brokers/kucoin.py index 947d0a3b..65f47a40 100644 --- a/piker/brokers/kucoin.py +++ b/piker/brokers/kucoin.py @@ -19,6 +19,7 @@ Kucoin broker backend """ +from random import randint from typing import Any, Callable, Optional, Literal, AsyncGenerator from contextlib import asynccontextmanager as acm from datetime import datetime @@ -62,17 +63,6 @@ _ohlc_dtype = [ ] -def get_config() -> dict[str, dict]: - conf, path = config.load() - - section = conf.get("kucoin") - - if section is None: - log.warning("No config section found for kucoin in config") - - return section - - class KucoinMktPair(Struct, frozen=True): ''' Kucoin's pair format @@ -137,6 +127,12 @@ class KucoinTrade(Struct, frozen=True): time: float +class BrokerConfig(Struct, frozen=True): + key_id: str + key_secret: str + key_passphrase: str + + class KucoinTradeMsg(Struct, frozen=True): type: str topic: str @@ -144,6 +140,18 @@ class KucoinTradeMsg(Struct, frozen=True): data: list[KucoinTrade] +def get_config() -> BrokerConfig | None: + conf, path = config.load() + + section = conf.get("kucoin") + + if section is None: + log.warning("No config section found for kucoin in config") + return None + + return BrokerConfig(**section) + + class Client: def __init__(self) -> None: self._pairs: dict[str, KucoinMktPair] = {} @@ -153,25 +161,25 @@ class Client: self._key_passphrase: str self._authenticated: bool = False - config = get_config() + config: BrokerConfig | None = get_config() if ( config - and float("key_id" in config) - and ("key_secret" in config) - and ("key_passphrase" in config) + and float(config.key_id) + and config.key_secret + and config.key_passphrase ): self._authenticated = True - self._key_id = config["key_id"] - self._key_secret = config["key_secret"] - self._key_passphrase = config["key_passphrase"] + self._key_id = config.key_id + self._key_secret = config.key_secret + self._key_passphrase = config.key_passphrase def _gen_auth_req_headers( self, action: Literal["POST", "GET"], endpoint: str, api_v: str = "v2", - ) -> dict[str, str]: + ) -> dict[str, str | bytes]: ''' Generate authenticated request headers https://docs.kucoin.com/#authentication @@ -212,7 +220,7 @@ class Client: endpoint: str, api_v: str = "v2", headers: dict = {}, - ) -> dict[str, Any]: + ) -> Any: ''' Generic request wrapper for Kucoin API @@ -221,12 +229,14 @@ class Client: headers = self._gen_auth_req_headers(action, endpoint, api_v) api_url = f"https://api.kucoin.com/api/{api_v}{endpoint}" + res = await asks.request(action, api_url, headers=headers) if "data" in res.json(): return res.json()["data"] else: log.error(f'Error making request to {api_url} -> {res.json()["msg"]}') + return res.json()["msg"] async def _get_ws_token( self, @@ -237,14 +247,18 @@ class Client: ''' token_type = "private" if private else "public" - data = await self._request("POST", f"/bullet-{token_type}", "v1") + data: dict[str, Any] | None = await self._request( + "POST", + f"/bullet-{token_type}", + "v1" + ) - if "token" in data: - ping_interval = data["instanceServers"][0]["pingInterval"] + if data and "token" in data: + ping_interval: int = data["instanceServers"][0]["pingInterval"] return data["token"], ping_interval - else: + elif data: log.error( - f'Error making request for Kucoin ws token -> {res.json()["msg"]}' + f'Error making request for Kucoin ws token -> {data.json()["msg"]}' ) async def _get_pairs( @@ -297,8 +311,9 @@ class Client: # repack in dict form return {kucoin_sym_to_fqsn(item[0].name): item[0] for item in matches} - async def last_trades(self, sym: str) -> AccountResponse: + async def last_trades(self, sym: str) -> list[AccountTrade]: trades = await self._request("GET", f"/accounts/ledgers?currency={sym}", "v1") + trades = AccountResponse(**trades) return trades.items async def _get_bars( @@ -327,12 +342,19 @@ class Client: kucoin_sym = fqsn_to_kucoin_sym(fqsn, self._pairs) url = f"/market/candles?type={type}&symbol={kucoin_sym}&startAt={start_dt}&endAt={end_dt}" + bars = [] + for i in range(10): - bars = await self._request( - "GET", - url, - api_v="v1", - ) + res = await self._request( + "GET", + url, + api_v="v1", + ) + if not isinstance(res, list): + await trio.sleep(i + (randint(0, 1000) / 1000)) + else: + bars = res + break # Map to OHLC values to dict then to np array new_bars = []