From 47d9f59dbe971b36eac87a8b26e3e56f73c4cdf4 Mon Sep 17 00:00:00 2001 From: Guillermo Rodriguez Date: Thu, 5 Oct 2023 14:15:21 -0300 Subject: [PATCH] Start setting HF env vars from config --- skynet/cli.py | 9 +++++++++ skynet/config.py | 6 ++++++ 2 files changed, 15 insertions(+) diff --git a/skynet/cli.py b/skynet/cli.py index d8bf166..d391af2 100755 --- a/skynet/cli.py +++ b/skynet/cli.py @@ -34,6 +34,8 @@ def txt2img(*args, **kwargs): config = load_skynet_ini() hf_token = load_key(config, 'skynet.dgpu', 'hf_token') + hf_home = load_key(config, 'skynet.dgpu', 'hf_home') + set_hf_vars(hf_token, hf_home) utils.txt2img(hf_token, **kwargs) @click.command() @@ -50,6 +52,8 @@ def img2img(model, prompt, input, output, strength, guidance, steps, seed): from . import utils config = load_skynet_ini() hf_token = load_key(config, 'skynet.dgpu', 'hf_token') + hf_home = load_key(config, 'skynet.dgpu', 'hf_home') + set_hf_vars(hf_token, hf_home) utils.img2img( hf_token, model=model, @@ -79,6 +83,8 @@ def download(): from . import utils config = load_skynet_ini() hf_token = load_key(config, 'skynet.dgpu', 'hf_token') + hf_home = load_key(config, 'skynet.dgpu', 'hf_home') + set_hf_vars(hf_token, hf_home) utils.download_all_models(hf_token) @skynet.command() @@ -329,6 +335,9 @@ def dgpu( logging.basicConfig(level=loglevel) config = load_skynet_ini(file_path=config_path) + hf_token = load_key(config, 'skynet.dgpu', 'hf_token') + hf_home = load_key(config, 'skynet.dgpu', 'hf_home') + set_hf_vars(hf_token, hf_home) assert 'skynet.dgpu' in config diff --git a/skynet/config.py b/skynet/config.py index 26669ae..83ec62f 100755 --- a/skynet/config.py +++ b/skynet/config.py @@ -1,5 +1,7 @@ #!/usr/bin/python +import os + from configparser import ConfigParser from .constants import DEFAULT_CONFIG_PATH @@ -28,3 +30,7 @@ def load_key(config: ConfigParser, section: str, key: str) -> str: raise ConfigParsingError(f'key \"{key}\" not in {conf_keys}') return str(config[section][key]) + +def set_hf_vars(hf_token: str, hf_home: str): + os.environ['HF_TOKEN'] = hf_token + os.environ['HF_HOME'] = hf_home