diff --git a/piker/brokers/__init__.py b/piker/brokers/__init__.py index 54afc783..8ae3f1dc 100644 --- a/piker/brokers/__init__.py +++ b/piker/brokers/__init__.py @@ -1,3 +1,23 @@ """ Broker clients, daemons and general back end machinery. """ +from importlib import import_module +from types import ModuleType + +__brokers__ = [ + 'questrade', + 'robinhood', +] + + +def get_brokermod(brokername: str) -> ModuleType: + """Return the imported broker module by name. + """ + return import_module('.' + brokername, 'piker.brokers') + + +def iter_brokermods(): + """Iterate all built-in broker modules. + """ + for name in __brokers__: + yield get_brokermod(name) diff --git a/piker/brokers/core.py b/piker/brokers/core.py index 6d9c7dbd..0290e0cf 100644 --- a/piker/brokers/core.py +++ b/piker/brokers/core.py @@ -8,7 +8,6 @@ from typing import AsyncContextManager import trio -from .questrade import QuestradeError from ..log import get_logger log = get_logger('broker.core') @@ -100,8 +99,8 @@ async def poll_tickers( delay = sleeptime - tot if delay <= 0: log.warn( - f"Took {req_time} (request) + {proc_time} (processing) = {tot}" - f" secs (> {sleeptime}) for processing quotes?") + f"Took {req_time} (request) + {proc_time} (processing) " + f"= {tot} secs (> {sleeptime}) for processing quotes?") else: log.debug(f"Sleeping for {delay}") await trio.sleep(delay) diff --git a/piker/brokers/questrade.py b/piker/brokers/questrade.py index 312b2fa9..7127ae75 100644 --- a/piker/brokers/questrade.py +++ b/piker/brokers/questrade.py @@ -17,7 +17,7 @@ from ..log import get_logger, colorize_json import asks asks.init('trio') -log = get_logger('questrade') +log = get_logger(__name__) _refresh_token_ep = 'https://login.questrade.com/oauth2/' _version = 'v1' @@ -165,8 +165,8 @@ class Client: return quotes - async def symbols(self, tickers): - """Return quotes for each ticker in ``tickers``. + async def symbol_data(self, tickers: [str]): + """Return symbol data for ``tickers``. """ t2ids = await self.tickers2ids(tickers) ids = ','.join(map(str, t2ids.values())) diff --git a/piker/brokers/robinhood.py b/piker/brokers/robinhood.py index 405f1fa0..42ba2982 100644 --- a/piker/brokers/robinhood.py +++ b/piker/brokers/robinhood.py @@ -4,14 +4,15 @@ Robinhood API backend. from functools import partial from async_generator import asynccontextmanager +# TODO: move to urllib3/requests once supported import asks from ..log import get_logger from ._util import resproc from ..calc import percent_change -log = get_logger('robinhood') - +asks.init('trio') +log = get_logger(__name__) _service_ep = 'https://api.robinhood.com' @@ -43,15 +44,25 @@ class Client: self._sess.base_location = _service_ep self.api = _API(self._sess) - async def quote(self, symbols: [str]): - results = (await self.api.quotes(','.join(symbols)))['results'] - return {quote['symbol'] if quote else sym: quote - for sym, quote in zip(symbols, results)} + def _zip_in_order(self, symbols: [str], results_dict: dict): + return {quote.get('symbol', sym) if quote else sym: quote + for sym, quote in zip(symbols, results_dict)} - async def symbols(self, tickers: [str]): - """Placeholder for the watchlist calling code... + async def quote(self, symbols: [str]): + """Retrieve quotes for a list of ``symbols``. """ - return {} + return self._zip_in_order( + symbols, + (await self.api.quotes(','.join(symbols)))['results'] + ) + + async def symbol_data(self, symbols: [str]): + """Retrieve symbol data via the ``fundmentals`` endpoint. + """ + return self._zip_in_order( + symbols, + (await self.api.fundamentals(','.join(symbols)))['results'] + ) @asynccontextmanager diff --git a/piker/cli.py b/piker/cli.py index 43540213..daefd1e1 100644 --- a/piker/cli.py +++ b/piker/cli.py @@ -2,14 +2,13 @@ Console interface to broker client/daemons. """ from functools import partial -from importlib import import_module import click import trio import pandas as pd from .log import get_console_log, colorize_json, get_logger -from .brokers import core +from .brokers import core, get_brokermod log = get_logger('cli') DEFAULT_BROKER = 'robinhood' @@ -44,7 +43,7 @@ def api(meth, kwargs, loglevel, broker, keys): """client for testing broker API methods with pretty printing of output. """ log = get_console_log(loglevel) - brokermod = import_module('.' + broker, 'piker.brokers') + brokermod = get_brokermod(broker) _kwargs = {} for kwarg in kwargs: @@ -77,11 +76,11 @@ def api(meth, kwargs, loglevel, broker, keys): @click.option('--loglevel', '-l', default='warning', help='Logging level') @click.option('--df-output', '-df', flag_value=True, help='Ouput in `pandas.DataFrame` format') -@click.argument('tickers', nargs=-1) +@click.argument('tickers', nargs=-1, required=True) def quote(loglevel, broker, tickers, df_output): """client for testing broker API methods with pretty printing of output. """ - brokermod = import_module('.' + broker, 'piker.brokers') + brokermod = get_brokermod(broker) quotes = run(partial(core.quote, brokermod, tickers), loglevel=loglevel) if not quotes: log.error(f"No quotes could be found for {tickers}?") @@ -111,7 +110,7 @@ def watch(loglevel, broker, rate, name): """ from .ui.watchlist import _async_main log = get_console_log(loglevel) # activate console logging - brokermod = import_module('.' + broker, 'piker.brokers') + brokermod = get_brokermod(broker) watchlists = { 'cannabis': [ @@ -119,7 +118,7 @@ def watch(loglevel, broker, rate, name): 'CBW.VN', 'TRST.CN', 'VFF.TO', 'ACB.TO', 'ABCN.VN', 'APH.TO', 'MARI.CN', 'WMD.VN', 'LEAF.TO', 'THCX.VN', 'WEED.TO', 'NINE.VN', 'RTI.VN', 'SNN.CN', 'ACB.TO', - 'OGI.VN', 'IMH.VN', 'FIRE.VN', 'EAT.CN', 'NUU.VN', + 'OGI.VN', 'IMH.VN', 'FIRE.VN', 'EAT.CN', 'WMD.VN', 'HEMP.VN', 'CALI.CN', 'RQB.CN', 'MPX.CN', 'SEED.TO', 'HMJR.TO', 'CMED.TO', 'PAS.VN', 'CRON', diff --git a/piker/ui/watchlist.py b/piker/ui/watchlist.py index 4aa3b268..b6bc410c 100644 --- a/piker/ui/watchlist.py +++ b/piker/ui/watchlist.py @@ -393,11 +393,11 @@ async def _async_main(name, tickers, brokermod, rate): async with brokermod.get_client() as client: async with trio.open_nursery() as nursery: # get long term data including last days close price - sd = await client.symbols(tickers) + sd = await client.symbol_data(tickers) nursery.start_soon( partial(poll_tickers, client, brokermod.quoter, tickers, queue, - rate=rate) + rate=rate) ) # get first quotes response diff --git a/requirements.txt b/requirements.txt index 26f8be53..64f8ca32 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ # matham's next-gen async port of kivy -git+git://github.com/matham/kivy.git@async-loop +git+git://github.com/matham/kivy.git@async-loop#egg=kivy diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 00000000..a9c9077e --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,87 @@ +""" +CLI testing, dawg. +""" +import json +import subprocess +import pytest + + +def run(cmd): + """Run cmd and check for zero return code. + """ + cp = subprocess.run(cmd.split()) + cp.check_returncode() + return cp + + +def verify_keys(tickers, quotes_dict): + """Verify all ticker names are keys in ``quotes_dict``. + """ + for key, quote in quotes_dict.items(): + assert key in tickers + + +@pytest.fixture +def nyse_tickers(): + """List of well known NYSE ticker symbols. + """ + return ('TD', 'CRON', 'TSLA', 'AAPL') + + +def test_known_quotes(capfd, nyse_tickers): + """Verify quotes are dumped to the console as json. + """ + run(f"piker quote {' '.join(nyse_tickers)}") + + # verify output can be parsed as json + out, err = capfd.readouterr() + quotes_dict = json.loads(out) + verify_keys(nyse_tickers, quotes_dict) + + +@pytest.mark.parametrize( + 'multiple_tickers', + [True, False] +) +def test_quotes_ticker_not_found( + capfd, caplog, nyse_tickers, multiple_tickers +): + """Verify that if a ticker can't be found it's quote value is + ``None`` and a warning log message is emitted to the console. + """ + bad_ticker = ('doggy',) + tickers = bad_ticker + nyse_tickers if multiple_tickers else bad_ticker + + run(f"piker quote {' '.join(tickers)}") + + out, err = capfd.readouterr() + if out: + # verify output can be parsed as json + quotes_dict = json.loads(out) + verify_keys(tickers, quotes_dict) + # check for warning log message when some quotes are found + warnmsg = f'Could not find symbol {bad_ticker[0]}' + assert warnmsg in err + else: + # when no quotes are found we should get an error message + errmsg = f'No quotes could be found for {bad_ticker}' + assert errmsg in err + + +def test_api_method(nyse_tickers, capfd): + """Ensure a low level api method can be called via CLI. + """ + run(f"piker api quotes symbols={','.join(nyse_tickers)}") + out, err = capfd.readouterr() + quotes_dict = json.loads(out) + assert isinstance(quotes_dict, dict) + + +def test_api_method_not_found(nyse_tickers, capfd): + """Ensure an error messages is printed when an API method isn't found. + """ + bad_meth = 'doggy' + run(f"piker api {bad_meth} names={' '.join(nyse_tickers)}") + out, err = capfd.readouterr() + assert 'null' in out + assert f'No api method `{bad_meth}` could be found?' in err