Start testing inpainting mode

txt2txt
Guillermo Rodriguez 2025-01-09 21:10:07 -03:00
parent 1e40c05da6
commit 8d35e5ed9a
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
9 changed files with 384 additions and 302 deletions

View File

@ -20,7 +20,7 @@ def skynet(*args, **kwargs):
@click.command() @click.command()
@click.option('--model', '-m', default='midj') @click.option('--model', '-m', default=list(MODELS.keys())[-1])
@click.option( @click.option(
'--prompt', '-p', default='a red old tractor in a sunny wheat field') '--prompt', '-p', default='a red old tractor in a sunny wheat field')
@click.option('--output', '-o', default='output.png') @click.option('--output', '-o', default='output.png')
@ -39,7 +39,7 @@ def txt2img(*args, **kwargs):
utils.txt2img(hf_token, **kwargs) utils.txt2img(hf_token, **kwargs)
@click.command() @click.command()
@click.option('--model', '-m', default=list(MODELS.keys())[0]) @click.option('--model', '-m', default=list(MODELS.keys())[-2])
@click.option( @click.option(
'--prompt', '-p', default='a red old tractor in a sunny wheat field') '--prompt', '-p', default='a red old tractor in a sunny wheat field')
@click.option('--input', '-i', default='input.png') @click.option('--input', '-i', default='input.png')
@ -68,7 +68,7 @@ def img2img(model, prompt, input, output, strength, guidance, steps, seed):
@click.command() @click.command()
@click.option('--model', '-m', default=list(MODELS.keys())[-1]) @click.option('--model', '-m', default=list(MODELS.keys())[-3])
@click.option( @click.option(
'--prompt', '-p', default='a red old tractor in a sunny wheat field') '--prompt', '-p', default='a red old tractor in a sunny wheat field')
@click.option('--input', '-i', default='input.png') @click.option('--input', '-i', default='input.png')

View File

@ -4,34 +4,108 @@ VERSION = '0.1a12'
DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda' DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda'
MODELS = { import msgspec
'prompthero/openjourney': {'short': 'midj', 'mem': 6, 'size': {'w': 512, 'h': 512}}, from typing import Literal
'runwayml/stable-diffusion-v1-5': {'short': 'stable', 'mem': 6, 'size': {'w': 512, 'h': 512}},
'stabilityai/stable-diffusion-2-1-base': {'short': 'stable2', 'mem': 6, 'size': {'w': 512, 'h': 512}},
'snowkidy/stable-diffusion-xl-base-0.9': {'short': 'stablexl0.9', 'mem': 8.3, 'size': {'w': 1024, 'h': 1024}},
'Linaqruf/anything-v3.0': {'short': 'hdanime', 'mem': 6, 'size': {'w': 512, 'h': 512}},
'hakurei/waifu-diffusion': {'short': 'waifu', 'mem': 6, 'size': {'w': 512, 'h': 512}},
'nitrosocke/Ghibli-Diffusion': {'short': 'ghibli', 'mem': 6, 'size': {'w': 512, 'h': 512}},
'dallinmackay/Van-Gogh-diffusion': {'short': 'van-gogh', 'mem': 6, 'size': {'w': 512, 'h': 512}},
'lambdalabs/sd-pokemon-diffusers': {'short': 'pokemon', 'mem': 6, 'size': {'w': 512, 'h': 512}},
'Envvi/Inkpunk-Diffusion': {'short': 'ink', 'mem': 6, 'size': {'w': 512, 'h': 512}},
'nousr/robo-diffusion': {'short': 'robot', 'mem': 6, 'size': {'w': 512, 'h': 512}},
# -1 is always inpaint default class Size(msgspec.Struct):
'diffusers/stable-diffusion-xl-1.0-inpainting-0.1': {'short': 'stablexl-inpainting', 'mem': 8.3, 'size': {'w': 1024, 'h': 1024}}, w: int
h: int
# default is always last class ModelDesc(msgspec.Struct):
'stabilityai/stable-diffusion-xl-base-1.0': {'short': 'stablexl', 'mem': 8.3, 'size': {'w': 1024, 'h': 1024}}, short: str
mem: float
size: Size
tags: list[Literal['txt2img', 'img2img', 'inpaint']]
MODELS: dict[str, ModelDesc] = {
'runwayml/stable-diffusion-v1-5': ModelDesc(
short='stable',
mem=6,
size=Size(w=512, h=512),
tags=['txt2img']
),
'stabilityai/stable-diffusion-2-1-base': ModelDesc(
short='stable2',
mem=6,
size=Size(w=512, h=512),
tags=['txt2img']
),
'snowkidy/stable-diffusion-xl-base-0.9': ModelDesc(
short='stablexl0.9',
mem=8.3,
size=Size(w=1024, h=1024),
tags=['txt2img']
),
'Linaqruf/anything-v3.0': ModelDesc(
short='hdanime',
mem=6,
size=Size(w=512, h=512),
tags=['txt2img']
),
'hakurei/waifu-diffusion': ModelDesc(
short='waifu',
mem=6,
size=Size(w=512, h=512),
tags=['txt2img']
),
'nitrosocke/Ghibli-Diffusion': ModelDesc(
short='ghibli',
mem=6,
size=Size(w=512, h=512),
tags=['txt2img']
),
'dallinmackay/Van-Gogh-diffusion': ModelDesc(
short='van-gogh',
mem=6,
size=Size(w=512, h=512),
tags=['txt2img']
),
'lambdalabs/sd-pokemon-diffusers': ModelDesc(
short='pokemon',
mem=6,
size=Size(w=512, h=512),
tags=['txt2img']
),
'Envvi/Inkpunk-Diffusion': ModelDesc(
short='ink',
mem=6,
size=Size(w=512, h=512),
tags=['txt2img']
),
'nousr/robo-diffusion': ModelDesc(
short='robot',
mem=6,
size=Size(w=512, h=512),
tags=['txt2img']
),
'diffusers/stable-diffusion-xl-1.0-inpainting-0.1': ModelDesc(
short='stablexl-inpainting',
mem=8.3,
size=Size(w=1024, h=1024),
tags=['inpaint']
),
'prompthero/openjourney': ModelDesc(
short='midj',
mem=6,
size=Size(w=512, h=512),
tags=['txt2img', 'img2img']
),
'stabilityai/stable-diffusion-xl-base-1.0': ModelDesc(
short='stablexl',
mem=8.3,
size=Size(w=1024, h=1024),
tags=['txt2img']
),
} }
SHORT_NAMES = [ SHORT_NAMES = [
model_info['short'] model_info.short
for model_info in MODELS.values() for model_info in MODELS.values()
] ]
def get_model_by_shortname(short: str): def get_model_by_shortname(short: str):
for model, info in MODELS.items(): for model, info in MODELS.items():
if short == info['short']: if short == info.short:
return model return model
N = '\n' N = '\n'
@ -169,9 +243,7 @@ DEFAULT_UPSCALER = None
DEFAULT_CONFIG_PATH = 'skynet.toml' DEFAULT_CONFIG_PATH = 'skynet.toml'
DEFAULT_INITAL_MODELS = [ DEFAULT_INITAL_MODEL = list(MODELS.keys())[-1]
'stabilityai/stable-diffusion-xl-base-1.0'
]
DATE_FORMAT = '%B the %dth %Y, %H:%M:%S' DATE_FORMAT = '%B the %dth %Y, %H:%M:%S'

View File

@ -13,7 +13,7 @@ from diffusers import DiffusionPipeline
import trio import trio
import torch import torch
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS from skynet.constants import DEFAULT_INITAL_MODEL, MODELS
from skynet.dgpu.errors import DGPUComputeError, DGPUInferenceCancelled from skynet.dgpu.errors import DGPUComputeError, DGPUInferenceCancelled
from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
@ -21,28 +21,36 @@ from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_ima
def prepare_params_for_diffuse( def prepare_params_for_diffuse(
params: dict, params: dict,
input_type: str, mode: str,
binary = None inputs: list[bytes]
): ):
_params = {} _params = {}
if binary != None: match mode:
match input_type: case 'inpaint':
case 'png': image = crop_image(
image = crop_image( inputs[0], params['width'], params['height'])
binary, params['width'], params['height'])
_params['image'] = image mask = crop_image(
_params['strength'] = float(params['strength']) inputs[1], params['width'], params['height'])
case 'none': _params['image'] = image
... _params['strength'] = float(params['strength'])
case _: case 'img2img':
raise DGPUComputeError(f'Unknown input_type {input_type}') image = crop_image(
inputs[0], params['width'], params['height'])
else: _params['image'] = image
_params['width'] = int(params['width']) _params['strength'] = float(params['strength'])
_params['height'] = int(params['height'])
case 'txt2img':
...
case _:
raise DGPUComputeError(f'Unknown input_type {input_type}')
_params['width'] = int(params['width'])
_params['height'] = int(params['height'])
return ( return (
params['prompt'], params['prompt'],
@ -58,94 +66,52 @@ class SkynetMM:
def __init__(self, config: dict): def __init__(self, config: dict):
self.upscaler = init_upscaler() self.upscaler = init_upscaler()
self.initial_models = (
config['initial_models']
if 'initial_models' in config else DEFAULT_INITAL_MODELS
)
self.cache_dir = None self.cache_dir = None
if 'hf_home' in config: if 'hf_home' in config:
self.cache_dir = config['hf_home'] self.cache_dir = config['hf_home']
self._models = {} self.load_model(DEFAULT_INITAL_MODEL, 'txt2img')
for model in self.initial_models:
self.load_model(model, False, force=True)
def log_debug_info(self): def log_debug_info(self):
logging.info('memory summary:') logging.info('memory summary:')
logging.info('\n' + torch.cuda.memory_summary()) logging.info('\n' + torch.cuda.memory_summary())
def is_model_loaded(self, model_name: str, image: bool): def is_model_loaded(self, name: str, mode: str):
for model_key, model_data in self._models.items(): if (name == self._model_name and
if (model_key == model_name and mode == self._model_mode):
model_data['image'] == image): return True
return True
return False return False
def load_model( def load_model(
self, self,
model_name: str, name: str,
image: bool, mode: str
force=False
): ):
logging.info(f'loading model {model_name}...') logging.info(f'loading model {model_name}...')
if force or len(self._models.keys()) == 0: self._model_mode = mode
pipe = pipeline_for( self._model_name = name
model_name, image=image, cache_dir=self.cache_dir)
self._models[model_name] = { gc.collect()
'pipe': pipe, torch.cuda.empty_cache()
'generated': 0,
'image': image
}
else: self._model = pipeline_for(
least_used = list(self._models.keys())[0] name, mode, cache_dir=self.cache_dir)
for model in self._models: def get_model(self, name: str, mode: str) -> DiffusionPipeline:
if self._models[ if name not in MODELS:
least_used]['generated'] > self._models[model]['generated']:
least_used = model
del self._models[least_used]
logging.info(f'swapping model {least_used} for {model_name}...')
gc.collect()
torch.cuda.empty_cache()
pipe = pipeline_for(
model_name, image=image, cache_dir=self.cache_dir)
self._models[model_name] = {
'pipe': pipe,
'generated': 0,
'image': image
}
logging.info(f'loaded model {model_name}')
return pipe
def get_model(self, model_name: str, image: bool) -> DiffusionPipeline:
if model_name not in MODELS:
raise DGPUComputeError(f'Unknown model {model_name}') raise DGPUComputeError(f'Unknown model {model_name}')
if not self.is_model_loaded(model_name, image): if not self.is_model_loaded(name, mode):
pipe = self.load_model(model_name, image=image) self.load_model(name, mode)
else:
pipe = self._models[model_name]['pipe']
return pipe
def compute_one( def compute_one(
self, self,
request_id: int, request_id: int,
method: str, method: str,
params: dict, params: dict,
input_type: str = 'png', inputs: list[bytes] = []
binary: bytes | None = None
): ):
def maybe_cancel_work(step, *args, **kwargs): def maybe_cancel_work(step, *args, **kwargs):
if self._should_cancel: if self._should_cancel:
@ -164,17 +130,16 @@ class SkynetMM:
output_hash = None output_hash = None
try: try:
match method: match method:
case 'diffuse': case 'txt2img' | 'img2img' | 'inpaint':
arguments = prepare_params_for_diffuse( arguments = prepare_params_for_diffuse(
params, input_type, binary=binary) params, method, inputs)
prompt, guidance, step, seed, upscaler, extra_params = arguments prompt, guidance, step, seed, upscaler, extra_params = arguments
model = self.get_model( self.get_model(
params['model'], params['model'],
'image' in extra_params, method
'mask_image' in extra_params
) )
output = model( output = self._model(
prompt, prompt,
guidance_scale=guidance, guidance_scale=guidance,
num_inference_steps=step, num_inference_steps=step,

View File

@ -117,6 +117,96 @@ class SkynetDGPUDaemon:
return app return app
async def maybe_serve_one(self, req):
rid = req['id']
# parse request
body = json.loads(req['body'])
model = body['params']['model']
# if model not known
if model not in MODELS:
logging.warning(f'Unknown model {model}')
return False
# if whitelist enabled and model not in it continue
if (len(self.model_whitelist) > 0 and
not model in self.model_whitelist):
return False
# if blacklist contains model skip
if model in self.model_blacklist:
return False
my_results = [res['id'] for res in self._snap['my_results']]
if rid not in my_results and rid in self._snap['requests']:
statuses = self._snap['requests'][rid]
if len(statuses) == 0:
inputs = [
await self.conn.get_input_data(_input)
for _input in req['binary_data'].split(',')
]
hash_str = (
str(req['nonce'])
+
req['body']
+
req['binary_data']
)
logging.info(f'hashing: {hash_str}')
request_hash = sha256(hash_str.encode('utf-8')).hexdigest()
# TODO: validate request
# perform work
logging.info(f'working on {body}')
resp = await self.conn.begin_work(rid)
if 'code' in resp:
logging.info(f'probably being worked on already... skip.')
else:
try:
output_type = 'png'
if 'output_type' in body['params']:
output_type = body['params']['output_type']
output = None
output_hash = None
match self.backend:
case 'sync-on-thread':
self.mm._should_cancel = self.should_cancel_work
output_hash, output = await trio.to_thread.run_sync(
partial(
self.mm.compute_one,
rid,
body['method'], body['params'],
inputs=inputs
)
)
case _:
raise DGPUComputeError(f'Unsupported backend {self.backend}')
self._last_generation_ts = datetime.now().isoformat()
self._last_benchmark = self._benchmark
self._benchmark = []
ipfs_hash = await self.conn.publish_on_ipfs(output, typ=output_type)
await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
except BaseException as e:
traceback.print_exc()
await self.conn.cancel_work(rid, str(e))
finally:
return True
else:
logging.info(f'request {rid} already beign worked on, skip...')
async def serve_forever(self): async def serve_forever(self):
try: try:
while True: while True:
@ -133,92 +223,8 @@ class SkynetDGPUDaemon:
) )
for req in queue: for req in queue:
rid = req['id'] if (await self.maybe_serve_one(req)):
break
# parse request
body = json.loads(req['body'])
model = body['params']['model']
# if model not known
if model not in MODELS:
logging.warning(f'Unknown model {model}')
continue
# if whitelist enabled and model not in it continue
if (len(self.model_whitelist) > 0 and
not model in self.model_whitelist):
continue
# if blacklist contains model skip
if model in self.model_blacklist:
continue
my_results = [res['id'] for res in self._snap['my_results']]
if rid not in my_results and rid in self._snap['requests']:
statuses = self._snap['requests'][rid]
if len(statuses) == 0:
binary, input_type = await self.conn.get_input_data(req['binary_data'])
hash_str = (
str(req['nonce'])
+
req['body']
+
req['binary_data']
)
logging.info(f'hashing: {hash_str}')
request_hash = sha256(hash_str.encode('utf-8')).hexdigest()
# TODO: validate request
# perform work
logging.info(f'working on {body}')
resp = await self.conn.begin_work(rid)
if 'code' in resp:
logging.info(f'probably being worked on already... skip.')
else:
try:
output_type = 'png'
if 'output_type' in body['params']:
output_type = body['params']['output_type']
output = None
output_hash = None
match self.backend:
case 'sync-on-thread':
self.mm._should_cancel = self.should_cancel_work
output_hash, output = await trio.to_thread.run_sync(
partial(
self.mm.compute_one,
rid,
body['method'], body['params'],
input_type=input_type,
binary=binary
)
)
case _:
raise DGPUComputeError(f'Unsupported backend {self.backend}')
self._last_generation_ts = datetime.now().isoformat()
self._last_benchmark = self._benchmark
self._benchmark = []
ipfs_hash = await self.conn.publish_on_ipfs(output, typ=output_type)
await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
except BaseException as e:
traceback.print_exc()
await self.conn.cancel_work(rid, str(e))
finally:
break
else:
logging.info(f'request {rid} already beign worked on, skip...')
await trio.sleep(1) await trio.sleep(1)

View File

@ -267,46 +267,15 @@ class SkynetGPUConnector:
return file_cid return file_cid
async def get_input_data(self, ipfs_hash: str) -> tuple[bytes, str]: async def get_input_data(self, ipfs_hash: str) -> Image:
input_type = 'none'
if ipfs_hash == '':
return b'', input_type
results = {}
ipfs_link = f'https://{self.ipfs_domain}/ipfs/{ipfs_hash}' ipfs_link = f'https://{self.ipfs_domain}/ipfs/{ipfs_hash}'
ipfs_link_legacy = ipfs_link + '/image.png'
async with trio.open_nursery() as n: res = await get_ipfs_file(link, timeout=1)
async def get_and_set_results(link: str): logging.info(f'got response from {link}')
res = await get_ipfs_file(link, timeout=1) if not res or res.status_code != 200:
logging.info(f'got response from {link}') logging.warning(f'couldn\'t get ipfs binary data at {link}!')
if not res or res.status_code != 200:
logging.warning(f'couldn\'t get ipfs binary data at {link}!')
else: # attempt to decode as image
try: input_data = Image.open(io.BytesIO(res.raw))
# attempt to decode as image
results[link] = Image.open(io.BytesIO(res.raw))
input_type = 'png'
n.cancel_scope.cancel()
except UnidentifiedImageError: return input_data
logging.warning(f'couldn\'t get ipfs binary data at {link}!')
n.start_soon(
get_and_set_results, ipfs_link)
n.start_soon(
get_and_set_results, ipfs_link_legacy)
input_data = None
if ipfs_link_legacy in results:
input_data = results[ipfs_link_legacy]
if ipfs_link in results:
input_data = results[ipfs_link]
if input_data == None:
raise DGPUComputeError('Couldn\'t gather input data from ipfs')
return input_data, input_type

View File

@ -39,7 +39,7 @@ def validate_user_config_request(req: str):
case 'model' | 'algo': case 'model' | 'algo':
attr = 'model' attr = 'model'
val = params[2] val = params[2]
shorts = [model_info['short'] for model_info in MODELS.values()] shorts = [model_info.short for model_info in MODELS.values()]
if val not in shorts: if val not in shorts:
raise ConfigUnknownAlgorithm(f'no model named {val}') raise ConfigUnknownAlgorithm(f'no model named {val}')
@ -112,20 +112,10 @@ def validate_user_config_request(req: str):
def perform_auto_conf(config: dict) -> dict: def perform_auto_conf(config: dict) -> dict:
model = config['model'] model = MODELS[config['model']]
prefered_size_w = 512
prefered_size_h = 512
if 'xl' in model:
prefered_size_w = 1024
prefered_size_h = 1024
else:
prefered_size_w = 512
prefered_size_h = 512
config['step'] = random.randint(20, 35) config['step'] = random.randint(20, 35)
config['width'] = prefered_size_w config['width'] = model.size.w
config['height'] = prefered_size_h config['height'] = model.size.h
return config return config

View File

@ -116,7 +116,7 @@ class SkynetTelegramFrontend:
method: str, method: str,
params: dict, params: dict,
file_id: str | None = None, file_id: str | None = None,
binary_data: str = '' inputs: list[str] = []
) -> bool: ) -> bool:
if params['seed'] == None: if params['seed'] == None:
params['seed'] = random.randint(0, 0xFFFFFFFF) params['seed'] = random.randint(0, 0xFFFFFFFF)
@ -148,7 +148,7 @@ class SkynetTelegramFrontend:
{ {
'user': Name(self.account), 'user': Name(self.account),
'request_body': body, 'request_body': body,
'binary_data': binary_data, 'binary_data': inputs.joint(','),
'reward': asset_from_str(reward), 'reward': asset_from_str(reward),
'min_verification': 1 'min_verification': 1
}, },
@ -181,7 +181,7 @@ class SkynetTelegramFrontend:
request_id, nonce = out.split(':') request_id, nonce = out.split(':')
request_hash = sha256( request_hash = sha256(
(nonce + body + binary_data).encode('utf-8')).hexdigest().upper() (nonce + body + inputs.join(',')).encode('utf-8')).hexdigest().upper()
request_id = int(request_id) request_id = int(request_id)
@ -241,47 +241,29 @@ class SkynetTelegramFrontend:
user, params, tx_hash, worker, reward, self.explorer_domain) user, params, tx_hash, worker, reward, self.explorer_domain)
# attempt to get the image and send it # attempt to get the image and send it
results = {}
ipfs_link = f'https://{self.ipfs_domain}/ipfs/{ipfs_hash}' ipfs_link = f'https://{self.ipfs_domain}/ipfs/{ipfs_hash}'
ipfs_link_legacy = ipfs_link + '/image.png'
async def get_and_set_results(link: str): res = await get_ipfs_file(link)
res = await get_ipfs_file(link) logging.info(f'got response from {link}')
logging.info(f'got response from {link}') if not res or res.status_code != 200:
if not res or res.status_code != 200: logging.warning(f'couldn\'t get ipfs binary data at {link}!')
else:
try:
with Image.open(io.BytesIO(res.raw)) as image:
w, h = image.size
if w > TG_MAX_WIDTH or h > TG_MAX_HEIGHT:
logging.warning(f'result is of size {image.size}')
image.thumbnail((TG_MAX_WIDTH, TG_MAX_HEIGHT))
tmp_buf = io.BytesIO()
image.save(tmp_buf, format='PNG')
png_img = tmp_buf.getvalue()
except UnidentifiedImageError:
logging.warning(f'couldn\'t get ipfs binary data at {link}!') logging.warning(f'couldn\'t get ipfs binary data at {link}!')
else:
try:
with Image.open(io.BytesIO(res.raw)) as image:
w, h = image.size
if w > TG_MAX_WIDTH or h > TG_MAX_HEIGHT:
logging.warning(f'result is of size {image.size}')
image.thumbnail((TG_MAX_WIDTH, TG_MAX_HEIGHT))
tmp_buf = io.BytesIO()
image.save(tmp_buf, format='PNG')
png_img = tmp_buf.getvalue()
results[link] = png_img
except UnidentifiedImageError:
logging.warning(f'couldn\'t get ipfs binary data at {link}!')
tasks = [
get_and_set_results(ipfs_link),
get_and_set_results(ipfs_link_legacy)
]
await asyncio.gather(*tasks)
png_img = None
if ipfs_link_legacy in results:
png_img = results[ipfs_link_legacy]
if ipfs_link in results:
png_img = results[ipfs_link]
if not png_img: if not png_img:
await self.update_status_message( await self.update_status_message(
status_msg, status_msg,

View File

@ -66,8 +66,7 @@ def convert_from_bytes_and_crop(raw: bytes, max_w: int, max_h: int) -> Image:
def pipeline_for( def pipeline_for(
model: str, model: str,
mem_fraction: float = 1.0, mem_fraction: float = 1.0,
image: bool = False, mode: str = [],
inpainting: bool = False,
cache_dir: str | None = None cache_dir: str | None = None
) -> DiffusionPipeline: ) -> DiffusionPipeline:
@ -85,14 +84,14 @@ def pipeline_for(
model_info = MODELS[model] model_info = MODELS[model]
req_mem = model_info['mem'] req_mem = model_info.mem
mem_gb = torch.cuda.mem_get_info()[1] / (10**9) mem_gb = torch.cuda.mem_get_info()[1] / (10**9)
mem_gb *= mem_fraction mem_gb *= mem_fraction
over_mem = mem_gb < req_mem over_mem = mem_gb < req_mem
if over_mem: if over_mem:
logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..') logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..')
shortname = model_info['short'] shortname = model_info.short
params = { params = {
'safety_checker': None, 'safety_checker': None,
@ -107,13 +106,14 @@ def pipeline_for(
torch.cuda.set_per_process_memory_fraction(mem_fraction) torch.cuda.set_per_process_memory_fraction(mem_fraction)
if inpainting: if 'inpaint' in mode:
pipe = AutoPipelineForInpainting.from_pretrained( pipe_class = AutoPipelineForInpainting
model, **params)
else: else:
pipe = DiffusionPipeline.from_pretrained( pipe_class = DiffusionPipeline
model, **params)
pipe = AutoPipelineForInpainting.from_pretrained(
model, **params)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe.scheduler.config) pipe.scheduler.config)
@ -121,7 +121,7 @@ def pipeline_for(
pipe.enable_xformers_memory_efficient_attention() pipe.enable_xformers_memory_efficient_attention()
if over_mem: if over_mem:
if not image: if 'img2img' not in mode:
pipe.enable_vae_slicing() pipe.enable_vae_slicing()
pipe.enable_vae_tiling() pipe.enable_vae_tiling()
@ -140,7 +140,7 @@ def pipeline_for(
def txt2img( def txt2img(
hf_token: str, hf_token: str,
model: str = 'prompthero/openjourney', model: str = list(MODELS.keys())[-1],
prompt: str = 'a red old tractor in a sunny wheat field', prompt: str = 'a red old tractor in a sunny wheat field',
output: str = 'output.png', output: str = 'output.png',
width: int = 512, height: int = 512, width: int = 512, height: int = 512,
@ -166,7 +166,7 @@ def txt2img(
def img2img( def img2img(
hf_token: str, hf_token: str,
model: str = 'prompthero/openjourney', model: str = list(MODELS.keys())[-2],
prompt: str = 'a red old tractor in a sunny wheat field', prompt: str = 'a red old tractor in a sunny wheat field',
img_path: str = 'input.png', img_path: str = 'input.png',
output: str = 'output.png', output: str = 'output.png',
@ -181,7 +181,7 @@ def img2img(
model_info = MODELS[model] model_info = MODELS[model]
with open(img_path, 'rb') as img_file: with open(img_path, 'rb') as img_file:
input_img = convert_from_bytes_and_crop(img_file.read(), model_info['size']['w'], model_info['size']['h']) input_img = convert_from_bytes_and_crop(img_file.read(), model_info.size.w, model_info.size.h)
seed = seed if seed else random.randint(0, 2 ** 64) seed = seed if seed else random.randint(0, 2 ** 64)
prompt = prompt prompt = prompt
@ -198,7 +198,7 @@ def img2img(
def inpaint( def inpaint(
hf_token: str, hf_token: str,
model: str = 'diffusers/stable-diffusion-xl-1.0-inpainting-0.1', model: str = list(MODELS.keys())[-3],
prompt: str = 'a red old tractor in a sunny wheat field', prompt: str = 'a red old tractor in a sunny wheat field',
img_path: str = 'input.png', img_path: str = 'input.png',
mask_path: str = 'mask.png', mask_path: str = 'mask.png',
@ -214,10 +214,10 @@ def inpaint(
model_info = MODELS[model] model_info = MODELS[model]
with open(img_path, 'rb') as img_file: with open(img_path, 'rb') as img_file:
input_img = convert_from_bytes_and_crop(img_file.read(), model_info['size']['w'], model_info['size']['h']) input_img = convert_from_bytes_and_crop(img_file.read(), model_info.size.w, model_info.size.h)
with open(mask_path, 'rb') as mask_file: with open(mask_path, 'rb') as mask_file:
mask_img = convert_from_bytes_and_crop(mask_file.read(), model_info['size']['w'], model_info['size']['h']) mask_img = convert_from_bytes_and_crop(mask_file.read(), model_info.size.w, model_info.size.h)
seed = seed if seed else random.randint(0, 2 ** 64) seed = seed if seed else random.randint(0, 2 ** 64)
prompt = prompt prompt = prompt

98
tests/test_reqs.py 100644
View File

@ -0,0 +1,98 @@
from skynet.config import *
async def test_txt2img():
req = {
'id': 0,
'body': json.dumps({
"method": "txt2img",
"params": {
"prompt": "Kronos God Realistic 4k",
"model": list(MODELS.keys())[-1],
"step": 21,
"width": 1024,
"height": 1024,
"seed": 168402949,
"guidance": "7.5"
}
}),
'inputs': [],
}
config = load_skynet_toml(file_path=config_path)
hf_token = load_key(config, 'skynet.dgpu.hf_token')
hf_home = load_key(config, 'skynet.dgpu.hf_home')
set_hf_vars(hf_token, hf_home)
assert 'skynet' in config
assert 'dgpu' in config['skynet']
mm = SkynetMM(config['skynet']['dgpu'])
mm.maybe_serve_one(req)
async def test_img2img():
req = {
'id': 0,
'body': json.dumps({
"method": "img2img",
"params": {
"prompt": "Kronos God Realistic 4k",
"model": list(MODELS.keys())[-2],
"step": 21,
"width": 1024,
"height": 1024,
"seed": 168402949,
"guidance": "7.5",
"strength": "0.5"
}
}),
'inputs': ['QmZcGdXXVQfpco1G3tr2CGFBtv8xVsCwcwuq9gnJBWDymi'],
}
config = load_skynet_toml(file_path=config_path)
hf_token = load_key(config, 'skynet.dgpu.hf_token')
hf_home = load_key(config, 'skynet.dgpu.hf_home')
set_hf_vars(hf_token, hf_home)
assert 'skynet' in config
assert 'dgpu' in config['skynet']
mm = SkynetMM(config['skynet']['dgpu'])
mm.maybe_serve_one(req)
async def test_inpaint():
req = {
'id': 0,
'body': json.dumps({
"method": "inpaint",
"params": {
"prompt": "a black panther on a sunny roof",
"model": list(MODELS.keys())[-3],
"step": 21,
"width": 1024,
"height": 1024,
"seed": 168402949,
"guidance": "7.5",
"strength": "0.5"
}
}),
'inputs': [
'QmZcGdXXVQfpco1G3tr2CGFBtv8xVsCwcwuq9gnJBWDymi',
'Qmccx1aXNmq5mZDS3YviUhgGHXWhQeHvca3AgA7MDjj2hR'
],
}
config = load_skynet_toml(file_path=config_path)
hf_token = load_key(config, 'skynet.dgpu.hf_token')
hf_home = load_key(config, 'skynet.dgpu.hf_home')
set_hf_vars(hf_token, hf_home)
assert 'skynet' in config
assert 'dgpu' in config['skynet']
mm = SkynetMM(config['skynet']['dgpu'])
mm.maybe_serve_one(req)