mirror of https://github.com/skygpu/skynet.git
Add upscaler
parent
1fc2020ed5
commit
a9c237b538
|
@ -3,6 +3,9 @@ from python:3.10.0
|
||||||
|
|
||||||
env DEBIAN_FRONTEND=noninteractive
|
env DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
run apt-get update && \
|
||||||
|
apt-get install -y ffmpeg libsm6 libxext6
|
||||||
|
|
||||||
workdir /skynet
|
workdir /skynet
|
||||||
|
|
||||||
copy requirements.cuda* ./
|
copy requirements.cuda* ./
|
||||||
|
@ -27,3 +30,4 @@ env HF_HOME /hf_home
|
||||||
|
|
||||||
copy scripts scripts
|
copy scripts scripts
|
||||||
copy tests tests
|
copy tests tests
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
scipy
|
scipy
|
||||||
triton
|
triton
|
||||||
|
basicsr
|
||||||
|
realesrgan
|
||||||
accelerate
|
accelerate
|
||||||
transformers
|
transformers
|
||||||
huggingface_hub
|
huggingface_hub
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
trio
|
trio
|
||||||
pynng
|
pynng
|
||||||
|
numpy
|
||||||
triopg
|
triopg
|
||||||
aiohttp
|
aiohttp
|
||||||
msgspec
|
msgspec
|
||||||
|
|
|
@ -204,11 +204,8 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||||
logging.info('txt2img')
|
logging.info('txt2img')
|
||||||
user_config = {**(await get_user_config(conn, user))}
|
user_config = {**(await get_user_config(conn, user))}
|
||||||
del user_config['id']
|
del user_config['id']
|
||||||
prompt = req.params['prompt']
|
user_config.update((k, req.params[k]) for k in req.params)
|
||||||
req = ImageGenRequest(
|
req = ImageGenRequest(**user_config)
|
||||||
prompt=prompt,
|
|
||||||
**user_config
|
|
||||||
)
|
|
||||||
rid, img, meta = await dgpu_stream_one_img(req)
|
rid, img, meta = await dgpu_stream_one_img(req)
|
||||||
logging.info(f'done streaming {rid}')
|
logging.info(f'done streaming {rid}')
|
||||||
result = {
|
result = {
|
||||||
|
@ -217,7 +214,7 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||||
'meta': meta
|
'meta': meta
|
||||||
}
|
}
|
||||||
|
|
||||||
await update_user_stats(conn, user, last_prompt=prompt)
|
await update_user_stats(conn, user, last_prompt=user_config['prompt'])
|
||||||
logging.info('updated user stats.')
|
logging.info('updated user stats.')
|
||||||
|
|
||||||
case 'redo':
|
case 'redo':
|
||||||
|
|
|
@ -38,19 +38,14 @@ def txt2img(*args, **kwargs):
|
||||||
utils.txt2img(os.environ['HF_TOKEN'], **kwargs)
|
utils.txt2img(os.environ['HF_TOKEN'], **kwargs)
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@click.option(
|
|
||||||
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
|
|
||||||
@click.option('--input', '-i', default='input.png')
|
@click.option('--input', '-i', default='input.png')
|
||||||
@click.option('--output', '-o', default='output.png')
|
@click.option('--output', '-o', default='output.png')
|
||||||
@click.option('--steps', '-s', default=26)
|
@click.option('--model', '-m', default='weights/RealESRGAN_x4plus.pth')
|
||||||
def upscale(prompt, input, output, steps):
|
def upscale(input, output, model):
|
||||||
assert 'HF_TOKEN' in os.environ
|
|
||||||
utils.upscale(
|
utils.upscale(
|
||||||
os.environ['HF_TOKEN'],
|
|
||||||
prompt=prompt,
|
|
||||||
img_path=input,
|
img_path=input,
|
||||||
output=output,
|
output=output,
|
||||||
steps=steps)
|
model_path=model)
|
||||||
|
|
||||||
|
|
||||||
@skynet.group()
|
@skynet.group()
|
||||||
|
|
|
@ -26,37 +26,34 @@ from diffusers import (
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
EulerAncestralDiscreteScheduler
|
EulerAncestralDiscreteScheduler
|
||||||
)
|
)
|
||||||
|
from realesrgan import RealESRGANer
|
||||||
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
from diffusers.models import UNet2DConditionModel
|
from diffusers.models import UNet2DConditionModel
|
||||||
|
|
||||||
|
from .utils import (
|
||||||
|
pipeline_for,
|
||||||
|
convert_from_cv2_to_image, convert_from_image_to_cv2
|
||||||
|
)
|
||||||
from .structs import *
|
from .structs import *
|
||||||
from .constants import *
|
from .constants import *
|
||||||
from .frontend import open_skynet_rpc
|
from .frontend import open_skynet_rpc
|
||||||
|
|
||||||
|
|
||||||
def pipeline_for(algo: str, mem_fraction: float = 1.0):
|
def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'):
|
||||||
assert torch.cuda.is_available()
|
return RealESRGANer(
|
||||||
torch.cuda.empty_cache()
|
scale=4,
|
||||||
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
model_path=model_path,
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
dni_weight=None,
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
model=RRDBNet(
|
||||||
|
num_in_ch=3,
|
||||||
params = {
|
num_out_ch=3,
|
||||||
'torch_dtype': torch.float16,
|
num_feat=64,
|
||||||
'safety_checker': None
|
num_block=23,
|
||||||
}
|
num_grow_ch=32,
|
||||||
|
scale=4
|
||||||
if algo == 'stable':
|
),
|
||||||
params['revision'] = 'fp16'
|
half=True
|
||||||
|
)
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(
|
|
||||||
ALGOS[algo], **params)
|
|
||||||
|
|
||||||
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
|
||||||
pipe.scheduler.config)
|
|
||||||
|
|
||||||
pipe.enable_vae_slicing()
|
|
||||||
|
|
||||||
return pipe.to('cuda')
|
|
||||||
|
|
||||||
|
|
||||||
class DGPUComputeError(BaseException):
|
class DGPUComputeError(BaseException):
|
||||||
|
@ -79,6 +76,7 @@ async def open_dgpu_node(
|
||||||
|
|
||||||
logging.info(f'loading models...')
|
logging.info(f'loading models...')
|
||||||
|
|
||||||
|
upscaler = init_upscaler()
|
||||||
initial_algos = (
|
initial_algos = (
|
||||||
initial_algos
|
initial_algos
|
||||||
if initial_algos else DEFAULT_INITAL_ALGOS
|
if initial_algos else DEFAULT_INITAL_ALGOS
|
||||||
|
@ -91,8 +89,8 @@ async def open_dgpu_node(
|
||||||
}
|
}
|
||||||
logging.info(f'loaded {algo}.')
|
logging.info(f'loaded {algo}.')
|
||||||
|
|
||||||
logging.info('memory summary:\n')
|
logging.info('memory summary:')
|
||||||
logging.info(torch.cuda.memory_summary())
|
logging.info('\n' + torch.cuda.memory_summary())
|
||||||
|
|
||||||
async def gpu_compute_one(ireq: ImageGenRequest):
|
async def gpu_compute_one(ireq: ImageGenRequest):
|
||||||
if ireq.algo not in models:
|
if ireq.algo not in models:
|
||||||
|
@ -118,6 +116,15 @@ async def open_dgpu_node(
|
||||||
num_inference_steps=ireq.step,
|
num_inference_steps=ireq.step,
|
||||||
generator=torch.Generator("cuda").manual_seed(ireq.seed)
|
generator=torch.Generator("cuda").manual_seed(ireq.seed)
|
||||||
).images[0]
|
).images[0]
|
||||||
|
|
||||||
|
if ireq.upscaler == 'x4':
|
||||||
|
logging.info('performing upscale...')
|
||||||
|
up_img, _ = upscaler.enhance(
|
||||||
|
convert_from_image_to_cv2(image), outscale=4)
|
||||||
|
|
||||||
|
image = convert_from_cv2_to_image(up_img)
|
||||||
|
logging.info('done')
|
||||||
|
|
||||||
return image.tobytes()
|
return image.tobytes()
|
||||||
|
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
|
|
|
@ -117,39 +117,40 @@ def validate_user_config_request(req: str):
|
||||||
try:
|
try:
|
||||||
attr = params[1]
|
attr = params[1]
|
||||||
|
|
||||||
if attr == 'algo':
|
match attr:
|
||||||
|
case 'algo':
|
||||||
val = params[2]
|
val = params[2]
|
||||||
if val not in ALGOS:
|
if val not in ALGOS:
|
||||||
raise ConfigUnknownAlgorithm(f'no algo named {val}')
|
raise ConfigUnknownAlgorithm(f'no algo named {val}')
|
||||||
|
|
||||||
elif attr == 'step':
|
case 'step':
|
||||||
val = int(params[2])
|
val = int(params[2])
|
||||||
val = max(min(val, MAX_STEP), MIN_STEP)
|
val = max(min(val, MAX_STEP), MIN_STEP)
|
||||||
|
|
||||||
elif attr == 'width':
|
case 'width':
|
||||||
val = max(min(int(params[2]), MAX_WIDTH), 16)
|
val = max(min(int(params[2]), MAX_WIDTH), 16)
|
||||||
if val % 8 != 0:
|
if val % 8 != 0:
|
||||||
raise ConfigSizeDivisionByEight(
|
raise ConfigSizeDivisionByEight(
|
||||||
'size must be divisible by 8!')
|
'size must be divisible by 8!')
|
||||||
|
|
||||||
elif attr == 'height':
|
case 'height':
|
||||||
val = max(min(int(params[2]), MAX_HEIGHT), 16)
|
val = max(min(int(params[2]), MAX_HEIGHT), 16)
|
||||||
if val % 8 != 0:
|
if val % 8 != 0:
|
||||||
raise ConfigSizeDivisionByEight(
|
raise ConfigSizeDivisionByEight(
|
||||||
'size must be divisible by 8!')
|
'size must be divisible by 8!')
|
||||||
|
|
||||||
elif attr == 'seed':
|
case 'seed':
|
||||||
val = params[2]
|
val = params[2]
|
||||||
if val == 'auto':
|
if val == 'auto':
|
||||||
val = None
|
val = None
|
||||||
else:
|
else:
|
||||||
val = int(params[2])
|
val = int(params[2])
|
||||||
|
|
||||||
elif attr == 'guidance':
|
case 'guidance':
|
||||||
val = float(params[2])
|
val = float(params[2])
|
||||||
val = max(min(val, MAX_GUIDANCE), 0)
|
val = max(min(val, MAX_GUIDANCE), 0)
|
||||||
|
|
||||||
elif attr == 'upscaler':
|
case 'upscaler':
|
||||||
val = params[2]
|
val = params[2]
|
||||||
if val == 'off':
|
if val == 'off':
|
||||||
val = None
|
val = None
|
||||||
|
@ -157,7 +158,7 @@ def validate_user_config_request(req: str):
|
||||||
raise ConfigUnknownUpscaler(
|
raise ConfigUnknownUpscaler(
|
||||||
f'\"{val}\" is not a valid upscaler')
|
f'\"{val}\" is not a valid upscaler')
|
||||||
|
|
||||||
else:
|
case _:
|
||||||
raise ConfigUnknownAttribute(
|
raise ConfigUnknownAttribute(
|
||||||
f'\"{attr}\" not a configurable parameter')
|
f'\"{attr}\" not a configurable parameter')
|
||||||
|
|
||||||
|
|
|
@ -26,13 +26,7 @@ from pprint import pformat
|
||||||
import msgspec
|
import msgspec
|
||||||
|
|
||||||
|
|
||||||
class Struct(
|
class Struct(msgspec.Struct):
|
||||||
msgspec.Struct,
|
|
||||||
|
|
||||||
# https://jcristharif.com/msgspec/structs.html#tagged-unions
|
|
||||||
# tag='pikerstruct',
|
|
||||||
# tag=True,
|
|
||||||
):
|
|
||||||
'''
|
'''
|
||||||
A "human friendlier" (aka repl buddy) struct subtype.
|
A "human friendlier" (aka repl buddy) struct subtype.
|
||||||
'''
|
'''
|
||||||
|
@ -88,7 +82,7 @@ class Struct(
|
||||||
from OpenSSL.crypto import PKey, X509, verify, sign
|
from OpenSSL.crypto import PKey, X509, verify, sign
|
||||||
|
|
||||||
|
|
||||||
class AuthenticatedStruct(Struct):
|
class AuthenticatedStruct(Struct, kw_only=True):
|
||||||
cert: Optional[str] = None
|
cert: Optional[str] = None
|
||||||
sig: Optional[str] = None
|
sig: Optional[str] = None
|
||||||
|
|
||||||
|
|
|
@ -6,16 +6,55 @@ from typing import Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
StableDiffusionUpscalePipeline,
|
StableDiffusionUpscalePipeline,
|
||||||
EulerAncestralDiscreteScheduler
|
EulerAncestralDiscreteScheduler
|
||||||
)
|
)
|
||||||
|
from realesrgan import RealESRGANer
|
||||||
from huggingface_hub import login
|
from huggingface_hub import login
|
||||||
|
|
||||||
from .dgpu import pipeline_for
|
from .constants import ALGOS
|
||||||
|
|
||||||
|
|
||||||
|
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
|
||||||
|
# return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
||||||
|
return Image.fromarray(img)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_from_image_to_cv2(img: Image) -> np.ndarray:
|
||||||
|
# return cv2.cvtColor(numpy.array(img), cv2.COLOR_RGB2BGR)
|
||||||
|
return np.asarray(img)
|
||||||
|
|
||||||
|
|
||||||
|
def pipeline_for(algo: str, mem_fraction: float = 1.0):
|
||||||
|
assert torch.cuda.is_available()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
params = {
|
||||||
|
'torch_dtype': torch.float16,
|
||||||
|
'safety_checker': None
|
||||||
|
}
|
||||||
|
|
||||||
|
if algo == 'stable':
|
||||||
|
params['revision'] = 'fp16'
|
||||||
|
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
|
ALGOS[algo], **params)
|
||||||
|
|
||||||
|
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
||||||
|
pipe.scheduler.config)
|
||||||
|
|
||||||
|
pipe.enable_vae_slicing()
|
||||||
|
|
||||||
|
return pipe.to('cuda')
|
||||||
|
|
||||||
|
|
||||||
def txt2img(
|
def txt2img(
|
||||||
|
@ -51,11 +90,9 @@ def txt2img(
|
||||||
|
|
||||||
|
|
||||||
def upscale(
|
def upscale(
|
||||||
hf_token: str,
|
|
||||||
prompt: str = 'a red old tractor in a sunny wheat field',
|
|
||||||
img_path: str = 'input.png',
|
img_path: str = 'input.png',
|
||||||
output: str = 'output.png',
|
output: str = 'output.png',
|
||||||
steps: int = 28
|
model_path: str = 'weights/RealESRGAN_x4plus.pth'
|
||||||
):
|
):
|
||||||
assert torch.cuda.is_available()
|
assert torch.cuda.is_available()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
@ -63,20 +100,26 @@ def upscale(
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
login(token=hf_token)
|
input_img = Image.open(img_path).convert('RGB')
|
||||||
|
|
||||||
pipe = StableDiffusionUpscalePipeline.from_pretrained(
|
upscaler = RealESRGANer(
|
||||||
'stabilityai/stable-diffusion-x4-upscaler',
|
scale=4,
|
||||||
revision="fp16", torch_dtype=torch.float16)
|
model_path=model_path,
|
||||||
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
dni_weight=None,
|
||||||
pipe.scheduler.config)
|
model=RRDBNet(
|
||||||
pipe = pipe.to('cuda')
|
num_in_ch=3,
|
||||||
|
num_out_ch=3,
|
||||||
|
num_feat=64,
|
||||||
|
num_block=23,
|
||||||
|
num_grow_ch=32,
|
||||||
|
scale=4
|
||||||
|
),
|
||||||
|
half=True)
|
||||||
|
|
||||||
|
up_img, _ = upscaler.enhance(
|
||||||
|
convert_from_image_to_cv2(input_img), outscale=4)
|
||||||
|
|
||||||
|
image = convert_from_cv2_to_image(up_img)
|
||||||
|
|
||||||
prompt = prompt
|
|
||||||
image = pipe(
|
|
||||||
prompt,
|
|
||||||
image=Image.open(img_path).convert("RGB"),
|
|
||||||
num_inference_steps=steps
|
|
||||||
).images[0]
|
|
||||||
|
|
||||||
image.save(output)
|
image.save(output)
|
||||||
|
|
|
@ -6,6 +6,7 @@ import json
|
||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
@ -42,7 +43,8 @@ async def check_request_img(
|
||||||
uid: int = 0,
|
uid: int = 0,
|
||||||
width: int = 512,
|
width: int = 512,
|
||||||
height: int = 512,
|
height: int = 512,
|
||||||
expect_unique=True
|
expect_unique = True,
|
||||||
|
upscaler: Optional[str] = None
|
||||||
):
|
):
|
||||||
global _images
|
global _images
|
||||||
|
|
||||||
|
@ -60,12 +62,16 @@ async def check_request_img(
|
||||||
'guidance': 7.5,
|
'guidance': 7.5,
|
||||||
'seed': None,
|
'seed': None,
|
||||||
'algo': list(ALGOS.keys())[i],
|
'algo': list(ALGOS.keys())[i],
|
||||||
'upscaler': None
|
'upscaler': upscaler
|
||||||
})
|
})
|
||||||
|
|
||||||
if 'error' in res.result:
|
if 'error' in res.result:
|
||||||
raise SkynetDGPUComputeError(json.dumps(res.result))
|
raise SkynetDGPUComputeError(json.dumps(res.result))
|
||||||
|
|
||||||
|
if upscaler == 'x4':
|
||||||
|
width *= 4
|
||||||
|
height *= 4
|
||||||
|
|
||||||
img_raw = base64.b64decode(bytes.fromhex(res.result['img']))
|
img_raw = base64.b64decode(bytes.fromhex(res.result['img']))
|
||||||
img_sha = sha256(img_raw).hexdigest()
|
img_sha = sha256(img_raw).hexdigest()
|
||||||
img = Image.frombytes(
|
img = Image.frombytes(
|
||||||
|
@ -80,6 +86,8 @@ async def check_request_img(
|
||||||
|
|
||||||
assert len(img_raw) > 100000
|
assert len(img_raw) > 100000
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
'dgpu_workers', [(1, ['midj'])], indirect=True)
|
'dgpu_workers', [(1, ['midj'])], indirect=True)
|
||||||
|
@ -123,6 +131,27 @@ async def test_dgpu_workers(dgpu_workers):
|
||||||
await check_request_img(1)
|
await check_request_img(1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'dgpu_workers', [(1, ['midj'])], indirect=True)
|
||||||
|
async def test_dgpu_worker_upscale(dgpu_workers):
|
||||||
|
'''Generate two images in a single dgpu worker using
|
||||||
|
two different models.
|
||||||
|
'''
|
||||||
|
|
||||||
|
async with open_skynet_rpc(
|
||||||
|
'test-ctx',
|
||||||
|
security=True,
|
||||||
|
cert_name='whitelist/testing',
|
||||||
|
key_name='testing'
|
||||||
|
) as test_rpc:
|
||||||
|
await wait_for_dgpus(test_rpc, 1)
|
||||||
|
logging.error('UPSCALE')
|
||||||
|
|
||||||
|
img = await check_request_img(0, upscaler='x4')
|
||||||
|
|
||||||
|
assert img.size == (2048, 2048)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
'dgpu_workers', [(2, ['midj'])], indirect=True)
|
'dgpu_workers', [(2, ['midj'])], indirect=True)
|
||||||
async def test_dgpu_workers_two(dgpu_workers):
|
async def test_dgpu_workers_two(dgpu_workers):
|
||||||
|
|
Loading…
Reference in New Issue