mirror of https://github.com/skygpu/skynet.git
Add more ais, and start upscaler config and pipeline
parent
3a9e612695
commit
91e0693e65
|
@ -1,11 +1,14 @@
|
|||
from pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime
|
||||
from pytorch/pytorch:latest
|
||||
|
||||
env DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
run apt-get update && apt-get install -y git wget
|
||||
|
||||
run conda install xformers -c xformers/label/dev
|
||||
|
||||
run pip install --upgrade \
|
||||
diffusers[torch] \
|
||||
accelerate \
|
||||
transformers \
|
||||
huggingface_hub \
|
||||
pyTelegramBotAPI \
|
||||
|
@ -13,6 +16,8 @@ run pip install --upgrade \
|
|||
scipy \
|
||||
pdbpp
|
||||
|
||||
env NVIDIA_VISIBLE_DEVICES=all
|
||||
|
||||
run mkdir /scripts
|
||||
run mkdir /outputs
|
||||
run mkdir /inputs
|
||||
|
|
|
@ -27,7 +27,7 @@ from datetime import datetime
|
|||
|
||||
from pymongo import MongoClient
|
||||
|
||||
from typing import Tuple
|
||||
from typing import Tuple, Optional
|
||||
|
||||
db_user = os.environ['DB_USER']
|
||||
db_pass = os.environ['DB_PASS']
|
||||
|
@ -39,9 +39,13 @@ MEM_FRACTION = .33
|
|||
ALGOS = {
|
||||
'stable': 'runwayml/stable-diffusion-v1-5',
|
||||
'midj': 'prompthero/openjourney',
|
||||
'hdanime': 'Linaqruf/anything-v3.0',
|
||||
'waifu': 'hakurei/waifu-diffusion',
|
||||
'ghibli': 'nitrosocke/Ghibli-Diffusion',
|
||||
'van-gogh': 'dallinmackay/Van-Gogh-diffusion',
|
||||
'pokemon': 'lambdalabs/sd-pokemon-diffusers'
|
||||
'pokemon': 'lambdalabs/sd-pokemon-diffusers',
|
||||
'ink': 'Envvi/Inkpunk-Diffusion',
|
||||
'robot': 'nousr/robo-diffusion'
|
||||
}
|
||||
|
||||
N = '\n'
|
||||
|
@ -111,6 +115,7 @@ DEFAULT_STEP = 75
|
|||
DEFAULT_CREDITS = 10
|
||||
DEFAULT_ALGO = 'stable'
|
||||
DEFAULT_ROLE = 'pleb'
|
||||
DEFAULT_UPSCALER = None
|
||||
|
||||
rr_total = 1
|
||||
rr_id = 0
|
||||
|
@ -141,38 +146,49 @@ def generate_image(
|
|||
size: Tuple[int, int],
|
||||
guidance: int,
|
||||
seed: int,
|
||||
algo: str
|
||||
algo: str,
|
||||
upscaler: Optional[str]
|
||||
):
|
||||
assert torch.cuda.is_available()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.set_per_process_memory_fraction(MEM_FRACTION)
|
||||
with torch.no_grad():
|
||||
if algo == 'stable':
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
'runwayml/stable-diffusion-v1-5',
|
||||
torch_dtype=torch.float16,
|
||||
revision="fp16",
|
||||
safety_checker=None
|
||||
)
|
||||
|
||||
if algo == 'stable':
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
'runwayml/stable-diffusion-v1-5',
|
||||
torch_dtype=torch.float16,
|
||||
revision="fp16",
|
||||
safety_checker=None
|
||||
)
|
||||
else:
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
ALGOS[algo],
|
||||
torch_dtype=torch.float16,
|
||||
safety_checker=None
|
||||
)
|
||||
|
||||
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||
pipe = pipe.to("cuda")
|
||||
w, h = size
|
||||
print(f'generating image... of size {w, h}')
|
||||
image = pipe(
|
||||
prompt,
|
||||
width=w,
|
||||
height=h,
|
||||
guidance_scale=guidance, num_inference_steps=step,
|
||||
generator=torch.Generator("cuda").manual_seed(seed)
|
||||
).images[0]
|
||||
|
||||
if upscaler == 'x4':
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
'stabilityai/stable-diffusion-x4-upscaler',
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
image = pipe(prompt=prompt, image=image).images[0]
|
||||
|
||||
else:
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
ALGOS[algo],
|
||||
torch_dtype=torch.float16,
|
||||
safety_checker=None
|
||||
)
|
||||
|
||||
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||
pipe = pipe.to("cuda")
|
||||
w, h = size
|
||||
print(f'generating image... of size {w, h}')
|
||||
image = pipe(
|
||||
prompt,
|
||||
width=w,
|
||||
height=h,
|
||||
guidance_scale=guidance, num_inference_steps=step,
|
||||
generator=torch.Generator("cuda").manual_seed(seed)
|
||||
).images[0]
|
||||
image.save(f'/outputs/{name}.png')
|
||||
print('saved')
|
||||
|
||||
|
@ -220,7 +236,8 @@ if __name__ == '__main__':
|
|||
'step': DEFAULT_STEP,
|
||||
'size': DEFAULT_SIZE,
|
||||
'seed': None,
|
||||
'guidance': DEFAULT_GUIDANCE
|
||||
'guidance': DEFAULT_GUIDANCE,
|
||||
'upscaler': DEFAULT_UPSCALER
|
||||
}
|
||||
})
|
||||
|
||||
|
@ -234,11 +251,16 @@ if __name__ == '__main__':
|
|||
res = tg_users.find_one_and_update(
|
||||
{'uid': db_user['uid']}, {'$set': {'role': DEFAULT_ROLE}})
|
||||
|
||||
# new: ai selection
|
||||
# new: algo selection
|
||||
if 'algo' not in db_user['config']:
|
||||
res = tg_users.find_one_and_update(
|
||||
{'uid': db_user['uid']}, {'$set': {'config.algo': DEFAULT_ALGO}})
|
||||
|
||||
# new: upscaler selection
|
||||
if 'upscaler' not in db_user['config']:
|
||||
res = tg_users.find_one_and_update(
|
||||
{'uid': db_user['uid']}, {'$set': {'config.upscaler': DEFAULT_UPSCALER}})
|
||||
|
||||
return get_user(db_user['uid'])
|
||||
|
||||
def get_or_create_user(uid: int):
|
||||
|
@ -263,13 +285,14 @@ if __name__ == '__main__':
|
|||
# bot handler
|
||||
def img_for_user_with_prompt(
|
||||
uid: int,
|
||||
prompt: str, step: int, size: Tuple[int, int], guidance: int, seed: int, algo: str
|
||||
prompt: str, step: int, size: Tuple[int, int], guidance: int, seed: int,
|
||||
algo: str, upscaler: Optional[str]
|
||||
):
|
||||
name = uuid.uuid4()
|
||||
|
||||
spawn(
|
||||
generate_image,
|
||||
args=(prompt, name, step, size, guidance, seed, algo))
|
||||
args=(prompt, name, step, size, guidance, seed, algo, upscaler))
|
||||
|
||||
logging.info(f'done generating. got {name}, sending...')
|
||||
|
||||
|
@ -324,10 +347,11 @@ if __name__ == '__main__':
|
|||
size = user_conf['size']
|
||||
seed = user_conf['seed'] if user_conf['seed'] else random.randint(0, 999999999)
|
||||
guidance = user_conf['guidance']
|
||||
upscaler = user_conf['upscaler']
|
||||
|
||||
try:
|
||||
reply_txt, name = img_for_user_with_prompt(
|
||||
user.id, prompt, step, size, guidance, seed, algo)
|
||||
user.id, prompt, step, size, guidance, seed, algo, upscaler)
|
||||
|
||||
update_user(
|
||||
user.id,
|
||||
|
@ -370,12 +394,13 @@ if __name__ == '__main__':
|
|||
size = user_conf['size']
|
||||
seed = user_conf['seed'] if user_conf['seed'] else random.randint(0, 999999999)
|
||||
guidance = user_conf['guidance']
|
||||
upscaler = user_conf['upscaler']
|
||||
|
||||
logging.info(f"{user.first_name} ({user.id}) on chat {chat.id} redo: {prompt}")
|
||||
|
||||
try:
|
||||
reply_txt, name = img_for_user_with_prompt(
|
||||
user.id, prompt, step, size, guidance, seed, algo)
|
||||
user.id, prompt, step, size, guidance, seed, algo, upscaler)
|
||||
|
||||
update_user(
|
||||
user.id,
|
||||
|
@ -447,6 +472,13 @@ if __name__ == '__main__':
|
|||
val = max(min(val, MAX_GUIDANCE), 0)
|
||||
res = update_user(user.id, {'$set': {'config.guidance': val}})
|
||||
|
||||
elif attr == 'upscaler':
|
||||
val = params[2]
|
||||
if val == 'off':
|
||||
val = None
|
||||
|
||||
res = update_user(user.id, {'$set': {'config.upscaler': val}})
|
||||
|
||||
else:
|
||||
bot.reply_to(message, f'\"{attr}\" not a parameter')
|
||||
|
||||
|
@ -501,4 +533,5 @@ if __name__ == '__main__':
|
|||
|
||||
|
||||
login(token=os.environ['HF_TOKEN'])
|
||||
|
||||
bot.infinity_polling()
|
||||
|
|
Loading…
Reference in New Issue