From a4e40ba6624cc7a27c761a54e02938e1d23f4ca9 Mon Sep 17 00:00:00 2001 From: Guillermo Rodriguez Date: Sat, 8 Feb 2025 00:18:38 -0300 Subject: [PATCH] Added auto-download through hf for the upscaler --- skynet/dgpu/compute.py | 2 +- skynet/dgpu/daemon.py | 5 +++-- skynet/dgpu/tui.py | 3 ++- skynet/dgpu/utils.py | 18 ++++++++++++------ 4 files changed, 18 insertions(+), 10 deletions(-) diff --git a/skynet/dgpu/compute.py b/skynet/dgpu/compute.py index 67bc7ea..3a0645a 100755 --- a/skynet/dgpu/compute.py +++ b/skynet/dgpu/compute.py @@ -114,7 +114,7 @@ def compute_one( inputs: list[bytes] = [], should_cancel = None ): - total_steps = params['step'] + total_steps = params['step'] if 'step' in params else 1 def inference_step_wakeup(*args, **kwargs): '''This is a callback function that gets invoked every inference step, we need to raise an exception here if we need to cancel work diff --git a/skynet/dgpu/daemon.py b/skynet/dgpu/daemon.py index 2fc2159..b23c652 100755 --- a/skynet/dgpu/daemon.py +++ b/skynet/dgpu/daemon.py @@ -124,7 +124,8 @@ async def maybe_serve_one( request_hash = sha256(hash_str.encode('utf-8')).hexdigest() logging.info(f'calculated request hash: {request_hash}') - total_step = body['params']['step'] + params = body['params'] + total_step = params['step'] if 'step' in params else 1 model = body['params']['model'] mode = body['method'] @@ -152,7 +153,7 @@ async def maybe_serve_one( compute_one, model, rid, - mode, body['params'], + mode, params, inputs=inputs, should_cancel=conn.should_cancel_work, ) diff --git a/skynet/dgpu/tui.py b/skynet/dgpu/tui.py index ed3c1a7..5403025 100644 --- a/skynet/dgpu/tui.py +++ b/skynet/dgpu/tui.py @@ -81,10 +81,11 @@ class WorkerMonitor: for req in requests: # Build a columns widget for the request row + prompt = req['prompt'] if 'prompt' in req else 'UPSCALE' columns = urwid.Columns([ ('fixed', 5, urwid.Text(f"#{req['id']}")), # e.g. "#12" ('weight', 3, urwid.Text(req['model'])), - ('weight', 3, urwid.Text(req['prompt'])), + ('weight', 3, urwid.Text(prompt)), ('fixed', 13, urwid.Text(req['user'])), ('fixed', 13, urwid.Text(req['reward'])), ], dividechars=1) diff --git a/skynet/dgpu/utils.py b/skynet/dgpu/utils.py index 2b3090e..a355be9 100755 --- a/skynet/dgpu/utils.py +++ b/skynet/dgpu/utils.py @@ -21,8 +21,9 @@ from diffusers import ( AutoPipelineForInpainting, EulerAncestralDiscreteScheduler, ) -from huggingface_hub import login +from huggingface_hub import login, hf_hub_download +from skynet.config import load_skynet_toml from skynet.constants import MODELS # Hack to fix a changed import in torchvision 0.17+, which otherwise breaks @@ -40,7 +41,6 @@ from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer - def convert_from_cv2_to_image(img: np.ndarray) -> Image: # return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) return Image.fromarray(img) @@ -285,7 +285,14 @@ def inpaint( image.save(output) -def init_upscaler(model_path: str = 'hf_home/RealESRGAN_x4plus.pth'): +def init_upscaler(): + config = load_skynet_toml().dgpu + model_path = hf_hub_download( + 'leonelhs/realesrgan', + 'RealESRGAN_x4plus.pth', + token=config.hf_token, + cache_dir=config.hf_home + ) return RealESRGANer( scale=4, model_path=model_path, @@ -303,12 +310,11 @@ def init_upscaler(model_path: str = 'hf_home/RealESRGAN_x4plus.pth'): def upscale( img_path: str = 'input.png', - output: str = 'output.png', - model_path: str = 'hf_home/RealESRGAN_x4plus.pth' + output: str = 'output.png' ): input_img = Image.open(img_path).convert('RGB') - upscaler = init_upscaler(model_path=model_path) + upscaler = init_upscaler() up_img, _ = upscaler.enhance( convert_from_image_to_cv2(input_img), outscale=4)