mirror of https://github.com/skygpu/skynet.git
				
				
				
			Add roles and start working on algo selection, clean up rr code
							parent
							
								
									31058c116a
								
							
						
					
					
						commit
						72149ae1b4
					
				| 
						 | 
				
			
			@ -4,7 +4,7 @@ docker run \
 | 
			
		|||
    --gpus=all \
 | 
			
		||||
    --env HF_TOKEN='' \
 | 
			
		||||
    --env DB_USER='skynet' \
 | 
			
		||||
    --env DB_PASS='password' \
 | 
			
		||||
    --env DB_PASS='nnf01nmf091d0i' \
 | 
			
		||||
    --mount type=bind,source="$(pwd)"/outputs,target=/outputs \
 | 
			
		||||
    --mount type=bind,source="$(pwd)"/hf_home,target=/hf_home \
 | 
			
		||||
    --mount type=bind,source="$(pwd)"/scripts,target=/scripts \
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -17,13 +17,17 @@ from pathlib import Path
 | 
			
		|||
 | 
			
		||||
import torch
 | 
			
		||||
from torch.multiprocessing.spawn import ProcessRaisedException
 | 
			
		||||
from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler
 | 
			
		||||
from diffusers import (
 | 
			
		||||
    FlaxStableDiffusionPipeline,
 | 
			
		||||
    StableDiffusionPipeline,
 | 
			
		||||
    EulerAncestralDiscreteScheduler
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from huggingface_hub import login
 | 
			
		||||
from datetime import datetime
 | 
			
		||||
 | 
			
		||||
from pymongo import MongoClient
 | 
			
		||||
 | 
			
		||||
#  import jax.numpy as jnp
 | 
			
		||||
 | 
			
		||||
db_user = os.environ['DB_USER']
 | 
			
		||||
db_pass = os.environ['DB_PASS']
 | 
			
		||||
| 
						 | 
				
			
			@ -75,7 +79,11 @@ COOL_WORDS = [
 | 
			
		|||
    'michelangelo'
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
GROUP_ID = -889553587
 | 
			
		||||
GROUP_ID = -1001541979235
 | 
			
		||||
 | 
			
		||||
ALGOS = ['stable', 'midj']
 | 
			
		||||
 | 
			
		||||
MP_ENABLED_ROLES = ['god']
 | 
			
		||||
 | 
			
		||||
MIN_STEP = 1
 | 
			
		||||
MAX_STEP = 100
 | 
			
		||||
| 
						 | 
				
			
			@ -86,6 +94,8 @@ DEFAULT_SIZE = (512, 512)
 | 
			
		|||
DEFAULT_GUIDANCE = 7.5
 | 
			
		||||
DEFAULT_STEP = 75
 | 
			
		||||
DEFAULT_CREDITS = 10
 | 
			
		||||
DEFAULT_ALGO = 'stable'
 | 
			
		||||
DEFAULT_ROLE = 'pleb'
 | 
			
		||||
 | 
			
		||||
rr_total = 2
 | 
			
		||||
rr_id = 0
 | 
			
		||||
| 
						 | 
				
			
			@ -98,18 +108,47 @@ def its_my_turn():
 | 
			
		|||
    request_counter += 1
 | 
			
		||||
    return my_turn
 | 
			
		||||
 | 
			
		||||
def round_robined(func):
 | 
			
		||||
    def rr_wrapper(*args, **kwargs):
 | 
			
		||||
        if not its_my_turn():
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
def generate_image(i, prompt, name, step, size, guidance, seed):
 | 
			
		||||
        func(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    return rr_wrapper
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_image(
 | 
			
		||||
    i: int,
 | 
			
		||||
    prompt: str,
 | 
			
		||||
    name: str,
 | 
			
		||||
    step: int,
 | 
			
		||||
    size: tuple[int, int],
 | 
			
		||||
    guidance: int,
 | 
			
		||||
    seed: int,
 | 
			
		||||
    algo: str
 | 
			
		||||
):
 | 
			
		||||
    assert torch.cuda.is_available()
 | 
			
		||||
    torch.cuda.empty_cache()
 | 
			
		||||
    torch.cuda.set_per_process_memory_fraction(MEM_FRACTION)
 | 
			
		||||
    pipe = StableDiffusionPipeline.from_pretrained(
 | 
			
		||||
        "runwayml/stable-diffusion-v1-5",
 | 
			
		||||
        torch_dtype=torch.float16,
 | 
			
		||||
        revision="fp16",
 | 
			
		||||
        safety_checker=None
 | 
			
		||||
    )
 | 
			
		||||
    pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
 | 
			
		||||
 | 
			
		||||
    if algo == 'stable':
 | 
			
		||||
        pipe = StableDiffusionPipeline.from_pretrained(
 | 
			
		||||
            'runwayml/stable-diffusion-v1-5',
 | 
			
		||||
            torch_dtype=torch.float16,
 | 
			
		||||
            revision="fp16",
 | 
			
		||||
            safety_checker=None
 | 
			
		||||
        )
 | 
			
		||||
        pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
 | 
			
		||||
 | 
			
		||||
    elif algo == 'midj':
 | 
			
		||||
        import jax as jnp
 | 
			
		||||
        pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
 | 
			
		||||
            'flax/midjourney-v4-diffusion',
 | 
			
		||||
            revision="bf16",
 | 
			
		||||
            dtype= jnp.bfloat16,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    pipe = pipe.to("cuda")
 | 
			
		||||
    w, h = size
 | 
			
		||||
    print(f'generating image... of size {w, h}')
 | 
			
		||||
| 
						 | 
				
			
			@ -161,7 +200,9 @@ if __name__ == '__main__':
 | 
			
		|||
            'credits': DEFAULT_CREDITS,
 | 
			
		||||
            'joined': datetime.utcnow().isoformat(),
 | 
			
		||||
            'last_prompt': None,
 | 
			
		||||
            'role': DEFAULT_ROLE,
 | 
			
		||||
            'config': {
 | 
			
		||||
                'algo': DEFAULT_ALGO,
 | 
			
		||||
                'step': DEFAULT_STEP,
 | 
			
		||||
                'size': DEFAULT_SIZE,
 | 
			
		||||
                'seed': None,
 | 
			
		||||
| 
						 | 
				
			
			@ -173,6 +214,18 @@ if __name__ == '__main__':
 | 
			
		|||
 | 
			
		||||
        return get_user(uid)
 | 
			
		||||
 | 
			
		||||
    def migrate_user(db_user):
 | 
			
		||||
        # new: user roles
 | 
			
		||||
        if 'role' not in db_user:
 | 
			
		||||
            res = tg_users.find_one_and_update(
 | 
			
		||||
                {'uid': db_user['uid']}, {'$set': {'role': DEFAULT_ROLE}})
 | 
			
		||||
 | 
			
		||||
        # new: ai selection
 | 
			
		||||
        if 'algo' not in db_user['config']:
 | 
			
		||||
            res = tg_users.find_one_and_update(
 | 
			
		||||
                {'uid': db_user['uid']}, {'$set': {'config.algo': DEFAULT_ALGO}})
 | 
			
		||||
 | 
			
		||||
        return get_user(db_user['uid'])
 | 
			
		||||
 | 
			
		||||
    def get_or_create_user(uid: int):
 | 
			
		||||
        db_user = get_user(uid)
 | 
			
		||||
| 
						 | 
				
			
			@ -180,8 +233,9 @@ if __name__ == '__main__':
 | 
			
		|||
        if not db_user:
 | 
			
		||||
            db_user = new_user(uid)
 | 
			
		||||
 | 
			
		||||
        return db_user
 | 
			
		||||
        logging.info(f'req from: {uid}')
 | 
			
		||||
 | 
			
		||||
        return migrate_user(db_user)
 | 
			
		||||
 | 
			
		||||
    def update_user(uid: int, updt_cmd: dict):
 | 
			
		||||
        user = get_user(uid)
 | 
			
		||||
| 
						 | 
				
			
			@ -195,13 +249,13 @@ if __name__ == '__main__':
 | 
			
		|||
    # bot handler
 | 
			
		||||
    def img_for_user_with_prompt(
 | 
			
		||||
        uid: int,
 | 
			
		||||
        prompt: str, step: int, size: tuple[int, int], guidance: int, seed: int
 | 
			
		||||
        prompt: str, step: int, size: tuple[int, int], guidance: int, seed: int, algo: str
 | 
			
		||||
    ):
 | 
			
		||||
        name = uuid.uuid4()
 | 
			
		||||
 | 
			
		||||
        spawn(
 | 
			
		||||
            generate_image,
 | 
			
		||||
            args=(prompt, name, step, size, guidance, seed))
 | 
			
		||||
            args=(prompt, name, step, size, guidance, seed, algo))
 | 
			
		||||
 | 
			
		||||
        logging.info(f'done generating. got {name}, sending...')
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -221,23 +275,24 @@ if __name__ == '__main__':
 | 
			
		|||
        return reply_txt, name
 | 
			
		||||
 | 
			
		||||
    @bot.message_handler(commands=['help'])
 | 
			
		||||
    @round_robined
 | 
			
		||||
    def send_help(message):
 | 
			
		||||
        if its_my_turn():
 | 
			
		||||
            bot.reply_to(message, HELP_TEXT)
 | 
			
		||||
        bot.reply_to(message, HELP_TEXT)
 | 
			
		||||
 | 
			
		||||
    @bot.message_handler(commands=['cool'])
 | 
			
		||||
    @round_robined
 | 
			
		||||
    def send_cool_words(message):
 | 
			
		||||
        if its_my_turn():
 | 
			
		||||
            bot.reply_to(message, '\n'.join(COOL_WORDS))
 | 
			
		||||
        bot.reply_to(message, '\n'.join(COOL_WORDS))
 | 
			
		||||
 | 
			
		||||
    @bot.message_handler(commands=['txt2img'])
 | 
			
		||||
    @round_robined
 | 
			
		||||
    def send_txt2img(message):
 | 
			
		||||
        if not its_my_turn():
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        # check msg comes from testing group
 | 
			
		||||
        chat = message.chat
 | 
			
		||||
        if chat.type != 'group' and chat.id != GROUP_ID:
 | 
			
		||||
        user = message.from_user
 | 
			
		||||
        db_user = get_or_create_user(user.id)
 | 
			
		||||
 | 
			
		||||
        if ((chat.type != 'group' and chat.id != GROUP_ID) and
 | 
			
		||||
                (db_user['role'] not in MP_ENABLED_ROLES)):
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        prompt = ' '.join(message.text.split(' ')[1:])
 | 
			
		||||
| 
						 | 
				
			
			@ -246,13 +301,11 @@ if __name__ == '__main__':
 | 
			
		|||
            bot.reply_to(message, 'empty text prompt ignored.')
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        user = message.from_user
 | 
			
		||||
        db_user = get_or_create_user(user.id)
 | 
			
		||||
 | 
			
		||||
        logging.info(f"{user.first_name} ({user.id}) on chat {chat.id} txt2img: {prompt}")
 | 
			
		||||
 | 
			
		||||
        user_conf = db_user['config']
 | 
			
		||||
 | 
			
		||||
        algo = user_conf['algo']
 | 
			
		||||
        step = user_conf['step']
 | 
			
		||||
        size = user_conf['size']
 | 
			
		||||
        seed = user_conf['seed'] if user_conf['seed'] else random.randint(0, 999999999)
 | 
			
		||||
| 
						 | 
				
			
			@ -260,7 +313,7 @@ if __name__ == '__main__':
 | 
			
		|||
 | 
			
		||||
        try:
 | 
			
		||||
            reply_txt, name = img_for_user_with_prompt(
 | 
			
		||||
                user.id, prompt, step, size, guidance, seed)
 | 
			
		||||
                user.id, prompt, step, size, guidance, seed, algo)
 | 
			
		||||
 | 
			
		||||
            update_user(
 | 
			
		||||
                user.id,
 | 
			
		||||
| 
						 | 
				
			
			@ -279,18 +332,17 @@ if __name__ == '__main__':
 | 
			
		|||
            bot.reply_to(message, 'that command caused an error :(\nchange settings and try again (?')
 | 
			
		||||
 | 
			
		||||
    @bot.message_handler(commands=['redo'])
 | 
			
		||||
    @round_robined
 | 
			
		||||
    def redo_txt2img(message):
 | 
			
		||||
        if not its_my_turn():
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        # check msg comes from testing group
 | 
			
		||||
        chat = message.chat
 | 
			
		||||
        if chat.type != 'group' and chat.id != GROUP_ID:
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        user = message.from_user
 | 
			
		||||
        db_user = get_or_create_user(user.id)
 | 
			
		||||
 | 
			
		||||
        if ((chat.type != 'group' and chat.id != GROUP_ID) and
 | 
			
		||||
                (db_user['role'] not in MP_ENABLED_ROLES)):
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        prompt = db_user['last_prompt']
 | 
			
		||||
 | 
			
		||||
        if not prompt:
 | 
			
		||||
| 
						 | 
				
			
			@ -299,6 +351,7 @@ if __name__ == '__main__':
 | 
			
		|||
 | 
			
		||||
        user_conf = db_user['config']
 | 
			
		||||
 | 
			
		||||
        algo = user_conf['algo']
 | 
			
		||||
        step = user_conf['step']
 | 
			
		||||
        size = user_conf['size']
 | 
			
		||||
        seed = user_conf['seed'] if user_conf['seed'] else random.randint(0, 999999999)
 | 
			
		||||
| 
						 | 
				
			
			@ -308,7 +361,7 @@ if __name__ == '__main__':
 | 
			
		|||
 | 
			
		||||
        try:
 | 
			
		||||
            reply_txt, name = img_for_user_with_prompt(
 | 
			
		||||
                user.id, prompt, step, size, guidance, seed)
 | 
			
		||||
                user.id, prompt, step, size, guidance, seed, algo)
 | 
			
		||||
 | 
			
		||||
            update_user(
 | 
			
		||||
                user.id,
 | 
			
		||||
| 
						 | 
				
			
			@ -326,9 +379,9 @@ if __name__ == '__main__':
 | 
			
		|||
            bot.reply_to(message, 'that command caused an error :(\nchange settings and try again (?')
 | 
			
		||||
 | 
			
		||||
    @bot.message_handler(commands=['config'])
 | 
			
		||||
    @round_robined
 | 
			
		||||
    def set_config(message):
 | 
			
		||||
        if not its_my_turn():
 | 
			
		||||
            return
 | 
			
		||||
        logging.info(f'config req on chat: {message.chat.id}')
 | 
			
		||||
 | 
			
		||||
        params = message.text.split(' ')
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -338,15 +391,17 @@ if __name__ == '__main__':
 | 
			
		|||
        else:
 | 
			
		||||
            user = message.from_user
 | 
			
		||||
            chat = message.chat
 | 
			
		||||
            db_user = get_user(user.id)
 | 
			
		||||
 | 
			
		||||
            if not db_user:
 | 
			
		||||
                db_user = new_user(user.id)
 | 
			
		||||
            db_user = get_or_create_user(user.id)
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                attr = params[1]
 | 
			
		||||
 | 
			
		||||
                if attr == 'step':
 | 
			
		||||
                if attr == 'algo':
 | 
			
		||||
                    val = params[2]
 | 
			
		||||
                    assert val in ALGOS
 | 
			
		||||
                    res = update_user(user.id, {'$set': {'config.algo': val}})
 | 
			
		||||
 | 
			
		||||
                elif attr == 'step':
 | 
			
		||||
                    val = int(params[2])
 | 
			
		||||
                    val = max(min(val, MAX_STEP), MIN_STEP)
 | 
			
		||||
                    res = update_user(user.id, {'$set': {'config.step': val}})
 | 
			
		||||
| 
						 | 
				
			
			@ -378,27 +433,30 @@ if __name__ == '__main__':
 | 
			
		|||
                    val = max(min(val, MAX_GUIDANCE), 0)
 | 
			
		||||
                    res = update_user(user.id, {'$set': {'config.guidance': val}})
 | 
			
		||||
 | 
			
		||||
                bot.reply_to(message, f"config updated! {attr} to {val}")
 | 
			
		||||
                else:
 | 
			
		||||
                    bot.reply_to(message, f'\"{attr}\" not a parameter')
 | 
			
		||||
 | 
			
		||||
                bot.reply_to(message, f'config updated! {attr} to {val}')
 | 
			
		||||
 | 
			
		||||
            except ValueError:
 | 
			
		||||
                bot.reply_to(message, f"\"{val}\" is not a number silly")
 | 
			
		||||
                bot.reply_to(message, f'\"{val}\" is not a number silly')
 | 
			
		||||
 | 
			
		||||
            except AssertionError:
 | 
			
		||||
                bot.reply_to(message, f'no algo named {val}')
 | 
			
		||||
 | 
			
		||||
    @bot.message_handler(commands=['stats'])
 | 
			
		||||
    @round_robined
 | 
			
		||||
    def user_stats(message):
 | 
			
		||||
        if not its_my_turn():
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        user = message.from_user
 | 
			
		||||
        db_user = get_user(user.id)
 | 
			
		||||
 | 
			
		||||
        if not db_user:
 | 
			
		||||
            db_user = new_user(user.id)
 | 
			
		||||
        db_user = get_or_create_user(user.id)
 | 
			
		||||
        migrate_user(db_user)
 | 
			
		||||
 | 
			
		||||
        joined_date_str = datetime.fromisoformat(db_user['joined']).strftime('%B the %dth %Y, %H:%M:%S')
 | 
			
		||||
 | 
			
		||||
        user_stats_str = f'generated: {db_user["generated"]}\n'
 | 
			
		||||
        user_stats_str += f'joined: {joined_date_str}\n'
 | 
			
		||||
        user_stats_str += f'credits: {db_user["credits"]}\n'
 | 
			
		||||
        user_stats_str += f'role: {db_user["role"]}\n'
 | 
			
		||||
 | 
			
		||||
        bot.reply_to(
 | 
			
		||||
            message, user_stats_str)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue