From c8a0a390a67eb3725ddca4dd1521d75e43383841 Mon Sep 17 00:00:00 2001 From: Guillermo Rodriguez Date: Thu, 8 Jun 2023 21:25:07 -0300 Subject: [PATCH] Fix for img2img mode on new worker system --- skynet/cli.py | 2 +- skynet/config.py | 2 ++ skynet/dgpu/compute.py | 2 +- skynet/utils.py | 6 ++++-- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/skynet/cli.py b/skynet/cli.py index a9f240a..63b3d92 100644 --- a/skynet/cli.py +++ b/skynet/cli.py @@ -50,7 +50,7 @@ def txt2img(*args, **kwargs): utils.txt2img(hf_token, **kwargs) @click.command() -@click.option('--model', '-m', default='midj') +@click.option('--model', '-m', default=list(MODELS.keys())[0]) @click.option( '--prompt', '-p', default='a red old tractor in a sunny wheat field') @click.option('--input', '-i', default='input.png') diff --git a/skynet/config.py b/skynet/config.py index fc8f2d9..d068295 100644 --- a/skynet/config.py +++ b/skynet/config.py @@ -34,6 +34,7 @@ def init_env_from_config( sub_config = config['skynet.dgpu'] if 'hf_token' in sub_config: hf_token = sub_config['hf_token'] + os.environ['HF_TOKEN'] = hf_token if 'HF_HOME' in os.environ: hf_home = os.environ['HF_HOME'] @@ -42,6 +43,7 @@ def init_env_from_config( sub_config = config['skynet.dgpu'] if 'hf_home' in sub_config: hf_home = sub_config['hf_home'] + os.environ['HF_HOME'] = hf_home if 'TG_TOKEN' in os.environ: tg_token = os.environ['TG_TOKEN'] diff --git a/skynet/dgpu/compute.py b/skynet/dgpu/compute.py index a51072d..069af47 100644 --- a/skynet/dgpu/compute.py +++ b/skynet/dgpu/compute.py @@ -133,7 +133,7 @@ class SkynetMM: arguments = prepare_params_for_diffuse(params, binary) prompt, guidance, step, seed, upscaler, extra_params = arguments - model = self.get_model(params['model'], 'image' in params) + model = self.get_model(params['model'], 'image' in extra_params) image = model( prompt, diff --git a/skynet/utils.py b/skynet/utils.py index e4bd04b..2837118 100644 --- a/skynet/utils.py +++ b/skynet/utils.py @@ -55,6 +55,8 @@ def convert_from_bytes_and_crop(raw: bytes, max_w: int, max_h: int) -> Image: if w > max_w or h > max_h: image.thumbnail((512, 512)) + return image.convert('RGB') + def pipeline_for(model: str, mem_fraction: float = 1.0, image=False) -> DiffusionPipeline: assert torch.cuda.is_available() @@ -147,7 +149,8 @@ def img2img( login(token=hf_token) pipe = pipeline_for(model, image=True) - input_img = Image.open(img_path).convert('RGB') + with open(img_path, 'rb') as img_file: + input_img = convert_from_bytes_and_crop(img_file.read(), 512, 512) seed = seed if seed else random.randint(0, 2 ** 64) prompt = prompt @@ -162,7 +165,6 @@ def img2img( image.save(output) - def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'): return RealESRGANer( scale=4,