mirror of https://github.com/skygpu/skynet.git
Start testing inpainting mode
parent
1e40c05da6
commit
8d35e5ed9a
|
@ -20,7 +20,7 @@ def skynet(*args, **kwargs):
|
||||||
|
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@click.option('--model', '-m', default='midj')
|
@click.option('--model', '-m', default=list(MODELS.keys())[-1])
|
||||||
@click.option(
|
@click.option(
|
||||||
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
|
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
|
||||||
@click.option('--output', '-o', default='output.png')
|
@click.option('--output', '-o', default='output.png')
|
||||||
|
@ -39,7 +39,7 @@ def txt2img(*args, **kwargs):
|
||||||
utils.txt2img(hf_token, **kwargs)
|
utils.txt2img(hf_token, **kwargs)
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@click.option('--model', '-m', default=list(MODELS.keys())[0])
|
@click.option('--model', '-m', default=list(MODELS.keys())[-2])
|
||||||
@click.option(
|
@click.option(
|
||||||
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
|
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
|
||||||
@click.option('--input', '-i', default='input.png')
|
@click.option('--input', '-i', default='input.png')
|
||||||
|
@ -68,7 +68,7 @@ def img2img(model, prompt, input, output, strength, guidance, steps, seed):
|
||||||
|
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@click.option('--model', '-m', default=list(MODELS.keys())[-1])
|
@click.option('--model', '-m', default=list(MODELS.keys())[-3])
|
||||||
@click.option(
|
@click.option(
|
||||||
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
|
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
|
||||||
@click.option('--input', '-i', default='input.png')
|
@click.option('--input', '-i', default='input.png')
|
||||||
|
|
|
@ -4,34 +4,108 @@ VERSION = '0.1a12'
|
||||||
|
|
||||||
DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda'
|
DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda'
|
||||||
|
|
||||||
MODELS = {
|
import msgspec
|
||||||
'prompthero/openjourney': {'short': 'midj', 'mem': 6, 'size': {'w': 512, 'h': 512}},
|
from typing import Literal
|
||||||
'runwayml/stable-diffusion-v1-5': {'short': 'stable', 'mem': 6, 'size': {'w': 512, 'h': 512}},
|
|
||||||
'stabilityai/stable-diffusion-2-1-base': {'short': 'stable2', 'mem': 6, 'size': {'w': 512, 'h': 512}},
|
|
||||||
'snowkidy/stable-diffusion-xl-base-0.9': {'short': 'stablexl0.9', 'mem': 8.3, 'size': {'w': 1024, 'h': 1024}},
|
|
||||||
'Linaqruf/anything-v3.0': {'short': 'hdanime', 'mem': 6, 'size': {'w': 512, 'h': 512}},
|
|
||||||
'hakurei/waifu-diffusion': {'short': 'waifu', 'mem': 6, 'size': {'w': 512, 'h': 512}},
|
|
||||||
'nitrosocke/Ghibli-Diffusion': {'short': 'ghibli', 'mem': 6, 'size': {'w': 512, 'h': 512}},
|
|
||||||
'dallinmackay/Van-Gogh-diffusion': {'short': 'van-gogh', 'mem': 6, 'size': {'w': 512, 'h': 512}},
|
|
||||||
'lambdalabs/sd-pokemon-diffusers': {'short': 'pokemon', 'mem': 6, 'size': {'w': 512, 'h': 512}},
|
|
||||||
'Envvi/Inkpunk-Diffusion': {'short': 'ink', 'mem': 6, 'size': {'w': 512, 'h': 512}},
|
|
||||||
'nousr/robo-diffusion': {'short': 'robot', 'mem': 6, 'size': {'w': 512, 'h': 512}},
|
|
||||||
|
|
||||||
# -1 is always inpaint default
|
class Size(msgspec.Struct):
|
||||||
'diffusers/stable-diffusion-xl-1.0-inpainting-0.1': {'short': 'stablexl-inpainting', 'mem': 8.3, 'size': {'w': 1024, 'h': 1024}},
|
w: int
|
||||||
|
h: int
|
||||||
|
|
||||||
# default is always last
|
class ModelDesc(msgspec.Struct):
|
||||||
'stabilityai/stable-diffusion-xl-base-1.0': {'short': 'stablexl', 'mem': 8.3, 'size': {'w': 1024, 'h': 1024}},
|
short: str
|
||||||
|
mem: float
|
||||||
|
size: Size
|
||||||
|
tags: list[Literal['txt2img', 'img2img', 'inpaint']]
|
||||||
|
|
||||||
|
MODELS: dict[str, ModelDesc] = {
|
||||||
|
'runwayml/stable-diffusion-v1-5': ModelDesc(
|
||||||
|
short='stable',
|
||||||
|
mem=6,
|
||||||
|
size=Size(w=512, h=512),
|
||||||
|
tags=['txt2img']
|
||||||
|
),
|
||||||
|
'stabilityai/stable-diffusion-2-1-base': ModelDesc(
|
||||||
|
short='stable2',
|
||||||
|
mem=6,
|
||||||
|
size=Size(w=512, h=512),
|
||||||
|
tags=['txt2img']
|
||||||
|
),
|
||||||
|
'snowkidy/stable-diffusion-xl-base-0.9': ModelDesc(
|
||||||
|
short='stablexl0.9',
|
||||||
|
mem=8.3,
|
||||||
|
size=Size(w=1024, h=1024),
|
||||||
|
tags=['txt2img']
|
||||||
|
),
|
||||||
|
'Linaqruf/anything-v3.0': ModelDesc(
|
||||||
|
short='hdanime',
|
||||||
|
mem=6,
|
||||||
|
size=Size(w=512, h=512),
|
||||||
|
tags=['txt2img']
|
||||||
|
),
|
||||||
|
'hakurei/waifu-diffusion': ModelDesc(
|
||||||
|
short='waifu',
|
||||||
|
mem=6,
|
||||||
|
size=Size(w=512, h=512),
|
||||||
|
tags=['txt2img']
|
||||||
|
),
|
||||||
|
'nitrosocke/Ghibli-Diffusion': ModelDesc(
|
||||||
|
short='ghibli',
|
||||||
|
mem=6,
|
||||||
|
size=Size(w=512, h=512),
|
||||||
|
tags=['txt2img']
|
||||||
|
),
|
||||||
|
'dallinmackay/Van-Gogh-diffusion': ModelDesc(
|
||||||
|
short='van-gogh',
|
||||||
|
mem=6,
|
||||||
|
size=Size(w=512, h=512),
|
||||||
|
tags=['txt2img']
|
||||||
|
),
|
||||||
|
'lambdalabs/sd-pokemon-diffusers': ModelDesc(
|
||||||
|
short='pokemon',
|
||||||
|
mem=6,
|
||||||
|
size=Size(w=512, h=512),
|
||||||
|
tags=['txt2img']
|
||||||
|
),
|
||||||
|
'Envvi/Inkpunk-Diffusion': ModelDesc(
|
||||||
|
short='ink',
|
||||||
|
mem=6,
|
||||||
|
size=Size(w=512, h=512),
|
||||||
|
tags=['txt2img']
|
||||||
|
),
|
||||||
|
'nousr/robo-diffusion': ModelDesc(
|
||||||
|
short='robot',
|
||||||
|
mem=6,
|
||||||
|
size=Size(w=512, h=512),
|
||||||
|
tags=['txt2img']
|
||||||
|
),
|
||||||
|
'diffusers/stable-diffusion-xl-1.0-inpainting-0.1': ModelDesc(
|
||||||
|
short='stablexl-inpainting',
|
||||||
|
mem=8.3,
|
||||||
|
size=Size(w=1024, h=1024),
|
||||||
|
tags=['inpaint']
|
||||||
|
),
|
||||||
|
'prompthero/openjourney': ModelDesc(
|
||||||
|
short='midj',
|
||||||
|
mem=6,
|
||||||
|
size=Size(w=512, h=512),
|
||||||
|
tags=['txt2img', 'img2img']
|
||||||
|
),
|
||||||
|
'stabilityai/stable-diffusion-xl-base-1.0': ModelDesc(
|
||||||
|
short='stablexl',
|
||||||
|
mem=8.3,
|
||||||
|
size=Size(w=1024, h=1024),
|
||||||
|
tags=['txt2img']
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
SHORT_NAMES = [
|
SHORT_NAMES = [
|
||||||
model_info['short']
|
model_info.short
|
||||||
for model_info in MODELS.values()
|
for model_info in MODELS.values()
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_model_by_shortname(short: str):
|
def get_model_by_shortname(short: str):
|
||||||
for model, info in MODELS.items():
|
for model, info in MODELS.items():
|
||||||
if short == info['short']:
|
if short == info.short:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
N = '\n'
|
N = '\n'
|
||||||
|
@ -169,9 +243,7 @@ DEFAULT_UPSCALER = None
|
||||||
|
|
||||||
DEFAULT_CONFIG_PATH = 'skynet.toml'
|
DEFAULT_CONFIG_PATH = 'skynet.toml'
|
||||||
|
|
||||||
DEFAULT_INITAL_MODELS = [
|
DEFAULT_INITAL_MODEL = list(MODELS.keys())[-1]
|
||||||
'stabilityai/stable-diffusion-xl-base-1.0'
|
|
||||||
]
|
|
||||||
|
|
||||||
DATE_FORMAT = '%B the %dth %Y, %H:%M:%S'
|
DATE_FORMAT = '%B the %dth %Y, %H:%M:%S'
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ from diffusers import DiffusionPipeline
|
||||||
import trio
|
import trio
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
|
from skynet.constants import DEFAULT_INITAL_MODEL, MODELS
|
||||||
from skynet.dgpu.errors import DGPUComputeError, DGPUInferenceCancelled
|
from skynet.dgpu.errors import DGPUComputeError, DGPUInferenceCancelled
|
||||||
|
|
||||||
from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
|
from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
|
||||||
|
@ -21,26 +21,34 @@ from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_ima
|
||||||
|
|
||||||
def prepare_params_for_diffuse(
|
def prepare_params_for_diffuse(
|
||||||
params: dict,
|
params: dict,
|
||||||
input_type: str,
|
mode: str,
|
||||||
binary = None
|
inputs: list[bytes]
|
||||||
):
|
):
|
||||||
_params = {}
|
_params = {}
|
||||||
if binary != None:
|
match mode:
|
||||||
match input_type:
|
case 'inpaint':
|
||||||
case 'png':
|
|
||||||
image = crop_image(
|
image = crop_image(
|
||||||
binary, params['width'], params['height'])
|
inputs[0], params['width'], params['height'])
|
||||||
|
|
||||||
|
mask = crop_image(
|
||||||
|
inputs[1], params['width'], params['height'])
|
||||||
|
|
||||||
_params['image'] = image
|
_params['image'] = image
|
||||||
_params['strength'] = float(params['strength'])
|
_params['strength'] = float(params['strength'])
|
||||||
|
|
||||||
case 'none':
|
case 'img2img':
|
||||||
|
image = crop_image(
|
||||||
|
inputs[0], params['width'], params['height'])
|
||||||
|
|
||||||
|
_params['image'] = image
|
||||||
|
_params['strength'] = float(params['strength'])
|
||||||
|
|
||||||
|
case 'txt2img':
|
||||||
...
|
...
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
raise DGPUComputeError(f'Unknown input_type {input_type}')
|
raise DGPUComputeError(f'Unknown input_type {input_type}')
|
||||||
|
|
||||||
else:
|
|
||||||
_params['width'] = int(params['width'])
|
_params['width'] = int(params['width'])
|
||||||
_params['height'] = int(params['height'])
|
_params['height'] = int(params['height'])
|
||||||
|
|
||||||
|
@ -58,94 +66,52 @@ class SkynetMM:
|
||||||
|
|
||||||
def __init__(self, config: dict):
|
def __init__(self, config: dict):
|
||||||
self.upscaler = init_upscaler()
|
self.upscaler = init_upscaler()
|
||||||
self.initial_models = (
|
|
||||||
config['initial_models']
|
|
||||||
if 'initial_models' in config else DEFAULT_INITAL_MODELS
|
|
||||||
)
|
|
||||||
|
|
||||||
self.cache_dir = None
|
self.cache_dir = None
|
||||||
if 'hf_home' in config:
|
if 'hf_home' in config:
|
||||||
self.cache_dir = config['hf_home']
|
self.cache_dir = config['hf_home']
|
||||||
|
|
||||||
self._models = {}
|
self.load_model(DEFAULT_INITAL_MODEL, 'txt2img')
|
||||||
for model in self.initial_models:
|
|
||||||
self.load_model(model, False, force=True)
|
|
||||||
|
|
||||||
def log_debug_info(self):
|
def log_debug_info(self):
|
||||||
logging.info('memory summary:')
|
logging.info('memory summary:')
|
||||||
logging.info('\n' + torch.cuda.memory_summary())
|
logging.info('\n' + torch.cuda.memory_summary())
|
||||||
|
|
||||||
def is_model_loaded(self, model_name: str, image: bool):
|
def is_model_loaded(self, name: str, mode: str):
|
||||||
for model_key, model_data in self._models.items():
|
if (name == self._model_name and
|
||||||
if (model_key == model_name and
|
mode == self._model_mode):
|
||||||
model_data['image'] == image):
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
name: str,
|
||||||
image: bool,
|
mode: str
|
||||||
force=False
|
|
||||||
):
|
):
|
||||||
logging.info(f'loading model {model_name}...')
|
logging.info(f'loading model {model_name}...')
|
||||||
if force or len(self._models.keys()) == 0:
|
self._model_mode = mode
|
||||||
pipe = pipeline_for(
|
self._model_name = name
|
||||||
model_name, image=image, cache_dir=self.cache_dir)
|
|
||||||
|
|
||||||
self._models[model_name] = {
|
|
||||||
'pipe': pipe,
|
|
||||||
'generated': 0,
|
|
||||||
'image': image
|
|
||||||
}
|
|
||||||
|
|
||||||
else:
|
|
||||||
least_used = list(self._models.keys())[0]
|
|
||||||
|
|
||||||
for model in self._models:
|
|
||||||
if self._models[
|
|
||||||
least_used]['generated'] > self._models[model]['generated']:
|
|
||||||
least_used = model
|
|
||||||
|
|
||||||
del self._models[least_used]
|
|
||||||
|
|
||||||
logging.info(f'swapping model {least_used} for {model_name}...')
|
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
pipe = pipeline_for(
|
self._model = pipeline_for(
|
||||||
model_name, image=image, cache_dir=self.cache_dir)
|
name, mode, cache_dir=self.cache_dir)
|
||||||
|
|
||||||
self._models[model_name] = {
|
def get_model(self, name: str, mode: str) -> DiffusionPipeline:
|
||||||
'pipe': pipe,
|
if name not in MODELS:
|
||||||
'generated': 0,
|
|
||||||
'image': image
|
|
||||||
}
|
|
||||||
|
|
||||||
logging.info(f'loaded model {model_name}')
|
|
||||||
return pipe
|
|
||||||
|
|
||||||
def get_model(self, model_name: str, image: bool) -> DiffusionPipeline:
|
|
||||||
if model_name not in MODELS:
|
|
||||||
raise DGPUComputeError(f'Unknown model {model_name}')
|
raise DGPUComputeError(f'Unknown model {model_name}')
|
||||||
|
|
||||||
if not self.is_model_loaded(model_name, image):
|
if not self.is_model_loaded(name, mode):
|
||||||
pipe = self.load_model(model_name, image=image)
|
self.load_model(name, mode)
|
||||||
|
|
||||||
else:
|
|
||||||
pipe = self._models[model_name]['pipe']
|
|
||||||
|
|
||||||
return pipe
|
|
||||||
|
|
||||||
def compute_one(
|
def compute_one(
|
||||||
self,
|
self,
|
||||||
request_id: int,
|
request_id: int,
|
||||||
method: str,
|
method: str,
|
||||||
params: dict,
|
params: dict,
|
||||||
input_type: str = 'png',
|
inputs: list[bytes] = []
|
||||||
binary: bytes | None = None
|
|
||||||
):
|
):
|
||||||
def maybe_cancel_work(step, *args, **kwargs):
|
def maybe_cancel_work(step, *args, **kwargs):
|
||||||
if self._should_cancel:
|
if self._should_cancel:
|
||||||
|
@ -164,17 +130,16 @@ class SkynetMM:
|
||||||
output_hash = None
|
output_hash = None
|
||||||
try:
|
try:
|
||||||
match method:
|
match method:
|
||||||
case 'diffuse':
|
case 'txt2img' | 'img2img' | 'inpaint':
|
||||||
arguments = prepare_params_for_diffuse(
|
arguments = prepare_params_for_diffuse(
|
||||||
params, input_type, binary=binary)
|
params, method, inputs)
|
||||||
prompt, guidance, step, seed, upscaler, extra_params = arguments
|
prompt, guidance, step, seed, upscaler, extra_params = arguments
|
||||||
model = self.get_model(
|
self.get_model(
|
||||||
params['model'],
|
params['model'],
|
||||||
'image' in extra_params,
|
method
|
||||||
'mask_image' in extra_params
|
|
||||||
)
|
)
|
||||||
|
|
||||||
output = model(
|
output = self._model(
|
||||||
prompt,
|
prompt,
|
||||||
guidance_scale=guidance,
|
guidance_scale=guidance,
|
||||||
num_inference_steps=step,
|
num_inference_steps=step,
|
||||||
|
|
|
@ -117,22 +117,7 @@ class SkynetDGPUDaemon:
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
async def serve_forever(self):
|
async def maybe_serve_one(self, req):
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
if self.auto_withdraw:
|
|
||||||
await self.conn.maybe_withdraw_all()
|
|
||||||
|
|
||||||
queue = self._snap['queue']
|
|
||||||
|
|
||||||
random.shuffle(queue)
|
|
||||||
queue = sorted(
|
|
||||||
queue,
|
|
||||||
key=lambda req: convert_reward_to_int(req['reward']),
|
|
||||||
reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
for req in queue:
|
|
||||||
rid = req['id']
|
rid = req['id']
|
||||||
|
|
||||||
# parse request
|
# parse request
|
||||||
|
@ -142,23 +127,26 @@ class SkynetDGPUDaemon:
|
||||||
# if model not known
|
# if model not known
|
||||||
if model not in MODELS:
|
if model not in MODELS:
|
||||||
logging.warning(f'Unknown model {model}')
|
logging.warning(f'Unknown model {model}')
|
||||||
continue
|
return False
|
||||||
|
|
||||||
# if whitelist enabled and model not in it continue
|
# if whitelist enabled and model not in it continue
|
||||||
if (len(self.model_whitelist) > 0 and
|
if (len(self.model_whitelist) > 0 and
|
||||||
not model in self.model_whitelist):
|
not model in self.model_whitelist):
|
||||||
continue
|
return False
|
||||||
|
|
||||||
# if blacklist contains model skip
|
# if blacklist contains model skip
|
||||||
if model in self.model_blacklist:
|
if model in self.model_blacklist:
|
||||||
continue
|
return False
|
||||||
|
|
||||||
my_results = [res['id'] for res in self._snap['my_results']]
|
my_results = [res['id'] for res in self._snap['my_results']]
|
||||||
if rid not in my_results and rid in self._snap['requests']:
|
if rid not in my_results and rid in self._snap['requests']:
|
||||||
statuses = self._snap['requests'][rid]
|
statuses = self._snap['requests'][rid]
|
||||||
|
|
||||||
if len(statuses) == 0:
|
if len(statuses) == 0:
|
||||||
binary, input_type = await self.conn.get_input_data(req['binary_data'])
|
inputs = [
|
||||||
|
await self.conn.get_input_data(_input)
|
||||||
|
for _input in req['binary_data'].split(',')
|
||||||
|
]
|
||||||
|
|
||||||
hash_str = (
|
hash_str = (
|
||||||
str(req['nonce'])
|
str(req['nonce'])
|
||||||
|
@ -195,8 +183,7 @@ class SkynetDGPUDaemon:
|
||||||
self.mm.compute_one,
|
self.mm.compute_one,
|
||||||
rid,
|
rid,
|
||||||
body['method'], body['params'],
|
body['method'], body['params'],
|
||||||
input_type=input_type,
|
inputs=inputs
|
||||||
binary=binary
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -215,11 +202,30 @@ class SkynetDGPUDaemon:
|
||||||
await self.conn.cancel_work(rid, str(e))
|
await self.conn.cancel_work(rid, str(e))
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
break
|
return True
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logging.info(f'request {rid} already beign worked on, skip...')
|
logging.info(f'request {rid} already beign worked on, skip...')
|
||||||
|
|
||||||
|
async def serve_forever(self):
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
if self.auto_withdraw:
|
||||||
|
await self.conn.maybe_withdraw_all()
|
||||||
|
|
||||||
|
queue = self._snap['queue']
|
||||||
|
|
||||||
|
random.shuffle(queue)
|
||||||
|
queue = sorted(
|
||||||
|
queue,
|
||||||
|
key=lambda req: convert_reward_to_int(req['reward']),
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
|
||||||
|
for req in queue:
|
||||||
|
if (await self.maybe_serve_one(req)):
|
||||||
|
break
|
||||||
|
|
||||||
await trio.sleep(1)
|
await trio.sleep(1)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
|
|
|
@ -267,46 +267,15 @@ class SkynetGPUConnector:
|
||||||
|
|
||||||
return file_cid
|
return file_cid
|
||||||
|
|
||||||
async def get_input_data(self, ipfs_hash: str) -> tuple[bytes, str]:
|
async def get_input_data(self, ipfs_hash: str) -> Image:
|
||||||
input_type = 'none'
|
|
||||||
|
|
||||||
if ipfs_hash == '':
|
|
||||||
return b'', input_type
|
|
||||||
|
|
||||||
results = {}
|
|
||||||
ipfs_link = f'https://{self.ipfs_domain}/ipfs/{ipfs_hash}'
|
ipfs_link = f'https://{self.ipfs_domain}/ipfs/{ipfs_hash}'
|
||||||
ipfs_link_legacy = ipfs_link + '/image.png'
|
|
||||||
|
|
||||||
async with trio.open_nursery() as n:
|
|
||||||
async def get_and_set_results(link: str):
|
|
||||||
res = await get_ipfs_file(link, timeout=1)
|
res = await get_ipfs_file(link, timeout=1)
|
||||||
logging.info(f'got response from {link}')
|
logging.info(f'got response from {link}')
|
||||||
if not res or res.status_code != 200:
|
if not res or res.status_code != 200:
|
||||||
logging.warning(f'couldn\'t get ipfs binary data at {link}!')
|
logging.warning(f'couldn\'t get ipfs binary data at {link}!')
|
||||||
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
# attempt to decode as image
|
# attempt to decode as image
|
||||||
results[link] = Image.open(io.BytesIO(res.raw))
|
input_data = Image.open(io.BytesIO(res.raw))
|
||||||
input_type = 'png'
|
|
||||||
n.cancel_scope.cancel()
|
|
||||||
|
|
||||||
except UnidentifiedImageError:
|
return input_data
|
||||||
logging.warning(f'couldn\'t get ipfs binary data at {link}!')
|
|
||||||
|
|
||||||
n.start_soon(
|
|
||||||
get_and_set_results, ipfs_link)
|
|
||||||
n.start_soon(
|
|
||||||
get_and_set_results, ipfs_link_legacy)
|
|
||||||
|
|
||||||
input_data = None
|
|
||||||
if ipfs_link_legacy in results:
|
|
||||||
input_data = results[ipfs_link_legacy]
|
|
||||||
|
|
||||||
if ipfs_link in results:
|
|
||||||
input_data = results[ipfs_link]
|
|
||||||
|
|
||||||
if input_data == None:
|
|
||||||
raise DGPUComputeError('Couldn\'t gather input data from ipfs')
|
|
||||||
|
|
||||||
return input_data, input_type
|
|
||||||
|
|
|
@ -39,7 +39,7 @@ def validate_user_config_request(req: str):
|
||||||
case 'model' | 'algo':
|
case 'model' | 'algo':
|
||||||
attr = 'model'
|
attr = 'model'
|
||||||
val = params[2]
|
val = params[2]
|
||||||
shorts = [model_info['short'] for model_info in MODELS.values()]
|
shorts = [model_info.short for model_info in MODELS.values()]
|
||||||
if val not in shorts:
|
if val not in shorts:
|
||||||
raise ConfigUnknownAlgorithm(f'no model named {val}')
|
raise ConfigUnknownAlgorithm(f'no model named {val}')
|
||||||
|
|
||||||
|
@ -112,20 +112,10 @@ def validate_user_config_request(req: str):
|
||||||
|
|
||||||
|
|
||||||
def perform_auto_conf(config: dict) -> dict:
|
def perform_auto_conf(config: dict) -> dict:
|
||||||
model = config['model']
|
model = MODELS[config['model']]
|
||||||
prefered_size_w = 512
|
|
||||||
prefered_size_h = 512
|
|
||||||
|
|
||||||
if 'xl' in model:
|
|
||||||
prefered_size_w = 1024
|
|
||||||
prefered_size_h = 1024
|
|
||||||
|
|
||||||
else:
|
|
||||||
prefered_size_w = 512
|
|
||||||
prefered_size_h = 512
|
|
||||||
|
|
||||||
config['step'] = random.randint(20, 35)
|
config['step'] = random.randint(20, 35)
|
||||||
config['width'] = prefered_size_w
|
config['width'] = model.size.w
|
||||||
config['height'] = prefered_size_h
|
config['height'] = model.size.h
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
|
@ -116,7 +116,7 @@ class SkynetTelegramFrontend:
|
||||||
method: str,
|
method: str,
|
||||||
params: dict,
|
params: dict,
|
||||||
file_id: str | None = None,
|
file_id: str | None = None,
|
||||||
binary_data: str = ''
|
inputs: list[str] = []
|
||||||
) -> bool:
|
) -> bool:
|
||||||
if params['seed'] == None:
|
if params['seed'] == None:
|
||||||
params['seed'] = random.randint(0, 0xFFFFFFFF)
|
params['seed'] = random.randint(0, 0xFFFFFFFF)
|
||||||
|
@ -148,7 +148,7 @@ class SkynetTelegramFrontend:
|
||||||
{
|
{
|
||||||
'user': Name(self.account),
|
'user': Name(self.account),
|
||||||
'request_body': body,
|
'request_body': body,
|
||||||
'binary_data': binary_data,
|
'binary_data': inputs.joint(','),
|
||||||
'reward': asset_from_str(reward),
|
'reward': asset_from_str(reward),
|
||||||
'min_verification': 1
|
'min_verification': 1
|
||||||
},
|
},
|
||||||
|
@ -181,7 +181,7 @@ class SkynetTelegramFrontend:
|
||||||
request_id, nonce = out.split(':')
|
request_id, nonce = out.split(':')
|
||||||
|
|
||||||
request_hash = sha256(
|
request_hash = sha256(
|
||||||
(nonce + body + binary_data).encode('utf-8')).hexdigest().upper()
|
(nonce + body + inputs.join(',')).encode('utf-8')).hexdigest().upper()
|
||||||
|
|
||||||
request_id = int(request_id)
|
request_id = int(request_id)
|
||||||
|
|
||||||
|
@ -241,11 +241,8 @@ class SkynetTelegramFrontend:
|
||||||
user, params, tx_hash, worker, reward, self.explorer_domain)
|
user, params, tx_hash, worker, reward, self.explorer_domain)
|
||||||
|
|
||||||
# attempt to get the image and send it
|
# attempt to get the image and send it
|
||||||
results = {}
|
|
||||||
ipfs_link = f'https://{self.ipfs_domain}/ipfs/{ipfs_hash}'
|
ipfs_link = f'https://{self.ipfs_domain}/ipfs/{ipfs_hash}'
|
||||||
ipfs_link_legacy = ipfs_link + '/image.png'
|
|
||||||
|
|
||||||
async def get_and_set_results(link: str):
|
|
||||||
res = await get_ipfs_file(link)
|
res = await get_ipfs_file(link)
|
||||||
logging.info(f'got response from {link}')
|
logging.info(f'got response from {link}')
|
||||||
if not res or res.status_code != 200:
|
if not res or res.status_code != 200:
|
||||||
|
@ -264,24 +261,9 @@ class SkynetTelegramFrontend:
|
||||||
image.save(tmp_buf, format='PNG')
|
image.save(tmp_buf, format='PNG')
|
||||||
png_img = tmp_buf.getvalue()
|
png_img = tmp_buf.getvalue()
|
||||||
|
|
||||||
results[link] = png_img
|
|
||||||
|
|
||||||
except UnidentifiedImageError:
|
except UnidentifiedImageError:
|
||||||
logging.warning(f'couldn\'t get ipfs binary data at {link}!')
|
logging.warning(f'couldn\'t get ipfs binary data at {link}!')
|
||||||
|
|
||||||
tasks = [
|
|
||||||
get_and_set_results(ipfs_link),
|
|
||||||
get_and_set_results(ipfs_link_legacy)
|
|
||||||
]
|
|
||||||
await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
png_img = None
|
|
||||||
if ipfs_link_legacy in results:
|
|
||||||
png_img = results[ipfs_link_legacy]
|
|
||||||
|
|
||||||
if ipfs_link in results:
|
|
||||||
png_img = results[ipfs_link]
|
|
||||||
|
|
||||||
if not png_img:
|
if not png_img:
|
||||||
await self.update_status_message(
|
await self.update_status_message(
|
||||||
status_msg,
|
status_msg,
|
||||||
|
|
|
@ -66,8 +66,7 @@ def convert_from_bytes_and_crop(raw: bytes, max_w: int, max_h: int) -> Image:
|
||||||
def pipeline_for(
|
def pipeline_for(
|
||||||
model: str,
|
model: str,
|
||||||
mem_fraction: float = 1.0,
|
mem_fraction: float = 1.0,
|
||||||
image: bool = False,
|
mode: str = [],
|
||||||
inpainting: bool = False,
|
|
||||||
cache_dir: str | None = None
|
cache_dir: str | None = None
|
||||||
) -> DiffusionPipeline:
|
) -> DiffusionPipeline:
|
||||||
|
|
||||||
|
@ -85,14 +84,14 @@ def pipeline_for(
|
||||||
|
|
||||||
model_info = MODELS[model]
|
model_info = MODELS[model]
|
||||||
|
|
||||||
req_mem = model_info['mem']
|
req_mem = model_info.mem
|
||||||
mem_gb = torch.cuda.mem_get_info()[1] / (10**9)
|
mem_gb = torch.cuda.mem_get_info()[1] / (10**9)
|
||||||
mem_gb *= mem_fraction
|
mem_gb *= mem_fraction
|
||||||
over_mem = mem_gb < req_mem
|
over_mem = mem_gb < req_mem
|
||||||
if over_mem:
|
if over_mem:
|
||||||
logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..')
|
logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..')
|
||||||
|
|
||||||
shortname = model_info['short']
|
shortname = model_info.short
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'safety_checker': None,
|
'safety_checker': None,
|
||||||
|
@ -107,12 +106,13 @@ def pipeline_for(
|
||||||
|
|
||||||
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
||||||
|
|
||||||
if inpainting:
|
if 'inpaint' in mode:
|
||||||
pipe = AutoPipelineForInpainting.from_pretrained(
|
pipe_class = AutoPipelineForInpainting
|
||||||
model, **params)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
pipe = DiffusionPipeline.from_pretrained(
|
pipe_class = DiffusionPipeline
|
||||||
|
|
||||||
|
pipe = AutoPipelineForInpainting.from_pretrained(
|
||||||
model, **params)
|
model, **params)
|
||||||
|
|
||||||
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
||||||
|
@ -121,7 +121,7 @@ def pipeline_for(
|
||||||
pipe.enable_xformers_memory_efficient_attention()
|
pipe.enable_xformers_memory_efficient_attention()
|
||||||
|
|
||||||
if over_mem:
|
if over_mem:
|
||||||
if not image:
|
if 'img2img' not in mode:
|
||||||
pipe.enable_vae_slicing()
|
pipe.enable_vae_slicing()
|
||||||
pipe.enable_vae_tiling()
|
pipe.enable_vae_tiling()
|
||||||
|
|
||||||
|
@ -140,7 +140,7 @@ def pipeline_for(
|
||||||
|
|
||||||
def txt2img(
|
def txt2img(
|
||||||
hf_token: str,
|
hf_token: str,
|
||||||
model: str = 'prompthero/openjourney',
|
model: str = list(MODELS.keys())[-1],
|
||||||
prompt: str = 'a red old tractor in a sunny wheat field',
|
prompt: str = 'a red old tractor in a sunny wheat field',
|
||||||
output: str = 'output.png',
|
output: str = 'output.png',
|
||||||
width: int = 512, height: int = 512,
|
width: int = 512, height: int = 512,
|
||||||
|
@ -166,7 +166,7 @@ def txt2img(
|
||||||
|
|
||||||
def img2img(
|
def img2img(
|
||||||
hf_token: str,
|
hf_token: str,
|
||||||
model: str = 'prompthero/openjourney',
|
model: str = list(MODELS.keys())[-2],
|
||||||
prompt: str = 'a red old tractor in a sunny wheat field',
|
prompt: str = 'a red old tractor in a sunny wheat field',
|
||||||
img_path: str = 'input.png',
|
img_path: str = 'input.png',
|
||||||
output: str = 'output.png',
|
output: str = 'output.png',
|
||||||
|
@ -181,7 +181,7 @@ def img2img(
|
||||||
model_info = MODELS[model]
|
model_info = MODELS[model]
|
||||||
|
|
||||||
with open(img_path, 'rb') as img_file:
|
with open(img_path, 'rb') as img_file:
|
||||||
input_img = convert_from_bytes_and_crop(img_file.read(), model_info['size']['w'], model_info['size']['h'])
|
input_img = convert_from_bytes_and_crop(img_file.read(), model_info.size.w, model_info.size.h)
|
||||||
|
|
||||||
seed = seed if seed else random.randint(0, 2 ** 64)
|
seed = seed if seed else random.randint(0, 2 ** 64)
|
||||||
prompt = prompt
|
prompt = prompt
|
||||||
|
@ -198,7 +198,7 @@ def img2img(
|
||||||
|
|
||||||
def inpaint(
|
def inpaint(
|
||||||
hf_token: str,
|
hf_token: str,
|
||||||
model: str = 'diffusers/stable-diffusion-xl-1.0-inpainting-0.1',
|
model: str = list(MODELS.keys())[-3],
|
||||||
prompt: str = 'a red old tractor in a sunny wheat field',
|
prompt: str = 'a red old tractor in a sunny wheat field',
|
||||||
img_path: str = 'input.png',
|
img_path: str = 'input.png',
|
||||||
mask_path: str = 'mask.png',
|
mask_path: str = 'mask.png',
|
||||||
|
@ -214,10 +214,10 @@ def inpaint(
|
||||||
model_info = MODELS[model]
|
model_info = MODELS[model]
|
||||||
|
|
||||||
with open(img_path, 'rb') as img_file:
|
with open(img_path, 'rb') as img_file:
|
||||||
input_img = convert_from_bytes_and_crop(img_file.read(), model_info['size']['w'], model_info['size']['h'])
|
input_img = convert_from_bytes_and_crop(img_file.read(), model_info.size.w, model_info.size.h)
|
||||||
|
|
||||||
with open(mask_path, 'rb') as mask_file:
|
with open(mask_path, 'rb') as mask_file:
|
||||||
mask_img = convert_from_bytes_and_crop(mask_file.read(), model_info['size']['w'], model_info['size']['h'])
|
mask_img = convert_from_bytes_and_crop(mask_file.read(), model_info.size.w, model_info.size.h)
|
||||||
|
|
||||||
seed = seed if seed else random.randint(0, 2 ** 64)
|
seed = seed if seed else random.randint(0, 2 ** 64)
|
||||||
prompt = prompt
|
prompt = prompt
|
||||||
|
|
|
@ -0,0 +1,98 @@
|
||||||
|
|
||||||
|
from skynet.config import *
|
||||||
|
|
||||||
|
async def test_txt2img():
|
||||||
|
req = {
|
||||||
|
'id': 0,
|
||||||
|
'body': json.dumps({
|
||||||
|
"method": "txt2img",
|
||||||
|
"params": {
|
||||||
|
"prompt": "Kronos God Realistic 4k",
|
||||||
|
"model": list(MODELS.keys())[-1],
|
||||||
|
"step": 21,
|
||||||
|
"width": 1024,
|
||||||
|
"height": 1024,
|
||||||
|
"seed": 168402949,
|
||||||
|
"guidance": "7.5"
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
'inputs': [],
|
||||||
|
}
|
||||||
|
|
||||||
|
config = load_skynet_toml(file_path=config_path)
|
||||||
|
hf_token = load_key(config, 'skynet.dgpu.hf_token')
|
||||||
|
hf_home = load_key(config, 'skynet.dgpu.hf_home')
|
||||||
|
set_hf_vars(hf_token, hf_home)
|
||||||
|
|
||||||
|
assert 'skynet' in config
|
||||||
|
assert 'dgpu' in config['skynet']
|
||||||
|
|
||||||
|
mm = SkynetMM(config['skynet']['dgpu'])
|
||||||
|
|
||||||
|
mm.maybe_serve_one(req)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_img2img():
|
||||||
|
req = {
|
||||||
|
'id': 0,
|
||||||
|
'body': json.dumps({
|
||||||
|
"method": "img2img",
|
||||||
|
"params": {
|
||||||
|
"prompt": "Kronos God Realistic 4k",
|
||||||
|
"model": list(MODELS.keys())[-2],
|
||||||
|
"step": 21,
|
||||||
|
"width": 1024,
|
||||||
|
"height": 1024,
|
||||||
|
"seed": 168402949,
|
||||||
|
"guidance": "7.5",
|
||||||
|
"strength": "0.5"
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
'inputs': ['QmZcGdXXVQfpco1G3tr2CGFBtv8xVsCwcwuq9gnJBWDymi'],
|
||||||
|
}
|
||||||
|
|
||||||
|
config = load_skynet_toml(file_path=config_path)
|
||||||
|
hf_token = load_key(config, 'skynet.dgpu.hf_token')
|
||||||
|
hf_home = load_key(config, 'skynet.dgpu.hf_home')
|
||||||
|
set_hf_vars(hf_token, hf_home)
|
||||||
|
|
||||||
|
assert 'skynet' in config
|
||||||
|
assert 'dgpu' in config['skynet']
|
||||||
|
|
||||||
|
mm = SkynetMM(config['skynet']['dgpu'])
|
||||||
|
|
||||||
|
mm.maybe_serve_one(req)
|
||||||
|
|
||||||
|
async def test_inpaint():
|
||||||
|
req = {
|
||||||
|
'id': 0,
|
||||||
|
'body': json.dumps({
|
||||||
|
"method": "inpaint",
|
||||||
|
"params": {
|
||||||
|
"prompt": "a black panther on a sunny roof",
|
||||||
|
"model": list(MODELS.keys())[-3],
|
||||||
|
"step": 21,
|
||||||
|
"width": 1024,
|
||||||
|
"height": 1024,
|
||||||
|
"seed": 168402949,
|
||||||
|
"guidance": "7.5",
|
||||||
|
"strength": "0.5"
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
'inputs': [
|
||||||
|
'QmZcGdXXVQfpco1G3tr2CGFBtv8xVsCwcwuq9gnJBWDymi',
|
||||||
|
'Qmccx1aXNmq5mZDS3YviUhgGHXWhQeHvca3AgA7MDjj2hR'
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
config = load_skynet_toml(file_path=config_path)
|
||||||
|
hf_token = load_key(config, 'skynet.dgpu.hf_token')
|
||||||
|
hf_home = load_key(config, 'skynet.dgpu.hf_home')
|
||||||
|
set_hf_vars(hf_token, hf_home)
|
||||||
|
|
||||||
|
assert 'skynet' in config
|
||||||
|
assert 'dgpu' in config['skynet']
|
||||||
|
|
||||||
|
mm = SkynetMM(config['skynet']['dgpu'])
|
||||||
|
|
||||||
|
mm.maybe_serve_one(req)
|
Loading…
Reference in New Issue