mirror of https://github.com/skygpu/skynet.git
				
				
				
			Simplify pipeline_for function and add the infra needed for diferent io/types than png
							parent
							
								
									ee1fdcc557
								
							
						
					
					
						commit
						3d2069d151
					
				| 
						 | 
					@ -15,7 +15,6 @@ Pillow = '^10.0.1'
 | 
				
			||||||
docker = '^6.1.3'
 | 
					docker = '^6.1.3'
 | 
				
			||||||
py-leap = {git = 'https://github.com/guilledk/py-leap.git', rev = 'v0.1a14'}
 | 
					py-leap = {git = 'https://github.com/guilledk/py-leap.git', rev = 'v0.1a14'}
 | 
				
			||||||
toml = "^0.10.2"
 | 
					toml = "^0.10.2"
 | 
				
			||||||
tractor = {git = "https://github.com/goodboy/tractor.git"}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
[tool.poetry.group.frontend]
 | 
					[tool.poetry.group.frontend]
 | 
				
			||||||
optional = true
 | 
					optional = true
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -85,7 +85,7 @@ def download():
 | 
				
			||||||
    hf_token = load_key(config, 'skynet.dgpu.hf_token')
 | 
					    hf_token = load_key(config, 'skynet.dgpu.hf_token')
 | 
				
			||||||
    hf_home = load_key(config, 'skynet.dgpu.hf_home')
 | 
					    hf_home = load_key(config, 'skynet.dgpu.hf_home')
 | 
				
			||||||
    set_hf_vars(hf_token, hf_home)
 | 
					    set_hf_vars(hf_token, hf_home)
 | 
				
			||||||
    utils.download_all_models(hf_token)
 | 
					    utils.download_all_models(hf_token, hf_home)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@skynet.command()
 | 
					@skynet.command()
 | 
				
			||||||
@click.option(
 | 
					@click.option(
 | 
				
			||||||
| 
						 | 
					@ -120,21 +120,21 @@ def enqueue(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cleos = CLEOS(None, None, url=node_url, remote=node_url)
 | 
					    cleos = CLEOS(None, None, url=node_url, remote=node_url)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    binary = kwargs['binary_data']
 | 
				
			||||||
 | 
					    if not kwargs['strength']:
 | 
				
			||||||
 | 
					        if binary:
 | 
				
			||||||
 | 
					            raise ValueError('strength -Z param required if binary data passed')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        del kwargs['strength']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        kwargs['strength'] = float(kwargs['strength'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def enqueue_n_jobs():
 | 
					    async def enqueue_n_jobs():
 | 
				
			||||||
        for i in range(jobs):
 | 
					        for i in range(jobs):
 | 
				
			||||||
            if not kwargs['seed']:
 | 
					            if not kwargs['seed']:
 | 
				
			||||||
                kwargs['seed'] = random.randint(0, 10e9)
 | 
					                kwargs['seed'] = random.randint(0, 10e9)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            binary = kwargs['binary_data']
 | 
					 | 
				
			||||||
            if not kwargs['strength']:
 | 
					 | 
				
			||||||
                if binary:
 | 
					 | 
				
			||||||
                    raise ValueError('strength -Z param required if binary data passed')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                del kwargs['strength']
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                kwargs['strength'] = float(kwargs['strength'])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            req = json.dumps({
 | 
					            req = json.dumps({
 | 
				
			||||||
                'method': 'diffuse',
 | 
					                'method': 'diffuse',
 | 
				
			||||||
                'params': kwargs
 | 
					                'params': kwargs
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -5,18 +5,20 @@ VERSION = '0.1a12'
 | 
				
			||||||
DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda'
 | 
					DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
MODELS = {
 | 
					MODELS = {
 | 
				
			||||||
    'prompthero/openjourney':                   {'short': 'midj',        'mem': 8},
 | 
					    'prompthero/openjourney':                   {'short': 'midj',        'mem': 6},
 | 
				
			||||||
    'runwayml/stable-diffusion-v1-5':           {'short': 'stable',      'mem': 8},
 | 
					    'runwayml/stable-diffusion-v1-5':           {'short': 'stable',      'mem': 6},
 | 
				
			||||||
    'stabilityai/stable-diffusion-2-1-base':    {'short': 'stable2',     'mem': 8},
 | 
					    'stabilityai/stable-diffusion-2-1-base':    {'short': 'stable2',     'mem': 6},
 | 
				
			||||||
    'snowkidy/stable-diffusion-xl-base-0.9':    {'short': 'stablexl0.9', 'mem': 24},
 | 
					    'snowkidy/stable-diffusion-xl-base-0.9':    {'short': 'stablexl0.9', 'mem': 8.3},
 | 
				
			||||||
    'stabilityai/stable-diffusion-xl-base-1.0': {'short': 'stablexl',    'mem': 24},
 | 
					    'Linaqruf/anything-v3.0':                   {'short': 'hdanime',     'mem': 6},
 | 
				
			||||||
    'Linaqruf/anything-v3.0':                   {'short': 'hdanime',     'mem': 8},
 | 
					    'hakurei/waifu-diffusion':                  {'short': 'waifu',       'mem': 6},
 | 
				
			||||||
    'hakurei/waifu-diffusion':                  {'short': 'waifu',       'mem': 8},
 | 
					    'nitrosocke/Ghibli-Diffusion':              {'short': 'ghibli',      'mem': 6},
 | 
				
			||||||
    'nitrosocke/Ghibli-Diffusion':              {'short': 'ghibli',      'mem': 8},
 | 
					    'dallinmackay/Van-Gogh-diffusion':          {'short': 'van-gogh',    'mem': 6},
 | 
				
			||||||
    'dallinmackay/Van-Gogh-diffusion':          {'short': 'van-gogh',    'mem': 8},
 | 
					    'lambdalabs/sd-pokemon-diffusers':          {'short': 'pokemon',     'mem': 6},
 | 
				
			||||||
    'lambdalabs/sd-pokemon-diffusers':          {'short': 'pokemon',     'mem': 8},
 | 
					    'Envvi/Inkpunk-Diffusion':                  {'short': 'ink',         'mem': 6},
 | 
				
			||||||
    'Envvi/Inkpunk-Diffusion':                  {'short': 'ink',         'mem': 8},
 | 
					    'nousr/robo-diffusion':                     {'short': 'robot',       'mem': 6},
 | 
				
			||||||
    'nousr/robo-diffusion':                     {'short': 'robot',       'mem': 8}
 | 
					
 | 
				
			||||||
 | 
					    # default is always last
 | 
				
			||||||
 | 
					    'stabilityai/stable-diffusion-xl-base-1.0': {'short': 'stablexl',    'mem': 8.3},
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
SHORT_NAMES = [
 | 
					SHORT_NAMES = [
 | 
				
			||||||
| 
						 | 
					@ -158,7 +160,7 @@ DEFAULT_GUIDANCE = 7.5
 | 
				
			||||||
DEFAULT_STRENGTH = 0.5
 | 
					DEFAULT_STRENGTH = 0.5
 | 
				
			||||||
DEFAULT_STEP = 28
 | 
					DEFAULT_STEP = 28
 | 
				
			||||||
DEFAULT_CREDITS = 10
 | 
					DEFAULT_CREDITS = 10
 | 
				
			||||||
DEFAULT_MODEL = list(MODELS.keys())[4]
 | 
					DEFAULT_MODEL = list(MODELS.keys())[-1]
 | 
				
			||||||
DEFAULT_ROLE = 'pleb'
 | 
					DEFAULT_ROLE = 'pleb'
 | 
				
			||||||
DEFAULT_UPSCALER = None
 | 
					DEFAULT_UPSCALER = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,165 +0,0 @@
 | 
				
			||||||
#!/usr/bin/python
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import gc
 | 
					 | 
				
			||||||
import json
 | 
					 | 
				
			||||||
import logging
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from hashlib import sha256
 | 
					 | 
				
			||||||
from diffusers import DiffusionPipeline
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import trio
 | 
					 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
 | 
					 | 
				
			||||||
from skynet.dgpu.errors import DGPUComputeError
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from skynet.utils import convert_from_bytes_and_crop, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def prepare_params_for_diffuse(
 | 
					 | 
				
			||||||
    params: dict,
 | 
					 | 
				
			||||||
    binary: bytes | None = None
 | 
					 | 
				
			||||||
):
 | 
					 | 
				
			||||||
    image = None
 | 
					 | 
				
			||||||
    if binary:
 | 
					 | 
				
			||||||
        image = convert_from_bytes_and_crop(binary, 512, 512)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    _params = {}
 | 
					 | 
				
			||||||
    if image:
 | 
					 | 
				
			||||||
        _params['image'] = image
 | 
					 | 
				
			||||||
        _params['strength'] = float(params['strength'])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        _params['width'] = int(params['width'])
 | 
					 | 
				
			||||||
        _params['height'] = int(params['height'])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return (
 | 
					 | 
				
			||||||
        params['prompt'],
 | 
					 | 
				
			||||||
        float(params['guidance']),
 | 
					 | 
				
			||||||
        int(params['step']),
 | 
					 | 
				
			||||||
        torch.manual_seed(int(params['seed'])),
 | 
					 | 
				
			||||||
        params['upscaler'] if 'upscaler' in params else None,
 | 
					 | 
				
			||||||
        _params
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
_models = {}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def is_model_loaded(model_name: str, image: bool):
 | 
					 | 
				
			||||||
    for model_key, model_data in _models.items():
 | 
					 | 
				
			||||||
        if (model_key == model_name and
 | 
					 | 
				
			||||||
            model_data['image'] == image):
 | 
					 | 
				
			||||||
            return True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def load_model(
 | 
					 | 
				
			||||||
    model_name: str,
 | 
					 | 
				
			||||||
    image: bool,
 | 
					 | 
				
			||||||
    force=False
 | 
					 | 
				
			||||||
):
 | 
					 | 
				
			||||||
    logging.info(f'loading model {model_name}...')
 | 
					 | 
				
			||||||
    if force or len(_models.keys()) == 0:
 | 
					 | 
				
			||||||
        pipe = pipeline_for(
 | 
					 | 
				
			||||||
            model_name, image=image)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        _models[model_name] = {
 | 
					 | 
				
			||||||
            'pipe': pipe,
 | 
					 | 
				
			||||||
            'generated': 0,
 | 
					 | 
				
			||||||
            'image': image
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        least_used = list(_models.keys())[0]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        for model in _models:
 | 
					 | 
				
			||||||
            if _models[
 | 
					 | 
				
			||||||
                least_used]['generated'] > _models[model]['generated']:
 | 
					 | 
				
			||||||
                least_used = model
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        del _models[least_used]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        logging.info(f'swapping model {least_used} for {model_name}...')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        gc.collect()
 | 
					 | 
				
			||||||
        torch.cuda.empty_cache()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        pipe = pipeline_for(
 | 
					 | 
				
			||||||
            model_name, image=image)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        _models[model_name] = {
 | 
					 | 
				
			||||||
            'pipe': pipe,
 | 
					 | 
				
			||||||
            'generated': 0,
 | 
					 | 
				
			||||||
            'image': image
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    logging.info(f'loaded model {model_name}')
 | 
					 | 
				
			||||||
    return pipe
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_model(model_name: str, image: bool) -> DiffusionPipeline:
 | 
					 | 
				
			||||||
    if model_name not in MODELS:
 | 
					 | 
				
			||||||
        raise DGPUComputeError(f'Unknown model {model_name}')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if not is_model_loaded(model_name, image):
 | 
					 | 
				
			||||||
        pipe = load_model(model_name, image=image)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        pipe = _models[model_name]['pipe']
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return pipe
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def _static_compute_one(kwargs: dict):
 | 
					 | 
				
			||||||
    request_id: int = kwargs['request_id']
 | 
					 | 
				
			||||||
    method: str = kwargs['method']
 | 
					 | 
				
			||||||
    params: dict = kwargs['params']
 | 
					 | 
				
			||||||
    binary: bytes | None = kwargs['binary']
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _checkpoint(*args, **kwargs):
 | 
					 | 
				
			||||||
        trio.from_thread.run(trio.sleep, 0)    
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        match method:
 | 
					 | 
				
			||||||
            case 'diffuse':
 | 
					 | 
				
			||||||
                image = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                arguments = prepare_params_for_diffuse(params, binary)
 | 
					 | 
				
			||||||
                prompt, guidance, step, seed, upscaler, extra_params = arguments
 | 
					 | 
				
			||||||
                model = get_model(params['model'], 'image' in extra_params)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                image = model(
 | 
					 | 
				
			||||||
                    prompt,
 | 
					 | 
				
			||||||
                    guidance_scale=guidance,
 | 
					 | 
				
			||||||
                    num_inference_steps=step,
 | 
					 | 
				
			||||||
                    generator=seed,
 | 
					 | 
				
			||||||
                    callback=_checkpoint,
 | 
					 | 
				
			||||||
                    callback_steps=1,
 | 
					 | 
				
			||||||
                    **extra_params
 | 
					 | 
				
			||||||
                ).images[0]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if upscaler == 'x4':
 | 
					 | 
				
			||||||
                    upscaler = init_upscaler()
 | 
					 | 
				
			||||||
                    input_img = image.convert('RGB')
 | 
					 | 
				
			||||||
                    up_img, _ = upscaler.enhance(
 | 
					 | 
				
			||||||
                        convert_from_image_to_cv2(input_img), outscale=4)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    image = convert_from_cv2_to_image(up_img)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                img_raw = convert_from_img_to_bytes(image)
 | 
					 | 
				
			||||||
                img_sha = sha256(img_raw).hexdigest()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                return img_sha, img_raw
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            case _:
 | 
					 | 
				
			||||||
                raise DGPUComputeError('Unsupported compute method')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    except BaseException as e:
 | 
					 | 
				
			||||||
        logging.error(e)
 | 
					 | 
				
			||||||
        raise DGPUComputeError(str(e))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    finally:
 | 
					 | 
				
			||||||
        torch.cuda.empty_cache()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
async def _tractor_static_compute_one(**kwargs):
 | 
					 | 
				
			||||||
    return await trio.to_thread.run_sync(
 | 
					 | 
				
			||||||
        _static_compute_one, kwargs)
 | 
					 | 
				
			||||||
| 
						 | 
					@ -3,10 +3,11 @@
 | 
				
			||||||
# Skynet Memory Manager
 | 
					# Skynet Memory Manager
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import gc
 | 
					import gc
 | 
				
			||||||
import json
 | 
					 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from hashlib import sha256
 | 
					from hashlib import sha256
 | 
				
			||||||
 | 
					import zipfile
 | 
				
			||||||
 | 
					from PIL import Image
 | 
				
			||||||
from diffusers import DiffusionPipeline
 | 
					from diffusers import DiffusionPipeline
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import trio
 | 
					import trio
 | 
				
			||||||
| 
						 | 
					@ -15,22 +16,29 @@ import torch
 | 
				
			||||||
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
 | 
					from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
 | 
				
			||||||
from skynet.dgpu.errors import DGPUComputeError, DGPUInferenceCancelled
 | 
					from skynet.dgpu.errors import DGPUComputeError, DGPUInferenceCancelled
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from skynet.utils import convert_from_bytes_and_crop, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
 | 
					from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ._mp_compute import _static_compute_one, _tractor_static_compute_one
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
def prepare_params_for_diffuse(
 | 
					def prepare_params_for_diffuse(
 | 
				
			||||||
    params: dict,
 | 
					    params: dict,
 | 
				
			||||||
    binary: bytes | None = None
 | 
					    input_type: str,
 | 
				
			||||||
 | 
					    binary = None
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    image = None
 | 
					 | 
				
			||||||
    if binary:
 | 
					 | 
				
			||||||
        image = convert_from_bytes_and_crop(binary, 512, 512)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    _params = {}
 | 
					    _params = {}
 | 
				
			||||||
    if image:
 | 
					    if binary != None:
 | 
				
			||||||
        _params['image'] = image
 | 
					        match input_type:
 | 
				
			||||||
        _params['strength'] = float(params['strength'])
 | 
					            case 'png':
 | 
				
			||||||
 | 
					                image = crop_image(
 | 
				
			||||||
 | 
					                    binary, params['width'], params['height'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                _params['image'] = image
 | 
				
			||||||
 | 
					                _params['strength'] = float(params['strength'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            case 'none':
 | 
				
			||||||
 | 
					                ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            case _:
 | 
				
			||||||
 | 
					                raise DGPUComputeError(f'Unknown input_type {input_type}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        _params['width'] = int(params['width'])
 | 
					        _params['width'] = int(params['width'])
 | 
				
			||||||
| 
						 | 
					@ -136,6 +144,7 @@ class SkynetMM:
 | 
				
			||||||
        request_id: int,
 | 
					        request_id: int,
 | 
				
			||||||
        method: str,
 | 
					        method: str,
 | 
				
			||||||
        params: dict,
 | 
					        params: dict,
 | 
				
			||||||
 | 
					        input_type: str = 'png',
 | 
				
			||||||
        binary: bytes | None = None
 | 
					        binary: bytes | None = None
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        def maybe_cancel_work(step, *args, **kwargs):
 | 
					        def maybe_cancel_work(step, *args, **kwargs):
 | 
				
			||||||
| 
						 | 
					@ -147,16 +156,21 @@ class SkynetMM:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        maybe_cancel_work(0)
 | 
					        maybe_cancel_work(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        output_type = 'png'
 | 
				
			||||||
 | 
					        if 'output_type' in params:
 | 
				
			||||||
 | 
					            output_type = params['output_type']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        output = None
 | 
				
			||||||
 | 
					        output_hash = None
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            match method:
 | 
					            match method:
 | 
				
			||||||
                case 'diffuse':
 | 
					                case 'diffuse':
 | 
				
			||||||
                    image = None
 | 
					                    arguments = prepare_params_for_diffuse(
 | 
				
			||||||
 | 
					                        params, input_type, binary=binary)
 | 
				
			||||||
                    arguments = prepare_params_for_diffuse(params, binary)
 | 
					 | 
				
			||||||
                    prompt, guidance, step, seed, upscaler, extra_params = arguments
 | 
					                    prompt, guidance, step, seed, upscaler, extra_params = arguments
 | 
				
			||||||
                    model = self.get_model(params['model'], 'image' in extra_params)
 | 
					                    model = self.get_model(params['model'], 'image' in extra_params)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    image = model(
 | 
					                    output = model(
 | 
				
			||||||
                        prompt,
 | 
					                        prompt,
 | 
				
			||||||
                        guidance_scale=guidance,
 | 
					                        guidance_scale=guidance,
 | 
				
			||||||
                        num_inference_steps=step,
 | 
					                        num_inference_steps=step,
 | 
				
			||||||
| 
						 | 
					@ -166,17 +180,22 @@ class SkynetMM:
 | 
				
			||||||
                        **extra_params
 | 
					                        **extra_params
 | 
				
			||||||
                    ).images[0]
 | 
					                    ).images[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    if upscaler == 'x4':
 | 
					                    output_binary = b''
 | 
				
			||||||
                        input_img = image.convert('RGB')
 | 
					                    match output_type:
 | 
				
			||||||
                        up_img, _ = self.upscaler.enhance(
 | 
					                        case 'png':
 | 
				
			||||||
                            convert_from_image_to_cv2(input_img), outscale=4)
 | 
					                            if upscaler == 'x4':
 | 
				
			||||||
 | 
					                                input_img = output.convert('RGB')
 | 
				
			||||||
 | 
					                                up_img, _ = self.upscaler.enhance(
 | 
				
			||||||
 | 
					                                    convert_from_image_to_cv2(input_img), outscale=4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        image = convert_from_cv2_to_image(up_img)
 | 
					                                output = convert_from_cv2_to_image(up_img)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    img_raw = convert_from_img_to_bytes(image)
 | 
					                            output_binary = convert_from_img_to_bytes(output)
 | 
				
			||||||
                    img_sha = sha256(img_raw).hexdigest()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    return img_sha, img_raw
 | 
					                        case _:
 | 
				
			||||||
 | 
					                            raise DGPUComputeError(f'Unsupported output type: {output_type}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    output_hash = sha256(output_binary).hexdigest()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                case _:
 | 
					                case _:
 | 
				
			||||||
                    raise DGPUComputeError('Unsupported compute method')
 | 
					                    raise DGPUComputeError('Unsupported compute method')
 | 
				
			||||||
| 
						 | 
					@ -187,3 +206,5 @@ class SkynetMM:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        finally:
 | 
					        finally:
 | 
				
			||||||
            torch.cuda.empty_cache()
 | 
					            torch.cuda.empty_cache()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return output_hash, output
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -9,10 +9,10 @@ from hashlib import sha256
 | 
				
			||||||
from functools import partial
 | 
					from functools import partial
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import trio
 | 
					import trio
 | 
				
			||||||
import tractor
 | 
					from skynet.constants import MODELS
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from skynet.dgpu.errors import *
 | 
					from skynet.dgpu.errors import *
 | 
				
			||||||
from skynet.dgpu.compute import SkynetMM, _tractor_static_compute_one
 | 
					from skynet.dgpu.compute import SkynetMM
 | 
				
			||||||
from skynet.dgpu.network import SkynetGPUConnector
 | 
					from skynet.dgpu.network import SkynetGPUConnector
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -97,6 +97,11 @@ class SkynetDGPUDaemon:
 | 
				
			||||||
                    body = json.loads(req['body'])
 | 
					                    body = json.loads(req['body'])
 | 
				
			||||||
                    model = body['params']['model']
 | 
					                    model = body['params']['model']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    # if model not known
 | 
				
			||||||
 | 
					                    if model not in MODELS:
 | 
				
			||||||
 | 
					                        logging.warning(f'Unknown model {model}')
 | 
				
			||||||
 | 
					                        continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    # if whitelist enabled and model not in it continue
 | 
					                    # if whitelist enabled and model not in it continue
 | 
				
			||||||
                    if (len(self.model_whitelist) > 0 and
 | 
					                    if (len(self.model_whitelist) > 0 and
 | 
				
			||||||
                        not model in self.model_whitelist):
 | 
					                        not model in self.model_whitelist):
 | 
				
			||||||
| 
						 | 
					@ -111,7 +116,7 @@ class SkynetDGPUDaemon:
 | 
				
			||||||
                        statuses = self._snap['requests'][rid]
 | 
					                        statuses = self._snap['requests'][rid]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        if len(statuses) == 0:
 | 
					                        if len(statuses) == 0:
 | 
				
			||||||
                            binary = await self.conn.get_input_data(req['binary_data'])
 | 
					                            binary, input_type = await self.conn.get_input_data(req['binary_data'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                            hash_str = (
 | 
					                            hash_str = (
 | 
				
			||||||
                                str(req['nonce'])
 | 
					                                str(req['nonce'])
 | 
				
			||||||
| 
						 | 
					@ -134,46 +139,31 @@ class SkynetDGPUDaemon:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                            else:
 | 
					                            else:
 | 
				
			||||||
                                try:
 | 
					                                try:
 | 
				
			||||||
 | 
					                                    output_type = 'png'
 | 
				
			||||||
 | 
					                                    if 'output_type' in body['params']:
 | 
				
			||||||
 | 
					                                        output_type = body['params']['output_type']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                                    output = None
 | 
				
			||||||
 | 
					                                    output_hash = None
 | 
				
			||||||
                                    match self.backend:
 | 
					                                    match self.backend:
 | 
				
			||||||
                                        case 'sync-on-thread':
 | 
					                                        case 'sync-on-thread':
 | 
				
			||||||
                                            self.mm._should_cancel = self.should_cancel_work
 | 
					                                            self.mm._should_cancel = self.should_cancel_work
 | 
				
			||||||
                                            img_sha, img_raw = await trio.to_thread.run_sync(
 | 
					                                            output_hash, output = await trio.to_thread.run_sync(
 | 
				
			||||||
                                                partial(
 | 
					                                                partial(
 | 
				
			||||||
                                                    self.mm.compute_one,
 | 
					                                                    self.mm.compute_one,
 | 
				
			||||||
                                                    rid,
 | 
					                                                    rid,
 | 
				
			||||||
                                                    body['method'], body['params'], binary=binary
 | 
					                                                    body['method'], body['params'],
 | 
				
			||||||
                                                )
 | 
					                                                    input_type=input_type,
 | 
				
			||||||
                                            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                                        case 'tractor':
 | 
					 | 
				
			||||||
                                            async def _should_cancel_oracle():
 | 
					 | 
				
			||||||
                                                while True:
 | 
					 | 
				
			||||||
                                                    await trio.sleep(1)
 | 
					 | 
				
			||||||
                                                    if (await self.should_cancel_work(rid)):
 | 
					 | 
				
			||||||
                                                        raise DGPUInferenceCancelled
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                                            async with (
 | 
					 | 
				
			||||||
                                                trio.open_nursery() as trio_n,
 | 
					 | 
				
			||||||
                                                tractor.open_nursery() as tractor_n
 | 
					 | 
				
			||||||
                                            ):
 | 
					 | 
				
			||||||
                                                trio_n.start_soon(_should_cancel_oracle)
 | 
					 | 
				
			||||||
                                                portal = await tractor_n.run_in_actor(
 | 
					 | 
				
			||||||
                                                    _tractor_static_compute_one,
 | 
					 | 
				
			||||||
                                                    name='tractor-cuda-mp',
 | 
					 | 
				
			||||||
                                                    request_id=rid,
 | 
					 | 
				
			||||||
                                                    method=body['method'],
 | 
					 | 
				
			||||||
                                                    params=body['params'],
 | 
					 | 
				
			||||||
                                                    binary=binary
 | 
					                                                    binary=binary
 | 
				
			||||||
                                                )
 | 
					                                                )
 | 
				
			||||||
                                                img_sha, img_raw = await portal.result()
 | 
					                                            )
 | 
				
			||||||
                                                trio_n.cancel_scope.cancel()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                                        case _:
 | 
					                                        case _:
 | 
				
			||||||
                                            raise DGPUComputeError(f'Unsupported backend {self.backend}')
 | 
					                                            raise DGPUComputeError(f'Unsupported backend {self.backend}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                                    ipfs_hash = await self.conn.publish_on_ipfs(img_raw)
 | 
					                                    ipfs_hash = await self.conn.publish_on_ipfs(output, typ=output_type)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                                    await self.conn.submit_work(rid, request_hash, img_sha, ipfs_hash)
 | 
					                                    await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                                except BaseException as e:
 | 
					                                except BaseException as e:
 | 
				
			||||||
                                    traceback.print_exc()
 | 
					                                    traceback.print_exc()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -9,17 +9,19 @@ from pathlib import Path
 | 
				
			||||||
from functools import partial
 | 
					from functools import partial
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import asks
 | 
					import asks
 | 
				
			||||||
 | 
					import numpy
 | 
				
			||||||
import trio
 | 
					import trio
 | 
				
			||||||
import anyio
 | 
					import anyio
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from PIL import Image, UnidentifiedImageError
 | 
					from PIL import Image, UnidentifiedImageError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from leap.cleos import CLEOS
 | 
					from leap.cleos import CLEOS
 | 
				
			||||||
from leap.sugar import Checksum256, Name, asset_from_str
 | 
					from leap.sugar import Checksum256, Name, asset_from_str
 | 
				
			||||||
from skynet.constants import DEFAULT_DOMAIN
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from skynet.dgpu.errors import DGPUComputeError
 | 
					 | 
				
			||||||
from skynet.ipfs import AsyncIPFSHTTP, get_ipfs_file
 | 
					from skynet.ipfs import AsyncIPFSHTTP, get_ipfs_file
 | 
				
			||||||
 | 
					from skynet.dgpu.errors import DGPUComputeError
 | 
				
			||||||
 | 
					from skynet.constants import DEFAULT_DOMAIN
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
REQUEST_UPDATE_TIME = 3
 | 
					REQUEST_UPDATE_TIME = 3
 | 
				
			||||||
| 
						 | 
					@ -235,11 +237,19 @@ class SkynetGPUConnector:
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # IPFS helpers
 | 
					    # IPFS helpers
 | 
				
			||||||
    async def publish_on_ipfs(self, raw_img: bytes):
 | 
					    async def publish_on_ipfs(self, raw, typ: str = 'png'):
 | 
				
			||||||
        Path('ipfs-staging').mkdir(exist_ok=True)
 | 
					        Path('ipfs-staging').mkdir(exist_ok=True)
 | 
				
			||||||
        logging.info('publish_on_ipfs')
 | 
					        logging.info('publish_on_ipfs')
 | 
				
			||||||
        img = Image.open(io.BytesIO(raw_img))
 | 
					
 | 
				
			||||||
        img.save('ipfs-staging/image.png')
 | 
					        target_file = ''
 | 
				
			||||||
 | 
					        match typ:
 | 
				
			||||||
 | 
					            case 'png':
 | 
				
			||||||
 | 
					                raw: Image
 | 
				
			||||||
 | 
					                target_file = 'ipfs-staging/image.png'
 | 
				
			||||||
 | 
					                raw.save(target_file)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            case _:
 | 
				
			||||||
 | 
					                raise ValueError(f'Unsupported output type: {typ}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.ipfs_gateway_url:
 | 
					        if self.ipfs_gateway_url:
 | 
				
			||||||
            # check peer connections, reconnect to skynet gateway if not
 | 
					            # check peer connections, reconnect to skynet gateway if not
 | 
				
			||||||
| 
						 | 
					@ -248,16 +258,18 @@ class SkynetGPUConnector:
 | 
				
			||||||
            if gateway_id not in [p['Peer'] for p in peers]:
 | 
					            if gateway_id not in [p['Peer'] for p in peers]:
 | 
				
			||||||
                await self.ipfs_client.connect(self.ipfs_gateway_url)
 | 
					                await self.ipfs_client.connect(self.ipfs_gateway_url)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        file_info = await self.ipfs_client.add(Path('ipfs-staging/image.png'))
 | 
					        file_info = await self.ipfs_client.add(Path(target_file))
 | 
				
			||||||
        file_cid = file_info['Hash']
 | 
					        file_cid = file_info['Hash']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        await self.ipfs_client.pin(file_cid)
 | 
					        await self.ipfs_client.pin(file_cid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return file_cid
 | 
					        return file_cid
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def get_input_data(self, ipfs_hash: str) -> bytes:
 | 
					    async def get_input_data(self, ipfs_hash: str) -> tuple[bytes, str]:
 | 
				
			||||||
 | 
					        input_type = 'none'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if ipfs_hash == '':
 | 
					        if ipfs_hash == '':
 | 
				
			||||||
            return b''
 | 
					            return b'', input_type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        results = {}
 | 
					        results = {}
 | 
				
			||||||
        ipfs_link = f'https://ipfs.{DEFAULT_DOMAIN}/ipfs/{ipfs_hash}'
 | 
					        ipfs_link = f'https://ipfs.{DEFAULT_DOMAIN}/ipfs/{ipfs_hash}'
 | 
				
			||||||
| 
						 | 
					@ -272,9 +284,10 @@ class SkynetGPUConnector:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                else:
 | 
					                else:
 | 
				
			||||||
                    try:
 | 
					                    try:
 | 
				
			||||||
                        with Image.open(io.BytesIO(res.raw)):
 | 
					                        # attempt to decode as image
 | 
				
			||||||
                            results[link] = res.raw
 | 
					                        results[link] = Image.open(io.BytesIO(res.raw))
 | 
				
			||||||
                            n.cancel_scope.cancel()
 | 
					                        input_type = 'png'
 | 
				
			||||||
 | 
					                        n.cancel_scope.cancel()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    except UnidentifiedImageError:
 | 
					                    except UnidentifiedImageError:
 | 
				
			||||||
                        logging.warning(f'couldn\'t get ipfs binary data at {link}!')
 | 
					                        logging.warning(f'couldn\'t get ipfs binary data at {link}!')
 | 
				
			||||||
| 
						 | 
					@ -284,14 +297,14 @@ class SkynetGPUConnector:
 | 
				
			||||||
            n.start_soon(
 | 
					            n.start_soon(
 | 
				
			||||||
                get_and_set_results, ipfs_link_legacy)
 | 
					                get_and_set_results, ipfs_link_legacy)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        png_img = None
 | 
					        input_data = None
 | 
				
			||||||
        if ipfs_link_legacy in results:
 | 
					        if ipfs_link_legacy in results:
 | 
				
			||||||
            png_img = results[ipfs_link_legacy]
 | 
					            input_data = results[ipfs_link_legacy]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if ipfs_link in results:
 | 
					        if ipfs_link in results:
 | 
				
			||||||
            png_img = results[ipfs_link]
 | 
					            input_data = results[ipfs_link]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if not png_img:
 | 
					        if input_data == None:
 | 
				
			||||||
            raise DGPUComputeError('Couldn\'t gather input data from ipfs')
 | 
					            raise DGPUComputeError('Couldn\'t gather input data from ipfs')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return png_img
 | 
					        return input_data, input_type
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -18,15 +18,10 @@ from PIL import Image
 | 
				
			||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
 | 
					from basicsr.archs.rrdbnet_arch import RRDBNet
 | 
				
			||||||
from diffusers import (
 | 
					from diffusers import (
 | 
				
			||||||
    DiffusionPipeline,
 | 
					    DiffusionPipeline,
 | 
				
			||||||
    StableDiffusionXLPipeline,
 | 
					 | 
				
			||||||
    StableDiffusionXLImg2ImgPipeline,
 | 
					 | 
				
			||||||
    StableDiffusionPipeline,
 | 
					 | 
				
			||||||
    StableDiffusionImg2ImgPipeline,
 | 
					 | 
				
			||||||
    EulerAncestralDiscreteScheduler
 | 
					    EulerAncestralDiscreteScheduler
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from realesrgan import RealESRGANer
 | 
					from realesrgan import RealESRGANer
 | 
				
			||||||
from huggingface_hub import login
 | 
					from huggingface_hub import login
 | 
				
			||||||
from torch.distributions import weibull
 | 
					 | 
				
			||||||
import trio
 | 
					import trio
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .constants import MODELS
 | 
					from .constants import MODELS
 | 
				
			||||||
| 
						 | 
					@ -56,11 +51,10 @@ def convert_from_img_to_bytes(image: Image, fmt='PNG') -> bytes:
 | 
				
			||||||
    return byte_arr.getvalue()
 | 
					    return byte_arr.getvalue()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def convert_from_bytes_and_crop(raw: bytes, max_w: int, max_h: int) -> Image:
 | 
					def crop_image(image: Image, max_w: int, max_h: int) -> Image:
 | 
				
			||||||
    image = convert_from_bytes_to_img(raw)
 | 
					 | 
				
			||||||
    w, h = image.size
 | 
					    w, h = image.size
 | 
				
			||||||
    if w > max_w or h > max_h:
 | 
					    if w > max_w or h > max_h:
 | 
				
			||||||
        image.thumbnail((512, 512))
 | 
					        image.thumbnail((max_w, max_h))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return image.convert('RGB')
 | 
					    return image.convert('RGB')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -74,7 +68,6 @@ def pipeline_for(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    assert torch.cuda.is_available()
 | 
					    assert torch.cuda.is_available()
 | 
				
			||||||
    torch.cuda.empty_cache()
 | 
					    torch.cuda.empty_cache()
 | 
				
			||||||
    torch.cuda.set_per_process_memory_fraction(mem_fraction)
 | 
					 | 
				
			||||||
    torch.backends.cuda.matmul.allow_tf32 = True
 | 
					    torch.backends.cuda.matmul.allow_tf32 = True
 | 
				
			||||||
    torch.backends.cudnn.allow_tf32 = True
 | 
					    torch.backends.cudnn.allow_tf32 = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -89,6 +82,7 @@ def pipeline_for(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    req_mem = model_info['mem']
 | 
					    req_mem = model_info['mem']
 | 
				
			||||||
    mem_gb = torch.cuda.mem_get_info()[1] / (10**9)
 | 
					    mem_gb = torch.cuda.mem_get_info()[1] / (10**9)
 | 
				
			||||||
 | 
					    mem_gb *= mem_fraction
 | 
				
			||||||
    over_mem = mem_gb < req_mem
 | 
					    over_mem = mem_gb < req_mem
 | 
				
			||||||
    if over_mem:
 | 
					    if over_mem:
 | 
				
			||||||
        logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..')
 | 
					        logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..')
 | 
				
			||||||
| 
						 | 
					@ -96,26 +90,19 @@ def pipeline_for(
 | 
				
			||||||
    shortname = model_info['short']
 | 
					    shortname = model_info['short']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    params = {
 | 
					    params = {
 | 
				
			||||||
        'torch_dtype': torch.float16,
 | 
					 | 
				
			||||||
        'safety_checker': None,
 | 
					        'safety_checker': None,
 | 
				
			||||||
        'cache_dir': cache_dir
 | 
					        'torch_dtype': torch.float16,
 | 
				
			||||||
 | 
					        'cache_dir': cache_dir,
 | 
				
			||||||
 | 
					        'variant': 'fp16'
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if shortname == 'stable':
 | 
					    match shortname:
 | 
				
			||||||
        params['revision'] = 'fp16'
 | 
					        case 'stable':
 | 
				
			||||||
 | 
					            params['revision'] = 'fp16'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if 'xl' in shortname:
 | 
					    torch.cuda.set_per_process_memory_fraction(mem_fraction)
 | 
				
			||||||
        if image:
 | 
					 | 
				
			||||||
            pipe_class = StableDiffusionXLImg2ImgPipeline
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            pipe_class = StableDiffusionXLPipeline
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        if image:
 | 
					 | 
				
			||||||
            pipe_class = StableDiffusionImg2ImgPipeline
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            pipe_class = StableDiffusionPipeline
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    pipe = pipe_class.from_pretrained(
 | 
					    pipe = DiffusionPipeline.from_pretrained(
 | 
				
			||||||
        model, **params)
 | 
					        model, **params)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
 | 
					    pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
 | 
				
			||||||
| 
						 | 
					@ -151,12 +138,6 @@ def txt2img(
 | 
				
			||||||
    steps: int = 28,
 | 
					    steps: int = 28,
 | 
				
			||||||
    seed: Optional[int] = None
 | 
					    seed: Optional[int] = None
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    assert torch.cuda.is_available()
 | 
					 | 
				
			||||||
    torch.cuda.empty_cache()
 | 
					 | 
				
			||||||
    torch.cuda.set_per_process_memory_fraction(1.0)
 | 
					 | 
				
			||||||
    torch.backends.cuda.matmul.allow_tf32 = True
 | 
					 | 
				
			||||||
    torch.backends.cudnn.allow_tf32 = True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    login(token=hf_token)
 | 
					    login(token=hf_token)
 | 
				
			||||||
    pipe = pipeline_for(model)
 | 
					    pipe = pipeline_for(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -184,12 +165,6 @@ def img2img(
 | 
				
			||||||
    steps: int = 28,
 | 
					    steps: int = 28,
 | 
				
			||||||
    seed: Optional[int] = None
 | 
					    seed: Optional[int] = None
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    assert torch.cuda.is_available()
 | 
					 | 
				
			||||||
    torch.cuda.empty_cache()
 | 
					 | 
				
			||||||
    torch.cuda.set_per_process_memory_fraction(1.0)
 | 
					 | 
				
			||||||
    torch.backends.cuda.matmul.allow_tf32 = True
 | 
					 | 
				
			||||||
    torch.backends.cudnn.allow_tf32 = True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    login(token=hf_token)
 | 
					    login(token=hf_token)
 | 
				
			||||||
    pipe = pipeline_for(model, image=True)
 | 
					    pipe = pipeline_for(model, image=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -230,12 +205,6 @@ def upscale(
 | 
				
			||||||
    output: str = 'output.png',
 | 
					    output: str = 'output.png',
 | 
				
			||||||
    model_path: str = 'weights/RealESRGAN_x4plus.pth'
 | 
					    model_path: str = 'weights/RealESRGAN_x4plus.pth'
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    assert torch.cuda.is_available()
 | 
					 | 
				
			||||||
    torch.cuda.empty_cache()
 | 
					 | 
				
			||||||
    torch.cuda.set_per_process_memory_fraction(1.0)
 | 
					 | 
				
			||||||
    torch.backends.cuda.matmul.allow_tf32 = True
 | 
					 | 
				
			||||||
    torch.backends.cudnn.allow_tf32 = True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    input_img = Image.open(img_path).convert('RGB')
 | 
					    input_img = Image.open(img_path).convert('RGB')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    upscaler = init_upscaler(model_path=model_path)
 | 
					    upscaler = init_upscaler(model_path=model_path)
 | 
				
			||||||
| 
						 | 
					@ -258,7 +227,7 @@ async def download_upscaler():
 | 
				
			||||||
        f.write(response.content)
 | 
					        f.write(response.content)
 | 
				
			||||||
    print('done')
 | 
					    print('done')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def download_all_models(hf_token: str):
 | 
					def download_all_models(hf_token: str, hf_home: str):
 | 
				
			||||||
    assert torch.cuda.is_available()
 | 
					    assert torch.cuda.is_available()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    trio.run(download_upscaler)
 | 
					    trio.run(download_upscaler)
 | 
				
			||||||
| 
						 | 
					@ -266,6 +235,4 @@ def download_all_models(hf_token: str):
 | 
				
			||||||
    login(token=hf_token)
 | 
					    login(token=hf_token)
 | 
				
			||||||
    for model in MODELS:
 | 
					    for model in MODELS:
 | 
				
			||||||
        print(f'DOWNLOADING {model.upper()}')
 | 
					        print(f'DOWNLOADING {model.upper()}')
 | 
				
			||||||
        pipeline_for(model)
 | 
					        pipeline_for(model, cache_dir=hf_home)
 | 
				
			||||||
        print(f'DOWNLOADING IMAGE {model.upper()}')
 | 
					 | 
				
			||||||
        pipeline_for(model, image=True)
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue