mirror of https://github.com/skygpu/skynet.git
Add inpainting cli
parent
7274fb017d
commit
18ca8c573a
|
@ -66,6 +66,37 @@ def img2img(model, prompt, input, output, strength, guidance, steps, seed):
|
|||
seed=seed
|
||||
)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option('--model', '-m', default=list(MODELS.keys())[-1])
|
||||
@click.option(
|
||||
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
|
||||
@click.option('--input', '-i', default='input.png')
|
||||
@click.option('--mask', '-M', default='mask.png')
|
||||
@click.option('--output', '-o', default='output.png')
|
||||
@click.option('--strength', '-Z', default=1.0)
|
||||
@click.option('--guidance', '-g', default=10.0)
|
||||
@click.option('--steps', '-s', default=26)
|
||||
@click.option('--seed', '-S', default=None)
|
||||
def inpaint(model, prompt, input, mask, output, strength, guidance, steps, seed):
|
||||
from . import utils
|
||||
config = load_skynet_toml()
|
||||
hf_token = load_key(config, 'skynet.dgpu.hf_token')
|
||||
hf_home = load_key(config, 'skynet.dgpu.hf_home')
|
||||
set_hf_vars(hf_token, hf_home)
|
||||
utils.inpaint(
|
||||
hf_token,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
img_path=input,
|
||||
mask_path=mask,
|
||||
output=output,
|
||||
strength=strength,
|
||||
guidance=guidance,
|
||||
steps=steps,
|
||||
seed=seed
|
||||
)
|
||||
|
||||
@click.command()
|
||||
@click.option('--input', '-i', default='input.png')
|
||||
@click.option('--output', '-o', default='output.png')
|
||||
|
|
|
@ -5,20 +5,23 @@ VERSION = '0.1a12'
|
|||
DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda'
|
||||
|
||||
MODELS = {
|
||||
'prompthero/openjourney': {'short': 'midj', 'mem': 6},
|
||||
'runwayml/stable-diffusion-v1-5': {'short': 'stable', 'mem': 6},
|
||||
'stabilityai/stable-diffusion-2-1-base': {'short': 'stable2', 'mem': 6},
|
||||
'snowkidy/stable-diffusion-xl-base-0.9': {'short': 'stablexl0.9', 'mem': 8.3},
|
||||
'Linaqruf/anything-v3.0': {'short': 'hdanime', 'mem': 6},
|
||||
'hakurei/waifu-diffusion': {'short': 'waifu', 'mem': 6},
|
||||
'nitrosocke/Ghibli-Diffusion': {'short': 'ghibli', 'mem': 6},
|
||||
'dallinmackay/Van-Gogh-diffusion': {'short': 'van-gogh', 'mem': 6},
|
||||
'lambdalabs/sd-pokemon-diffusers': {'short': 'pokemon', 'mem': 6},
|
||||
'Envvi/Inkpunk-Diffusion': {'short': 'ink', 'mem': 6},
|
||||
'nousr/robo-diffusion': {'short': 'robot', 'mem': 6},
|
||||
'prompthero/openjourney': {'short': 'midj', 'mem': 6, 'size': {'w': 512, 'h': 512}},
|
||||
'runwayml/stable-diffusion-v1-5': {'short': 'stable', 'mem': 6, 'size': {'w': 512, 'h': 512}},
|
||||
'stabilityai/stable-diffusion-2-1-base': {'short': 'stable2', 'mem': 6, 'size': {'w': 512, 'h': 512}},
|
||||
'snowkidy/stable-diffusion-xl-base-0.9': {'short': 'stablexl0.9', 'mem': 8.3, 'size': {'w': 1024, 'h': 1024}},
|
||||
'Linaqruf/anything-v3.0': {'short': 'hdanime', 'mem': 6, 'size': {'w': 512, 'h': 512}},
|
||||
'hakurei/waifu-diffusion': {'short': 'waifu', 'mem': 6, 'size': {'w': 512, 'h': 512}},
|
||||
'nitrosocke/Ghibli-Diffusion': {'short': 'ghibli', 'mem': 6, 'size': {'w': 512, 'h': 512}},
|
||||
'dallinmackay/Van-Gogh-diffusion': {'short': 'van-gogh', 'mem': 6, 'size': {'w': 512, 'h': 512}},
|
||||
'lambdalabs/sd-pokemon-diffusers': {'short': 'pokemon', 'mem': 6, 'size': {'w': 512, 'h': 512}},
|
||||
'Envvi/Inkpunk-Diffusion': {'short': 'ink', 'mem': 6, 'size': {'w': 512, 'h': 512}},
|
||||
'nousr/robo-diffusion': {'short': 'robot', 'mem': 6, 'size': {'w': 512, 'h': 512}},
|
||||
|
||||
# -1 is always inpaint default
|
||||
'diffusers/stable-diffusion-xl-1.0-inpainting-0.1': {'short': 'stablexl-inpainting': 'mem': 8.3, 'size': {'w': 1024, 'h': 1024}},
|
||||
|
||||
# default is always last
|
||||
'stabilityai/stable-diffusion-xl-base-1.0': {'short': 'stablexl', 'mem': 8.3},
|
||||
'stabilityai/stable-diffusion-xl-base-1.0': {'short': 'stablexl', 'mem': 8.3, 'size': {'w': 1024, 'h': 1024}},
|
||||
}
|
||||
|
||||
SHORT_NAMES = [
|
||||
|
|
|
@ -168,7 +168,11 @@ class SkynetMM:
|
|||
arguments = prepare_params_for_diffuse(
|
||||
params, input_type, binary=binary)
|
||||
prompt, guidance, step, seed, upscaler, extra_params = arguments
|
||||
model = self.get_model(params['model'], 'image' in extra_params)
|
||||
model = self.get_model(
|
||||
params['model'],
|
||||
'image' in extra_params,
|
||||
'mask_image' in extra_params
|
||||
)
|
||||
|
||||
output = model(
|
||||
prompt,
|
||||
|
|
|
@ -63,6 +63,7 @@ def pipeline_for(
|
|||
model: str,
|
||||
mem_fraction: float = 1.0,
|
||||
image: bool = False,
|
||||
inpainting: bool = False,
|
||||
cache_dir: str | None = None
|
||||
) -> DiffusionPipeline:
|
||||
|
||||
|
@ -102,6 +103,11 @@ def pipeline_for(
|
|||
|
||||
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
||||
|
||||
if inpainting:
|
||||
pipe = AutoPipelineForInpainting.from_pretrained(
|
||||
model, **params)
|
||||
|
||||
else:
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
model, **params)
|
||||
|
||||
|
@ -168,8 +174,10 @@ def img2img(
|
|||
login(token=hf_token)
|
||||
pipe = pipeline_for(model, image=True)
|
||||
|
||||
model_info = MODELS[model]
|
||||
|
||||
with open(img_path, 'rb') as img_file:
|
||||
input_img = convert_from_bytes_and_crop(img_file.read(), 512, 512)
|
||||
input_img = convert_from_bytes_and_crop(img_file.read(), model_info['size']['w'], model_info['size']['h'])
|
||||
|
||||
seed = seed if seed else random.randint(0, 2 ** 64)
|
||||
prompt = prompt
|
||||
|
@ -184,6 +192,43 @@ def img2img(
|
|||
image.save(output)
|
||||
|
||||
|
||||
def inpaint(
|
||||
hf_token: str,
|
||||
model: str = 'diffusers/stable-diffusion-xl-1.0-inpainting-0.1',
|
||||
prompt: str = 'a red old tractor in a sunny wheat field',
|
||||
img_path: str = 'input.png',
|
||||
mask_path: str = 'mask.png',
|
||||
output: str = 'output.png',
|
||||
strength: float = 1.0,
|
||||
guidance: float = 10,
|
||||
steps: int = 28,
|
||||
seed: Optional[int] = None
|
||||
):
|
||||
login(token=hf_token)
|
||||
pipe = pipeline_for(model, image=True)
|
||||
|
||||
model_info = MODELS[model]
|
||||
|
||||
with open(img_path, 'rb') as img_file:
|
||||
input_img = convert_from_bytes_and_crop(img_file.read(), model_info['size']['w'], model_info['size']['h'])
|
||||
|
||||
with open(mask_path, 'rb') as mask_file:
|
||||
mask_img = convert_from_bytes_and_crop(mask_file.read(), model_info['size']['w'], model_info['size']['h'])
|
||||
|
||||
seed = seed if seed else random.randint(0, 2 ** 64)
|
||||
prompt = prompt
|
||||
image = pipe(
|
||||
prompt,
|
||||
image=input_img,
|
||||
mask_image=mask_img
|
||||
strength=strength,
|
||||
guidance_scale=guidance, num_inference_steps=steps,
|
||||
generator=torch.Generator("cuda").manual_seed(seed)
|
||||
).images[0]
|
||||
|
||||
image.save(output)
|
||||
|
||||
|
||||
def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'):
|
||||
return RealESRGANer(
|
||||
scale=4,
|
||||
|
|
Loading…
Reference in New Issue