Added auto-download through hf for the upscaler

pull/47/head
Guillermo Rodriguez 2025-02-08 00:18:38 -03:00
parent 8828fa13fc
commit a4e40ba662
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
4 changed files with 18 additions and 10 deletions

View File

@ -114,7 +114,7 @@ def compute_one(
inputs: list[bytes] = [],
should_cancel = None
):
total_steps = params['step']
total_steps = params['step'] if 'step' in params else 1
def inference_step_wakeup(*args, **kwargs):
'''This is a callback function that gets invoked every inference step,
we need to raise an exception here if we need to cancel work

View File

@ -124,7 +124,8 @@ async def maybe_serve_one(
request_hash = sha256(hash_str.encode('utf-8')).hexdigest()
logging.info(f'calculated request hash: {request_hash}')
total_step = body['params']['step']
params = body['params']
total_step = params['step'] if 'step' in params else 1
model = body['params']['model']
mode = body['method']
@ -152,7 +153,7 @@ async def maybe_serve_one(
compute_one,
model,
rid,
mode, body['params'],
mode, params,
inputs=inputs,
should_cancel=conn.should_cancel_work,
)

View File

@ -81,10 +81,11 @@ class WorkerMonitor:
for req in requests:
# Build a columns widget for the request row
prompt = req['prompt'] if 'prompt' in req else 'UPSCALE'
columns = urwid.Columns([
('fixed', 5, urwid.Text(f"#{req['id']}")), # e.g. "#12"
('weight', 3, urwid.Text(req['model'])),
('weight', 3, urwid.Text(req['prompt'])),
('weight', 3, urwid.Text(prompt)),
('fixed', 13, urwid.Text(req['user'])),
('fixed', 13, urwid.Text(req['reward'])),
], dividechars=1)

View File

@ -21,8 +21,9 @@ from diffusers import (
AutoPipelineForInpainting,
EulerAncestralDiscreteScheduler,
)
from huggingface_hub import login
from huggingface_hub import login, hf_hub_download
from skynet.config import load_skynet_toml
from skynet.constants import MODELS
# Hack to fix a changed import in torchvision 0.17+, which otherwise breaks
@ -40,7 +41,6 @@ from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
# return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
return Image.fromarray(img)
@ -285,7 +285,14 @@ def inpaint(
image.save(output)
def init_upscaler(model_path: str = 'hf_home/RealESRGAN_x4plus.pth'):
def init_upscaler():
config = load_skynet_toml().dgpu
model_path = hf_hub_download(
'leonelhs/realesrgan',
'RealESRGAN_x4plus.pth',
token=config.hf_token,
cache_dir=config.hf_home
)
return RealESRGANer(
scale=4,
model_path=model_path,
@ -303,12 +310,11 @@ def init_upscaler(model_path: str = 'hf_home/RealESRGAN_x4plus.pth'):
def upscale(
img_path: str = 'input.png',
output: str = 'output.png',
model_path: str = 'hf_home/RealESRGAN_x4plus.pth'
output: str = 'output.png'
):
input_img = Image.open(img_path).convert('RGB')
upscaler = init_upscaler(model_path=model_path)
upscaler = init_upscaler()
up_img, _ = upscaler.enhance(
convert_from_image_to_cv2(input_img), outscale=4)