Add token-from-user toggles to token auth methods
parent
130553b8df
commit
77548d2ee6
|
@ -4,6 +4,10 @@ Broker clients, daemons and general back end machinery.
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
|
|
||||||
|
# TODO: move to urllib3/requests once supported
|
||||||
|
import asks
|
||||||
|
asks.init('trio')
|
||||||
|
|
||||||
__brokers__ = [
|
__brokers__ = [
|
||||||
'questrade',
|
'questrade',
|
||||||
'robinhood',
|
'robinhood',
|
||||||
|
|
|
@ -77,15 +77,15 @@ class BrokerFeed:
|
||||||
|
|
||||||
|
|
||||||
@tractor.msg.pub(tasks=['stock', 'option'])
|
@tractor.msg.pub(tasks=['stock', 'option'])
|
||||||
async def stream_quotes(
|
async def stream_requests(
|
||||||
get_topics: typing.Callable,
|
get_topics: typing.Callable,
|
||||||
get_quotes: Coroutine,
|
get_quotes: Coroutine,
|
||||||
feed: BrokerFeed,
|
feed: BrokerFeed,
|
||||||
rate: int = 3, # delay between quote requests
|
rate: int = 3, # delay between quote requests
|
||||||
diff_cached: bool = True, # only deliver "new" quotes to the queue
|
diff_cached: bool = True, # only deliver "new" quotes to the queue
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Stream quotes for a sequence of tickers at the given ``rate``
|
"""Stream requests for quotes for a set of symbols at the given
|
||||||
per second.
|
``rate`` (per second).
|
||||||
|
|
||||||
A stock-broker client ``get_quotes()`` async context manager must be
|
A stock-broker client ``get_quotes()`` async context manager must be
|
||||||
provided which returns an async quote retrieval function.
|
provided which returns an async quote retrieval function.
|
||||||
|
@ -307,7 +307,7 @@ async def start_quote_stream(
|
||||||
# push initial smoke quote response for client initialization
|
# push initial smoke quote response for client initialization
|
||||||
await ctx.send_yield(payload)
|
await ctx.send_yield(payload)
|
||||||
|
|
||||||
await stream_quotes(
|
await stream_requests(
|
||||||
|
|
||||||
# pub required kwargs
|
# pub required kwargs
|
||||||
task_name=feed_type,
|
task_name=feed_type,
|
||||||
|
|
|
@ -12,6 +12,7 @@ from typing import List, Tuple, Dict, Any, Iterator, NamedTuple
|
||||||
import trio
|
import trio
|
||||||
from async_generator import asynccontextmanager
|
from async_generator import asynccontextmanager
|
||||||
import wrapt
|
import wrapt
|
||||||
|
import asks
|
||||||
|
|
||||||
from ..calc import humanize, percent_change
|
from ..calc import humanize, percent_change
|
||||||
from . import config
|
from . import config
|
||||||
|
@ -19,10 +20,6 @@ from ._util import resproc, BrokerError
|
||||||
from ..log import get_logger, colorize_json
|
from ..log import get_logger, colorize_json
|
||||||
from .._async_utils import async_lifo_cache
|
from .._async_utils import async_lifo_cache
|
||||||
|
|
||||||
# TODO: move to urllib3/requests once supported
|
|
||||||
import asks
|
|
||||||
asks.init('trio')
|
|
||||||
|
|
||||||
log = get_logger(__name__)
|
log = get_logger(__name__)
|
||||||
|
|
||||||
_use_practice_account = False
|
_use_practice_account = False
|
||||||
|
@ -218,11 +215,13 @@ class Client:
|
||||||
self._sess = asks.Session()
|
self._sess = asks.Session()
|
||||||
self.api = _API(self)
|
self.api = _API(self)
|
||||||
self._conf = config
|
self._conf = config
|
||||||
self._is_practise_account = _use_practice_account
|
self._is_practice = _use_practice_account or (
|
||||||
|
config['questrade'].get('is_practice', False)
|
||||||
|
)
|
||||||
self._auth_ep = _refresh_token_ep.format(
|
self._auth_ep = _refresh_token_ep.format(
|
||||||
'practice' if _use_practice_account else '')
|
'practice' if self._is_practice else '')
|
||||||
self.access_data = {}
|
self.access_data = {}
|
||||||
self._reload_config(config)
|
self._reload_config(config=config)
|
||||||
self._symbol_cache: Dict[str, int] = {}
|
self._symbol_cache: Dict[str, int] = {}
|
||||||
self._optids2contractinfo = {}
|
self._optids2contractinfo = {}
|
||||||
self._contract2ids = {}
|
self._contract2ids = {}
|
||||||
|
@ -234,7 +233,10 @@ class Client:
|
||||||
self._mutex = trio.StrictFIFOLock()
|
self._mutex = trio.StrictFIFOLock()
|
||||||
|
|
||||||
def _reload_config(self, config=None, **kwargs):
|
def _reload_config(self, config=None, **kwargs):
|
||||||
self._conf = config or get_config(**kwargs)
|
if config:
|
||||||
|
self._conf = config
|
||||||
|
else:
|
||||||
|
self._conf, _ = get_config(**kwargs)
|
||||||
self.access_data = dict(self._conf['questrade'])
|
self.access_data = dict(self._conf['questrade'])
|
||||||
|
|
||||||
def write_config(self):
|
def write_config(self):
|
||||||
|
@ -243,7 +245,11 @@ class Client:
|
||||||
self._conf['questrade'] = self.access_data
|
self._conf['questrade'] = self.access_data
|
||||||
config.write(self._conf)
|
config.write(self._conf)
|
||||||
|
|
||||||
async def ensure_access(self, force_refresh: bool = False) -> dict:
|
async def ensure_access(
|
||||||
|
self,
|
||||||
|
force_refresh: bool = False,
|
||||||
|
ask_user: bool = True,
|
||||||
|
) -> dict:
|
||||||
"""Acquire a new token set (``access_token`` and ``refresh_token``).
|
"""Acquire a new token set (``access_token`` and ``refresh_token``).
|
||||||
|
|
||||||
Checks if the locally cached (file system) ``access_token`` has expired
|
Checks if the locally cached (file system) ``access_token`` has expired
|
||||||
|
@ -295,14 +301,16 @@ class Client:
|
||||||
|
|
||||||
elif msg == 'Bad Request':
|
elif msg == 'Bad Request':
|
||||||
# likely config ``refresh_token`` is expired but
|
# likely config ``refresh_token`` is expired but
|
||||||
# may be updated in the config file via another
|
# may be updated in the config file via
|
||||||
# piker process
|
# another actor
|
||||||
self._reload_config()
|
self._reload_config()
|
||||||
try:
|
try:
|
||||||
data = await self.api._new_auth_token(
|
data = await self.api._new_auth_token(
|
||||||
self.access_data['refresh_token'])
|
self.access_data['refresh_token'])
|
||||||
except BrokerError as qterr:
|
except BrokerError as qterr:
|
||||||
if get_err_msg(qterr) == 'Bad Request':
|
if get_err_msg(qterr) == 'Bad Request' and (
|
||||||
|
ask_user
|
||||||
|
):
|
||||||
# actually expired; get new from user
|
# actually expired; get new from user
|
||||||
self._reload_config(force_from_user=True)
|
self._reload_config(force_from_user=True)
|
||||||
data = await self.api._new_auth_token(
|
data = await self.api._new_auth_token(
|
||||||
|
@ -518,8 +526,8 @@ def _token_from_user(conf: 'configparser.ConfigParser') -> None:
|
||||||
|
|
||||||
|
|
||||||
def get_config(
|
def get_config(
|
||||||
force_from_user: bool = False,
|
|
||||||
config_path: str = None,
|
config_path: str = None,
|
||||||
|
force_from_user: bool = False,
|
||||||
) -> "configparser.ConfigParser":
|
) -> "configparser.ConfigParser":
|
||||||
"""Load the broker config from disk.
|
"""Load the broker config from disk.
|
||||||
|
|
||||||
|
@ -531,24 +539,27 @@ def get_config(
|
||||||
"""
|
"""
|
||||||
log.debug("Reloading access config data")
|
log.debug("Reloading access config data")
|
||||||
conf, path = config.load(config_path)
|
conf, path = config.load(config_path)
|
||||||
if not conf.has_section('questrade'):
|
if force_from_user:
|
||||||
log.warn(
|
|
||||||
f"No valid refresh token could be found in {path}")
|
|
||||||
elif force_from_user:
|
|
||||||
log.warn(f"Forcing manual token auth from user")
|
log.warn(f"Forcing manual token auth from user")
|
||||||
_token_from_user(conf)
|
_token_from_user(conf)
|
||||||
|
|
||||||
return conf
|
return conf, path
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def get_client(**kwargs) -> Client:
|
async def get_client(
|
||||||
|
config_path: str = None,
|
||||||
|
ask_user: bool = True
|
||||||
|
) -> Client:
|
||||||
"""Spawn a broker client for making requests to the API service.
|
"""Spawn a broker client for making requests to the API service.
|
||||||
"""
|
"""
|
||||||
conf = get_config(config_path=kwargs.get('config_path'))
|
conf, path = get_config(config_path)
|
||||||
|
if not conf.has_section('questrade'):
|
||||||
|
raise ValueError(
|
||||||
|
f"No `questrade` section could be found in {path}")
|
||||||
log.debug(f"Loaded config:\n{colorize_json(dict(conf['questrade']))}")
|
log.debug(f"Loaded config:\n{colorize_json(dict(conf['questrade']))}")
|
||||||
client = Client(conf, **kwargs)
|
client = Client(conf)
|
||||||
await client.ensure_access()
|
await client.ensure_access(ask_user=ask_user)
|
||||||
try:
|
try:
|
||||||
log.debug("Check time to ensure access token is valid")
|
log.debug("Check time to ensure access token is valid")
|
||||||
# XXX: the `time()` end point requires acc_read Oauth access.
|
# XXX: the `time()` end point requires acc_read Oauth access.
|
||||||
|
@ -560,7 +571,7 @@ async def get_client(**kwargs) -> Client:
|
||||||
# access token is likely no good
|
# access token is likely no good
|
||||||
log.warn(f"Access tokens {client.access_data} seem"
|
log.warn(f"Access tokens {client.access_data} seem"
|
||||||
f" expired, forcing refresh")
|
f" expired, forcing refresh")
|
||||||
await client.ensure_access(force_refresh=True)
|
await client.ensure_access(force_refresh=True, ask_user=ask_user)
|
||||||
await client.api.time()
|
await client.api.time()
|
||||||
try:
|
try:
|
||||||
yield client
|
yield client
|
||||||
|
|
|
@ -9,15 +9,14 @@ from functools import partial
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from async_generator import asynccontextmanager
|
from async_generator import asynccontextmanager
|
||||||
# TODO: move to urllib3/requests once supported
|
|
||||||
import asks
|
import asks
|
||||||
|
|
||||||
from ..log import get_logger
|
from ..log import get_logger
|
||||||
from ._util import resproc, BrokerError
|
from ._util import resproc, BrokerError
|
||||||
from ..calc import percent_change
|
from ..calc import percent_change
|
||||||
|
|
||||||
asks.init('trio')
|
|
||||||
log = get_logger(__name__)
|
log = get_logger(__name__)
|
||||||
|
|
||||||
_service_ep = 'https://api.robinhood.com'
|
_service_ep = 'https://api.robinhood.com'
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue