diff --git a/skynet/utils.py b/skynet/utils.py index 2837118..177f62b 100644 --- a/skynet/utils.py +++ b/skynet/utils.py @@ -77,6 +77,10 @@ def pipeline_for(model: str, mem_fraction: float = 1.0, image=False) -> Diffusio 'safety_checker': None } + if model == 'snowkidy/stable-diffusion-xl-base-0.9': + # TODO: figure out what this does + params['addition_embed_type'] = None + if model == 'runwayml/stable-diffusion-v1-5': params['revision'] = 'fp16'