mirror of https://github.com/skygpu/skynet.git
Start setting HF env vars from config
parent
d7ccbe7023
commit
47d9f59dbe
|
@ -34,6 +34,8 @@ def txt2img(*args, **kwargs):
|
||||||
|
|
||||||
config = load_skynet_ini()
|
config = load_skynet_ini()
|
||||||
hf_token = load_key(config, 'skynet.dgpu', 'hf_token')
|
hf_token = load_key(config, 'skynet.dgpu', 'hf_token')
|
||||||
|
hf_home = load_key(config, 'skynet.dgpu', 'hf_home')
|
||||||
|
set_hf_vars(hf_token, hf_home)
|
||||||
utils.txt2img(hf_token, **kwargs)
|
utils.txt2img(hf_token, **kwargs)
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
|
@ -50,6 +52,8 @@ def img2img(model, prompt, input, output, strength, guidance, steps, seed):
|
||||||
from . import utils
|
from . import utils
|
||||||
config = load_skynet_ini()
|
config = load_skynet_ini()
|
||||||
hf_token = load_key(config, 'skynet.dgpu', 'hf_token')
|
hf_token = load_key(config, 'skynet.dgpu', 'hf_token')
|
||||||
|
hf_home = load_key(config, 'skynet.dgpu', 'hf_home')
|
||||||
|
set_hf_vars(hf_token, hf_home)
|
||||||
utils.img2img(
|
utils.img2img(
|
||||||
hf_token,
|
hf_token,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -79,6 +83,8 @@ def download():
|
||||||
from . import utils
|
from . import utils
|
||||||
config = load_skynet_ini()
|
config = load_skynet_ini()
|
||||||
hf_token = load_key(config, 'skynet.dgpu', 'hf_token')
|
hf_token = load_key(config, 'skynet.dgpu', 'hf_token')
|
||||||
|
hf_home = load_key(config, 'skynet.dgpu', 'hf_home')
|
||||||
|
set_hf_vars(hf_token, hf_home)
|
||||||
utils.download_all_models(hf_token)
|
utils.download_all_models(hf_token)
|
||||||
|
|
||||||
@skynet.command()
|
@skynet.command()
|
||||||
|
@ -329,6 +335,9 @@ def dgpu(
|
||||||
logging.basicConfig(level=loglevel)
|
logging.basicConfig(level=loglevel)
|
||||||
|
|
||||||
config = load_skynet_ini(file_path=config_path)
|
config = load_skynet_ini(file_path=config_path)
|
||||||
|
hf_token = load_key(config, 'skynet.dgpu', 'hf_token')
|
||||||
|
hf_home = load_key(config, 'skynet.dgpu', 'hf_home')
|
||||||
|
set_hf_vars(hf_token, hf_home)
|
||||||
|
|
||||||
assert 'skynet.dgpu' in config
|
assert 'skynet.dgpu' in config
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
#!/usr/bin/python
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
from configparser import ConfigParser
|
from configparser import ConfigParser
|
||||||
|
|
||||||
from .constants import DEFAULT_CONFIG_PATH
|
from .constants import DEFAULT_CONFIG_PATH
|
||||||
|
@ -28,3 +30,7 @@ def load_key(config: ConfigParser, section: str, key: str) -> str:
|
||||||
raise ConfigParsingError(f'key \"{key}\" not in {conf_keys}')
|
raise ConfigParsingError(f'key \"{key}\" not in {conf_keys}')
|
||||||
|
|
||||||
return str(config[section][key])
|
return str(config[section][key])
|
||||||
|
|
||||||
|
def set_hf_vars(hf_token: str, hf_home: str):
|
||||||
|
os.environ['HF_TOKEN'] = hf_token
|
||||||
|
os.environ['HF_HOME'] = hf_home
|
||||||
|
|
Loading…
Reference in New Issue