mirror of https://github.com/skygpu/skynet.git
				
				
				
			First attempt at adding flux models, update all deps, upgrade to cuda 12, add custom pipe sys
							parent
							
								
									00dcccf2bb
								
							
						
					
					
						commit
						07b211514d
					
				| 
						 | 
					@ -0,0 +1,45 @@
 | 
				
			||||||
 | 
					from nvidia/cuda:12.4.1-devel-ubuntu22.04
 | 
				
			||||||
 | 
					from python:3.12
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					env DEBIAN_FRONTEND=noninteractive
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					run apt-get update && apt-get install -y \
 | 
				
			||||||
 | 
					    git \
 | 
				
			||||||
 | 
					    llvm \
 | 
				
			||||||
 | 
					    ffmpeg \
 | 
				
			||||||
 | 
					    libsm6 \
 | 
				
			||||||
 | 
					    libxext6 \
 | 
				
			||||||
 | 
					    ninja-build
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# env CC /usr/bin/clang
 | 
				
			||||||
 | 
					# env CXX /usr/bin/clang++
 | 
				
			||||||
 | 
					# 
 | 
				
			||||||
 | 
					# # install llvm10 as required by llvm-lite
 | 
				
			||||||
 | 
					# run git clone https://github.com/llvm/llvm-project.git -b llvmorg-10.0.1
 | 
				
			||||||
 | 
					# workdir /llvm-project
 | 
				
			||||||
 | 
					# # this adds a commit from 12.0.0 that fixes build on newer compilers
 | 
				
			||||||
 | 
					# run git cherry-pick -n b498303066a63a203d24f739b2d2e0e56dca70d1
 | 
				
			||||||
 | 
					# run cmake -S llvm -B build -G Ninja -DCMAKE_BUILD_TYPE=Release
 | 
				
			||||||
 | 
					# run ninja -C build install  # -j8
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					run curl -sSL https://install.python-poetry.org | python3 -
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					env PATH "/root/.local/bin:$PATH"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					copy . /skynet
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					workdir /skynet
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					env POETRY_VIRTUALENVS_PATH /skynet/.venv
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					run poetry install --with=cuda -v
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					workdir /root/target
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					env PYTORCH_CUDA_ALLOC_CONF max_split_size_mb:128
 | 
				
			||||||
 | 
					env NVIDIA_VISIBLE_DEVICES=all
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					copy docker/entrypoint.sh /entrypoint.sh
 | 
				
			||||||
 | 
					entrypoint ["/entrypoint.sh"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cmd ["skynet", "--help"]
 | 
				
			||||||
| 
						 | 
					@ -1,7 +1,7 @@
 | 
				
			||||||
docker build \
 | 
					docker build \
 | 
				
			||||||
    -t guilledk/skynet:runtime-cuda-py311 \
 | 
					    -t guilledk/skynet:runtime-cuda-py312 \
 | 
				
			||||||
    -f docker/Dockerfile.runtime+cuda-py311 .
 | 
					    -f docker/Dockerfile.runtime+cuda-py312 .
 | 
				
			||||||
 | 
					
 | 
				
			||||||
docker build \
 | 
					# docker build \
 | 
				
			||||||
    -t guilledk/skynet:runtime-cuda \
 | 
					#     -t guilledk/skynet:runtime-cuda \
 | 
				
			||||||
    -f docker/Dockerfile.runtime+cuda-py311 .
 | 
					#     -f docker/Dockerfile.runtime+cuda-py311 .
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| 
						 | 
					@ -1,21 +1,31 @@
 | 
				
			||||||
[tool.poetry]
 | 
					[tool.poetry]
 | 
				
			||||||
name = 'skynet'
 | 
					name = 'skynet'
 | 
				
			||||||
version = '0.1a12'
 | 
					version = '0.1a13'
 | 
				
			||||||
description = 'Decentralized compute platform'
 | 
					description = 'Decentralized compute platform'
 | 
				
			||||||
authors = ['Guillermo Rodriguez <guillermo@telos.net>']
 | 
					authors = ['Guillermo Rodriguez <guillermo@telos.net>']
 | 
				
			||||||
license = 'AGPL'
 | 
					license = 'AGPL'
 | 
				
			||||||
readme = 'README.md'
 | 
					readme = 'README.md'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[tool.poetry.dependencies]
 | 
					[tool.poetry.dependencies]
 | 
				
			||||||
python = '>=3.10,<3.12'
 | 
					python = '>=3.10,<3.13'
 | 
				
			||||||
pytz = '^2023.3.post1'
 | 
					pytz = '^2023.3.post1'
 | 
				
			||||||
trio = '^0.22.2'
 | 
					trio = '^0.22.2'
 | 
				
			||||||
asks = '^3.0.0'
 | 
					asks = '^3.0.0'
 | 
				
			||||||
Pillow = '^10.0.1'
 | 
					Pillow = '^10.0.1'
 | 
				
			||||||
docker = '^6.1.3'
 | 
					docker = '^6.1.3'
 | 
				
			||||||
py-leap = {git = 'https://github.com/guilledk/py-leap.git', rev = 'v0.1a14'}
 | 
					py-leap = {git = 'https://github.com/guilledk/py-leap.git', rev = 'v0.1a32'}
 | 
				
			||||||
toml = '^0.10.2'
 | 
					toml = '^0.10.2'
 | 
				
			||||||
msgspec = "^0.19.0"
 | 
					msgspec = "^0.19.0"
 | 
				
			||||||
 | 
					numpy = "<2.1"
 | 
				
			||||||
 | 
					gguf = "^0.14.0"
 | 
				
			||||||
 | 
					protobuf = "^5.29.3"
 | 
				
			||||||
 | 
					zstandard = "^0.23.0"
 | 
				
			||||||
 | 
					diskcache = "^5.6.3"
 | 
				
			||||||
 | 
					bitsandbytes = "^0.45.0"
 | 
				
			||||||
 | 
					hqq = "^0.2.2"
 | 
				
			||||||
 | 
					optimum-quanto = "^0.2.6"
 | 
				
			||||||
 | 
					basicsr = "^1.4.2"
 | 
				
			||||||
 | 
					realesrgan = "^0.3.0"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[tool.poetry.group.frontend]
 | 
					[tool.poetry.group.frontend]
 | 
				
			||||||
optional = true
 | 
					optional = true
 | 
				
			||||||
| 
						 | 
					@ -39,26 +49,24 @@ pytest-trio = "^0.8.0"
 | 
				
			||||||
optional = true
 | 
					optional = true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[tool.poetry.group.cuda.dependencies]
 | 
					[tool.poetry.group.cuda.dependencies]
 | 
				
			||||||
torch = {version = '2.0.1+cu118', source = 'torch'}
 | 
					torch = {version = '2.5.1+cu121', source = 'torch'}
 | 
				
			||||||
scipy = {version = '^1.11.2'}
 | 
					scipy = {version = '1.15.1'}
 | 
				
			||||||
numba = {version = '0.57.0'}
 | 
					numba = {version = '0.60.0'}
 | 
				
			||||||
quart = {version = '^0.19.3'}
 | 
					quart = {version = '^0.19.3'}
 | 
				
			||||||
triton = {version = '2.0.0', source = 'torch'}
 | 
					triton = {version = '3.1.0', source = 'torch'}
 | 
				
			||||||
basicsr = {version = '^1.4.2'}
 | 
					xformers = {version = '^0.0.29'}
 | 
				
			||||||
xformers = {version = '^0.0.22'}
 | 
					 | 
				
			||||||
hypercorn = {version = '^0.14.4'}
 | 
					hypercorn = {version = '^0.14.4'}
 | 
				
			||||||
diffusers = {version = '^0.21.2'}
 | 
					diffusers = {version = '0.32.1'}
 | 
				
			||||||
realesrgan = {version = '^0.3.0'}
 | 
					 | 
				
			||||||
quart-trio = {version = '^0.11.0'}
 | 
					quart-trio = {version = '^0.11.0'}
 | 
				
			||||||
torchvision = {version = '0.15.2+cu118', source = 'torch'}
 | 
					torchvision = {version = '0.20.1+cu121', source = 'torch'}
 | 
				
			||||||
accelerate = {version = '^0.23.0'}
 | 
					accelerate = {version = '0.34.0'}
 | 
				
			||||||
transformers = {version = '^4.33.2'}
 | 
					transformers = {version = '4.48.0'}
 | 
				
			||||||
huggingface-hub = {version = '^0.17.3'}
 | 
					huggingface-hub = {version = '^0.27.1'}
 | 
				
			||||||
invisible-watermark = {version = '^0.2.0'}
 | 
					invisible-watermark = {version = '^0.2.0'}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[[tool.poetry.source]]
 | 
					[[tool.poetry.source]]
 | 
				
			||||||
name = 'torch'
 | 
					name = 'torch'
 | 
				
			||||||
url = 'https://download.pytorch.org/whl/cu118'
 | 
					url = 'https://download.pytorch.org/whl/cu121'
 | 
				
			||||||
priority = 'explicit'
 | 
					priority = 'explicit'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[build-system]
 | 
					[build-system]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -8,7 +8,7 @@ from functools import partial
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import click
 | 
					import click
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from leap.sugar import Name, asset_from_str
 | 
					from leap.protocol import Name, Asset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .config import *
 | 
					from .config import *
 | 
				
			||||||
from .constants import *
 | 
					from .constants import *
 | 
				
			||||||
| 
						 | 
					@ -178,7 +178,7 @@ def enqueue(
 | 
				
			||||||
                    'user': Name(account),
 | 
					                    'user': Name(account),
 | 
				
			||||||
                    'request_body': req,
 | 
					                    'request_body': req,
 | 
				
			||||||
                    'binary_data': binary,
 | 
					                    'binary_data': binary,
 | 
				
			||||||
                    'reward': asset_from_str(reward),
 | 
					                    'reward': Asset.from_str(reward),
 | 
				
			||||||
                    'min_verification': 1
 | 
					                    'min_verification': 1
 | 
				
			||||||
                },
 | 
					                },
 | 
				
			||||||
                account, key, permission,
 | 
					                account, key, permission,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -78,8 +78,20 @@ MODELS: dict[str, ModelDesc] = {
 | 
				
			||||||
        size=Size(w=512, h=512),
 | 
					        size=Size(w=512, h=512),
 | 
				
			||||||
        tags=['txt2img']
 | 
					        tags=['txt2img']
 | 
				
			||||||
    ),
 | 
					    ),
 | 
				
			||||||
 | 
					    'black-forest-labs/FLUX.1-schnell': ModelDesc(
 | 
				
			||||||
 | 
					        short='flux',
 | 
				
			||||||
 | 
					        mem=24,
 | 
				
			||||||
 | 
					        size=Size(w=1024, h=1024),
 | 
				
			||||||
 | 
					        tags=['txt2img']
 | 
				
			||||||
 | 
					    ),
 | 
				
			||||||
 | 
					    'black-forest-labs/FLUX.1-Fill-dev': ModelDesc(
 | 
				
			||||||
 | 
					        short='flux-inpaint',
 | 
				
			||||||
 | 
					        mem=24,
 | 
				
			||||||
 | 
					        size=Size(w=1024, h=1024),
 | 
				
			||||||
 | 
					        tags=['inpaint']
 | 
				
			||||||
 | 
					    ),
 | 
				
			||||||
    'diffusers/stable-diffusion-xl-1.0-inpainting-0.1': ModelDesc(
 | 
					    'diffusers/stable-diffusion-xl-1.0-inpainting-0.1': ModelDesc(
 | 
				
			||||||
        short='stablexl-inpainting',
 | 
					        short='stablexl-inpaint',
 | 
				
			||||||
        mem=8.3,
 | 
					        mem=8.3,
 | 
				
			||||||
        size=Size(w=1024, h=1024),
 | 
					        size=Size(w=1024, h=1024),
 | 
				
			||||||
        tags=['inpaint']
 | 
					        tags=['inpaint']
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -18,7 +18,6 @@ from skynet.dgpu.errors import DGPUComputeError, DGPUInferenceCancelled
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
 | 
					from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
def prepare_params_for_diffuse(
 | 
					def prepare_params_for_diffuse(
 | 
				
			||||||
    params: dict,
 | 
					    params: dict,
 | 
				
			||||||
    mode: str,
 | 
					    mode: str,
 | 
				
			||||||
| 
						 | 
					@ -35,6 +34,10 @@ def prepare_params_for_diffuse(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            _params['image'] = image
 | 
					            _params['image'] = image
 | 
				
			||||||
            _params['mask_image'] = mask
 | 
					            _params['mask_image'] = mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if 'flux' in params['model'].lower():
 | 
				
			||||||
 | 
					                _params['max_sequence_length'] = 512
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
                _params['strength'] = float(params['strength'])
 | 
					                _params['strength'] = float(params['strength'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        case 'img2img':
 | 
					        case 'img2img':
 | 
				
			||||||
| 
						 | 
					@ -66,8 +69,6 @@ def prepare_params_for_diffuse(
 | 
				
			||||||
class SkynetMM:
 | 
					class SkynetMM:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, config: dict):
 | 
					    def __init__(self, config: dict):
 | 
				
			||||||
        self.upscaler = init_upscaler()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.cache_dir = None
 | 
					        self.cache_dir = None
 | 
				
			||||||
        if 'hf_home' in config:
 | 
					        if 'hf_home' in config:
 | 
				
			||||||
            self.cache_dir = config['hf_home']
 | 
					            self.cache_dir = config['hf_home']
 | 
				
			||||||
| 
						 | 
					@ -88,30 +89,28 @@ class SkynetMM:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return False
 | 
					        return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def load_model(
 | 
					    def unload_model(self):
 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        name: str,
 | 
					 | 
				
			||||||
        mode: str
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        logging.info(f'loading model {name}...')
 | 
					 | 
				
			||||||
        self._model_mode = mode
 | 
					 | 
				
			||||||
        self._model_name = name
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if getattr(self, '_model', None):
 | 
					        if getattr(self, '_model', None):
 | 
				
			||||||
            del self._model
 | 
					            del self._model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        gc.collect()
 | 
					        gc.collect()
 | 
				
			||||||
        torch.cuda.empty_cache()
 | 
					        torch.cuda.empty_cache()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._model_name = ''
 | 
				
			||||||
 | 
					        self._model_mode = ''
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def load_model(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        name: str,
 | 
				
			||||||
 | 
					        mode: str
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        logging.info(f'loading model {name}...')
 | 
				
			||||||
 | 
					        self.unload_model()
 | 
				
			||||||
        self._model = pipeline_for(
 | 
					        self._model = pipeline_for(
 | 
				
			||||||
            name, mode, cache_dir=self.cache_dir)
 | 
					            name, mode, cache_dir=self.cache_dir)
 | 
				
			||||||
 | 
					        self._model_mode = mode
 | 
				
			||||||
 | 
					        self._model_name = name
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_model(self, name: str, mode: str) -> DiffusionPipeline:
 | 
					 | 
				
			||||||
        if name not in MODELS:
 | 
					 | 
				
			||||||
            raise DGPUComputeError(f'Unknown model {model_name}')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if not self.is_model_loaded(name, mode):
 | 
					 | 
				
			||||||
            self.load_model(name, mode)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def compute_one(
 | 
					    def compute_one(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
| 
						 | 
					@ -127,6 +126,8 @@ class SkynetMM:
 | 
				
			||||||
                    logging.warn(f'cancelling work at step {step}')
 | 
					                    logging.warn(f'cancelling work at step {step}')
 | 
				
			||||||
                    raise DGPUInferenceCancelled()
 | 
					                    raise DGPUInferenceCancelled()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            return {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        maybe_cancel_work(0)
 | 
					        maybe_cancel_work(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        output_type = 'png'
 | 
					        output_type = 'png'
 | 
				
			||||||
| 
						 | 
					@ -136,23 +137,29 @@ class SkynetMM:
 | 
				
			||||||
        output = None
 | 
					        output = None
 | 
				
			||||||
        output_hash = None
 | 
					        output_hash = None
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
 | 
					            name = params['model']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            match method:
 | 
					            match method:
 | 
				
			||||||
                case 'diffuse' | 'txt2img' | 'img2img' | 'inpaint':
 | 
					                case 'diffuse' | 'txt2img' | 'img2img' | 'inpaint':
 | 
				
			||||||
 | 
					                    if not self.is_model_loaded(name, method):
 | 
				
			||||||
 | 
					                        self.load_model(name, method)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    arguments = prepare_params_for_diffuse(
 | 
					                    arguments = prepare_params_for_diffuse(
 | 
				
			||||||
                        params, method, inputs)
 | 
					                        params, method, inputs)
 | 
				
			||||||
                    prompt, guidance, step, seed, upscaler, extra_params = arguments
 | 
					                    prompt, guidance, step, seed, upscaler, extra_params = arguments
 | 
				
			||||||
                    self.get_model(
 | 
					
 | 
				
			||||||
                        params['model'],
 | 
					                    if 'flux' in name.lower():
 | 
				
			||||||
                        method
 | 
					                        extra_params['callback_on_step_end'] = maybe_cancel_work
 | 
				
			||||||
                    )
 | 
					
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
 | 
					                        extra_params['callback'] = maybe_cancel_work
 | 
				
			||||||
 | 
					                        extra_params['callback_steps'] = 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    output = self._model(
 | 
					                    output = self._model(
 | 
				
			||||||
                        prompt,
 | 
					                        prompt,
 | 
				
			||||||
                        guidance_scale=guidance,
 | 
					                        guidance_scale=guidance,
 | 
				
			||||||
                        num_inference_steps=step,
 | 
					                        num_inference_steps=step,
 | 
				
			||||||
                        generator=seed,
 | 
					                        generator=seed,
 | 
				
			||||||
                        callback=maybe_cancel_work,
 | 
					 | 
				
			||||||
                        callback_steps=1,
 | 
					 | 
				
			||||||
                        **extra_params
 | 
					                        **extra_params
 | 
				
			||||||
                    ).images[0]
 | 
					                    ).images[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -161,7 +168,7 @@ class SkynetMM:
 | 
				
			||||||
                        case 'png':
 | 
					                        case 'png':
 | 
				
			||||||
                            if upscaler == 'x4':
 | 
					                            if upscaler == 'x4':
 | 
				
			||||||
                                input_img = output.convert('RGB')
 | 
					                                input_img = output.convert('RGB')
 | 
				
			||||||
                                up_img, _ = self.upscaler.enhance(
 | 
					                                up_img, _ = init_upscaler().enhance(
 | 
				
			||||||
                                    convert_from_image_to_cv2(input_img), outscale=4)
 | 
					                                    convert_from_image_to_cv2(input_img), outscale=4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                                output = convert_from_cv2_to_image(up_img)
 | 
					                                output = convert_from_cv2_to_image(up_img)
 | 
				
			||||||
| 
						 | 
					@ -173,6 +180,22 @@ class SkynetMM:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    output_hash = sha256(output_binary).hexdigest()
 | 
					                    output_hash = sha256(output_binary).hexdigest()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                case 'upscale':
 | 
				
			||||||
 | 
					                    if self._model_mode != 'upscale':
 | 
				
			||||||
 | 
					                        self.unload_model()
 | 
				
			||||||
 | 
					                        self._model = init_upscaler()
 | 
				
			||||||
 | 
					                        self._model_mode = 'upscale'
 | 
				
			||||||
 | 
					                        self._model_name = 'realesrgan'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    input_img = inputs[0].convert('RGB')
 | 
				
			||||||
 | 
					                    up_img, _ = self._model.enhance(
 | 
				
			||||||
 | 
					                        convert_from_image_to_cv2(input_img), outscale=4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    output = convert_from_cv2_to_image(up_img)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    output_binary = convert_from_img_to_bytes(output)
 | 
				
			||||||
 | 
					                    output_hash = sha256(output_binary).hexdigest()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                case _:
 | 
					                case _:
 | 
				
			||||||
                    raise DGPUComputeError('Unsupported compute method')
 | 
					                    raise DGPUComputeError('Unsupported compute method')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -125,7 +125,7 @@ class SkynetDGPUDaemon:
 | 
				
			||||||
        model = body['params']['model']
 | 
					        model = body['params']['model']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # if model not known
 | 
					        # if model not known
 | 
				
			||||||
        if model not in MODELS:
 | 
					        if model != 'RealESRGAN_x4plus' and model not in MODELS:
 | 
				
			||||||
            logging.warning(f'Unknown model {model}')
 | 
					            logging.warning(f'Unknown model {model}')
 | 
				
			||||||
            return False
 | 
					            return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -143,11 +143,17 @@ class SkynetDGPUDaemon:
 | 
				
			||||||
            statuses = self._snap['requests'][rid]
 | 
					            statuses = self._snap['requests'][rid]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if len(statuses) == 0:
 | 
					            if len(statuses) == 0:
 | 
				
			||||||
                inputs = [
 | 
					                inputs = []
 | 
				
			||||||
                    await self.conn.get_input_data(_input)
 | 
					                for _input in req['binary_data'].split(','):
 | 
				
			||||||
                    for _input in req['binary_data'].split(',')
 | 
					                    if _input:
 | 
				
			||||||
                    if _input
 | 
					                        for _ in range(3):
 | 
				
			||||||
                ]
 | 
					                            try:
 | 
				
			||||||
 | 
					                                img = await self.conn.get_input_data(_input)
 | 
				
			||||||
 | 
					                                inputs.append(img)
 | 
				
			||||||
 | 
					                                break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                            except:
 | 
				
			||||||
 | 
					                                ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                hash_str = (
 | 
					                hash_str = (
 | 
				
			||||||
                    str(req['nonce'])
 | 
					                    str(req['nonce'])
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -15,7 +15,7 @@ import anyio
 | 
				
			||||||
from PIL import Image, UnidentifiedImageError
 | 
					from PIL import Image, UnidentifiedImageError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from leap.cleos import CLEOS
 | 
					from leap.cleos import CLEOS
 | 
				
			||||||
from leap.sugar import Checksum256, Name, asset_from_str
 | 
					from leap.protocol import Asset
 | 
				
			||||||
from skynet.constants import DEFAULT_IPFS_DOMAIN
 | 
					from skynet.constants import DEFAULT_IPFS_DOMAIN
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from skynet.ipfs import AsyncIPFSHTTP, get_ipfs_file
 | 
					from skynet.ipfs import AsyncIPFSHTTP, get_ipfs_file
 | 
				
			||||||
| 
						 | 
					@ -24,6 +24,225 @@ from skynet.dgpu.errors import DGPUComputeError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
REQUEST_UPDATE_TIME = 3
 | 
					REQUEST_UPDATE_TIME = 3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					gpu_abi = {
 | 
				
			||||||
 | 
					    "version": "eosio::abi/1.2",
 | 
				
			||||||
 | 
					    "types": [],
 | 
				
			||||||
 | 
					    "structs": [
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "name": "account",
 | 
				
			||||||
 | 
					            "base": "",
 | 
				
			||||||
 | 
					            "fields": [
 | 
				
			||||||
 | 
					                {"name": "user", "type": "name"},
 | 
				
			||||||
 | 
					                {"name": "balance", "type": "asset"},
 | 
				
			||||||
 | 
					                {"name": "nonce", "type": "uint64"}
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "name": "card",
 | 
				
			||||||
 | 
					            "base": "",
 | 
				
			||||||
 | 
					            "fields": [
 | 
				
			||||||
 | 
					                {"name": "id", "type": "uint64"},
 | 
				
			||||||
 | 
					                {"name": "owner", "type": "name"},
 | 
				
			||||||
 | 
					                {"name": "card_name", "type": "string"},
 | 
				
			||||||
 | 
					                {"name": "version", "type": "string"},
 | 
				
			||||||
 | 
					                {"name": "total_memory", "type": "uint64"},
 | 
				
			||||||
 | 
					                {"name": "mp_count", "type": "uint32"},
 | 
				
			||||||
 | 
					                {"name": "extra", "type": "string"}
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "name": "clean",
 | 
				
			||||||
 | 
					            "base": "",
 | 
				
			||||||
 | 
					            "fields": []
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "name": "config",
 | 
				
			||||||
 | 
					            "base": "",
 | 
				
			||||||
 | 
					            "fields": [
 | 
				
			||||||
 | 
					                {"name": "token_contract", "type": "name"},
 | 
				
			||||||
 | 
					                {"name": "token_symbol", "type": "symbol"}
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "name": "dequeue",
 | 
				
			||||||
 | 
					            "base": "",
 | 
				
			||||||
 | 
					            "fields": [
 | 
				
			||||||
 | 
					                {"name": "user", "type": "name"},
 | 
				
			||||||
 | 
					                {"name": "request_id", "type": "uint64"}
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "name": "enqueue",
 | 
				
			||||||
 | 
					            "base": "",
 | 
				
			||||||
 | 
					            "fields": [
 | 
				
			||||||
 | 
					                {"name": "user", "type": "name"},
 | 
				
			||||||
 | 
					                {"name": "request_body", "type": "string"},
 | 
				
			||||||
 | 
					                {"name": "binary_data", "type": "string"},
 | 
				
			||||||
 | 
					                {"name": "reward", "type": "asset"},
 | 
				
			||||||
 | 
					                {"name": "min_verification", "type": "uint32"}
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "name": "gcfgstruct",
 | 
				
			||||||
 | 
					            "base": "",
 | 
				
			||||||
 | 
					            "fields": [
 | 
				
			||||||
 | 
					                {"name": "token_contract", "type": "name"},
 | 
				
			||||||
 | 
					                {"name": "token_symbol", "type": "symbol"}
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "name": "submit",
 | 
				
			||||||
 | 
					            "base": "",
 | 
				
			||||||
 | 
					            "fields": [
 | 
				
			||||||
 | 
					                {"name": "worker", "type": "name"},
 | 
				
			||||||
 | 
					                {"name": "request_id", "type": "uint64"},
 | 
				
			||||||
 | 
					                {"name": "request_hash", "type": "checksum256"},
 | 
				
			||||||
 | 
					                {"name": "result_hash", "type": "checksum256"},
 | 
				
			||||||
 | 
					                {"name": "ipfs_hash", "type": "string"}
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "name": "withdraw",
 | 
				
			||||||
 | 
					            "base": "",
 | 
				
			||||||
 | 
					            "fields": [
 | 
				
			||||||
 | 
					                {"name": "user", "type": "name"},
 | 
				
			||||||
 | 
					                {"name": "quantity", "type": "asset"}
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "name": "work_request_struct",
 | 
				
			||||||
 | 
					            "base": "",
 | 
				
			||||||
 | 
					            "fields": [
 | 
				
			||||||
 | 
					                {"name": "id", "type": "uint64"},
 | 
				
			||||||
 | 
					                {"name": "user", "type": "name"},
 | 
				
			||||||
 | 
					                {"name": "reward", "type": "asset"},
 | 
				
			||||||
 | 
					                {"name": "min_verification", "type": "uint32"},
 | 
				
			||||||
 | 
					                {"name": "nonce", "type": "uint64"},
 | 
				
			||||||
 | 
					                {"name": "body", "type": "string"},
 | 
				
			||||||
 | 
					                {"name": "binary_data", "type": "string"},
 | 
				
			||||||
 | 
					                {"name": "timestamp", "type": "time_point_sec"}
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "name": "work_result_struct",
 | 
				
			||||||
 | 
					            "base": "",
 | 
				
			||||||
 | 
					            "fields": [
 | 
				
			||||||
 | 
					                {"name": "id", "type": "uint64"},
 | 
				
			||||||
 | 
					                {"name": "request_id", "type": "uint64"},
 | 
				
			||||||
 | 
					                {"name": "user", "type": "name"},
 | 
				
			||||||
 | 
					                {"name": "worker", "type": "name"},
 | 
				
			||||||
 | 
					                {"name": "result_hash", "type": "checksum256"},
 | 
				
			||||||
 | 
					                {"name": "ipfs_hash", "type": "string"},
 | 
				
			||||||
 | 
					                {"name": "submited", "type": "time_point_sec"}
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "name": "workbegin",
 | 
				
			||||||
 | 
					            "base": "",
 | 
				
			||||||
 | 
					            "fields": [
 | 
				
			||||||
 | 
					                {"name": "worker", "type": "name"},
 | 
				
			||||||
 | 
					                {"name": "request_id", "type": "uint64"},
 | 
				
			||||||
 | 
					                {"name": "max_workers", "type": "uint32"}
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "name": "workcancel",
 | 
				
			||||||
 | 
					            "base": "",
 | 
				
			||||||
 | 
					            "fields": [
 | 
				
			||||||
 | 
					                {"name": "worker", "type": "name"},
 | 
				
			||||||
 | 
					                {"name": "request_id", "type": "uint64"},
 | 
				
			||||||
 | 
					                {"name": "reason", "type": "string"}
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "name": "worker",
 | 
				
			||||||
 | 
					            "base": "",
 | 
				
			||||||
 | 
					            "fields": [
 | 
				
			||||||
 | 
					                {"name": "account", "type": "name"},
 | 
				
			||||||
 | 
					                {"name": "joined", "type": "time_point_sec"},
 | 
				
			||||||
 | 
					                {"name": "left", "type": "time_point_sec"},
 | 
				
			||||||
 | 
					                {"name": "url", "type": "string"}
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "name": "worker_status_struct",
 | 
				
			||||||
 | 
					            "base": "",
 | 
				
			||||||
 | 
					            "fields": [
 | 
				
			||||||
 | 
					                {"name": "worker", "type": "name"},
 | 
				
			||||||
 | 
					                {"name": "status", "type": "string"},
 | 
				
			||||||
 | 
					                {"name": "started", "type": "time_point_sec"}
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    "actions": [
 | 
				
			||||||
 | 
					        {"name": "clean", "type": "clean", "ricardian_contract": ""},
 | 
				
			||||||
 | 
					        {"name": "config", "type": "config", "ricardian_contract": ""},
 | 
				
			||||||
 | 
					        {"name": "dequeue", "type": "dequeue", "ricardian_contract": ""},
 | 
				
			||||||
 | 
					        {"name": "enqueue", "type": "enqueue", "ricardian_contract": ""},
 | 
				
			||||||
 | 
					        {"name": "submit", "type": "submit", "ricardian_contract": ""},
 | 
				
			||||||
 | 
					        {"name": "withdraw", "type": "withdraw", "ricardian_contract": ""},
 | 
				
			||||||
 | 
					        {"name": "workbegin", "type": "workbegin", "ricardian_contract": ""},
 | 
				
			||||||
 | 
					        {"name": "workcancel", "type": "workcancel", "ricardian_contract": ""}
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    "tables": [
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "name": "cards",
 | 
				
			||||||
 | 
					            "index_type": "i64",
 | 
				
			||||||
 | 
					            "key_names": [],
 | 
				
			||||||
 | 
					            "key_types": [],
 | 
				
			||||||
 | 
					            "type": "card"
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "name": "gcfgstruct",
 | 
				
			||||||
 | 
					            "index_type": "i64",
 | 
				
			||||||
 | 
					            "key_names": [],
 | 
				
			||||||
 | 
					            "key_types": [],
 | 
				
			||||||
 | 
					            "type": "gcfgstruct"
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "name": "queue",
 | 
				
			||||||
 | 
					            "index_type": "i64",
 | 
				
			||||||
 | 
					            "key_names": [],
 | 
				
			||||||
 | 
					            "key_types": [],
 | 
				
			||||||
 | 
					            "type": "work_request_struct"
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "name": "results",
 | 
				
			||||||
 | 
					            "index_type": "i64",
 | 
				
			||||||
 | 
					            "key_names": [],
 | 
				
			||||||
 | 
					            "key_types": [],
 | 
				
			||||||
 | 
					            "type": "work_result_struct"
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "name": "status",
 | 
				
			||||||
 | 
					            "index_type": "i64",
 | 
				
			||||||
 | 
					            "key_names": [],
 | 
				
			||||||
 | 
					            "key_types": [],
 | 
				
			||||||
 | 
					            "type": "worker_status_struct"
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "name": "users",
 | 
				
			||||||
 | 
					            "index_type": "i64",
 | 
				
			||||||
 | 
					            "key_names": [],
 | 
				
			||||||
 | 
					            "key_types": [],
 | 
				
			||||||
 | 
					            "type": "account"
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "name": "workers",
 | 
				
			||||||
 | 
					            "index_type": "i64",
 | 
				
			||||||
 | 
					            "key_names": [],
 | 
				
			||||||
 | 
					            "key_types": [],
 | 
				
			||||||
 | 
					            "type": "worker"
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    "ricardian_clauses": [],
 | 
				
			||||||
 | 
					    "error_messages": [],
 | 
				
			||||||
 | 
					    "abi_extensions": [],
 | 
				
			||||||
 | 
					    "variants": [],
 | 
				
			||||||
 | 
					    "action_results": []
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def failable(fn: partial, ret_fail=None):
 | 
					async def failable(fn: partial, ret_fail=None):
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
| 
						 | 
					@ -35,22 +254,22 @@ async def failable(fn: partial, ret_fail=None):
 | 
				
			||||||
        asks.errors.RequestTimeout,
 | 
					        asks.errors.RequestTimeout,
 | 
				
			||||||
        asks.errors.BadHttpResponse,
 | 
					        asks.errors.BadHttpResponse,
 | 
				
			||||||
        anyio.BrokenResourceError
 | 
					        anyio.BrokenResourceError
 | 
				
			||||||
    ):
 | 
					    ) as e:
 | 
				
			||||||
        return ret_fail
 | 
					        return ret_fail
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class SkynetGPUConnector:
 | 
					class SkynetGPUConnector:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, config: dict):
 | 
					    def __init__(self, config: dict):
 | 
				
			||||||
        self.account = Name(config['account'])
 | 
					        self.account = config['account']
 | 
				
			||||||
        self.permission = config['permission']
 | 
					        self.permission = config['permission']
 | 
				
			||||||
        self.key = config['key']
 | 
					        self.key = config['key']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.node_url = config['node_url']
 | 
					        self.node_url = config['node_url']
 | 
				
			||||||
        self.hyperion_url = config['hyperion_url']
 | 
					        self.hyperion_url = config['hyperion_url']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.cleos = CLEOS(
 | 
					        self.cleos = CLEOS(endpoint=self.node_url)
 | 
				
			||||||
            None, None, self.node_url, remote=self.node_url)
 | 
					        self.cleos.load_abi('gpu.scd', gpu_abi)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.ipfs_gateway_url = None
 | 
					        self.ipfs_gateway_url = None
 | 
				
			||||||
        if 'ipfs_gateway_url' in config:
 | 
					        if 'ipfs_gateway_url' in config:
 | 
				
			||||||
| 
						 | 
					@ -151,11 +370,11 @@ class SkynetGPUConnector:
 | 
				
			||||||
                self.cleos.a_push_action,
 | 
					                self.cleos.a_push_action,
 | 
				
			||||||
                'gpu.scd',
 | 
					                'gpu.scd',
 | 
				
			||||||
                'workbegin',
 | 
					                'workbegin',
 | 
				
			||||||
                {
 | 
					                list({
 | 
				
			||||||
                    'worker': self.account,
 | 
					                    'worker': self.account,
 | 
				
			||||||
                    'request_id': request_id,
 | 
					                    'request_id': request_id,
 | 
				
			||||||
                    'max_workers': 2
 | 
					                    'max_workers': 2
 | 
				
			||||||
                },
 | 
					                }.values()),
 | 
				
			||||||
                self.account, self.key,
 | 
					                self.account, self.key,
 | 
				
			||||||
                permission=self.permission
 | 
					                permission=self.permission
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
| 
						 | 
					@ -168,11 +387,11 @@ class SkynetGPUConnector:
 | 
				
			||||||
                self.cleos.a_push_action,
 | 
					                self.cleos.a_push_action,
 | 
				
			||||||
                'gpu.scd',
 | 
					                'gpu.scd',
 | 
				
			||||||
                'workcancel',
 | 
					                'workcancel',
 | 
				
			||||||
                {
 | 
					                list({
 | 
				
			||||||
                    'worker': self.account,
 | 
					                    'worker': self.account,
 | 
				
			||||||
                    'request_id': request_id,
 | 
					                    'request_id': request_id,
 | 
				
			||||||
                    'reason': reason
 | 
					                    'reason': reason
 | 
				
			||||||
                },
 | 
					                }.values()),
 | 
				
			||||||
                self.account, self.key,
 | 
					                self.account, self.key,
 | 
				
			||||||
                permission=self.permission
 | 
					                permission=self.permission
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
| 
						 | 
					@ -191,10 +410,10 @@ class SkynetGPUConnector:
 | 
				
			||||||
                    self.cleos.a_push_action,
 | 
					                    self.cleos.a_push_action,
 | 
				
			||||||
                    'gpu.scd',
 | 
					                    'gpu.scd',
 | 
				
			||||||
                    'withdraw',
 | 
					                    'withdraw',
 | 
				
			||||||
                    {
 | 
					                    list({
 | 
				
			||||||
                        'user': self.account,
 | 
					                        'user': self.account,
 | 
				
			||||||
                        'quantity': asset_from_str(balance)
 | 
					                        'quantity': Asset.from_str(balance)
 | 
				
			||||||
                    },
 | 
					                    }.values()),
 | 
				
			||||||
                    self.account, self.key,
 | 
					                    self.account, self.key,
 | 
				
			||||||
                    permission=self.permission
 | 
					                    permission=self.permission
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
| 
						 | 
					@ -226,13 +445,13 @@ class SkynetGPUConnector:
 | 
				
			||||||
                self.cleos.a_push_action,
 | 
					                self.cleos.a_push_action,
 | 
				
			||||||
                'gpu.scd',
 | 
					                'gpu.scd',
 | 
				
			||||||
                'submit',
 | 
					                'submit',
 | 
				
			||||||
                {
 | 
					                list({
 | 
				
			||||||
                    'worker': self.account,
 | 
					                    'worker': self.account,
 | 
				
			||||||
                    'request_id': request_id,
 | 
					                    'request_id': request_id,
 | 
				
			||||||
                    'request_hash': Checksum256(request_hash),
 | 
					                    'request_hash': request_hash,
 | 
				
			||||||
                    'result_hash': Checksum256(result_hash),
 | 
					                    'result_hash': result_hash,
 | 
				
			||||||
                    'ipfs_hash': ipfs_hash
 | 
					                    'ipfs_hash': ipfs_hash
 | 
				
			||||||
                },
 | 
					                }.values()),
 | 
				
			||||||
                self.account, self.key,
 | 
					                self.account, self.key,
 | 
				
			||||||
                permission=self.permission
 | 
					                permission=self.permission
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,50 @@
 | 
				
			||||||
 | 
					#!/usr/bin/python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from diffusers import (
 | 
				
			||||||
 | 
					    DiffusionPipeline,
 | 
				
			||||||
 | 
					    FluxPipeline,
 | 
				
			||||||
 | 
					    FluxTransformer2DModel
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from transformers import T5EncoderModel, BitsAndBytesConfig
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from huggingface_hub import hf_hub_download
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					__model = {
 | 
				
			||||||
 | 
					    'name': 'black-forest-labs/FLUX.1-schnell'
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def pipeline_for(
 | 
				
			||||||
 | 
					    model: str,
 | 
				
			||||||
 | 
					    mode: str,
 | 
				
			||||||
 | 
					    mem_fraction: float = 1.0,
 | 
				
			||||||
 | 
					    cache_dir: str | None = None
 | 
				
			||||||
 | 
					) -> DiffusionPipeline:
 | 
				
			||||||
 | 
					    qonfig = BitsAndBytesConfig(
 | 
				
			||||||
 | 
					        load_in_4bit=True,
 | 
				
			||||||
 | 
					        bnb_4bit_quant_type="nf4",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    params = {
 | 
				
			||||||
 | 
					        'torch_dtype': torch.bfloat16,
 | 
				
			||||||
 | 
					        'cache_dir': cache_dir,
 | 
				
			||||||
 | 
					        'device_map': 'balanced',
 | 
				
			||||||
 | 
					        'max_memory': {'cpu': '10GiB', 0: '11GiB'}
 | 
				
			||||||
 | 
					        # 'max_memory': {0: '11GiB'}
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    text_encoder = T5EncoderModel.from_pretrained(
 | 
				
			||||||
 | 
					        'black-forest-labs/FLUX.1-schnell',
 | 
				
			||||||
 | 
					        subfolder="text_encoder_2",
 | 
				
			||||||
 | 
					        torch_dtype=torch.bfloat16,
 | 
				
			||||||
 | 
					        quantization_config=qonfig
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    params['text_encoder_2'] = text_encoder
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pipe = FluxPipeline.from_pretrained(
 | 
				
			||||||
 | 
					        model, **params)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pipe.vae.enable_tiling()
 | 
				
			||||||
 | 
					    pipe.vae.enable_slicing()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return pipe
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,55 @@
 | 
				
			||||||
 | 
					#!/usr/bin/python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from diffusers import (
 | 
				
			||||||
 | 
					    DiffusionPipeline,
 | 
				
			||||||
 | 
					    FluxFillPipeline,
 | 
				
			||||||
 | 
					    FluxTransformer2DModel
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from transformers import T5EncoderModel, BitsAndBytesConfig
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					__model = {
 | 
				
			||||||
 | 
					    'name': 'black-forest-labs/FLUX.1-Fill-dev'
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def pipeline_for(
 | 
				
			||||||
 | 
					    model: str,
 | 
				
			||||||
 | 
					    mode: str,
 | 
				
			||||||
 | 
					    mem_fraction: float = 1.0,
 | 
				
			||||||
 | 
					    cache_dir: str | None = None
 | 
				
			||||||
 | 
					) -> DiffusionPipeline:
 | 
				
			||||||
 | 
					    qonfig = BitsAndBytesConfig(
 | 
				
			||||||
 | 
					        load_in_4bit=True,
 | 
				
			||||||
 | 
					        bnb_4bit_quant_type="nf4",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    params = {
 | 
				
			||||||
 | 
					        'torch_dtype': torch.bfloat16,
 | 
				
			||||||
 | 
					        'cache_dir': cache_dir,
 | 
				
			||||||
 | 
					        'device_map': 'balanced',
 | 
				
			||||||
 | 
					        'max_memory': {'cpu': '10GiB', 0: '11GiB'}
 | 
				
			||||||
 | 
					        # 'max_memory': {0: '11GiB'}
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    text_encoder = T5EncoderModel.from_pretrained(
 | 
				
			||||||
 | 
					        'sayakpaul/FLUX.1-Fill-dev-nf4',
 | 
				
			||||||
 | 
					        subfolder="text_encoder_2",
 | 
				
			||||||
 | 
					        torch_dtype=torch.bfloat16,
 | 
				
			||||||
 | 
					        quantization_config=qonfig
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    params['text_encoder_2'] = text_encoder
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    transformer = FluxTransformer2DModel.from_pretrained(
 | 
				
			||||||
 | 
					        'sayakpaul/FLUX.1-Fill-dev-nf4',
 | 
				
			||||||
 | 
					        subfolder="transformer",
 | 
				
			||||||
 | 
					        torch_dtype=torch.bfloat16,
 | 
				
			||||||
 | 
					        quantization_config=qonfig
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pipe = FluxFillPipeline.from_pretrained(
 | 
				
			||||||
 | 
					        model, **params)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pipe.vae.enable_tiling()
 | 
				
			||||||
 | 
					    pipe.vae.enable_slicing()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return pipe
 | 
				
			||||||
							
								
								
									
										101
									
								
								skynet/utils.py
								
								
								
								
							
							
						
						
									
										101
									
								
								skynet/utils.py
								
								
								
								
							| 
						 | 
					@ -6,29 +6,42 @@ import sys
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
import random
 | 
					import random
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
 | 
					import importlib
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional
 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
import asks
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import trio
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from PIL import Image
 | 
					from PIL import Image
 | 
				
			||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
 | 
					 | 
				
			||||||
from diffusers import (
 | 
					from diffusers import (
 | 
				
			||||||
    DiffusionPipeline,
 | 
					    DiffusionPipeline,
 | 
				
			||||||
    AutoPipelineForText2Image,
 | 
					    AutoPipelineForText2Image,
 | 
				
			||||||
    AutoPipelineForImage2Image,
 | 
					    AutoPipelineForImage2Image,
 | 
				
			||||||
    AutoPipelineForInpainting,
 | 
					    AutoPipelineForInpainting,
 | 
				
			||||||
    EulerAncestralDiscreteScheduler
 | 
					    EulerAncestralDiscreteScheduler,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from realesrgan import RealESRGANer
 | 
					 | 
				
			||||||
from huggingface_hub import login
 | 
					from huggingface_hub import login
 | 
				
			||||||
import trio
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .constants import MODELS
 | 
					from .constants import MODELS
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Hack to fix a changed import in torchvision 0.17+, which otherwise breaks
 | 
				
			||||||
 | 
					# basicsr; see https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/13985
 | 
				
			||||||
 | 
					try:
 | 
				
			||||||
 | 
					    import torchvision.transforms.functional_tensor  # noqa: F401
 | 
				
			||||||
 | 
					except ImportError:
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        import torchvision.transforms.functional as functional
 | 
				
			||||||
 | 
					        sys.modules["torchvision.transforms.functional_tensor"] = functional
 | 
				
			||||||
 | 
					    except ImportError:
 | 
				
			||||||
 | 
					        pass  # shrug...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from basicsr.archs.rrdbnet_arch import RRDBNet
 | 
				
			||||||
 | 
					from realesrgan import RealESRGANer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def time_ms():
 | 
					def time_ms():
 | 
				
			||||||
    return int(time.time() * 1000)
 | 
					    return int(time.time() * 1000)
 | 
				
			||||||
| 
						 | 
					@ -72,6 +85,7 @@ def pipeline_for(
 | 
				
			||||||
    cache_dir: str | None = None
 | 
					    cache_dir: str | None = None
 | 
				
			||||||
) -> DiffusionPipeline:
 | 
					) -> DiffusionPipeline:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    logging.info(f'pipeline_for {model} {mode}')
 | 
				
			||||||
    assert torch.cuda.is_available()
 | 
					    assert torch.cuda.is_available()
 | 
				
			||||||
    torch.cuda.empty_cache()
 | 
					    torch.cuda.empty_cache()
 | 
				
			||||||
    torch.backends.cuda.matmul.allow_tf32 = True
 | 
					    torch.backends.cuda.matmul.allow_tf32 = True
 | 
				
			||||||
| 
						 | 
					@ -85,21 +99,35 @@ def pipeline_for(
 | 
				
			||||||
    torch.use_deterministic_algorithms(True)
 | 
					    torch.use_deterministic_algorithms(True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model_info = MODELS[model]
 | 
					    model_info = MODELS[model]
 | 
				
			||||||
 | 
					    shortname = model_info.short
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # disable for compat with "diffuse" method
 | 
				
			||||||
 | 
					    # assert mode in model_info.tags
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # default to checking if custom pipeline exist and return that if not, attempt generic
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        normalized_shortname = shortname.replace('-', '_')
 | 
				
			||||||
 | 
					        custom_pipeline = importlib.import_module(f'skynet.dgpu.pipes.{normalized_shortname}')
 | 
				
			||||||
 | 
					        assert custom_pipeline.__model['name'] == model
 | 
				
			||||||
 | 
					        return custom_pipeline.pipeline_for(model, mode, mem_fraction=mem_fraction, cache_dir=cache_dir)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    except ImportError:
 | 
				
			||||||
 | 
					        ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    req_mem = model_info.mem
 | 
					    req_mem = model_info.mem
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    mem_gb = torch.cuda.mem_get_info()[1] / (10**9)
 | 
					    mem_gb = torch.cuda.mem_get_info()[1] / (10**9)
 | 
				
			||||||
    mem_gb *= mem_fraction
 | 
					    mem_gb *= mem_fraction
 | 
				
			||||||
    over_mem = mem_gb < req_mem
 | 
					    over_mem = mem_gb < req_mem
 | 
				
			||||||
    if over_mem:
 | 
					    if over_mem:
 | 
				
			||||||
        logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..')
 | 
					        logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    shortname = model_info.short
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    params = {
 | 
					    params = {
 | 
				
			||||||
        'safety_checker': None,
 | 
					        'safety_checker': None,
 | 
				
			||||||
        'torch_dtype': torch.float16,
 | 
					        'torch_dtype': torch.float16,
 | 
				
			||||||
        'cache_dir': cache_dir,
 | 
					        'cache_dir': cache_dir,
 | 
				
			||||||
        'variant': 'fp16'
 | 
					        'variant': 'fp16',
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    match shortname:
 | 
					    match shortname:
 | 
				
			||||||
| 
						 | 
					@ -108,6 +136,7 @@ def pipeline_for(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    torch.cuda.set_per_process_memory_fraction(mem_fraction)
 | 
					    torch.cuda.set_per_process_memory_fraction(mem_fraction)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pipe_class = DiffusionPipeline
 | 
				
			||||||
    match mode:
 | 
					    match mode:
 | 
				
			||||||
        case 'inpaint':
 | 
					        case 'inpaint':
 | 
				
			||||||
            pipe_class = AutoPipelineForInpainting
 | 
					            pipe_class = AutoPipelineForInpainting
 | 
				
			||||||
| 
						 | 
					@ -115,7 +144,7 @@ def pipeline_for(
 | 
				
			||||||
        case 'img2img':
 | 
					        case 'img2img':
 | 
				
			||||||
            pipe_class = AutoPipelineForImage2Image
 | 
					            pipe_class = AutoPipelineForImage2Image
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        case 'txt2img' | 'diffuse':
 | 
					        case 'txt2img':
 | 
				
			||||||
            pipe_class = AutoPipelineForText2Image
 | 
					            pipe_class = AutoPipelineForText2Image
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    pipe = pipe_class.from_pretrained(
 | 
					    pipe = pipe_class.from_pretrained(
 | 
				
			||||||
| 
						 | 
					@ -124,20 +153,20 @@ def pipeline_for(
 | 
				
			||||||
    pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
 | 
					    pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
 | 
				
			||||||
        pipe.scheduler.config)
 | 
					        pipe.scheduler.config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    pipe.enable_xformers_memory_efficient_attention()
 | 
					    # pipe.enable_xformers_memory_efficient_attention()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if over_mem:
 | 
					    if over_mem:
 | 
				
			||||||
        if mode == 'txt2img':
 | 
					        if mode == 'txt2img':
 | 
				
			||||||
            pipe.enable_vae_slicing()
 | 
					            pipe.vae.enable_tiling()
 | 
				
			||||||
            pipe.enable_vae_tiling()
 | 
					            pipe.vae.enable_slicing()
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        pipe.enable_model_cpu_offload()
 | 
					        pipe.enable_model_cpu_offload()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        if sys.version_info[1] < 11:
 | 
					        # if sys.version_info[1] < 11:
 | 
				
			||||||
            # torch.compile only supported on python < 3.11
 | 
					        #     # torch.compile only supported on python < 3.11
 | 
				
			||||||
            pipe.unet = torch.compile(
 | 
					        #     pipe.unet = torch.compile(
 | 
				
			||||||
                pipe.unet, mode='reduce-overhead', fullgraph=True)
 | 
					        #         pipe.unet, mode='reduce-overhead', fullgraph=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        pipe = pipe.to('cuda')
 | 
					        pipe = pipe.to('cuda')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -155,7 +184,7 @@ def txt2img(
 | 
				
			||||||
    seed: Optional[int] = None
 | 
					    seed: Optional[int] = None
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    login(token=hf_token)
 | 
					    login(token=hf_token)
 | 
				
			||||||
    pipe = pipeline_for(model)
 | 
					    pipe = pipeline_for(model, 'txt2img')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    seed = seed if seed else random.randint(0, 2 ** 64)
 | 
					    seed = seed if seed else random.randint(0, 2 ** 64)
 | 
				
			||||||
    prompt = prompt
 | 
					    prompt = prompt
 | 
				
			||||||
| 
						 | 
					@ -182,7 +211,7 @@ def img2img(
 | 
				
			||||||
    seed: Optional[int] = None
 | 
					    seed: Optional[int] = None
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    login(token=hf_token)
 | 
					    login(token=hf_token)
 | 
				
			||||||
    pipe = pipeline_for(model, image=True)
 | 
					    pipe = pipeline_for(model, 'img2img')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model_info = MODELS[model]
 | 
					    model_info = MODELS[model]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -215,7 +244,7 @@ def inpaint(
 | 
				
			||||||
    seed: Optional[int] = None
 | 
					    seed: Optional[int] = None
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    login(token=hf_token)
 | 
					    login(token=hf_token)
 | 
				
			||||||
    pipe = pipeline_for(model, image=True, inpainting=True)
 | 
					    pipe = pipeline_for(model, 'inpaint')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model_info = MODELS[model]
 | 
					    model_info = MODELS[model]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -225,21 +254,25 @@ def inpaint(
 | 
				
			||||||
    with open(mask_path, 'rb') as mask_file:
 | 
					    with open(mask_path, 'rb') as mask_file:
 | 
				
			||||||
        mask_img = convert_from_bytes_and_crop(mask_file.read(), model_info.size.w, model_info.size.h)
 | 
					        mask_img = convert_from_bytes_and_crop(mask_file.read(), model_info.size.w, model_info.size.h)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    var_params = {}
 | 
				
			||||||
 | 
					    if 'flux' not in model.lower():
 | 
				
			||||||
 | 
					        var_params['strength'] = strength
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    seed = seed if seed else random.randint(0, 2 ** 64)
 | 
					    seed = seed if seed else random.randint(0, 2 ** 64)
 | 
				
			||||||
    prompt = prompt
 | 
					    prompt = prompt
 | 
				
			||||||
    image = pipe(
 | 
					    image = pipe(
 | 
				
			||||||
        prompt,
 | 
					        prompt,
 | 
				
			||||||
        image=input_img,
 | 
					        image=input_img,
 | 
				
			||||||
        mask_image=mask_img,
 | 
					        mask_image=mask_img,
 | 
				
			||||||
        strength=strength,
 | 
					 | 
				
			||||||
        guidance_scale=guidance, num_inference_steps=steps,
 | 
					        guidance_scale=guidance, num_inference_steps=steps,
 | 
				
			||||||
        generator=torch.Generator("cuda").manual_seed(seed)
 | 
					        generator=torch.Generator("cuda").manual_seed(seed),
 | 
				
			||||||
 | 
					        **var_params
 | 
				
			||||||
    ).images[0]
 | 
					    ).images[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    image.save(output)
 | 
					    image.save(output)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'):
 | 
					def init_upscaler(model_path: str = 'hf_home/RealESRGAN_x4plus.pth'):
 | 
				
			||||||
    return RealESRGANer(
 | 
					    return RealESRGANer(
 | 
				
			||||||
        scale=4,
 | 
					        scale=4,
 | 
				
			||||||
        model_path=model_path,
 | 
					        model_path=model_path,
 | 
				
			||||||
| 
						 | 
					@ -258,7 +291,7 @@ def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'):
 | 
				
			||||||
def upscale(
 | 
					def upscale(
 | 
				
			||||||
    img_path: str = 'input.png',
 | 
					    img_path: str = 'input.png',
 | 
				
			||||||
    output: str = 'output.png',
 | 
					    output: str = 'output.png',
 | 
				
			||||||
    model_path: str = 'weights/RealESRGAN_x4plus.pth'
 | 
					    model_path: str = 'hf_home/RealESRGAN_x4plus.pth'
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    input_img = Image.open(img_path).convert('RGB')
 | 
					    input_img = Image.open(img_path).convert('RGB')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -269,25 +302,3 @@ def upscale(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    image = convert_from_cv2_to_image(up_img)
 | 
					    image = convert_from_cv2_to_image(up_img)
 | 
				
			||||||
    image.save(output)
 | 
					    image.save(output)
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
async def download_upscaler():
 | 
					 | 
				
			||||||
    print('downloading upscaler...')
 | 
					 | 
				
			||||||
    weights_path = Path('weights')
 | 
					 | 
				
			||||||
    weights_path.mkdir(exist_ok=True)
 | 
					 | 
				
			||||||
    upscaler_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'
 | 
					 | 
				
			||||||
    save_path = weights_path / 'RealESRGAN_x4plus.pth'
 | 
					 | 
				
			||||||
    response = await asks.get(upscaler_url)
 | 
					 | 
				
			||||||
    with open(save_path, 'wb') as f:
 | 
					 | 
				
			||||||
        f.write(response.content)
 | 
					 | 
				
			||||||
    print('done')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def download_all_models(hf_token: str, hf_home: str):
 | 
					 | 
				
			||||||
    assert torch.cuda.is_available()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    trio.run(download_upscaler)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    login(token=hf_token)
 | 
					 | 
				
			||||||
    for model in MODELS:
 | 
					 | 
				
			||||||
        print(f'DOWNLOADING {model.upper()}')
 | 
					 | 
				
			||||||
        pipeline_for(model, cache_dir=hf_home)
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue