diff --git a/piker/__init__.py b/piker/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/piker/brokers/__init__.py b/piker/brokers/__init__.py new file mode 100644 index 00000000..f2c6409c --- /dev/null +++ b/piker/brokers/__init__.py @@ -0,0 +1,27 @@ +""" +Broker client-daemons and general back end machinery. +""" +import sys +import trio +from .questrade import serve_forever +from ..log import get_console_log + + +def main() -> None: + log = get_console_log('INFO', name='questrade') + argv = sys.argv[1:] + + refresh_token = None + if argv: + refresh_token = argv[0] + + # main loop + try: + client = trio.run(serve_forever, refresh_token) + except Exception as err: + log.exception(err) + else: + log.info( + f"\nLast refresh_token: {client.access_data['refresh_token']}\n" + f"Last access_token: {client.access_data['access_token']}\n" + ) diff --git a/piker/brokers/config.py b/piker/brokers/config.py new file mode 100644 index 00000000..5d923adf --- /dev/null +++ b/piker/brokers/config.py @@ -0,0 +1,30 @@ +""" +Broker configuration mgmt. +""" +from os import path +import configparser +from ..log import get_logger + +log = get_logger('broker-config') + +_broker_conf_path = path.join(path.dirname(__file__), 'brokers.ini') + + +def load() -> (configparser.ConfigParser, str): + """Load broker config. + + Create a ``broker.ini`` file if one dne. + """ + config = configparser.ConfigParser() + # mode = 'r' if path.exists(_broker_conf_path) else 'a' + read = config.read(_broker_conf_path) + log.debug(f"Read config file {_broker_conf_path}") + return config, _broker_conf_path + + +def write(config: configparser.ConfigParser) -> None: + """Write broker config to disk. + """ + log.debug(f"Writing config file {_broker_conf_path}") + with open(_broker_conf_path, 'w') as cf: + return config.write(cf) diff --git a/piker/brokers/questrade.py b/piker/brokers/questrade.py new file mode 100644 index 00000000..bdb5e710 --- /dev/null +++ b/piker/brokers/questrade.py @@ -0,0 +1,186 @@ +""" +Questrade API backend. +""" +from . import config +from ..log import get_logger +from pprint import pformat +import time +from async_generator import asynccontextmanager + +# TODO: move to urllib3/requests once supported +import asks +asks.init('trio') + +log = get_logger('questrade') + +_refresh_token_ep = 'https://login.questrade.com/oauth2/' +_version = 'v1' + + +class QuestradeError(Exception): + "Non-200 OK response code" + + +def resproc( + resp: asks.response_objects.Response, + return_json: bool = True +) -> asks.response_objects.Response: + """Raise error on non-200 OK response. + """ + data = resp.json() + log.debug(f"Received json contents:\n{pformat(data)}\n") + + if not resp.status_code == 200: + raise QuestradeError(resp.body) + + return data if return_json else resp + + +class API: + """Questrade API at its finest. + """ + def __init__(self, session: asks.Session): + self._sess = session + + async def _request(self, path: str) -> dict: + resp = await self._sess.get(path=f'/{path}') + return resproc(resp) + + async def accounts(self): + return await self._request('accounts') + + async def time(self): + return await self._request('time') + + +class Client: + """API client suitable for use as a long running broker daemon. + """ + def __init__(self, config: dict): + sess = self._sess = asks.Session() + self.api = API(sess) + self.access_data = config + self.user_data = {} + self._conf = None # possibly set in ``from_config`` factory + + @classmethod + async def from_config(cls, config): + client = cls(dict(config['questrade'])) + client._conf = config + await client.enable_access() + return client + + async def _new_auth_token(self) -> dict: + """Request a new api authorization ``refresh_token``. + + Gain api access using either a user provided or existing token. + See the instructions:: + + http://www.questrade.com/api/documentation/getting-started + http://www.questrade.com/api/documentation/security + """ + resp = await self._sess.get( + _refresh_token_ep + 'token', + params={'grant_type': 'refresh_token', + 'refresh_token': self.access_data['refresh_token']} + ) + data = resproc(resp) + self.access_data.update(data) + + return data + + async def _prep_sess(self) -> None: + """Fill http session with auth headers and a base url. + """ + data = self.access_data + # set access token header for the session + self._sess.headers.update({ + 'Authorization': (f"{data['token_type']} {data['access_token']}")}) + # set base API url (asks shorthand) + self._sess.base_location = self.access_data['api_server'] + _version + + async def _revoke_auth_token(self) -> None: + """Revoke api access for the current token. + """ + token = self.access_data['refresh_token'] + log.debug(f"Revoking token {token}") + resp = await asks.post( + _refresh_token_ep + 'revoke', + headers={'token': token} + ) + return resp + + async def enable_access(self, force_refresh: bool = False) -> dict: + """Acquire new ``refresh_token`` and/or ``access_token`` if necessary. + + Only needs to be called if the locally stored ``refresh_token`` has + expired (normally has a lifetime of 3 days). If ``false is set then + refresh the access token instead of using the locally cached version. + """ + access_token = self.access_data.get('access_token') + expires = float(self.access_data.get('expires_at', 0)) + # expired_by = time.time() - float(self.ttl or 0) + # if not access_token or (self.ttl is None) or (expires < time.time()): + if not access_token or (expires < time.time()) or force_refresh: + log.info( + f"Access token {access_token} expired @ {expires}, " + "refreshing...") + data = await self._new_auth_token() + + # store absolute token expiry time + self.access_data['expires_at'] = time.time() + float( + data['expires_in']) + + await self._prep_sess() + return self.access_data + + +def get_config() -> "configparser.ConfigParser": + conf, path = config.load() + if not conf.has_section('questrade') or ( + not conf['questrade'].get('refresh_token') + ): + log.warn( + f"No valid `questrade` refresh token could be found in {path}") + # get from user + refresh_token = input("Please provide your Questrade access token: ") + conf['questrade'] = {'refresh_token': refresh_token} + + return conf + + +@asynccontextmanager +async def get_client(refresh_token: str = None) -> Client: + """Spawn a broker client. + + """ + conf = get_config() + log.debug(f"Loaded questrade config: {conf['questrade']}") + log.info("Waiting on api access...") + client = await Client.from_config(conf) + + try: + try: # do a test ping to ensure the access token works + log.debug("Check time to ensure access token is valid") + await client.api.time() + except Exception as err: + # access token is likely no good + log.warn(f"Access token {client.access_data['access_token']} seems" + f" expired, forcing refresh") + await client.enable_access(force_refresh=True) + await client.api.time() + + yield client + finally: + # save access creds for next run + conf['questrade'] = client.access_data + config.write(conf) + + +async def serve_forever(refresh_token: str = None) -> None: + """Start up a client and serve until terminated. + """ + async with get_client(refresh_token) as client: + # pretty sure this doesn't work + # await client._revoke_auth_token() + return client diff --git a/piker/log.py b/piker/log.py new file mode 100644 index 00000000..5433d69a --- /dev/null +++ b/piker/log.py @@ -0,0 +1,85 @@ +""" +Log like a forester! +(You can't usually find stupid suits in the forest) +""" +import sys +import logging +import colorlog + +_proj_name = 'piker' + +# Super sexy formatting thanks to ``colorlog``. +# (NOTE: we use the '{' format style) +# Here, `thin_white` is just the laymen's gray. +LOG_FORMAT = ( + "{bold_white}{thin_white}{asctime}{reset}" + " {bold_white}{thin_white}({reset}" + "{thin_white}{threadName}{reset}{bold_white}{thin_white})" + " {reset}{log_color}[{reset}{bold_log_color}{levelname}{reset}{log_color}]" + " {log_color}{name}" + " {thin_white}{filename}{log_color}:{reset}{thin_white}{lineno}{log_color}" + " {reset}{bold_white}{thin_white}{message}" +) +DATE_FORMAT = '%b %d %H:%M:%S' +LEVELS = { + 'GARBAGE': 1, + 'TRACE': 5, + 'PROFILE': 15, + 'QUIET': 1000, +} +STD_PALETTE = { + 'CRITICAL': 'red', + 'ERROR': 'red', + 'WARNING': 'yellow', + 'INFO': 'green', + 'DEBUG': 'purple', + 'TRACE': 'cyan', + 'GARBAGE': 'blue', +} +BOLD_PALETTE = { + 'bold': { + level: f"bold_{color}" for level, color in STD_PALETTE.items()} +} + + +def get_logger(name: str = None) -> logging.Logger: + '''Return the package log or a sub-log for `name` if provided. + ''' + log = rlog = logging.getLogger(_proj_name) + if name and name != _proj_name: + log = rlog.getChild(name) + log.level = rlog.level + return log + + +def get_console_log(level: str = None, name: str = None) -> logging.Logger: + '''Get the package logger and enable a handler which writes to stderr. + + Yeah yeah, i know we can use ``DictConfig``. You do it... + ''' + log = get_logger(name) # our root logger + + if level: + log.setLevel(level.upper() if not isinstance(level, int) else level) + + if not any( + handler.stream == sys.stderr for handler in log.handlers + if getattr(handler, 'stream', None) + ): + handler = logging.StreamHandler() + + # additional levels + for name, val in LEVELS.items(): + logging.addLevelName(val, name) + + formatter = colorlog.ColoredFormatter( + LOG_FORMAT, + datefmt=DATE_FORMAT, + log_colors=STD_PALETTE, + secondary_log_colors=BOLD_PALETTE, + style='{', + ) + handler.setFormatter(formatter) + log.addHandler(handler) + + return log diff --git a/setup.py b/setup.py index b022e0c6..ca2daa3e 100755 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ setup( ], entry_points={ 'console_scripts': [ - 'pikerd = piker.brokers.questrade:main', + 'pikerd = piker.brokers:main', ] }, install_requires=['click', 'colorlog', 'trio', 'attrs'],