Add upscaler

pull/2/head
Guillermo Rodriguez 2022-12-24 10:39:40 -03:00
parent 1fc2020ed5
commit a9c237b538
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
10 changed files with 178 additions and 105 deletions

View File

@ -3,6 +3,9 @@ from python:3.10.0
env DEBIAN_FRONTEND=noninteractive
run apt-get update && \
apt-get install -y ffmpeg libsm6 libxext6
workdir /skynet
copy requirements.cuda* ./
@ -27,3 +30,4 @@ env HF_HOME /hf_home
copy scripts scripts
copy tests tests

View File

@ -1,5 +1,7 @@
scipy
triton
basicsr
realesrgan
accelerate
transformers
huggingface_hub

View File

@ -1,5 +1,6 @@
trio
pynng
numpy
triopg
aiohttp
msgspec

View File

@ -204,11 +204,8 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
logging.info('txt2img')
user_config = {**(await get_user_config(conn, user))}
del user_config['id']
prompt = req.params['prompt']
req = ImageGenRequest(
prompt=prompt,
**user_config
)
user_config.update((k, req.params[k]) for k in req.params)
req = ImageGenRequest(**user_config)
rid, img, meta = await dgpu_stream_one_img(req)
logging.info(f'done streaming {rid}')
result = {
@ -217,7 +214,7 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
'meta': meta
}
await update_user_stats(conn, user, last_prompt=prompt)
await update_user_stats(conn, user, last_prompt=user_config['prompt'])
logging.info('updated user stats.')
case 'redo':

View File

@ -38,19 +38,14 @@ def txt2img(*args, **kwargs):
utils.txt2img(os.environ['HF_TOKEN'], **kwargs)
@click.command()
@click.option(
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
@click.option('--input', '-i', default='input.png')
@click.option('--output', '-o', default='output.png')
@click.option('--steps', '-s', default=26)
def upscale(prompt, input, output, steps):
assert 'HF_TOKEN' in os.environ
@click.option('--model', '-m', default='weights/RealESRGAN_x4plus.pth')
def upscale(input, output, model):
utils.upscale(
os.environ['HF_TOKEN'],
prompt=prompt,
img_path=input,
output=output,
steps=steps)
model_path=model)
@skynet.group()

View File

@ -26,37 +26,34 @@ from diffusers import (
StableDiffusionPipeline,
EulerAncestralDiscreteScheduler
)
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
from diffusers.models import UNet2DConditionModel
from .utils import (
pipeline_for,
convert_from_cv2_to_image, convert_from_image_to_cv2
)
from .structs import *
from .constants import *
from .frontend import open_skynet_rpc
def pipeline_for(algo: str, mem_fraction: float = 1.0):
assert torch.cuda.is_available()
torch.cuda.empty_cache()
torch.cuda.set_per_process_memory_fraction(mem_fraction)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
params = {
'torch_dtype': torch.float16,
'safety_checker': None
}
if algo == 'stable':
params['revision'] = 'fp16'
pipe = StableDiffusionPipeline.from_pretrained(
ALGOS[algo], **params)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe.scheduler.config)
pipe.enable_vae_slicing()
return pipe.to('cuda')
def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'):
return RealESRGANer(
scale=4,
model_path=model_path,
dni_weight=None,
model=RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4
),
half=True
)
class DGPUComputeError(BaseException):
@ -79,6 +76,7 @@ async def open_dgpu_node(
logging.info(f'loading models...')
upscaler = init_upscaler()
initial_algos = (
initial_algos
if initial_algos else DEFAULT_INITAL_ALGOS
@ -91,8 +89,8 @@ async def open_dgpu_node(
}
logging.info(f'loaded {algo}.')
logging.info('memory summary:\n')
logging.info(torch.cuda.memory_summary())
logging.info('memory summary:')
logging.info('\n' + torch.cuda.memory_summary())
async def gpu_compute_one(ireq: ImageGenRequest):
if ireq.algo not in models:
@ -118,6 +116,15 @@ async def open_dgpu_node(
num_inference_steps=ireq.step,
generator=torch.Generator("cuda").manual_seed(ireq.seed)
).images[0]
if ireq.upscaler == 'x4':
logging.info('performing upscale...')
up_img, _ = upscaler.enhance(
convert_from_image_to_cv2(image), outscale=4)
image = convert_from_cv2_to_image(up_img)
logging.info('done')
return image.tobytes()
except BaseException as e:

View File

@ -117,49 +117,50 @@ def validate_user_config_request(req: str):
try:
attr = params[1]
if attr == 'algo':
val = params[2]
if val not in ALGOS:
raise ConfigUnknownAlgorithm(f'no algo named {val}')
match attr:
case 'algo':
val = params[2]
if val not in ALGOS:
raise ConfigUnknownAlgorithm(f'no algo named {val}')
elif attr == 'step':
val = int(params[2])
val = max(min(val, MAX_STEP), MIN_STEP)
elif attr == 'width':
val = max(min(int(params[2]), MAX_WIDTH), 16)
if val % 8 != 0:
raise ConfigSizeDivisionByEight(
'size must be divisible by 8!')
elif attr == 'height':
val = max(min(int(params[2]), MAX_HEIGHT), 16)
if val % 8 != 0:
raise ConfigSizeDivisionByEight(
'size must be divisible by 8!')
elif attr == 'seed':
val = params[2]
if val == 'auto':
val = None
else:
case 'step':
val = int(params[2])
val = max(min(val, MAX_STEP), MIN_STEP)
elif attr == 'guidance':
val = float(params[2])
val = max(min(val, MAX_GUIDANCE), 0)
case 'width':
val = max(min(int(params[2]), MAX_WIDTH), 16)
if val % 8 != 0:
raise ConfigSizeDivisionByEight(
'size must be divisible by 8!')
elif attr == 'upscaler':
val = params[2]
if val == 'off':
val = None
elif val != 'x4':
raise ConfigUnknownUpscaler(
f'\"{val}\" is not a valid upscaler')
case 'height':
val = max(min(int(params[2]), MAX_HEIGHT), 16)
if val % 8 != 0:
raise ConfigSizeDivisionByEight(
'size must be divisible by 8!')
else:
raise ConfigUnknownAttribute(
f'\"{attr}\" not a configurable parameter')
case 'seed':
val = params[2]
if val == 'auto':
val = None
else:
val = int(params[2])
case 'guidance':
val = float(params[2])
val = max(min(val, MAX_GUIDANCE), 0)
case 'upscaler':
val = params[2]
if val == 'off':
val = None
elif val != 'x4':
raise ConfigUnknownUpscaler(
f'\"{val}\" is not a valid upscaler')
case _:
raise ConfigUnknownAttribute(
f'\"{attr}\" not a configurable parameter')
return attr, val, f'config updated! {attr} to {val}'

View File

@ -26,13 +26,7 @@ from pprint import pformat
import msgspec
class Struct(
msgspec.Struct,
# https://jcristharif.com/msgspec/structs.html#tagged-unions
# tag='pikerstruct',
# tag=True,
):
class Struct(msgspec.Struct):
'''
A "human friendlier" (aka repl buddy) struct subtype.
'''
@ -88,7 +82,7 @@ class Struct(
from OpenSSL.crypto import PKey, X509, verify, sign
class AuthenticatedStruct(Struct):
class AuthenticatedStruct(Struct, kw_only=True):
cert: Optional[str] = None
sig: Optional[str] = None

View File

@ -6,16 +6,55 @@ from typing import Optional
from pathlib import Path
import torch
import numpy as np
from PIL import Image
from basicsr.archs.rrdbnet_arch import RRDBNet
from diffusers import (
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
EulerAncestralDiscreteScheduler
)
from realesrgan import RealESRGANer
from huggingface_hub import login
from .dgpu import pipeline_for
from .constants import ALGOS
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
# return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
return Image.fromarray(img)
def convert_from_image_to_cv2(img: Image) -> np.ndarray:
# return cv2.cvtColor(numpy.array(img), cv2.COLOR_RGB2BGR)
return np.asarray(img)
def pipeline_for(algo: str, mem_fraction: float = 1.0):
assert torch.cuda.is_available()
torch.cuda.empty_cache()
torch.cuda.set_per_process_memory_fraction(mem_fraction)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
params = {
'torch_dtype': torch.float16,
'safety_checker': None
}
if algo == 'stable':
params['revision'] = 'fp16'
pipe = StableDiffusionPipeline.from_pretrained(
ALGOS[algo], **params)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe.scheduler.config)
pipe.enable_vae_slicing()
return pipe.to('cuda')
def txt2img(
@ -51,11 +90,9 @@ def txt2img(
def upscale(
hf_token: str,
prompt: str = 'a red old tractor in a sunny wheat field',
img_path: str = 'input.png',
output: str = 'output.png',
steps: int = 28
model_path: str = 'weights/RealESRGAN_x4plus.pth'
):
assert torch.cuda.is_available()
torch.cuda.empty_cache()
@ -63,20 +100,26 @@ def upscale(
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
login(token=hf_token)
input_img = Image.open(img_path).convert('RGB')
pipe = StableDiffusionUpscalePipeline.from_pretrained(
'stabilityai/stable-diffusion-x4-upscaler',
revision="fp16", torch_dtype=torch.float16)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe.scheduler.config)
pipe = pipe.to('cuda')
upscaler = RealESRGANer(
scale=4,
model_path=model_path,
dni_weight=None,
model=RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4
),
half=True)
up_img, _ = upscaler.enhance(
convert_from_image_to_cv2(input_img), outscale=4)
image = convert_from_cv2_to_image(up_img)
prompt = prompt
image = pipe(
prompt,
image=Image.open(img_path).convert("RGB"),
num_inference_steps=steps
).images[0]
image.save(output)

View File

@ -6,6 +6,7 @@ import json
import base64
import logging
from typing import Optional
from hashlib import sha256
from functools import partial
@ -42,7 +43,8 @@ async def check_request_img(
uid: int = 0,
width: int = 512,
height: int = 512,
expect_unique=True
expect_unique = True,
upscaler: Optional[str] = None
):
global _images
@ -60,12 +62,16 @@ async def check_request_img(
'guidance': 7.5,
'seed': None,
'algo': list(ALGOS.keys())[i],
'upscaler': None
'upscaler': upscaler
})
if 'error' in res.result:
raise SkynetDGPUComputeError(json.dumps(res.result))
if upscaler == 'x4':
width *= 4
height *= 4
img_raw = base64.b64decode(bytes.fromhex(res.result['img']))
img_sha = sha256(img_raw).hexdigest()
img = Image.frombytes(
@ -80,6 +86,8 @@ async def check_request_img(
assert len(img_raw) > 100000
return img
@pytest.mark.parametrize(
'dgpu_workers', [(1, ['midj'])], indirect=True)
@ -123,6 +131,27 @@ async def test_dgpu_workers(dgpu_workers):
await check_request_img(1)
@pytest.mark.parametrize(
'dgpu_workers', [(1, ['midj'])], indirect=True)
async def test_dgpu_worker_upscale(dgpu_workers):
'''Generate two images in a single dgpu worker using
two different models.
'''
async with open_skynet_rpc(
'test-ctx',
security=True,
cert_name='whitelist/testing',
key_name='testing'
) as test_rpc:
await wait_for_dgpus(test_rpc, 1)
logging.error('UPSCALE')
img = await check_request_img(0, upscaler='x4')
assert img.size == (2048, 2048)
@pytest.mark.parametrize(
'dgpu_workers', [(2, ['midj'])], indirect=True)
async def test_dgpu_workers_two(dgpu_workers):