mirror of https://github.com/skygpu/skynet.git
				
				
				
			Frontend db model name related fixes, and gpu worker fixes when swapping models
							parent
							
								
									aa41c08d2f
								
							
						
					
					
						commit
						91edb2aa56
					
				| 
						 | 
				
			
			@ -16,6 +16,11 @@ MODELS = {
 | 
			
		|||
    'nousr/robo-diffusion':            { 'short': 'robot'}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
SHORT_NAMES = [
 | 
			
		||||
    model_info['short']
 | 
			
		||||
    for model_info in MODELS.values()
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
def get_model_by_shortname(short: str):
 | 
			
		||||
    for model, info in MODELS.items():
 | 
			
		||||
        if short == info['short']:
 | 
			
		||||
| 
						 | 
				
			
			@ -40,8 +45,9 @@ config is individual to each user!
 | 
			
		|||
/donate - see donation info
 | 
			
		||||
 | 
			
		||||
/config algo NAME - select AI to use one of:
 | 
			
		||||
/config model NAME - select AI to use one of:
 | 
			
		||||
 | 
			
		||||
{N.join(MODELS.keys())}
 | 
			
		||||
{N.join(SHORT_NAMES)}
 | 
			
		||||
 | 
			
		||||
/config step NUMBER - set amount of iterations
 | 
			
		||||
/config seed NUMBER - set the seed, deterministic results!
 | 
			
		||||
| 
						 | 
				
			
			@ -114,7 +120,7 @@ DEFAULT_GUIDANCE = 7.5
 | 
			
		|||
DEFAULT_STRENGTH = 0.5
 | 
			
		||||
DEFAULT_STEP = 35
 | 
			
		||||
DEFAULT_CREDITS = 10
 | 
			
		||||
DEFAULT_ALGO = 'midj'
 | 
			
		||||
DEFAULT_MODEL = list(MODELS.keys())[0]
 | 
			
		||||
DEFAULT_ROLE = 'pleb'
 | 
			
		||||
DEFAULT_UPSCALER = None
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -35,7 +35,7 @@ CREATE TABLE IF NOT EXISTS skynet.user(
 | 
			
		|||
 | 
			
		||||
CREATE TABLE IF NOT EXISTS skynet.user_config(
 | 
			
		||||
    id BIGSERIAL NOT NULL,
 | 
			
		||||
    algo VARCHAR(128) NOT NULL,
 | 
			
		||||
    model VARCHAR(512) NOT NULL,
 | 
			
		||||
    step INT NOT NULL,
 | 
			
		||||
    width INT NOT NULL,
 | 
			
		||||
    height INT NOT NULL,
 | 
			
		||||
| 
						 | 
				
			
			@ -278,13 +278,13 @@ async def new_user(conn, uid: int):
 | 
			
		|||
 | 
			
		||||
        stmt = await conn.prepare('''
 | 
			
		||||
            INSERT INTO skynet.user_config(
 | 
			
		||||
                id, algo, step, width, height, guidance, strength, upscaler)
 | 
			
		||||
                id, model, step, width, height, guidance, strength, upscaler)
 | 
			
		||||
 | 
			
		||||
            VALUES($1, $2, $3, $4, $5, $6, $7, $8)
 | 
			
		||||
        ''')
 | 
			
		||||
        resp = await stmt.fetch(
 | 
			
		||||
            uid,
 | 
			
		||||
            DEFAULT_ALGO,
 | 
			
		||||
            DEFAULT_MODEL,
 | 
			
		||||
            DEFAULT_STEP,
 | 
			
		||||
            DEFAULT_WIDTH,
 | 
			
		||||
            DEFAULT_HEIGHT,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -6,6 +6,7 @@ import gc
 | 
			
		|||
from hashlib import sha256
 | 
			
		||||
import json
 | 
			
		||||
import logging
 | 
			
		||||
from diffusers import DiffusionPipeline
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
 | 
			
		||||
| 
						 | 
				
			
			@ -107,7 +108,7 @@ class SkynetMM:
 | 
			
		|||
        logging.info(f'loaded model {model_name}')
 | 
			
		||||
        return pipe
 | 
			
		||||
 | 
			
		||||
    def get_model(self, model_name: str, image: bool):
 | 
			
		||||
    def get_model(self, model_name: str, image: bool) -> DiffusionPipeline:
 | 
			
		||||
        if model_name not in MODELS:
 | 
			
		||||
            raise DGPUComputeError(f'Unknown model {model_name}')
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -115,7 +116,7 @@ class SkynetMM:
 | 
			
		|||
            pipe = self.load_model(model_name, image=image)
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            pipe = self._models[model_name]
 | 
			
		||||
            pipe = self._models[model_name]['pipe']
 | 
			
		||||
 | 
			
		||||
        return pipe
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -134,7 +135,7 @@ class SkynetMM:
 | 
			
		|||
                    prompt, guidance, step, seed, upscaler, extra_params = arguments
 | 
			
		||||
                    model = self.get_model(params['model'], 'image' in params)
 | 
			
		||||
 | 
			
		||||
                    image = model['pipe'](
 | 
			
		||||
                    image = model(
 | 
			
		||||
                        prompt,
 | 
			
		||||
                        guidance_scale=guidance,
 | 
			
		||||
                        num_inference_steps=step,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -42,7 +42,7 @@ class SkynetDGPUDaemon:
 | 
			
		|||
                    if rid not in my_results:
 | 
			
		||||
                        statuses = await self.conn.get_status_by_request_id(rid)
 | 
			
		||||
 | 
			
		||||
                        if len(statuses) < req['min_verification']:
 | 
			
		||||
                        if len(statuses) == 0:
 | 
			
		||||
 | 
			
		||||
                            # parse request
 | 
			
		||||
                            body = json.loads(req['body'])
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -30,11 +30,12 @@ def validate_user_config_request(req: str):
 | 
			
		|||
            attr = params[1]
 | 
			
		||||
 | 
			
		||||
            match attr:
 | 
			
		||||
                case 'algo':
 | 
			
		||||
                case 'model' | 'algo':
 | 
			
		||||
                    attr = 'model'
 | 
			
		||||
                    val = params[2]
 | 
			
		||||
                    shorts = [model_info['short'] for model_info in MODELS.values()]
 | 
			
		||||
                    if val not in shorts:
 | 
			
		||||
                        raise ConfigUnknownAlgorithm(f'no algo named {val}')
 | 
			
		||||
                        raise ConfigUnknownAlgorithm(f'no model named {val}')
 | 
			
		||||
 | 
			
		||||
                    val = get_model_by_shortname(val)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -151,9 +151,6 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
 | 
			
		|||
            **user_config
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        params['model'] = get_model_by_shortname(params['algo'])
 | 
			
		||||
        del params['algo']
 | 
			
		||||
 | 
			
		||||
        await db_call(
 | 
			
		||||
            'update_user_stats', user.id, 'txt2img', last_prompt=prompt)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -230,8 +227,6 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
 | 
			
		|||
            'prompt': prompt,
 | 
			
		||||
            **user_config
 | 
			
		||||
        }
 | 
			
		||||
        params['model'] = get_model_by_shortname(params['algo'])
 | 
			
		||||
        del params['algo']
 | 
			
		||||
 | 
			
		||||
        await db_call(
 | 
			
		||||
            'update_user_stats',
 | 
			
		||||
| 
						 | 
				
			
			@ -302,8 +297,6 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
 | 
			
		|||
            'prompt': prompt,
 | 
			
		||||
            **user_config
 | 
			
		||||
        }
 | 
			
		||||
        params['model'] = get_model_by_shortname(params['algo'])
 | 
			
		||||
        del params['algo']
 | 
			
		||||
 | 
			
		||||
        await work_request(
 | 
			
		||||
            user, status_msg, 'redo', params,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue