Frontend db model name related fixes, and gpu worker fixes when swapping models

add-txt2txt-models
Guillermo Rodriguez 2023-06-06 12:27:40 -03:00
parent aa41c08d2f
commit 91edb2aa56
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
6 changed files with 19 additions and 18 deletions

View File

@ -16,6 +16,11 @@ MODELS = {
'nousr/robo-diffusion': { 'short': 'robot'}
}
SHORT_NAMES = [
model_info['short']
for model_info in MODELS.values()
]
def get_model_by_shortname(short: str):
for model, info in MODELS.items():
if short == info['short']:
@ -40,8 +45,9 @@ config is individual to each user!
/donate - see donation info
/config algo NAME - select AI to use one of:
/config model NAME - select AI to use one of:
{N.join(MODELS.keys())}
{N.join(SHORT_NAMES)}
/config step NUMBER - set amount of iterations
/config seed NUMBER - set the seed, deterministic results!
@ -114,7 +120,7 @@ DEFAULT_GUIDANCE = 7.5
DEFAULT_STRENGTH = 0.5
DEFAULT_STEP = 35
DEFAULT_CREDITS = 10
DEFAULT_ALGO = 'midj'
DEFAULT_MODEL = list(MODELS.keys())[0]
DEFAULT_ROLE = 'pleb'
DEFAULT_UPSCALER = None

View File

@ -35,7 +35,7 @@ CREATE TABLE IF NOT EXISTS skynet.user(
CREATE TABLE IF NOT EXISTS skynet.user_config(
id BIGSERIAL NOT NULL,
algo VARCHAR(128) NOT NULL,
model VARCHAR(512) NOT NULL,
step INT NOT NULL,
width INT NOT NULL,
height INT NOT NULL,
@ -278,13 +278,13 @@ async def new_user(conn, uid: int):
stmt = await conn.prepare('''
INSERT INTO skynet.user_config(
id, algo, step, width, height, guidance, strength, upscaler)
id, model, step, width, height, guidance, strength, upscaler)
VALUES($1, $2, $3, $4, $5, $6, $7, $8)
''')
resp = await stmt.fetch(
uid,
DEFAULT_ALGO,
DEFAULT_MODEL,
DEFAULT_STEP,
DEFAULT_WIDTH,
DEFAULT_HEIGHT,

View File

@ -6,6 +6,7 @@ import gc
from hashlib import sha256
import json
import logging
from diffusers import DiffusionPipeline
import torch
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
@ -107,7 +108,7 @@ class SkynetMM:
logging.info(f'loaded model {model_name}')
return pipe
def get_model(self, model_name: str, image: bool):
def get_model(self, model_name: str, image: bool) -> DiffusionPipeline:
if model_name not in MODELS:
raise DGPUComputeError(f'Unknown model {model_name}')
@ -115,7 +116,7 @@ class SkynetMM:
pipe = self.load_model(model_name, image=image)
else:
pipe = self._models[model_name]
pipe = self._models[model_name]['pipe']
return pipe
@ -134,7 +135,7 @@ class SkynetMM:
prompt, guidance, step, seed, upscaler, extra_params = arguments
model = self.get_model(params['model'], 'image' in params)
image = model['pipe'](
image = model(
prompt,
guidance_scale=guidance,
num_inference_steps=step,

View File

@ -42,7 +42,7 @@ class SkynetDGPUDaemon:
if rid not in my_results:
statuses = await self.conn.get_status_by_request_id(rid)
if len(statuses) < req['min_verification']:
if len(statuses) == 0:
# parse request
body = json.loads(req['body'])

View File

@ -30,11 +30,12 @@ def validate_user_config_request(req: str):
attr = params[1]
match attr:
case 'algo':
case 'model' | 'algo':
attr = 'model'
val = params[2]
shorts = [model_info['short'] for model_info in MODELS.values()]
if val not in shorts:
raise ConfigUnknownAlgorithm(f'no algo named {val}')
raise ConfigUnknownAlgorithm(f'no model named {val}')
val = get_model_by_shortname(val)

View File

@ -151,9 +151,6 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
**user_config
}
params['model'] = get_model_by_shortname(params['algo'])
del params['algo']
await db_call(
'update_user_stats', user.id, 'txt2img', last_prompt=prompt)
@ -230,8 +227,6 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
'prompt': prompt,
**user_config
}
params['model'] = get_model_by_shortname(params['algo'])
del params['algo']
await db_call(
'update_user_stats',
@ -302,8 +297,6 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
'prompt': prompt,
**user_config
}
params['model'] = get_model_by_shortname(params['algo'])
del params['algo']
await work_request(
user, status_msg, 'redo', params,