From 8bd255717ef65ceaf45bad6a080cbdc3d8e1b3b3 Mon Sep 17 00:00:00 2001 From: Guillermo Rodriguez Date: Mon, 5 Dec 2022 19:36:21 -0300 Subject: [PATCH] Add roles and start working on algo selection, clean up rr code --- run-bot.sh | 2 +- scripts/telegram-bot-dev.py | 158 ++++++++++++++++++++++++------------ 2 files changed, 109 insertions(+), 51 deletions(-) diff --git a/run-bot.sh b/run-bot.sh index acc606b..00081b8 100755 --- a/run-bot.sh +++ b/run-bot.sh @@ -4,7 +4,7 @@ docker run \ --gpus=all \ --env HF_TOKEN='' \ --env DB_USER='skynet' \ - --env DB_PASS='password' \ + --env DB_PASS='nnf01nmf091d0i' \ --mount type=bind,source="$(pwd)"/outputs,target=/outputs \ --mount type=bind,source="$(pwd)"/hf_home,target=/hf_home \ --mount type=bind,source="$(pwd)"/scripts,target=/scripts \ diff --git a/scripts/telegram-bot-dev.py b/scripts/telegram-bot-dev.py index 62cba42..8937737 100644 --- a/scripts/telegram-bot-dev.py +++ b/scripts/telegram-bot-dev.py @@ -17,13 +17,17 @@ from pathlib import Path import torch from torch.multiprocessing.spawn import ProcessRaisedException -from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler +from diffusers import ( + FlaxStableDiffusionPipeline, + StableDiffusionPipeline, + EulerAncestralDiscreteScheduler +) from huggingface_hub import login from datetime import datetime from pymongo import MongoClient - +# import jax.numpy as jnp db_user = os.environ['DB_USER'] db_pass = os.environ['DB_PASS'] @@ -75,7 +79,11 @@ COOL_WORDS = [ 'michelangelo' ] -GROUP_ID = -889553587 +GROUP_ID = -1001541979235 + +ALGOS = ['stable', 'midj'] + +MP_ENABLED_ROLES = ['god'] MIN_STEP = 1 MAX_STEP = 100 @@ -86,6 +94,8 @@ DEFAULT_SIZE = (512, 512) DEFAULT_GUIDANCE = 7.5 DEFAULT_STEP = 75 DEFAULT_CREDITS = 10 +DEFAULT_ALGO = 'stable' +DEFAULT_ROLE = 'pleb' rr_total = 2 rr_id = 0 @@ -98,18 +108,47 @@ def its_my_turn(): request_counter += 1 return my_turn +def round_robined(func): + def rr_wrapper(*args, **kwargs): + if not its_my_turn(): + return -def generate_image(i, prompt, name, step, size, guidance, seed): + func(*args, **kwargs) + + return rr_wrapper + + +def generate_image( + i: int, + prompt: str, + name: str, + step: int, + size: tuple[int, int], + guidance: int, + seed: int, + algo: str +): assert torch.cuda.is_available() torch.cuda.empty_cache() torch.cuda.set_per_process_memory_fraction(MEM_FRACTION) - pipe = StableDiffusionPipeline.from_pretrained( - "runwayml/stable-diffusion-v1-5", - torch_dtype=torch.float16, - revision="fp16", - safety_checker=None - ) - pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) + + if algo == 'stable': + pipe = StableDiffusionPipeline.from_pretrained( + 'runwayml/stable-diffusion-v1-5', + torch_dtype=torch.float16, + revision="fp16", + safety_checker=None + ) + pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) + + elif algo == 'midj': + import jax as jnp + pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( + 'flax/midjourney-v4-diffusion', + revision="bf16", + dtype= jnp.bfloat16, + ) + pipe = pipe.to("cuda") w, h = size print(f'generating image... of size {w, h}') @@ -161,7 +200,9 @@ if __name__ == '__main__': 'credits': DEFAULT_CREDITS, 'joined': datetime.utcnow().isoformat(), 'last_prompt': None, + 'role': DEFAULT_ROLE, 'config': { + 'algo': DEFAULT_ALGO, 'step': DEFAULT_STEP, 'size': DEFAULT_SIZE, 'seed': None, @@ -173,6 +214,18 @@ if __name__ == '__main__': return get_user(uid) + def migrate_user(db_user): + # new: user roles + if 'role' not in db_user: + res = tg_users.find_one_and_update( + {'uid': db_user['uid']}, {'$set': {'role': DEFAULT_ROLE}}) + + # new: ai selection + if 'algo' not in db_user['config']: + res = tg_users.find_one_and_update( + {'uid': db_user['uid']}, {'$set': {'config.algo': DEFAULT_ALGO}}) + + return get_user(db_user['uid']) def get_or_create_user(uid: int): db_user = get_user(uid) @@ -180,8 +233,9 @@ if __name__ == '__main__': if not db_user: db_user = new_user(uid) - return db_user + logging.info(f'req from: {uid}') + return migrate_user(db_user) def update_user(uid: int, updt_cmd: dict): user = get_user(uid) @@ -195,13 +249,13 @@ if __name__ == '__main__': # bot handler def img_for_user_with_prompt( uid: int, - prompt: str, step: int, size: tuple[int, int], guidance: int, seed: int + prompt: str, step: int, size: tuple[int, int], guidance: int, seed: int, algo: str ): name = uuid.uuid4() spawn( generate_image, - args=(prompt, name, step, size, guidance, seed)) + args=(prompt, name, step, size, guidance, seed, algo)) logging.info(f'done generating. got {name}, sending...') @@ -221,23 +275,24 @@ if __name__ == '__main__': return reply_txt, name @bot.message_handler(commands=['help']) + @round_robined def send_help(message): - if its_my_turn(): - bot.reply_to(message, HELP_TEXT) + bot.reply_to(message, HELP_TEXT) @bot.message_handler(commands=['cool']) + @round_robined def send_cool_words(message): - if its_my_turn(): - bot.reply_to(message, '\n'.join(COOL_WORDS)) + bot.reply_to(message, '\n'.join(COOL_WORDS)) @bot.message_handler(commands=['txt2img']) + @round_robined def send_txt2img(message): - if not its_my_turn(): - return - - # check msg comes from testing group chat = message.chat - if chat.type != 'group' and chat.id != GROUP_ID: + user = message.from_user + db_user = get_or_create_user(user.id) + + if ((chat.type != 'group' and chat.id != GROUP_ID) and + (db_user['role'] not in MP_ENABLED_ROLES)): return prompt = ' '.join(message.text.split(' ')[1:]) @@ -246,13 +301,11 @@ if __name__ == '__main__': bot.reply_to(message, 'empty text prompt ignored.') return - user = message.from_user - db_user = get_or_create_user(user.id) - logging.info(f"{user.first_name} ({user.id}) on chat {chat.id} txt2img: {prompt}") user_conf = db_user['config'] + algo = user_conf['algo'] step = user_conf['step'] size = user_conf['size'] seed = user_conf['seed'] if user_conf['seed'] else random.randint(0, 999999999) @@ -260,7 +313,7 @@ if __name__ == '__main__': try: reply_txt, name = img_for_user_with_prompt( - user.id, prompt, step, size, guidance, seed) + user.id, prompt, step, size, guidance, seed, algo) update_user( user.id, @@ -279,18 +332,17 @@ if __name__ == '__main__': bot.reply_to(message, 'that command caused an error :(\nchange settings and try again (?') @bot.message_handler(commands=['redo']) + @round_robined def redo_txt2img(message): - if not its_my_turn(): - return - # check msg comes from testing group chat = message.chat - if chat.type != 'group' and chat.id != GROUP_ID: - return - user = message.from_user db_user = get_or_create_user(user.id) + if ((chat.type != 'group' and chat.id != GROUP_ID) and + (db_user['role'] not in MP_ENABLED_ROLES)): + return + prompt = db_user['last_prompt'] if not prompt: @@ -299,6 +351,7 @@ if __name__ == '__main__': user_conf = db_user['config'] + algo = user_conf['algo'] step = user_conf['step'] size = user_conf['size'] seed = user_conf['seed'] if user_conf['seed'] else random.randint(0, 999999999) @@ -308,7 +361,7 @@ if __name__ == '__main__': try: reply_txt, name = img_for_user_with_prompt( - user.id, prompt, step, size, guidance, seed) + user.id, prompt, step, size, guidance, seed, algo) update_user( user.id, @@ -326,9 +379,9 @@ if __name__ == '__main__': bot.reply_to(message, 'that command caused an error :(\nchange settings and try again (?') @bot.message_handler(commands=['config']) + @round_robined def set_config(message): - if not its_my_turn(): - return + logging.info(f'config req on chat: {message.chat.id}') params = message.text.split(' ') @@ -338,15 +391,17 @@ if __name__ == '__main__': else: user = message.from_user chat = message.chat - db_user = get_user(user.id) - - if not db_user: - db_user = new_user(user.id) + db_user = get_or_create_user(user.id) try: attr = params[1] - if attr == 'step': + if attr == 'algo': + val = params[2] + assert val in ALGOS + res = update_user(user.id, {'$set': {'config.algo': val}}) + + elif attr == 'step': val = int(params[2]) val = max(min(val, MAX_STEP), MIN_STEP) res = update_user(user.id, {'$set': {'config.step': val}}) @@ -378,27 +433,30 @@ if __name__ == '__main__': val = max(min(val, MAX_GUIDANCE), 0) res = update_user(user.id, {'$set': {'config.guidance': val}}) - bot.reply_to(message, f"config updated! {attr} to {val}") + else: + bot.reply_to(message, f'\"{attr}\" not a parameter') + + bot.reply_to(message, f'config updated! {attr} to {val}') except ValueError: - bot.reply_to(message, f"\"{val}\" is not a number silly") + bot.reply_to(message, f'\"{val}\" is not a number silly') + + except AssertionError: + bot.reply_to(message, f'no algo named {val}') @bot.message_handler(commands=['stats']) + @round_robined def user_stats(message): - if not its_my_turn(): - return - user = message.from_user - db_user = get_user(user.id) - - if not db_user: - db_user = new_user(user.id) + db_user = get_or_create_user(user.id) + migrate_user(db_user) joined_date_str = datetime.fromisoformat(db_user['joined']).strftime('%B the %dth %Y, %H:%M:%S') user_stats_str = f'generated: {db_user["generated"]}\n' user_stats_str += f'joined: {joined_date_str}\n' user_stats_str += f'credits: {db_user["credits"]}\n' + user_stats_str += f'role: {db_user["role"]}\n' bot.reply_to( message, user_stats_str)