mirror of https://github.com/skygpu/skynet.git
Simplify pipeline_for function and add the infra needed for diferent io/types than png
parent
ee1fdcc557
commit
3d2069d151
|
@ -15,7 +15,6 @@ Pillow = '^10.0.1'
|
||||||
docker = '^6.1.3'
|
docker = '^6.1.3'
|
||||||
py-leap = {git = 'https://github.com/guilledk/py-leap.git', rev = 'v0.1a14'}
|
py-leap = {git = 'https://github.com/guilledk/py-leap.git', rev = 'v0.1a14'}
|
||||||
toml = "^0.10.2"
|
toml = "^0.10.2"
|
||||||
tractor = {git = "https://github.com/goodboy/tractor.git"}
|
|
||||||
|
|
||||||
[tool.poetry.group.frontend]
|
[tool.poetry.group.frontend]
|
||||||
optional = true
|
optional = true
|
||||||
|
|
|
@ -85,7 +85,7 @@ def download():
|
||||||
hf_token = load_key(config, 'skynet.dgpu.hf_token')
|
hf_token = load_key(config, 'skynet.dgpu.hf_token')
|
||||||
hf_home = load_key(config, 'skynet.dgpu.hf_home')
|
hf_home = load_key(config, 'skynet.dgpu.hf_home')
|
||||||
set_hf_vars(hf_token, hf_home)
|
set_hf_vars(hf_token, hf_home)
|
||||||
utils.download_all_models(hf_token)
|
utils.download_all_models(hf_token, hf_home)
|
||||||
|
|
||||||
@skynet.command()
|
@skynet.command()
|
||||||
@click.option(
|
@click.option(
|
||||||
|
@ -120,21 +120,21 @@ def enqueue(
|
||||||
|
|
||||||
cleos = CLEOS(None, None, url=node_url, remote=node_url)
|
cleos = CLEOS(None, None, url=node_url, remote=node_url)
|
||||||
|
|
||||||
|
binary = kwargs['binary_data']
|
||||||
|
if not kwargs['strength']:
|
||||||
|
if binary:
|
||||||
|
raise ValueError('strength -Z param required if binary data passed')
|
||||||
|
|
||||||
|
del kwargs['strength']
|
||||||
|
|
||||||
|
else:
|
||||||
|
kwargs['strength'] = float(kwargs['strength'])
|
||||||
|
|
||||||
async def enqueue_n_jobs():
|
async def enqueue_n_jobs():
|
||||||
for i in range(jobs):
|
for i in range(jobs):
|
||||||
if not kwargs['seed']:
|
if not kwargs['seed']:
|
||||||
kwargs['seed'] = random.randint(0, 10e9)
|
kwargs['seed'] = random.randint(0, 10e9)
|
||||||
|
|
||||||
binary = kwargs['binary_data']
|
|
||||||
if not kwargs['strength']:
|
|
||||||
if binary:
|
|
||||||
raise ValueError('strength -Z param required if binary data passed')
|
|
||||||
|
|
||||||
del kwargs['strength']
|
|
||||||
|
|
||||||
else:
|
|
||||||
kwargs['strength'] = float(kwargs['strength'])
|
|
||||||
|
|
||||||
req = json.dumps({
|
req = json.dumps({
|
||||||
'method': 'diffuse',
|
'method': 'diffuse',
|
||||||
'params': kwargs
|
'params': kwargs
|
||||||
|
|
|
@ -5,18 +5,20 @@ VERSION = '0.1a12'
|
||||||
DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda'
|
DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda'
|
||||||
|
|
||||||
MODELS = {
|
MODELS = {
|
||||||
'prompthero/openjourney': {'short': 'midj', 'mem': 8},
|
'prompthero/openjourney': {'short': 'midj', 'mem': 6},
|
||||||
'runwayml/stable-diffusion-v1-5': {'short': 'stable', 'mem': 8},
|
'runwayml/stable-diffusion-v1-5': {'short': 'stable', 'mem': 6},
|
||||||
'stabilityai/stable-diffusion-2-1-base': {'short': 'stable2', 'mem': 8},
|
'stabilityai/stable-diffusion-2-1-base': {'short': 'stable2', 'mem': 6},
|
||||||
'snowkidy/stable-diffusion-xl-base-0.9': {'short': 'stablexl0.9', 'mem': 24},
|
'snowkidy/stable-diffusion-xl-base-0.9': {'short': 'stablexl0.9', 'mem': 8.3},
|
||||||
'stabilityai/stable-diffusion-xl-base-1.0': {'short': 'stablexl', 'mem': 24},
|
'Linaqruf/anything-v3.0': {'short': 'hdanime', 'mem': 6},
|
||||||
'Linaqruf/anything-v3.0': {'short': 'hdanime', 'mem': 8},
|
'hakurei/waifu-diffusion': {'short': 'waifu', 'mem': 6},
|
||||||
'hakurei/waifu-diffusion': {'short': 'waifu', 'mem': 8},
|
'nitrosocke/Ghibli-Diffusion': {'short': 'ghibli', 'mem': 6},
|
||||||
'nitrosocke/Ghibli-Diffusion': {'short': 'ghibli', 'mem': 8},
|
'dallinmackay/Van-Gogh-diffusion': {'short': 'van-gogh', 'mem': 6},
|
||||||
'dallinmackay/Van-Gogh-diffusion': {'short': 'van-gogh', 'mem': 8},
|
'lambdalabs/sd-pokemon-diffusers': {'short': 'pokemon', 'mem': 6},
|
||||||
'lambdalabs/sd-pokemon-diffusers': {'short': 'pokemon', 'mem': 8},
|
'Envvi/Inkpunk-Diffusion': {'short': 'ink', 'mem': 6},
|
||||||
'Envvi/Inkpunk-Diffusion': {'short': 'ink', 'mem': 8},
|
'nousr/robo-diffusion': {'short': 'robot', 'mem': 6},
|
||||||
'nousr/robo-diffusion': {'short': 'robot', 'mem': 8}
|
|
||||||
|
# default is always last
|
||||||
|
'stabilityai/stable-diffusion-xl-base-1.0': {'short': 'stablexl', 'mem': 8.3},
|
||||||
}
|
}
|
||||||
|
|
||||||
SHORT_NAMES = [
|
SHORT_NAMES = [
|
||||||
|
@ -158,7 +160,7 @@ DEFAULT_GUIDANCE = 7.5
|
||||||
DEFAULT_STRENGTH = 0.5
|
DEFAULT_STRENGTH = 0.5
|
||||||
DEFAULT_STEP = 28
|
DEFAULT_STEP = 28
|
||||||
DEFAULT_CREDITS = 10
|
DEFAULT_CREDITS = 10
|
||||||
DEFAULT_MODEL = list(MODELS.keys())[4]
|
DEFAULT_MODEL = list(MODELS.keys())[-1]
|
||||||
DEFAULT_ROLE = 'pleb'
|
DEFAULT_ROLE = 'pleb'
|
||||||
DEFAULT_UPSCALER = None
|
DEFAULT_UPSCALER = None
|
||||||
|
|
||||||
|
|
|
@ -1,165 +0,0 @@
|
||||||
#!/usr/bin/python
|
|
||||||
|
|
||||||
import gc
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from hashlib import sha256
|
|
||||||
from diffusers import DiffusionPipeline
|
|
||||||
|
|
||||||
import trio
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
|
|
||||||
from skynet.dgpu.errors import DGPUComputeError
|
|
||||||
|
|
||||||
from skynet.utils import convert_from_bytes_and_crop, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_params_for_diffuse(
|
|
||||||
params: dict,
|
|
||||||
binary: bytes | None = None
|
|
||||||
):
|
|
||||||
image = None
|
|
||||||
if binary:
|
|
||||||
image = convert_from_bytes_and_crop(binary, 512, 512)
|
|
||||||
|
|
||||||
_params = {}
|
|
||||||
if image:
|
|
||||||
_params['image'] = image
|
|
||||||
_params['strength'] = float(params['strength'])
|
|
||||||
|
|
||||||
else:
|
|
||||||
_params['width'] = int(params['width'])
|
|
||||||
_params['height'] = int(params['height'])
|
|
||||||
|
|
||||||
return (
|
|
||||||
params['prompt'],
|
|
||||||
float(params['guidance']),
|
|
||||||
int(params['step']),
|
|
||||||
torch.manual_seed(int(params['seed'])),
|
|
||||||
params['upscaler'] if 'upscaler' in params else None,
|
|
||||||
_params
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
_models = {}
|
|
||||||
|
|
||||||
def is_model_loaded(model_name: str, image: bool):
|
|
||||||
for model_key, model_data in _models.items():
|
|
||||||
if (model_key == model_name and
|
|
||||||
model_data['image'] == image):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def load_model(
|
|
||||||
model_name: str,
|
|
||||||
image: bool,
|
|
||||||
force=False
|
|
||||||
):
|
|
||||||
logging.info(f'loading model {model_name}...')
|
|
||||||
if force or len(_models.keys()) == 0:
|
|
||||||
pipe = pipeline_for(
|
|
||||||
model_name, image=image)
|
|
||||||
|
|
||||||
_models[model_name] = {
|
|
||||||
'pipe': pipe,
|
|
||||||
'generated': 0,
|
|
||||||
'image': image
|
|
||||||
}
|
|
||||||
|
|
||||||
else:
|
|
||||||
least_used = list(_models.keys())[0]
|
|
||||||
|
|
||||||
for model in _models:
|
|
||||||
if _models[
|
|
||||||
least_used]['generated'] > _models[model]['generated']:
|
|
||||||
least_used = model
|
|
||||||
|
|
||||||
del _models[least_used]
|
|
||||||
|
|
||||||
logging.info(f'swapping model {least_used} for {model_name}...')
|
|
||||||
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
pipe = pipeline_for(
|
|
||||||
model_name, image=image)
|
|
||||||
|
|
||||||
_models[model_name] = {
|
|
||||||
'pipe': pipe,
|
|
||||||
'generated': 0,
|
|
||||||
'image': image
|
|
||||||
}
|
|
||||||
|
|
||||||
logging.info(f'loaded model {model_name}')
|
|
||||||
return pipe
|
|
||||||
|
|
||||||
def get_model(model_name: str, image: bool) -> DiffusionPipeline:
|
|
||||||
if model_name not in MODELS:
|
|
||||||
raise DGPUComputeError(f'Unknown model {model_name}')
|
|
||||||
|
|
||||||
if not is_model_loaded(model_name, image):
|
|
||||||
pipe = load_model(model_name, image=image)
|
|
||||||
|
|
||||||
else:
|
|
||||||
pipe = _models[model_name]['pipe']
|
|
||||||
|
|
||||||
return pipe
|
|
||||||
|
|
||||||
def _static_compute_one(kwargs: dict):
|
|
||||||
request_id: int = kwargs['request_id']
|
|
||||||
method: str = kwargs['method']
|
|
||||||
params: dict = kwargs['params']
|
|
||||||
binary: bytes | None = kwargs['binary']
|
|
||||||
|
|
||||||
def _checkpoint(*args, **kwargs):
|
|
||||||
trio.from_thread.run(trio.sleep, 0)
|
|
||||||
|
|
||||||
try:
|
|
||||||
match method:
|
|
||||||
case 'diffuse':
|
|
||||||
image = None
|
|
||||||
|
|
||||||
arguments = prepare_params_for_diffuse(params, binary)
|
|
||||||
prompt, guidance, step, seed, upscaler, extra_params = arguments
|
|
||||||
model = get_model(params['model'], 'image' in extra_params)
|
|
||||||
|
|
||||||
image = model(
|
|
||||||
prompt,
|
|
||||||
guidance_scale=guidance,
|
|
||||||
num_inference_steps=step,
|
|
||||||
generator=seed,
|
|
||||||
callback=_checkpoint,
|
|
||||||
callback_steps=1,
|
|
||||||
**extra_params
|
|
||||||
).images[0]
|
|
||||||
|
|
||||||
if upscaler == 'x4':
|
|
||||||
upscaler = init_upscaler()
|
|
||||||
input_img = image.convert('RGB')
|
|
||||||
up_img, _ = upscaler.enhance(
|
|
||||||
convert_from_image_to_cv2(input_img), outscale=4)
|
|
||||||
|
|
||||||
image = convert_from_cv2_to_image(up_img)
|
|
||||||
|
|
||||||
img_raw = convert_from_img_to_bytes(image)
|
|
||||||
img_sha = sha256(img_raw).hexdigest()
|
|
||||||
|
|
||||||
return img_sha, img_raw
|
|
||||||
|
|
||||||
case _:
|
|
||||||
raise DGPUComputeError('Unsupported compute method')
|
|
||||||
|
|
||||||
except BaseException as e:
|
|
||||||
logging.error(e)
|
|
||||||
raise DGPUComputeError(str(e))
|
|
||||||
|
|
||||||
finally:
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
async def _tractor_static_compute_one(**kwargs):
|
|
||||||
return await trio.to_thread.run_sync(
|
|
||||||
_static_compute_one, kwargs)
|
|
|
@ -3,10 +3,11 @@
|
||||||
# Skynet Memory Manager
|
# Skynet Memory Manager
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
|
import zipfile
|
||||||
|
from PIL import Image
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
|
@ -15,22 +16,29 @@ import torch
|
||||||
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
|
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
|
||||||
from skynet.dgpu.errors import DGPUComputeError, DGPUInferenceCancelled
|
from skynet.dgpu.errors import DGPUComputeError, DGPUInferenceCancelled
|
||||||
|
|
||||||
from skynet.utils import convert_from_bytes_and_crop, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
|
from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
|
||||||
|
|
||||||
from ._mp_compute import _static_compute_one, _tractor_static_compute_one
|
|
||||||
|
|
||||||
def prepare_params_for_diffuse(
|
def prepare_params_for_diffuse(
|
||||||
params: dict,
|
params: dict,
|
||||||
binary: bytes | None = None
|
input_type: str,
|
||||||
|
binary = None
|
||||||
):
|
):
|
||||||
image = None
|
|
||||||
if binary:
|
|
||||||
image = convert_from_bytes_and_crop(binary, 512, 512)
|
|
||||||
|
|
||||||
_params = {}
|
_params = {}
|
||||||
if image:
|
if binary != None:
|
||||||
_params['image'] = image
|
match input_type:
|
||||||
_params['strength'] = float(params['strength'])
|
case 'png':
|
||||||
|
image = crop_image(
|
||||||
|
binary, params['width'], params['height'])
|
||||||
|
|
||||||
|
_params['image'] = image
|
||||||
|
_params['strength'] = float(params['strength'])
|
||||||
|
|
||||||
|
case 'none':
|
||||||
|
...
|
||||||
|
|
||||||
|
case _:
|
||||||
|
raise DGPUComputeError(f'Unknown input_type {input_type}')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
_params['width'] = int(params['width'])
|
_params['width'] = int(params['width'])
|
||||||
|
@ -136,6 +144,7 @@ class SkynetMM:
|
||||||
request_id: int,
|
request_id: int,
|
||||||
method: str,
|
method: str,
|
||||||
params: dict,
|
params: dict,
|
||||||
|
input_type: str = 'png',
|
||||||
binary: bytes | None = None
|
binary: bytes | None = None
|
||||||
):
|
):
|
||||||
def maybe_cancel_work(step, *args, **kwargs):
|
def maybe_cancel_work(step, *args, **kwargs):
|
||||||
|
@ -147,16 +156,21 @@ class SkynetMM:
|
||||||
|
|
||||||
maybe_cancel_work(0)
|
maybe_cancel_work(0)
|
||||||
|
|
||||||
|
output_type = 'png'
|
||||||
|
if 'output_type' in params:
|
||||||
|
output_type = params['output_type']
|
||||||
|
|
||||||
|
output = None
|
||||||
|
output_hash = None
|
||||||
try:
|
try:
|
||||||
match method:
|
match method:
|
||||||
case 'diffuse':
|
case 'diffuse':
|
||||||
image = None
|
arguments = prepare_params_for_diffuse(
|
||||||
|
params, input_type, binary=binary)
|
||||||
arguments = prepare_params_for_diffuse(params, binary)
|
|
||||||
prompt, guidance, step, seed, upscaler, extra_params = arguments
|
prompt, guidance, step, seed, upscaler, extra_params = arguments
|
||||||
model = self.get_model(params['model'], 'image' in extra_params)
|
model = self.get_model(params['model'], 'image' in extra_params)
|
||||||
|
|
||||||
image = model(
|
output = model(
|
||||||
prompt,
|
prompt,
|
||||||
guidance_scale=guidance,
|
guidance_scale=guidance,
|
||||||
num_inference_steps=step,
|
num_inference_steps=step,
|
||||||
|
@ -166,17 +180,22 @@ class SkynetMM:
|
||||||
**extra_params
|
**extra_params
|
||||||
).images[0]
|
).images[0]
|
||||||
|
|
||||||
if upscaler == 'x4':
|
output_binary = b''
|
||||||
input_img = image.convert('RGB')
|
match output_type:
|
||||||
up_img, _ = self.upscaler.enhance(
|
case 'png':
|
||||||
convert_from_image_to_cv2(input_img), outscale=4)
|
if upscaler == 'x4':
|
||||||
|
input_img = output.convert('RGB')
|
||||||
|
up_img, _ = self.upscaler.enhance(
|
||||||
|
convert_from_image_to_cv2(input_img), outscale=4)
|
||||||
|
|
||||||
image = convert_from_cv2_to_image(up_img)
|
output = convert_from_cv2_to_image(up_img)
|
||||||
|
|
||||||
img_raw = convert_from_img_to_bytes(image)
|
output_binary = convert_from_img_to_bytes(output)
|
||||||
img_sha = sha256(img_raw).hexdigest()
|
|
||||||
|
|
||||||
return img_sha, img_raw
|
case _:
|
||||||
|
raise DGPUComputeError(f'Unsupported output type: {output_type}')
|
||||||
|
|
||||||
|
output_hash = sha256(output_binary).hexdigest()
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
raise DGPUComputeError('Unsupported compute method')
|
raise DGPUComputeError('Unsupported compute method')
|
||||||
|
@ -187,3 +206,5 @@ class SkynetMM:
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return output_hash, output
|
||||||
|
|
|
@ -9,10 +9,10 @@ from hashlib import sha256
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
import tractor
|
from skynet.constants import MODELS
|
||||||
|
|
||||||
from skynet.dgpu.errors import *
|
from skynet.dgpu.errors import *
|
||||||
from skynet.dgpu.compute import SkynetMM, _tractor_static_compute_one
|
from skynet.dgpu.compute import SkynetMM
|
||||||
from skynet.dgpu.network import SkynetGPUConnector
|
from skynet.dgpu.network import SkynetGPUConnector
|
||||||
|
|
||||||
|
|
||||||
|
@ -97,6 +97,11 @@ class SkynetDGPUDaemon:
|
||||||
body = json.loads(req['body'])
|
body = json.loads(req['body'])
|
||||||
model = body['params']['model']
|
model = body['params']['model']
|
||||||
|
|
||||||
|
# if model not known
|
||||||
|
if model not in MODELS:
|
||||||
|
logging.warning(f'Unknown model {model}')
|
||||||
|
continue
|
||||||
|
|
||||||
# if whitelist enabled and model not in it continue
|
# if whitelist enabled and model not in it continue
|
||||||
if (len(self.model_whitelist) > 0 and
|
if (len(self.model_whitelist) > 0 and
|
||||||
not model in self.model_whitelist):
|
not model in self.model_whitelist):
|
||||||
|
@ -111,7 +116,7 @@ class SkynetDGPUDaemon:
|
||||||
statuses = self._snap['requests'][rid]
|
statuses = self._snap['requests'][rid]
|
||||||
|
|
||||||
if len(statuses) == 0:
|
if len(statuses) == 0:
|
||||||
binary = await self.conn.get_input_data(req['binary_data'])
|
binary, input_type = await self.conn.get_input_data(req['binary_data'])
|
||||||
|
|
||||||
hash_str = (
|
hash_str = (
|
||||||
str(req['nonce'])
|
str(req['nonce'])
|
||||||
|
@ -134,46 +139,31 @@ class SkynetDGPUDaemon:
|
||||||
|
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
output_type = 'png'
|
||||||
|
if 'output_type' in body['params']:
|
||||||
|
output_type = body['params']['output_type']
|
||||||
|
|
||||||
|
output = None
|
||||||
|
output_hash = None
|
||||||
match self.backend:
|
match self.backend:
|
||||||
case 'sync-on-thread':
|
case 'sync-on-thread':
|
||||||
self.mm._should_cancel = self.should_cancel_work
|
self.mm._should_cancel = self.should_cancel_work
|
||||||
img_sha, img_raw = await trio.to_thread.run_sync(
|
output_hash, output = await trio.to_thread.run_sync(
|
||||||
partial(
|
partial(
|
||||||
self.mm.compute_one,
|
self.mm.compute_one,
|
||||||
rid,
|
rid,
|
||||||
body['method'], body['params'], binary=binary
|
body['method'], body['params'],
|
||||||
)
|
input_type=input_type,
|
||||||
)
|
|
||||||
|
|
||||||
case 'tractor':
|
|
||||||
async def _should_cancel_oracle():
|
|
||||||
while True:
|
|
||||||
await trio.sleep(1)
|
|
||||||
if (await self.should_cancel_work(rid)):
|
|
||||||
raise DGPUInferenceCancelled
|
|
||||||
|
|
||||||
async with (
|
|
||||||
trio.open_nursery() as trio_n,
|
|
||||||
tractor.open_nursery() as tractor_n
|
|
||||||
):
|
|
||||||
trio_n.start_soon(_should_cancel_oracle)
|
|
||||||
portal = await tractor_n.run_in_actor(
|
|
||||||
_tractor_static_compute_one,
|
|
||||||
name='tractor-cuda-mp',
|
|
||||||
request_id=rid,
|
|
||||||
method=body['method'],
|
|
||||||
params=body['params'],
|
|
||||||
binary=binary
|
binary=binary
|
||||||
)
|
)
|
||||||
img_sha, img_raw = await portal.result()
|
)
|
||||||
trio_n.cancel_scope.cancel()
|
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
raise DGPUComputeError(f'Unsupported backend {self.backend}')
|
raise DGPUComputeError(f'Unsupported backend {self.backend}')
|
||||||
|
|
||||||
ipfs_hash = await self.conn.publish_on_ipfs(img_raw)
|
ipfs_hash = await self.conn.publish_on_ipfs(output, typ=output_type)
|
||||||
|
|
||||||
await self.conn.submit_work(rid, request_hash, img_sha, ipfs_hash)
|
await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
|
||||||
|
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
|
@ -9,17 +9,19 @@ from pathlib import Path
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import asks
|
import asks
|
||||||
|
import numpy
|
||||||
import trio
|
import trio
|
||||||
import anyio
|
import anyio
|
||||||
|
import torch
|
||||||
|
|
||||||
from PIL import Image, UnidentifiedImageError
|
from PIL import Image, UnidentifiedImageError
|
||||||
|
|
||||||
from leap.cleos import CLEOS
|
from leap.cleos import CLEOS
|
||||||
from leap.sugar import Checksum256, Name, asset_from_str
|
from leap.sugar import Checksum256, Name, asset_from_str
|
||||||
from skynet.constants import DEFAULT_DOMAIN
|
|
||||||
|
|
||||||
from skynet.dgpu.errors import DGPUComputeError
|
|
||||||
from skynet.ipfs import AsyncIPFSHTTP, get_ipfs_file
|
from skynet.ipfs import AsyncIPFSHTTP, get_ipfs_file
|
||||||
|
from skynet.dgpu.errors import DGPUComputeError
|
||||||
|
from skynet.constants import DEFAULT_DOMAIN
|
||||||
|
|
||||||
|
|
||||||
REQUEST_UPDATE_TIME = 3
|
REQUEST_UPDATE_TIME = 3
|
||||||
|
@ -235,11 +237,19 @@ class SkynetGPUConnector:
|
||||||
)
|
)
|
||||||
|
|
||||||
# IPFS helpers
|
# IPFS helpers
|
||||||
async def publish_on_ipfs(self, raw_img: bytes):
|
async def publish_on_ipfs(self, raw, typ: str = 'png'):
|
||||||
Path('ipfs-staging').mkdir(exist_ok=True)
|
Path('ipfs-staging').mkdir(exist_ok=True)
|
||||||
logging.info('publish_on_ipfs')
|
logging.info('publish_on_ipfs')
|
||||||
img = Image.open(io.BytesIO(raw_img))
|
|
||||||
img.save('ipfs-staging/image.png')
|
target_file = ''
|
||||||
|
match typ:
|
||||||
|
case 'png':
|
||||||
|
raw: Image
|
||||||
|
target_file = 'ipfs-staging/image.png'
|
||||||
|
raw.save(target_file)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
raise ValueError(f'Unsupported output type: {typ}')
|
||||||
|
|
||||||
if self.ipfs_gateway_url:
|
if self.ipfs_gateway_url:
|
||||||
# check peer connections, reconnect to skynet gateway if not
|
# check peer connections, reconnect to skynet gateway if not
|
||||||
|
@ -248,16 +258,18 @@ class SkynetGPUConnector:
|
||||||
if gateway_id not in [p['Peer'] for p in peers]:
|
if gateway_id not in [p['Peer'] for p in peers]:
|
||||||
await self.ipfs_client.connect(self.ipfs_gateway_url)
|
await self.ipfs_client.connect(self.ipfs_gateway_url)
|
||||||
|
|
||||||
file_info = await self.ipfs_client.add(Path('ipfs-staging/image.png'))
|
file_info = await self.ipfs_client.add(Path(target_file))
|
||||||
file_cid = file_info['Hash']
|
file_cid = file_info['Hash']
|
||||||
|
|
||||||
await self.ipfs_client.pin(file_cid)
|
await self.ipfs_client.pin(file_cid)
|
||||||
|
|
||||||
return file_cid
|
return file_cid
|
||||||
|
|
||||||
async def get_input_data(self, ipfs_hash: str) -> bytes:
|
async def get_input_data(self, ipfs_hash: str) -> tuple[bytes, str]:
|
||||||
|
input_type = 'none'
|
||||||
|
|
||||||
if ipfs_hash == '':
|
if ipfs_hash == '':
|
||||||
return b''
|
return b'', input_type
|
||||||
|
|
||||||
results = {}
|
results = {}
|
||||||
ipfs_link = f'https://ipfs.{DEFAULT_DOMAIN}/ipfs/{ipfs_hash}'
|
ipfs_link = f'https://ipfs.{DEFAULT_DOMAIN}/ipfs/{ipfs_hash}'
|
||||||
|
@ -272,9 +284,10 @@ class SkynetGPUConnector:
|
||||||
|
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
with Image.open(io.BytesIO(res.raw)):
|
# attempt to decode as image
|
||||||
results[link] = res.raw
|
results[link] = Image.open(io.BytesIO(res.raw))
|
||||||
n.cancel_scope.cancel()
|
input_type = 'png'
|
||||||
|
n.cancel_scope.cancel()
|
||||||
|
|
||||||
except UnidentifiedImageError:
|
except UnidentifiedImageError:
|
||||||
logging.warning(f'couldn\'t get ipfs binary data at {link}!')
|
logging.warning(f'couldn\'t get ipfs binary data at {link}!')
|
||||||
|
@ -284,14 +297,14 @@ class SkynetGPUConnector:
|
||||||
n.start_soon(
|
n.start_soon(
|
||||||
get_and_set_results, ipfs_link_legacy)
|
get_and_set_results, ipfs_link_legacy)
|
||||||
|
|
||||||
png_img = None
|
input_data = None
|
||||||
if ipfs_link_legacy in results:
|
if ipfs_link_legacy in results:
|
||||||
png_img = results[ipfs_link_legacy]
|
input_data = results[ipfs_link_legacy]
|
||||||
|
|
||||||
if ipfs_link in results:
|
if ipfs_link in results:
|
||||||
png_img = results[ipfs_link]
|
input_data = results[ipfs_link]
|
||||||
|
|
||||||
if not png_img:
|
if input_data == None:
|
||||||
raise DGPUComputeError('Couldn\'t gather input data from ipfs')
|
raise DGPUComputeError('Couldn\'t gather input data from ipfs')
|
||||||
|
|
||||||
return png_img
|
return input_data, input_type
|
||||||
|
|
|
@ -18,15 +18,10 @@ from PIL import Image
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
DiffusionPipeline,
|
DiffusionPipeline,
|
||||||
StableDiffusionXLPipeline,
|
|
||||||
StableDiffusionXLImg2ImgPipeline,
|
|
||||||
StableDiffusionPipeline,
|
|
||||||
StableDiffusionImg2ImgPipeline,
|
|
||||||
EulerAncestralDiscreteScheduler
|
EulerAncestralDiscreteScheduler
|
||||||
)
|
)
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
from huggingface_hub import login
|
from huggingface_hub import login
|
||||||
from torch.distributions import weibull
|
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
from .constants import MODELS
|
from .constants import MODELS
|
||||||
|
@ -56,11 +51,10 @@ def convert_from_img_to_bytes(image: Image, fmt='PNG') -> bytes:
|
||||||
return byte_arr.getvalue()
|
return byte_arr.getvalue()
|
||||||
|
|
||||||
|
|
||||||
def convert_from_bytes_and_crop(raw: bytes, max_w: int, max_h: int) -> Image:
|
def crop_image(image: Image, max_w: int, max_h: int) -> Image:
|
||||||
image = convert_from_bytes_to_img(raw)
|
|
||||||
w, h = image.size
|
w, h = image.size
|
||||||
if w > max_w or h > max_h:
|
if w > max_w or h > max_h:
|
||||||
image.thumbnail((512, 512))
|
image.thumbnail((max_w, max_h))
|
||||||
|
|
||||||
return image.convert('RGB')
|
return image.convert('RGB')
|
||||||
|
|
||||||
|
@ -74,7 +68,6 @@ def pipeline_for(
|
||||||
|
|
||||||
assert torch.cuda.is_available()
|
assert torch.cuda.is_available()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
@ -89,6 +82,7 @@ def pipeline_for(
|
||||||
|
|
||||||
req_mem = model_info['mem']
|
req_mem = model_info['mem']
|
||||||
mem_gb = torch.cuda.mem_get_info()[1] / (10**9)
|
mem_gb = torch.cuda.mem_get_info()[1] / (10**9)
|
||||||
|
mem_gb *= mem_fraction
|
||||||
over_mem = mem_gb < req_mem
|
over_mem = mem_gb < req_mem
|
||||||
if over_mem:
|
if over_mem:
|
||||||
logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..')
|
logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..')
|
||||||
|
@ -96,26 +90,19 @@ def pipeline_for(
|
||||||
shortname = model_info['short']
|
shortname = model_info['short']
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'torch_dtype': torch.float16,
|
|
||||||
'safety_checker': None,
|
'safety_checker': None,
|
||||||
'cache_dir': cache_dir
|
'torch_dtype': torch.float16,
|
||||||
|
'cache_dir': cache_dir,
|
||||||
|
'variant': 'fp16'
|
||||||
}
|
}
|
||||||
|
|
||||||
if shortname == 'stable':
|
match shortname:
|
||||||
params['revision'] = 'fp16'
|
case 'stable':
|
||||||
|
params['revision'] = 'fp16'
|
||||||
|
|
||||||
if 'xl' in shortname:
|
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
||||||
if image:
|
|
||||||
pipe_class = StableDiffusionXLImg2ImgPipeline
|
|
||||||
else:
|
|
||||||
pipe_class = StableDiffusionXLPipeline
|
|
||||||
else:
|
|
||||||
if image:
|
|
||||||
pipe_class = StableDiffusionImg2ImgPipeline
|
|
||||||
else:
|
|
||||||
pipe_class = StableDiffusionPipeline
|
|
||||||
|
|
||||||
pipe = pipe_class.from_pretrained(
|
pipe = DiffusionPipeline.from_pretrained(
|
||||||
model, **params)
|
model, **params)
|
||||||
|
|
||||||
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
||||||
|
@ -151,12 +138,6 @@ def txt2img(
|
||||||
steps: int = 28,
|
steps: int = 28,
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
):
|
):
|
||||||
assert torch.cuda.is_available()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
torch.cuda.set_per_process_memory_fraction(1.0)
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
|
||||||
|
|
||||||
login(token=hf_token)
|
login(token=hf_token)
|
||||||
pipe = pipeline_for(model)
|
pipe = pipeline_for(model)
|
||||||
|
|
||||||
|
@ -184,12 +165,6 @@ def img2img(
|
||||||
steps: int = 28,
|
steps: int = 28,
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
):
|
):
|
||||||
assert torch.cuda.is_available()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
torch.cuda.set_per_process_memory_fraction(1.0)
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
|
||||||
|
|
||||||
login(token=hf_token)
|
login(token=hf_token)
|
||||||
pipe = pipeline_for(model, image=True)
|
pipe = pipeline_for(model, image=True)
|
||||||
|
|
||||||
|
@ -230,12 +205,6 @@ def upscale(
|
||||||
output: str = 'output.png',
|
output: str = 'output.png',
|
||||||
model_path: str = 'weights/RealESRGAN_x4plus.pth'
|
model_path: str = 'weights/RealESRGAN_x4plus.pth'
|
||||||
):
|
):
|
||||||
assert torch.cuda.is_available()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
torch.cuda.set_per_process_memory_fraction(1.0)
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
|
||||||
|
|
||||||
input_img = Image.open(img_path).convert('RGB')
|
input_img = Image.open(img_path).convert('RGB')
|
||||||
|
|
||||||
upscaler = init_upscaler(model_path=model_path)
|
upscaler = init_upscaler(model_path=model_path)
|
||||||
|
@ -258,7 +227,7 @@ async def download_upscaler():
|
||||||
f.write(response.content)
|
f.write(response.content)
|
||||||
print('done')
|
print('done')
|
||||||
|
|
||||||
def download_all_models(hf_token: str):
|
def download_all_models(hf_token: str, hf_home: str):
|
||||||
assert torch.cuda.is_available()
|
assert torch.cuda.is_available()
|
||||||
|
|
||||||
trio.run(download_upscaler)
|
trio.run(download_upscaler)
|
||||||
|
@ -266,6 +235,4 @@ def download_all_models(hf_token: str):
|
||||||
login(token=hf_token)
|
login(token=hf_token)
|
||||||
for model in MODELS:
|
for model in MODELS:
|
||||||
print(f'DOWNLOADING {model.upper()}')
|
print(f'DOWNLOADING {model.upper()}')
|
||||||
pipeline_for(model)
|
pipeline_for(model, cache_dir=hf_home)
|
||||||
print(f'DOWNLOADING IMAGE {model.upper()}')
|
|
||||||
pipeline_for(model, image=True)
|
|
||||||
|
|
Loading…
Reference in New Issue