Frontend db model name related fixes, and gpu worker fixes when swapping models

add-txt2txt-models
Guillermo Rodriguez 2023-06-06 12:27:40 -03:00
parent aa41c08d2f
commit 91edb2aa56
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
6 changed files with 19 additions and 18 deletions

View File

@ -16,6 +16,11 @@ MODELS = {
'nousr/robo-diffusion': { 'short': 'robot'} 'nousr/robo-diffusion': { 'short': 'robot'}
} }
SHORT_NAMES = [
model_info['short']
for model_info in MODELS.values()
]
def get_model_by_shortname(short: str): def get_model_by_shortname(short: str):
for model, info in MODELS.items(): for model, info in MODELS.items():
if short == info['short']: if short == info['short']:
@ -40,8 +45,9 @@ config is individual to each user!
/donate - see donation info /donate - see donation info
/config algo NAME - select AI to use one of: /config algo NAME - select AI to use one of:
/config model NAME - select AI to use one of:
{N.join(MODELS.keys())} {N.join(SHORT_NAMES)}
/config step NUMBER - set amount of iterations /config step NUMBER - set amount of iterations
/config seed NUMBER - set the seed, deterministic results! /config seed NUMBER - set the seed, deterministic results!
@ -114,7 +120,7 @@ DEFAULT_GUIDANCE = 7.5
DEFAULT_STRENGTH = 0.5 DEFAULT_STRENGTH = 0.5
DEFAULT_STEP = 35 DEFAULT_STEP = 35
DEFAULT_CREDITS = 10 DEFAULT_CREDITS = 10
DEFAULT_ALGO = 'midj' DEFAULT_MODEL = list(MODELS.keys())[0]
DEFAULT_ROLE = 'pleb' DEFAULT_ROLE = 'pleb'
DEFAULT_UPSCALER = None DEFAULT_UPSCALER = None

View File

@ -35,7 +35,7 @@ CREATE TABLE IF NOT EXISTS skynet.user(
CREATE TABLE IF NOT EXISTS skynet.user_config( CREATE TABLE IF NOT EXISTS skynet.user_config(
id BIGSERIAL NOT NULL, id BIGSERIAL NOT NULL,
algo VARCHAR(128) NOT NULL, model VARCHAR(512) NOT NULL,
step INT NOT NULL, step INT NOT NULL,
width INT NOT NULL, width INT NOT NULL,
height INT NOT NULL, height INT NOT NULL,
@ -278,13 +278,13 @@ async def new_user(conn, uid: int):
stmt = await conn.prepare(''' stmt = await conn.prepare('''
INSERT INTO skynet.user_config( INSERT INTO skynet.user_config(
id, algo, step, width, height, guidance, strength, upscaler) id, model, step, width, height, guidance, strength, upscaler)
VALUES($1, $2, $3, $4, $5, $6, $7, $8) VALUES($1, $2, $3, $4, $5, $6, $7, $8)
''') ''')
resp = await stmt.fetch( resp = await stmt.fetch(
uid, uid,
DEFAULT_ALGO, DEFAULT_MODEL,
DEFAULT_STEP, DEFAULT_STEP,
DEFAULT_WIDTH, DEFAULT_WIDTH,
DEFAULT_HEIGHT, DEFAULT_HEIGHT,

View File

@ -6,6 +6,7 @@ import gc
from hashlib import sha256 from hashlib import sha256
import json import json
import logging import logging
from diffusers import DiffusionPipeline
import torch import torch
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
@ -107,7 +108,7 @@ class SkynetMM:
logging.info(f'loaded model {model_name}') logging.info(f'loaded model {model_name}')
return pipe return pipe
def get_model(self, model_name: str, image: bool): def get_model(self, model_name: str, image: bool) -> DiffusionPipeline:
if model_name not in MODELS: if model_name not in MODELS:
raise DGPUComputeError(f'Unknown model {model_name}') raise DGPUComputeError(f'Unknown model {model_name}')
@ -115,7 +116,7 @@ class SkynetMM:
pipe = self.load_model(model_name, image=image) pipe = self.load_model(model_name, image=image)
else: else:
pipe = self._models[model_name] pipe = self._models[model_name]['pipe']
return pipe return pipe
@ -134,7 +135,7 @@ class SkynetMM:
prompt, guidance, step, seed, upscaler, extra_params = arguments prompt, guidance, step, seed, upscaler, extra_params = arguments
model = self.get_model(params['model'], 'image' in params) model = self.get_model(params['model'], 'image' in params)
image = model['pipe']( image = model(
prompt, prompt,
guidance_scale=guidance, guidance_scale=guidance,
num_inference_steps=step, num_inference_steps=step,

View File

@ -42,7 +42,7 @@ class SkynetDGPUDaemon:
if rid not in my_results: if rid not in my_results:
statuses = await self.conn.get_status_by_request_id(rid) statuses = await self.conn.get_status_by_request_id(rid)
if len(statuses) < req['min_verification']: if len(statuses) == 0:
# parse request # parse request
body = json.loads(req['body']) body = json.loads(req['body'])

View File

@ -30,11 +30,12 @@ def validate_user_config_request(req: str):
attr = params[1] attr = params[1]
match attr: match attr:
case 'algo': case 'model' | 'algo':
attr = 'model'
val = params[2] val = params[2]
shorts = [model_info['short'] for model_info in MODELS.values()] shorts = [model_info['short'] for model_info in MODELS.values()]
if val not in shorts: if val not in shorts:
raise ConfigUnknownAlgorithm(f'no algo named {val}') raise ConfigUnknownAlgorithm(f'no model named {val}')
val = get_model_by_shortname(val) val = get_model_by_shortname(val)

View File

@ -151,9 +151,6 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
**user_config **user_config
} }
params['model'] = get_model_by_shortname(params['algo'])
del params['algo']
await db_call( await db_call(
'update_user_stats', user.id, 'txt2img', last_prompt=prompt) 'update_user_stats', user.id, 'txt2img', last_prompt=prompt)
@ -230,8 +227,6 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
'prompt': prompt, 'prompt': prompt,
**user_config **user_config
} }
params['model'] = get_model_by_shortname(params['algo'])
del params['algo']
await db_call( await db_call(
'update_user_stats', 'update_user_stats',
@ -302,8 +297,6 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
'prompt': prompt, 'prompt': prompt,
**user_config **user_config
} }
params['model'] = get_model_by_shortname(params['algo'])
del params['algo']
await work_request( await work_request(
user, status_msg, 'redo', params, user, status_msg, 'redo', params,