mirror of https://github.com/skygpu/skynet.git
Frontend db model name related fixes, and gpu worker fixes when swapping models
parent
aa41c08d2f
commit
91edb2aa56
|
@ -16,6 +16,11 @@ MODELS = {
|
|||
'nousr/robo-diffusion': { 'short': 'robot'}
|
||||
}
|
||||
|
||||
SHORT_NAMES = [
|
||||
model_info['short']
|
||||
for model_info in MODELS.values()
|
||||
]
|
||||
|
||||
def get_model_by_shortname(short: str):
|
||||
for model, info in MODELS.items():
|
||||
if short == info['short']:
|
||||
|
@ -40,8 +45,9 @@ config is individual to each user!
|
|||
/donate - see donation info
|
||||
|
||||
/config algo NAME - select AI to use one of:
|
||||
/config model NAME - select AI to use one of:
|
||||
|
||||
{N.join(MODELS.keys())}
|
||||
{N.join(SHORT_NAMES)}
|
||||
|
||||
/config step NUMBER - set amount of iterations
|
||||
/config seed NUMBER - set the seed, deterministic results!
|
||||
|
@ -114,7 +120,7 @@ DEFAULT_GUIDANCE = 7.5
|
|||
DEFAULT_STRENGTH = 0.5
|
||||
DEFAULT_STEP = 35
|
||||
DEFAULT_CREDITS = 10
|
||||
DEFAULT_ALGO = 'midj'
|
||||
DEFAULT_MODEL = list(MODELS.keys())[0]
|
||||
DEFAULT_ROLE = 'pleb'
|
||||
DEFAULT_UPSCALER = None
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ CREATE TABLE IF NOT EXISTS skynet.user(
|
|||
|
||||
CREATE TABLE IF NOT EXISTS skynet.user_config(
|
||||
id BIGSERIAL NOT NULL,
|
||||
algo VARCHAR(128) NOT NULL,
|
||||
model VARCHAR(512) NOT NULL,
|
||||
step INT NOT NULL,
|
||||
width INT NOT NULL,
|
||||
height INT NOT NULL,
|
||||
|
@ -278,13 +278,13 @@ async def new_user(conn, uid: int):
|
|||
|
||||
stmt = await conn.prepare('''
|
||||
INSERT INTO skynet.user_config(
|
||||
id, algo, step, width, height, guidance, strength, upscaler)
|
||||
id, model, step, width, height, guidance, strength, upscaler)
|
||||
|
||||
VALUES($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
''')
|
||||
resp = await stmt.fetch(
|
||||
uid,
|
||||
DEFAULT_ALGO,
|
||||
DEFAULT_MODEL,
|
||||
DEFAULT_STEP,
|
||||
DEFAULT_WIDTH,
|
||||
DEFAULT_HEIGHT,
|
||||
|
|
|
@ -6,6 +6,7 @@ import gc
|
|||
from hashlib import sha256
|
||||
import json
|
||||
import logging
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
import torch
|
||||
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
|
||||
|
@ -107,7 +108,7 @@ class SkynetMM:
|
|||
logging.info(f'loaded model {model_name}')
|
||||
return pipe
|
||||
|
||||
def get_model(self, model_name: str, image: bool):
|
||||
def get_model(self, model_name: str, image: bool) -> DiffusionPipeline:
|
||||
if model_name not in MODELS:
|
||||
raise DGPUComputeError(f'Unknown model {model_name}')
|
||||
|
||||
|
@ -115,7 +116,7 @@ class SkynetMM:
|
|||
pipe = self.load_model(model_name, image=image)
|
||||
|
||||
else:
|
||||
pipe = self._models[model_name]
|
||||
pipe = self._models[model_name]['pipe']
|
||||
|
||||
return pipe
|
||||
|
||||
|
@ -134,7 +135,7 @@ class SkynetMM:
|
|||
prompt, guidance, step, seed, upscaler, extra_params = arguments
|
||||
model = self.get_model(params['model'], 'image' in params)
|
||||
|
||||
image = model['pipe'](
|
||||
image = model(
|
||||
prompt,
|
||||
guidance_scale=guidance,
|
||||
num_inference_steps=step,
|
||||
|
|
|
@ -42,7 +42,7 @@ class SkynetDGPUDaemon:
|
|||
if rid not in my_results:
|
||||
statuses = await self.conn.get_status_by_request_id(rid)
|
||||
|
||||
if len(statuses) < req['min_verification']:
|
||||
if len(statuses) == 0:
|
||||
|
||||
# parse request
|
||||
body = json.loads(req['body'])
|
||||
|
|
|
@ -30,11 +30,12 @@ def validate_user_config_request(req: str):
|
|||
attr = params[1]
|
||||
|
||||
match attr:
|
||||
case 'algo':
|
||||
case 'model' | 'algo':
|
||||
attr = 'model'
|
||||
val = params[2]
|
||||
shorts = [model_info['short'] for model_info in MODELS.values()]
|
||||
if val not in shorts:
|
||||
raise ConfigUnknownAlgorithm(f'no algo named {val}')
|
||||
raise ConfigUnknownAlgorithm(f'no model named {val}')
|
||||
|
||||
val = get_model_by_shortname(val)
|
||||
|
||||
|
|
|
@ -151,9 +151,6 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
|
|||
**user_config
|
||||
}
|
||||
|
||||
params['model'] = get_model_by_shortname(params['algo'])
|
||||
del params['algo']
|
||||
|
||||
await db_call(
|
||||
'update_user_stats', user.id, 'txt2img', last_prompt=prompt)
|
||||
|
||||
|
@ -230,8 +227,6 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
|
|||
'prompt': prompt,
|
||||
**user_config
|
||||
}
|
||||
params['model'] = get_model_by_shortname(params['algo'])
|
||||
del params['algo']
|
||||
|
||||
await db_call(
|
||||
'update_user_stats',
|
||||
|
@ -302,8 +297,6 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
|
|||
'prompt': prompt,
|
||||
**user_config
|
||||
}
|
||||
params['model'] = get_model_by_shortname(params['algo'])
|
||||
del params['algo']
|
||||
|
||||
await work_request(
|
||||
user, status_msg, 'redo', params,
|
||||
|
|
Loading…
Reference in New Issue