mirror of https://github.com/skygpu/skynet.git
				
				
				
			
							parent
							
								
									8427165a76
								
							
						
					
					
						commit
						aaecd41fb6
					
				| 
						 | 
					@ -111,6 +111,7 @@ DEFAULT_SEED = None
 | 
				
			||||||
DEFAULT_WIDTH = 512
 | 
					DEFAULT_WIDTH = 512
 | 
				
			||||||
DEFAULT_HEIGHT = 512
 | 
					DEFAULT_HEIGHT = 512
 | 
				
			||||||
DEFAULT_GUIDANCE = 7.5
 | 
					DEFAULT_GUIDANCE = 7.5
 | 
				
			||||||
 | 
					DEFAULT_STRENGTH = 0.5
 | 
				
			||||||
DEFAULT_STEP = 35
 | 
					DEFAULT_STEP = 35
 | 
				
			||||||
DEFAULT_CREDITS = 10
 | 
					DEFAULT_CREDITS = 10
 | 
				
			||||||
DEFAULT_ALGO = 'midj'
 | 
					DEFAULT_ALGO = 'midj'
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -49,7 +49,8 @@ CREATE TABLE IF NOT EXISTS skynet.user_config(
 | 
				
			||||||
    width INT NOT NULL,
 | 
					    width INT NOT NULL,
 | 
				
			||||||
    height INT NOT NULL,
 | 
					    height INT NOT NULL,
 | 
				
			||||||
    seed BIGINT,
 | 
					    seed BIGINT,
 | 
				
			||||||
    guidance INT NOT NULL,
 | 
					    guidance REAL NOT NULL,
 | 
				
			||||||
 | 
					    strength REAL NOT NULL,
 | 
				
			||||||
    upscaler VARCHAR(128)
 | 
					    upscaler VARCHAR(128)
 | 
				
			||||||
);
 | 
					);
 | 
				
			||||||
ALTER TABLE skynet.user_config
 | 
					ALTER TABLE skynet.user_config
 | 
				
			||||||
| 
						 | 
					@ -173,9 +174,9 @@ async def new_user(conn, uid: str):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        stmt = await conn.prepare('''
 | 
					        stmt = await conn.prepare('''
 | 
				
			||||||
            INSERT INTO skynet.user_config(
 | 
					            INSERT INTO skynet.user_config(
 | 
				
			||||||
                id, algo, step, width, height, seed, guidance, upscaler)
 | 
					                id, algo, step, width, height, seed, guidance, strength, upscaler)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            VALUES($1, $2, $3, $4, $5, $6, $7, $8)
 | 
					            VALUES($1, $2, $3, $4, $5, $6, $7, $8, $9)
 | 
				
			||||||
            ON CONFLICT DO NOTHING
 | 
					            ON CONFLICT DO NOTHING
 | 
				
			||||||
        ''')
 | 
					        ''')
 | 
				
			||||||
        user = await stmt.fetch(
 | 
					        user = await stmt.fetch(
 | 
				
			||||||
| 
						 | 
					@ -186,6 +187,7 @@ async def new_user(conn, uid: str):
 | 
				
			||||||
            DEFAULT_HEIGHT,
 | 
					            DEFAULT_HEIGHT,
 | 
				
			||||||
            DEFAULT_SEED,
 | 
					            DEFAULT_SEED,
 | 
				
			||||||
            DEFAULT_GUIDANCE,
 | 
					            DEFAULT_GUIDANCE,
 | 
				
			||||||
 | 
					            DEFAULT_STRENGTH,
 | 
				
			||||||
            DEFAULT_UPSCALER
 | 
					            DEFAULT_UPSCALER
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -160,6 +160,7 @@ async def open_dgpu_node(
 | 
				
			||||||
        _params = {}
 | 
					        _params = {}
 | 
				
			||||||
        if ireq.image:
 | 
					        if ireq.image:
 | 
				
			||||||
            _params['image'] = image
 | 
					            _params['image'] = image
 | 
				
			||||||
 | 
					            _params['strength'] = ireq.strength
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            _params['width'] = int(ireq.width)
 | 
					            _params['width'] = int(ireq.width)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -152,6 +152,10 @@ def validate_user_config_request(req: str):
 | 
				
			||||||
                    val = float(params[2])
 | 
					                    val = float(params[2])
 | 
				
			||||||
                    val = max(min(val, MAX_GUIDANCE), 0)
 | 
					                    val = max(min(val, MAX_GUIDANCE), 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                case 'strength':
 | 
				
			||||||
 | 
					                    val = float(params[2])
 | 
				
			||||||
 | 
					                    val = max(min(val, 0.99), 0.01)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                case 'upscaler':
 | 
					                case 'upscaler':
 | 
				
			||||||
                    val = params[2]
 | 
					                    val = params[2]
 | 
				
			||||||
                    if val == 'off':
 | 
					                    if val == 'off':
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -37,6 +37,8 @@ def prepare_metainfo_caption(tguser, meta: dict) -> str:
 | 
				
			||||||
    meta_str += f'seed: {meta["seed"]}\n'
 | 
					    meta_str += f'seed: {meta["seed"]}\n'
 | 
				
			||||||
    meta_str += f'step: {meta["step"]}\n'
 | 
					    meta_str += f'step: {meta["step"]}\n'
 | 
				
			||||||
    meta_str += f'guidance: {meta["guidance"]}\n'
 | 
					    meta_str += f'guidance: {meta["guidance"]}\n'
 | 
				
			||||||
 | 
					    if meta['strength']:
 | 
				
			||||||
 | 
					        meta_str += f'strength: {meta["strength"]}'
 | 
				
			||||||
    meta_str += f'algo: \"{meta["algo"]}\"\n'
 | 
					    meta_str += f'algo: \"{meta["algo"]}\"\n'
 | 
				
			||||||
    if meta['upscaler']:
 | 
					    if meta['upscaler']:
 | 
				
			||||||
        meta_str += f'upscaler: \"{meta["upscaler"]}\"\n'
 | 
					        meta_str += f'upscaler: \"{meta["upscaler"]}\"\n'
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,33 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ModelStore:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        max_models: int = 2
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        self.max_models = max_models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._models = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get(self, model_name: str):
 | 
				
			||||||
 | 
					        if model_name in self._models:
 | 
				
			||||||
 | 
					            return self._models[model_name]['pipe']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if len(self._models) == max_models:
 | 
				
			||||||
 | 
					            least_used = list(self._models.keys())[0]
 | 
				
			||||||
 | 
					            for model in self._models:
 | 
				
			||||||
 | 
					                if self._models[least_used]['generated'] > self._models[model]['generated']:
 | 
				
			||||||
 | 
					                    least_used = model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            del self._models[least_used]
 | 
				
			||||||
 | 
					            gc.collect()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        pipe = pipeline_for(model_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._models[model_name] = {
 | 
				
			||||||
 | 
					            'pipe': pipe,
 | 
				
			||||||
 | 
					            'generated': 0
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return pipe
 | 
				
			||||||
| 
						 | 
					@ -23,6 +23,7 @@ class DiffusionParameters(Struct):
 | 
				
			||||||
    width: int
 | 
					    width: int
 | 
				
			||||||
    height: int
 | 
					    height: int
 | 
				
			||||||
    guidance: float
 | 
					    guidance: float
 | 
				
			||||||
 | 
					    strength: float
 | 
				
			||||||
    seed: Optional[int]
 | 
					    seed: Optional[int]
 | 
				
			||||||
    image: bool  # if true indicates a bytestream is next msg
 | 
					    image: bool  # if true indicates a bytestream is next msg
 | 
				
			||||||
    upscaler: Optional[str]
 | 
					    upscaler: Optional[str]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue