mirror of https://github.com/skygpu/skynet.git
Add inpainting cli
parent
7274fb017d
commit
18ca8c573a
|
@ -66,6 +66,37 @@ def img2img(model, prompt, input, output, strength, guidance, steps, seed):
|
||||||
seed=seed
|
seed=seed
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.option('--model', '-m', default=list(MODELS.keys())[-1])
|
||||||
|
@click.option(
|
||||||
|
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
|
||||||
|
@click.option('--input', '-i', default='input.png')
|
||||||
|
@click.option('--mask', '-M', default='mask.png')
|
||||||
|
@click.option('--output', '-o', default='output.png')
|
||||||
|
@click.option('--strength', '-Z', default=1.0)
|
||||||
|
@click.option('--guidance', '-g', default=10.0)
|
||||||
|
@click.option('--steps', '-s', default=26)
|
||||||
|
@click.option('--seed', '-S', default=None)
|
||||||
|
def inpaint(model, prompt, input, mask, output, strength, guidance, steps, seed):
|
||||||
|
from . import utils
|
||||||
|
config = load_skynet_toml()
|
||||||
|
hf_token = load_key(config, 'skynet.dgpu.hf_token')
|
||||||
|
hf_home = load_key(config, 'skynet.dgpu.hf_home')
|
||||||
|
set_hf_vars(hf_token, hf_home)
|
||||||
|
utils.inpaint(
|
||||||
|
hf_token,
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
img_path=input,
|
||||||
|
mask_path=mask,
|
||||||
|
output=output,
|
||||||
|
strength=strength,
|
||||||
|
guidance=guidance,
|
||||||
|
steps=steps,
|
||||||
|
seed=seed
|
||||||
|
)
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@click.option('--input', '-i', default='input.png')
|
@click.option('--input', '-i', default='input.png')
|
||||||
@click.option('--output', '-o', default='output.png')
|
@click.option('--output', '-o', default='output.png')
|
||||||
|
|
|
@ -5,20 +5,23 @@ VERSION = '0.1a12'
|
||||||
DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda'
|
DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda'
|
||||||
|
|
||||||
MODELS = {
|
MODELS = {
|
||||||
'prompthero/openjourney': {'short': 'midj', 'mem': 6},
|
'prompthero/openjourney': {'short': 'midj', 'mem': 6, 'size': {'w': 512, 'h': 512}},
|
||||||
'runwayml/stable-diffusion-v1-5': {'short': 'stable', 'mem': 6},
|
'runwayml/stable-diffusion-v1-5': {'short': 'stable', 'mem': 6, 'size': {'w': 512, 'h': 512}},
|
||||||
'stabilityai/stable-diffusion-2-1-base': {'short': 'stable2', 'mem': 6},
|
'stabilityai/stable-diffusion-2-1-base': {'short': 'stable2', 'mem': 6, 'size': {'w': 512, 'h': 512}},
|
||||||
'snowkidy/stable-diffusion-xl-base-0.9': {'short': 'stablexl0.9', 'mem': 8.3},
|
'snowkidy/stable-diffusion-xl-base-0.9': {'short': 'stablexl0.9', 'mem': 8.3, 'size': {'w': 1024, 'h': 1024}},
|
||||||
'Linaqruf/anything-v3.0': {'short': 'hdanime', 'mem': 6},
|
'Linaqruf/anything-v3.0': {'short': 'hdanime', 'mem': 6, 'size': {'w': 512, 'h': 512}},
|
||||||
'hakurei/waifu-diffusion': {'short': 'waifu', 'mem': 6},
|
'hakurei/waifu-diffusion': {'short': 'waifu', 'mem': 6, 'size': {'w': 512, 'h': 512}},
|
||||||
'nitrosocke/Ghibli-Diffusion': {'short': 'ghibli', 'mem': 6},
|
'nitrosocke/Ghibli-Diffusion': {'short': 'ghibli', 'mem': 6, 'size': {'w': 512, 'h': 512}},
|
||||||
'dallinmackay/Van-Gogh-diffusion': {'short': 'van-gogh', 'mem': 6},
|
'dallinmackay/Van-Gogh-diffusion': {'short': 'van-gogh', 'mem': 6, 'size': {'w': 512, 'h': 512}},
|
||||||
'lambdalabs/sd-pokemon-diffusers': {'short': 'pokemon', 'mem': 6},
|
'lambdalabs/sd-pokemon-diffusers': {'short': 'pokemon', 'mem': 6, 'size': {'w': 512, 'h': 512}},
|
||||||
'Envvi/Inkpunk-Diffusion': {'short': 'ink', 'mem': 6},
|
'Envvi/Inkpunk-Diffusion': {'short': 'ink', 'mem': 6, 'size': {'w': 512, 'h': 512}},
|
||||||
'nousr/robo-diffusion': {'short': 'robot', 'mem': 6},
|
'nousr/robo-diffusion': {'short': 'robot', 'mem': 6, 'size': {'w': 512, 'h': 512}},
|
||||||
|
|
||||||
|
# -1 is always inpaint default
|
||||||
|
'diffusers/stable-diffusion-xl-1.0-inpainting-0.1': {'short': 'stablexl-inpainting': 'mem': 8.3, 'size': {'w': 1024, 'h': 1024}},
|
||||||
|
|
||||||
# default is always last
|
# default is always last
|
||||||
'stabilityai/stable-diffusion-xl-base-1.0': {'short': 'stablexl', 'mem': 8.3},
|
'stabilityai/stable-diffusion-xl-base-1.0': {'short': 'stablexl', 'mem': 8.3, 'size': {'w': 1024, 'h': 1024}},
|
||||||
}
|
}
|
||||||
|
|
||||||
SHORT_NAMES = [
|
SHORT_NAMES = [
|
||||||
|
|
|
@ -168,7 +168,11 @@ class SkynetMM:
|
||||||
arguments = prepare_params_for_diffuse(
|
arguments = prepare_params_for_diffuse(
|
||||||
params, input_type, binary=binary)
|
params, input_type, binary=binary)
|
||||||
prompt, guidance, step, seed, upscaler, extra_params = arguments
|
prompt, guidance, step, seed, upscaler, extra_params = arguments
|
||||||
model = self.get_model(params['model'], 'image' in extra_params)
|
model = self.get_model(
|
||||||
|
params['model'],
|
||||||
|
'image' in extra_params,
|
||||||
|
'mask_image' in extra_params
|
||||||
|
)
|
||||||
|
|
||||||
output = model(
|
output = model(
|
||||||
prompt,
|
prompt,
|
||||||
|
|
|
@ -63,6 +63,7 @@ def pipeline_for(
|
||||||
model: str,
|
model: str,
|
||||||
mem_fraction: float = 1.0,
|
mem_fraction: float = 1.0,
|
||||||
image: bool = False,
|
image: bool = False,
|
||||||
|
inpainting: bool = False,
|
||||||
cache_dir: str | None = None
|
cache_dir: str | None = None
|
||||||
) -> DiffusionPipeline:
|
) -> DiffusionPipeline:
|
||||||
|
|
||||||
|
@ -102,6 +103,11 @@ def pipeline_for(
|
||||||
|
|
||||||
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
||||||
|
|
||||||
|
if inpainting:
|
||||||
|
pipe = AutoPipelineForInpainting.from_pretrained(
|
||||||
|
model, **params)
|
||||||
|
|
||||||
|
else:
|
||||||
pipe = DiffusionPipeline.from_pretrained(
|
pipe = DiffusionPipeline.from_pretrained(
|
||||||
model, **params)
|
model, **params)
|
||||||
|
|
||||||
|
@ -168,8 +174,10 @@ def img2img(
|
||||||
login(token=hf_token)
|
login(token=hf_token)
|
||||||
pipe = pipeline_for(model, image=True)
|
pipe = pipeline_for(model, image=True)
|
||||||
|
|
||||||
|
model_info = MODELS[model]
|
||||||
|
|
||||||
with open(img_path, 'rb') as img_file:
|
with open(img_path, 'rb') as img_file:
|
||||||
input_img = convert_from_bytes_and_crop(img_file.read(), 512, 512)
|
input_img = convert_from_bytes_and_crop(img_file.read(), model_info['size']['w'], model_info['size']['h'])
|
||||||
|
|
||||||
seed = seed if seed else random.randint(0, 2 ** 64)
|
seed = seed if seed else random.randint(0, 2 ** 64)
|
||||||
prompt = prompt
|
prompt = prompt
|
||||||
|
@ -184,6 +192,43 @@ def img2img(
|
||||||
image.save(output)
|
image.save(output)
|
||||||
|
|
||||||
|
|
||||||
|
def inpaint(
|
||||||
|
hf_token: str,
|
||||||
|
model: str = 'diffusers/stable-diffusion-xl-1.0-inpainting-0.1',
|
||||||
|
prompt: str = 'a red old tractor in a sunny wheat field',
|
||||||
|
img_path: str = 'input.png',
|
||||||
|
mask_path: str = 'mask.png',
|
||||||
|
output: str = 'output.png',
|
||||||
|
strength: float = 1.0,
|
||||||
|
guidance: float = 10,
|
||||||
|
steps: int = 28,
|
||||||
|
seed: Optional[int] = None
|
||||||
|
):
|
||||||
|
login(token=hf_token)
|
||||||
|
pipe = pipeline_for(model, image=True)
|
||||||
|
|
||||||
|
model_info = MODELS[model]
|
||||||
|
|
||||||
|
with open(img_path, 'rb') as img_file:
|
||||||
|
input_img = convert_from_bytes_and_crop(img_file.read(), model_info['size']['w'], model_info['size']['h'])
|
||||||
|
|
||||||
|
with open(mask_path, 'rb') as mask_file:
|
||||||
|
mask_img = convert_from_bytes_and_crop(mask_file.read(), model_info['size']['w'], model_info['size']['h'])
|
||||||
|
|
||||||
|
seed = seed if seed else random.randint(0, 2 ** 64)
|
||||||
|
prompt = prompt
|
||||||
|
image = pipe(
|
||||||
|
prompt,
|
||||||
|
image=input_img,
|
||||||
|
mask_image=mask_img
|
||||||
|
strength=strength,
|
||||||
|
guidance_scale=guidance, num_inference_steps=steps,
|
||||||
|
generator=torch.Generator("cuda").manual_seed(seed)
|
||||||
|
).images[0]
|
||||||
|
|
||||||
|
image.save(output)
|
||||||
|
|
||||||
|
|
||||||
def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'):
|
def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'):
|
||||||
return RealESRGANer(
|
return RealESRGANer(
|
||||||
scale=4,
|
scale=4,
|
||||||
|
|
Loading…
Reference in New Issue