Add upscaler

pull/2/head
Guillermo Rodriguez 2022-12-24 10:39:40 -03:00
parent 1fc2020ed5
commit a9c237b538
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
10 changed files with 178 additions and 105 deletions

View File

@ -3,6 +3,9 @@ from python:3.10.0
env DEBIAN_FRONTEND=noninteractive env DEBIAN_FRONTEND=noninteractive
run apt-get update && \
apt-get install -y ffmpeg libsm6 libxext6
workdir /skynet workdir /skynet
copy requirements.cuda* ./ copy requirements.cuda* ./
@ -27,3 +30,4 @@ env HF_HOME /hf_home
copy scripts scripts copy scripts scripts
copy tests tests copy tests tests

View File

@ -1,5 +1,7 @@
scipy scipy
triton triton
basicsr
realesrgan
accelerate accelerate
transformers transformers
huggingface_hub huggingface_hub

View File

@ -1,5 +1,6 @@
trio trio
pynng pynng
numpy
triopg triopg
aiohttp aiohttp
msgspec msgspec

View File

@ -204,11 +204,8 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
logging.info('txt2img') logging.info('txt2img')
user_config = {**(await get_user_config(conn, user))} user_config = {**(await get_user_config(conn, user))}
del user_config['id'] del user_config['id']
prompt = req.params['prompt'] user_config.update((k, req.params[k]) for k in req.params)
req = ImageGenRequest( req = ImageGenRequest(**user_config)
prompt=prompt,
**user_config
)
rid, img, meta = await dgpu_stream_one_img(req) rid, img, meta = await dgpu_stream_one_img(req)
logging.info(f'done streaming {rid}') logging.info(f'done streaming {rid}')
result = { result = {
@ -217,7 +214,7 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
'meta': meta 'meta': meta
} }
await update_user_stats(conn, user, last_prompt=prompt) await update_user_stats(conn, user, last_prompt=user_config['prompt'])
logging.info('updated user stats.') logging.info('updated user stats.')
case 'redo': case 'redo':

View File

@ -38,19 +38,14 @@ def txt2img(*args, **kwargs):
utils.txt2img(os.environ['HF_TOKEN'], **kwargs) utils.txt2img(os.environ['HF_TOKEN'], **kwargs)
@click.command() @click.command()
@click.option(
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
@click.option('--input', '-i', default='input.png') @click.option('--input', '-i', default='input.png')
@click.option('--output', '-o', default='output.png') @click.option('--output', '-o', default='output.png')
@click.option('--steps', '-s', default=26) @click.option('--model', '-m', default='weights/RealESRGAN_x4plus.pth')
def upscale(prompt, input, output, steps): def upscale(input, output, model):
assert 'HF_TOKEN' in os.environ
utils.upscale( utils.upscale(
os.environ['HF_TOKEN'],
prompt=prompt,
img_path=input, img_path=input,
output=output, output=output,
steps=steps) model_path=model)
@skynet.group() @skynet.group()

View File

@ -26,37 +26,34 @@ from diffusers import (
StableDiffusionPipeline, StableDiffusionPipeline,
EulerAncestralDiscreteScheduler EulerAncestralDiscreteScheduler
) )
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
from diffusers.models import UNet2DConditionModel from diffusers.models import UNet2DConditionModel
from .utils import (
pipeline_for,
convert_from_cv2_to_image, convert_from_image_to_cv2
)
from .structs import * from .structs import *
from .constants import * from .constants import *
from .frontend import open_skynet_rpc from .frontend import open_skynet_rpc
def pipeline_for(algo: str, mem_fraction: float = 1.0): def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'):
assert torch.cuda.is_available() return RealESRGANer(
torch.cuda.empty_cache() scale=4,
torch.cuda.set_per_process_memory_fraction(mem_fraction) model_path=model_path,
torch.backends.cuda.matmul.allow_tf32 = True dni_weight=None,
torch.backends.cudnn.allow_tf32 = True model=RRDBNet(
num_in_ch=3,
params = { num_out_ch=3,
'torch_dtype': torch.float16, num_feat=64,
'safety_checker': None num_block=23,
} num_grow_ch=32,
scale=4
if algo == 'stable': ),
params['revision'] = 'fp16' half=True
)
pipe = StableDiffusionPipeline.from_pretrained(
ALGOS[algo], **params)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe.scheduler.config)
pipe.enable_vae_slicing()
return pipe.to('cuda')
class DGPUComputeError(BaseException): class DGPUComputeError(BaseException):
@ -79,6 +76,7 @@ async def open_dgpu_node(
logging.info(f'loading models...') logging.info(f'loading models...')
upscaler = init_upscaler()
initial_algos = ( initial_algos = (
initial_algos initial_algos
if initial_algos else DEFAULT_INITAL_ALGOS if initial_algos else DEFAULT_INITAL_ALGOS
@ -91,8 +89,8 @@ async def open_dgpu_node(
} }
logging.info(f'loaded {algo}.') logging.info(f'loaded {algo}.')
logging.info('memory summary:\n') logging.info('memory summary:')
logging.info(torch.cuda.memory_summary()) logging.info('\n' + torch.cuda.memory_summary())
async def gpu_compute_one(ireq: ImageGenRequest): async def gpu_compute_one(ireq: ImageGenRequest):
if ireq.algo not in models: if ireq.algo not in models:
@ -118,6 +116,15 @@ async def open_dgpu_node(
num_inference_steps=ireq.step, num_inference_steps=ireq.step,
generator=torch.Generator("cuda").manual_seed(ireq.seed) generator=torch.Generator("cuda").manual_seed(ireq.seed)
).images[0] ).images[0]
if ireq.upscaler == 'x4':
logging.info('performing upscale...')
up_img, _ = upscaler.enhance(
convert_from_image_to_cv2(image), outscale=4)
image = convert_from_cv2_to_image(up_img)
logging.info('done')
return image.tobytes() return image.tobytes()
except BaseException as e: except BaseException as e:

View File

@ -117,39 +117,40 @@ def validate_user_config_request(req: str):
try: try:
attr = params[1] attr = params[1]
if attr == 'algo': match attr:
case 'algo':
val = params[2] val = params[2]
if val not in ALGOS: if val not in ALGOS:
raise ConfigUnknownAlgorithm(f'no algo named {val}') raise ConfigUnknownAlgorithm(f'no algo named {val}')
elif attr == 'step': case 'step':
val = int(params[2]) val = int(params[2])
val = max(min(val, MAX_STEP), MIN_STEP) val = max(min(val, MAX_STEP), MIN_STEP)
elif attr == 'width': case 'width':
val = max(min(int(params[2]), MAX_WIDTH), 16) val = max(min(int(params[2]), MAX_WIDTH), 16)
if val % 8 != 0: if val % 8 != 0:
raise ConfigSizeDivisionByEight( raise ConfigSizeDivisionByEight(
'size must be divisible by 8!') 'size must be divisible by 8!')
elif attr == 'height': case 'height':
val = max(min(int(params[2]), MAX_HEIGHT), 16) val = max(min(int(params[2]), MAX_HEIGHT), 16)
if val % 8 != 0: if val % 8 != 0:
raise ConfigSizeDivisionByEight( raise ConfigSizeDivisionByEight(
'size must be divisible by 8!') 'size must be divisible by 8!')
elif attr == 'seed': case 'seed':
val = params[2] val = params[2]
if val == 'auto': if val == 'auto':
val = None val = None
else: else:
val = int(params[2]) val = int(params[2])
elif attr == 'guidance': case 'guidance':
val = float(params[2]) val = float(params[2])
val = max(min(val, MAX_GUIDANCE), 0) val = max(min(val, MAX_GUIDANCE), 0)
elif attr == 'upscaler': case 'upscaler':
val = params[2] val = params[2]
if val == 'off': if val == 'off':
val = None val = None
@ -157,7 +158,7 @@ def validate_user_config_request(req: str):
raise ConfigUnknownUpscaler( raise ConfigUnknownUpscaler(
f'\"{val}\" is not a valid upscaler') f'\"{val}\" is not a valid upscaler')
else: case _:
raise ConfigUnknownAttribute( raise ConfigUnknownAttribute(
f'\"{attr}\" not a configurable parameter') f'\"{attr}\" not a configurable parameter')

View File

@ -26,13 +26,7 @@ from pprint import pformat
import msgspec import msgspec
class Struct( class Struct(msgspec.Struct):
msgspec.Struct,
# https://jcristharif.com/msgspec/structs.html#tagged-unions
# tag='pikerstruct',
# tag=True,
):
''' '''
A "human friendlier" (aka repl buddy) struct subtype. A "human friendlier" (aka repl buddy) struct subtype.
''' '''
@ -88,7 +82,7 @@ class Struct(
from OpenSSL.crypto import PKey, X509, verify, sign from OpenSSL.crypto import PKey, X509, verify, sign
class AuthenticatedStruct(Struct): class AuthenticatedStruct(Struct, kw_only=True):
cert: Optional[str] = None cert: Optional[str] = None
sig: Optional[str] = None sig: Optional[str] = None

View File

@ -6,16 +6,55 @@ from typing import Optional
from pathlib import Path from pathlib import Path
import torch import torch
import numpy as np
from PIL import Image from PIL import Image
from basicsr.archs.rrdbnet_arch import RRDBNet
from diffusers import ( from diffusers import (
StableDiffusionPipeline, StableDiffusionPipeline,
StableDiffusionUpscalePipeline, StableDiffusionUpscalePipeline,
EulerAncestralDiscreteScheduler EulerAncestralDiscreteScheduler
) )
from realesrgan import RealESRGANer
from huggingface_hub import login from huggingface_hub import login
from .dgpu import pipeline_for from .constants import ALGOS
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
# return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
return Image.fromarray(img)
def convert_from_image_to_cv2(img: Image) -> np.ndarray:
# return cv2.cvtColor(numpy.array(img), cv2.COLOR_RGB2BGR)
return np.asarray(img)
def pipeline_for(algo: str, mem_fraction: float = 1.0):
assert torch.cuda.is_available()
torch.cuda.empty_cache()
torch.cuda.set_per_process_memory_fraction(mem_fraction)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
params = {
'torch_dtype': torch.float16,
'safety_checker': None
}
if algo == 'stable':
params['revision'] = 'fp16'
pipe = StableDiffusionPipeline.from_pretrained(
ALGOS[algo], **params)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe.scheduler.config)
pipe.enable_vae_slicing()
return pipe.to('cuda')
def txt2img( def txt2img(
@ -51,11 +90,9 @@ def txt2img(
def upscale( def upscale(
hf_token: str,
prompt: str = 'a red old tractor in a sunny wheat field',
img_path: str = 'input.png', img_path: str = 'input.png',
output: str = 'output.png', output: str = 'output.png',
steps: int = 28 model_path: str = 'weights/RealESRGAN_x4plus.pth'
): ):
assert torch.cuda.is_available() assert torch.cuda.is_available()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -63,20 +100,26 @@ def upscale(
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
login(token=hf_token) input_img = Image.open(img_path).convert('RGB')
pipe = StableDiffusionUpscalePipeline.from_pretrained( upscaler = RealESRGANer(
'stabilityai/stable-diffusion-x4-upscaler', scale=4,
revision="fp16", torch_dtype=torch.float16) model_path=model_path,
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( dni_weight=None,
pipe.scheduler.config) model=RRDBNet(
pipe = pipe.to('cuda') num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4
),
half=True)
up_img, _ = upscaler.enhance(
convert_from_image_to_cv2(input_img), outscale=4)
image = convert_from_cv2_to_image(up_img)
prompt = prompt
image = pipe(
prompt,
image=Image.open(img_path).convert("RGB"),
num_inference_steps=steps
).images[0]
image.save(output) image.save(output)

View File

@ -6,6 +6,7 @@ import json
import base64 import base64
import logging import logging
from typing import Optional
from hashlib import sha256 from hashlib import sha256
from functools import partial from functools import partial
@ -42,7 +43,8 @@ async def check_request_img(
uid: int = 0, uid: int = 0,
width: int = 512, width: int = 512,
height: int = 512, height: int = 512,
expect_unique=True expect_unique = True,
upscaler: Optional[str] = None
): ):
global _images global _images
@ -60,12 +62,16 @@ async def check_request_img(
'guidance': 7.5, 'guidance': 7.5,
'seed': None, 'seed': None,
'algo': list(ALGOS.keys())[i], 'algo': list(ALGOS.keys())[i],
'upscaler': None 'upscaler': upscaler
}) })
if 'error' in res.result: if 'error' in res.result:
raise SkynetDGPUComputeError(json.dumps(res.result)) raise SkynetDGPUComputeError(json.dumps(res.result))
if upscaler == 'x4':
width *= 4
height *= 4
img_raw = base64.b64decode(bytes.fromhex(res.result['img'])) img_raw = base64.b64decode(bytes.fromhex(res.result['img']))
img_sha = sha256(img_raw).hexdigest() img_sha = sha256(img_raw).hexdigest()
img = Image.frombytes( img = Image.frombytes(
@ -80,6 +86,8 @@ async def check_request_img(
assert len(img_raw) > 100000 assert len(img_raw) > 100000
return img
@pytest.mark.parametrize( @pytest.mark.parametrize(
'dgpu_workers', [(1, ['midj'])], indirect=True) 'dgpu_workers', [(1, ['midj'])], indirect=True)
@ -123,6 +131,27 @@ async def test_dgpu_workers(dgpu_workers):
await check_request_img(1) await check_request_img(1)
@pytest.mark.parametrize(
'dgpu_workers', [(1, ['midj'])], indirect=True)
async def test_dgpu_worker_upscale(dgpu_workers):
'''Generate two images in a single dgpu worker using
two different models.
'''
async with open_skynet_rpc(
'test-ctx',
security=True,
cert_name='whitelist/testing',
key_name='testing'
) as test_rpc:
await wait_for_dgpus(test_rpc, 1)
logging.error('UPSCALE')
img = await check_request_img(0, upscaler='x4')
assert img.size == (2048, 2048)
@pytest.mark.parametrize( @pytest.mark.parametrize(
'dgpu_workers', [(2, ['midj'])], indirect=True) 'dgpu_workers', [(2, ['midj'])], indirect=True)
async def test_dgpu_workers_two(dgpu_workers): async def test_dgpu_workers_two(dgpu_workers):