Start using msgspec for message serialization/deseraliazation

Add new pipeline_for_v2 that loads based on ModelParams struct
Fix cli to new protocol_v2
Fix worker code to new protocol_v2
Switch to pdbplus
Split cuda_utils and normal utils
protocol_v2
Guillermo Rodriguez 2023-10-16 07:13:38 -03:00
parent d18d59a0ab
commit 2c4a8661ef
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
13 changed files with 1518 additions and 1103 deletions

1323
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -11,10 +11,12 @@ python = '>=3.10,<3.12'
pytz = '^2023.3.post1'
trio = '^0.22.2'
asks = '^3.0.0'
toml = '^0.10.2'
Pillow = '^10.0.1'
docker = '^6.1.3'
py-leap = {git = 'https://github.com/guilledk/py-leap.git', rev = 'v0.1a14'}
toml = '^0.10.2'
ueosio = {git = 'https://github.com/EOSArgentina/ueosio.git', rev = '543ab0a8b4b515d4b34ff02f1af4252b34ebd554'}
py-leap = {git = 'https://github.com/guilledk/py-leap.git', rev = 'multi_push_action'}
msgspec = '^0.18.4'
[tool.poetry.group.frontend]
optional = true
@ -30,7 +32,7 @@ pyTelegramBotAPI = {version = '^4.14.0'}
optional = true
[tool.poetry.group.dev.dependencies]
pdbpp = {version = '^0.10.3'}
pdbplus = {version = '^1.5.0'}
pytest = {version = '^7.4.2'}
[tool.poetry.group.cuda]
@ -41,6 +43,7 @@ torch = {version = '2.0.1+cu118', source = 'torch'}
scipy = {version = '^1.11.2'}
numba = {version = '0.57.0'}
quart = {version = '^0.19.3'}
compel = {version = '^2.0.2'}
triton = {version = '2.0.0', source = 'torch'}
basicsr = {version = '^1.4.2'}
xformers = {version = '^0.0.22'}

View File

@ -19,6 +19,13 @@ auto_withdraw = true
non_compete = []
api_bind = '127.0.0.1:42690'
[[initial_models]]
name = 'stabilityai/stable-diffusion-xl-base-1.0'
pipe_fqn = 'diffusers.DiffusionPipeline'
[initial_models.setup]
variant = 'fp16'
# telegram bot config (optional)
[skynet.telegram]
account = 'telegram'

View File

@ -1,2 +1,3 @@
#!/usr/bin/python
import pdbp

View File

@ -8,7 +8,10 @@ from functools import partial
import click
from leap.sugar import Name, asset_from_str
from leap.sugar import Name, ListArgument, asset_from_str, symbol_from_str
import msgspec
from skynet.protocol import ComputeRequest, ParamsStruct, RequestRow
from .config import *
from .constants import *
@ -93,37 +96,49 @@ def download():
@click.option('--jobs', '-j', default=1)
@click.option('--model', '-m', default='stabilityai/stable-diffusion-xl-base-1.0')
@click.option(
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
@click.option('--output', '-o', default='output.png')
@click.option('--width', '-w', default=1024)
@click.option('--height', '-h', default=1024)
'--prompt', '-p',
default='cyberpunk skynet terminator skull a post impressionist oil painting with muted colors authored by Paul Cézanne, Paul Gauguin, Vincent van Gogh, Georges Seurat')
@click.option('--guidance', '-g', default=10)
@click.option('--step', '-s', default=26)
@click.option('--width', '-w', default=1024)
@click.option('--height', '-h', default=1024)
@click.option('--seed', '-S', default=None)
@click.option('--upscaler', '-U', default='x4')
@click.option('--binary_data', '-b', default='')
@click.option('--input', '-i', multiple=True)
@click.option('--strength', '-Z', default=None)
def enqueue(
reward: str,
jobs: int,
model: str,
prompt: str,
guidance: float,
step: int,
**kwargs
):
import trio
from leap.cleos import CLEOS
config = load_skynet_toml()
logging.basicConfig(level='INFO')
key = load_key(config, 'skynet.user.key')
account = load_key(config, 'skynet.user.account')
permission = load_key(config, 'skynet.user.permission')
node_url = load_key(config, 'skynet.node_url')
contract = load_key(config, 'skynet.contract')
cleos = CLEOS(None, None, url=node_url, remote=node_url)
binary = kwargs['binary_data']
inputs = kwargs['input']
if len(inputs) > 0:
del kwargs['width']
del kwargs['height']
del kwargs['input']
if not kwargs['strength']:
if binary:
raise ValueError('strength -Z param required if binary data passed')
if len(inputs) > 0:
raise ValueError('strength -Z param required if input data passed')
del kwargs['strength']
@ -139,29 +154,45 @@ def enqueue(
seed = random.randint(0, 10e9)
_kwargs = kwargs.copy()
_kwargs['seed'] = seed
_kwargs['generator'] = seed
del _kwargs['seed']
req = json.dumps({
'method': 'diffuse',
'params': _kwargs
})
request = ComputeRequest(
method='diffuse',
params=ParamsStruct(
model=ModelParams(
name=model,
pipe_fqn='diffusers.DiffusionPipeline',
setup={'variant': 'fp16'}
),
runtime_args=[prompt],
runtime_kwargs={
'guidance_scale': guidance,
'num_inference_steps': step,
**_kwargs
}
)
)
req = msgspec.json.encode(request)
actions.append({
'account': 'telos.gpu',
'account': contract,
'name': 'enqueue',
'data': {
'user': Name(account),
'request_body': req,
'binary_data': binary,
'reward': asset_from_str(reward),
'min_verification': 1
},
'data': [
Name(account),
ListArgument(req, 'uint8'),
ListArgument(inputs, 'string'),
asset_from_str(reward),
1
],
'authorization': [{
'actor': account,
'permission': permission
}]
})
# breakpoint()
res = await cleos.a_push_actions(actions, key)
print(res)
@ -181,13 +212,14 @@ def clean(
account = load_key(config, 'skynet.user.account')
permission = load_key(config, 'skynet.user.permission')
node_url = load_key(config, 'skynet.node_url')
contract = load_key(config, 'skynet.contract')
logging.basicConfig(level=loglevel)
cleos = CLEOS(None, None, url=node_url, remote=node_url)
trio.run(
partial(
cleos.a_push_action,
'telos.gpu',
contract,
'clean',
{},
account, key, permission=permission
@ -199,33 +231,26 @@ def queue():
import requests
config = load_skynet_toml()
node_url = load_key(config, 'skynet.node_url')
contract = load_key(config, 'skynet.contract')
resp = requests.post(
f'{node_url}/v1/chain/get_table_rows',
json={
'code': 'telos.gpu',
'code': contract,
'table': 'queue',
'scope': 'telos.gpu',
'scope': contract,
'json': True
}
)
print(json.dumps(resp.json(), indent=4))
).json()
# process hex body
results = []
for row in resp['rows']:
req = row.copy()
req['body'] = json.loads(bytes.fromhex(req['body']).decode())
results.append(req)
print(json.dumps(results, indent=4))
@skynet.command()
@click.argument('request-id')
def status(request_id: int):
import requests
config = load_skynet_toml()
node_url = load_key(config, 'skynet.node_url')
resp = requests.post(
f'{node_url}/v1/chain/get_table_rows',
json={
'code': 'telos.gpu',
'table': 'status',
'scope': request_id,
'json': True
}
)
print(json.dumps(resp.json(), indent=4))
@skynet.command()
@click.argument('request-id')
@ -238,12 +263,13 @@ def dequeue(request_id: int):
account = load_key(config, 'skynet.user.account')
permission = load_key(config, 'skynet.user.permission')
node_url = load_key(config, 'skynet.node_url')
contract = load_key(config, 'skynet.contract')
cleos = CLEOS(None, None, url=node_url, remote=node_url)
res = trio.run(
partial(
cleos.a_push_action,
'telos.gpu',
contract,
'dequeue',
{
'user': Name(account),
@ -256,33 +282,39 @@ def dequeue(request_id: int):
@skynet.command()
@click.option(
'--token-contract', '-c', default='eosio.token')
@click.option(
'--token-symbol', '-S', default='4,GPU')
@click.argument(
'token-contract', required=True)
@click.argument(
'token-symbol', required=True)
@click.argument(
'nonce', required=True)
def config(
token_contract: str,
token_symbol: str
token_symbol: str,
nonce: int
):
import trio
from leap.cleos import CLEOS
logging.basicConfig(level='INFO')
config = load_skynet_toml()
key = load_key(config, 'skynet.user.key')
account = load_key(config, 'skynet.user.account')
permission = load_key(config, 'skynet.user.permission')
node_url = load_key(config, 'skynet.node_url')
contract = load_key(config, 'skynet.contract')
cleos = CLEOS(None, None, url=node_url, remote=node_url)
res = trio.run(
partial(
cleos.a_push_action,
'telos.gpu',
contract,
'config',
{
'token_contract': token_contract,
'token_symbol': token_symbol,
'token_contract': Name(token_contract),
'token_symbol': symbol_from_str(token_symbol),
'nonce': int(nonce)
},
account, key, permission=permission
)
@ -302,16 +334,17 @@ def deposit(quantity: str):
account = load_key(config, 'skynet.user.account')
permission = load_key(config, 'skynet.user.permission')
node_url = load_key(config, 'skynet.node_url')
contract = load_key(config, 'skynet.contract')
cleos = CLEOS(None, None, url=node_url, remote=node_url)
res = trio.run(
partial(
cleos.a_push_action,
'telos.gpu',
'eosio.token',
'transfer',
{
'sender': Name(account),
'recipient': Name('telos.gpu'),
'recipient': Name(contract),
'amount': asset_from_str(quantity),
'memo': f'{account} transferred {quantity} to telos.gpu'
},

View File

@ -1,5 +1,8 @@
#!/usr/bin/python
from skynet.protocol import ModelParams
VERSION = '0.1a12'
DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda'
@ -167,7 +170,11 @@ DEFAULT_UPSCALER = None
DEFAULT_CONFIG_PATH = 'skynet.toml'
DEFAULT_INITAL_MODELS = [
'stabilityai/stable-diffusion-xl-base-1.0'
ModelParams(
name='stabilityai/stable-diffusion-xl-base-1.0',
pipe_fqn='diffusers.DiffusionPipeline',
setup={'variant': 'fp16'}
)
]
DATE_FORMAT = '%B the %dth %Y, %H:%M:%S'

View File

@ -0,0 +1,298 @@
#!/usr/bin/python
from copy import deepcopy
import io
import os
import sys
import random
import logging
from typing import Any, Optional
from pathlib import Path
from importlib import import_module
import trio
import asks
import torch
import numpy as np
from PIL import Image
from basicsr.archs.rrdbnet_arch import RRDBNet
from diffusers import (
DiffusionPipeline,
EulerAncestralDiscreteScheduler
)
from realesrgan import RealESRGANer
from huggingface_hub import login
from skynet.protocol import ModelParams
from .constants import MODELS
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
# return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
return Image.fromarray(img)
def convert_from_image_to_cv2(img: Image) -> np.ndarray:
# return cv2.cvtColor(numpy.array(img), cv2.COLOR_RGB2BGR)
return np.asarray(img)
def convert_from_bytes_to_img(raw: bytes) -> Image:
return Image.open(io.BytesIO(raw))
def convert_from_img_to_bytes(image: Image, fmt='PNG') -> bytes:
byte_arr = io.BytesIO()
image.save(byte_arr, format=fmt)
return byte_arr.getvalue()
def crop_image(image: Image, max_w: int, max_h: int) -> Image:
w, h = image.size
if w > max_w or h > max_h:
image.thumbnail((max_w, max_h))
return image.convert('RGB')
def pipeline_for(
model: str,
mem_fraction: float = 1.0,
image: bool = False,
cache_dir: str | None = None
) -> DiffusionPipeline:
assert torch.cuda.is_available()
torch.cuda.empty_cache()
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# full determinism
# https://huggingface.co/docs/diffusers/using-diffusers/reproducibility#deterministic-algorithms
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
model_info = MODELS[model]
req_mem = model_info['mem']
mem_gb = torch.cuda.mem_get_info()[1] / (10**9)
mem_gb *= mem_fraction
over_mem = mem_gb < req_mem
if over_mem:
logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..')
shortname = model_info['short']
params = {
'safety_checker': None,
'torch_dtype': torch.float16,
'cache_dir': cache_dir,
'variant': 'fp16'
}
match shortname:
case 'stable':
params['revision'] = 'fp16'
torch.cuda.set_per_process_memory_fraction(mem_fraction)
pipe = DiffusionPipeline.from_pretrained(
model, **params)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe.scheduler.config)
pipe.enable_xformers_memory_efficient_attention()
if over_mem:
if not image:
pipe.enable_vae_slicing()
pipe.enable_vae_tiling()
pipe.enable_model_cpu_offload()
else:
if sys.version_info[1] < 11:
# torch.compile only supported on python < 3.11
pipe.unet = torch.compile(
pipe.unet, mode='reduce-overhead', fullgraph=True)
pipe = pipe.to('cuda')
return pipe
def pipeline_for_v2(
model: ModelParams,
mem_fraction: float = 1.0,
cache_dir: str | None = None
) -> Any:
mod_name, class_name = model.pipe_fqn.rsplit('.', 1)
mod = import_module(mod_name)
pipe_class = getattr(mod, class_name)
assert torch.cuda.is_available()
torch.cuda.empty_cache()
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# full determinism
# https://huggingface.co/docs/diffusers/using-diffusers/reproducibility#deterministic-algorithms
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
model_info = MODELS[model.name]
req_mem = model_info['mem']
mem_gb = torch.cuda.mem_get_info()[1] / (10**9)
mem_gb *= mem_fraction
over_mem = mem_gb < req_mem
if over_mem:
logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..')
torch.cuda.set_per_process_memory_fraction(mem_fraction)
setup_params = deepcopy(model.setup)
setup_params['safety_checker'] = None
setup_params['torch_dtype'] = torch.float16
setup_params['cache_dir'] = cache_dir
pipe = pipe_class.from_pretrained(model.name, **setup_params)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe.scheduler.config)
pipe.enable_xformers_memory_efficient_attention()
if over_mem:
if 'Img' not in model.pipe_fqn:
pipe.enable_vae_slicing()
pipe.enable_vae_tiling()
pipe.enable_model_cpu_offload()
else:
if sys.version_info[1] < 11:
# torch.compile only supported on python < 3.11
pipe.unet = torch.compile(
pipe.unet, mode='reduce-overhead', fullgraph=True)
pipe = pipe.to('cuda')
return pipe
def txt2img(
hf_token: str,
model: str = 'prompthero/openjourney',
prompt: str = 'a red old tractor in a sunny wheat field',
output: str = 'output.png',
width: int = 512, height: int = 512,
guidance: float = 10,
steps: int = 28,
seed: Optional[int] = None
):
login(token=hf_token)
pipe = pipeline_for(model)
seed = seed if seed else random.randint(0, 2 ** 64)
prompt = prompt
image = pipe(
prompt,
width=width,
height=height,
guidance_scale=guidance, num_inference_steps=steps,
generator=torch.Generator("cuda").manual_seed(seed)
).images[0]
image.save(output)
def img2img(
hf_token: str,
model: str = 'prompthero/openjourney',
prompt: str = 'a red old tractor in a sunny wheat field',
img_path: str = 'input.png',
output: str = 'output.png',
strength: float = 1.0,
guidance: float = 10,
steps: int = 28,
seed: Optional[int] = None
):
login(token=hf_token)
pipe = pipeline_for(model, image=True)
with open(img_path, 'rb') as img_file:
input_img = convert_from_bytes_and_crop(img_file.read(), 512, 512)
seed = seed if seed else random.randint(0, 2 ** 64)
prompt = prompt
image = pipe(
prompt,
image=input_img,
strength=strength,
guidance_scale=guidance, num_inference_steps=steps,
generator=torch.Generator("cuda").manual_seed(seed)
).images[0]
image.save(output)
def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'):
return RealESRGANer(
scale=4,
model_path=model_path,
dni_weight=None,
model=RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4
),
half=True
)
def upscale(
img_path: str = 'input.png',
output: str = 'output.png',
model_path: str = 'weights/RealESRGAN_x4plus.pth'
):
input_img = Image.open(img_path).convert('RGB')
upscaler = init_upscaler(model_path=model_path)
up_img, _ = upscaler.enhance(
convert_from_image_to_cv2(input_img), outscale=4)
image = convert_from_cv2_to_image(up_img)
image.save(output)
async def download_upscaler():
print('downloading upscaler...')
weights_path = Path('weights')
weights_path.mkdir(exist_ok=True)
upscaler_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'
save_path = weights_path / 'RealESRGAN_x4plus.pth'
response = await asks.get(upscaler_url)
with open(save_path, 'wb') as f:
f.write(response.content)
print('done')
def download_all_models(hf_token: str, hf_home: str):
assert torch.cuda.is_available()
trio.run(download_upscaler)
login(token=hf_token)
for model in MODELS:
print(f'DOWNLOADING {model.upper()}')
pipeline_for(model, cache_dir=hf_home)

View File

@ -11,20 +11,22 @@ from skynet.dgpu.network import SkynetGPUConnector
async def open_dgpu_node(config: dict):
conn = SkynetGPUConnector({**config, **config['dgpu']})
mm = SkynetMM(config['dgpu'])
daemon = SkynetDGPUDaemon(mm, conn, config['dgpu'])
config = {**config, **config['dgpu']}
conn = SkynetGPUConnector(config)
mm = SkynetMM(config)
daemon = SkynetDGPUDaemon(mm, conn, config)
api = None
if 'api_bind' in config['dgpu']:
if 'api_bind' in config:
api_conf = Config()
api_conf.bind = [config['api_bind']]
api = await daemon.generate_api()
async with trio.open_nursery() as n:
n.start_soon(conn.data_updater_task)
await n.start(conn.data_updater_task)
if api:
n.start_soon(serve, api, api_conf)
await daemon.serve_forever()
n.cancel_scope.cancel()

View File

@ -5,10 +5,9 @@
import gc
import logging
from hashlib import sha256
from copy import deepcopy
from typing import Any
from PIL import Image
from diffusers import DiffusionPipeline
import trio
@ -16,53 +15,29 @@ import torch
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
from skynet.dgpu.errors import DGPUComputeError, DGPUInferenceCancelled
from skynet.protocol import ComputeRequest, ModelParams, ParamsStruct
from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
from skynet.cuda_utils import (
init_upscaler,
pipeline_for_v2
)
def prepare_params_for_diffuse(
params: dict,
inputs: list[tuple[Any, str]],
):
_params = {}
def unpack_diffuse_params(params: ParamsStruct):
kwargs = deepcopy(params.runtime_kwargs)
if len(inputs) > 1:
raise DGPUComputeError('sorry binary_inputs > 1 not implemented yet')
if 'generator' in kwargs:
kwargs['generator'] = torch.manual_seed(int(kwargs['generator']))
if len(inputs) == 0:
binary, input_type = inputs[0]
match input_type:
case 'png':
image = crop_image(
binary, params['width'], params['height'])
_params['image'] = image
_params['strength'] = float(params['strength'])
case _:
raise DGPUComputeError(f'Unknown input_type {input_type}')
else:
_params['width'] = int(params['width'])
_params['height'] = int(params['height'])
return (
params['prompt'],
float(params['guidance']),
int(params['step']),
torch.manual_seed(int(params['seed'])),
params['upscaler'] if 'upscaler' in params else None,
_params
)
return params.runtime_args, kwargs
class SkynetMM:
def __init__(self, config: dict):
self.upscaler = init_upscaler()
self.initial_models = (
config['initial_models']
self.initial_models: list[ModelParams] = (
[ModelParams(**model) for model in config['initial_models']]
if 'initial_models' in config else DEFAULT_INITAL_MODELS
)
@ -78,35 +53,28 @@ class SkynetMM:
self._models = {}
for model in self.initial_models:
self.load_model(model, False, force=True)
self.load_model(model)
def log_debug_info(self):
logging.info('memory summary:')
logging.info('\n' + torch.cuda.memory_summary())
def is_model_loaded(self, model_name: str, image: bool):
for model_key, model_data in self._models.items():
if (model_key == model_name and
model_data['image'] == image):
return True
return False
def is_model_loaded(self, model: ModelParams):
return model.get_uid() in self._models
def load_model(
self,
model_name: str,
image: bool,
force=False
model: ModelParams
):
logging.info(f'loading model {model_name}...')
if force or len(self._models.keys()) == 0:
pipe = pipeline_for(
model_name, image=image, cache_dir=self.cache_dir)
logging.info(f'loading model {model.name}...')
if len(self._models.keys()) == 0:
pipe = pipeline_for_v2(
model, cache_dir=self.cache_dir)
self._models[model_name] = {
self._models[model.get_uid()] = {
'pipe': pipe,
'generated': 0,
'image': image
'params': model,
'generated': 0
}
else:
@ -119,42 +87,41 @@ class SkynetMM:
del self._models[least_used]
logging.info(f'swapping model {least_used} for {model_name}...')
logging.info(f'swapping model {least_used} for {model.get_uid()}...')
gc.collect()
torch.cuda.empty_cache()
pipe = pipeline_for(
model_name, image=image, cache_dir=self.cache_dir)
pipe = pipeline_for_v2(
model, cache_dir=self.cache_dir)
self._models[model_name] = {
self._models[model.get_uid()] = {
'pipe': pipe,
'generated': 0,
'image': image
'params': model,
'generated': 0
}
logging.info(f'loaded model {model_name}')
logging.info(f'loaded model {model.name}')
return pipe
def get_model(self, model_name: str, image: bool) -> DiffusionPipeline:
if model_name not in MODELS:
raise DGPUComputeError(f'Unknown model {model_name}')
def get_model(self, model: ModelParams) -> DiffusionPipeline:
if model.name not in MODELS:
raise DGPUComputeError(f'Unknown model {model.name}')
if not self.is_model_loaded(model_name, image):
pipe = self.load_model(model_name, image=image)
if not self.is_model_loaded(model):
pipe = self.load_model(model)
else:
pipe = self._models[model_name]['pipe']
pipe = self._models[model.get_uid()]['pipe']
return pipe
def compute_one(
self,
request_id: int,
method: str,
params: dict,
request: ComputeRequest,
inputs: list[tuple[Any, str]]
):
) -> list[tuple[bytes, str]]:
def maybe_cancel_work(step, *args, **kwargs):
if self._should_cancel:
should_raise = trio.from_thread.run(self._should_cancel, request_id)
@ -165,44 +132,24 @@ class SkynetMM:
maybe_cancel_work(0)
output_type = 'png'
if 'output_type' in params:
output_type = params['output_type']
if 'output_type' in request.params.runtime_kwargs:
output_type = request.params.runtime_kwargs['output_type']
output = None
output_hash = None
outputs = None
try:
match method:
match request.method:
case 'diffuse':
arguments = prepare_params_for_diffuse(params, inputs)
prompt, guidance, step, seed, upscaler, extra_params = arguments
model = self.get_model(params['model'], 'image' in extra_params)
model = self.get_model(request.params.model)
output = model(
prompt,
guidance_scale=guidance,
num_inference_steps=step,
generator=seed,
args, kwargs = unpack_diffuse_params(request.params)
outputs = model(
*args, **kwargs,
callback=maybe_cancel_work,
callback_steps=1,
**extra_params
).images[0]
callback_steps=1
)
output_binary = b''
match output_type:
case 'png':
if upscaler == 'x4':
input_img = output.convert('RGB')
up_img, _ = self.upscaler.enhance(
convert_from_image_to_cv2(input_img), outscale=4)
output = convert_from_cv2_to_image(up_img)
output_binary = convert_from_img_to_bytes(output)
case _:
raise DGPUComputeError(f'Unsupported output type: {output_type}')
output_hash = sha256(output_binary).hexdigest()
output = outputs.images[0]
case _:
raise DGPUComputeError('Unsupported compute method')
@ -214,4 +161,4 @@ class SkynetMM:
finally:
torch.cuda.empty_cache()
return output_hash, output
return [(output, output_type)]

View File

@ -1,8 +1,9 @@
#!/usr/bin/python
import json
import logging
import time
import random
import logging
import traceback
from hashlib import sha256
@ -19,6 +20,7 @@ from skynet.constants import MODELS, VERSION
from skynet.dgpu.errors import *
from skynet.dgpu.compute import SkynetMM
from skynet.dgpu.network import SkynetGPUConnector
from skynet.protocol import ComputeRequest, ModelParams, ParamsStruct, RequestRow
def convert_reward_to_int(reward_str):
@ -87,9 +89,12 @@ class SkynetDGPUDaemon:
async def should_cancel_work(self, request_id: int):
self._benchmark.append(time.time())
competitors = self.conn.get_competitors_for_request(request_id)
if competitors == None:
return True
status = self.conn.get_status_for_request(request_id)
competitors = [
s.worker
for s in status
if s.worker != self.account
]
return bool(self.non_compete & set(competitors))
async def generate_api(self):
@ -106,25 +111,37 @@ class SkynetDGPUDaemon:
return app
def find_best_requests(self) -> list[dict]:
def find_best_requests(self) -> list[tuple[RequestRow, ComputeRequest]]:
queue = self.conn.get_queue()
# for _ in range(3):
# random.shuffle(queue)
for _ in range(3):
random.shuffle(queue)
# queue = sorted(
# queue,
# key=lambda req: convert_reward_to_int(req['reward']),
# reverse=True
# )
queue = sorted(
queue,
key=lambda req: convert_reward_to_int(req.reward),
reverse=True
)
requests = []
for req in queue:
rid = req['nonce']
rid = req.nonce
# parse request
body = json.loads(req['body'])
model = body['params']['model']
try:
req_json = json.loads(req.body)
compute_request = ComputeRequest(**req_json)
compute_request.params = ParamsStruct(**req_json['params'])
compute_request.params.model = ModelParams(**req_json['params']['model'])
model = compute_request.params.model.name
except TypeError as e:
logging.warning(f'Couldn\'t parse request: {e}')
continue
except json.JSONDecodeError as e:
logging.warning(f'Couldn\'t parse request: {e}')
continue
# if model not known
if model not in MODELS:
@ -140,7 +157,7 @@ class SkynetDGPUDaemon:
if model in self.model_blacklist:
continue
my_results = [res['id'] for res in self.conn.get_my_results()]
my_results = [res.id for res in self.conn.get_my_results()]
# if this worker already on it
if rid in my_results:
@ -150,13 +167,17 @@ class SkynetDGPUDaemon:
if status == None:
continue
if self.non_compete & set(self.conn.get_competitors_for_request(rid)):
if self.non_compete & set([
s.worker
for s in status
if s.worker != self.account
]):
continue
if len(status) > self.max_concurrent:
continue
requests.append(req)
requests.append((req, compute_request))
return requests
@ -164,24 +185,26 @@ class SkynetDGPUDaemon:
# check worker is registered
me = self.conn.get_on_chain_worker_info(self.account)
if not me:
ec, out = await self.conn.register_worker()
if ec != 0:
res = await self.conn.register_worker()
if 'error' in res:
raise DGPUDaemonError(f'Couldn\'t register worker! {out}')
me = self.conn.get_on_chain_worker_info(self.account)
if not me:
raise DGPUDaemonError('Unknown error while registering')
# find if reported on chain gpus match local
found_difference = False
for i in range(self.mm.num_gpus):
chain_gpu = me['cards'][i]
chain_gpu = me.cards[i]
gpu = self.mm.gpus[i]
gpu_v = f'{gpu.major}.{gpu.minor}'
found_difference = gpu.name != chain_gpu['card_name']
found_difference = gpu_v != chain_gpu['version']
found_difference = gpu.total_memory != chain_gpu['total_memory']
found_difference = gpu.multi_processor_count != chain_gpu['mp_count']
found_difference = gpu.name != chain_gpu.card_name
found_difference = gpu_v != chain_gpu.version
found_difference = gpu.total_memory != chain_gpu.total_memory
found_difference = gpu.multi_processor_count != chain_gpu.mp_count
if found_difference:
break
@ -189,20 +212,24 @@ class SkynetDGPUDaemon:
if found_difference:
await self.conn.flush_cards()
for i, gpu in enumerate(self.mm.gpus):
ec, _ = await self.conn.add_card(
res = await self.conn.add_card(
gpu.name, f'{gpu.major}.{gpu.minor}',
gpu.total_memory, gpu.multi_processor_count,
'',
is_online
)
if ec != 0:
if 'error' in res:
raise DGPUDaemonError(f'error while reporting card {i}')
return found_difference
async def all_gpu_set_online_flag(self, is_online: bool):
for i, chain_gpu in enumerate(me['cards']):
if chain_gpu['is_online'] != is_online:
me = self.conn.get_on_chain_worker_info(self.account)
if not me:
raise DGPUDaemonError('Couldn\'t find worker info!')
for i, chain_gpu in enumerate(me.cards):
if chain_gpu.is_online != is_online:
await self.conn.toggle_card(i)
async def serve_forever(self):
@ -219,23 +246,24 @@ class SkynetDGPUDaemon:
requests = self.find_best_requests()
if len(requests) > 0:
request = requests[0]
rid = request['nonce']
body = json.loads(request['body'])
request, compute_request = requests[0]
rid = request.nonce
body = json.loads(request.body)
logging.info(f'trying to process req: {rid}')
inputs = await self.conn.get_inputs(request['binary_inputs'])
hash_str = (
str(request['nonce'])
hash_buf = (
str(request.nonce).encode()
+
request['body']
request.body.encode()
+
''.join([_in for _in in request['binary_inputs']])
b''.join([_in.encode() for _in in request.inputs])
)
logging.info(f'hashing: {hash_str}')
request_hash = sha256(hash_str.encode('utf-8')).hexdigest()
logging.info(f'hashing str of length {len(hash_buf)}')
request_hash = sha256(hash_buf).hexdigest()
# TODO: validate request
inputs = []
if len(request.inputs) > 0:
inputs = await self.conn.get_inputs(request.inputs)
# perform work
logging.info(f'working on {body}')
@ -247,19 +275,17 @@ class SkynetDGPUDaemon:
else:
try:
output_type = 'png'
if 'output_type' in body['params']:
output_type = body['params']['output_type']
if 'output_type' in compute_request.params.runtime_kwargs:
output_type = compute_request.params.runtime_kwargs['output_type']
output = None
output_hash = None
outputs = []
match self.backend:
case 'sync-on-thread':
self.mm._should_cancel = self.should_cancel_work
output_hash, output = await trio.to_thread.run_sync(
outputs = await trio.to_thread.run_sync(
partial(
self.mm.compute_one,
rid,
body['method'], body['params'],
rid, compute_request,
inputs=inputs
)
)
@ -271,9 +297,9 @@ class SkynetDGPUDaemon:
self._last_benchmark = self._benchmark
self._benchmark = []
ipfs_hash = await self.conn.publish_on_ipfs(output, typ=output_type)
outputs = await self.conn.publish_on_ipfs(outputs)
await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
await self.conn.submit_work(rid, request_hash, outputs)
except BaseException as e:
traceback.print_exc()

View File

@ -16,11 +16,18 @@ import anyio
from PIL import Image, UnidentifiedImageError
from leap.cleos import CLEOS
from leap.sugar import Checksum256, Name, asset_from_str
from leap.sugar import (
ListArgument,
Checksum256,
Name,
asset_from_str
)
from skynet.constants import DEFAULT_IPFS_DOMAIN
from skynet.ipfs import AsyncIPFSHTTP, get_ipfs_file
from skynet.dgpu.errors import DGPUComputeError
from skynet.protocol import CardStruct, ConfigRow, RequestRow, WorkerResultRow, WorkerRow, WorkerStatusStruct
REQUEST_UPDATE_TIME = 3
@ -93,99 +100,107 @@ class SkynetGPUConnector:
else:
return default
async def data_updater_task(self):
async def data_updater_task(self, task_status=trio.TASK_STATUS_IGNORED):
tasks = (
(self._get_work_requests_last_hour, 'queue'),
(self._find_my_results, 'my_results'),
(self._get_workers, 'workers')
)
while True:
async def _update():
async with trio.open_nursery() as n:
for task in tasks:
n.start_soon(self._cache_set, *task)
await trio.sleep(self._update_delta)
await _update()
def get_queue(self):
task_status.started()
while True:
await trio.sleep(self._update_delta)
await _update()
def get_queue(self) -> list[RequestRow]:
return self._cache_get('queue', default=[])
def get_my_results(self):
def get_my_results(self) -> list[WorkerResultRow]:
return self._cache_get('my_results', default=[])
def get_workers(self):
def get_workers(self) -> list[WorkerRow]:
return self._cache_get('workers', default=[])
def get_status_for_request(self, request_id: int) -> list[dict] | None:
request: dict | None = next((
req
def get_status_for_request(self, request_id: int) -> list[WorkerStatusStruct]:
return next((
[WorkerStatusStruct(**status) for status in req.status]
for req in self.get_queue()
if req['id'] == request_id), None)
if req.nonce == request_id), [])
if request:
return request['status']
async def _get_work_requests_last_hour(self) -> list[RequestRow]:
logging.debug('get_work_requests_last_hour')
result = []
for row in (
await failable(
partial(
self.cleos.aget_table,
self.contract, self.contract, 'queue',
index_position=2,
key_type='i64',
lower_bound=int(time.time()) - (60 * 60)
), ret_fail=[])
):
row = RequestRow(**row)
row.body = bytes.fromhex(row.body).decode()
result.append(row)
else:
return None
def get_competitors_for_request(self, request_id: int) -> list[str] | None:
status = self.get_status_for_request(request_id)
if not status:
return None
return result
async def _find_my_results(self) -> list[WorkerResultRow]:
logging.debug('find_my_results')
return [
s['worker']
for s in status
if s['worker'] != self.account
WorkerResultRow(**row)
for row in (
await failable(
partial(
self.cleos.aget_table,
self.contract, self.contract, 'results',
index_position=4,
key_type='name',
lower_bound=self.account,
upper_bound=self.account
)
)
)
]
async def _get_work_requests_last_hour(self) -> list[dict]:
logging.info('get_work_requests_last_hour')
return await failable(
partial(
self.cleos.aget_table,
self.contract, self.contract, 'queue',
index_position=2,
order='asc',
limit=1000
), ret_fail=[])
async def _find_my_results(self):
logging.info('find_my_results')
return await failable(
partial(
self.cleos.aget_table,
self.contract, self.contract, 'results',
index_position=4,
key_type='name',
lower_bound=self.account,
upper_bound=self.account
)
)
async def _get_workers(self) -> list[dict]:
logging.info('get_workers')
return await failable(
async def _get_workers(self) -> list[WorkerRow]:
logging.debug('get_workers')
worker_rows = await failable(
partial(
self.cleos.aget_table,
self.contract, self.contract, 'workers'
)
)
result = []
for row in worker_rows:
row['cards'] = [CardStruct(**card) for card in row['cards']]
result.append(WorkerRow(**row))
async def get_global_config(self):
logging.info('get_global_config')
return result
async def get_global_config(self) -> ConfigRow | None:
logging.debug('get_global_config')
rows = await failable(
partial(
self.cleos.aget_table,
'telos.gpu', 'telos.gpu', 'config'))
self.contract, self.contract, 'config'))
if rows:
return rows[0]
return ConfigRow(**rows[0])
else:
return None
async def get_worker_balance(self):
logging.info('get_worker_balance')
async def get_worker_balance(self) -> str | None:
logging.debug('get_worker_balance')
rows = await failable(
partial(
self.cleos.aget_table,
@ -201,14 +216,14 @@ class SkynetGPUConnector:
else:
return None
def get_on_chain_worker_info(self, worker: str):
def get_on_chain_worker_info(self, worker: str) -> WorkerRow | None:
return next((
w for w in self.get_workers()
if w['account'] == w
if w.account == worker
), None)
async def register_worker(self):
logging.info(f'registering worker')
logging.debug(f'registering worker')
return await failable(
partial(
self.cleos.a_push_action,
@ -217,7 +232,9 @@ class SkynetGPUConnector:
{
'account': self.account,
'url': self.worker_url
}
},
self.account, self.key,
permission=self.permission
)
)
@ -230,7 +247,7 @@ class SkynetGPUConnector:
extra: str,
is_online: bool
):
logging.info(f'adding card: {card_name} {version}')
logging.debug(f'adding card: {card_name} {version}')
return await failable(
partial(
self.cleos.a_push_action,
@ -244,34 +261,40 @@ class SkynetGPUConnector:
'mp_count': mp_count,
'extra': extra,
'is_online': is_online
}
},
self.account, self.key,
permission=self.permission
)
)
async def toggle_card(self, index: int):
logging.info(f'toggle card {index}')
logging.debug(f'toggle card {index}')
return await failable(
partial(
self.cleos.a_push_action,
self.contract,
'togglecard',
{'worker': self.account, 'index': index}
{'worker': self.account, 'index': index},
self.account, self.key,
permission=self.permission
)
)
async def flush_cards(self):
logging.info('flushing cards...')
logging.debug('flushing cards...')
return await failable(
partial(
self.cleos.a_push_action,
self.contract,
'flushcards',
{'worker': self.account}
{'worker': self.account},
self.account, self.key,
permission=self.permission
)
)
async def begin_work(self, request_id: int):
logging.info('begin_work')
logging.debug('begin_work')
return await failable(
partial(
self.cleos.a_push_action,
@ -288,7 +311,7 @@ class SkynetGPUConnector:
)
async def cancel_work(self, request_id: int, reason: str):
logging.info('cancel_work')
logging.debug('cancel_work')
return await failable(
partial(
self.cleos.a_push_action,
@ -305,7 +328,7 @@ class SkynetGPUConnector:
)
async def maybe_withdraw_all(self):
logging.info('maybe_withdraw_all')
logging.debug('maybe_withdraw_all')
balance = await self.get_worker_balance()
if not balance:
return
@ -330,10 +353,9 @@ class SkynetGPUConnector:
self,
request_id: int,
request_hash: str,
result_hash: str,
ipfs_hash: str
outputs: list[str]
):
logging.info('submit_work')
logging.debug('submit_work')
return await failable(
partial(
self.cleos.a_push_action,
@ -343,8 +365,7 @@ class SkynetGPUConnector:
'worker': self.account,
'request_id': request_id,
'request_hash': Checksum256(request_hash),
'result_hash': Checksum256(result_hash),
'ipfs_hash': ipfs_hash
'outputs': ListArgument(outputs, 'string')
},
self.account, self.key,
permission=self.permission
@ -352,19 +373,9 @@ class SkynetGPUConnector:
)
# IPFS helpers
async def publish_on_ipfs(self, raw, typ: str = 'png'):
async def publish_on_ipfs(self, outputs: list[tuple[bytes, str]]) -> list[str]:
Path('ipfs-staging').mkdir(exist_ok=True)
logging.info('publish_on_ipfs')
target_file = ''
match typ:
case 'png':
raw: Image
target_file = 'ipfs-staging/image.png'
raw.save(target_file)
case _:
raise ValueError(f'Unsupported output type: {typ}')
logging.debug('publish_on_ipfs')
if self.ipfs_gateway_url:
# check peer connections, reconnect to skynet gateway if not
@ -373,12 +384,32 @@ class SkynetGPUConnector:
if gateway_id not in [p['Peer'] for p in peers]:
await self.ipfs_client.connect(self.ipfs_gateway_url)
file_info = await self.ipfs_client.add(Path(target_file))
file_cid = file_info['Hash']
ipfs_outs = []
async def _publish_one(target: str):
file_info = await self.ipfs_client.add(Path(target))
file_cid = file_info['Hash']
await self.ipfs_client.pin(file_cid)
await self.ipfs_client.pin(file_cid)
logging.debug(f'published {file_cid}.')
return file_cid
ipfs_outs.append(file_cid)
async with trio.open_nursery() as n:
i = 0
for output, otype in outputs:
target_file = ''
match otype:
case 'png':
target_file = f'ipfs-staging/image-{i}.png'
output.save(target_file)
n.start_soon(_publish_one, target_file)
case _:
raise ValueError(f'Unsupported output type: {otype}')
i += 1
return ipfs_outs
async def get_input_data(self, ipfs_hash: str) -> tuple[bytes, str]:
results = {}
@ -389,7 +420,7 @@ class SkynetGPUConnector:
async with trio.open_nursery() as n:
async def get_and_set_results(link: str):
res = await get_ipfs_file(link, timeout=1)
logging.info(f'got response from {link}')
logging.debug(f'got response from {link}')
if not res or res.status_code != 200:
logging.warning(f'couldn\'t get ipfs binary data at {link}!')

83
skynet/protocol.py 100644
View File

@ -0,0 +1,83 @@
from msgspec import Struct
from skynet.utils import hash_dict
class ModelParams(Struct):
name: str
pipe_fqn: str
setup: dict
def get_uid(self) -> str:
return f'{self.pipe_fqn}:{self.name}-{hash_dict(self.setup)}'
class ParamsStruct(Struct):
model: ModelParams
runtime_args: list
runtime_kwargs: dict
class ComputeRequest(Struct):
method: str
params: ParamsStruct
# telos.gpu smart contract types
TimestampSec = int
class ConfigRow(Struct):
token_contract: str
token_symbol: str
nonce: int
class AccountRow(Struct):
user: str
balance: str
class CardStruct(Struct):
card_name: str
version: str
total_memory: int
mp_count: int
extra: str
is_online: bool
class WorkerRow(Struct):
account: str
joined: TimestampSec
left: TimestampSec
url: str
cards: list[CardStruct]
class WorkerStatusStruct(Struct):
worker: str
status: str
started: TimestampSec
class RequestRow(Struct):
nonce: int
user: str
reward: str
min_verification: int
body: str
inputs: list[str]
status: list[WorkerStatusStruct]
timestamp: TimestampSec
class WorkerResultRow(Struct):
id: int
request_id: int
user: str
worker: str
result_hash: str
ipfs_hash: str
submited: TimestampSec

View File

@ -1,238 +1,14 @@
#!/usr/bin/python
import io
import os
import sys
import time
import random
import logging
from typing import Optional
from pathlib import Path
import asks
import torch
import numpy as np
from PIL import Image
from basicsr.archs.rrdbnet_arch import RRDBNet
from diffusers import (
DiffusionPipeline,
EulerAncestralDiscreteScheduler
)
from realesrgan import RealESRGANer
from huggingface_hub import login
import trio
from .constants import MODELS
import json
import hashlib
def time_ms():
def hash_dict(d) -> str:
d_str = json.dumps(d, sort_keys=True)
return hashlib.sha256(d_str.encode('utf-8')).hexdigest()
def time_ms() -> int:
return int(time.time() * 1000)
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
# return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
return Image.fromarray(img)
def convert_from_image_to_cv2(img: Image) -> np.ndarray:
# return cv2.cvtColor(numpy.array(img), cv2.COLOR_RGB2BGR)
return np.asarray(img)
def convert_from_bytes_to_img(raw: bytes) -> Image:
return Image.open(io.BytesIO(raw))
def convert_from_img_to_bytes(image: Image, fmt='PNG') -> bytes:
byte_arr = io.BytesIO()
image.save(byte_arr, format=fmt)
return byte_arr.getvalue()
def crop_image(image: Image, max_w: int, max_h: int) -> Image:
w, h = image.size
if w > max_w or h > max_h:
image.thumbnail((max_w, max_h))
return image.convert('RGB')
def pipeline_for(
model: str,
mem_fraction: float = 1.0,
image: bool = False,
cache_dir: str | None = None
) -> DiffusionPipeline:
assert torch.cuda.is_available()
torch.cuda.empty_cache()
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# full determinism
# https://huggingface.co/docs/diffusers/using-diffusers/reproducibility#deterministic-algorithms
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
model_info = MODELS[model]
req_mem = model_info['mem']
mem_gb = torch.cuda.mem_get_info()[1] / (10**9)
mem_gb *= mem_fraction
over_mem = mem_gb < req_mem
if over_mem:
logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..')
shortname = model_info['short']
params = {
'safety_checker': None,
'torch_dtype': torch.float16,
'cache_dir': cache_dir,
'variant': 'fp16'
}
match shortname:
case 'stable':
params['revision'] = 'fp16'
torch.cuda.set_per_process_memory_fraction(mem_fraction)
pipe = DiffusionPipeline.from_pretrained(
model, **params)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe.scheduler.config)
pipe.enable_xformers_memory_efficient_attention()
if over_mem:
if not image:
pipe.enable_vae_slicing()
pipe.enable_vae_tiling()
pipe.enable_model_cpu_offload()
else:
if sys.version_info[1] < 11:
# torch.compile only supported on python < 3.11
pipe.unet = torch.compile(
pipe.unet, mode='reduce-overhead', fullgraph=True)
pipe = pipe.to('cuda')
return pipe
def txt2img(
hf_token: str,
model: str = 'prompthero/openjourney',
prompt: str = 'a red old tractor in a sunny wheat field',
output: str = 'output.png',
width: int = 512, height: int = 512,
guidance: float = 10,
steps: int = 28,
seed: Optional[int] = None
):
login(token=hf_token)
pipe = pipeline_for(model)
seed = seed if seed else random.randint(0, 2 ** 64)
prompt = prompt
image = pipe(
prompt,
width=width,
height=height,
guidance_scale=guidance, num_inference_steps=steps,
generator=torch.Generator("cuda").manual_seed(seed)
).images[0]
image.save(output)
def img2img(
hf_token: str,
model: str = 'prompthero/openjourney',
prompt: str = 'a red old tractor in a sunny wheat field',
img_path: str = 'input.png',
output: str = 'output.png',
strength: float = 1.0,
guidance: float = 10,
steps: int = 28,
seed: Optional[int] = None
):
login(token=hf_token)
pipe = pipeline_for(model, image=True)
with open(img_path, 'rb') as img_file:
input_img = convert_from_bytes_and_crop(img_file.read(), 512, 512)
seed = seed if seed else random.randint(0, 2 ** 64)
prompt = prompt
image = pipe(
prompt,
image=input_img,
strength=strength,
guidance_scale=guidance, num_inference_steps=steps,
generator=torch.Generator("cuda").manual_seed(seed)
).images[0]
image.save(output)
def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'):
return RealESRGANer(
scale=4,
model_path=model_path,
dni_weight=None,
model=RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4
),
half=True
)
def upscale(
img_path: str = 'input.png',
output: str = 'output.png',
model_path: str = 'weights/RealESRGAN_x4plus.pth'
):
input_img = Image.open(img_path).convert('RGB')
upscaler = init_upscaler(model_path=model_path)
up_img, _ = upscaler.enhance(
convert_from_image_to_cv2(input_img), outscale=4)
image = convert_from_cv2_to_image(up_img)
image.save(output)
async def download_upscaler():
print('downloading upscaler...')
weights_path = Path('weights')
weights_path.mkdir(exist_ok=True)
upscaler_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'
save_path = weights_path / 'RealESRGAN_x4plus.pth'
response = await asks.get(upscaler_url)
with open(save_path, 'wb') as f:
f.write(response.content)
print('done')
def download_all_models(hf_token: str, hf_home: str):
assert torch.cuda.is_available()
trio.run(download_upscaler)
login(token=hf_token)
for model in MODELS:
print(f'DOWNLOADING {model.upper()}')
pipeline_for(model, cache_dir=hf_home)