mirror of https://github.com/skygpu/skynet.git
Add upscaler
parent
1fc2020ed5
commit
a9c237b538
|
@ -3,6 +3,9 @@ from python:3.10.0
|
|||
|
||||
env DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
run apt-get update && \
|
||||
apt-get install -y ffmpeg libsm6 libxext6
|
||||
|
||||
workdir /skynet
|
||||
|
||||
copy requirements.cuda* ./
|
||||
|
@ -27,3 +30,4 @@ env HF_HOME /hf_home
|
|||
|
||||
copy scripts scripts
|
||||
copy tests tests
|
||||
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
scipy
|
||||
triton
|
||||
basicsr
|
||||
realesrgan
|
||||
accelerate
|
||||
transformers
|
||||
huggingface_hub
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
trio
|
||||
pynng
|
||||
numpy
|
||||
triopg
|
||||
aiohttp
|
||||
msgspec
|
||||
|
|
|
@ -204,11 +204,8 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
|||
logging.info('txt2img')
|
||||
user_config = {**(await get_user_config(conn, user))}
|
||||
del user_config['id']
|
||||
prompt = req.params['prompt']
|
||||
req = ImageGenRequest(
|
||||
prompt=prompt,
|
||||
**user_config
|
||||
)
|
||||
user_config.update((k, req.params[k]) for k in req.params)
|
||||
req = ImageGenRequest(**user_config)
|
||||
rid, img, meta = await dgpu_stream_one_img(req)
|
||||
logging.info(f'done streaming {rid}')
|
||||
result = {
|
||||
|
@ -217,7 +214,7 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
|||
'meta': meta
|
||||
}
|
||||
|
||||
await update_user_stats(conn, user, last_prompt=prompt)
|
||||
await update_user_stats(conn, user, last_prompt=user_config['prompt'])
|
||||
logging.info('updated user stats.')
|
||||
|
||||
case 'redo':
|
||||
|
|
|
@ -38,19 +38,14 @@ def txt2img(*args, **kwargs):
|
|||
utils.txt2img(os.environ['HF_TOKEN'], **kwargs)
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
|
||||
@click.option('--input', '-i', default='input.png')
|
||||
@click.option('--output', '-o', default='output.png')
|
||||
@click.option('--steps', '-s', default=26)
|
||||
def upscale(prompt, input, output, steps):
|
||||
assert 'HF_TOKEN' in os.environ
|
||||
@click.option('--model', '-m', default='weights/RealESRGAN_x4plus.pth')
|
||||
def upscale(input, output, model):
|
||||
utils.upscale(
|
||||
os.environ['HF_TOKEN'],
|
||||
prompt=prompt,
|
||||
img_path=input,
|
||||
output=output,
|
||||
steps=steps)
|
||||
model_path=model)
|
||||
|
||||
|
||||
@skynet.group()
|
||||
|
|
|
@ -26,37 +26,34 @@ from diffusers import (
|
|||
StableDiffusionPipeline,
|
||||
EulerAncestralDiscreteScheduler
|
||||
)
|
||||
from realesrgan import RealESRGANer
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
|
||||
from .utils import (
|
||||
pipeline_for,
|
||||
convert_from_cv2_to_image, convert_from_image_to_cv2
|
||||
)
|
||||
from .structs import *
|
||||
from .constants import *
|
||||
from .frontend import open_skynet_rpc
|
||||
|
||||
|
||||
def pipeline_for(algo: str, mem_fraction: float = 1.0):
|
||||
assert torch.cuda.is_available()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
params = {
|
||||
'torch_dtype': torch.float16,
|
||||
'safety_checker': None
|
||||
}
|
||||
|
||||
if algo == 'stable':
|
||||
params['revision'] = 'fp16'
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
ALGOS[algo], **params)
|
||||
|
||||
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
||||
pipe.scheduler.config)
|
||||
|
||||
pipe.enable_vae_slicing()
|
||||
|
||||
return pipe.to('cuda')
|
||||
def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'):
|
||||
return RealESRGANer(
|
||||
scale=4,
|
||||
model_path=model_path,
|
||||
dni_weight=None,
|
||||
model=RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=4
|
||||
),
|
||||
half=True
|
||||
)
|
||||
|
||||
|
||||
class DGPUComputeError(BaseException):
|
||||
|
@ -79,6 +76,7 @@ async def open_dgpu_node(
|
|||
|
||||
logging.info(f'loading models...')
|
||||
|
||||
upscaler = init_upscaler()
|
||||
initial_algos = (
|
||||
initial_algos
|
||||
if initial_algos else DEFAULT_INITAL_ALGOS
|
||||
|
@ -91,8 +89,8 @@ async def open_dgpu_node(
|
|||
}
|
||||
logging.info(f'loaded {algo}.')
|
||||
|
||||
logging.info('memory summary:\n')
|
||||
logging.info(torch.cuda.memory_summary())
|
||||
logging.info('memory summary:')
|
||||
logging.info('\n' + torch.cuda.memory_summary())
|
||||
|
||||
async def gpu_compute_one(ireq: ImageGenRequest):
|
||||
if ireq.algo not in models:
|
||||
|
@ -118,6 +116,15 @@ async def open_dgpu_node(
|
|||
num_inference_steps=ireq.step,
|
||||
generator=torch.Generator("cuda").manual_seed(ireq.seed)
|
||||
).images[0]
|
||||
|
||||
if ireq.upscaler == 'x4':
|
||||
logging.info('performing upscale...')
|
||||
up_img, _ = upscaler.enhance(
|
||||
convert_from_image_to_cv2(image), outscale=4)
|
||||
|
||||
image = convert_from_cv2_to_image(up_img)
|
||||
logging.info('done')
|
||||
|
||||
return image.tobytes()
|
||||
|
||||
except BaseException as e:
|
||||
|
|
|
@ -117,49 +117,50 @@ def validate_user_config_request(req: str):
|
|||
try:
|
||||
attr = params[1]
|
||||
|
||||
if attr == 'algo':
|
||||
val = params[2]
|
||||
if val not in ALGOS:
|
||||
raise ConfigUnknownAlgorithm(f'no algo named {val}')
|
||||
match attr:
|
||||
case 'algo':
|
||||
val = params[2]
|
||||
if val not in ALGOS:
|
||||
raise ConfigUnknownAlgorithm(f'no algo named {val}')
|
||||
|
||||
elif attr == 'step':
|
||||
val = int(params[2])
|
||||
val = max(min(val, MAX_STEP), MIN_STEP)
|
||||
|
||||
elif attr == 'width':
|
||||
val = max(min(int(params[2]), MAX_WIDTH), 16)
|
||||
if val % 8 != 0:
|
||||
raise ConfigSizeDivisionByEight(
|
||||
'size must be divisible by 8!')
|
||||
|
||||
elif attr == 'height':
|
||||
val = max(min(int(params[2]), MAX_HEIGHT), 16)
|
||||
if val % 8 != 0:
|
||||
raise ConfigSizeDivisionByEight(
|
||||
'size must be divisible by 8!')
|
||||
|
||||
elif attr == 'seed':
|
||||
val = params[2]
|
||||
if val == 'auto':
|
||||
val = None
|
||||
else:
|
||||
case 'step':
|
||||
val = int(params[2])
|
||||
val = max(min(val, MAX_STEP), MIN_STEP)
|
||||
|
||||
elif attr == 'guidance':
|
||||
val = float(params[2])
|
||||
val = max(min(val, MAX_GUIDANCE), 0)
|
||||
case 'width':
|
||||
val = max(min(int(params[2]), MAX_WIDTH), 16)
|
||||
if val % 8 != 0:
|
||||
raise ConfigSizeDivisionByEight(
|
||||
'size must be divisible by 8!')
|
||||
|
||||
elif attr == 'upscaler':
|
||||
val = params[2]
|
||||
if val == 'off':
|
||||
val = None
|
||||
elif val != 'x4':
|
||||
raise ConfigUnknownUpscaler(
|
||||
f'\"{val}\" is not a valid upscaler')
|
||||
case 'height':
|
||||
val = max(min(int(params[2]), MAX_HEIGHT), 16)
|
||||
if val % 8 != 0:
|
||||
raise ConfigSizeDivisionByEight(
|
||||
'size must be divisible by 8!')
|
||||
|
||||
else:
|
||||
raise ConfigUnknownAttribute(
|
||||
f'\"{attr}\" not a configurable parameter')
|
||||
case 'seed':
|
||||
val = params[2]
|
||||
if val == 'auto':
|
||||
val = None
|
||||
else:
|
||||
val = int(params[2])
|
||||
|
||||
case 'guidance':
|
||||
val = float(params[2])
|
||||
val = max(min(val, MAX_GUIDANCE), 0)
|
||||
|
||||
case 'upscaler':
|
||||
val = params[2]
|
||||
if val == 'off':
|
||||
val = None
|
||||
elif val != 'x4':
|
||||
raise ConfigUnknownUpscaler(
|
||||
f'\"{val}\" is not a valid upscaler')
|
||||
|
||||
case _:
|
||||
raise ConfigUnknownAttribute(
|
||||
f'\"{attr}\" not a configurable parameter')
|
||||
|
||||
return attr, val, f'config updated! {attr} to {val}'
|
||||
|
||||
|
|
|
@ -26,13 +26,7 @@ from pprint import pformat
|
|||
import msgspec
|
||||
|
||||
|
||||
class Struct(
|
||||
msgspec.Struct,
|
||||
|
||||
# https://jcristharif.com/msgspec/structs.html#tagged-unions
|
||||
# tag='pikerstruct',
|
||||
# tag=True,
|
||||
):
|
||||
class Struct(msgspec.Struct):
|
||||
'''
|
||||
A "human friendlier" (aka repl buddy) struct subtype.
|
||||
'''
|
||||
|
@ -88,7 +82,7 @@ class Struct(
|
|||
from OpenSSL.crypto import PKey, X509, verify, sign
|
||||
|
||||
|
||||
class AuthenticatedStruct(Struct):
|
||||
class AuthenticatedStruct(Struct, kw_only=True):
|
||||
cert: Optional[str] = None
|
||||
sig: Optional[str] = None
|
||||
|
||||
|
|
|
@ -6,16 +6,55 @@ from typing import Optional
|
|||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from diffusers import (
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
EulerAncestralDiscreteScheduler
|
||||
)
|
||||
from realesrgan import RealESRGANer
|
||||
from huggingface_hub import login
|
||||
|
||||
from .dgpu import pipeline_for
|
||||
from .constants import ALGOS
|
||||
|
||||
|
||||
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
|
||||
# return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
||||
return Image.fromarray(img)
|
||||
|
||||
|
||||
def convert_from_image_to_cv2(img: Image) -> np.ndarray:
|
||||
# return cv2.cvtColor(numpy.array(img), cv2.COLOR_RGB2BGR)
|
||||
return np.asarray(img)
|
||||
|
||||
|
||||
def pipeline_for(algo: str, mem_fraction: float = 1.0):
|
||||
assert torch.cuda.is_available()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
params = {
|
||||
'torch_dtype': torch.float16,
|
||||
'safety_checker': None
|
||||
}
|
||||
|
||||
if algo == 'stable':
|
||||
params['revision'] = 'fp16'
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
ALGOS[algo], **params)
|
||||
|
||||
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
||||
pipe.scheduler.config)
|
||||
|
||||
pipe.enable_vae_slicing()
|
||||
|
||||
return pipe.to('cuda')
|
||||
|
||||
|
||||
def txt2img(
|
||||
|
@ -51,11 +90,9 @@ def txt2img(
|
|||
|
||||
|
||||
def upscale(
|
||||
hf_token: str,
|
||||
prompt: str = 'a red old tractor in a sunny wheat field',
|
||||
img_path: str = 'input.png',
|
||||
output: str = 'output.png',
|
||||
steps: int = 28
|
||||
model_path: str = 'weights/RealESRGAN_x4plus.pth'
|
||||
):
|
||||
assert torch.cuda.is_available()
|
||||
torch.cuda.empty_cache()
|
||||
|
@ -63,20 +100,26 @@ def upscale(
|
|||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
login(token=hf_token)
|
||||
input_img = Image.open(img_path).convert('RGB')
|
||||
|
||||
pipe = StableDiffusionUpscalePipeline.from_pretrained(
|
||||
'stabilityai/stable-diffusion-x4-upscaler',
|
||||
revision="fp16", torch_dtype=torch.float16)
|
||||
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
||||
pipe.scheduler.config)
|
||||
pipe = pipe.to('cuda')
|
||||
upscaler = RealESRGANer(
|
||||
scale=4,
|
||||
model_path=model_path,
|
||||
dni_weight=None,
|
||||
model=RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=4
|
||||
),
|
||||
half=True)
|
||||
|
||||
up_img, _ = upscaler.enhance(
|
||||
convert_from_image_to_cv2(input_img), outscale=4)
|
||||
|
||||
image = convert_from_cv2_to_image(up_img)
|
||||
|
||||
prompt = prompt
|
||||
image = pipe(
|
||||
prompt,
|
||||
image=Image.open(img_path).convert("RGB"),
|
||||
num_inference_steps=steps
|
||||
).images[0]
|
||||
|
||||
image.save(output)
|
||||
|
|
|
@ -6,6 +6,7 @@ import json
|
|||
import base64
|
||||
import logging
|
||||
|
||||
from typing import Optional
|
||||
from hashlib import sha256
|
||||
from functools import partial
|
||||
|
||||
|
@ -42,7 +43,8 @@ async def check_request_img(
|
|||
uid: int = 0,
|
||||
width: int = 512,
|
||||
height: int = 512,
|
||||
expect_unique=True
|
||||
expect_unique = True,
|
||||
upscaler: Optional[str] = None
|
||||
):
|
||||
global _images
|
||||
|
||||
|
@ -60,12 +62,16 @@ async def check_request_img(
|
|||
'guidance': 7.5,
|
||||
'seed': None,
|
||||
'algo': list(ALGOS.keys())[i],
|
||||
'upscaler': None
|
||||
'upscaler': upscaler
|
||||
})
|
||||
|
||||
if 'error' in res.result:
|
||||
raise SkynetDGPUComputeError(json.dumps(res.result))
|
||||
|
||||
if upscaler == 'x4':
|
||||
width *= 4
|
||||
height *= 4
|
||||
|
||||
img_raw = base64.b64decode(bytes.fromhex(res.result['img']))
|
||||
img_sha = sha256(img_raw).hexdigest()
|
||||
img = Image.frombytes(
|
||||
|
@ -80,6 +86,8 @@ async def check_request_img(
|
|||
|
||||
assert len(img_raw) > 100000
|
||||
|
||||
return img
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'dgpu_workers', [(1, ['midj'])], indirect=True)
|
||||
|
@ -123,6 +131,27 @@ async def test_dgpu_workers(dgpu_workers):
|
|||
await check_request_img(1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'dgpu_workers', [(1, ['midj'])], indirect=True)
|
||||
async def test_dgpu_worker_upscale(dgpu_workers):
|
||||
'''Generate two images in a single dgpu worker using
|
||||
two different models.
|
||||
'''
|
||||
|
||||
async with open_skynet_rpc(
|
||||
'test-ctx',
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
) as test_rpc:
|
||||
await wait_for_dgpus(test_rpc, 1)
|
||||
logging.error('UPSCALE')
|
||||
|
||||
img = await check_request_img(0, upscaler='x4')
|
||||
|
||||
assert img.size == (2048, 2048)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'dgpu_workers', [(2, ['midj'])], indirect=True)
|
||||
async def test_dgpu_workers_two(dgpu_workers):
|
||||
|
|
Loading…
Reference in New Issue