Start setting HF env vars from config

pull/26/head
Guillermo Rodriguez 2023-10-05 14:15:21 -03:00
parent d7ccbe7023
commit 47d9f59dbe
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
2 changed files with 15 additions and 0 deletions

View File

@ -34,6 +34,8 @@ def txt2img(*args, **kwargs):
config = load_skynet_ini() config = load_skynet_ini()
hf_token = load_key(config, 'skynet.dgpu', 'hf_token') hf_token = load_key(config, 'skynet.dgpu', 'hf_token')
hf_home = load_key(config, 'skynet.dgpu', 'hf_home')
set_hf_vars(hf_token, hf_home)
utils.txt2img(hf_token, **kwargs) utils.txt2img(hf_token, **kwargs)
@click.command() @click.command()
@ -50,6 +52,8 @@ def img2img(model, prompt, input, output, strength, guidance, steps, seed):
from . import utils from . import utils
config = load_skynet_ini() config = load_skynet_ini()
hf_token = load_key(config, 'skynet.dgpu', 'hf_token') hf_token = load_key(config, 'skynet.dgpu', 'hf_token')
hf_home = load_key(config, 'skynet.dgpu', 'hf_home')
set_hf_vars(hf_token, hf_home)
utils.img2img( utils.img2img(
hf_token, hf_token,
model=model, model=model,
@ -79,6 +83,8 @@ def download():
from . import utils from . import utils
config = load_skynet_ini() config = load_skynet_ini()
hf_token = load_key(config, 'skynet.dgpu', 'hf_token') hf_token = load_key(config, 'skynet.dgpu', 'hf_token')
hf_home = load_key(config, 'skynet.dgpu', 'hf_home')
set_hf_vars(hf_token, hf_home)
utils.download_all_models(hf_token) utils.download_all_models(hf_token)
@skynet.command() @skynet.command()
@ -329,6 +335,9 @@ def dgpu(
logging.basicConfig(level=loglevel) logging.basicConfig(level=loglevel)
config = load_skynet_ini(file_path=config_path) config = load_skynet_ini(file_path=config_path)
hf_token = load_key(config, 'skynet.dgpu', 'hf_token')
hf_home = load_key(config, 'skynet.dgpu', 'hf_home')
set_hf_vars(hf_token, hf_home)
assert 'skynet.dgpu' in config assert 'skynet.dgpu' in config

View File

@ -1,5 +1,7 @@
#!/usr/bin/python #!/usr/bin/python
import os
from configparser import ConfigParser from configparser import ConfigParser
from .constants import DEFAULT_CONFIG_PATH from .constants import DEFAULT_CONFIG_PATH
@ -28,3 +30,7 @@ def load_key(config: ConfigParser, section: str, key: str) -> str:
raise ConfigParsingError(f'key \"{key}\" not in {conf_keys}') raise ConfigParsingError(f'key \"{key}\" not in {conf_keys}')
return str(config[section][key]) return str(config[section][key])
def set_hf_vars(hf_token: str, hf_home: str):
os.environ['HF_TOKEN'] = hf_token
os.environ['HF_HOME'] = hf_home