Add token-from-user toggles to token auth methods

kivy_mainline_and_py3.8
Tyler Goodlet 2019-02-25 20:11:45 -05:00
parent 130553b8df
commit 77548d2ee6
4 changed files with 43 additions and 29 deletions

View File

@ -4,6 +4,10 @@ Broker clients, daemons and general back end machinery.
from importlib import import_module from importlib import import_module
from types import ModuleType from types import ModuleType
# TODO: move to urllib3/requests once supported
import asks
asks.init('trio')
__brokers__ = [ __brokers__ = [
'questrade', 'questrade',
'robinhood', 'robinhood',

View File

@ -77,15 +77,15 @@ class BrokerFeed:
@tractor.msg.pub(tasks=['stock', 'option']) @tractor.msg.pub(tasks=['stock', 'option'])
async def stream_quotes( async def stream_requests(
get_topics: typing.Callable, get_topics: typing.Callable,
get_quotes: Coroutine, get_quotes: Coroutine,
feed: BrokerFeed, feed: BrokerFeed,
rate: int = 3, # delay between quote requests rate: int = 3, # delay between quote requests
diff_cached: bool = True, # only deliver "new" quotes to the queue diff_cached: bool = True, # only deliver "new" quotes to the queue
) -> None: ) -> None:
"""Stream quotes for a sequence of tickers at the given ``rate`` """Stream requests for quotes for a set of symbols at the given
per second. ``rate`` (per second).
A stock-broker client ``get_quotes()`` async context manager must be A stock-broker client ``get_quotes()`` async context manager must be
provided which returns an async quote retrieval function. provided which returns an async quote retrieval function.
@ -307,7 +307,7 @@ async def start_quote_stream(
# push initial smoke quote response for client initialization # push initial smoke quote response for client initialization
await ctx.send_yield(payload) await ctx.send_yield(payload)
await stream_quotes( await stream_requests(
# pub required kwargs # pub required kwargs
task_name=feed_type, task_name=feed_type,

View File

@ -12,6 +12,7 @@ from typing import List, Tuple, Dict, Any, Iterator, NamedTuple
import trio import trio
from async_generator import asynccontextmanager from async_generator import asynccontextmanager
import wrapt import wrapt
import asks
from ..calc import humanize, percent_change from ..calc import humanize, percent_change
from . import config from . import config
@ -19,10 +20,6 @@ from ._util import resproc, BrokerError
from ..log import get_logger, colorize_json from ..log import get_logger, colorize_json
from .._async_utils import async_lifo_cache from .._async_utils import async_lifo_cache
# TODO: move to urllib3/requests once supported
import asks
asks.init('trio')
log = get_logger(__name__) log = get_logger(__name__)
_use_practice_account = False _use_practice_account = False
@ -218,11 +215,13 @@ class Client:
self._sess = asks.Session() self._sess = asks.Session()
self.api = _API(self) self.api = _API(self)
self._conf = config self._conf = config
self._is_practise_account = _use_practice_account self._is_practice = _use_practice_account or (
config['questrade'].get('is_practice', False)
)
self._auth_ep = _refresh_token_ep.format( self._auth_ep = _refresh_token_ep.format(
'practice' if _use_practice_account else '') 'practice' if self._is_practice else '')
self.access_data = {} self.access_data = {}
self._reload_config(config) self._reload_config(config=config)
self._symbol_cache: Dict[str, int] = {} self._symbol_cache: Dict[str, int] = {}
self._optids2contractinfo = {} self._optids2contractinfo = {}
self._contract2ids = {} self._contract2ids = {}
@ -234,7 +233,10 @@ class Client:
self._mutex = trio.StrictFIFOLock() self._mutex = trio.StrictFIFOLock()
def _reload_config(self, config=None, **kwargs): def _reload_config(self, config=None, **kwargs):
self._conf = config or get_config(**kwargs) if config:
self._conf = config
else:
self._conf, _ = get_config(**kwargs)
self.access_data = dict(self._conf['questrade']) self.access_data = dict(self._conf['questrade'])
def write_config(self): def write_config(self):
@ -243,7 +245,11 @@ class Client:
self._conf['questrade'] = self.access_data self._conf['questrade'] = self.access_data
config.write(self._conf) config.write(self._conf)
async def ensure_access(self, force_refresh: bool = False) -> dict: async def ensure_access(
self,
force_refresh: bool = False,
ask_user: bool = True,
) -> dict:
"""Acquire a new token set (``access_token`` and ``refresh_token``). """Acquire a new token set (``access_token`` and ``refresh_token``).
Checks if the locally cached (file system) ``access_token`` has expired Checks if the locally cached (file system) ``access_token`` has expired
@ -295,14 +301,16 @@ class Client:
elif msg == 'Bad Request': elif msg == 'Bad Request':
# likely config ``refresh_token`` is expired but # likely config ``refresh_token`` is expired but
# may be updated in the config file via another # may be updated in the config file via
# piker process # another actor
self._reload_config() self._reload_config()
try: try:
data = await self.api._new_auth_token( data = await self.api._new_auth_token(
self.access_data['refresh_token']) self.access_data['refresh_token'])
except BrokerError as qterr: except BrokerError as qterr:
if get_err_msg(qterr) == 'Bad Request': if get_err_msg(qterr) == 'Bad Request' and (
ask_user
):
# actually expired; get new from user # actually expired; get new from user
self._reload_config(force_from_user=True) self._reload_config(force_from_user=True)
data = await self.api._new_auth_token( data = await self.api._new_auth_token(
@ -518,8 +526,8 @@ def _token_from_user(conf: 'configparser.ConfigParser') -> None:
def get_config( def get_config(
force_from_user: bool = False,
config_path: str = None, config_path: str = None,
force_from_user: bool = False,
) -> "configparser.ConfigParser": ) -> "configparser.ConfigParser":
"""Load the broker config from disk. """Load the broker config from disk.
@ -531,24 +539,27 @@ def get_config(
""" """
log.debug("Reloading access config data") log.debug("Reloading access config data")
conf, path = config.load(config_path) conf, path = config.load(config_path)
if not conf.has_section('questrade'): if force_from_user:
log.warn(
f"No valid refresh token could be found in {path}")
elif force_from_user:
log.warn(f"Forcing manual token auth from user") log.warn(f"Forcing manual token auth from user")
_token_from_user(conf) _token_from_user(conf)
return conf return conf, path
@asynccontextmanager @asynccontextmanager
async def get_client(**kwargs) -> Client: async def get_client(
config_path: str = None,
ask_user: bool = True
) -> Client:
"""Spawn a broker client for making requests to the API service. """Spawn a broker client for making requests to the API service.
""" """
conf = get_config(config_path=kwargs.get('config_path')) conf, path = get_config(config_path)
if not conf.has_section('questrade'):
raise ValueError(
f"No `questrade` section could be found in {path}")
log.debug(f"Loaded config:\n{colorize_json(dict(conf['questrade']))}") log.debug(f"Loaded config:\n{colorize_json(dict(conf['questrade']))}")
client = Client(conf, **kwargs) client = Client(conf)
await client.ensure_access() await client.ensure_access(ask_user=ask_user)
try: try:
log.debug("Check time to ensure access token is valid") log.debug("Check time to ensure access token is valid")
# XXX: the `time()` end point requires acc_read Oauth access. # XXX: the `time()` end point requires acc_read Oauth access.
@ -560,7 +571,7 @@ async def get_client(**kwargs) -> Client:
# access token is likely no good # access token is likely no good
log.warn(f"Access tokens {client.access_data} seem" log.warn(f"Access tokens {client.access_data} seem"
f" expired, forcing refresh") f" expired, forcing refresh")
await client.ensure_access(force_refresh=True) await client.ensure_access(force_refresh=True, ask_user=ask_user)
await client.api.time() await client.api.time()
try: try:
yield client yield client

View File

@ -9,15 +9,14 @@ from functools import partial
from typing import List from typing import List
from async_generator import asynccontextmanager from async_generator import asynccontextmanager
# TODO: move to urllib3/requests once supported
import asks import asks
from ..log import get_logger from ..log import get_logger
from ._util import resproc, BrokerError from ._util import resproc, BrokerError
from ..calc import percent_change from ..calc import percent_change
asks.init('trio')
log = get_logger(__name__) log = get_logger(__name__)
_service_ep = 'https://api.robinhood.com' _service_ep = 'https://api.robinhood.com'