Synchronize Questrade token refreshing per client

Questrade's API is half baked and can't handle concurrency.
It allows multiple concurrent requests to most endpoints *except*
for the auth endpoint used to refresh tokens:

    https://www.questrade.com/api/documentation/security

I've gone through extensive dialogue with their API team and despite
making what I think are very good arguments for doing the request
serialization on the server side, they decided that I should instead
do the "locking" on the client side. Frankly it doesn't seem like they
have that competent an engineering department as it took me a long time
to explain the issue even though it's rather trivial and probably not
that hard to fix; maybe it's better this way.

This adds a few things to ensure more reliable token refreshes on
expiry:

- add a `@refresh_token_on_err` decorator which can be used on `_API`
  methods that should refresh tokens on failure
- decorate most endpoints with this *except* for the auth ep
- add locking logic for the troublesome scenario as follows:
  * every time a request is sent out set a "request in progress" event
    variable that can be used to determine when no requests are currently
    outstanding
  * every time the auth end point is hit in order to refresh tokens set
    an event that locks out other tasks from making requests
  * only allow hitting the auth endpoint when there are no "requests in
    progress" using the first event
  * mutex all auth endpoint requests; there can only be one outstanding

- don't hit the accounts endpoint at client startup; we want to
  eventually support keys from multiple accounts and you can disable
  account info per key and just share the market data function
kivy_mainline_and_py3.8
Tyler Goodlet 2019-02-09 21:39:22 -05:00
parent f6230dd6df
commit 395f0c8e4a
1 changed files with 207 additions and 155 deletions

View File

@ -1,6 +1,8 @@
""" """
Questrade API backend. Questrade API backend.
""" """
from __future__ import annotations
import inspect
import time import time
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
@ -9,6 +11,7 @@ from typing import List, Tuple, Dict, Any, Iterator, NamedTuple
import trio import trio
from async_generator import asynccontextmanager from async_generator import asynccontextmanager
import wrapt
from ..calc import humanize, percent_change from ..calc import humanize, percent_change
from . import config from . import config
@ -40,56 +43,121 @@ class ContractsKey(NamedTuple):
expiry: datetime expiry: datetime
def refresh_token_on_err(tries=3):
"""`_API` method decorator which locks the client and refreshes tokens
before unlocking access to the API again.
QT's service end can't handle concurrent requests to multiple
endpoints reliably without choking up and confusing their interal
servers.
"""
@wrapt.decorator
async def wrapper(wrapped, api, args, kwargs):
assert inspect.iscoroutinefunction(wrapped)
client = api.client
if not client._has_access.is_set():
log.warning("WAITING ON ACCESS LOCK")
await client._has_access.wait()
for i in range(1, tries):
try:
try:
client._request_not_in_progress.clear()
return await wrapped(*args, **kwargs)
finally:
client._request_not_in_progress.set()
except (QuestradeError, BrokerError) as qterr:
if "Access token is invalid" not in str(qterr.args[0]):
raise
# TODO: this will crash when run from a sub-actor since
# STDIN can't be acquired. The right way to handle this
# is to make a request to the parent actor (i.e.
# spawner of this) to call this
# `client.ensure_access()` locally thus blocking until
# the user provides an API key on the "client side"
log.warning(f"Tokens are invalid refreshing try {i}..")
await client.ensure_access(force_refresh=True)
if i == tries - 1:
raise
return wrapper
class _API: class _API:
"""Questrade API endpoints exposed as methods and wrapped with an """Questrade API endpoints exposed as methods and wrapped with an
http session. http session.
""" """
def __init__(self, session: asks.Session): def __init__(
self._sess = session self,
client: Client,
):
self.client = client
self._sess: asks.Session = client._sess
async def _request(self, path: str, params=None) -> dict: @refresh_token_on_err()
async def _get(self, path: str, params=None) -> dict:
"""Get an endpoint "reliably" by ensuring access on failure.
"""
resp = await self._sess.get(path=f'/{path}', params=params) resp = await self._sess.get(path=f'/{path}', params=params)
return resproc(resp, log) return resproc(resp, log)
async def _new_auth_token(self, refresh_token: str) -> dict:
"""Request a new api authorization ``refresh_token``.
Gain api access using either a user provided or existing token.
See the instructions::
http://www.questrade.com/api/documentation/getting-started
http://www.questrade.com/api/documentation/security
"""
resp = await self._sess.get(
_refresh_token_ep + 'token',
params={'grant_type': 'refresh_token',
'refresh_token': refresh_token}
)
return resproc(resp, log)
async def accounts(self) -> dict: async def accounts(self) -> dict:
return await self._request('accounts') return await self._get('accounts')
async def time(self) -> dict: async def time(self) -> dict:
return await self._request('time') return await self._get('time')
async def markets(self) -> dict: async def markets(self) -> dict:
return await self._request('markets') return await self._get('markets')
async def search(self, prefix: str) -> dict: async def search(self, prefix: str) -> dict:
return await self._request( return await self._get(
'symbols/search', params={'prefix': prefix}) 'symbols/search', params={'prefix': prefix})
async def symbols(self, ids: str = '', names: str = '') -> dict: async def symbols(self, ids: str = '', names: str = '') -> dict:
log.debug(f"Symbol lookup for {ids or names}") log.debug(f"Symbol lookup for {ids or names}")
return await self._request( return await self._get(
'symbols', params={'ids': ids, 'names': names}) 'symbols', params={'ids': ids, 'names': names})
async def quotes(self, ids: str) -> dict: async def quotes(self, ids: str) -> dict:
quotes = (await self._request( quotes = (await self._get(
'markets/quotes', params={'ids': ids}))['quotes'] 'markets/quotes', params={'ids': ids}))['quotes']
for quote in quotes: for quote in quotes:
quote['key'] = quote['symbol'] quote['key'] = quote['symbol']
return quotes return quotes
async def candles(self, id: str, start: str, end, interval) -> dict: async def candles(self, id: str, start: str, end, interval) -> dict:
return await self._request(f'markets/candles/{id}', params={}) return await self._get(f'markets/candles/{id}', params={})
async def balances(self, id: str) -> dict: async def balances(self, id: str) -> dict:
return await self._request(f'accounts/{id}/balances') return await self._get(f'accounts/{id}/balances')
async def postions(self, id: str) -> dict: async def postions(self, id: str) -> dict:
return await self._request(f'accounts/{id}/positions') return await self._get(f'accounts/{id}/positions')
async def option_contracts(self, symbol_id: str) -> dict: async def option_contracts(self, symbol_id: str) -> dict:
"Retrieve all option contract API ids with expiry -> strike prices." "Retrieve all option contract API ids with expiry -> strike prices."
contracts = await self._request(f'symbols/{symbol_id}/options') contracts = await self._get(f'symbols/{symbol_id}/options')
return contracts['optionChain'] return contracts['optionChain']
@refresh_token_on_err()
async def option_quotes( async def option_quotes(
self, self,
contracts: Dict[ContractsKey, Dict[int, dict]] = {}, contracts: Dict[ContractsKey, Dict[int, dict]] = {},
@ -107,7 +175,8 @@ class _API:
] ]
resp = await self._sess.post( resp = await self._sess.post(
path=f'/markets/quotes/options', path=f'/markets/quotes/options',
# XXX: b'{"code":1024,"message":"The size of the array requested is not valid: optionIds"}' # XXX: b'{"code":1024,"message":"The size of the array requested
# is not valid: optionIds"}'
# ^ what I get when trying to use too many ids manually... # ^ what I get when trying to use too many ids manually...
json={'filters': filters, 'optionIds': option_ids} json={'filters': filters, 'optionIds': option_ids}
) )
@ -122,48 +191,24 @@ class Client:
""" """
def __init__(self, config: configparser.ConfigParser): def __init__(self, config: configparser.ConfigParser):
self._sess = asks.Session() self._sess = asks.Session()
self.api = _API(self._sess) self.api = _API(self)
self._conf = config self._conf = config
self.access_data = {} self.access_data = {}
self._reload_config(config) self._reload_config(config)
self._symbol_cache: Dict[str, int] = {} self._symbol_cache: Dict[str, int] = {}
self._optids2contractinfo = {} self._optids2contractinfo = {}
self._contract2ids = {} self._contract2ids = {}
# for blocking during token refresh
self._has_access = trio.Event()
self._has_access.set()
self._request_not_in_progress = trio.Event()
self._request_not_in_progress.set()
self._mutex = trio.StrictFIFOLock()
def _reload_config(self, config=None, **kwargs): def _reload_config(self, config=None, **kwargs):
log.warn("Reloading access config data")
self._conf = config or get_config(**kwargs) self._conf = config or get_config(**kwargs)
self.access_data = dict(self._conf['questrade']) self.access_data = dict(self._conf['questrade'])
async def _new_auth_token(self) -> dict:
"""Request a new api authorization ``refresh_token``.
Gain api access using either a user provided or existing token.
See the instructions::
http://www.questrade.com/api/documentation/getting-started
http://www.questrade.com/api/documentation/security
"""
resp = await self._sess.get(
_refresh_token_ep + 'token',
params={'grant_type': 'refresh_token',
'refresh_token': self.access_data['refresh_token']}
)
data = resproc(resp, log)
self.access_data.update(data)
return data
def _prep_sess(self) -> None:
"""Fill http session with auth headers and a base url.
"""
data = self.access_data
# set access token header for the session
self._sess.headers.update({
'Authorization': (f"{data['token_type']} {data['access_token']}")})
# set base API url (asks shorthand)
self._sess.base_location = self.access_data['api_server'] + _version
async def _revoke_auth_token(self) -> None: async def _revoke_auth_token(self) -> None:
"""Revoke api access for the current token. """Revoke api access for the current token.
""" """
@ -175,8 +220,14 @@ class Client:
) )
return resp return resp
def write_config(self):
"""Save access creds to config file.
"""
self._conf['questrade'] = self.access_data
config.write(self._conf)
async def ensure_access(self, force_refresh: bool = False) -> dict: async def ensure_access(self, force_refresh: bool = False) -> dict:
"""Acquire new ``access_token`` and/or ``refresh_token`` if necessary. """Acquire a new token set (``access_token`` and ``refresh_token``).
Checks if the locally cached (file system) ``access_token`` has expired Checks if the locally cached (file system) ``access_token`` has expired
(based on a ``expires_at`` time stamp stored in the brokers.ini config) (based on a ``expires_at`` time stamp stored in the brokers.ini config)
@ -185,47 +236,90 @@ class Client:
``refresh_token`` has expired a new one needs to be provided by the ``refresh_token`` has expired a new one needs to be provided by the
user. user.
""" """
access_token = self.access_data.get('access_token') # wait for ongoing requests to clear (API can't handle
expires = float(self.access_data.get('expires_at', 0)) # concurrent endpoint requests alongside a token refresh)
expires_stamp = datetime.fromtimestamp( await self._request_not_in_progress.wait()
expires).strftime('%Y-%m-%d %H:%M:%S')
if not access_token or (expires < time.time()) or force_refresh: # block api access to tall other tasks
log.debug( # XXX: this is limitation of the API when using a single
f"Refreshing access token {access_token} which expired at" # token whereby their service can't handle concurrent requests
f" {expires_stamp}") # to differnet end points (particularly the auth ep) which
try: # causes hangs and premature token invalidation issues.
data = await self._new_auth_token() self._has_access.clear()
except BrokerError as qterr: try:
if "We're making some changes" in str(qterr.args[0]): # don't allow simultaneous token refresh requests
# API service is down async with self._mutex:
raise QuestradeError("API is down for maintenance") access_token = self.access_data.get('access_token')
elif qterr.args[0].decode() == 'Bad Request': expires = float(self.access_data.get('expires_at', 0))
# likely config ``refresh_token`` is expired but may expires_stamp = datetime.fromtimestamp(
# be updated in the config file via another piker process expires).strftime('%Y-%m-%d %H:%M:%S')
self._reload_config() if not access_token or (
expires < time.time()
) or force_refresh:
log.info("REFRESHING TOKENS!")
log.debug(
f"Refreshing access token {access_token} which expired"
f" at {expires_stamp}")
try: try:
data = await self._new_auth_token() data = await self.api._new_auth_token(
self.access_data['refresh_token'])
except BrokerError as qterr: except BrokerError as qterr:
if qterr.args[0].decode() == 'Bad Request':
# actually expired; get new from user def get_err_msg(err):
self._reload_config(force_from_user=True) # handle str and bytes...
data = await self._new_auth_token() msg = err.args[0]
return msg.decode() if msg.isascii() else msg
msg = get_err_msg(qterr)
if "We're making some changes" in msg:
# API service is down
raise QuestradeError("API is down for maintenance")
elif msg == 'Bad Request':
# likely config ``refresh_token`` is expired but
# may be updated in the config file via another
# piker process
self._reload_config()
try:
data = await self.api._new_auth_token(
self.access_data['refresh_token'])
except BrokerError as qterr:
if get_err_msg(qterr) == 'Bad Request':
# actually expired; get new from user
self._reload_config(force_from_user=True)
data = await self.api._new_auth_token(
self.access_data['refresh_token'])
else:
raise QuestradeError(qterr)
else: else:
raise QuestradeError(qterr) raise qterr
self.access_data.update(data)
log.debug(f"Updated tokens:\n{data}")
# store an absolute access token expiry time
self.access_data['expires_at'] = time.time() + float(
data['expires_in'])
# write to config to disk
self.write_config()
else: else:
raise qterr log.info(
f"\nCurrent access token {access_token} expires at"
f" {expires_stamp}\n")
# store absolute token expiry time # set access token header for the session
self.access_data['expires_at'] = time.time() + float( data = self.access_data
data['expires_in']) self._sess.headers.update({
# write to config on disk 'Authorization':
write_conf(self) (f"{data['token_type']} {data['access_token']}")}
else: )
log.debug(f"\nCurrent access token {access_token} expires at" # set base API url (asks shorthand)
f" {expires_stamp}\n") self._sess.base_location = data['api_server'] + _version
finally:
self._has_access.set()
self._prep_sess() return data
return self.access_data
async def tickers2ids(self, tickers): async def tickers2ids(self, tickers):
"""Helper routine that take a sequence of ticker symbols and returns """Helper routine that take a sequence of ticker symbols and returns
@ -407,54 +501,53 @@ def _token_from_user(conf: 'configparser.ConfigParser') -> None:
conf['questrade'] = {'refresh_token': refresh_token} conf['questrade'] = {'refresh_token': refresh_token}
def get_config(force_from_user=False) -> "configparser.ConfigParser": def get_config(
conf, path = config.load() force_from_user: bool = False,
if not conf.has_section('questrade') or ( config_path: str = None,
not conf['questrade'].get('refresh_token') or ( ) -> "configparser.ConfigParser":
force_from_user) """Load the broker config from disk.
):
By default this is the file:
~/.config/piker/brokers.ini
though may be different depending on your OS.
"""
log.debug("Reloading access config data")
conf, path = config.load(config_path)
if not conf.has_section('questrade'):
log.warn( log.warn(
f"No valid refresh token could be found in {path}") f"No valid refresh token could be found in {path}")
elif force_from_user:
log.warn(f"Forcing manual token auth from user")
_token_from_user(conf) _token_from_user(conf)
return conf return conf
def write_conf(client):
"""Save access creds to config file.
"""
client._conf['questrade'] = client.access_data
config.write(client._conf)
@asynccontextmanager @asynccontextmanager
async def get_client() -> Client: async def get_client() -> Client:
"""Spawn a broker client. """Spawn a broker client for making requests to the API service.
A client must adhere to the method calls in ``piker.broker.core``.
""" """
conf = get_config() conf = get_config()
log.debug(f"Loaded config:\n{colorize_json(dict(conf['questrade']))}") log.debug(f"Loaded config:\n{colorize_json(dict(conf['questrade']))}")
client = Client(conf) client = Client(conf)
await client.ensure_access() await client.ensure_access()
try: try:
log.debug("Check time to ensure access token is valid") log.debug("Check time to ensure access token is valid")
try: await client.api.time()
# await client.api.time() except Exception:
await client.quote(['RY.TO']) # access token is likely no good
except Exception: log.warn(f"Access tokens {client.access_data} seem"
# access token is likely no good f" expired, forcing refresh")
log.warn(f"Access token {client.access_data['access_token']} seems" await client.ensure_access(force_refresh=True)
f" expired, forcing refresh") await client.api.time()
await client.ensure_access(force_refresh=True) try:
await client.api.time()
accounts = await client.api.accounts()
log.info(f"Available accounts:\n{colorize_json(accounts)}")
yield client yield client
finally: except trio.Cancelled:
write_conf(client) # only write config if we didn't bail out
client.write_config()
raise
async def stock_quoter(client: Client, tickers: List[str]): async def stock_quoter(client: Client, tickers: List[str]):
@ -480,27 +573,7 @@ async def stock_quoter(client: Client, tickers: List[str]):
return {} return {}
ids = await get_symbol_id_seq(tuple(tickers)) ids = await get_symbol_id_seq(tuple(tickers))
quotes_resp = await client.api.quotes(ids=ids)
try:
quotes_resp = await client.api.quotes(ids=ids)
except (QuestradeError, BrokerError) as qterr:
if "Access token is invalid" not in str(qterr.args[0]):
raise
# out-of-process piker actor may have
# renewed already..
client._reload_config()
try:
quotes_resp = await client.api.quotes(ids=ids)
except BrokerError as qterr:
if "Access token is invalid" in str(qterr.args[0]):
# TODO: this will crash when run from a sub-actor since
# STDIN can't be acquired. The right way to handle this
# is to make a request to the parent actor (i.e.
# spawner of this) to call this
# `client.ensure_access()` locally thus blocking until
# the user provides an API key on the "client side"
await client.ensure_access(force_refresh=True)
quotes_resp = await client.api.quotes(ids=ids)
# post-processing # post-processing
for quote in quotes_resp: for quote in quotes_resp:
@ -543,28 +616,7 @@ async def option_quoter(client: Client, tickers: List[str]):
""" """
contracts = await get_contract_by_date( contracts = await get_contract_by_date(
tuple(symbol_date_pairs)) tuple(symbol_date_pairs))
try: return await client.option_chains(contracts)
quotes = await client.option_chains(contracts)
except (QuestradeError, BrokerError) as qterr:
if "Access token is invalid" not in str(qterr.args[0]):
raise
# out-of-process piker actor may have
# renewed already..
client._reload_config()
try:
quotes = await client.option_chains(contracts)
except BrokerError as qterr:
if "Access token is invalid" in str(qterr.args[0]):
# TODO: this will crash when run from a sub-actor since
# STDIN can't be acquired. The right way to handle this
# is to make a request to the parent actor (i.e.
# spawner of this) to call this
# `client.ensure_access()` locally thus blocking until
# the user provides an API key on the "client side"
await client.ensure_access(force_refresh=True)
quotes = await client.option_chains(contracts)
return quotes
return get_quote return get_quote