Add strength parameter for img2img

pull/4/head v0.1a7
Guillermo Rodriguez 2023-01-18 07:04:08 -03:00
parent 8427165a76
commit aaecd41fb6
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
7 changed files with 47 additions and 3 deletions

View File

@ -111,6 +111,7 @@ DEFAULT_SEED = None
DEFAULT_WIDTH = 512 DEFAULT_WIDTH = 512
DEFAULT_HEIGHT = 512 DEFAULT_HEIGHT = 512
DEFAULT_GUIDANCE = 7.5 DEFAULT_GUIDANCE = 7.5
DEFAULT_STRENGTH = 0.5
DEFAULT_STEP = 35 DEFAULT_STEP = 35
DEFAULT_CREDITS = 10 DEFAULT_CREDITS = 10
DEFAULT_ALGO = 'midj' DEFAULT_ALGO = 'midj'

View File

@ -49,7 +49,8 @@ CREATE TABLE IF NOT EXISTS skynet.user_config(
width INT NOT NULL, width INT NOT NULL,
height INT NOT NULL, height INT NOT NULL,
seed BIGINT, seed BIGINT,
guidance INT NOT NULL, guidance REAL NOT NULL,
strength REAL NOT NULL,
upscaler VARCHAR(128) upscaler VARCHAR(128)
); );
ALTER TABLE skynet.user_config ALTER TABLE skynet.user_config
@ -173,9 +174,9 @@ async def new_user(conn, uid: str):
stmt = await conn.prepare(''' stmt = await conn.prepare('''
INSERT INTO skynet.user_config( INSERT INTO skynet.user_config(
id, algo, step, width, height, seed, guidance, upscaler) id, algo, step, width, height, seed, guidance, strength, upscaler)
VALUES($1, $2, $3, $4, $5, $6, $7, $8) VALUES($1, $2, $3, $4, $5, $6, $7, $8, $9)
ON CONFLICT DO NOTHING ON CONFLICT DO NOTHING
''') ''')
user = await stmt.fetch( user = await stmt.fetch(
@ -186,6 +187,7 @@ async def new_user(conn, uid: str):
DEFAULT_HEIGHT, DEFAULT_HEIGHT,
DEFAULT_SEED, DEFAULT_SEED,
DEFAULT_GUIDANCE, DEFAULT_GUIDANCE,
DEFAULT_STRENGTH,
DEFAULT_UPSCALER DEFAULT_UPSCALER
) )

View File

@ -160,6 +160,7 @@ async def open_dgpu_node(
_params = {} _params = {}
if ireq.image: if ireq.image:
_params['image'] = image _params['image'] = image
_params['strength'] = ireq.strength
else: else:
_params['width'] = int(ireq.width) _params['width'] = int(ireq.width)

View File

@ -152,6 +152,10 @@ def validate_user_config_request(req: str):
val = float(params[2]) val = float(params[2])
val = max(min(val, MAX_GUIDANCE), 0) val = max(min(val, MAX_GUIDANCE), 0)
case 'strength':
val = float(params[2])
val = max(min(val, 0.99), 0.01)
case 'upscaler': case 'upscaler':
val = params[2] val = params[2]
if val == 'off': if val == 'off':

View File

@ -37,6 +37,8 @@ def prepare_metainfo_caption(tguser, meta: dict) -> str:
meta_str += f'seed: {meta["seed"]}\n' meta_str += f'seed: {meta["seed"]}\n'
meta_str += f'step: {meta["step"]}\n' meta_str += f'step: {meta["step"]}\n'
meta_str += f'guidance: {meta["guidance"]}\n' meta_str += f'guidance: {meta["guidance"]}\n'
if meta['strength']:
meta_str += f'strength: {meta["strength"]}'
meta_str += f'algo: \"{meta["algo"]}\"\n' meta_str += f'algo: \"{meta["algo"]}\"\n'
if meta['upscaler']: if meta['upscaler']:
meta_str += f'upscaler: \"{meta["upscaler"]}\"\n' meta_str += f'upscaler: \"{meta["upscaler"]}\"\n'

33
skynet/models.py 100644
View File

@ -0,0 +1,33 @@
class ModelStore:
def __init__(
self,
max_models: int = 2
):
self.max_models = max_models
self._models = {}
def get(self, model_name: str):
if model_name in self._models:
return self._models[model_name]['pipe']
if len(self._models) == max_models:
least_used = list(self._models.keys())[0]
for model in self._models:
if self._models[least_used]['generated'] > self._models[model]['generated']:
least_used = model
del self._models[least_used]
gc.collect()
pipe = pipeline_for(model_name)
self._models[model_name] = {
'pipe': pipe,
'generated': 0
}
return pipe

View File

@ -23,6 +23,7 @@ class DiffusionParameters(Struct):
width: int width: int
height: int height: int
guidance: float guidance: float
strength: float
seed: Optional[int] seed: Optional[int]
image: bool # if true indicates a bytestream is next msg image: bool # if true indicates a bytestream is next msg
upscaler: Optional[str] upscaler: Optional[str]