Fix stablexl pipeline

pull/5/head
Guillermo Rodriguez 2023-07-27 11:30:11 -03:00
parent 4082adf184
commit 440bb015cd
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
2 changed files with 12 additions and 6 deletions

View File

@ -8,7 +8,7 @@ MODELS = {
'prompthero/openjourney': { 'short': 'midj'},
'runwayml/stable-diffusion-v1-5': { 'short': 'stable'},
'stabilityai/stable-diffusion-2-1-base': { 'short': 'stable2'},
'snowkidy/stable-diffusion-xl-base-1.0': { 'short': 'stablexl'},
'stabilityai/stable-diffusion-xl-base-1.0': { 'short': 'stablexl'},
'Linaqruf/anything-v3.0': { 'short': 'hdanime'},
'hakurei/waifu-diffusion': { 'short': 'waifu'},
'nitrosocke/Ghibli-Diffusion': { 'short': 'ghibli'},

View File

@ -15,6 +15,8 @@ from PIL import Image
from basicsr.archs.rrdbnet_arch import RRDBNet
from diffusers import (
DiffusionPipeline,
StableDiffusionXLPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionPipeline,
StableDiffusionImg2ImgPipeline,
EulerAncestralDiscreteScheduler
@ -80,12 +82,16 @@ def pipeline_for(model: str, mem_fraction: float = 1.0, image=False) -> Diffusio
if model == 'runwayml/stable-diffusion-v1-5':
params['revision'] = 'fp16'
if image:
pipe_class = StableDiffusionImg2ImgPipeline
elif model == 'snowkidy/stable-diffusion-xl-base-0.9':
pipe_class = DiffusionPipeline
if model == 'stabilityai/stable-diffusion-xl-base-1.0':
if image:
pipe_class = StableDiffusionXLImg2ImgPipeline
else:
pipe_class = StableDiffusionXLPipeline
else:
pipe_class = StableDiffusionPipeline
if image:
pipe_class = StableDiffusionImg2ImgPipeline
else:
pipe_class = StableDiffusionPipeline
pipe = pipe_class.from_pretrained(
model, **params)