Synchronize Questrade token refreshing per client

Questrade's API is half baked and can't handle concurrency.
It allows multiple concurrent requests to most endpoints *except*
for the auth endpoint used to refresh tokens:

    https://www.questrade.com/api/documentation/security

I've gone through extensive dialogue with their API team and despite
making what I think are very good arguments for doing the request
serialization on the server side, they decided that I should instead
do the "locking" on the client side. Frankly it doesn't seem like they
have that competent an engineering department as it took me a long time
to explain the issue even though it's rather trivial and probably not
that hard to fix; maybe it's better this way.

This adds a few things to ensure more reliable token refreshes on
expiry:

- add a `@refresh_token_on_err` decorator which can be used on `_API`
  methods that should refresh tokens on failure
- decorate most endpoints with this *except* for the auth ep
- add locking logic for the troublesome scenario as follows:
  * every time a request is sent out set a "request in progress" event
    variable that can be used to determine when no requests are currently
    outstanding
  * every time the auth end point is hit in order to refresh tokens set
    an event that locks out other tasks from making requests
  * only allow hitting the auth endpoint when there are no "requests in
    progress" using the first event
  * mutex all auth endpoint requests; there can only be one outstanding

- don't hit the accounts endpoint at client startup; we want to
  eventually support keys from multiple accounts and you can disable
  account info per key and just share the market data function
kivy_mainline_and_py3.8
Tyler Goodlet 2019-02-09 21:39:22 -05:00
parent f6230dd6df
commit 395f0c8e4a
1 changed files with 207 additions and 155 deletions

View File

@ -1,6 +1,8 @@
"""
Questrade API backend.
"""
from __future__ import annotations
import inspect
import time
from datetime import datetime
from functools import partial
@ -9,6 +11,7 @@ from typing import List, Tuple, Dict, Any, Iterator, NamedTuple
import trio
from async_generator import asynccontextmanager
import wrapt
from ..calc import humanize, percent_change
from . import config
@ -40,56 +43,121 @@ class ContractsKey(NamedTuple):
expiry: datetime
def refresh_token_on_err(tries=3):
"""`_API` method decorator which locks the client and refreshes tokens
before unlocking access to the API again.
QT's service end can't handle concurrent requests to multiple
endpoints reliably without choking up and confusing their interal
servers.
"""
@wrapt.decorator
async def wrapper(wrapped, api, args, kwargs):
assert inspect.iscoroutinefunction(wrapped)
client = api.client
if not client._has_access.is_set():
log.warning("WAITING ON ACCESS LOCK")
await client._has_access.wait()
for i in range(1, tries):
try:
try:
client._request_not_in_progress.clear()
return await wrapped(*args, **kwargs)
finally:
client._request_not_in_progress.set()
except (QuestradeError, BrokerError) as qterr:
if "Access token is invalid" not in str(qterr.args[0]):
raise
# TODO: this will crash when run from a sub-actor since
# STDIN can't be acquired. The right way to handle this
# is to make a request to the parent actor (i.e.
# spawner of this) to call this
# `client.ensure_access()` locally thus blocking until
# the user provides an API key on the "client side"
log.warning(f"Tokens are invalid refreshing try {i}..")
await client.ensure_access(force_refresh=True)
if i == tries - 1:
raise
return wrapper
class _API:
"""Questrade API endpoints exposed as methods and wrapped with an
http session.
"""
def __init__(self, session: asks.Session):
self._sess = session
def __init__(
self,
client: Client,
):
self.client = client
self._sess: asks.Session = client._sess
async def _request(self, path: str, params=None) -> dict:
@refresh_token_on_err()
async def _get(self, path: str, params=None) -> dict:
"""Get an endpoint "reliably" by ensuring access on failure.
"""
resp = await self._sess.get(path=f'/{path}', params=params)
return resproc(resp, log)
async def _new_auth_token(self, refresh_token: str) -> dict:
"""Request a new api authorization ``refresh_token``.
Gain api access using either a user provided or existing token.
See the instructions::
http://www.questrade.com/api/documentation/getting-started
http://www.questrade.com/api/documentation/security
"""
resp = await self._sess.get(
_refresh_token_ep + 'token',
params={'grant_type': 'refresh_token',
'refresh_token': refresh_token}
)
return resproc(resp, log)
async def accounts(self) -> dict:
return await self._request('accounts')
return await self._get('accounts')
async def time(self) -> dict:
return await self._request('time')
return await self._get('time')
async def markets(self) -> dict:
return await self._request('markets')
return await self._get('markets')
async def search(self, prefix: str) -> dict:
return await self._request(
return await self._get(
'symbols/search', params={'prefix': prefix})
async def symbols(self, ids: str = '', names: str = '') -> dict:
log.debug(f"Symbol lookup for {ids or names}")
return await self._request(
return await self._get(
'symbols', params={'ids': ids, 'names': names})
async def quotes(self, ids: str) -> dict:
quotes = (await self._request(
quotes = (await self._get(
'markets/quotes', params={'ids': ids}))['quotes']
for quote in quotes:
quote['key'] = quote['symbol']
return quotes
async def candles(self, id: str, start: str, end, interval) -> dict:
return await self._request(f'markets/candles/{id}', params={})
return await self._get(f'markets/candles/{id}', params={})
async def balances(self, id: str) -> dict:
return await self._request(f'accounts/{id}/balances')
return await self._get(f'accounts/{id}/balances')
async def postions(self, id: str) -> dict:
return await self._request(f'accounts/{id}/positions')
return await self._get(f'accounts/{id}/positions')
async def option_contracts(self, symbol_id: str) -> dict:
"Retrieve all option contract API ids with expiry -> strike prices."
contracts = await self._request(f'symbols/{symbol_id}/options')
contracts = await self._get(f'symbols/{symbol_id}/options')
return contracts['optionChain']
@refresh_token_on_err()
async def option_quotes(
self,
contracts: Dict[ContractsKey, Dict[int, dict]] = {},
@ -107,7 +175,8 @@ class _API:
]
resp = await self._sess.post(
path=f'/markets/quotes/options',
# XXX: b'{"code":1024,"message":"The size of the array requested is not valid: optionIds"}'
# XXX: b'{"code":1024,"message":"The size of the array requested
# is not valid: optionIds"}'
# ^ what I get when trying to use too many ids manually...
json={'filters': filters, 'optionIds': option_ids}
)
@ -122,48 +191,24 @@ class Client:
"""
def __init__(self, config: configparser.ConfigParser):
self._sess = asks.Session()
self.api = _API(self._sess)
self.api = _API(self)
self._conf = config
self.access_data = {}
self._reload_config(config)
self._symbol_cache: Dict[str, int] = {}
self._optids2contractinfo = {}
self._contract2ids = {}
# for blocking during token refresh
self._has_access = trio.Event()
self._has_access.set()
self._request_not_in_progress = trio.Event()
self._request_not_in_progress.set()
self._mutex = trio.StrictFIFOLock()
def _reload_config(self, config=None, **kwargs):
log.warn("Reloading access config data")
self._conf = config or get_config(**kwargs)
self.access_data = dict(self._conf['questrade'])
async def _new_auth_token(self) -> dict:
"""Request a new api authorization ``refresh_token``.
Gain api access using either a user provided or existing token.
See the instructions::
http://www.questrade.com/api/documentation/getting-started
http://www.questrade.com/api/documentation/security
"""
resp = await self._sess.get(
_refresh_token_ep + 'token',
params={'grant_type': 'refresh_token',
'refresh_token': self.access_data['refresh_token']}
)
data = resproc(resp, log)
self.access_data.update(data)
return data
def _prep_sess(self) -> None:
"""Fill http session with auth headers and a base url.
"""
data = self.access_data
# set access token header for the session
self._sess.headers.update({
'Authorization': (f"{data['token_type']} {data['access_token']}")})
# set base API url (asks shorthand)
self._sess.base_location = self.access_data['api_server'] + _version
async def _revoke_auth_token(self) -> None:
"""Revoke api access for the current token.
"""
@ -175,8 +220,14 @@ class Client:
)
return resp
def write_config(self):
"""Save access creds to config file.
"""
self._conf['questrade'] = self.access_data
config.write(self._conf)
async def ensure_access(self, force_refresh: bool = False) -> dict:
"""Acquire new ``access_token`` and/or ``refresh_token`` if necessary.
"""Acquire a new token set (``access_token`` and ``refresh_token``).
Checks if the locally cached (file system) ``access_token`` has expired
(based on a ``expires_at`` time stamp stored in the brokers.ini config)
@ -185,47 +236,90 @@ class Client:
``refresh_token`` has expired a new one needs to be provided by the
user.
"""
# wait for ongoing requests to clear (API can't handle
# concurrent endpoint requests alongside a token refresh)
await self._request_not_in_progress.wait()
# block api access to tall other tasks
# XXX: this is limitation of the API when using a single
# token whereby their service can't handle concurrent requests
# to differnet end points (particularly the auth ep) which
# causes hangs and premature token invalidation issues.
self._has_access.clear()
try:
# don't allow simultaneous token refresh requests
async with self._mutex:
access_token = self.access_data.get('access_token')
expires = float(self.access_data.get('expires_at', 0))
expires_stamp = datetime.fromtimestamp(
expires).strftime('%Y-%m-%d %H:%M:%S')
if not access_token or (expires < time.time()) or force_refresh:
if not access_token or (
expires < time.time()
) or force_refresh:
log.info("REFRESHING TOKENS!")
log.debug(
f"Refreshing access token {access_token} which expired at"
f" {expires_stamp}")
f"Refreshing access token {access_token} which expired"
f" at {expires_stamp}")
try:
data = await self._new_auth_token()
data = await self.api._new_auth_token(
self.access_data['refresh_token'])
except BrokerError as qterr:
if "We're making some changes" in str(qterr.args[0]):
def get_err_msg(err):
# handle str and bytes...
msg = err.args[0]
return msg.decode() if msg.isascii() else msg
msg = get_err_msg(qterr)
if "We're making some changes" in msg:
# API service is down
raise QuestradeError("API is down for maintenance")
elif qterr.args[0].decode() == 'Bad Request':
# likely config ``refresh_token`` is expired but may
# be updated in the config file via another piker process
elif msg == 'Bad Request':
# likely config ``refresh_token`` is expired but
# may be updated in the config file via another
# piker process
self._reload_config()
try:
data = await self._new_auth_token()
data = await self.api._new_auth_token(
self.access_data['refresh_token'])
except BrokerError as qterr:
if qterr.args[0].decode() == 'Bad Request':
if get_err_msg(qterr) == 'Bad Request':
# actually expired; get new from user
self._reload_config(force_from_user=True)
data = await self._new_auth_token()
data = await self.api._new_auth_token(
self.access_data['refresh_token'])
else:
raise QuestradeError(qterr)
else:
raise qterr
# store absolute token expiry time
self.access_data.update(data)
log.debug(f"Updated tokens:\n{data}")
# store an absolute access token expiry time
self.access_data['expires_at'] = time.time() + float(
data['expires_in'])
# write to config on disk
write_conf(self)
# write to config to disk
self.write_config()
else:
log.debug(f"\nCurrent access token {access_token} expires at"
log.info(
f"\nCurrent access token {access_token} expires at"
f" {expires_stamp}\n")
self._prep_sess()
return self.access_data
# set access token header for the session
data = self.access_data
self._sess.headers.update({
'Authorization':
(f"{data['token_type']} {data['access_token']}")}
)
# set base API url (asks shorthand)
self._sess.base_location = data['api_server'] + _version
finally:
self._has_access.set()
return data
async def tickers2ids(self, tickers):
"""Helper routine that take a sequence of ticker symbols and returns
@ -407,54 +501,53 @@ def _token_from_user(conf: 'configparser.ConfigParser') -> None:
conf['questrade'] = {'refresh_token': refresh_token}
def get_config(force_from_user=False) -> "configparser.ConfigParser":
conf, path = config.load()
if not conf.has_section('questrade') or (
not conf['questrade'].get('refresh_token') or (
force_from_user)
):
def get_config(
force_from_user: bool = False,
config_path: str = None,
) -> "configparser.ConfigParser":
"""Load the broker config from disk.
By default this is the file:
~/.config/piker/brokers.ini
though may be different depending on your OS.
"""
log.debug("Reloading access config data")
conf, path = config.load(config_path)
if not conf.has_section('questrade'):
log.warn(
f"No valid refresh token could be found in {path}")
elif force_from_user:
log.warn(f"Forcing manual token auth from user")
_token_from_user(conf)
return conf
def write_conf(client):
"""Save access creds to config file.
"""
client._conf['questrade'] = client.access_data
config.write(client._conf)
@asynccontextmanager
async def get_client() -> Client:
"""Spawn a broker client.
A client must adhere to the method calls in ``piker.broker.core``.
"""Spawn a broker client for making requests to the API service.
"""
conf = get_config()
log.debug(f"Loaded config:\n{colorize_json(dict(conf['questrade']))}")
client = Client(conf)
await client.ensure_access()
try:
log.debug("Check time to ensure access token is valid")
try:
# await client.api.time()
await client.quote(['RY.TO'])
await client.api.time()
except Exception:
# access token is likely no good
log.warn(f"Access token {client.access_data['access_token']} seems"
log.warn(f"Access tokens {client.access_data} seem"
f" expired, forcing refresh")
await client.ensure_access(force_refresh=True)
await client.api.time()
accounts = await client.api.accounts()
log.info(f"Available accounts:\n{colorize_json(accounts)}")
try:
yield client
finally:
write_conf(client)
except trio.Cancelled:
# only write config if we didn't bail out
client.write_config()
raise
async def stock_quoter(client: Client, tickers: List[str]):
@ -480,26 +573,6 @@ async def stock_quoter(client: Client, tickers: List[str]):
return {}
ids = await get_symbol_id_seq(tuple(tickers))
try:
quotes_resp = await client.api.quotes(ids=ids)
except (QuestradeError, BrokerError) as qterr:
if "Access token is invalid" not in str(qterr.args[0]):
raise
# out-of-process piker actor may have
# renewed already..
client._reload_config()
try:
quotes_resp = await client.api.quotes(ids=ids)
except BrokerError as qterr:
if "Access token is invalid" in str(qterr.args[0]):
# TODO: this will crash when run from a sub-actor since
# STDIN can't be acquired. The right way to handle this
# is to make a request to the parent actor (i.e.
# spawner of this) to call this
# `client.ensure_access()` locally thus blocking until
# the user provides an API key on the "client side"
await client.ensure_access(force_refresh=True)
quotes_resp = await client.api.quotes(ids=ids)
# post-processing
@ -543,28 +616,7 @@ async def option_quoter(client: Client, tickers: List[str]):
"""
contracts = await get_contract_by_date(
tuple(symbol_date_pairs))
try:
quotes = await client.option_chains(contracts)
except (QuestradeError, BrokerError) as qterr:
if "Access token is invalid" not in str(qterr.args[0]):
raise
# out-of-process piker actor may have
# renewed already..
client._reload_config()
try:
quotes = await client.option_chains(contracts)
except BrokerError as qterr:
if "Access token is invalid" in str(qterr.args[0]):
# TODO: this will crash when run from a sub-actor since
# STDIN can't be acquired. The right way to handle this
# is to make a request to the parent actor (i.e.
# spawner of this) to call this
# `client.ensure_access()` locally thus blocking until
# the user provides an API key on the "client side"
await client.ensure_access(force_refresh=True)
quotes = await client.option_chains(contracts)
return quotes
return await client.option_chains(contracts)
return get_quote