mirror of https://github.com/skygpu/skynet.git
Frontend db model name related fixes, and gpu worker fixes when swapping models
parent
aa41c08d2f
commit
91edb2aa56
|
@ -16,6 +16,11 @@ MODELS = {
|
||||||
'nousr/robo-diffusion': { 'short': 'robot'}
|
'nousr/robo-diffusion': { 'short': 'robot'}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SHORT_NAMES = [
|
||||||
|
model_info['short']
|
||||||
|
for model_info in MODELS.values()
|
||||||
|
]
|
||||||
|
|
||||||
def get_model_by_shortname(short: str):
|
def get_model_by_shortname(short: str):
|
||||||
for model, info in MODELS.items():
|
for model, info in MODELS.items():
|
||||||
if short == info['short']:
|
if short == info['short']:
|
||||||
|
@ -40,8 +45,9 @@ config is individual to each user!
|
||||||
/donate - see donation info
|
/donate - see donation info
|
||||||
|
|
||||||
/config algo NAME - select AI to use one of:
|
/config algo NAME - select AI to use one of:
|
||||||
|
/config model NAME - select AI to use one of:
|
||||||
|
|
||||||
{N.join(MODELS.keys())}
|
{N.join(SHORT_NAMES)}
|
||||||
|
|
||||||
/config step NUMBER - set amount of iterations
|
/config step NUMBER - set amount of iterations
|
||||||
/config seed NUMBER - set the seed, deterministic results!
|
/config seed NUMBER - set the seed, deterministic results!
|
||||||
|
@ -114,7 +120,7 @@ DEFAULT_GUIDANCE = 7.5
|
||||||
DEFAULT_STRENGTH = 0.5
|
DEFAULT_STRENGTH = 0.5
|
||||||
DEFAULT_STEP = 35
|
DEFAULT_STEP = 35
|
||||||
DEFAULT_CREDITS = 10
|
DEFAULT_CREDITS = 10
|
||||||
DEFAULT_ALGO = 'midj'
|
DEFAULT_MODEL = list(MODELS.keys())[0]
|
||||||
DEFAULT_ROLE = 'pleb'
|
DEFAULT_ROLE = 'pleb'
|
||||||
DEFAULT_UPSCALER = None
|
DEFAULT_UPSCALER = None
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,7 @@ CREATE TABLE IF NOT EXISTS skynet.user(
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS skynet.user_config(
|
CREATE TABLE IF NOT EXISTS skynet.user_config(
|
||||||
id BIGSERIAL NOT NULL,
|
id BIGSERIAL NOT NULL,
|
||||||
algo VARCHAR(128) NOT NULL,
|
model VARCHAR(512) NOT NULL,
|
||||||
step INT NOT NULL,
|
step INT NOT NULL,
|
||||||
width INT NOT NULL,
|
width INT NOT NULL,
|
||||||
height INT NOT NULL,
|
height INT NOT NULL,
|
||||||
|
@ -278,13 +278,13 @@ async def new_user(conn, uid: int):
|
||||||
|
|
||||||
stmt = await conn.prepare('''
|
stmt = await conn.prepare('''
|
||||||
INSERT INTO skynet.user_config(
|
INSERT INTO skynet.user_config(
|
||||||
id, algo, step, width, height, guidance, strength, upscaler)
|
id, model, step, width, height, guidance, strength, upscaler)
|
||||||
|
|
||||||
VALUES($1, $2, $3, $4, $5, $6, $7, $8)
|
VALUES($1, $2, $3, $4, $5, $6, $7, $8)
|
||||||
''')
|
''')
|
||||||
resp = await stmt.fetch(
|
resp = await stmt.fetch(
|
||||||
uid,
|
uid,
|
||||||
DEFAULT_ALGO,
|
DEFAULT_MODEL,
|
||||||
DEFAULT_STEP,
|
DEFAULT_STEP,
|
||||||
DEFAULT_WIDTH,
|
DEFAULT_WIDTH,
|
||||||
DEFAULT_HEIGHT,
|
DEFAULT_HEIGHT,
|
||||||
|
|
|
@ -6,6 +6,7 @@ import gc
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
|
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
|
||||||
|
@ -107,7 +108,7 @@ class SkynetMM:
|
||||||
logging.info(f'loaded model {model_name}')
|
logging.info(f'loaded model {model_name}')
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
def get_model(self, model_name: str, image: bool):
|
def get_model(self, model_name: str, image: bool) -> DiffusionPipeline:
|
||||||
if model_name not in MODELS:
|
if model_name not in MODELS:
|
||||||
raise DGPUComputeError(f'Unknown model {model_name}')
|
raise DGPUComputeError(f'Unknown model {model_name}')
|
||||||
|
|
||||||
|
@ -115,7 +116,7 @@ class SkynetMM:
|
||||||
pipe = self.load_model(model_name, image=image)
|
pipe = self.load_model(model_name, image=image)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
pipe = self._models[model_name]
|
pipe = self._models[model_name]['pipe']
|
||||||
|
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
|
@ -134,7 +135,7 @@ class SkynetMM:
|
||||||
prompt, guidance, step, seed, upscaler, extra_params = arguments
|
prompt, guidance, step, seed, upscaler, extra_params = arguments
|
||||||
model = self.get_model(params['model'], 'image' in params)
|
model = self.get_model(params['model'], 'image' in params)
|
||||||
|
|
||||||
image = model['pipe'](
|
image = model(
|
||||||
prompt,
|
prompt,
|
||||||
guidance_scale=guidance,
|
guidance_scale=guidance,
|
||||||
num_inference_steps=step,
|
num_inference_steps=step,
|
||||||
|
|
|
@ -42,7 +42,7 @@ class SkynetDGPUDaemon:
|
||||||
if rid not in my_results:
|
if rid not in my_results:
|
||||||
statuses = await self.conn.get_status_by_request_id(rid)
|
statuses = await self.conn.get_status_by_request_id(rid)
|
||||||
|
|
||||||
if len(statuses) < req['min_verification']:
|
if len(statuses) == 0:
|
||||||
|
|
||||||
# parse request
|
# parse request
|
||||||
body = json.loads(req['body'])
|
body = json.loads(req['body'])
|
||||||
|
|
|
@ -30,11 +30,12 @@ def validate_user_config_request(req: str):
|
||||||
attr = params[1]
|
attr = params[1]
|
||||||
|
|
||||||
match attr:
|
match attr:
|
||||||
case 'algo':
|
case 'model' | 'algo':
|
||||||
|
attr = 'model'
|
||||||
val = params[2]
|
val = params[2]
|
||||||
shorts = [model_info['short'] for model_info in MODELS.values()]
|
shorts = [model_info['short'] for model_info in MODELS.values()]
|
||||||
if val not in shorts:
|
if val not in shorts:
|
||||||
raise ConfigUnknownAlgorithm(f'no algo named {val}')
|
raise ConfigUnknownAlgorithm(f'no model named {val}')
|
||||||
|
|
||||||
val = get_model_by_shortname(val)
|
val = get_model_by_shortname(val)
|
||||||
|
|
||||||
|
|
|
@ -151,9 +151,6 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
|
||||||
**user_config
|
**user_config
|
||||||
}
|
}
|
||||||
|
|
||||||
params['model'] = get_model_by_shortname(params['algo'])
|
|
||||||
del params['algo']
|
|
||||||
|
|
||||||
await db_call(
|
await db_call(
|
||||||
'update_user_stats', user.id, 'txt2img', last_prompt=prompt)
|
'update_user_stats', user.id, 'txt2img', last_prompt=prompt)
|
||||||
|
|
||||||
|
@ -230,8 +227,6 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
|
||||||
'prompt': prompt,
|
'prompt': prompt,
|
||||||
**user_config
|
**user_config
|
||||||
}
|
}
|
||||||
params['model'] = get_model_by_shortname(params['algo'])
|
|
||||||
del params['algo']
|
|
||||||
|
|
||||||
await db_call(
|
await db_call(
|
||||||
'update_user_stats',
|
'update_user_stats',
|
||||||
|
@ -302,8 +297,6 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
|
||||||
'prompt': prompt,
|
'prompt': prompt,
|
||||||
**user_config
|
**user_config
|
||||||
}
|
}
|
||||||
params['model'] = get_model_by_shortname(params['algo'])
|
|
||||||
del params['algo']
|
|
||||||
|
|
||||||
await work_request(
|
await work_request(
|
||||||
user, status_msg, 'redo', params,
|
user, status_msg, 'redo', params,
|
||||||
|
|
Loading…
Reference in New Issue