mirror of https://github.com/skygpu/skynet.git
				
				
				
			Added auto-download through hf for the upscaler
							parent
							
								
									8828fa13fc
								
							
						
					
					
						commit
						a4e40ba662
					
				| 
						 | 
				
			
			@ -114,7 +114,7 @@ def compute_one(
 | 
			
		|||
    inputs: list[bytes] = [],
 | 
			
		||||
    should_cancel = None
 | 
			
		||||
):
 | 
			
		||||
    total_steps = params['step']
 | 
			
		||||
    total_steps = params['step'] if 'step' in params else 1
 | 
			
		||||
    def inference_step_wakeup(*args, **kwargs):
 | 
			
		||||
        '''This is a callback function that gets invoked every inference step,
 | 
			
		||||
        we need to raise an exception here if we need to cancel work
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -124,7 +124,8 @@ async def maybe_serve_one(
 | 
			
		|||
    request_hash = sha256(hash_str.encode('utf-8')).hexdigest()
 | 
			
		||||
    logging.info(f'calculated request hash: {request_hash}')
 | 
			
		||||
 | 
			
		||||
    total_step = body['params']['step']
 | 
			
		||||
    params = body['params']
 | 
			
		||||
    total_step = params['step'] if 'step' in params else 1
 | 
			
		||||
    model = body['params']['model']
 | 
			
		||||
    mode = body['method']
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -152,7 +153,7 @@ async def maybe_serve_one(
 | 
			
		|||
                            compute_one,
 | 
			
		||||
                            model,
 | 
			
		||||
                            rid,
 | 
			
		||||
                            mode, body['params'],
 | 
			
		||||
                            mode, params,
 | 
			
		||||
                            inputs=inputs,
 | 
			
		||||
                            should_cancel=conn.should_cancel_work,
 | 
			
		||||
                        )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -81,10 +81,11 @@ class WorkerMonitor:
 | 
			
		|||
 | 
			
		||||
        for req in requests:
 | 
			
		||||
            # Build a columns widget for the request row
 | 
			
		||||
            prompt = req['prompt'] if 'prompt' in req else 'UPSCALE'
 | 
			
		||||
            columns = urwid.Columns([
 | 
			
		||||
                ('fixed', 5,  urwid.Text(f"#{req['id']}")),   # e.g. "#12"
 | 
			
		||||
                ('weight', 3, urwid.Text(req['model'])),
 | 
			
		||||
                ('weight', 3, urwid.Text(req['prompt'])),
 | 
			
		||||
                ('weight', 3, urwid.Text(prompt)),
 | 
			
		||||
                ('fixed', 13, urwid.Text(req['user'])),
 | 
			
		||||
                ('fixed', 13, urwid.Text(req['reward'])),
 | 
			
		||||
            ], dividechars=1)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -21,8 +21,9 @@ from diffusers import (
 | 
			
		|||
    AutoPipelineForInpainting,
 | 
			
		||||
    EulerAncestralDiscreteScheduler,
 | 
			
		||||
)
 | 
			
		||||
from huggingface_hub import login
 | 
			
		||||
from huggingface_hub import login, hf_hub_download
 | 
			
		||||
 | 
			
		||||
from skynet.config import load_skynet_toml
 | 
			
		||||
from skynet.constants import MODELS
 | 
			
		||||
 | 
			
		||||
# Hack to fix a changed import in torchvision 0.17+, which otherwise breaks
 | 
			
		||||
| 
						 | 
				
			
			@ -40,7 +41,6 @@ from basicsr.archs.rrdbnet_arch import RRDBNet
 | 
			
		|||
from realesrgan import RealESRGANer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
 | 
			
		||||
    # return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
 | 
			
		||||
    return Image.fromarray(img)
 | 
			
		||||
| 
						 | 
				
			
			@ -285,7 +285,14 @@ def inpaint(
 | 
			
		|||
    image.save(output)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def init_upscaler(model_path: str = 'hf_home/RealESRGAN_x4plus.pth'):
 | 
			
		||||
def init_upscaler():
 | 
			
		||||
    config = load_skynet_toml().dgpu
 | 
			
		||||
    model_path = hf_hub_download(
 | 
			
		||||
        'leonelhs/realesrgan',
 | 
			
		||||
        'RealESRGAN_x4plus.pth',
 | 
			
		||||
        token=config.hf_token,
 | 
			
		||||
        cache_dir=config.hf_home
 | 
			
		||||
    )
 | 
			
		||||
    return RealESRGANer(
 | 
			
		||||
        scale=4,
 | 
			
		||||
        model_path=model_path,
 | 
			
		||||
| 
						 | 
				
			
			@ -303,12 +310,11 @@ def init_upscaler(model_path: str = 'hf_home/RealESRGAN_x4plus.pth'):
 | 
			
		|||
 | 
			
		||||
def upscale(
 | 
			
		||||
    img_path: str = 'input.png',
 | 
			
		||||
    output: str = 'output.png',
 | 
			
		||||
    model_path: str = 'hf_home/RealESRGAN_x4plus.pth'
 | 
			
		||||
    output: str = 'output.png'
 | 
			
		||||
):
 | 
			
		||||
    input_img = Image.open(img_path).convert('RGB')
 | 
			
		||||
 | 
			
		||||
    upscaler = init_upscaler(model_path=model_path)
 | 
			
		||||
    upscaler = init_upscaler()
 | 
			
		||||
 | 
			
		||||
    up_img, _ = upscaler.enhance(
 | 
			
		||||
        convert_from_image_to_cv2(input_img), outscale=4)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue