mirror of https://github.com/skygpu/skynet.git
Added auto-download through hf for the upscaler
parent
8828fa13fc
commit
a4e40ba662
|
@ -114,7 +114,7 @@ def compute_one(
|
||||||
inputs: list[bytes] = [],
|
inputs: list[bytes] = [],
|
||||||
should_cancel = None
|
should_cancel = None
|
||||||
):
|
):
|
||||||
total_steps = params['step']
|
total_steps = params['step'] if 'step' in params else 1
|
||||||
def inference_step_wakeup(*args, **kwargs):
|
def inference_step_wakeup(*args, **kwargs):
|
||||||
'''This is a callback function that gets invoked every inference step,
|
'''This is a callback function that gets invoked every inference step,
|
||||||
we need to raise an exception here if we need to cancel work
|
we need to raise an exception here if we need to cancel work
|
||||||
|
|
|
@ -124,7 +124,8 @@ async def maybe_serve_one(
|
||||||
request_hash = sha256(hash_str.encode('utf-8')).hexdigest()
|
request_hash = sha256(hash_str.encode('utf-8')).hexdigest()
|
||||||
logging.info(f'calculated request hash: {request_hash}')
|
logging.info(f'calculated request hash: {request_hash}')
|
||||||
|
|
||||||
total_step = body['params']['step']
|
params = body['params']
|
||||||
|
total_step = params['step'] if 'step' in params else 1
|
||||||
model = body['params']['model']
|
model = body['params']['model']
|
||||||
mode = body['method']
|
mode = body['method']
|
||||||
|
|
||||||
|
@ -152,7 +153,7 @@ async def maybe_serve_one(
|
||||||
compute_one,
|
compute_one,
|
||||||
model,
|
model,
|
||||||
rid,
|
rid,
|
||||||
mode, body['params'],
|
mode, params,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
should_cancel=conn.should_cancel_work,
|
should_cancel=conn.should_cancel_work,
|
||||||
)
|
)
|
||||||
|
|
|
@ -81,10 +81,11 @@ class WorkerMonitor:
|
||||||
|
|
||||||
for req in requests:
|
for req in requests:
|
||||||
# Build a columns widget for the request row
|
# Build a columns widget for the request row
|
||||||
|
prompt = req['prompt'] if 'prompt' in req else 'UPSCALE'
|
||||||
columns = urwid.Columns([
|
columns = urwid.Columns([
|
||||||
('fixed', 5, urwid.Text(f"#{req['id']}")), # e.g. "#12"
|
('fixed', 5, urwid.Text(f"#{req['id']}")), # e.g. "#12"
|
||||||
('weight', 3, urwid.Text(req['model'])),
|
('weight', 3, urwid.Text(req['model'])),
|
||||||
('weight', 3, urwid.Text(req['prompt'])),
|
('weight', 3, urwid.Text(prompt)),
|
||||||
('fixed', 13, urwid.Text(req['user'])),
|
('fixed', 13, urwid.Text(req['user'])),
|
||||||
('fixed', 13, urwid.Text(req['reward'])),
|
('fixed', 13, urwid.Text(req['reward'])),
|
||||||
], dividechars=1)
|
], dividechars=1)
|
||||||
|
|
|
@ -21,8 +21,9 @@ from diffusers import (
|
||||||
AutoPipelineForInpainting,
|
AutoPipelineForInpainting,
|
||||||
EulerAncestralDiscreteScheduler,
|
EulerAncestralDiscreteScheduler,
|
||||||
)
|
)
|
||||||
from huggingface_hub import login
|
from huggingface_hub import login, hf_hub_download
|
||||||
|
|
||||||
|
from skynet.config import load_skynet_toml
|
||||||
from skynet.constants import MODELS
|
from skynet.constants import MODELS
|
||||||
|
|
||||||
# Hack to fix a changed import in torchvision 0.17+, which otherwise breaks
|
# Hack to fix a changed import in torchvision 0.17+, which otherwise breaks
|
||||||
|
@ -40,7 +41,6 @@ from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
|
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
|
||||||
# return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
# return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
||||||
return Image.fromarray(img)
|
return Image.fromarray(img)
|
||||||
|
@ -285,7 +285,14 @@ def inpaint(
|
||||||
image.save(output)
|
image.save(output)
|
||||||
|
|
||||||
|
|
||||||
def init_upscaler(model_path: str = 'hf_home/RealESRGAN_x4plus.pth'):
|
def init_upscaler():
|
||||||
|
config = load_skynet_toml().dgpu
|
||||||
|
model_path = hf_hub_download(
|
||||||
|
'leonelhs/realesrgan',
|
||||||
|
'RealESRGAN_x4plus.pth',
|
||||||
|
token=config.hf_token,
|
||||||
|
cache_dir=config.hf_home
|
||||||
|
)
|
||||||
return RealESRGANer(
|
return RealESRGANer(
|
||||||
scale=4,
|
scale=4,
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
|
@ -303,12 +310,11 @@ def init_upscaler(model_path: str = 'hf_home/RealESRGAN_x4plus.pth'):
|
||||||
|
|
||||||
def upscale(
|
def upscale(
|
||||||
img_path: str = 'input.png',
|
img_path: str = 'input.png',
|
||||||
output: str = 'output.png',
|
output: str = 'output.png'
|
||||||
model_path: str = 'hf_home/RealESRGAN_x4plus.pth'
|
|
||||||
):
|
):
|
||||||
input_img = Image.open(img_path).convert('RGB')
|
input_img = Image.open(img_path).convert('RGB')
|
||||||
|
|
||||||
upscaler = init_upscaler(model_path=model_path)
|
upscaler = init_upscaler()
|
||||||
|
|
||||||
up_img, _ = upscaler.enhance(
|
up_img, _ = upscaler.enhance(
|
||||||
convert_from_image_to_cv2(input_img), outscale=4)
|
convert_from_image_to_cv2(input_img), outscale=4)
|
||||||
|
|
Loading…
Reference in New Issue