Save tokens locally for use across runs
Store tokens in a local config file avoiding any refresh delay unless necessary when the current access token expires. Summary: - move draft main routine into the `brokers` package mod - start an api wrapper type - always write the current access tokens to the config on teardownkivy_mainline_and_py3.8
							parent
							
								
									e312fb6525
								
							
						
					
					
						commit
						570d879146
					
				| 
						 | 
					@ -0,0 +1,27 @@
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					Broker client-daemons and general back end machinery.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					import sys
 | 
				
			||||||
 | 
					import trio
 | 
				
			||||||
 | 
					from .questrade import serve_forever
 | 
				
			||||||
 | 
					from ..log import get_console_log
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def main() -> None:
 | 
				
			||||||
 | 
					    log = get_console_log('INFO', name='questrade')
 | 
				
			||||||
 | 
					    argv = sys.argv[1:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    refresh_token = None
 | 
				
			||||||
 | 
					    if argv:
 | 
				
			||||||
 | 
					        refresh_token = argv[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # main loop
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        client = trio.run(serve_forever, refresh_token)
 | 
				
			||||||
 | 
					    except Exception as err:
 | 
				
			||||||
 | 
					        log.exception(err)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        log.info(
 | 
				
			||||||
 | 
					            f"\nLast refresh_token: {client.access_data['refresh_token']}\n"
 | 
				
			||||||
 | 
					            f"Last access_token: {client.access_data['access_token']}\n"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
| 
						 | 
					@ -1,11 +1,11 @@
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
Questrade API backend.
 | 
					Questrade API backend.
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
# from ..config import load
 | 
					from . import config
 | 
				
			||||||
from ..log import get_logger, get_console_log
 | 
					from ..log import get_logger
 | 
				
			||||||
from pprint import pformat
 | 
					from pprint import pformat
 | 
				
			||||||
import sys
 | 
					import time
 | 
				
			||||||
import trio
 | 
					from async_generator import asynccontextmanager
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# TODO: move to urllib3/requests once supported
 | 
					# TODO: move to urllib3/requests once supported
 | 
				
			||||||
import asks
 | 
					import asks
 | 
				
			||||||
| 
						 | 
					@ -13,105 +13,174 @@ asks.init('trio')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
log = get_logger('questrade')
 | 
					log = get_logger('questrade')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
_refresh_token_ep = 'https://login.questrade.com/oauth2/token'
 | 
					_refresh_token_ep = 'https://login.questrade.com/oauth2/'
 | 
				
			||||||
_version = 'v1'
 | 
					_version = 'v1'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ResponseError(Exception):
 | 
					class QuestradeError(Exception):
 | 
				
			||||||
    "Non-200 OK response code"
 | 
					    "Non-200 OK response code"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def err_on_status(resp: asks.response_objects.Response) -> None:
 | 
					def resproc(
 | 
				
			||||||
 | 
					    resp: asks.response_objects.Response,
 | 
				
			||||||
 | 
					    return_json: bool = True
 | 
				
			||||||
 | 
					) -> asks.response_objects.Response:
 | 
				
			||||||
    """Raise error on non-200 OK response.
 | 
					    """Raise error on non-200 OK response.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					    data = resp.json()
 | 
				
			||||||
 | 
					    log.debug(f"Received json contents:\n{pformat(data)}\n")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if not resp.status_code == 200:
 | 
					    if not resp.status_code == 200:
 | 
				
			||||||
        raise ResponseError(resp.body)
 | 
					        raise QuestradeError(resp.body)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return data if return_json else resp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class API:
 | 
				
			||||||
 | 
					    """Questrade API at its finest.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    def __init__(self, session: asks.Session):
 | 
				
			||||||
 | 
					        self._sess = session
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def _request(self, path: str) -> dict:
 | 
				
			||||||
 | 
					        resp = await self._sess.get(path=f'/{path}')
 | 
				
			||||||
 | 
					        return resproc(resp)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def accounts(self):
 | 
				
			||||||
 | 
					        return await self._request('accounts')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def time(self):
 | 
				
			||||||
 | 
					        return await self._request('time')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Client:
 | 
					class Client:
 | 
				
			||||||
    """API client suitable for use as a long running broker daemon.
 | 
					    """API client suitable for use as a long running broker daemon.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    def __init__(self, refresh_token: str):
 | 
					    def __init__(self, config: dict):
 | 
				
			||||||
        self._sess = asks.Session()
 | 
					        sess = self._sess = asks.Session()
 | 
				
			||||||
        self.data = {'refresh_token': refresh_token}
 | 
					        self.api = API(sess)
 | 
				
			||||||
        self.refresh_token = refresh_token
 | 
					        self.access_data = config
 | 
				
			||||||
 | 
					        self.user_data = {}
 | 
				
			||||||
 | 
					        self._conf = None  # possibly set in ``from_config`` factory
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    async def from_token(cls, refresh_token: str):
 | 
					    async def from_config(cls, config):
 | 
				
			||||||
        client = cls(refresh_token)
 | 
					        client = cls(dict(config['questrade']))
 | 
				
			||||||
        await client.refresh_access()
 | 
					        client._conf = config
 | 
				
			||||||
 | 
					        await client.enable_access()
 | 
				
			||||||
        return client
 | 
					        return client
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def refresh_access(self) -> None:
 | 
					    async def _new_auth_token(self) -> dict:
 | 
				
			||||||
        """Acquire new ``refresh_token`` and ``access_token`` if necessary.
 | 
					        """Request a new api authorization ``refresh_token``.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        """
 | 
					        Gain api access using either a user provided or existing token.
 | 
				
			||||||
        resp = await self._sess.get(
 | 
					        See the instructions::
 | 
				
			||||||
            _refresh_token_ep,
 | 
					 | 
				
			||||||
            params={'grant_type': 'refresh_token',
 | 
					 | 
				
			||||||
                    'refresh_token': self.data['refresh_token']}
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        err_on_status(resp)
 | 
					 | 
				
			||||||
        data = resp.json()
 | 
					 | 
				
			||||||
        self.data.update(data)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # set auth token for the session
 | 
					 | 
				
			||||||
        self._sess.headers.update(
 | 
					 | 
				
			||||||
            {'Authorization': f"{data['token_type']} {data['access_token']}"}
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        # set base API url (asks shorthand)
 | 
					 | 
				
			||||||
        self._sess.base_location = data['api_server'] + _version
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def get_user_data(self) -> dict:
 | 
					 | 
				
			||||||
        """Get and store user data from the ``accounts`` endpoint.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        resp = await self._sess.get(path='/accounts')
 | 
					 | 
				
			||||||
        err_on_status(resp)
 | 
					 | 
				
			||||||
        data = resp.json()
 | 
					 | 
				
			||||||
        self.data.update(data)
 | 
					 | 
				
			||||||
        return data
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
async def get_client(refresh_token: str = None) -> Client:
 | 
					 | 
				
			||||||
    """Gain api access using a user generated token.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    See the instructions::
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        http://www.questrade.com/api/documentation/getting-started
 | 
					        http://www.questrade.com/api/documentation/getting-started
 | 
				
			||||||
    """
 | 
					        http://www.questrade.com/api/documentation/security
 | 
				
			||||||
    if refresh_token is None:
 | 
					        """
 | 
				
			||||||
        # sanitize?
 | 
					        resp = await self._sess.get(
 | 
				
			||||||
        refresh_token = input(
 | 
					            _refresh_token_ep + 'token',
 | 
				
			||||||
            "Questrade access token:")
 | 
					            params={'grant_type': 'refresh_token',
 | 
				
			||||||
 | 
					                    'refresh_token': self.access_data['refresh_token']}
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        data = resproc(resp)
 | 
				
			||||||
 | 
					        self.access_data.update(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    log.info("Waiting for initial API access...")
 | 
					        return data
 | 
				
			||||||
    return await Client.from_token(refresh_token)
 | 
					
 | 
				
			||||||
 | 
					    async def _prep_sess(self) -> None:
 | 
				
			||||||
 | 
					        """Fill http session with auth headers and a base url.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        data = self.access_data
 | 
				
			||||||
 | 
					        # set access token header for the session
 | 
				
			||||||
 | 
					        self._sess.headers.update({
 | 
				
			||||||
 | 
					            'Authorization': (f"{data['token_type']} {data['access_token']}")})
 | 
				
			||||||
 | 
					        # set base API url (asks shorthand)
 | 
				
			||||||
 | 
					        self._sess.base_location = self.access_data['api_server'] + _version
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def _revoke_auth_token(self) -> None:
 | 
				
			||||||
 | 
					        """Revoke api access for the current token.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        token = self.access_data['refresh_token']
 | 
				
			||||||
 | 
					        log.debug(f"Revoking token {token}")
 | 
				
			||||||
 | 
					        resp = await asks.post(
 | 
				
			||||||
 | 
					            _refresh_token_ep + 'revoke',
 | 
				
			||||||
 | 
					            headers={'token': token}
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        return resp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def enable_access(self, force_refresh: bool = False) -> dict:
 | 
				
			||||||
 | 
					        """Acquire new ``refresh_token`` and/or ``access_token`` if necessary.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Only needs to be called if the locally stored ``refresh_token`` has
 | 
				
			||||||
 | 
					        expired (normally has a lifetime of 3 days). If ``false is set then
 | 
				
			||||||
 | 
					        refresh the access token instead of using the locally cached version.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        access_token = self.access_data.get('access_token')
 | 
				
			||||||
 | 
					        expires = float(self.access_data.get('expires_at', 0))
 | 
				
			||||||
 | 
					        # expired_by = time.time() - float(self.ttl or 0)
 | 
				
			||||||
 | 
					        # if not access_token or (self.ttl is None) or (expires < time.time()):
 | 
				
			||||||
 | 
					        if not access_token or (expires < time.time()) or force_refresh:
 | 
				
			||||||
 | 
					            log.info(
 | 
				
			||||||
 | 
					                f"Access token {access_token} expired @ {expires}, "
 | 
				
			||||||
 | 
					                "refreshing...")
 | 
				
			||||||
 | 
					            data = await self._new_auth_token()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # store absolute token expiry time
 | 
				
			||||||
 | 
					            self.access_data['expires_at'] = time.time() + float(
 | 
				
			||||||
 | 
					                data['expires_in'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        await self._prep_sess()
 | 
				
			||||||
 | 
					        return self.access_data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_config() -> "configparser.ConfigParser":
 | 
				
			||||||
 | 
					    conf, path = config.load()
 | 
				
			||||||
 | 
					    if not conf.has_section('questrade') or (
 | 
				
			||||||
 | 
					        not conf['questrade'].get('refresh_token')
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        log.warn(
 | 
				
			||||||
 | 
					            f"No valid `questrade` refresh token could be found in {path}")
 | 
				
			||||||
 | 
					        # get from user
 | 
				
			||||||
 | 
					        refresh_token = input("Please provide your Questrade access token: ")
 | 
				
			||||||
 | 
					        conf['questrade'] = {'refresh_token': refresh_token}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return conf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@asynccontextmanager
 | 
				
			||||||
 | 
					async def get_client(refresh_token: str = None) -> Client:
 | 
				
			||||||
 | 
					    """Spawn a broker client.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    conf = get_config()
 | 
				
			||||||
 | 
					    log.debug(f"Loaded questrade config: {conf['questrade']}")
 | 
				
			||||||
 | 
					    log.info("Waiting on api access...")
 | 
				
			||||||
 | 
					    client = await Client.from_config(conf)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        try:  # do a test ping to ensure the access token works
 | 
				
			||||||
 | 
					            log.debug("Check time to ensure access token is valid")
 | 
				
			||||||
 | 
					            await client.api.time()
 | 
				
			||||||
 | 
					        except Exception as err:
 | 
				
			||||||
 | 
					            # access token is likely no good
 | 
				
			||||||
 | 
					            log.warn(f"Access token {client.access_data['access_token']} seems"
 | 
				
			||||||
 | 
					                     f" expired, forcing refresh")
 | 
				
			||||||
 | 
					            await client.enable_access(force_refresh=True)
 | 
				
			||||||
 | 
					            await client.api.time()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        yield client
 | 
				
			||||||
 | 
					    finally:
 | 
				
			||||||
 | 
					        # save access creds for next run
 | 
				
			||||||
 | 
					        conf['questrade'] = client.access_data
 | 
				
			||||||
 | 
					        config.write(conf)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def serve_forever(refresh_token: str = None) -> None:
 | 
					async def serve_forever(refresh_token: str = None) -> None:
 | 
				
			||||||
    """Start up a client and serve until terminated.
 | 
					    """Start up a client and serve until terminated.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    client = await get_client(refresh_token=refresh_token)
 | 
					    async with get_client(refresh_token) as client:
 | 
				
			||||||
    data = await client.get_user_data()
 | 
					        # pretty sure this doesn't work
 | 
				
			||||||
    log.info(pformat(data))
 | 
					        # await client._revoke_auth_token()
 | 
				
			||||||
    return client
 | 
					        return client
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def main() -> None:
 | 
					 | 
				
			||||||
    log = get_console_log('INFO', name='questrade')
 | 
					 | 
				
			||||||
    argv = sys.argv[1:]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    refresh_token = None
 | 
					 | 
				
			||||||
    if argv:
 | 
					 | 
				
			||||||
        refresh_token = argv[0]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # main loop
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        client = trio.run(serve_forever, refresh_token)
 | 
					 | 
				
			||||||
    except Exception as err:
 | 
					 | 
				
			||||||
        log.exception(err)
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        log.info(
 | 
					 | 
				
			||||||
            f"\nLast refresh_token: {client.data['refresh_token']}\n"
 | 
					 | 
				
			||||||
            f"Last access_token: {client.data['access_token']}"
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue