diff --git a/piker/brokers/config.py b/piker/brokers/config.py index 22cbfe9c..f1c8a90d 100644 --- a/piker/brokers/config.py +++ b/piker/brokers/config.py @@ -1,7 +1,7 @@ """ Broker configuration mgmt. """ -from os import path, makedirs +import os import configparser import click from ..log import get_logger @@ -9,28 +9,46 @@ from ..log import get_logger log = get_logger('broker-config') _config_dir = click.get_app_dir('piker') -_broker_conf_path = path.join(_config_dir, 'brokers.ini') +_file_name = 'brokers.ini' -def load(path: str = None) -> (configparser.ConfigParser, str): +def _override_config_dir( + path: str +) -> None: + global _config_dir + _config_dir = path + + +def get_broker_conf_path(): + return os.path.join(_config_dir, _file_name) + + +def load( + path: str = None +) -> (configparser.ConfigParser, str): """Load broker config. """ - path = path or _broker_conf_path + path = path or get_broker_conf_path() config = configparser.ConfigParser() - read = config.read(path) + config.read(path) log.debug(f"Read config file {path}") return config, path -def write(config: configparser.ConfigParser) -> None: +def write( + config: configparser.ConfigParser, + path: str = None, +) -> None: """Write broker config to disk. Create a ``brokers.ini`` file if one does not exist. """ - if not path.isdir(_config_dir): + path = path or get_broker_conf_path() + dirname = os.path.dirname(path) + if not os.path.isdir(dirname): log.debug(f"Creating config dir {_config_dir}") - makedirs(_config_dir) + os.makedirs(dirname) - log.debug(f"Writing config file {_broker_conf_path}") - with open(_broker_conf_path, 'w') as cf: + log.debug(f"Writing config file {path}") + with open(path, 'w') as cf: return config.write(cf)