diff --git a/piker/cli.py b/piker/cli.py index daefd1e1..65377937 100644 --- a/piker/cli.py +++ b/piker/cli.py @@ -2,17 +2,26 @@ Console interface to broker client/daemons. """ from functools import partial +from importlib import import_module +import os +from collections import defaultdict +import json import click import trio import pandas as pd + from .log import get_console_log, colorize_json, get_logger +from . import watchlists as wl from .brokers import core, get_brokermod log = get_logger('cli') DEFAULT_BROKER = 'robinhood' +_config_dir = click.get_app_dir('piker') +_watchlists_data_path = os.path.join(_config_dir, 'watchlists.json') + def run(main, loglevel='info'): log = get_console_log(loglevel) @@ -112,7 +121,7 @@ def watch(loglevel, broker, rate, name): log = get_console_log(loglevel) # activate console logging brokermod = get_brokermod(broker) - watchlists = { + watchlists_base = { 'cannabis': [ 'EMH.VN', 'LEAF.TO', 'HVT.VN', 'HMMJ.TO', 'APH.TO', 'CBW.VN', 'TRST.CN', 'VFF.TO', 'ACB.TO', 'ABCN.VN', @@ -127,6 +136,8 @@ def watch(loglevel, broker, rate, name): 'pharma': ['ATE.VN'], 'indexes': ['SPY', 'DAX', 'QQQ', 'DIA'], } + watchlist_from_file = wl.ensure_watchlists(_watchlists_data_path) + watchlists = wl.merge_watchlist(watchlist_from_file, watchlists_base) # broker_conf_path = os.path.join( # click.get_app_dir('piker'), 'watchlists.json') # from piker.testing import _quote_streamer as brokermod @@ -135,3 +146,83 @@ def watch(loglevel, broker, rate, name): rate = broker_limit log.warn(f"Limiting {brokermod.__name__} query rate to {rate}/sec") trio.run(_async_main, name, watchlists[name], brokermod, rate) + # broker_conf_path = os.path.join( + # click.get_app_dir('piker'), 'watchlists.json') + # from piker.testing import _quote_streamer as brokermod + + +@cli.group() +@click.option('--loglevel', '-l', default='warning', help='Logging level') +@click.option('--config_dir', '-d', default=_watchlists_data_path, + help='Path to piker configuration directory') +@click.pass_context +def watchlists(ctx, loglevel, config_dir): + """Watchlists commands and operations + """ + get_console_log(loglevel) # activate console logging + wl.make_config_dir(_config_dir) + ctx.obj = {'path': config_dir, + 'watchlist': wl.ensure_watchlists(config_dir)} + + +@watchlists.command(help='show watchlist') +@click.argument('name', nargs=1, required=False) +@click.pass_context +def show(ctx, name): + watchlist = ctx.obj['watchlist'] + click.echo(colorize_json( + watchlist if name is None else watchlist[name])) + + +@watchlists.command(help='load passed in watchlist') +@click.argument('data', nargs=1, required=True) +@click.pass_context +def load(ctx, data): + try: + wl.write_sorted_json(json.loads(data), ctx.obj['path']) + except (json.JSONDecodeError, IndexError): + click.echo('You have passed an invalid text respresentation of a ' + 'JSON object. Try again.') + + +@watchlists.command(help='add ticker to watchlist') +@click.argument('name', nargs=1, required=True) +@click.argument('ticker_name', nargs=1, required=True) +@click.pass_context +def add(ctx, name, ticker_name): + watchlist = wl.add_ticker(name, ticker_name, + ctx.obj['watchlist']) + wl.write_sorted_json(watchlist, ctx.obj['path']) + + +@watchlists.command(help='remove ticker from watchlist') +@click.argument('name', nargs=1, required=True) +@click.argument('ticker_name', nargs=1, required=True) +@click.pass_context +def remove(ctx, name, ticker_name): + watchlist = wl.remove_ticker(name, ticker_name, ctx.obj['watchlist']) + wl.write_sorted_json(watchlist, ctx.obj['path']) + + +@watchlists.command(help='delete watchlist group') +@click.argument('name', nargs=1, required=True) +@click.pass_context +def delete(ctx, name): + watchlist = wl.delete_group(name, ctx.obj['watchlist']) + wl.write_sorted_json(watchlist, ctx.obj['path']) + + +@watchlists.command(help='merge a watchlist from another user') +@click.argument('watchlist_to_merge', nargs=1, required=True) +@click.pass_context +def merge(ctx, watchlist_to_merge): + merged_watchlist = wl.merge_watchlist(json.loads(watchlist_to_merge), + ctx.obj['watchlist']) + wl.write_sorted_json(merged_watchlist, ctx.obj['path']) + + +@watchlists.command(help='dump text respresentation of a watchlist to console') +@click.argument('name', nargs=1, required=False) +@click.pass_context +def dump(ctx, name): + click.echo(json.dumps(ctx.obj['watchlist'])) diff --git a/piker/watchlists.py b/piker/watchlists.py new file mode 100644 index 00000000..ae777904 --- /dev/null +++ b/piker/watchlists.py @@ -0,0 +1,52 @@ +import os +import json +from collections import defaultdict + +from .log import get_logger + +log = get_logger(__name__) + + +def write_sorted_json(watchlist, path): + for key in watchlist: + watchlist[key] = sorted(list(set(watchlist[key]))) + with open(path, 'w') as f: + json.dump(watchlist, f, sort_keys=True) + + +def make_config_dir(dir_path): + if not os.path.isdir(dir_path): + log.debug(f"Creating config dir {dir_path}") + os.makedirs(dir_path) + + +def ensure_watchlists(file_path): + mode = 'r' if os.path.isfile(file_path) else 'w' + with open(file_path, mode) as f: + return json.load(f) if not os.stat(file_path).st_size == 0 else {} + + +def add_ticker(name, ticker_name, watchlist): + watchlist.setdefault(name, []).append(str(ticker_name).upper()) + return watchlist + + +def remove_ticker(name, ticker_name, watchlist): + if name in watchlist: + watchlist[name].remove(str(ticker_name).upper()) + if watchlist[name] == []: + del watchlist[name] + return watchlist + + +def delete_group(name, watchlist): + watchlist.pop(name, None) + return watchlist + + +def merge_watchlist(watchlist_to_merge, watchlist): + merged_watchlist = defaultdict(list) + for d in (watchlist, watchlist_to_merge): + for key, value in d.items(): + merged_watchlist[key].extend(value) + return merged_watchlist diff --git a/tests/test_cli.py b/tests/test_cli.py index a9c9077e..10e5668f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -4,12 +4,19 @@ CLI testing, dawg. import json import subprocess import pytest +import tempfile +import os.path +import logging + +import piker.watchlists as wl +import piker.cli as cli +from piker.log import colorize_json -def run(cmd): +def run(cmd, *args): """Run cmd and check for zero return code. """ - cp = subprocess.run(cmd.split()) + cp = subprocess.run(cmd.split() + list(args)) cp.check_returncode() return cp @@ -85,3 +92,116 @@ def test_api_method_not_found(nyse_tickers, capfd): out, err = capfd.readouterr() assert 'null' in out assert f'No api method `{bad_meth}` could be found?' in err + + +@pytest.fixture +def temp_dir(): + """Creates a path to a pretend config dir in a temporary directory for + testing. + """ + with tempfile.TemporaryDirectory() as tempdir: + yield os.path.join(tempdir, 'piker') + + +@pytest.fixture +def piker_dir(temp_dir): + wl.make_config_dir(temp_dir) + json_file_path = os.path.join(temp_dir, 'watchlists.json') + watchlist = { + 'dad': ['GM', 'TSLA', 'DOL.TO', 'CIM', 'SPY', 'SHOP.TO'], + 'pharma': ['ATE.VN'], + 'indexes': ['SPY', 'DAX', 'QQQ', 'DIA'], + } + wl.write_sorted_json(watchlist, json_file_path) + yield json_file_path + + +def test_show_watchlists(capfd, piker_dir): + """Ensure a watchlist is printed. + """ + expected_out = json.dumps({ + 'dad': ['CIM', 'DOL.TO', 'GM', 'SHOP.TO', 'SPY', 'TSLA'], + 'indexes': ['DAX', 'DIA', 'QQQ', 'SPY'], + 'pharma': ['ATE.VN'], + }, indent=4) + run(f"piker watchlists -d {piker_dir} show") + out, err = capfd.readouterr() + assert out.strip() == expected_out + + +def test_dump_watchlists(capfd, piker_dir): + """Ensure watchlist is dumped. + """ + expected_out = json.dumps({ + 'dad': ['CIM', 'DOL.TO', 'GM', 'SHOP.TO', 'SPY', 'TSLA'], + 'indexes': ['DAX', 'DIA', 'QQQ', 'SPY'], + 'pharma': ['ATE.VN'], + }) + run(f"piker watchlists -d {piker_dir} dump") + out, err = capfd.readouterr() + assert out.strip() == expected_out + + +def test_ticker_added_to_watchlists(capfd, piker_dir): + expected_out = { + 'dad': ['CIM', 'DOL.TO', 'GM', 'SHOP.TO', 'SPY', 'TSLA'], + 'indexes': ['DAX', 'DIA', 'QQQ', 'SPY'], + 'pharma': ['ATE.VN', 'CRACK'], + } + run(f"piker watchlists -d {piker_dir} add pharma CRACK") + out = wl.ensure_watchlists(piker_dir) + assert out == expected_out + + +def test_ticker_removed_from_watchlists(capfd, piker_dir): + expected_out = { + 'dad': ['CIM', 'DOL.TO', 'GM', 'SHOP.TO', 'SPY', 'TSLA'], + 'indexes': ['DAX', 'DIA', 'SPY'], + 'pharma': ['ATE.VN'], + } + run(f"piker watchlists -d {piker_dir} remove indexes QQQ") + out = wl.ensure_watchlists(piker_dir) + assert out == expected_out + + +def test_group_deleted_from_watchlists(capfd, piker_dir): + expected_out = { + 'dad': ['CIM', 'DOL.TO', 'GM', 'SHOP.TO', 'SPY', 'TSLA'], + 'indexes': ['DAX', 'DIA', 'QQQ', 'SPY'], + } + run(f"piker watchlists -d {piker_dir} delete pharma") + out = wl.ensure_watchlists(piker_dir) + assert out == expected_out + + +def test_watchlists_loaded(capfd, piker_dir): + expected_out = { + 'dad': ['CIM', 'DOL.TO', 'GM', 'SHOP.TO', 'SPY', 'TSLA'], + 'pharma': ['ATE.VN'], + } + expected_out_text = json.dumps(expected_out) + run(f"piker watchlists -d {piker_dir} load", expected_out_text) + out = wl.ensure_watchlists(piker_dir) + assert out == expected_out + + +def test_watchlists_are_merged(capfd, piker_dir): + orig_watchlist = { + 'dad': ['CIM', 'DOL.TO', 'GM', 'SHOP.TO', 'SPY', 'TSLA'], + 'indexes': ['DAX', 'DIA', 'QQQ', 'SPY'], + 'pharma': ['ATE.VN'], + } + list_to_merge = json.dumps({ + 'drugs': ['CRACK'], + 'pharma': ['ATE.VN', 'MALI', 'PERCOCET'] + }) + expected_out = { + 'dad': ['CIM', 'DOL.TO', 'GM', 'SHOP.TO', 'SPY', 'TSLA'], + 'indexes': ['DAX', 'DIA', 'QQQ', 'SPY'], + 'pharma': ['ATE.VN', 'MALI', 'PERCOCET'], + 'drugs': ['CRACK'] + } + wl.write_sorted_json(orig_watchlist, piker_dir) + run(f"piker watchlists -d {piker_dir} merge", list_to_merge) + out = wl.ensure_watchlists(piker_dir) + assert out == expected_out diff --git a/tests/test_watchlists.py b/tests/test_watchlists.py new file mode 100644 index 00000000..f198860f --- /dev/null +++ b/tests/test_watchlists.py @@ -0,0 +1,101 @@ +""" +Watchlists testing. +""" +import json +import pytest +import tempfile +import os.path +import logging + +import piker.watchlists as wl + + +@pytest.fixture +def temp_dir(): + """Creates a path to a pretend config dir in a temporary directory for + testing. + """ + with tempfile.TemporaryDirectory() as tempdir: + config_dir = os.path.join(tempdir, 'piker') + yield config_dir + + +@pytest.fixture +def piker_dir(temp_dir): + wl.make_config_dir(temp_dir) + yield os.path.join(temp_dir, 'watchlists.json') + + +def test_watchlist_is_sorted_no_dups_and_saved_to_file(piker_dir): + wl_temp = {'test': ['TEST.CN', 'AAA'], 'AA': ['TEST.CN', 'TEST.CN'], + 'AA': ['TEST.CN']} + wl_sort = {'AA': ['TEST.CN'], 'test': ['AAA', 'TEST.CN']} + wl.write_sorted_json(wl_temp, piker_dir) + temp_sorted = wl.ensure_watchlists(piker_dir) + assert temp_sorted == wl_sort + + +def test_watchlists_config_dir_created(caplog, temp_dir): + """Ensure that a config directory is created. + """ + with caplog.at_level(logging.DEBUG): + wl.make_config_dir(temp_dir) + assert len(caplog.records) == 1 + record = caplog.records[0] + assert record.levelname == 'DEBUG' + assert record.message == f"Creating config dir {temp_dir}" + assert os.path.isdir(temp_dir) + # Test that there is no error and that a log message is not generatd + # when trying to create a directory that already exists + with caplog.at_level(logging.DEBUG): + wl.make_config_dir(temp_dir) + # There should be no additional log message. + assert len(caplog.records) == 1 + + +def test_watchlist_is_read_from_file(piker_dir): + """Ensure json info is read from file or an empty dict is generated + and that text respresentation of a watchlist is saved to file. + """ + wl_temp = wl.ensure_watchlists(piker_dir) + assert wl_temp == {} + wl_temp2 = {"AA": ["TEST.CN"]} + wl.write_sorted_json(wl_temp2, piker_dir) + assert wl_temp2 == wl.ensure_watchlists(piker_dir) + + +def test_new_ticker_added(): + """Ensure that a new ticker is added to a watchlist for both cases. + """ + wl_temp = wl.add_ticker('test', 'TEST.CN', {'test': ['TEST2.CN']}) + assert len(wl_temp['test']) == 2 + wl_temp = wl.add_ticker('test2', 'TEST.CN', wl_temp) + assert wl_temp['test2'] + + +def test_ticker_is_removed(): + """Verify that passed in ticker is removed and that a group is removed + if no tickers left. + """ + wl_temp = {'test': ['TEST.CN', 'TEST2.CN'], 'test2': ['TEST.CN']} + wl_temp = wl.remove_ticker('test', 'TEST.CN', wl_temp) + wl_temp = wl.remove_ticker('test2', 'TEST.CN', wl_temp) + assert wl_temp == {'test': ['TEST2.CN']} + assert not wl_temp.get('test2') + + +def test_group_is_deleted(): + """Check that watchlist group is removed. + """ + wl_temp = {'test': ['TEST.CN']} + wl_temp = wl.delete_group('test', wl_temp) + assert not wl_temp.get('test') + + +def test_watchlist_is_merged(): + """Ensure that watchlist is merged. + """ + wl_temp = {'test': ['TEST.CN']} + wl_temp2 = {'test': ['TOAST'], "test2": ["TEST2.CN"]} + wl_temp3 = wl.merge_watchlist(wl_temp2, wl_temp) + assert wl_temp3 == {'test': ['TEST.CN', 'TOAST'], 'test2': ['TEST2.CN']}