mirror of https://github.com/skygpu/skynet.git
Add roles and start working on algo selection, clean up rr code
parent
318a21ac81
commit
8bd255717e
|
@ -4,7 +4,7 @@ docker run \
|
||||||
--gpus=all \
|
--gpus=all \
|
||||||
--env HF_TOKEN='' \
|
--env HF_TOKEN='' \
|
||||||
--env DB_USER='skynet' \
|
--env DB_USER='skynet' \
|
||||||
--env DB_PASS='password' \
|
--env DB_PASS='nnf01nmf091d0i' \
|
||||||
--mount type=bind,source="$(pwd)"/outputs,target=/outputs \
|
--mount type=bind,source="$(pwd)"/outputs,target=/outputs \
|
||||||
--mount type=bind,source="$(pwd)"/hf_home,target=/hf_home \
|
--mount type=bind,source="$(pwd)"/hf_home,target=/hf_home \
|
||||||
--mount type=bind,source="$(pwd)"/scripts,target=/scripts \
|
--mount type=bind,source="$(pwd)"/scripts,target=/scripts \
|
||||||
|
|
|
@ -17,13 +17,17 @@ from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.multiprocessing.spawn import ProcessRaisedException
|
from torch.multiprocessing.spawn import ProcessRaisedException
|
||||||
from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler
|
from diffusers import (
|
||||||
|
FlaxStableDiffusionPipeline,
|
||||||
|
StableDiffusionPipeline,
|
||||||
|
EulerAncestralDiscreteScheduler
|
||||||
|
)
|
||||||
|
|
||||||
from huggingface_hub import login
|
from huggingface_hub import login
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
|
# import jax.numpy as jnp
|
||||||
|
|
||||||
db_user = os.environ['DB_USER']
|
db_user = os.environ['DB_USER']
|
||||||
db_pass = os.environ['DB_PASS']
|
db_pass = os.environ['DB_PASS']
|
||||||
|
@ -75,7 +79,11 @@ COOL_WORDS = [
|
||||||
'michelangelo'
|
'michelangelo'
|
||||||
]
|
]
|
||||||
|
|
||||||
GROUP_ID = -889553587
|
GROUP_ID = -1001541979235
|
||||||
|
|
||||||
|
ALGOS = ['stable', 'midj']
|
||||||
|
|
||||||
|
MP_ENABLED_ROLES = ['god']
|
||||||
|
|
||||||
MIN_STEP = 1
|
MIN_STEP = 1
|
||||||
MAX_STEP = 100
|
MAX_STEP = 100
|
||||||
|
@ -86,6 +94,8 @@ DEFAULT_SIZE = (512, 512)
|
||||||
DEFAULT_GUIDANCE = 7.5
|
DEFAULT_GUIDANCE = 7.5
|
||||||
DEFAULT_STEP = 75
|
DEFAULT_STEP = 75
|
||||||
DEFAULT_CREDITS = 10
|
DEFAULT_CREDITS = 10
|
||||||
|
DEFAULT_ALGO = 'stable'
|
||||||
|
DEFAULT_ROLE = 'pleb'
|
||||||
|
|
||||||
rr_total = 2
|
rr_total = 2
|
||||||
rr_id = 0
|
rr_id = 0
|
||||||
|
@ -98,18 +108,47 @@ def its_my_turn():
|
||||||
request_counter += 1
|
request_counter += 1
|
||||||
return my_turn
|
return my_turn
|
||||||
|
|
||||||
|
def round_robined(func):
|
||||||
|
def rr_wrapper(*args, **kwargs):
|
||||||
|
if not its_my_turn():
|
||||||
|
return
|
||||||
|
|
||||||
def generate_image(i, prompt, name, step, size, guidance, seed):
|
func(*args, **kwargs)
|
||||||
|
|
||||||
|
return rr_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def generate_image(
|
||||||
|
i: int,
|
||||||
|
prompt: str,
|
||||||
|
name: str,
|
||||||
|
step: int,
|
||||||
|
size: tuple[int, int],
|
||||||
|
guidance: int,
|
||||||
|
seed: int,
|
||||||
|
algo: str
|
||||||
|
):
|
||||||
assert torch.cuda.is_available()
|
assert torch.cuda.is_available()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.set_per_process_memory_fraction(MEM_FRACTION)
|
torch.cuda.set_per_process_memory_fraction(MEM_FRACTION)
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(
|
|
||||||
"runwayml/stable-diffusion-v1-5",
|
if algo == 'stable':
|
||||||
torch_dtype=torch.float16,
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
revision="fp16",
|
'runwayml/stable-diffusion-v1-5',
|
||||||
safety_checker=None
|
torch_dtype=torch.float16,
|
||||||
)
|
revision="fp16",
|
||||||
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
|
safety_checker=None
|
||||||
|
)
|
||||||
|
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||||
|
|
||||||
|
elif algo == 'midj':
|
||||||
|
import jax as jnp
|
||||||
|
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||||
|
'flax/midjourney-v4-diffusion',
|
||||||
|
revision="bf16",
|
||||||
|
dtype= jnp.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
pipe = pipe.to("cuda")
|
pipe = pipe.to("cuda")
|
||||||
w, h = size
|
w, h = size
|
||||||
print(f'generating image... of size {w, h}')
|
print(f'generating image... of size {w, h}')
|
||||||
|
@ -161,7 +200,9 @@ if __name__ == '__main__':
|
||||||
'credits': DEFAULT_CREDITS,
|
'credits': DEFAULT_CREDITS,
|
||||||
'joined': datetime.utcnow().isoformat(),
|
'joined': datetime.utcnow().isoformat(),
|
||||||
'last_prompt': None,
|
'last_prompt': None,
|
||||||
|
'role': DEFAULT_ROLE,
|
||||||
'config': {
|
'config': {
|
||||||
|
'algo': DEFAULT_ALGO,
|
||||||
'step': DEFAULT_STEP,
|
'step': DEFAULT_STEP,
|
||||||
'size': DEFAULT_SIZE,
|
'size': DEFAULT_SIZE,
|
||||||
'seed': None,
|
'seed': None,
|
||||||
|
@ -173,6 +214,18 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
return get_user(uid)
|
return get_user(uid)
|
||||||
|
|
||||||
|
def migrate_user(db_user):
|
||||||
|
# new: user roles
|
||||||
|
if 'role' not in db_user:
|
||||||
|
res = tg_users.find_one_and_update(
|
||||||
|
{'uid': db_user['uid']}, {'$set': {'role': DEFAULT_ROLE}})
|
||||||
|
|
||||||
|
# new: ai selection
|
||||||
|
if 'algo' not in db_user['config']:
|
||||||
|
res = tg_users.find_one_and_update(
|
||||||
|
{'uid': db_user['uid']}, {'$set': {'config.algo': DEFAULT_ALGO}})
|
||||||
|
|
||||||
|
return get_user(db_user['uid'])
|
||||||
|
|
||||||
def get_or_create_user(uid: int):
|
def get_or_create_user(uid: int):
|
||||||
db_user = get_user(uid)
|
db_user = get_user(uid)
|
||||||
|
@ -180,8 +233,9 @@ if __name__ == '__main__':
|
||||||
if not db_user:
|
if not db_user:
|
||||||
db_user = new_user(uid)
|
db_user = new_user(uid)
|
||||||
|
|
||||||
return db_user
|
logging.info(f'req from: {uid}')
|
||||||
|
|
||||||
|
return migrate_user(db_user)
|
||||||
|
|
||||||
def update_user(uid: int, updt_cmd: dict):
|
def update_user(uid: int, updt_cmd: dict):
|
||||||
user = get_user(uid)
|
user = get_user(uid)
|
||||||
|
@ -195,13 +249,13 @@ if __name__ == '__main__':
|
||||||
# bot handler
|
# bot handler
|
||||||
def img_for_user_with_prompt(
|
def img_for_user_with_prompt(
|
||||||
uid: int,
|
uid: int,
|
||||||
prompt: str, step: int, size: tuple[int, int], guidance: int, seed: int
|
prompt: str, step: int, size: tuple[int, int], guidance: int, seed: int, algo: str
|
||||||
):
|
):
|
||||||
name = uuid.uuid4()
|
name = uuid.uuid4()
|
||||||
|
|
||||||
spawn(
|
spawn(
|
||||||
generate_image,
|
generate_image,
|
||||||
args=(prompt, name, step, size, guidance, seed))
|
args=(prompt, name, step, size, guidance, seed, algo))
|
||||||
|
|
||||||
logging.info(f'done generating. got {name}, sending...')
|
logging.info(f'done generating. got {name}, sending...')
|
||||||
|
|
||||||
|
@ -221,23 +275,24 @@ if __name__ == '__main__':
|
||||||
return reply_txt, name
|
return reply_txt, name
|
||||||
|
|
||||||
@bot.message_handler(commands=['help'])
|
@bot.message_handler(commands=['help'])
|
||||||
|
@round_robined
|
||||||
def send_help(message):
|
def send_help(message):
|
||||||
if its_my_turn():
|
bot.reply_to(message, HELP_TEXT)
|
||||||
bot.reply_to(message, HELP_TEXT)
|
|
||||||
|
|
||||||
@bot.message_handler(commands=['cool'])
|
@bot.message_handler(commands=['cool'])
|
||||||
|
@round_robined
|
||||||
def send_cool_words(message):
|
def send_cool_words(message):
|
||||||
if its_my_turn():
|
bot.reply_to(message, '\n'.join(COOL_WORDS))
|
||||||
bot.reply_to(message, '\n'.join(COOL_WORDS))
|
|
||||||
|
|
||||||
@bot.message_handler(commands=['txt2img'])
|
@bot.message_handler(commands=['txt2img'])
|
||||||
|
@round_robined
|
||||||
def send_txt2img(message):
|
def send_txt2img(message):
|
||||||
if not its_my_turn():
|
|
||||||
return
|
|
||||||
|
|
||||||
# check msg comes from testing group
|
|
||||||
chat = message.chat
|
chat = message.chat
|
||||||
if chat.type != 'group' and chat.id != GROUP_ID:
|
user = message.from_user
|
||||||
|
db_user = get_or_create_user(user.id)
|
||||||
|
|
||||||
|
if ((chat.type != 'group' and chat.id != GROUP_ID) and
|
||||||
|
(db_user['role'] not in MP_ENABLED_ROLES)):
|
||||||
return
|
return
|
||||||
|
|
||||||
prompt = ' '.join(message.text.split(' ')[1:])
|
prompt = ' '.join(message.text.split(' ')[1:])
|
||||||
|
@ -246,13 +301,11 @@ if __name__ == '__main__':
|
||||||
bot.reply_to(message, 'empty text prompt ignored.')
|
bot.reply_to(message, 'empty text prompt ignored.')
|
||||||
return
|
return
|
||||||
|
|
||||||
user = message.from_user
|
|
||||||
db_user = get_or_create_user(user.id)
|
|
||||||
|
|
||||||
logging.info(f"{user.first_name} ({user.id}) on chat {chat.id} txt2img: {prompt}")
|
logging.info(f"{user.first_name} ({user.id}) on chat {chat.id} txt2img: {prompt}")
|
||||||
|
|
||||||
user_conf = db_user['config']
|
user_conf = db_user['config']
|
||||||
|
|
||||||
|
algo = user_conf['algo']
|
||||||
step = user_conf['step']
|
step = user_conf['step']
|
||||||
size = user_conf['size']
|
size = user_conf['size']
|
||||||
seed = user_conf['seed'] if user_conf['seed'] else random.randint(0, 999999999)
|
seed = user_conf['seed'] if user_conf['seed'] else random.randint(0, 999999999)
|
||||||
|
@ -260,7 +313,7 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
try:
|
try:
|
||||||
reply_txt, name = img_for_user_with_prompt(
|
reply_txt, name = img_for_user_with_prompt(
|
||||||
user.id, prompt, step, size, guidance, seed)
|
user.id, prompt, step, size, guidance, seed, algo)
|
||||||
|
|
||||||
update_user(
|
update_user(
|
||||||
user.id,
|
user.id,
|
||||||
|
@ -279,18 +332,17 @@ if __name__ == '__main__':
|
||||||
bot.reply_to(message, 'that command caused an error :(\nchange settings and try again (?')
|
bot.reply_to(message, 'that command caused an error :(\nchange settings and try again (?')
|
||||||
|
|
||||||
@bot.message_handler(commands=['redo'])
|
@bot.message_handler(commands=['redo'])
|
||||||
|
@round_robined
|
||||||
def redo_txt2img(message):
|
def redo_txt2img(message):
|
||||||
if not its_my_turn():
|
|
||||||
return
|
|
||||||
|
|
||||||
# check msg comes from testing group
|
# check msg comes from testing group
|
||||||
chat = message.chat
|
chat = message.chat
|
||||||
if chat.type != 'group' and chat.id != GROUP_ID:
|
|
||||||
return
|
|
||||||
|
|
||||||
user = message.from_user
|
user = message.from_user
|
||||||
db_user = get_or_create_user(user.id)
|
db_user = get_or_create_user(user.id)
|
||||||
|
|
||||||
|
if ((chat.type != 'group' and chat.id != GROUP_ID) and
|
||||||
|
(db_user['role'] not in MP_ENABLED_ROLES)):
|
||||||
|
return
|
||||||
|
|
||||||
prompt = db_user['last_prompt']
|
prompt = db_user['last_prompt']
|
||||||
|
|
||||||
if not prompt:
|
if not prompt:
|
||||||
|
@ -299,6 +351,7 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
user_conf = db_user['config']
|
user_conf = db_user['config']
|
||||||
|
|
||||||
|
algo = user_conf['algo']
|
||||||
step = user_conf['step']
|
step = user_conf['step']
|
||||||
size = user_conf['size']
|
size = user_conf['size']
|
||||||
seed = user_conf['seed'] if user_conf['seed'] else random.randint(0, 999999999)
|
seed = user_conf['seed'] if user_conf['seed'] else random.randint(0, 999999999)
|
||||||
|
@ -308,7 +361,7 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
try:
|
try:
|
||||||
reply_txt, name = img_for_user_with_prompt(
|
reply_txt, name = img_for_user_with_prompt(
|
||||||
user.id, prompt, step, size, guidance, seed)
|
user.id, prompt, step, size, guidance, seed, algo)
|
||||||
|
|
||||||
update_user(
|
update_user(
|
||||||
user.id,
|
user.id,
|
||||||
|
@ -326,9 +379,9 @@ if __name__ == '__main__':
|
||||||
bot.reply_to(message, 'that command caused an error :(\nchange settings and try again (?')
|
bot.reply_to(message, 'that command caused an error :(\nchange settings and try again (?')
|
||||||
|
|
||||||
@bot.message_handler(commands=['config'])
|
@bot.message_handler(commands=['config'])
|
||||||
|
@round_robined
|
||||||
def set_config(message):
|
def set_config(message):
|
||||||
if not its_my_turn():
|
logging.info(f'config req on chat: {message.chat.id}')
|
||||||
return
|
|
||||||
|
|
||||||
params = message.text.split(' ')
|
params = message.text.split(' ')
|
||||||
|
|
||||||
|
@ -338,15 +391,17 @@ if __name__ == '__main__':
|
||||||
else:
|
else:
|
||||||
user = message.from_user
|
user = message.from_user
|
||||||
chat = message.chat
|
chat = message.chat
|
||||||
db_user = get_user(user.id)
|
db_user = get_or_create_user(user.id)
|
||||||
|
|
||||||
if not db_user:
|
|
||||||
db_user = new_user(user.id)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
attr = params[1]
|
attr = params[1]
|
||||||
|
|
||||||
if attr == 'step':
|
if attr == 'algo':
|
||||||
|
val = params[2]
|
||||||
|
assert val in ALGOS
|
||||||
|
res = update_user(user.id, {'$set': {'config.algo': val}})
|
||||||
|
|
||||||
|
elif attr == 'step':
|
||||||
val = int(params[2])
|
val = int(params[2])
|
||||||
val = max(min(val, MAX_STEP), MIN_STEP)
|
val = max(min(val, MAX_STEP), MIN_STEP)
|
||||||
res = update_user(user.id, {'$set': {'config.step': val}})
|
res = update_user(user.id, {'$set': {'config.step': val}})
|
||||||
|
@ -378,27 +433,30 @@ if __name__ == '__main__':
|
||||||
val = max(min(val, MAX_GUIDANCE), 0)
|
val = max(min(val, MAX_GUIDANCE), 0)
|
||||||
res = update_user(user.id, {'$set': {'config.guidance': val}})
|
res = update_user(user.id, {'$set': {'config.guidance': val}})
|
||||||
|
|
||||||
bot.reply_to(message, f"config updated! {attr} to {val}")
|
else:
|
||||||
|
bot.reply_to(message, f'\"{attr}\" not a parameter')
|
||||||
|
|
||||||
|
bot.reply_to(message, f'config updated! {attr} to {val}')
|
||||||
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
bot.reply_to(message, f"\"{val}\" is not a number silly")
|
bot.reply_to(message, f'\"{val}\" is not a number silly')
|
||||||
|
|
||||||
|
except AssertionError:
|
||||||
|
bot.reply_to(message, f'no algo named {val}')
|
||||||
|
|
||||||
@bot.message_handler(commands=['stats'])
|
@bot.message_handler(commands=['stats'])
|
||||||
|
@round_robined
|
||||||
def user_stats(message):
|
def user_stats(message):
|
||||||
if not its_my_turn():
|
|
||||||
return
|
|
||||||
|
|
||||||
user = message.from_user
|
user = message.from_user
|
||||||
db_user = get_user(user.id)
|
db_user = get_or_create_user(user.id)
|
||||||
|
migrate_user(db_user)
|
||||||
if not db_user:
|
|
||||||
db_user = new_user(user.id)
|
|
||||||
|
|
||||||
joined_date_str = datetime.fromisoformat(db_user['joined']).strftime('%B the %dth %Y, %H:%M:%S')
|
joined_date_str = datetime.fromisoformat(db_user['joined']).strftime('%B the %dth %Y, %H:%M:%S')
|
||||||
|
|
||||||
user_stats_str = f'generated: {db_user["generated"]}\n'
|
user_stats_str = f'generated: {db_user["generated"]}\n'
|
||||||
user_stats_str += f'joined: {joined_date_str}\n'
|
user_stats_str += f'joined: {joined_date_str}\n'
|
||||||
user_stats_str += f'credits: {db_user["credits"]}\n'
|
user_stats_str += f'credits: {db_user["credits"]}\n'
|
||||||
|
user_stats_str += f'role: {db_user["role"]}\n'
|
||||||
|
|
||||||
bot.reply_to(
|
bot.reply_to(
|
||||||
message, user_stats_str)
|
message, user_stats_str)
|
||||||
|
|
Loading…
Reference in New Issue