Added auto-download through hf for the upscaler

pull/47/head
Guillermo Rodriguez 2025-02-08 00:18:38 -03:00
parent 8828fa13fc
commit a4e40ba662
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
4 changed files with 18 additions and 10 deletions

View File

@ -114,7 +114,7 @@ def compute_one(
inputs: list[bytes] = [], inputs: list[bytes] = [],
should_cancel = None should_cancel = None
): ):
total_steps = params['step'] total_steps = params['step'] if 'step' in params else 1
def inference_step_wakeup(*args, **kwargs): def inference_step_wakeup(*args, **kwargs):
'''This is a callback function that gets invoked every inference step, '''This is a callback function that gets invoked every inference step,
we need to raise an exception here if we need to cancel work we need to raise an exception here if we need to cancel work

View File

@ -124,7 +124,8 @@ async def maybe_serve_one(
request_hash = sha256(hash_str.encode('utf-8')).hexdigest() request_hash = sha256(hash_str.encode('utf-8')).hexdigest()
logging.info(f'calculated request hash: {request_hash}') logging.info(f'calculated request hash: {request_hash}')
total_step = body['params']['step'] params = body['params']
total_step = params['step'] if 'step' in params else 1
model = body['params']['model'] model = body['params']['model']
mode = body['method'] mode = body['method']
@ -152,7 +153,7 @@ async def maybe_serve_one(
compute_one, compute_one,
model, model,
rid, rid,
mode, body['params'], mode, params,
inputs=inputs, inputs=inputs,
should_cancel=conn.should_cancel_work, should_cancel=conn.should_cancel_work,
) )

View File

@ -81,10 +81,11 @@ class WorkerMonitor:
for req in requests: for req in requests:
# Build a columns widget for the request row # Build a columns widget for the request row
prompt = req['prompt'] if 'prompt' in req else 'UPSCALE'
columns = urwid.Columns([ columns = urwid.Columns([
('fixed', 5, urwid.Text(f"#{req['id']}")), # e.g. "#12" ('fixed', 5, urwid.Text(f"#{req['id']}")), # e.g. "#12"
('weight', 3, urwid.Text(req['model'])), ('weight', 3, urwid.Text(req['model'])),
('weight', 3, urwid.Text(req['prompt'])), ('weight', 3, urwid.Text(prompt)),
('fixed', 13, urwid.Text(req['user'])), ('fixed', 13, urwid.Text(req['user'])),
('fixed', 13, urwid.Text(req['reward'])), ('fixed', 13, urwid.Text(req['reward'])),
], dividechars=1) ], dividechars=1)

View File

@ -21,8 +21,9 @@ from diffusers import (
AutoPipelineForInpainting, AutoPipelineForInpainting,
EulerAncestralDiscreteScheduler, EulerAncestralDiscreteScheduler,
) )
from huggingface_hub import login from huggingface_hub import login, hf_hub_download
from skynet.config import load_skynet_toml
from skynet.constants import MODELS from skynet.constants import MODELS
# Hack to fix a changed import in torchvision 0.17+, which otherwise breaks # Hack to fix a changed import in torchvision 0.17+, which otherwise breaks
@ -40,7 +41,6 @@ from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
def convert_from_cv2_to_image(img: np.ndarray) -> Image: def convert_from_cv2_to_image(img: np.ndarray) -> Image:
# return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) # return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
return Image.fromarray(img) return Image.fromarray(img)
@ -285,7 +285,14 @@ def inpaint(
image.save(output) image.save(output)
def init_upscaler(model_path: str = 'hf_home/RealESRGAN_x4plus.pth'): def init_upscaler():
config = load_skynet_toml().dgpu
model_path = hf_hub_download(
'leonelhs/realesrgan',
'RealESRGAN_x4plus.pth',
token=config.hf_token,
cache_dir=config.hf_home
)
return RealESRGANer( return RealESRGANer(
scale=4, scale=4,
model_path=model_path, model_path=model_path,
@ -303,12 +310,11 @@ def init_upscaler(model_path: str = 'hf_home/RealESRGAN_x4plus.pth'):
def upscale( def upscale(
img_path: str = 'input.png', img_path: str = 'input.png',
output: str = 'output.png', output: str = 'output.png'
model_path: str = 'hf_home/RealESRGAN_x4plus.pth'
): ):
input_img = Image.open(img_path).convert('RGB') input_img = Image.open(img_path).convert('RGB')
upscaler = init_upscaler(model_path=model_path) upscaler = init_upscaler()
up_img, _ = upscaler.enhance( up_img, _ = upscaler.enhance(
convert_from_image_to_cv2(input_img), outscale=4) convert_from_image_to_cv2(input_img), outscale=4)