Add strength parameter for img2img

pull/4/head v0.1a7
Guillermo Rodriguez 2023-01-18 07:04:08 -03:00
parent 8427165a76
commit aaecd41fb6
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
7 changed files with 47 additions and 3 deletions

View File

@ -111,6 +111,7 @@ DEFAULT_SEED = None
DEFAULT_WIDTH = 512
DEFAULT_HEIGHT = 512
DEFAULT_GUIDANCE = 7.5
DEFAULT_STRENGTH = 0.5
DEFAULT_STEP = 35
DEFAULT_CREDITS = 10
DEFAULT_ALGO = 'midj'

View File

@ -49,7 +49,8 @@ CREATE TABLE IF NOT EXISTS skynet.user_config(
width INT NOT NULL,
height INT NOT NULL,
seed BIGINT,
guidance INT NOT NULL,
guidance REAL NOT NULL,
strength REAL NOT NULL,
upscaler VARCHAR(128)
);
ALTER TABLE skynet.user_config
@ -173,9 +174,9 @@ async def new_user(conn, uid: str):
stmt = await conn.prepare('''
INSERT INTO skynet.user_config(
id, algo, step, width, height, seed, guidance, upscaler)
id, algo, step, width, height, seed, guidance, strength, upscaler)
VALUES($1, $2, $3, $4, $5, $6, $7, $8)
VALUES($1, $2, $3, $4, $5, $6, $7, $8, $9)
ON CONFLICT DO NOTHING
''')
user = await stmt.fetch(
@ -186,6 +187,7 @@ async def new_user(conn, uid: str):
DEFAULT_HEIGHT,
DEFAULT_SEED,
DEFAULT_GUIDANCE,
DEFAULT_STRENGTH,
DEFAULT_UPSCALER
)

View File

@ -160,6 +160,7 @@ async def open_dgpu_node(
_params = {}
if ireq.image:
_params['image'] = image
_params['strength'] = ireq.strength
else:
_params['width'] = int(ireq.width)

View File

@ -152,6 +152,10 @@ def validate_user_config_request(req: str):
val = float(params[2])
val = max(min(val, MAX_GUIDANCE), 0)
case 'strength':
val = float(params[2])
val = max(min(val, 0.99), 0.01)
case 'upscaler':
val = params[2]
if val == 'off':

View File

@ -37,6 +37,8 @@ def prepare_metainfo_caption(tguser, meta: dict) -> str:
meta_str += f'seed: {meta["seed"]}\n'
meta_str += f'step: {meta["step"]}\n'
meta_str += f'guidance: {meta["guidance"]}\n'
if meta['strength']:
meta_str += f'strength: {meta["strength"]}'
meta_str += f'algo: \"{meta["algo"]}\"\n'
if meta['upscaler']:
meta_str += f'upscaler: \"{meta["upscaler"]}\"\n'

33
skynet/models.py 100644
View File

@ -0,0 +1,33 @@
class ModelStore:
def __init__(
self,
max_models: int = 2
):
self.max_models = max_models
self._models = {}
def get(self, model_name: str):
if model_name in self._models:
return self._models[model_name]['pipe']
if len(self._models) == max_models:
least_used = list(self._models.keys())[0]
for model in self._models:
if self._models[least_used]['generated'] > self._models[model]['generated']:
least_used = model
del self._models[least_used]
gc.collect()
pipe = pipeline_for(model_name)
self._models[model_name] = {
'pipe': pipe,
'generated': 0
}
return pipe

View File

@ -23,6 +23,7 @@ class DiffusionParameters(Struct):
width: int
height: int
guidance: float
strength: float
seed: Optional[int]
image: bool # if true indicates a bytestream is next msg
upscaler: Optional[str]