Simplify pipeline_for function and add the infra needed for diferent io/types than png

pull/26/head
Guillermo Rodriguez 2023-10-08 18:00:18 -03:00
parent ee1fdcc557
commit 3d2069d151
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
8 changed files with 132 additions and 305 deletions

View File

@ -15,7 +15,6 @@ Pillow = '^10.0.1'
docker = '^6.1.3' docker = '^6.1.3'
py-leap = {git = 'https://github.com/guilledk/py-leap.git', rev = 'v0.1a14'} py-leap = {git = 'https://github.com/guilledk/py-leap.git', rev = 'v0.1a14'}
toml = "^0.10.2" toml = "^0.10.2"
tractor = {git = "https://github.com/goodboy/tractor.git"}
[tool.poetry.group.frontend] [tool.poetry.group.frontend]
optional = true optional = true

View File

@ -85,7 +85,7 @@ def download():
hf_token = load_key(config, 'skynet.dgpu.hf_token') hf_token = load_key(config, 'skynet.dgpu.hf_token')
hf_home = load_key(config, 'skynet.dgpu.hf_home') hf_home = load_key(config, 'skynet.dgpu.hf_home')
set_hf_vars(hf_token, hf_home) set_hf_vars(hf_token, hf_home)
utils.download_all_models(hf_token) utils.download_all_models(hf_token, hf_home)
@skynet.command() @skynet.command()
@click.option( @click.option(
@ -120,21 +120,21 @@ def enqueue(
cleos = CLEOS(None, None, url=node_url, remote=node_url) cleos = CLEOS(None, None, url=node_url, remote=node_url)
binary = kwargs['binary_data']
if not kwargs['strength']:
if binary:
raise ValueError('strength -Z param required if binary data passed')
del kwargs['strength']
else:
kwargs['strength'] = float(kwargs['strength'])
async def enqueue_n_jobs(): async def enqueue_n_jobs():
for i in range(jobs): for i in range(jobs):
if not kwargs['seed']: if not kwargs['seed']:
kwargs['seed'] = random.randint(0, 10e9) kwargs['seed'] = random.randint(0, 10e9)
binary = kwargs['binary_data']
if not kwargs['strength']:
if binary:
raise ValueError('strength -Z param required if binary data passed')
del kwargs['strength']
else:
kwargs['strength'] = float(kwargs['strength'])
req = json.dumps({ req = json.dumps({
'method': 'diffuse', 'method': 'diffuse',
'params': kwargs 'params': kwargs

View File

@ -5,18 +5,20 @@ VERSION = '0.1a12'
DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda' DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda'
MODELS = { MODELS = {
'prompthero/openjourney': {'short': 'midj', 'mem': 8}, 'prompthero/openjourney': {'short': 'midj', 'mem': 6},
'runwayml/stable-diffusion-v1-5': {'short': 'stable', 'mem': 8}, 'runwayml/stable-diffusion-v1-5': {'short': 'stable', 'mem': 6},
'stabilityai/stable-diffusion-2-1-base': {'short': 'stable2', 'mem': 8}, 'stabilityai/stable-diffusion-2-1-base': {'short': 'stable2', 'mem': 6},
'snowkidy/stable-diffusion-xl-base-0.9': {'short': 'stablexl0.9', 'mem': 24}, 'snowkidy/stable-diffusion-xl-base-0.9': {'short': 'stablexl0.9', 'mem': 8.3},
'stabilityai/stable-diffusion-xl-base-1.0': {'short': 'stablexl', 'mem': 24}, 'Linaqruf/anything-v3.0': {'short': 'hdanime', 'mem': 6},
'Linaqruf/anything-v3.0': {'short': 'hdanime', 'mem': 8}, 'hakurei/waifu-diffusion': {'short': 'waifu', 'mem': 6},
'hakurei/waifu-diffusion': {'short': 'waifu', 'mem': 8}, 'nitrosocke/Ghibli-Diffusion': {'short': 'ghibli', 'mem': 6},
'nitrosocke/Ghibli-Diffusion': {'short': 'ghibli', 'mem': 8}, 'dallinmackay/Van-Gogh-diffusion': {'short': 'van-gogh', 'mem': 6},
'dallinmackay/Van-Gogh-diffusion': {'short': 'van-gogh', 'mem': 8}, 'lambdalabs/sd-pokemon-diffusers': {'short': 'pokemon', 'mem': 6},
'lambdalabs/sd-pokemon-diffusers': {'short': 'pokemon', 'mem': 8}, 'Envvi/Inkpunk-Diffusion': {'short': 'ink', 'mem': 6},
'Envvi/Inkpunk-Diffusion': {'short': 'ink', 'mem': 8}, 'nousr/robo-diffusion': {'short': 'robot', 'mem': 6},
'nousr/robo-diffusion': {'short': 'robot', 'mem': 8}
# default is always last
'stabilityai/stable-diffusion-xl-base-1.0': {'short': 'stablexl', 'mem': 8.3},
} }
SHORT_NAMES = [ SHORT_NAMES = [
@ -158,7 +160,7 @@ DEFAULT_GUIDANCE = 7.5
DEFAULT_STRENGTH = 0.5 DEFAULT_STRENGTH = 0.5
DEFAULT_STEP = 28 DEFAULT_STEP = 28
DEFAULT_CREDITS = 10 DEFAULT_CREDITS = 10
DEFAULT_MODEL = list(MODELS.keys())[4] DEFAULT_MODEL = list(MODELS.keys())[-1]
DEFAULT_ROLE = 'pleb' DEFAULT_ROLE = 'pleb'
DEFAULT_UPSCALER = None DEFAULT_UPSCALER = None

View File

@ -1,165 +0,0 @@
#!/usr/bin/python
import gc
import json
import logging
from hashlib import sha256
from diffusers import DiffusionPipeline
import trio
import torch
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
from skynet.dgpu.errors import DGPUComputeError
from skynet.utils import convert_from_bytes_and_crop, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
def prepare_params_for_diffuse(
params: dict,
binary: bytes | None = None
):
image = None
if binary:
image = convert_from_bytes_and_crop(binary, 512, 512)
_params = {}
if image:
_params['image'] = image
_params['strength'] = float(params['strength'])
else:
_params['width'] = int(params['width'])
_params['height'] = int(params['height'])
return (
params['prompt'],
float(params['guidance']),
int(params['step']),
torch.manual_seed(int(params['seed'])),
params['upscaler'] if 'upscaler' in params else None,
_params
)
_models = {}
def is_model_loaded(model_name: str, image: bool):
for model_key, model_data in _models.items():
if (model_key == model_name and
model_data['image'] == image):
return True
return False
def load_model(
model_name: str,
image: bool,
force=False
):
logging.info(f'loading model {model_name}...')
if force or len(_models.keys()) == 0:
pipe = pipeline_for(
model_name, image=image)
_models[model_name] = {
'pipe': pipe,
'generated': 0,
'image': image
}
else:
least_used = list(_models.keys())[0]
for model in _models:
if _models[
least_used]['generated'] > _models[model]['generated']:
least_used = model
del _models[least_used]
logging.info(f'swapping model {least_used} for {model_name}...')
gc.collect()
torch.cuda.empty_cache()
pipe = pipeline_for(
model_name, image=image)
_models[model_name] = {
'pipe': pipe,
'generated': 0,
'image': image
}
logging.info(f'loaded model {model_name}')
return pipe
def get_model(model_name: str, image: bool) -> DiffusionPipeline:
if model_name not in MODELS:
raise DGPUComputeError(f'Unknown model {model_name}')
if not is_model_loaded(model_name, image):
pipe = load_model(model_name, image=image)
else:
pipe = _models[model_name]['pipe']
return pipe
def _static_compute_one(kwargs: dict):
request_id: int = kwargs['request_id']
method: str = kwargs['method']
params: dict = kwargs['params']
binary: bytes | None = kwargs['binary']
def _checkpoint(*args, **kwargs):
trio.from_thread.run(trio.sleep, 0)
try:
match method:
case 'diffuse':
image = None
arguments = prepare_params_for_diffuse(params, binary)
prompt, guidance, step, seed, upscaler, extra_params = arguments
model = get_model(params['model'], 'image' in extra_params)
image = model(
prompt,
guidance_scale=guidance,
num_inference_steps=step,
generator=seed,
callback=_checkpoint,
callback_steps=1,
**extra_params
).images[0]
if upscaler == 'x4':
upscaler = init_upscaler()
input_img = image.convert('RGB')
up_img, _ = upscaler.enhance(
convert_from_image_to_cv2(input_img), outscale=4)
image = convert_from_cv2_to_image(up_img)
img_raw = convert_from_img_to_bytes(image)
img_sha = sha256(img_raw).hexdigest()
return img_sha, img_raw
case _:
raise DGPUComputeError('Unsupported compute method')
except BaseException as e:
logging.error(e)
raise DGPUComputeError(str(e))
finally:
torch.cuda.empty_cache()
async def _tractor_static_compute_one(**kwargs):
return await trio.to_thread.run_sync(
_static_compute_one, kwargs)

View File

@ -3,10 +3,11 @@
# Skynet Memory Manager # Skynet Memory Manager
import gc import gc
import json
import logging import logging
from hashlib import sha256 from hashlib import sha256
import zipfile
from PIL import Image
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
import trio import trio
@ -15,22 +16,29 @@ import torch
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
from skynet.dgpu.errors import DGPUComputeError, DGPUInferenceCancelled from skynet.dgpu.errors import DGPUComputeError, DGPUInferenceCancelled
from skynet.utils import convert_from_bytes_and_crop, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
from ._mp_compute import _static_compute_one, _tractor_static_compute_one
def prepare_params_for_diffuse( def prepare_params_for_diffuse(
params: dict, params: dict,
binary: bytes | None = None input_type: str,
binary = None
): ):
image = None
if binary:
image = convert_from_bytes_and_crop(binary, 512, 512)
_params = {} _params = {}
if image: if binary != None:
_params['image'] = image match input_type:
_params['strength'] = float(params['strength']) case 'png':
image = crop_image(
binary, params['width'], params['height'])
_params['image'] = image
_params['strength'] = float(params['strength'])
case 'none':
...
case _:
raise DGPUComputeError(f'Unknown input_type {input_type}')
else: else:
_params['width'] = int(params['width']) _params['width'] = int(params['width'])
@ -136,6 +144,7 @@ class SkynetMM:
request_id: int, request_id: int,
method: str, method: str,
params: dict, params: dict,
input_type: str = 'png',
binary: bytes | None = None binary: bytes | None = None
): ):
def maybe_cancel_work(step, *args, **kwargs): def maybe_cancel_work(step, *args, **kwargs):
@ -147,16 +156,21 @@ class SkynetMM:
maybe_cancel_work(0) maybe_cancel_work(0)
output_type = 'png'
if 'output_type' in params:
output_type = params['output_type']
output = None
output_hash = None
try: try:
match method: match method:
case 'diffuse': case 'diffuse':
image = None arguments = prepare_params_for_diffuse(
params, input_type, binary=binary)
arguments = prepare_params_for_diffuse(params, binary)
prompt, guidance, step, seed, upscaler, extra_params = arguments prompt, guidance, step, seed, upscaler, extra_params = arguments
model = self.get_model(params['model'], 'image' in extra_params) model = self.get_model(params['model'], 'image' in extra_params)
image = model( output = model(
prompt, prompt,
guidance_scale=guidance, guidance_scale=guidance,
num_inference_steps=step, num_inference_steps=step,
@ -166,17 +180,22 @@ class SkynetMM:
**extra_params **extra_params
).images[0] ).images[0]
if upscaler == 'x4': output_binary = b''
input_img = image.convert('RGB') match output_type:
up_img, _ = self.upscaler.enhance( case 'png':
convert_from_image_to_cv2(input_img), outscale=4) if upscaler == 'x4':
input_img = output.convert('RGB')
up_img, _ = self.upscaler.enhance(
convert_from_image_to_cv2(input_img), outscale=4)
image = convert_from_cv2_to_image(up_img) output = convert_from_cv2_to_image(up_img)
img_raw = convert_from_img_to_bytes(image) output_binary = convert_from_img_to_bytes(output)
img_sha = sha256(img_raw).hexdigest()
return img_sha, img_raw case _:
raise DGPUComputeError(f'Unsupported output type: {output_type}')
output_hash = sha256(output_binary).hexdigest()
case _: case _:
raise DGPUComputeError('Unsupported compute method') raise DGPUComputeError('Unsupported compute method')
@ -187,3 +206,5 @@ class SkynetMM:
finally: finally:
torch.cuda.empty_cache() torch.cuda.empty_cache()
return output_hash, output

View File

@ -9,10 +9,10 @@ from hashlib import sha256
from functools import partial from functools import partial
import trio import trio
import tractor from skynet.constants import MODELS
from skynet.dgpu.errors import * from skynet.dgpu.errors import *
from skynet.dgpu.compute import SkynetMM, _tractor_static_compute_one from skynet.dgpu.compute import SkynetMM
from skynet.dgpu.network import SkynetGPUConnector from skynet.dgpu.network import SkynetGPUConnector
@ -97,6 +97,11 @@ class SkynetDGPUDaemon:
body = json.loads(req['body']) body = json.loads(req['body'])
model = body['params']['model'] model = body['params']['model']
# if model not known
if model not in MODELS:
logging.warning(f'Unknown model {model}')
continue
# if whitelist enabled and model not in it continue # if whitelist enabled and model not in it continue
if (len(self.model_whitelist) > 0 and if (len(self.model_whitelist) > 0 and
not model in self.model_whitelist): not model in self.model_whitelist):
@ -111,7 +116,7 @@ class SkynetDGPUDaemon:
statuses = self._snap['requests'][rid] statuses = self._snap['requests'][rid]
if len(statuses) == 0: if len(statuses) == 0:
binary = await self.conn.get_input_data(req['binary_data']) binary, input_type = await self.conn.get_input_data(req['binary_data'])
hash_str = ( hash_str = (
str(req['nonce']) str(req['nonce'])
@ -134,46 +139,31 @@ class SkynetDGPUDaemon:
else: else:
try: try:
output_type = 'png'
if 'output_type' in body['params']:
output_type = body['params']['output_type']
output = None
output_hash = None
match self.backend: match self.backend:
case 'sync-on-thread': case 'sync-on-thread':
self.mm._should_cancel = self.should_cancel_work self.mm._should_cancel = self.should_cancel_work
img_sha, img_raw = await trio.to_thread.run_sync( output_hash, output = await trio.to_thread.run_sync(
partial( partial(
self.mm.compute_one, self.mm.compute_one,
rid, rid,
body['method'], body['params'], binary=binary body['method'], body['params'],
) input_type=input_type,
)
case 'tractor':
async def _should_cancel_oracle():
while True:
await trio.sleep(1)
if (await self.should_cancel_work(rid)):
raise DGPUInferenceCancelled
async with (
trio.open_nursery() as trio_n,
tractor.open_nursery() as tractor_n
):
trio_n.start_soon(_should_cancel_oracle)
portal = await tractor_n.run_in_actor(
_tractor_static_compute_one,
name='tractor-cuda-mp',
request_id=rid,
method=body['method'],
params=body['params'],
binary=binary binary=binary
) )
img_sha, img_raw = await portal.result() )
trio_n.cancel_scope.cancel()
case _: case _:
raise DGPUComputeError(f'Unsupported backend {self.backend}') raise DGPUComputeError(f'Unsupported backend {self.backend}')
ipfs_hash = await self.conn.publish_on_ipfs(img_raw) ipfs_hash = await self.conn.publish_on_ipfs(output, typ=output_type)
await self.conn.submit_work(rid, request_hash, img_sha, ipfs_hash) await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
except BaseException as e: except BaseException as e:
traceback.print_exc() traceback.print_exc()

View File

@ -9,17 +9,19 @@ from pathlib import Path
from functools import partial from functools import partial
import asks import asks
import numpy
import trio import trio
import anyio import anyio
import torch
from PIL import Image, UnidentifiedImageError from PIL import Image, UnidentifiedImageError
from leap.cleos import CLEOS from leap.cleos import CLEOS
from leap.sugar import Checksum256, Name, asset_from_str from leap.sugar import Checksum256, Name, asset_from_str
from skynet.constants import DEFAULT_DOMAIN
from skynet.dgpu.errors import DGPUComputeError
from skynet.ipfs import AsyncIPFSHTTP, get_ipfs_file from skynet.ipfs import AsyncIPFSHTTP, get_ipfs_file
from skynet.dgpu.errors import DGPUComputeError
from skynet.constants import DEFAULT_DOMAIN
REQUEST_UPDATE_TIME = 3 REQUEST_UPDATE_TIME = 3
@ -235,11 +237,19 @@ class SkynetGPUConnector:
) )
# IPFS helpers # IPFS helpers
async def publish_on_ipfs(self, raw_img: bytes): async def publish_on_ipfs(self, raw, typ: str = 'png'):
Path('ipfs-staging').mkdir(exist_ok=True) Path('ipfs-staging').mkdir(exist_ok=True)
logging.info('publish_on_ipfs') logging.info('publish_on_ipfs')
img = Image.open(io.BytesIO(raw_img))
img.save('ipfs-staging/image.png') target_file = ''
match typ:
case 'png':
raw: Image
target_file = 'ipfs-staging/image.png'
raw.save(target_file)
case _:
raise ValueError(f'Unsupported output type: {typ}')
if self.ipfs_gateway_url: if self.ipfs_gateway_url:
# check peer connections, reconnect to skynet gateway if not # check peer connections, reconnect to skynet gateway if not
@ -248,16 +258,18 @@ class SkynetGPUConnector:
if gateway_id not in [p['Peer'] for p in peers]: if gateway_id not in [p['Peer'] for p in peers]:
await self.ipfs_client.connect(self.ipfs_gateway_url) await self.ipfs_client.connect(self.ipfs_gateway_url)
file_info = await self.ipfs_client.add(Path('ipfs-staging/image.png')) file_info = await self.ipfs_client.add(Path(target_file))
file_cid = file_info['Hash'] file_cid = file_info['Hash']
await self.ipfs_client.pin(file_cid) await self.ipfs_client.pin(file_cid)
return file_cid return file_cid
async def get_input_data(self, ipfs_hash: str) -> bytes: async def get_input_data(self, ipfs_hash: str) -> tuple[bytes, str]:
input_type = 'none'
if ipfs_hash == '': if ipfs_hash == '':
return b'' return b'', input_type
results = {} results = {}
ipfs_link = f'https://ipfs.{DEFAULT_DOMAIN}/ipfs/{ipfs_hash}' ipfs_link = f'https://ipfs.{DEFAULT_DOMAIN}/ipfs/{ipfs_hash}'
@ -272,9 +284,10 @@ class SkynetGPUConnector:
else: else:
try: try:
with Image.open(io.BytesIO(res.raw)): # attempt to decode as image
results[link] = res.raw results[link] = Image.open(io.BytesIO(res.raw))
n.cancel_scope.cancel() input_type = 'png'
n.cancel_scope.cancel()
except UnidentifiedImageError: except UnidentifiedImageError:
logging.warning(f'couldn\'t get ipfs binary data at {link}!') logging.warning(f'couldn\'t get ipfs binary data at {link}!')
@ -284,14 +297,14 @@ class SkynetGPUConnector:
n.start_soon( n.start_soon(
get_and_set_results, ipfs_link_legacy) get_and_set_results, ipfs_link_legacy)
png_img = None input_data = None
if ipfs_link_legacy in results: if ipfs_link_legacy in results:
png_img = results[ipfs_link_legacy] input_data = results[ipfs_link_legacy]
if ipfs_link in results: if ipfs_link in results:
png_img = results[ipfs_link] input_data = results[ipfs_link]
if not png_img: if input_data == None:
raise DGPUComputeError('Couldn\'t gather input data from ipfs') raise DGPUComputeError('Couldn\'t gather input data from ipfs')
return png_img return input_data, input_type

View File

@ -18,15 +18,10 @@ from PIL import Image
from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.rrdbnet_arch import RRDBNet
from diffusers import ( from diffusers import (
DiffusionPipeline, DiffusionPipeline,
StableDiffusionXLPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionPipeline,
StableDiffusionImg2ImgPipeline,
EulerAncestralDiscreteScheduler EulerAncestralDiscreteScheduler
) )
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
from huggingface_hub import login from huggingface_hub import login
from torch.distributions import weibull
import trio import trio
from .constants import MODELS from .constants import MODELS
@ -56,11 +51,10 @@ def convert_from_img_to_bytes(image: Image, fmt='PNG') -> bytes:
return byte_arr.getvalue() return byte_arr.getvalue()
def convert_from_bytes_and_crop(raw: bytes, max_w: int, max_h: int) -> Image: def crop_image(image: Image, max_w: int, max_h: int) -> Image:
image = convert_from_bytes_to_img(raw)
w, h = image.size w, h = image.size
if w > max_w or h > max_h: if w > max_w or h > max_h:
image.thumbnail((512, 512)) image.thumbnail((max_w, max_h))
return image.convert('RGB') return image.convert('RGB')
@ -74,7 +68,6 @@ def pipeline_for(
assert torch.cuda.is_available() assert torch.cuda.is_available()
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.set_per_process_memory_fraction(mem_fraction)
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
@ -89,6 +82,7 @@ def pipeline_for(
req_mem = model_info['mem'] req_mem = model_info['mem']
mem_gb = torch.cuda.mem_get_info()[1] / (10**9) mem_gb = torch.cuda.mem_get_info()[1] / (10**9)
mem_gb *= mem_fraction
over_mem = mem_gb < req_mem over_mem = mem_gb < req_mem
if over_mem: if over_mem:
logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..') logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..')
@ -96,26 +90,19 @@ def pipeline_for(
shortname = model_info['short'] shortname = model_info['short']
params = { params = {
'torch_dtype': torch.float16,
'safety_checker': None, 'safety_checker': None,
'cache_dir': cache_dir 'torch_dtype': torch.float16,
'cache_dir': cache_dir,
'variant': 'fp16'
} }
if shortname == 'stable': match shortname:
params['revision'] = 'fp16' case 'stable':
params['revision'] = 'fp16'
if 'xl' in shortname: torch.cuda.set_per_process_memory_fraction(mem_fraction)
if image:
pipe_class = StableDiffusionXLImg2ImgPipeline
else:
pipe_class = StableDiffusionXLPipeline
else:
if image:
pipe_class = StableDiffusionImg2ImgPipeline
else:
pipe_class = StableDiffusionPipeline
pipe = pipe_class.from_pretrained( pipe = DiffusionPipeline.from_pretrained(
model, **params) model, **params)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
@ -151,12 +138,6 @@ def txt2img(
steps: int = 28, steps: int = 28,
seed: Optional[int] = None seed: Optional[int] = None
): ):
assert torch.cuda.is_available()
torch.cuda.empty_cache()
torch.cuda.set_per_process_memory_fraction(1.0)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
login(token=hf_token) login(token=hf_token)
pipe = pipeline_for(model) pipe = pipeline_for(model)
@ -184,12 +165,6 @@ def img2img(
steps: int = 28, steps: int = 28,
seed: Optional[int] = None seed: Optional[int] = None
): ):
assert torch.cuda.is_available()
torch.cuda.empty_cache()
torch.cuda.set_per_process_memory_fraction(1.0)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
login(token=hf_token) login(token=hf_token)
pipe = pipeline_for(model, image=True) pipe = pipeline_for(model, image=True)
@ -230,12 +205,6 @@ def upscale(
output: str = 'output.png', output: str = 'output.png',
model_path: str = 'weights/RealESRGAN_x4plus.pth' model_path: str = 'weights/RealESRGAN_x4plus.pth'
): ):
assert torch.cuda.is_available()
torch.cuda.empty_cache()
torch.cuda.set_per_process_memory_fraction(1.0)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
input_img = Image.open(img_path).convert('RGB') input_img = Image.open(img_path).convert('RGB')
upscaler = init_upscaler(model_path=model_path) upscaler = init_upscaler(model_path=model_path)
@ -258,7 +227,7 @@ async def download_upscaler():
f.write(response.content) f.write(response.content)
print('done') print('done')
def download_all_models(hf_token: str): def download_all_models(hf_token: str, hf_home: str):
assert torch.cuda.is_available() assert torch.cuda.is_available()
trio.run(download_upscaler) trio.run(download_upscaler)
@ -266,6 +235,4 @@ def download_all_models(hf_token: str):
login(token=hf_token) login(token=hf_token)
for model in MODELS: for model in MODELS:
print(f'DOWNLOADING {model.upper()}') print(f'DOWNLOADING {model.upper()}')
pipeline_for(model) pipeline_for(model, cache_dir=hf_home)
print(f'DOWNLOADING IMAGE {model.upper()}')
pipeline_for(model, image=True)