mirror of https://github.com/skygpu/skynet.git
318 lines
8.5 KiB
Python
Executable File
318 lines
8.5 KiB
Python
Executable File
import io
|
|
import os
|
|
import sys
|
|
import time
|
|
import random
|
|
import logging
|
|
import importlib
|
|
|
|
from typing import Optional
|
|
from contextlib import contextmanager
|
|
|
|
import torch
|
|
import diffusers
|
|
import numpy as np
|
|
|
|
from PIL import Image
|
|
from diffusers import (
|
|
DiffusionPipeline,
|
|
AutoPipelineForText2Image,
|
|
AutoPipelineForImage2Image,
|
|
AutoPipelineForInpainting,
|
|
EulerAncestralDiscreteScheduler,
|
|
)
|
|
from huggingface_hub import login
|
|
|
|
from skynet.constants import MODELS
|
|
|
|
# Hack to fix a changed import in torchvision 0.17+, which otherwise breaks
|
|
# basicsr; see https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/13985
|
|
try:
|
|
import torchvision.transforms.functional_tensor # noqa: F401
|
|
except ImportError:
|
|
try:
|
|
import torchvision.transforms.functional as functional
|
|
sys.modules["torchvision.transforms.functional_tensor"] = functional
|
|
except ImportError:
|
|
pass # shrug...
|
|
|
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
|
from realesrgan import RealESRGANer
|
|
|
|
|
|
|
|
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
|
|
# return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
|
return Image.fromarray(img)
|
|
|
|
|
|
def convert_from_image_to_cv2(img: Image) -> np.ndarray:
|
|
# return cv2.cvtColor(numpy.array(img), cv2.COLOR_RGB2BGR)
|
|
return np.asarray(img)
|
|
|
|
|
|
def convert_from_bytes_to_img(raw: bytes) -> Image:
|
|
return Image.open(io.BytesIO(raw))
|
|
|
|
|
|
def convert_from_img_to_bytes(image: Image, fmt='PNG') -> bytes:
|
|
byte_arr = io.BytesIO()
|
|
image.save(byte_arr, format=fmt)
|
|
return byte_arr.getvalue()
|
|
|
|
|
|
def crop_image(image: Image, max_w: int, max_h: int) -> Image:
|
|
w, h = image.size
|
|
if w > max_w or h > max_h:
|
|
image.thumbnail((max_w, max_h))
|
|
|
|
return image.convert('RGB')
|
|
|
|
def convert_from_bytes_and_crop(raw: bytes, max_w: int, max_h: int) -> Image:
|
|
return crop_image(convert_from_bytes_to_img(raw), max_w, max_h)
|
|
|
|
|
|
class DummyPB:
|
|
def update(self):
|
|
...
|
|
|
|
@torch.compiler.disable
|
|
@contextmanager
|
|
def dummy_progress_bar(*args, **kwargs):
|
|
yield DummyPB()
|
|
|
|
|
|
def monkey_patch_pipeline_disable_progress_bar(pipe):
|
|
pipe.progress_bar = dummy_progress_bar
|
|
|
|
|
|
def pipeline_for(
|
|
model: str,
|
|
mode: str,
|
|
mem_fraction: float = 1.0,
|
|
cache_dir: str | None = None
|
|
) -> DiffusionPipeline:
|
|
diffusers.utils.logging.disable_progress_bar()
|
|
|
|
logging.info(f'pipeline_for {model} {mode}')
|
|
# assert torch.cuda.is_available()
|
|
torch.cuda.empty_cache()
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
|
|
# full determinism
|
|
# https://huggingface.co/docs/diffusers/using-diffusers/reproducibility#deterministic-algorithms
|
|
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
|
|
|
torch.backends.cudnn.benchmark = False
|
|
torch.use_deterministic_algorithms(True)
|
|
|
|
model_info = MODELS[model]
|
|
shortname = model_info.short
|
|
|
|
# disable for compat with "diffuse" method
|
|
# assert mode in model_info.tags
|
|
|
|
# default to checking if custom pipeline exist and return that if not, attempt generic
|
|
try:
|
|
normalized_shortname = shortname.replace('-', '_')
|
|
custom_pipeline = importlib.import_module(f'skynet.dgpu.pipes.{normalized_shortname}')
|
|
assert custom_pipeline.__model['name'] == model
|
|
pipe = custom_pipeline.pipeline_for(model, mode, mem_fraction=mem_fraction, cache_dir=cache_dir)
|
|
monkey_patch_pipeline_disable_progress_bar(pipe)
|
|
return pipe
|
|
|
|
except ImportError:
|
|
logging.info(f'didn\'t find a custom pipeline file for {shortname}')
|
|
|
|
|
|
req_mem = model_info.mem
|
|
|
|
mem_gb = torch.cuda.mem_get_info()[1] / (10**9)
|
|
mem_gb *= mem_fraction
|
|
over_mem = mem_gb < req_mem
|
|
if over_mem:
|
|
logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..')
|
|
|
|
params = {
|
|
'torch_dtype': torch.float16,
|
|
'cache_dir': cache_dir,
|
|
'variant': 'fp16',
|
|
}
|
|
|
|
match shortname:
|
|
case 'stable':
|
|
params['revision'] = 'fp16'
|
|
params['safety_checker'] = None
|
|
|
|
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
|
|
|
pipe_class = DiffusionPipeline
|
|
match mode:
|
|
case 'inpaint':
|
|
pipe_class = AutoPipelineForInpainting
|
|
|
|
case 'img2img':
|
|
pipe_class = AutoPipelineForImage2Image
|
|
|
|
case 'txt2img':
|
|
pipe_class = AutoPipelineForText2Image
|
|
|
|
pipe = pipe_class.from_pretrained(
|
|
model, **params)
|
|
|
|
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
|
pipe.scheduler.config)
|
|
|
|
# pipe.enable_xformers_memory_efficient_attention()
|
|
|
|
if over_mem:
|
|
if mode == 'txt2img':
|
|
pipe.vae.enable_tiling()
|
|
pipe.vae.enable_slicing()
|
|
|
|
pipe.enable_model_cpu_offload()
|
|
|
|
else:
|
|
# if sys.version_info[1] < 11:
|
|
# # torch.compile only supported on python < 3.11
|
|
# pipe.unet = torch.compile(
|
|
# pipe.unet, mode='reduce-overhead', fullgraph=True)
|
|
|
|
pipe = pipe.to('cuda')
|
|
|
|
monkey_patch_pipeline_disable_progress_bar(pipe)
|
|
|
|
return pipe
|
|
|
|
|
|
def txt2img(
|
|
hf_token: str,
|
|
model: str = list(MODELS.keys())[-1],
|
|
prompt: str = 'a red old tractor in a sunny wheat field',
|
|
output: str = 'output.png',
|
|
width: int = 512, height: int = 512,
|
|
guidance: float = 10,
|
|
steps: int = 28,
|
|
seed: Optional[int] = None
|
|
):
|
|
login(token=hf_token)
|
|
pipe = pipeline_for(model, 'txt2img')
|
|
|
|
seed = seed if seed else random.randint(0, 2 ** 64)
|
|
prompt = prompt
|
|
image = pipe(
|
|
prompt,
|
|
width=width,
|
|
height=height,
|
|
guidance_scale=guidance, num_inference_steps=steps,
|
|
generator=torch.Generator("cuda").manual_seed(seed)
|
|
).images[0]
|
|
|
|
image.save(output)
|
|
|
|
|
|
def img2img(
|
|
hf_token: str,
|
|
model: str = list(MODELS.keys())[-2],
|
|
prompt: str = 'a red old tractor in a sunny wheat field',
|
|
img_path: str = 'input.png',
|
|
output: str = 'output.png',
|
|
strength: float = 1.0,
|
|
guidance: float = 10,
|
|
steps: int = 28,
|
|
seed: Optional[int] = None
|
|
):
|
|
login(token=hf_token)
|
|
pipe = pipeline_for(model, 'img2img')
|
|
|
|
model_info = MODELS[model]
|
|
|
|
with open(img_path, 'rb') as img_file:
|
|
input_img = convert_from_bytes_and_crop(img_file.read(), model_info.size.w, model_info.size.h)
|
|
|
|
seed = seed if seed else random.randint(0, 2 ** 64)
|
|
prompt = prompt
|
|
image = pipe(
|
|
prompt,
|
|
image=input_img,
|
|
strength=strength,
|
|
guidance_scale=guidance, num_inference_steps=steps,
|
|
generator=torch.Generator("cuda").manual_seed(seed)
|
|
).images[0]
|
|
|
|
image.save(output)
|
|
|
|
|
|
def inpaint(
|
|
hf_token: str,
|
|
model: str = list(MODELS.keys())[-3],
|
|
prompt: str = 'a red old tractor in a sunny wheat field',
|
|
img_path: str = 'input.png',
|
|
mask_path: str = 'mask.png',
|
|
output: str = 'output.png',
|
|
strength: float = 1.0,
|
|
guidance: float = 10,
|
|
steps: int = 28,
|
|
seed: Optional[int] = None
|
|
):
|
|
login(token=hf_token)
|
|
pipe = pipeline_for(model, 'inpaint')
|
|
|
|
model_info = MODELS[model]
|
|
|
|
with open(img_path, 'rb') as img_file:
|
|
input_img = convert_from_bytes_and_crop(img_file.read(), model_info.size.w, model_info.size.h)
|
|
|
|
with open(mask_path, 'rb') as mask_file:
|
|
mask_img = convert_from_bytes_and_crop(mask_file.read(), model_info.size.w, model_info.size.h)
|
|
|
|
var_params = {}
|
|
if 'flux' not in model.lower():
|
|
var_params['strength'] = strength
|
|
|
|
seed = seed if seed else random.randint(0, 2 ** 64)
|
|
prompt = prompt
|
|
image = pipe(
|
|
prompt,
|
|
image=input_img,
|
|
mask_image=mask_img,
|
|
guidance_scale=guidance, num_inference_steps=steps,
|
|
generator=torch.Generator("cuda").manual_seed(seed),
|
|
**var_params
|
|
).images[0]
|
|
|
|
image.save(output)
|
|
|
|
|
|
def init_upscaler(model_path: str = 'hf_home/RealESRGAN_x4plus.pth'):
|
|
return RealESRGANer(
|
|
scale=4,
|
|
model_path=model_path,
|
|
dni_weight=None,
|
|
model=RRDBNet(
|
|
num_in_ch=3,
|
|
num_out_ch=3,
|
|
num_feat=64,
|
|
num_block=23,
|
|
num_grow_ch=32,
|
|
scale=4
|
|
),
|
|
half=True
|
|
)
|
|
|
|
def upscale(
|
|
img_path: str = 'input.png',
|
|
output: str = 'output.png',
|
|
model_path: str = 'hf_home/RealESRGAN_x4plus.pth'
|
|
):
|
|
input_img = Image.open(img_path).convert('RGB')
|
|
|
|
upscaler = init_upscaler(model_path=model_path)
|
|
|
|
up_img, _ = upscaler.enhance(
|
|
convert_from_image_to_cv2(input_img), outscale=4)
|
|
|
|
image = convert_from_cv2_to_image(up_img)
|
|
image.save(output)
|