diff --git a/Dockerfile.runtime+cuda b/Dockerfile.runtime+cuda index 48520cd..f8e9e93 100644 --- a/Dockerfile.runtime+cuda +++ b/Dockerfile.runtime+cuda @@ -3,6 +3,9 @@ from python:3.10.0 env DEBIAN_FRONTEND=noninteractive +run apt-get update && \ + apt-get install -y ffmpeg libsm6 libxext6 + workdir /skynet copy requirements.cuda* ./ @@ -27,3 +30,4 @@ env HF_HOME /hf_home copy scripts scripts copy tests tests + diff --git a/requirements.cuda.0.txt b/requirements.cuda.0.txt index e31de88..e0f47c3 100644 --- a/requirements.cuda.0.txt +++ b/requirements.cuda.0.txt @@ -1,5 +1,7 @@ scipy triton +basicsr +realesrgan accelerate transformers huggingface_hub diff --git a/requirements.txt b/requirements.txt index 650f6ab..8865f1a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ trio pynng +numpy triopg aiohttp msgspec diff --git a/skynet/brain.py b/skynet/brain.py index 91ed253..1d8d8a7 100644 --- a/skynet/brain.py +++ b/skynet/brain.py @@ -204,11 +204,8 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): logging.info('txt2img') user_config = {**(await get_user_config(conn, user))} del user_config['id'] - prompt = req.params['prompt'] - req = ImageGenRequest( - prompt=prompt, - **user_config - ) + user_config.update((k, req.params[k]) for k in req.params) + req = ImageGenRequest(**user_config) rid, img, meta = await dgpu_stream_one_img(req) logging.info(f'done streaming {rid}') result = { @@ -217,7 +214,7 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): 'meta': meta } - await update_user_stats(conn, user, last_prompt=prompt) + await update_user_stats(conn, user, last_prompt=user_config['prompt']) logging.info('updated user stats.') case 'redo': diff --git a/skynet/cli.py b/skynet/cli.py index cfb786e..a2985d9 100644 --- a/skynet/cli.py +++ b/skynet/cli.py @@ -38,19 +38,14 @@ def txt2img(*args, **kwargs): utils.txt2img(os.environ['HF_TOKEN'], **kwargs) @click.command() -@click.option( - '--prompt', '-p', default='a red old tractor in a sunny wheat field') @click.option('--input', '-i', default='input.png') @click.option('--output', '-o', default='output.png') -@click.option('--steps', '-s', default=26) -def upscale(prompt, input, output, steps): - assert 'HF_TOKEN' in os.environ +@click.option('--model', '-m', default='weights/RealESRGAN_x4plus.pth') +def upscale(input, output, model): utils.upscale( - os.environ['HF_TOKEN'], - prompt=prompt, img_path=input, output=output, - steps=steps) + model_path=model) @skynet.group() diff --git a/skynet/dgpu.py b/skynet/dgpu.py index 4beb8f3..1a04367 100644 --- a/skynet/dgpu.py +++ b/skynet/dgpu.py @@ -26,37 +26,34 @@ from diffusers import ( StableDiffusionPipeline, EulerAncestralDiscreteScheduler ) +from realesrgan import RealESRGANer +from basicsr.archs.rrdbnet_arch import RRDBNet from diffusers.models import UNet2DConditionModel +from .utils import ( + pipeline_for, + convert_from_cv2_to_image, convert_from_image_to_cv2 +) from .structs import * from .constants import * from .frontend import open_skynet_rpc -def pipeline_for(algo: str, mem_fraction: float = 1.0): - assert torch.cuda.is_available() - torch.cuda.empty_cache() - torch.cuda.set_per_process_memory_fraction(mem_fraction) - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - - params = { - 'torch_dtype': torch.float16, - 'safety_checker': None - } - - if algo == 'stable': - params['revision'] = 'fp16' - - pipe = StableDiffusionPipeline.from_pretrained( - ALGOS[algo], **params) - - pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( - pipe.scheduler.config) - - pipe.enable_vae_slicing() - - return pipe.to('cuda') +def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'): + return RealESRGANer( + scale=4, + model_path=model_path, + dni_weight=None, + model=RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=4 + ), + half=True + ) class DGPUComputeError(BaseException): @@ -79,6 +76,7 @@ async def open_dgpu_node( logging.info(f'loading models...') + upscaler = init_upscaler() initial_algos = ( initial_algos if initial_algos else DEFAULT_INITAL_ALGOS @@ -91,8 +89,8 @@ async def open_dgpu_node( } logging.info(f'loaded {algo}.') - logging.info('memory summary:\n') - logging.info(torch.cuda.memory_summary()) + logging.info('memory summary:') + logging.info('\n' + torch.cuda.memory_summary()) async def gpu_compute_one(ireq: ImageGenRequest): if ireq.algo not in models: @@ -118,6 +116,15 @@ async def open_dgpu_node( num_inference_steps=ireq.step, generator=torch.Generator("cuda").manual_seed(ireq.seed) ).images[0] + + if ireq.upscaler == 'x4': + logging.info('performing upscale...') + up_img, _ = upscaler.enhance( + convert_from_image_to_cv2(image), outscale=4) + + image = convert_from_cv2_to_image(up_img) + logging.info('done') + return image.tobytes() except BaseException as e: diff --git a/skynet/frontend/__init__.py b/skynet/frontend/__init__.py index ceb47eb..f238186 100644 --- a/skynet/frontend/__init__.py +++ b/skynet/frontend/__init__.py @@ -117,49 +117,50 @@ def validate_user_config_request(req: str): try: attr = params[1] - if attr == 'algo': - val = params[2] - if val not in ALGOS: - raise ConfigUnknownAlgorithm(f'no algo named {val}') + match attr: + case 'algo': + val = params[2] + if val not in ALGOS: + raise ConfigUnknownAlgorithm(f'no algo named {val}') - elif attr == 'step': - val = int(params[2]) - val = max(min(val, MAX_STEP), MIN_STEP) - - elif attr == 'width': - val = max(min(int(params[2]), MAX_WIDTH), 16) - if val % 8 != 0: - raise ConfigSizeDivisionByEight( - 'size must be divisible by 8!') - - elif attr == 'height': - val = max(min(int(params[2]), MAX_HEIGHT), 16) - if val % 8 != 0: - raise ConfigSizeDivisionByEight( - 'size must be divisible by 8!') - - elif attr == 'seed': - val = params[2] - if val == 'auto': - val = None - else: + case 'step': val = int(params[2]) + val = max(min(val, MAX_STEP), MIN_STEP) - elif attr == 'guidance': - val = float(params[2]) - val = max(min(val, MAX_GUIDANCE), 0) + case 'width': + val = max(min(int(params[2]), MAX_WIDTH), 16) + if val % 8 != 0: + raise ConfigSizeDivisionByEight( + 'size must be divisible by 8!') - elif attr == 'upscaler': - val = params[2] - if val == 'off': - val = None - elif val != 'x4': - raise ConfigUnknownUpscaler( - f'\"{val}\" is not a valid upscaler') + case 'height': + val = max(min(int(params[2]), MAX_HEIGHT), 16) + if val % 8 != 0: + raise ConfigSizeDivisionByEight( + 'size must be divisible by 8!') - else: - raise ConfigUnknownAttribute( - f'\"{attr}\" not a configurable parameter') + case 'seed': + val = params[2] + if val == 'auto': + val = None + else: + val = int(params[2]) + + case 'guidance': + val = float(params[2]) + val = max(min(val, MAX_GUIDANCE), 0) + + case 'upscaler': + val = params[2] + if val == 'off': + val = None + elif val != 'x4': + raise ConfigUnknownUpscaler( + f'\"{val}\" is not a valid upscaler') + + case _: + raise ConfigUnknownAttribute( + f'\"{attr}\" not a configurable parameter') return attr, val, f'config updated! {attr} to {val}' diff --git a/skynet/structs.py b/skynet/structs.py index cc9f25f..f110b6b 100644 --- a/skynet/structs.py +++ b/skynet/structs.py @@ -26,13 +26,7 @@ from pprint import pformat import msgspec -class Struct( - msgspec.Struct, - - # https://jcristharif.com/msgspec/structs.html#tagged-unions - # tag='pikerstruct', - # tag=True, -): +class Struct(msgspec.Struct): ''' A "human friendlier" (aka repl buddy) struct subtype. ''' @@ -88,7 +82,7 @@ class Struct( from OpenSSL.crypto import PKey, X509, verify, sign -class AuthenticatedStruct(Struct): +class AuthenticatedStruct(Struct, kw_only=True): cert: Optional[str] = None sig: Optional[str] = None diff --git a/skynet/utils.py b/skynet/utils.py index 06bba1d..c511453 100644 --- a/skynet/utils.py +++ b/skynet/utils.py @@ -6,16 +6,55 @@ from typing import Optional from pathlib import Path import torch +import numpy as np from PIL import Image +from basicsr.archs.rrdbnet_arch import RRDBNet from diffusers import ( StableDiffusionPipeline, StableDiffusionUpscalePipeline, EulerAncestralDiscreteScheduler ) +from realesrgan import RealESRGANer from huggingface_hub import login -from .dgpu import pipeline_for +from .constants import ALGOS + + +def convert_from_cv2_to_image(img: np.ndarray) -> Image: + # return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + return Image.fromarray(img) + + +def convert_from_image_to_cv2(img: Image) -> np.ndarray: + # return cv2.cvtColor(numpy.array(img), cv2.COLOR_RGB2BGR) + return np.asarray(img) + + +def pipeline_for(algo: str, mem_fraction: float = 1.0): + assert torch.cuda.is_available() + torch.cuda.empty_cache() + torch.cuda.set_per_process_memory_fraction(mem_fraction) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + params = { + 'torch_dtype': torch.float16, + 'safety_checker': None + } + + if algo == 'stable': + params['revision'] = 'fp16' + + pipe = StableDiffusionPipeline.from_pretrained( + ALGOS[algo], **params) + + pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( + pipe.scheduler.config) + + pipe.enable_vae_slicing() + + return pipe.to('cuda') def txt2img( @@ -51,11 +90,9 @@ def txt2img( def upscale( - hf_token: str, - prompt: str = 'a red old tractor in a sunny wheat field', img_path: str = 'input.png', output: str = 'output.png', - steps: int = 28 + model_path: str = 'weights/RealESRGAN_x4plus.pth' ): assert torch.cuda.is_available() torch.cuda.empty_cache() @@ -63,20 +100,26 @@ def upscale( torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True - login(token=hf_token) + input_img = Image.open(img_path).convert('RGB') - pipe = StableDiffusionUpscalePipeline.from_pretrained( - 'stabilityai/stable-diffusion-x4-upscaler', - revision="fp16", torch_dtype=torch.float16) - pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( - pipe.scheduler.config) - pipe = pipe.to('cuda') + upscaler = RealESRGANer( + scale=4, + model_path=model_path, + dni_weight=None, + model=RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=4 + ), + half=True) + + up_img, _ = upscaler.enhance( + convert_from_image_to_cv2(input_img), outscale=4) + + image = convert_from_cv2_to_image(up_img) - prompt = prompt - image = pipe( - prompt, - image=Image.open(img_path).convert("RGB"), - num_inference_steps=steps - ).images[0] image.save(output) diff --git a/tests/test_dgpu.py b/tests/test_dgpu.py index 28426a9..a580a27 100644 --- a/tests/test_dgpu.py +++ b/tests/test_dgpu.py @@ -6,6 +6,7 @@ import json import base64 import logging +from typing import Optional from hashlib import sha256 from functools import partial @@ -42,7 +43,8 @@ async def check_request_img( uid: int = 0, width: int = 512, height: int = 512, - expect_unique=True + expect_unique = True, + upscaler: Optional[str] = None ): global _images @@ -60,12 +62,16 @@ async def check_request_img( 'guidance': 7.5, 'seed': None, 'algo': list(ALGOS.keys())[i], - 'upscaler': None + 'upscaler': upscaler }) if 'error' in res.result: raise SkynetDGPUComputeError(json.dumps(res.result)) + if upscaler == 'x4': + width *= 4 + height *= 4 + img_raw = base64.b64decode(bytes.fromhex(res.result['img'])) img_sha = sha256(img_raw).hexdigest() img = Image.frombytes( @@ -80,6 +86,8 @@ async def check_request_img( assert len(img_raw) > 100000 + return img + @pytest.mark.parametrize( 'dgpu_workers', [(1, ['midj'])], indirect=True) @@ -123,6 +131,27 @@ async def test_dgpu_workers(dgpu_workers): await check_request_img(1) +@pytest.mark.parametrize( + 'dgpu_workers', [(1, ['midj'])], indirect=True) +async def test_dgpu_worker_upscale(dgpu_workers): + '''Generate two images in a single dgpu worker using + two different models. + ''' + + async with open_skynet_rpc( + 'test-ctx', + security=True, + cert_name='whitelist/testing', + key_name='testing' + ) as test_rpc: + await wait_for_dgpus(test_rpc, 1) + logging.error('UPSCALE') + + img = await check_request_img(0, upscaler='x4') + + assert img.size == (2048, 2048) + + @pytest.mark.parametrize( 'dgpu_workers', [(2, ['midj'])], indirect=True) async def test_dgpu_workers_two(dgpu_workers):