mirror of https://github.com/skygpu/skynet.git
Added auto-download through hf for the upscaler
parent
8828fa13fc
commit
a4e40ba662
|
@ -114,7 +114,7 @@ def compute_one(
|
|||
inputs: list[bytes] = [],
|
||||
should_cancel = None
|
||||
):
|
||||
total_steps = params['step']
|
||||
total_steps = params['step'] if 'step' in params else 1
|
||||
def inference_step_wakeup(*args, **kwargs):
|
||||
'''This is a callback function that gets invoked every inference step,
|
||||
we need to raise an exception here if we need to cancel work
|
||||
|
|
|
@ -124,7 +124,8 @@ async def maybe_serve_one(
|
|||
request_hash = sha256(hash_str.encode('utf-8')).hexdigest()
|
||||
logging.info(f'calculated request hash: {request_hash}')
|
||||
|
||||
total_step = body['params']['step']
|
||||
params = body['params']
|
||||
total_step = params['step'] if 'step' in params else 1
|
||||
model = body['params']['model']
|
||||
mode = body['method']
|
||||
|
||||
|
@ -152,7 +153,7 @@ async def maybe_serve_one(
|
|||
compute_one,
|
||||
model,
|
||||
rid,
|
||||
mode, body['params'],
|
||||
mode, params,
|
||||
inputs=inputs,
|
||||
should_cancel=conn.should_cancel_work,
|
||||
)
|
||||
|
|
|
@ -81,10 +81,11 @@ class WorkerMonitor:
|
|||
|
||||
for req in requests:
|
||||
# Build a columns widget for the request row
|
||||
prompt = req['prompt'] if 'prompt' in req else 'UPSCALE'
|
||||
columns = urwid.Columns([
|
||||
('fixed', 5, urwid.Text(f"#{req['id']}")), # e.g. "#12"
|
||||
('weight', 3, urwid.Text(req['model'])),
|
||||
('weight', 3, urwid.Text(req['prompt'])),
|
||||
('weight', 3, urwid.Text(prompt)),
|
||||
('fixed', 13, urwid.Text(req['user'])),
|
||||
('fixed', 13, urwid.Text(req['reward'])),
|
||||
], dividechars=1)
|
||||
|
|
|
@ -21,8 +21,9 @@ from diffusers import (
|
|||
AutoPipelineForInpainting,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
)
|
||||
from huggingface_hub import login
|
||||
from huggingface_hub import login, hf_hub_download
|
||||
|
||||
from skynet.config import load_skynet_toml
|
||||
from skynet.constants import MODELS
|
||||
|
||||
# Hack to fix a changed import in torchvision 0.17+, which otherwise breaks
|
||||
|
@ -40,7 +41,6 @@ from basicsr.archs.rrdbnet_arch import RRDBNet
|
|||
from realesrgan import RealESRGANer
|
||||
|
||||
|
||||
|
||||
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
|
||||
# return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
||||
return Image.fromarray(img)
|
||||
|
@ -285,7 +285,14 @@ def inpaint(
|
|||
image.save(output)
|
||||
|
||||
|
||||
def init_upscaler(model_path: str = 'hf_home/RealESRGAN_x4plus.pth'):
|
||||
def init_upscaler():
|
||||
config = load_skynet_toml().dgpu
|
||||
model_path = hf_hub_download(
|
||||
'leonelhs/realesrgan',
|
||||
'RealESRGAN_x4plus.pth',
|
||||
token=config.hf_token,
|
||||
cache_dir=config.hf_home
|
||||
)
|
||||
return RealESRGANer(
|
||||
scale=4,
|
||||
model_path=model_path,
|
||||
|
@ -303,12 +310,11 @@ def init_upscaler(model_path: str = 'hf_home/RealESRGAN_x4plus.pth'):
|
|||
|
||||
def upscale(
|
||||
img_path: str = 'input.png',
|
||||
output: str = 'output.png',
|
||||
model_path: str = 'hf_home/RealESRGAN_x4plus.pth'
|
||||
output: str = 'output.png'
|
||||
):
|
||||
input_img = Image.open(img_path).convert('RGB')
|
||||
|
||||
upscaler = init_upscaler(model_path=model_path)
|
||||
upscaler = init_upscaler()
|
||||
|
||||
up_img, _ = upscaler.enhance(
|
||||
convert_from_image_to_cv2(input_img), outscale=4)
|
||||
|
|
Loading…
Reference in New Issue