mirror of https://github.com/skygpu/skynet.git
Fix for img2img mode on new worker system
parent
91edb2aa56
commit
c8a0a390a6
|
@ -50,7 +50,7 @@ def txt2img(*args, **kwargs):
|
||||||
utils.txt2img(hf_token, **kwargs)
|
utils.txt2img(hf_token, **kwargs)
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@click.option('--model', '-m', default='midj')
|
@click.option('--model', '-m', default=list(MODELS.keys())[0])
|
||||||
@click.option(
|
@click.option(
|
||||||
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
|
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
|
||||||
@click.option('--input', '-i', default='input.png')
|
@click.option('--input', '-i', default='input.png')
|
||||||
|
|
|
@ -34,6 +34,7 @@ def init_env_from_config(
|
||||||
sub_config = config['skynet.dgpu']
|
sub_config = config['skynet.dgpu']
|
||||||
if 'hf_token' in sub_config:
|
if 'hf_token' in sub_config:
|
||||||
hf_token = sub_config['hf_token']
|
hf_token = sub_config['hf_token']
|
||||||
|
os.environ['HF_TOKEN'] = hf_token
|
||||||
|
|
||||||
if 'HF_HOME' in os.environ:
|
if 'HF_HOME' in os.environ:
|
||||||
hf_home = os.environ['HF_HOME']
|
hf_home = os.environ['HF_HOME']
|
||||||
|
@ -42,6 +43,7 @@ def init_env_from_config(
|
||||||
sub_config = config['skynet.dgpu']
|
sub_config = config['skynet.dgpu']
|
||||||
if 'hf_home' in sub_config:
|
if 'hf_home' in sub_config:
|
||||||
hf_home = sub_config['hf_home']
|
hf_home = sub_config['hf_home']
|
||||||
|
os.environ['HF_HOME'] = hf_home
|
||||||
|
|
||||||
if 'TG_TOKEN' in os.environ:
|
if 'TG_TOKEN' in os.environ:
|
||||||
tg_token = os.environ['TG_TOKEN']
|
tg_token = os.environ['TG_TOKEN']
|
||||||
|
|
|
@ -133,7 +133,7 @@ class SkynetMM:
|
||||||
|
|
||||||
arguments = prepare_params_for_diffuse(params, binary)
|
arguments = prepare_params_for_diffuse(params, binary)
|
||||||
prompt, guidance, step, seed, upscaler, extra_params = arguments
|
prompt, guidance, step, seed, upscaler, extra_params = arguments
|
||||||
model = self.get_model(params['model'], 'image' in params)
|
model = self.get_model(params['model'], 'image' in extra_params)
|
||||||
|
|
||||||
image = model(
|
image = model(
|
||||||
prompt,
|
prompt,
|
||||||
|
|
|
@ -55,6 +55,8 @@ def convert_from_bytes_and_crop(raw: bytes, max_w: int, max_h: int) -> Image:
|
||||||
if w > max_w or h > max_h:
|
if w > max_w or h > max_h:
|
||||||
image.thumbnail((512, 512))
|
image.thumbnail((512, 512))
|
||||||
|
|
||||||
|
return image.convert('RGB')
|
||||||
|
|
||||||
|
|
||||||
def pipeline_for(model: str, mem_fraction: float = 1.0, image=False) -> DiffusionPipeline:
|
def pipeline_for(model: str, mem_fraction: float = 1.0, image=False) -> DiffusionPipeline:
|
||||||
assert torch.cuda.is_available()
|
assert torch.cuda.is_available()
|
||||||
|
@ -147,7 +149,8 @@ def img2img(
|
||||||
login(token=hf_token)
|
login(token=hf_token)
|
||||||
pipe = pipeline_for(model, image=True)
|
pipe = pipeline_for(model, image=True)
|
||||||
|
|
||||||
input_img = Image.open(img_path).convert('RGB')
|
with open(img_path, 'rb') as img_file:
|
||||||
|
input_img = convert_from_bytes_and_crop(img_file.read(), 512, 512)
|
||||||
|
|
||||||
seed = seed if seed else random.randint(0, 2 ** 64)
|
seed = seed if seed else random.randint(0, 2 ** 64)
|
||||||
prompt = prompt
|
prompt = prompt
|
||||||
|
@ -162,7 +165,6 @@ def img2img(
|
||||||
image.save(output)
|
image.save(output)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'):
|
def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'):
|
||||||
return RealESRGANer(
|
return RealESRGANer(
|
||||||
scale=4,
|
scale=4,
|
||||||
|
|
Loading…
Reference in New Issue