diff --git a/piker/cli.py b/piker/cli.py index e3f89b8d..86c42643 100644 --- a/piker/cli.py +++ b/piker/cli.py @@ -6,14 +6,15 @@ from importlib import import_module import os from collections import defaultdict import json -import ast import click import trio import pandas as pd + from .log import get_console_log, colorize_json, get_logger -from .brokers import core, get_brokermod +from . import watchlists as wl +from .brokers import core log = get_logger('cli') DEFAULT_BROKER = 'robinhood' @@ -149,14 +150,6 @@ def watch(loglevel, broker, rate, name): trio.run(_async_main, name, watchlists[name], brokermod) -def json_sorted_writer(watchlist, open_file): - for key in watchlist: - watchlist[key].sort() - s = set(watchlist[key]) - watchlist[key] = list(s) - json.dump(watchlist, open_file, sort_keys=True) - - @cli.group() @click.option('--loglevel', '-l', default='warning', help='Logging level') @click.pass_context @@ -165,21 +158,8 @@ def watchlists(ctx, loglevel): """ # import pdb; pdb.set_trace() get_console_log(loglevel) # activate console logging - - ctx.obj = {} - - if not os.path.isdir(_config_dir): - log.debug(f"Creating config dir {_config_dir}") - os.makedirs(_config_dir) - - if os.path.isfile(_watchlists_data_path): - f = open(_watchlists_data_path, 'r') - if not os.stat(_watchlists_data_path).st_size == 0: - ctx.obj = json.load(f) - f.close() - else: - f = open(_watchlists_data_path, 'w') - f.close() + wl.make_config_dir(_config_dir) + ctx.obj = wl.ensure_watchlists(_watchlists_data_path) @watchlists.command(help='show watchlist') @@ -191,15 +171,22 @@ def show(ctx, name): watchlist if name is None else watchlist[name])) +@watchlists.command(help='load passed in watchlist') +@click.argument('data', nargs=1, required=True) +@click.pass_context +def load(ctx, data): + try: + wl.load_watchlists(data, _watchlists_data_path) + except (json.JSONDecodeError, IndexError): + click.echo('You must pass in a text respresentation of a json object. Try again.') + + @watchlists.command(help='add a new watchlist') @click.argument('name', nargs=1, required=True) @click.pass_context def new(ctx, name): watchlist = ctx.obj - f = open(_watchlists_data_path, 'w') - watchlist.setdefault(name, []) - json_sorted_writer(watchlist, f) - f.close() + wl.new_group(name, watchlist, _watchlists_data_path) @watchlists.command(help='add ticker to watchlist') @@ -208,11 +195,7 @@ def new(ctx, name): @click.pass_context def add(ctx, name, ticker_name): watchlist = ctx.obj - f = open(_watchlists_data_path, 'w') - if name in watchlist: - watchlist[name].append(str(ticker_name).upper()) - json_sorted_writer(watchlist, f) - f.close() + wl.add_ticker(name, ticker_name, watchlist, _watchlists_data_path) @watchlists.command(help='remove ticker from watchlist') @@ -221,11 +204,7 @@ def add(ctx, name, ticker_name): @click.pass_context def remove(ctx, name, ticker_name): watchlist = ctx.obj - f = open(_watchlists_data_path, 'w') - if name in watchlist: - watchlist[name].remove(str(ticker_name).upper()) - json_sorted_writer(watchlist, f) - f.close() + wl.remove_ticker(name, ticker_name, watchlist, _watchlists_data_path) @watchlists.command(help='delete watchlist') @@ -233,11 +212,7 @@ def remove(ctx, name, ticker_name): @click.pass_context def delete(ctx, name): watchlist = ctx.obj - f = open(_watchlists_data_path, 'w') - if name in watchlist: - del watchlist[name] - json_sorted_writer(watchlist, f) - f.close() + wl.delete_group(name, watchlist, _watchlists_data_path) @watchlists.command(help='merge a watchlist from another user') @@ -245,14 +220,7 @@ def delete(ctx, name): @click.pass_context def merge(ctx, watchlist_to_merge): watchlist = ctx.obj - f = open(_watchlists_data_path, 'w') - merged_watchlist = defaultdict(list) - watchlist_to_merge = ast.literal_eval(watchlist_to_merge) - for d in (watchlist, watchlist_to_merge): - for key, value in d.items(): - merged_watchlist[key].extend(value) - json_sorted_writer(merged_watchlist, f) - f.close() + wl.merge_watchlist(watchlist_to_merge, watchlist, _watchlists_data_path) @watchlists.command(help='dump a text respresentation of a watchlist to console') @@ -260,12 +228,4 @@ def merge(ctx, watchlist_to_merge): @click.pass_context def dump(ctx, name): watchlist = ctx.obj - f = open(_watchlists_data_path, 'r') print(json.dumps(watchlist)) - f.close() - - -@watchlists.command(help='purge watchlists and remove json file') -def purge(): - # import pdb; pdb.set_trace() - os.remove(_watchlists_data_path) diff --git a/piker/watchlists.py b/piker/watchlists.py new file mode 100644 index 00000000..6084a094 --- /dev/null +++ b/piker/watchlists.py @@ -0,0 +1,63 @@ +import os +import json +import ast +from collections import defaultdict + + +def write_sorted_json(watchlist, path): + for key in watchlist: + watchlist[key].sort() + s = set(watchlist[key]) + watchlist[key] = list(s) + with open(path, 'w') as f: + json.dump(watchlist, f, sort_keys=True) + + +def make_config_dir(dir_path): + if not os.path.isdir(dir_path): + log.debug(f"Creating config dir {dir_path}") + os.makedirs(dir_path) + + +def ensure_watchlists(file_path): + mode = 'r' if os.path.isfile(file_path) else 'w' + with open(file_path, mode) as f: + data = json.load(f) if not os.stat(file_path).st_size == 0 else {} + return data + + +def load_watchlists(watchlist, path): + watchlist = json.loads(watchlist) + write_sorted_json(watchlist, path) + + +def new_group(name, watchlist, path): + watchlist.setdefault(name, []) + write_sorted_json(watchlist, path) + + +def add_ticker(name, ticker_name, watchlist, path): + if name in watchlist: + watchlist[name].append(str(ticker_name).upper()) + write_sorted_json(watchlist, path) + + +def remove_ticker(name, ticker_name, watchlist, path): + if name in watchlist: + watchlist[name].remove(str(ticker_name).upper()) + write_sorted_json(watchlist, path) + + +def delete_group(name, watchlist, path): + if name in watchlist: + del watchlist[name] + write_sorted_json(watchlist, path) + + +def merge_watchlist(watchlist_to_merge, watchlist, path): + merged_watchlist = defaultdict(list) + watchlist_to_merge = ast.literal_eval(watchlist_to_merge) + for d in (watchlist, watchlist_to_merge): + for key, value in d.items(): + merged_watchlist[key].extend(value) + write_sorted_json(merged_watchlist, path)