Minor fixes to upscaler cli tool

pull/2/head
Guillermo Rodriguez 2022-12-19 12:42:31 -03:00
parent 6bc555f0d6
commit 896b0f684b
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
2 changed files with 8 additions and 7 deletions

View File

@ -56,7 +56,7 @@ def pipeline_for(algo: str, mem_fraction: float = 1.0):
pipe.enable_vae_slicing() pipe.enable_vae_slicing()
return pipe.to("cuda") return pipe.to('cuda')
class DGPUComputeError(BaseException): class DGPUComputeError(BaseException):

View File

@ -64,18 +64,19 @@ def upscale(
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
login(token=hf_token) login(token=hf_token)
params = {
'torch_dtype': torch.float16,
'safety_checker': None
}
pipe = StableDiffusionUpscalePipeline.from_pretrained( pipe = StableDiffusionUpscalePipeline.from_pretrained(
'stabilityai/stable-diffusion-x4-upscaler', **params) 'stabilityai/stable-diffusion-x4-upscaler',
revision="fp16", torch_dtype=torch.float16)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe.scheduler.config)
pipe = pipe.to('cuda')
prompt = prompt prompt = prompt
image = pipe( image = pipe(
prompt, prompt,
image=Image.open(img_path) image=Image.open(img_path).convert("RGB"),
num_inference_steps=steps
).images[0] ).images[0]
image.save(output) image.save(output)