First attempt at adding flux models, update all deps, upgrade to cuda 12, add custom pipe sys

pull/44/head
Guillermo Rodriguez 2025-01-17 11:38:52 -03:00
parent 00dcccf2bb
commit 07b211514d
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
12 changed files with 1352 additions and 696 deletions

View File

@ -0,0 +1,45 @@
from nvidia/cuda:12.4.1-devel-ubuntu22.04
from python:3.12
env DEBIAN_FRONTEND=noninteractive
run apt-get update && apt-get install -y \
git \
llvm \
ffmpeg \
libsm6 \
libxext6 \
ninja-build
# env CC /usr/bin/clang
# env CXX /usr/bin/clang++
#
# # install llvm10 as required by llvm-lite
# run git clone https://github.com/llvm/llvm-project.git -b llvmorg-10.0.1
# workdir /llvm-project
# # this adds a commit from 12.0.0 that fixes build on newer compilers
# run git cherry-pick -n b498303066a63a203d24f739b2d2e0e56dca70d1
# run cmake -S llvm -B build -G Ninja -DCMAKE_BUILD_TYPE=Release
# run ninja -C build install # -j8
run curl -sSL https://install.python-poetry.org | python3 -
env PATH "/root/.local/bin:$PATH"
copy . /skynet
workdir /skynet
env POETRY_VIRTUALENVS_PATH /skynet/.venv
run poetry install --with=cuda -v
workdir /root/target
env PYTORCH_CUDA_ALLOC_CONF max_split_size_mb:128
env NVIDIA_VISIBLE_DEVICES=all
copy docker/entrypoint.sh /entrypoint.sh
entrypoint ["/entrypoint.sh"]
cmd ["skynet", "--help"]

View File

@ -1,7 +1,7 @@
docker build \
-t guilledk/skynet:runtime-cuda-py311 \
-f docker/Dockerfile.runtime+cuda-py311 .
-t guilledk/skynet:runtime-cuda-py312 \
-f docker/Dockerfile.runtime+cuda-py312 .
docker build \
-t guilledk/skynet:runtime-cuda \
-f docker/Dockerfile.runtime+cuda-py311 .
# docker build \
# -t guilledk/skynet:runtime-cuda \
# -f docker/Dockerfile.runtime+cuda-py311 .

1383
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,21 +1,31 @@
[tool.poetry]
name = 'skynet'
version = '0.1a12'
version = '0.1a13'
description = 'Decentralized compute platform'
authors = ['Guillermo Rodriguez <guillermo@telos.net>']
license = 'AGPL'
readme = 'README.md'
[tool.poetry.dependencies]
python = '>=3.10,<3.12'
python = '>=3.10,<3.13'
pytz = '^2023.3.post1'
trio = '^0.22.2'
asks = '^3.0.0'
Pillow = '^10.0.1'
docker = '^6.1.3'
py-leap = {git = 'https://github.com/guilledk/py-leap.git', rev = 'v0.1a14'}
py-leap = {git = 'https://github.com/guilledk/py-leap.git', rev = 'v0.1a32'}
toml = '^0.10.2'
msgspec = "^0.19.0"
numpy = "<2.1"
gguf = "^0.14.0"
protobuf = "^5.29.3"
zstandard = "^0.23.0"
diskcache = "^5.6.3"
bitsandbytes = "^0.45.0"
hqq = "^0.2.2"
optimum-quanto = "^0.2.6"
basicsr = "^1.4.2"
realesrgan = "^0.3.0"
[tool.poetry.group.frontend]
optional = true
@ -39,26 +49,24 @@ pytest-trio = "^0.8.0"
optional = true
[tool.poetry.group.cuda.dependencies]
torch = {version = '2.0.1+cu118', source = 'torch'}
scipy = {version = '^1.11.2'}
numba = {version = '0.57.0'}
torch = {version = '2.5.1+cu121', source = 'torch'}
scipy = {version = '1.15.1'}
numba = {version = '0.60.0'}
quart = {version = '^0.19.3'}
triton = {version = '2.0.0', source = 'torch'}
basicsr = {version = '^1.4.2'}
xformers = {version = '^0.0.22'}
triton = {version = '3.1.0', source = 'torch'}
xformers = {version = '^0.0.29'}
hypercorn = {version = '^0.14.4'}
diffusers = {version = '^0.21.2'}
realesrgan = {version = '^0.3.0'}
diffusers = {version = '0.32.1'}
quart-trio = {version = '^0.11.0'}
torchvision = {version = '0.15.2+cu118', source = 'torch'}
accelerate = {version = '^0.23.0'}
transformers = {version = '^4.33.2'}
huggingface-hub = {version = '^0.17.3'}
torchvision = {version = '0.20.1+cu121', source = 'torch'}
accelerate = {version = '0.34.0'}
transformers = {version = '4.48.0'}
huggingface-hub = {version = '^0.27.1'}
invisible-watermark = {version = '^0.2.0'}
[[tool.poetry.source]]
name = 'torch'
url = 'https://download.pytorch.org/whl/cu118'
url = 'https://download.pytorch.org/whl/cu121'
priority = 'explicit'
[build-system]

View File

@ -8,7 +8,7 @@ from functools import partial
import click
from leap.sugar import Name, asset_from_str
from leap.protocol import Name, Asset
from .config import *
from .constants import *
@ -178,7 +178,7 @@ def enqueue(
'user': Name(account),
'request_body': req,
'binary_data': binary,
'reward': asset_from_str(reward),
'reward': Asset.from_str(reward),
'min_verification': 1
},
account, key, permission,

View File

@ -78,8 +78,20 @@ MODELS: dict[str, ModelDesc] = {
size=Size(w=512, h=512),
tags=['txt2img']
),
'black-forest-labs/FLUX.1-schnell': ModelDesc(
short='flux',
mem=24,
size=Size(w=1024, h=1024),
tags=['txt2img']
),
'black-forest-labs/FLUX.1-Fill-dev': ModelDesc(
short='flux-inpaint',
mem=24,
size=Size(w=1024, h=1024),
tags=['inpaint']
),
'diffusers/stable-diffusion-xl-1.0-inpainting-0.1': ModelDesc(
short='stablexl-inpainting',
short='stablexl-inpaint',
mem=8.3,
size=Size(w=1024, h=1024),
tags=['inpaint']

View File

@ -18,7 +18,6 @@ from skynet.dgpu.errors import DGPUComputeError, DGPUInferenceCancelled
from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
def prepare_params_for_diffuse(
params: dict,
mode: str,
@ -35,7 +34,11 @@ def prepare_params_for_diffuse(
_params['image'] = image
_params['mask_image'] = mask
_params['strength'] = float(params['strength'])
if 'flux' in params['model'].lower():
_params['max_sequence_length'] = 512
else:
_params['strength'] = float(params['strength'])
case 'img2img':
image = crop_image(
@ -66,8 +69,6 @@ def prepare_params_for_diffuse(
class SkynetMM:
def __init__(self, config: dict):
self.upscaler = init_upscaler()
self.cache_dir = None
if 'hf_home' in config:
self.cache_dir = config['hf_home']
@ -88,30 +89,28 @@ class SkynetMM:
return False
def load_model(
self,
name: str,
mode: str
):
logging.info(f'loading model {name}...')
self._model_mode = mode
self._model_name = name
def unload_model(self):
if getattr(self, '_model', None):
del self._model
gc.collect()
torch.cuda.empty_cache()
self._model_name = ''
self._model_mode = ''
def load_model(
self,
name: str,
mode: str
):
logging.info(f'loading model {name}...')
self.unload_model()
self._model = pipeline_for(
name, mode, cache_dir=self.cache_dir)
self._model_mode = mode
self._model_name = name
def get_model(self, name: str, mode: str) -> DiffusionPipeline:
if name not in MODELS:
raise DGPUComputeError(f'Unknown model {model_name}')
if not self.is_model_loaded(name, mode):
self.load_model(name, mode)
def compute_one(
self,
@ -127,6 +126,8 @@ class SkynetMM:
logging.warn(f'cancelling work at step {step}')
raise DGPUInferenceCancelled()
return {}
maybe_cancel_work(0)
output_type = 'png'
@ -136,23 +137,29 @@ class SkynetMM:
output = None
output_hash = None
try:
name = params['model']
match method:
case 'diffuse' | 'txt2img' | 'img2img' | 'inpaint':
if not self.is_model_loaded(name, method):
self.load_model(name, method)
arguments = prepare_params_for_diffuse(
params, method, inputs)
prompt, guidance, step, seed, upscaler, extra_params = arguments
self.get_model(
params['model'],
method
)
if 'flux' in name.lower():
extra_params['callback_on_step_end'] = maybe_cancel_work
else:
extra_params['callback'] = maybe_cancel_work
extra_params['callback_steps'] = 1
output = self._model(
prompt,
guidance_scale=guidance,
num_inference_steps=step,
generator=seed,
callback=maybe_cancel_work,
callback_steps=1,
**extra_params
).images[0]
@ -161,7 +168,7 @@ class SkynetMM:
case 'png':
if upscaler == 'x4':
input_img = output.convert('RGB')
up_img, _ = self.upscaler.enhance(
up_img, _ = init_upscaler().enhance(
convert_from_image_to_cv2(input_img), outscale=4)
output = convert_from_cv2_to_image(up_img)
@ -173,6 +180,22 @@ class SkynetMM:
output_hash = sha256(output_binary).hexdigest()
case 'upscale':
if self._model_mode != 'upscale':
self.unload_model()
self._model = init_upscaler()
self._model_mode = 'upscale'
self._model_name = 'realesrgan'
input_img = inputs[0].convert('RGB')
up_img, _ = self._model.enhance(
convert_from_image_to_cv2(input_img), outscale=4)
output = convert_from_cv2_to_image(up_img)
output_binary = convert_from_img_to_bytes(output)
output_hash = sha256(output_binary).hexdigest()
case _:
raise DGPUComputeError('Unsupported compute method')

View File

@ -125,7 +125,7 @@ class SkynetDGPUDaemon:
model = body['params']['model']
# if model not known
if model not in MODELS:
if model != 'RealESRGAN_x4plus' and model not in MODELS:
logging.warning(f'Unknown model {model}')
return False
@ -143,11 +143,17 @@ class SkynetDGPUDaemon:
statuses = self._snap['requests'][rid]
if len(statuses) == 0:
inputs = [
await self.conn.get_input_data(_input)
for _input in req['binary_data'].split(',')
if _input
]
inputs = []
for _input in req['binary_data'].split(','):
if _input:
for _ in range(3):
try:
img = await self.conn.get_input_data(_input)
inputs.append(img)
break
except:
...
hash_str = (
str(req['nonce'])

View File

@ -15,7 +15,7 @@ import anyio
from PIL import Image, UnidentifiedImageError
from leap.cleos import CLEOS
from leap.sugar import Checksum256, Name, asset_from_str
from leap.protocol import Asset
from skynet.constants import DEFAULT_IPFS_DOMAIN
from skynet.ipfs import AsyncIPFSHTTP, get_ipfs_file
@ -24,6 +24,225 @@ from skynet.dgpu.errors import DGPUComputeError
REQUEST_UPDATE_TIME = 3
gpu_abi = {
"version": "eosio::abi/1.2",
"types": [],
"structs": [
{
"name": "account",
"base": "",
"fields": [
{"name": "user", "type": "name"},
{"name": "balance", "type": "asset"},
{"name": "nonce", "type": "uint64"}
]
},
{
"name": "card",
"base": "",
"fields": [
{"name": "id", "type": "uint64"},
{"name": "owner", "type": "name"},
{"name": "card_name", "type": "string"},
{"name": "version", "type": "string"},
{"name": "total_memory", "type": "uint64"},
{"name": "mp_count", "type": "uint32"},
{"name": "extra", "type": "string"}
]
},
{
"name": "clean",
"base": "",
"fields": []
},
{
"name": "config",
"base": "",
"fields": [
{"name": "token_contract", "type": "name"},
{"name": "token_symbol", "type": "symbol"}
]
},
{
"name": "dequeue",
"base": "",
"fields": [
{"name": "user", "type": "name"},
{"name": "request_id", "type": "uint64"}
]
},
{
"name": "enqueue",
"base": "",
"fields": [
{"name": "user", "type": "name"},
{"name": "request_body", "type": "string"},
{"name": "binary_data", "type": "string"},
{"name": "reward", "type": "asset"},
{"name": "min_verification", "type": "uint32"}
]
},
{
"name": "gcfgstruct",
"base": "",
"fields": [
{"name": "token_contract", "type": "name"},
{"name": "token_symbol", "type": "symbol"}
]
},
{
"name": "submit",
"base": "",
"fields": [
{"name": "worker", "type": "name"},
{"name": "request_id", "type": "uint64"},
{"name": "request_hash", "type": "checksum256"},
{"name": "result_hash", "type": "checksum256"},
{"name": "ipfs_hash", "type": "string"}
]
},
{
"name": "withdraw",
"base": "",
"fields": [
{"name": "user", "type": "name"},
{"name": "quantity", "type": "asset"}
]
},
{
"name": "work_request_struct",
"base": "",
"fields": [
{"name": "id", "type": "uint64"},
{"name": "user", "type": "name"},
{"name": "reward", "type": "asset"},
{"name": "min_verification", "type": "uint32"},
{"name": "nonce", "type": "uint64"},
{"name": "body", "type": "string"},
{"name": "binary_data", "type": "string"},
{"name": "timestamp", "type": "time_point_sec"}
]
},
{
"name": "work_result_struct",
"base": "",
"fields": [
{"name": "id", "type": "uint64"},
{"name": "request_id", "type": "uint64"},
{"name": "user", "type": "name"},
{"name": "worker", "type": "name"},
{"name": "result_hash", "type": "checksum256"},
{"name": "ipfs_hash", "type": "string"},
{"name": "submited", "type": "time_point_sec"}
]
},
{
"name": "workbegin",
"base": "",
"fields": [
{"name": "worker", "type": "name"},
{"name": "request_id", "type": "uint64"},
{"name": "max_workers", "type": "uint32"}
]
},
{
"name": "workcancel",
"base": "",
"fields": [
{"name": "worker", "type": "name"},
{"name": "request_id", "type": "uint64"},
{"name": "reason", "type": "string"}
]
},
{
"name": "worker",
"base": "",
"fields": [
{"name": "account", "type": "name"},
{"name": "joined", "type": "time_point_sec"},
{"name": "left", "type": "time_point_sec"},
{"name": "url", "type": "string"}
]
},
{
"name": "worker_status_struct",
"base": "",
"fields": [
{"name": "worker", "type": "name"},
{"name": "status", "type": "string"},
{"name": "started", "type": "time_point_sec"}
]
}
],
"actions": [
{"name": "clean", "type": "clean", "ricardian_contract": ""},
{"name": "config", "type": "config", "ricardian_contract": ""},
{"name": "dequeue", "type": "dequeue", "ricardian_contract": ""},
{"name": "enqueue", "type": "enqueue", "ricardian_contract": ""},
{"name": "submit", "type": "submit", "ricardian_contract": ""},
{"name": "withdraw", "type": "withdraw", "ricardian_contract": ""},
{"name": "workbegin", "type": "workbegin", "ricardian_contract": ""},
{"name": "workcancel", "type": "workcancel", "ricardian_contract": ""}
],
"tables": [
{
"name": "cards",
"index_type": "i64",
"key_names": [],
"key_types": [],
"type": "card"
},
{
"name": "gcfgstruct",
"index_type": "i64",
"key_names": [],
"key_types": [],
"type": "gcfgstruct"
},
{
"name": "queue",
"index_type": "i64",
"key_names": [],
"key_types": [],
"type": "work_request_struct"
},
{
"name": "results",
"index_type": "i64",
"key_names": [],
"key_types": [],
"type": "work_result_struct"
},
{
"name": "status",
"index_type": "i64",
"key_names": [],
"key_types": [],
"type": "worker_status_struct"
},
{
"name": "users",
"index_type": "i64",
"key_names": [],
"key_types": [],
"type": "account"
},
{
"name": "workers",
"index_type": "i64",
"key_names": [],
"key_types": [],
"type": "worker"
}
],
"ricardian_clauses": [],
"error_messages": [],
"abi_extensions": [],
"variants": [],
"action_results": []
}
async def failable(fn: partial, ret_fail=None):
try:
@ -35,22 +254,22 @@ async def failable(fn: partial, ret_fail=None):
asks.errors.RequestTimeout,
asks.errors.BadHttpResponse,
anyio.BrokenResourceError
):
) as e:
return ret_fail
class SkynetGPUConnector:
def __init__(self, config: dict):
self.account = Name(config['account'])
self.account = config['account']
self.permission = config['permission']
self.key = config['key']
self.node_url = config['node_url']
self.hyperion_url = config['hyperion_url']
self.cleos = CLEOS(
None, None, self.node_url, remote=self.node_url)
self.cleos = CLEOS(endpoint=self.node_url)
self.cleos.load_abi('gpu.scd', gpu_abi)
self.ipfs_gateway_url = None
if 'ipfs_gateway_url' in config:
@ -151,11 +370,11 @@ class SkynetGPUConnector:
self.cleos.a_push_action,
'gpu.scd',
'workbegin',
{
list({
'worker': self.account,
'request_id': request_id,
'max_workers': 2
},
}.values()),
self.account, self.key,
permission=self.permission
)
@ -168,11 +387,11 @@ class SkynetGPUConnector:
self.cleos.a_push_action,
'gpu.scd',
'workcancel',
{
list({
'worker': self.account,
'request_id': request_id,
'reason': reason
},
}.values()),
self.account, self.key,
permission=self.permission
)
@ -191,10 +410,10 @@ class SkynetGPUConnector:
self.cleos.a_push_action,
'gpu.scd',
'withdraw',
{
list({
'user': self.account,
'quantity': asset_from_str(balance)
},
'quantity': Asset.from_str(balance)
}.values()),
self.account, self.key,
permission=self.permission
)
@ -226,13 +445,13 @@ class SkynetGPUConnector:
self.cleos.a_push_action,
'gpu.scd',
'submit',
{
list({
'worker': self.account,
'request_id': request_id,
'request_hash': Checksum256(request_hash),
'result_hash': Checksum256(result_hash),
'request_hash': request_hash,
'result_hash': result_hash,
'ipfs_hash': ipfs_hash
},
}.values()),
self.account, self.key,
permission=self.permission
)

View File

@ -0,0 +1,50 @@
#!/usr/bin/python
import torch
from diffusers import (
DiffusionPipeline,
FluxPipeline,
FluxTransformer2DModel
)
from transformers import T5EncoderModel, BitsAndBytesConfig
from huggingface_hub import hf_hub_download
__model = {
'name': 'black-forest-labs/FLUX.1-schnell'
}
def pipeline_for(
model: str,
mode: str,
mem_fraction: float = 1.0,
cache_dir: str | None = None
) -> DiffusionPipeline:
qonfig = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
)
params = {
'torch_dtype': torch.bfloat16,
'cache_dir': cache_dir,
'device_map': 'balanced',
'max_memory': {'cpu': '10GiB', 0: '11GiB'}
# 'max_memory': {0: '11GiB'}
}
text_encoder = T5EncoderModel.from_pretrained(
'black-forest-labs/FLUX.1-schnell',
subfolder="text_encoder_2",
torch_dtype=torch.bfloat16,
quantization_config=qonfig
)
params['text_encoder_2'] = text_encoder
pipe = FluxPipeline.from_pretrained(
model, **params)
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()
return pipe

View File

@ -0,0 +1,55 @@
#!/usr/bin/python
import torch
from diffusers import (
DiffusionPipeline,
FluxFillPipeline,
FluxTransformer2DModel
)
from transformers import T5EncoderModel, BitsAndBytesConfig
__model = {
'name': 'black-forest-labs/FLUX.1-Fill-dev'
}
def pipeline_for(
model: str,
mode: str,
mem_fraction: float = 1.0,
cache_dir: str | None = None
) -> DiffusionPipeline:
qonfig = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
)
params = {
'torch_dtype': torch.bfloat16,
'cache_dir': cache_dir,
'device_map': 'balanced',
'max_memory': {'cpu': '10GiB', 0: '11GiB'}
# 'max_memory': {0: '11GiB'}
}
text_encoder = T5EncoderModel.from_pretrained(
'sayakpaul/FLUX.1-Fill-dev-nf4',
subfolder="text_encoder_2",
torch_dtype=torch.bfloat16,
quantization_config=qonfig
)
params['text_encoder_2'] = text_encoder
transformer = FluxTransformer2DModel.from_pretrained(
'sayakpaul/FLUX.1-Fill-dev-nf4',
subfolder="transformer",
torch_dtype=torch.bfloat16,
quantization_config=qonfig
)
pipe = FluxFillPipeline.from_pretrained(
model, **params)
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()
return pipe

View File

@ -6,29 +6,42 @@ import sys
import time
import random
import logging
import importlib
from typing import Optional
from pathlib import Path
import asks
import trio
import torch
import numpy as np
from PIL import Image
from basicsr.archs.rrdbnet_arch import RRDBNet
from diffusers import (
DiffusionPipeline,
AutoPipelineForText2Image,
AutoPipelineForImage2Image,
AutoPipelineForInpainting,
EulerAncestralDiscreteScheduler
EulerAncestralDiscreteScheduler,
)
from realesrgan import RealESRGANer
from huggingface_hub import login
import trio
from .constants import MODELS
# Hack to fix a changed import in torchvision 0.17+, which otherwise breaks
# basicsr; see https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/13985
try:
import torchvision.transforms.functional_tensor # noqa: F401
except ImportError:
try:
import torchvision.transforms.functional as functional
sys.modules["torchvision.transforms.functional_tensor"] = functional
except ImportError:
pass # shrug...
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
def time_ms():
return int(time.time() * 1000)
@ -72,6 +85,7 @@ def pipeline_for(
cache_dir: str | None = None
) -> DiffusionPipeline:
logging.info(f'pipeline_for {model} {mode}')
assert torch.cuda.is_available()
torch.cuda.empty_cache()
torch.backends.cuda.matmul.allow_tf32 = True
@ -85,21 +99,35 @@ def pipeline_for(
torch.use_deterministic_algorithms(True)
model_info = MODELS[model]
shortname = model_info.short
# disable for compat with "diffuse" method
# assert mode in model_info.tags
# default to checking if custom pipeline exist and return that if not, attempt generic
try:
normalized_shortname = shortname.replace('-', '_')
custom_pipeline = importlib.import_module(f'skynet.dgpu.pipes.{normalized_shortname}')
assert custom_pipeline.__model['name'] == model
return custom_pipeline.pipeline_for(model, mode, mem_fraction=mem_fraction, cache_dir=cache_dir)
except ImportError:
...
req_mem = model_info.mem
mem_gb = torch.cuda.mem_get_info()[1] / (10**9)
mem_gb *= mem_fraction
over_mem = mem_gb < req_mem
if over_mem:
logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..')
shortname = model_info.short
params = {
'safety_checker': None,
'torch_dtype': torch.float16,
'cache_dir': cache_dir,
'variant': 'fp16'
'variant': 'fp16',
}
match shortname:
@ -108,6 +136,7 @@ def pipeline_for(
torch.cuda.set_per_process_memory_fraction(mem_fraction)
pipe_class = DiffusionPipeline
match mode:
case 'inpaint':
pipe_class = AutoPipelineForInpainting
@ -115,7 +144,7 @@ def pipeline_for(
case 'img2img':
pipe_class = AutoPipelineForImage2Image
case 'txt2img' | 'diffuse':
case 'txt2img':
pipe_class = AutoPipelineForText2Image
pipe = pipe_class.from_pretrained(
@ -124,20 +153,20 @@ def pipeline_for(
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe.scheduler.config)
pipe.enable_xformers_memory_efficient_attention()
# pipe.enable_xformers_memory_efficient_attention()
if over_mem:
if mode == 'txt2img':
pipe.enable_vae_slicing()
pipe.enable_vae_tiling()
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()
pipe.enable_model_cpu_offload()
else:
if sys.version_info[1] < 11:
# torch.compile only supported on python < 3.11
pipe.unet = torch.compile(
pipe.unet, mode='reduce-overhead', fullgraph=True)
# if sys.version_info[1] < 11:
# # torch.compile only supported on python < 3.11
# pipe.unet = torch.compile(
# pipe.unet, mode='reduce-overhead', fullgraph=True)
pipe = pipe.to('cuda')
@ -155,7 +184,7 @@ def txt2img(
seed: Optional[int] = None
):
login(token=hf_token)
pipe = pipeline_for(model)
pipe = pipeline_for(model, 'txt2img')
seed = seed if seed else random.randint(0, 2 ** 64)
prompt = prompt
@ -182,7 +211,7 @@ def img2img(
seed: Optional[int] = None
):
login(token=hf_token)
pipe = pipeline_for(model, image=True)
pipe = pipeline_for(model, 'img2img')
model_info = MODELS[model]
@ -215,7 +244,7 @@ def inpaint(
seed: Optional[int] = None
):
login(token=hf_token)
pipe = pipeline_for(model, image=True, inpainting=True)
pipe = pipeline_for(model, 'inpaint')
model_info = MODELS[model]
@ -225,21 +254,25 @@ def inpaint(
with open(mask_path, 'rb') as mask_file:
mask_img = convert_from_bytes_and_crop(mask_file.read(), model_info.size.w, model_info.size.h)
var_params = {}
if 'flux' not in model.lower():
var_params['strength'] = strength
seed = seed if seed else random.randint(0, 2 ** 64)
prompt = prompt
image = pipe(
prompt,
image=input_img,
mask_image=mask_img,
strength=strength,
guidance_scale=guidance, num_inference_steps=steps,
generator=torch.Generator("cuda").manual_seed(seed)
generator=torch.Generator("cuda").manual_seed(seed),
**var_params
).images[0]
image.save(output)
def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'):
def init_upscaler(model_path: str = 'hf_home/RealESRGAN_x4plus.pth'):
return RealESRGANer(
scale=4,
model_path=model_path,
@ -258,7 +291,7 @@ def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'):
def upscale(
img_path: str = 'input.png',
output: str = 'output.png',
model_path: str = 'weights/RealESRGAN_x4plus.pth'
model_path: str = 'hf_home/RealESRGAN_x4plus.pth'
):
input_img = Image.open(img_path).convert('RGB')
@ -269,25 +302,3 @@ def upscale(
image = convert_from_cv2_to_image(up_img)
image.save(output)
async def download_upscaler():
print('downloading upscaler...')
weights_path = Path('weights')
weights_path.mkdir(exist_ok=True)
upscaler_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'
save_path = weights_path / 'RealESRGAN_x4plus.pth'
response = await asks.get(upscaler_url)
with open(save_path, 'wb') as f:
f.write(response.content)
print('done')
def download_all_models(hf_token: str, hf_home: str):
assert torch.cuda.is_available()
trio.run(download_upscaler)
login(token=hf_token)
for model in MODELS:
print(f'DOWNLOADING {model.upper()}')
pipeline_for(model, cache_dir=hf_home)