From 77548d2ee6114846c0a1a577fce6b60e23eb10fb Mon Sep 17 00:00:00 2001 From: Tyler Goodlet Date: Mon, 25 Feb 2019 20:11:45 -0500 Subject: [PATCH] Add token-from-user toggles to token auth methods --- piker/brokers/__init__.py | 4 +++ piker/brokers/data.py | 8 +++--- piker/brokers/questrade.py | 57 +++++++++++++++++++++++--------------- piker/brokers/robinhood.py | 3 +- 4 files changed, 43 insertions(+), 29 deletions(-) diff --git a/piker/brokers/__init__.py b/piker/brokers/__init__.py index 2e12b3ea..ae14799c 100644 --- a/piker/brokers/__init__.py +++ b/piker/brokers/__init__.py @@ -4,6 +4,10 @@ Broker clients, daemons and general back end machinery. from importlib import import_module from types import ModuleType +# TODO: move to urllib3/requests once supported +import asks +asks.init('trio') + __brokers__ = [ 'questrade', 'robinhood', diff --git a/piker/brokers/data.py b/piker/brokers/data.py index 115b9ff7..0d2a04b8 100644 --- a/piker/brokers/data.py +++ b/piker/brokers/data.py @@ -77,15 +77,15 @@ class BrokerFeed: @tractor.msg.pub(tasks=['stock', 'option']) -async def stream_quotes( +async def stream_requests( get_topics: typing.Callable, get_quotes: Coroutine, feed: BrokerFeed, rate: int = 3, # delay between quote requests diff_cached: bool = True, # only deliver "new" quotes to the queue ) -> None: - """Stream quotes for a sequence of tickers at the given ``rate`` - per second. + """Stream requests for quotes for a set of symbols at the given + ``rate`` (per second). A stock-broker client ``get_quotes()`` async context manager must be provided which returns an async quote retrieval function. @@ -307,7 +307,7 @@ async def start_quote_stream( # push initial smoke quote response for client initialization await ctx.send_yield(payload) - await stream_quotes( + await stream_requests( # pub required kwargs task_name=feed_type, diff --git a/piker/brokers/questrade.py b/piker/brokers/questrade.py index 643962c9..01643a57 100644 --- a/piker/brokers/questrade.py +++ b/piker/brokers/questrade.py @@ -12,6 +12,7 @@ from typing import List, Tuple, Dict, Any, Iterator, NamedTuple import trio from async_generator import asynccontextmanager import wrapt +import asks from ..calc import humanize, percent_change from . import config @@ -19,10 +20,6 @@ from ._util import resproc, BrokerError from ..log import get_logger, colorize_json from .._async_utils import async_lifo_cache -# TODO: move to urllib3/requests once supported -import asks -asks.init('trio') - log = get_logger(__name__) _use_practice_account = False @@ -218,11 +215,13 @@ class Client: self._sess = asks.Session() self.api = _API(self) self._conf = config - self._is_practise_account = _use_practice_account + self._is_practice = _use_practice_account or ( + config['questrade'].get('is_practice', False) + ) self._auth_ep = _refresh_token_ep.format( - 'practice' if _use_practice_account else '') + 'practice' if self._is_practice else '') self.access_data = {} - self._reload_config(config) + self._reload_config(config=config) self._symbol_cache: Dict[str, int] = {} self._optids2contractinfo = {} self._contract2ids = {} @@ -234,7 +233,10 @@ class Client: self._mutex = trio.StrictFIFOLock() def _reload_config(self, config=None, **kwargs): - self._conf = config or get_config(**kwargs) + if config: + self._conf = config + else: + self._conf, _ = get_config(**kwargs) self.access_data = dict(self._conf['questrade']) def write_config(self): @@ -243,7 +245,11 @@ class Client: self._conf['questrade'] = self.access_data config.write(self._conf) - async def ensure_access(self, force_refresh: bool = False) -> dict: + async def ensure_access( + self, + force_refresh: bool = False, + ask_user: bool = True, + ) -> dict: """Acquire a new token set (``access_token`` and ``refresh_token``). Checks if the locally cached (file system) ``access_token`` has expired @@ -295,14 +301,16 @@ class Client: elif msg == 'Bad Request': # likely config ``refresh_token`` is expired but - # may be updated in the config file via another - # piker process + # may be updated in the config file via + # another actor self._reload_config() try: data = await self.api._new_auth_token( self.access_data['refresh_token']) except BrokerError as qterr: - if get_err_msg(qterr) == 'Bad Request': + if get_err_msg(qterr) == 'Bad Request' and ( + ask_user + ): # actually expired; get new from user self._reload_config(force_from_user=True) data = await self.api._new_auth_token( @@ -518,8 +526,8 @@ def _token_from_user(conf: 'configparser.ConfigParser') -> None: def get_config( - force_from_user: bool = False, config_path: str = None, + force_from_user: bool = False, ) -> "configparser.ConfigParser": """Load the broker config from disk. @@ -531,24 +539,27 @@ def get_config( """ log.debug("Reloading access config data") conf, path = config.load(config_path) - if not conf.has_section('questrade'): - log.warn( - f"No valid refresh token could be found in {path}") - elif force_from_user: + if force_from_user: log.warn(f"Forcing manual token auth from user") _token_from_user(conf) - return conf + return conf, path @asynccontextmanager -async def get_client(**kwargs) -> Client: +async def get_client( + config_path: str = None, + ask_user: bool = True +) -> Client: """Spawn a broker client for making requests to the API service. """ - conf = get_config(config_path=kwargs.get('config_path')) + conf, path = get_config(config_path) + if not conf.has_section('questrade'): + raise ValueError( + f"No `questrade` section could be found in {path}") log.debug(f"Loaded config:\n{colorize_json(dict(conf['questrade']))}") - client = Client(conf, **kwargs) - await client.ensure_access() + client = Client(conf) + await client.ensure_access(ask_user=ask_user) try: log.debug("Check time to ensure access token is valid") # XXX: the `time()` end point requires acc_read Oauth access. @@ -560,7 +571,7 @@ async def get_client(**kwargs) -> Client: # access token is likely no good log.warn(f"Access tokens {client.access_data} seem" f" expired, forcing refresh") - await client.ensure_access(force_refresh=True) + await client.ensure_access(force_refresh=True, ask_user=ask_user) await client.api.time() try: yield client diff --git a/piker/brokers/robinhood.py b/piker/brokers/robinhood.py index 4c17fff1..34be0627 100644 --- a/piker/brokers/robinhood.py +++ b/piker/brokers/robinhood.py @@ -9,15 +9,14 @@ from functools import partial from typing import List from async_generator import asynccontextmanager -# TODO: move to urllib3/requests once supported import asks from ..log import get_logger from ._util import resproc, BrokerError from ..calc import percent_change -asks.init('trio') log = get_logger(__name__) + _service_ep = 'https://api.robinhood.com'