mirror of https://github.com/skygpu/skynet.git
commit
563efed3e9
|
@ -0,0 +1,23 @@
|
|||
from pytorch/pytorch:latest
|
||||
|
||||
run apt-get update && apt-get install -y git curl
|
||||
|
||||
run pip install --upgrade \
|
||||
diffusers[torch] \
|
||||
transformers \
|
||||
huggingface_hub \
|
||||
pyTelegramBotAPI \
|
||||
pymongo \
|
||||
pdbpp
|
||||
|
||||
run mkdir /scripts
|
||||
run mkdir /outputs
|
||||
run mkdir /inputs
|
||||
|
||||
env HF_HOME /hf_home
|
||||
|
||||
run mkdir /hf_home
|
||||
|
||||
workdir /scripts
|
||||
|
||||
env PYTORCH_CUDA_ALLOC_CONF max_split_size_mb:128
|
|
@ -0,0 +1,11 @@
|
|||
docker run \
|
||||
-it \
|
||||
--rm \
|
||||
--gpus=all \
|
||||
--env HF_TOKEN='' \
|
||||
--env DB_USER='' \
|
||||
--env DB_PASS='' \
|
||||
--mount type=bind,source="$(pwd)"/outputs,target=/outputs \
|
||||
--mount type=bind,source="$(pwd)"/hf_home,target=/hf_home \
|
||||
--mount type=bind,source="$(pwd)"/scripts,target=/scripts \
|
||||
skynet-art-bot:0.1a3 $1
|
|
@ -0,0 +1,9 @@
|
|||
docker run
|
||||
-d \
|
||||
--rm \
|
||||
-p 27017:27017 \
|
||||
--name mongodb-skynet \
|
||||
--mount type=bind,source="$(pwd)"/mongodb,target=/data/db \
|
||||
-e MONGO_INITDB_ROOT_USERNAME="" \
|
||||
-e MONGO_INITDB_ROOT_PASSWORD="" \
|
||||
mongo
|
|
@ -0,0 +1,382 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import os
|
||||
|
||||
import logging
|
||||
import random
|
||||
|
||||
from torch.multiprocessing import spawn
|
||||
|
||||
import telebot
|
||||
from telebot.types import InputFile
|
||||
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.multiprocessing.spawn import ProcessRaisedException
|
||||
from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler
|
||||
|
||||
from huggingface_hub import login
|
||||
from datetime import datetime
|
||||
|
||||
from pymongo import MongoClient
|
||||
|
||||
|
||||
db_user = os.environ['DB_USER']
|
||||
db_pass = os.environ['DB_PASS']
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
MEM_FRACTION = .33
|
||||
|
||||
HELP_TEXT = '''
|
||||
test art bot v0.1a3
|
||||
|
||||
commands work on a user per user basis!
|
||||
config is individual to each user!
|
||||
|
||||
/txt2img {prompt} - request an image based on a prompt
|
||||
|
||||
/redo - redo last primpt
|
||||
|
||||
/cool - list of cool words to use
|
||||
|
||||
/stats - user statistics
|
||||
|
||||
/config step {number} - set amount of iterations
|
||||
/config seed {number} - set the seed, deterministic results!
|
||||
/config size {width} {height} - set size in pixels
|
||||
/config guidance {number} - prompt text importance
|
||||
'''
|
||||
|
||||
COOL_WORDS = [
|
||||
'cyberpunk',
|
||||
'soviet propaganda poster',
|
||||
'rastafari',
|
||||
'cannabis',
|
||||
'art deco',
|
||||
'H R Giger Necronom IV',
|
||||
'dimethyltryptamine',
|
||||
'lysergic',
|
||||
'slut',
|
||||
'psilocybin',
|
||||
'trippy',
|
||||
'lucy in the sky with diamonds',
|
||||
'fractal',
|
||||
'da vinci',
|
||||
'pencil illustration',
|
||||
'blueprint',
|
||||
'internal diagram',
|
||||
'baroque',
|
||||
'the last judgment',
|
||||
'michelangelo'
|
||||
]
|
||||
|
||||
GROUP_ID = -889553587
|
||||
|
||||
MIN_STEP = 1
|
||||
MAX_STEP = 100
|
||||
MAX_SIZE = (512, 656)
|
||||
MAX_GUIDANCE = 20
|
||||
|
||||
DEFAULT_SIZE = (512, 512)
|
||||
DEFAULT_GUIDANCE = 7.5
|
||||
DEFAULT_STEP = 75
|
||||
DEFAULT_CREDITS = 10
|
||||
|
||||
|
||||
def generate_image(i, prompt, name, step, size, guidance, seed):
|
||||
assert torch.cuda.is_available()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.set_per_process_memory_fraction(MEM_FRACTION)
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
revision="fp16",
|
||||
safety_checker=None
|
||||
)
|
||||
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||
pipe = pipe.to("cuda")
|
||||
w, h = size
|
||||
print(f'generating image... of size {w, h}')
|
||||
image = pipe(
|
||||
prompt,
|
||||
width=w,
|
||||
height=h,
|
||||
guidance_scale=guidance, num_inference_steps=step,
|
||||
generator=torch.Generator("cuda").manual_seed(seed)
|
||||
).images[0]
|
||||
image.save(f'/outputs/{name}.png')
|
||||
print('saved')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
API_TOKEN = '5880619053:AAFge2UfObw1kCn9Kb7AAyqaIHs_HgM0Fx0'
|
||||
|
||||
bot = telebot.TeleBot(API_TOKEN)
|
||||
db_client = MongoClient(f'mongodb://{db_user}:{db_pass}@localhost:27017')
|
||||
|
||||
rr_id = 0
|
||||
|
||||
tgdb = db_client.get_database('telegram')
|
||||
|
||||
collections = tgdb.list_collection_names()
|
||||
|
||||
if 'users' in collections:
|
||||
tg_users = tgdb.get_collection('users')
|
||||
# tg_users.delete_many({})
|
||||
|
||||
else:
|
||||
tg_users = tgdb.create_collection('users')
|
||||
|
||||
# db functions
|
||||
|
||||
def get_user(uid: int):
|
||||
return tg_users.find_one({'uid': uid})
|
||||
|
||||
|
||||
def new_user(uid: int):
|
||||
if get_user(uid):
|
||||
raise ValueError('User already present on db')
|
||||
|
||||
res = tg_users.insert_one({
|
||||
'generated': 0,
|
||||
'uid': uid,
|
||||
'credits': DEFAULT_CREDITS,
|
||||
'joined': datetime.utcnow().isoformat(),
|
||||
'last_prompt': None,
|
||||
'config': {
|
||||
'step': DEFAULT_STEP,
|
||||
'size': DEFAULT_SIZE,
|
||||
'seed': None,
|
||||
'guidance': DEFAULT_GUIDANCE
|
||||
}
|
||||
})
|
||||
|
||||
assert res.acknowledged
|
||||
|
||||
return get_user(uid)
|
||||
|
||||
|
||||
def get_or_create_user(uid: int):
|
||||
db_user = get_user(uid)
|
||||
|
||||
if not db_user:
|
||||
db_user = new_user(uid)
|
||||
|
||||
return db_user
|
||||
|
||||
|
||||
def update_user(uid: int, updt_cmd: dict):
|
||||
user = get_user(uid)
|
||||
if not user:
|
||||
raise ValueError('User not present on db')
|
||||
|
||||
return tg_users.find_one_and_update(
|
||||
{'uid': uid}, updt_cmd)
|
||||
|
||||
# bot handler
|
||||
|
||||
def img_for_user_with_prompt(
|
||||
uid: int,
|
||||
prompt: str, step: int, size: tuple[int, int], guidance: int, seed: int
|
||||
):
|
||||
name = uuid.uuid4()
|
||||
|
||||
spawn(
|
||||
generate_image,
|
||||
args=(prompt, name, step, size, guidance, seed))
|
||||
|
||||
logging.info(f'done generating. got {name}, sending...')
|
||||
|
||||
if len(prompt) > 256:
|
||||
reply_txt = f'prompt: \"{prompt[:256]}...\"\n(full prompt too big to show on reply...)\n'
|
||||
|
||||
else:
|
||||
reply_txt = f'prompt: \"{prompt}\"\n'
|
||||
|
||||
reply_txt += f'seed: {seed}\n'
|
||||
reply_txt += f'iterations: {step}\n'
|
||||
reply_txt += f'size: {size}\n'
|
||||
reply_txt += f'guidance: {guidance}\n'
|
||||
reply_txt += f'stable-diff v1.5 uncensored\n'
|
||||
reply_txt += f'euler ancestral discrete'
|
||||
|
||||
return reply_txt, name
|
||||
|
||||
@bot.message_handler(commands=['help'])
|
||||
def send_help(message):
|
||||
bot.reply_to(message, HELP_TEXT)
|
||||
|
||||
@bot.message_handler(commands=['cool'])
|
||||
def send_help(message):
|
||||
bot.reply_to(message, '\n'.join(COOL_WORDS))
|
||||
|
||||
@bot.message_handler(commands=['txt2img'])
|
||||
def send_txt2img(message):
|
||||
# check msg comes from testing group
|
||||
chat = message.chat
|
||||
if chat.type != 'group' and chat.id != GROUP_ID:
|
||||
return
|
||||
|
||||
prompt = ' '.join(message.text.split(' ')[1:])
|
||||
|
||||
if len(prompt) == 0:
|
||||
bot.reply_to(message, 'empty text prompt ignored.')
|
||||
return
|
||||
|
||||
user = message.from_user
|
||||
db_user = get_or_create_user(user.id)
|
||||
|
||||
logging.info(f"{user.first_name} ({user.id}) on chat {chat.id} txt2img: {prompt}")
|
||||
|
||||
user_conf = db_user['config']
|
||||
|
||||
step = user_conf['step']
|
||||
size = user_conf['size']
|
||||
seed = user_conf['seed'] if user_conf['seed'] else random.randint(0, 999999999)
|
||||
guidance = user_conf['guidance']
|
||||
|
||||
try:
|
||||
reply_txt, name = img_for_user_with_prompt(
|
||||
user.id, prompt, step, size, guidance, seed)
|
||||
|
||||
update_user(
|
||||
user.id,
|
||||
{'$set': {
|
||||
'generated': db_user['generated'] + 1,
|
||||
'last_prompt': prompt
|
||||
}})
|
||||
|
||||
bot.send_photo(
|
||||
chat.id,
|
||||
caption=f'sent by: {user.first_name}\n' + reply_txt,
|
||||
photo=InputFile(f'/outputs/{name}.png'))
|
||||
|
||||
except BaseException as e:
|
||||
logging.error(e)
|
||||
bot.reply_to(message, 'that command caused an error :(\nchange settings and try again (?')
|
||||
|
||||
@bot.message_handler(commands=['redo'])
|
||||
def redo_txt2img(message):
|
||||
# check msg comes from testing group
|
||||
chat = message.chat
|
||||
if chat.type != 'group' and chat.id != GROUP_ID:
|
||||
return
|
||||
|
||||
user = message.from_user
|
||||
db_user = get_or_create_user(user.id)
|
||||
|
||||
prompt = db_user['last_prompt']
|
||||
|
||||
if not prompt:
|
||||
bot.reply_to(message, 'do a /txt2img command first silly!')
|
||||
return
|
||||
|
||||
user_conf = db_user['config']
|
||||
|
||||
step = user_conf['step']
|
||||
size = user_conf['size']
|
||||
seed = user_conf['seed'] if user_conf['seed'] else random.randint(0, 999999999)
|
||||
guidance = user_conf['guidance']
|
||||
|
||||
logging.info(f"{user.first_name} ({user.id}) on chat {chat.id} redo: {prompt}")
|
||||
|
||||
try:
|
||||
reply_txt, name = img_for_user_with_prompt(
|
||||
user.id, prompt, step, size, guidance, seed)
|
||||
|
||||
update_user(
|
||||
user.id,
|
||||
{'$set': {
|
||||
'generated': db_user['generated'] + 1,
|
||||
}})
|
||||
|
||||
bot.send_photo(
|
||||
chat.id,
|
||||
caption=f'sent by: {user.first_name}\n' + reply_txt,
|
||||
photo=InputFile(f'/outputs/{name}.png'))
|
||||
|
||||
except BaseException as e:
|
||||
logging.error(e)
|
||||
bot.reply_to(message, 'that command caused an error :(\nchange settings and try again (?')
|
||||
|
||||
@bot.message_handler(commands=['config'])
|
||||
def set_config(message):
|
||||
params = message.text.split(' ')
|
||||
|
||||
if len(params) < 3:
|
||||
bot.reply_to(message, 'wrong msg format')
|
||||
|
||||
else:
|
||||
user = message.from_user
|
||||
chat = message.chat
|
||||
db_user = get_user(user.id)
|
||||
|
||||
if not db_user:
|
||||
db_user = new_user(user.id)
|
||||
|
||||
try:
|
||||
attr = params[1]
|
||||
|
||||
if attr == 'step':
|
||||
val = int(params[2])
|
||||
val = max(min(val, MAX_STEP), MIN_STEP)
|
||||
res = update_user(user.id, {'$set': {'config.step': val}})
|
||||
|
||||
elif attr == 'size':
|
||||
max_w, max_h = MAX_SIZE
|
||||
w = max(min(int(params[2]), max_w), 16)
|
||||
h = max(min(int(params[3]), max_h), 16)
|
||||
|
||||
val = (w, h)
|
||||
|
||||
if (w % 8 != 0) or (h % 8 != 0):
|
||||
bot.reply_to(message, 'size must be divisible by 8!')
|
||||
return
|
||||
|
||||
res = update_user(user.id, {'$set': {'config.size': val}})
|
||||
|
||||
elif attr == 'seed':
|
||||
val = params[2]
|
||||
if val == 'auto':
|
||||
val = None
|
||||
else:
|
||||
val = int(params[2])
|
||||
|
||||
res = update_user(user.id, {'$set': {'config.seed': val}})
|
||||
|
||||
elif attr == 'guidance':
|
||||
val = float(params[2])
|
||||
val = max(min(val, MAX_GUIDANCE), 0)
|
||||
res = update_user(user.id, {'$set': {'config.guidance': val}})
|
||||
|
||||
bot.reply_to(message, f"config updated! {attr} to {val}")
|
||||
|
||||
except ValueError:
|
||||
bot.reply_to(message, f"\"{val}\" is not a number silly")
|
||||
|
||||
@bot.message_handler(commands=['stats'])
|
||||
def user_stats(message):
|
||||
user = message.from_user
|
||||
db_user = get_user(user.id)
|
||||
|
||||
if not db_user:
|
||||
db_user = new_user(user.id)
|
||||
|
||||
joined_date_str = datetime.fromisoformat(db_user['joined']).strftime('%B the %dth %Y, %H:%M:%S')
|
||||
|
||||
user_stats_str = f'generated: {db_user["generated"]}\n'
|
||||
user_stats_str += f'joined: {joined_date_str}\n'
|
||||
user_stats_str += f'credits: {db_user["credits"]}\n'
|
||||
|
||||
bot.reply_to(
|
||||
message, user_stats_str)
|
||||
|
||||
|
||||
login(token=os.environ['HF_TOKEN'])
|
||||
bot.infinity_polling()
|
|
@ -0,0 +1,30 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
from huggingface_hub import login
|
||||
|
||||
assert torch.cuda.is_available()
|
||||
|
||||
login(token=os.environ['HF_TOKEN'])
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
revision="fp16"
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = sys.argv[1]
|
||||
image = pipe(
|
||||
prompt,
|
||||
width=640,
|
||||
height=640,
|
||||
guidance_scale=7.5, num_inference_steps=120
|
||||
).images[0]
|
||||
|
||||
image.save("/outputs/img.png")
|
Loading…
Reference in New Issue