From 91edb2aa56a718d10c884e5f635fd3d491403b6d Mon Sep 17 00:00:00 2001 From: Guillermo Rodriguez Date: Tue, 6 Jun 2023 12:27:40 -0300 Subject: [PATCH] Frontend db model name related fixes, and gpu worker fixes when swapping models --- skynet/constants.py | 10 ++++++++-- skynet/db/functions.py | 6 +++--- skynet/dgpu/compute.py | 7 ++++--- skynet/dgpu/daemon.py | 2 +- skynet/frontend/__init__.py | 5 +++-- skynet/frontend/telegram/handlers.py | 7 ------- 6 files changed, 19 insertions(+), 18 deletions(-) diff --git a/skynet/constants.py b/skynet/constants.py index 7e41d7a..a1e13f4 100644 --- a/skynet/constants.py +++ b/skynet/constants.py @@ -16,6 +16,11 @@ MODELS = { 'nousr/robo-diffusion': { 'short': 'robot'} } +SHORT_NAMES = [ + model_info['short'] + for model_info in MODELS.values() +] + def get_model_by_shortname(short: str): for model, info in MODELS.items(): if short == info['short']: @@ -40,8 +45,9 @@ config is individual to each user! /donate - see donation info /config algo NAME - select AI to use one of: +/config model NAME - select AI to use one of: -{N.join(MODELS.keys())} +{N.join(SHORT_NAMES)} /config step NUMBER - set amount of iterations /config seed NUMBER - set the seed, deterministic results! @@ -114,7 +120,7 @@ DEFAULT_GUIDANCE = 7.5 DEFAULT_STRENGTH = 0.5 DEFAULT_STEP = 35 DEFAULT_CREDITS = 10 -DEFAULT_ALGO = 'midj' +DEFAULT_MODEL = list(MODELS.keys())[0] DEFAULT_ROLE = 'pleb' DEFAULT_UPSCALER = None diff --git a/skynet/db/functions.py b/skynet/db/functions.py index c97bcf5..ac75a97 100644 --- a/skynet/db/functions.py +++ b/skynet/db/functions.py @@ -35,7 +35,7 @@ CREATE TABLE IF NOT EXISTS skynet.user( CREATE TABLE IF NOT EXISTS skynet.user_config( id BIGSERIAL NOT NULL, - algo VARCHAR(128) NOT NULL, + model VARCHAR(512) NOT NULL, step INT NOT NULL, width INT NOT NULL, height INT NOT NULL, @@ -278,13 +278,13 @@ async def new_user(conn, uid: int): stmt = await conn.prepare(''' INSERT INTO skynet.user_config( - id, algo, step, width, height, guidance, strength, upscaler) + id, model, step, width, height, guidance, strength, upscaler) VALUES($1, $2, $3, $4, $5, $6, $7, $8) ''') resp = await stmt.fetch( uid, - DEFAULT_ALGO, + DEFAULT_MODEL, DEFAULT_STEP, DEFAULT_WIDTH, DEFAULT_HEIGHT, diff --git a/skynet/dgpu/compute.py b/skynet/dgpu/compute.py index ce34910..a51072d 100644 --- a/skynet/dgpu/compute.py +++ b/skynet/dgpu/compute.py @@ -6,6 +6,7 @@ import gc from hashlib import sha256 import json import logging +from diffusers import DiffusionPipeline import torch from skynet.constants import DEFAULT_INITAL_MODELS, MODELS @@ -107,7 +108,7 @@ class SkynetMM: logging.info(f'loaded model {model_name}') return pipe - def get_model(self, model_name: str, image: bool): + def get_model(self, model_name: str, image: bool) -> DiffusionPipeline: if model_name not in MODELS: raise DGPUComputeError(f'Unknown model {model_name}') @@ -115,7 +116,7 @@ class SkynetMM: pipe = self.load_model(model_name, image=image) else: - pipe = self._models[model_name] + pipe = self._models[model_name]['pipe'] return pipe @@ -134,7 +135,7 @@ class SkynetMM: prompt, guidance, step, seed, upscaler, extra_params = arguments model = self.get_model(params['model'], 'image' in params) - image = model['pipe']( + image = model( prompt, guidance_scale=guidance, num_inference_steps=step, diff --git a/skynet/dgpu/daemon.py b/skynet/dgpu/daemon.py index 4897f43..216a800 100644 --- a/skynet/dgpu/daemon.py +++ b/skynet/dgpu/daemon.py @@ -42,7 +42,7 @@ class SkynetDGPUDaemon: if rid not in my_results: statuses = await self.conn.get_status_by_request_id(rid) - if len(statuses) < req['min_verification']: + if len(statuses) == 0: # parse request body = json.loads(req['body']) diff --git a/skynet/frontend/__init__.py b/skynet/frontend/__init__.py index 46ebd9f..42145a3 100644 --- a/skynet/frontend/__init__.py +++ b/skynet/frontend/__init__.py @@ -30,11 +30,12 @@ def validate_user_config_request(req: str): attr = params[1] match attr: - case 'algo': + case 'model' | 'algo': + attr = 'model' val = params[2] shorts = [model_info['short'] for model_info in MODELS.values()] if val not in shorts: - raise ConfigUnknownAlgorithm(f'no algo named {val}') + raise ConfigUnknownAlgorithm(f'no model named {val}') val = get_model_by_shortname(val) diff --git a/skynet/frontend/telegram/handlers.py b/skynet/frontend/telegram/handlers.py index 17d3213..7e77880 100644 --- a/skynet/frontend/telegram/handlers.py +++ b/skynet/frontend/telegram/handlers.py @@ -151,9 +151,6 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'): **user_config } - params['model'] = get_model_by_shortname(params['algo']) - del params['algo'] - await db_call( 'update_user_stats', user.id, 'txt2img', last_prompt=prompt) @@ -230,8 +227,6 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'): 'prompt': prompt, **user_config } - params['model'] = get_model_by_shortname(params['algo']) - del params['algo'] await db_call( 'update_user_stats', @@ -302,8 +297,6 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'): 'prompt': prompt, **user_config } - params['model'] = get_model_by_shortname(params['algo']) - del params['algo'] await work_request( user, status_msg, 'redo', params,