From 395f0c8e4aabef0e7df65f71629af593f39d101b Mon Sep 17 00:00:00 2001 From: Tyler Goodlet Date: Sat, 9 Feb 2019 21:39:22 -0500 Subject: [PATCH] Synchronize Questrade token refreshing per client Questrade's API is half baked and can't handle concurrency. It allows multiple concurrent requests to most endpoints *except* for the auth endpoint used to refresh tokens: https://www.questrade.com/api/documentation/security I've gone through extensive dialogue with their API team and despite making what I think are very good arguments for doing the request serialization on the server side, they decided that I should instead do the "locking" on the client side. Frankly it doesn't seem like they have that competent an engineering department as it took me a long time to explain the issue even though it's rather trivial and probably not that hard to fix; maybe it's better this way. This adds a few things to ensure more reliable token refreshes on expiry: - add a `@refresh_token_on_err` decorator which can be used on `_API` methods that should refresh tokens on failure - decorate most endpoints with this *except* for the auth ep - add locking logic for the troublesome scenario as follows: * every time a request is sent out set a "request in progress" event variable that can be used to determine when no requests are currently outstanding * every time the auth end point is hit in order to refresh tokens set an event that locks out other tasks from making requests * only allow hitting the auth endpoint when there are no "requests in progress" using the first event * mutex all auth endpoint requests; there can only be one outstanding - don't hit the accounts endpoint at client startup; we want to eventually support keys from multiple accounts and you can disable account info per key and just share the market data function --- piker/brokers/questrade.py | 362 +++++++++++++++++++++---------------- 1 file changed, 207 insertions(+), 155 deletions(-) diff --git a/piker/brokers/questrade.py b/piker/brokers/questrade.py index c3fb7586..2ac9d1ed 100644 --- a/piker/brokers/questrade.py +++ b/piker/brokers/questrade.py @@ -1,6 +1,8 @@ """ Questrade API backend. """ +from __future__ import annotations +import inspect import time from datetime import datetime from functools import partial @@ -9,6 +11,7 @@ from typing import List, Tuple, Dict, Any, Iterator, NamedTuple import trio from async_generator import asynccontextmanager +import wrapt from ..calc import humanize, percent_change from . import config @@ -40,56 +43,121 @@ class ContractsKey(NamedTuple): expiry: datetime +def refresh_token_on_err(tries=3): + """`_API` method decorator which locks the client and refreshes tokens + before unlocking access to the API again. + + QT's service end can't handle concurrent requests to multiple + endpoints reliably without choking up and confusing their interal + servers. + """ + + @wrapt.decorator + async def wrapper(wrapped, api, args, kwargs): + assert inspect.iscoroutinefunction(wrapped) + client = api.client + + if not client._has_access.is_set(): + log.warning("WAITING ON ACCESS LOCK") + await client._has_access.wait() + + for i in range(1, tries): + try: + try: + client._request_not_in_progress.clear() + return await wrapped(*args, **kwargs) + finally: + client._request_not_in_progress.set() + except (QuestradeError, BrokerError) as qterr: + if "Access token is invalid" not in str(qterr.args[0]): + raise + # TODO: this will crash when run from a sub-actor since + # STDIN can't be acquired. The right way to handle this + # is to make a request to the parent actor (i.e. + # spawner of this) to call this + # `client.ensure_access()` locally thus blocking until + # the user provides an API key on the "client side" + log.warning(f"Tokens are invalid refreshing try {i}..") + await client.ensure_access(force_refresh=True) + if i == tries - 1: + raise + return wrapper + + class _API: """Questrade API endpoints exposed as methods and wrapped with an http session. """ - def __init__(self, session: asks.Session): - self._sess = session + def __init__( + self, + client: Client, + ): + self.client = client + self._sess: asks.Session = client._sess - async def _request(self, path: str, params=None) -> dict: + @refresh_token_on_err() + async def _get(self, path: str, params=None) -> dict: + """Get an endpoint "reliably" by ensuring access on failure. + """ resp = await self._sess.get(path=f'/{path}', params=params) return resproc(resp, log) + async def _new_auth_token(self, refresh_token: str) -> dict: + """Request a new api authorization ``refresh_token``. + + Gain api access using either a user provided or existing token. + See the instructions:: + + http://www.questrade.com/api/documentation/getting-started + http://www.questrade.com/api/documentation/security + """ + resp = await self._sess.get( + _refresh_token_ep + 'token', + params={'grant_type': 'refresh_token', + 'refresh_token': refresh_token} + ) + return resproc(resp, log) + async def accounts(self) -> dict: - return await self._request('accounts') + return await self._get('accounts') async def time(self) -> dict: - return await self._request('time') + return await self._get('time') async def markets(self) -> dict: - return await self._request('markets') + return await self._get('markets') async def search(self, prefix: str) -> dict: - return await self._request( + return await self._get( 'symbols/search', params={'prefix': prefix}) async def symbols(self, ids: str = '', names: str = '') -> dict: log.debug(f"Symbol lookup for {ids or names}") - return await self._request( + return await self._get( 'symbols', params={'ids': ids, 'names': names}) async def quotes(self, ids: str) -> dict: - quotes = (await self._request( + quotes = (await self._get( 'markets/quotes', params={'ids': ids}))['quotes'] for quote in quotes: quote['key'] = quote['symbol'] return quotes async def candles(self, id: str, start: str, end, interval) -> dict: - return await self._request(f'markets/candles/{id}', params={}) + return await self._get(f'markets/candles/{id}', params={}) async def balances(self, id: str) -> dict: - return await self._request(f'accounts/{id}/balances') + return await self._get(f'accounts/{id}/balances') async def postions(self, id: str) -> dict: - return await self._request(f'accounts/{id}/positions') + return await self._get(f'accounts/{id}/positions') async def option_contracts(self, symbol_id: str) -> dict: "Retrieve all option contract API ids with expiry -> strike prices." - contracts = await self._request(f'symbols/{symbol_id}/options') + contracts = await self._get(f'symbols/{symbol_id}/options') return contracts['optionChain'] + @refresh_token_on_err() async def option_quotes( self, contracts: Dict[ContractsKey, Dict[int, dict]] = {}, @@ -107,7 +175,8 @@ class _API: ] resp = await self._sess.post( path=f'/markets/quotes/options', - # XXX: b'{"code":1024,"message":"The size of the array requested is not valid: optionIds"}' + # XXX: b'{"code":1024,"message":"The size of the array requested + # is not valid: optionIds"}' # ^ what I get when trying to use too many ids manually... json={'filters': filters, 'optionIds': option_ids} ) @@ -122,48 +191,24 @@ class Client: """ def __init__(self, config: configparser.ConfigParser): self._sess = asks.Session() - self.api = _API(self._sess) + self.api = _API(self) self._conf = config self.access_data = {} self._reload_config(config) self._symbol_cache: Dict[str, int] = {} self._optids2contractinfo = {} self._contract2ids = {} + # for blocking during token refresh + self._has_access = trio.Event() + self._has_access.set() + self._request_not_in_progress = trio.Event() + self._request_not_in_progress.set() + self._mutex = trio.StrictFIFOLock() def _reload_config(self, config=None, **kwargs): - log.warn("Reloading access config data") self._conf = config or get_config(**kwargs) self.access_data = dict(self._conf['questrade']) - async def _new_auth_token(self) -> dict: - """Request a new api authorization ``refresh_token``. - - Gain api access using either a user provided or existing token. - See the instructions:: - - http://www.questrade.com/api/documentation/getting-started - http://www.questrade.com/api/documentation/security - """ - resp = await self._sess.get( - _refresh_token_ep + 'token', - params={'grant_type': 'refresh_token', - 'refresh_token': self.access_data['refresh_token']} - ) - data = resproc(resp, log) - self.access_data.update(data) - - return data - - def _prep_sess(self) -> None: - """Fill http session with auth headers and a base url. - """ - data = self.access_data - # set access token header for the session - self._sess.headers.update({ - 'Authorization': (f"{data['token_type']} {data['access_token']}")}) - # set base API url (asks shorthand) - self._sess.base_location = self.access_data['api_server'] + _version - async def _revoke_auth_token(self) -> None: """Revoke api access for the current token. """ @@ -175,8 +220,14 @@ class Client: ) return resp + def write_config(self): + """Save access creds to config file. + """ + self._conf['questrade'] = self.access_data + config.write(self._conf) + async def ensure_access(self, force_refresh: bool = False) -> dict: - """Acquire new ``access_token`` and/or ``refresh_token`` if necessary. + """Acquire a new token set (``access_token`` and ``refresh_token``). Checks if the locally cached (file system) ``access_token`` has expired (based on a ``expires_at`` time stamp stored in the brokers.ini config) @@ -185,47 +236,90 @@ class Client: ``refresh_token`` has expired a new one needs to be provided by the user. """ - access_token = self.access_data.get('access_token') - expires = float(self.access_data.get('expires_at', 0)) - expires_stamp = datetime.fromtimestamp( - expires).strftime('%Y-%m-%d %H:%M:%S') - if not access_token or (expires < time.time()) or force_refresh: - log.debug( - f"Refreshing access token {access_token} which expired at" - f" {expires_stamp}") - try: - data = await self._new_auth_token() - except BrokerError as qterr: - if "We're making some changes" in str(qterr.args[0]): - # API service is down - raise QuestradeError("API is down for maintenance") - elif qterr.args[0].decode() == 'Bad Request': - # likely config ``refresh_token`` is expired but may - # be updated in the config file via another piker process - self._reload_config() + # wait for ongoing requests to clear (API can't handle + # concurrent endpoint requests alongside a token refresh) + await self._request_not_in_progress.wait() + + # block api access to tall other tasks + # XXX: this is limitation of the API when using a single + # token whereby their service can't handle concurrent requests + # to differnet end points (particularly the auth ep) which + # causes hangs and premature token invalidation issues. + self._has_access.clear() + try: + # don't allow simultaneous token refresh requests + async with self._mutex: + access_token = self.access_data.get('access_token') + expires = float(self.access_data.get('expires_at', 0)) + expires_stamp = datetime.fromtimestamp( + expires).strftime('%Y-%m-%d %H:%M:%S') + if not access_token or ( + expires < time.time() + ) or force_refresh: + log.info("REFRESHING TOKENS!") + log.debug( + f"Refreshing access token {access_token} which expired" + f" at {expires_stamp}") try: - data = await self._new_auth_token() + data = await self.api._new_auth_token( + self.access_data['refresh_token']) except BrokerError as qterr: - if qterr.args[0].decode() == 'Bad Request': - # actually expired; get new from user - self._reload_config(force_from_user=True) - data = await self._new_auth_token() + + def get_err_msg(err): + # handle str and bytes... + msg = err.args[0] + return msg.decode() if msg.isascii() else msg + + msg = get_err_msg(qterr) + + if "We're making some changes" in msg: + # API service is down + raise QuestradeError("API is down for maintenance") + + elif msg == 'Bad Request': + # likely config ``refresh_token`` is expired but + # may be updated in the config file via another + # piker process + self._reload_config() + try: + data = await self.api._new_auth_token( + self.access_data['refresh_token']) + except BrokerError as qterr: + if get_err_msg(qterr) == 'Bad Request': + # actually expired; get new from user + self._reload_config(force_from_user=True) + data = await self.api._new_auth_token( + self.access_data['refresh_token']) + else: + raise QuestradeError(qterr) else: - raise QuestradeError(qterr) + raise qterr + + self.access_data.update(data) + log.debug(f"Updated tokens:\n{data}") + # store an absolute access token expiry time + self.access_data['expires_at'] = time.time() + float( + data['expires_in']) + + # write to config to disk + self.write_config() else: - raise qterr + log.info( + f"\nCurrent access token {access_token} expires at" + f" {expires_stamp}\n") - # store absolute token expiry time - self.access_data['expires_at'] = time.time() + float( - data['expires_in']) - # write to config on disk - write_conf(self) - else: - log.debug(f"\nCurrent access token {access_token} expires at" - f" {expires_stamp}\n") + # set access token header for the session + data = self.access_data + self._sess.headers.update({ + 'Authorization': + (f"{data['token_type']} {data['access_token']}")} + ) + # set base API url (asks shorthand) + self._sess.base_location = data['api_server'] + _version + finally: + self._has_access.set() - self._prep_sess() - return self.access_data + return data async def tickers2ids(self, tickers): """Helper routine that take a sequence of ticker symbols and returns @@ -407,54 +501,53 @@ def _token_from_user(conf: 'configparser.ConfigParser') -> None: conf['questrade'] = {'refresh_token': refresh_token} -def get_config(force_from_user=False) -> "configparser.ConfigParser": - conf, path = config.load() - if not conf.has_section('questrade') or ( - not conf['questrade'].get('refresh_token') or ( - force_from_user) - ): +def get_config( + force_from_user: bool = False, + config_path: str = None, +) -> "configparser.ConfigParser": + """Load the broker config from disk. + + By default this is the file: + + ~/.config/piker/brokers.ini + + though may be different depending on your OS. + """ + log.debug("Reloading access config data") + conf, path = config.load(config_path) + if not conf.has_section('questrade'): log.warn( f"No valid refresh token could be found in {path}") + elif force_from_user: + log.warn(f"Forcing manual token auth from user") _token_from_user(conf) return conf -def write_conf(client): - """Save access creds to config file. - """ - client._conf['questrade'] = client.access_data - config.write(client._conf) - - @asynccontextmanager async def get_client() -> Client: - """Spawn a broker client. - - A client must adhere to the method calls in ``piker.broker.core``. + """Spawn a broker client for making requests to the API service. """ conf = get_config() log.debug(f"Loaded config:\n{colorize_json(dict(conf['questrade']))}") client = Client(conf) await client.ensure_access() - try: log.debug("Check time to ensure access token is valid") - try: - # await client.api.time() - await client.quote(['RY.TO']) - except Exception: - # access token is likely no good - log.warn(f"Access token {client.access_data['access_token']} seems" - f" expired, forcing refresh") - await client.ensure_access(force_refresh=True) - await client.api.time() - - accounts = await client.api.accounts() - log.info(f"Available accounts:\n{colorize_json(accounts)}") + await client.api.time() + except Exception: + # access token is likely no good + log.warn(f"Access tokens {client.access_data} seem" + f" expired, forcing refresh") + await client.ensure_access(force_refresh=True) + await client.api.time() + try: yield client - finally: - write_conf(client) + except trio.Cancelled: + # only write config if we didn't bail out + client.write_config() + raise async def stock_quoter(client: Client, tickers: List[str]): @@ -480,27 +573,7 @@ async def stock_quoter(client: Client, tickers: List[str]): return {} ids = await get_symbol_id_seq(tuple(tickers)) - - try: - quotes_resp = await client.api.quotes(ids=ids) - except (QuestradeError, BrokerError) as qterr: - if "Access token is invalid" not in str(qterr.args[0]): - raise - # out-of-process piker actor may have - # renewed already.. - client._reload_config() - try: - quotes_resp = await client.api.quotes(ids=ids) - except BrokerError as qterr: - if "Access token is invalid" in str(qterr.args[0]): - # TODO: this will crash when run from a sub-actor since - # STDIN can't be acquired. The right way to handle this - # is to make a request to the parent actor (i.e. - # spawner of this) to call this - # `client.ensure_access()` locally thus blocking until - # the user provides an API key on the "client side" - await client.ensure_access(force_refresh=True) - quotes_resp = await client.api.quotes(ids=ids) + quotes_resp = await client.api.quotes(ids=ids) # post-processing for quote in quotes_resp: @@ -543,28 +616,7 @@ async def option_quoter(client: Client, tickers: List[str]): """ contracts = await get_contract_by_date( tuple(symbol_date_pairs)) - try: - quotes = await client.option_chains(contracts) - except (QuestradeError, BrokerError) as qterr: - if "Access token is invalid" not in str(qterr.args[0]): - raise - # out-of-process piker actor may have - # renewed already.. - client._reload_config() - try: - quotes = await client.option_chains(contracts) - except BrokerError as qterr: - if "Access token is invalid" in str(qterr.args[0]): - # TODO: this will crash when run from a sub-actor since - # STDIN can't be acquired. The right way to handle this - # is to make a request to the parent actor (i.e. - # spawner of this) to call this - # `client.ensure_access()` locally thus blocking until - # the user provides an API key on the "client side" - await client.ensure_access(force_refresh=True) - quotes = await client.option_chains(contracts) - - return quotes + return await client.option_chains(contracts) return get_quote