mirror of https://github.com/skygpu/skynet.git
				
				
				
			Initial commit
						commit
						66f21ff9de
					
				| 
						 | 
				
			
			@ -0,0 +1,23 @@
 | 
			
		|||
from pytorch/pytorch:latest
 | 
			
		||||
 | 
			
		||||
run apt-get update && apt-get install -y git curl
 | 
			
		||||
 | 
			
		||||
run pip install --upgrade \
 | 
			
		||||
    diffusers[torch] \
 | 
			
		||||
    transformers \
 | 
			
		||||
    huggingface_hub \
 | 
			
		||||
    pyTelegramBotAPI \
 | 
			
		||||
    pymongo \
 | 
			
		||||
    pdbpp
 | 
			
		||||
 | 
			
		||||
run mkdir /scripts
 | 
			
		||||
run mkdir /outputs
 | 
			
		||||
run mkdir /inputs
 | 
			
		||||
 | 
			
		||||
env HF_HOME /hf_home
 | 
			
		||||
 | 
			
		||||
run mkdir /hf_home
 | 
			
		||||
 | 
			
		||||
workdir /scripts
 | 
			
		||||
 | 
			
		||||
env PYTORCH_CUDA_ALLOC_CONF max_split_size_mb:128
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,11 @@
 | 
			
		|||
docker run \
 | 
			
		||||
    -it \
 | 
			
		||||
    --rm \
 | 
			
		||||
    --gpus=all \
 | 
			
		||||
    --env HF_TOKEN='' \
 | 
			
		||||
    --env DB_USER='' \
 | 
			
		||||
    --env DB_PASS='' \
 | 
			
		||||
    --mount type=bind,source="$(pwd)"/outputs,target=/outputs \
 | 
			
		||||
    --mount type=bind,source="$(pwd)"/hf_home,target=/hf_home \
 | 
			
		||||
    --mount type=bind,source="$(pwd)"/scripts,target=/scripts \
 | 
			
		||||
    skynet-art-bot:0.1a3 $1
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,9 @@
 | 
			
		|||
docker run
 | 
			
		||||
    -d \
 | 
			
		||||
    --rm \
 | 
			
		||||
    -p 27017:27017 \
 | 
			
		||||
	--name mongodb-skynet \
 | 
			
		||||
    --mount type=bind,source="$(pwd)"/mongodb,target=/data/db \
 | 
			
		||||
	-e MONGO_INITDB_ROOT_USERNAME="" \
 | 
			
		||||
	-e MONGO_INITDB_ROOT_PASSWORD="" \
 | 
			
		||||
    mongo	
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,382 @@
 | 
			
		|||
#!/usr/bin/python
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
import random
 | 
			
		||||
 | 
			
		||||
from torch.multiprocessing import spawn
 | 
			
		||||
 | 
			
		||||
import telebot
 | 
			
		||||
from telebot.types import InputFile
 | 
			
		||||
 | 
			
		||||
import sys
 | 
			
		||||
import uuid
 | 
			
		||||
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch.multiprocessing.spawn import ProcessRaisedException
 | 
			
		||||
from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler
 | 
			
		||||
 | 
			
		||||
from huggingface_hub import login
 | 
			
		||||
from datetime import datetime
 | 
			
		||||
 | 
			
		||||
from pymongo import MongoClient
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
db_user = os.environ['DB_USER']
 | 
			
		||||
db_pass = os.environ['DB_PASS']
 | 
			
		||||
 | 
			
		||||
logging.basicConfig(level=logging.INFO)
 | 
			
		||||
 | 
			
		||||
MEM_FRACTION = .33
 | 
			
		||||
 | 
			
		||||
HELP_TEXT = '''
 | 
			
		||||
test art bot v0.1a3
 | 
			
		||||
 | 
			
		||||
commands work on a user per user basis!
 | 
			
		||||
config is individual to each user!
 | 
			
		||||
 | 
			
		||||
/txt2img {prompt} - request an image based on a prompt
 | 
			
		||||
 | 
			
		||||
/redo - redo last primpt
 | 
			
		||||
 | 
			
		||||
/cool - list of cool words to use
 | 
			
		||||
 | 
			
		||||
/stats - user statistics
 | 
			
		||||
 | 
			
		||||
/config step {number} - set amount of iterations
 | 
			
		||||
/config seed {number} - set the seed, deterministic results!
 | 
			
		||||
/config size {width} {height} - set size in pixels
 | 
			
		||||
/config guidance {number} - prompt text importance
 | 
			
		||||
'''
 | 
			
		||||
 | 
			
		||||
COOL_WORDS = [
 | 
			
		||||
    'cyberpunk',
 | 
			
		||||
    'soviet propaganda poster',
 | 
			
		||||
    'rastafari',
 | 
			
		||||
    'cannabis',
 | 
			
		||||
    'art deco',
 | 
			
		||||
    'H R Giger Necronom IV',
 | 
			
		||||
    'dimethyltryptamine',
 | 
			
		||||
    'lysergic',
 | 
			
		||||
    'slut',
 | 
			
		||||
    'psilocybin',
 | 
			
		||||
    'trippy',
 | 
			
		||||
    'lucy in the sky with diamonds',
 | 
			
		||||
    'fractal',
 | 
			
		||||
    'da vinci',
 | 
			
		||||
    'pencil illustration',
 | 
			
		||||
    'blueprint',
 | 
			
		||||
    'internal diagram',
 | 
			
		||||
    'baroque',
 | 
			
		||||
    'the last judgment',
 | 
			
		||||
    'michelangelo'
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
GROUP_ID = -889553587
 | 
			
		||||
 | 
			
		||||
MIN_STEP = 1
 | 
			
		||||
MAX_STEP = 100
 | 
			
		||||
MAX_SIZE = (512, 656)
 | 
			
		||||
MAX_GUIDANCE = 20
 | 
			
		||||
 | 
			
		||||
DEFAULT_SIZE = (512, 512)
 | 
			
		||||
DEFAULT_GUIDANCE = 7.5
 | 
			
		||||
DEFAULT_STEP = 75
 | 
			
		||||
DEFAULT_CREDITS = 10
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_image(i, prompt, name, step, size, guidance, seed):
 | 
			
		||||
    assert torch.cuda.is_available()
 | 
			
		||||
    torch.cuda.empty_cache()
 | 
			
		||||
    torch.cuda.set_per_process_memory_fraction(MEM_FRACTION)
 | 
			
		||||
    pipe = StableDiffusionPipeline.from_pretrained(
 | 
			
		||||
        "runwayml/stable-diffusion-v1-5",
 | 
			
		||||
        torch_dtype=torch.float16,
 | 
			
		||||
        revision="fp16",
 | 
			
		||||
        safety_checker=None
 | 
			
		||||
    )
 | 
			
		||||
    pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
 | 
			
		||||
    pipe = pipe.to("cuda")
 | 
			
		||||
    w, h = size
 | 
			
		||||
    print(f'generating image... of size {w, h}')
 | 
			
		||||
    image = pipe(
 | 
			
		||||
        prompt,
 | 
			
		||||
        width=w,
 | 
			
		||||
        height=h,
 | 
			
		||||
        guidance_scale=guidance, num_inference_steps=step,
 | 
			
		||||
        generator=torch.Generator("cuda").manual_seed(seed)
 | 
			
		||||
    ).images[0]
 | 
			
		||||
    image.save(f'/outputs/{name}.png')
 | 
			
		||||
    print('saved')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
 | 
			
		||||
    API_TOKEN = '5880619053:AAFge2UfObw1kCn9Kb7AAyqaIHs_HgM0Fx0'
 | 
			
		||||
 | 
			
		||||
    bot = telebot.TeleBot(API_TOKEN)
 | 
			
		||||
    db_client = MongoClient(f'mongodb://{db_user}:{db_pass}@localhost:27017')
 | 
			
		||||
 | 
			
		||||
    rr_id = 0
 | 
			
		||||
 | 
			
		||||
    tgdb = db_client.get_database('telegram')
 | 
			
		||||
 | 
			
		||||
    collections = tgdb.list_collection_names()
 | 
			
		||||
 | 
			
		||||
    if 'users' in collections:
 | 
			
		||||
        tg_users = tgdb.get_collection('users')
 | 
			
		||||
        # tg_users.delete_many({})
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
        tg_users = tgdb.create_collection('users')
 | 
			
		||||
 | 
			
		||||
    # db functions
 | 
			
		||||
 | 
			
		||||
    def get_user(uid: int):
 | 
			
		||||
        return tg_users.find_one({'uid': uid})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def new_user(uid: int):
 | 
			
		||||
        if get_user(uid):
 | 
			
		||||
            raise ValueError('User already present on db')
 | 
			
		||||
 | 
			
		||||
        res = tg_users.insert_one({
 | 
			
		||||
            'generated': 0,
 | 
			
		||||
            'uid': uid,
 | 
			
		||||
            'credits': DEFAULT_CREDITS,
 | 
			
		||||
            'joined': datetime.utcnow().isoformat(),
 | 
			
		||||
            'last_prompt': None,
 | 
			
		||||
            'config': {
 | 
			
		||||
                'step': DEFAULT_STEP,
 | 
			
		||||
                'size': DEFAULT_SIZE,
 | 
			
		||||
                'seed': None,
 | 
			
		||||
                'guidance': DEFAULT_GUIDANCE
 | 
			
		||||
            }
 | 
			
		||||
        })
 | 
			
		||||
 | 
			
		||||
        assert res.acknowledged
 | 
			
		||||
 | 
			
		||||
        return get_user(uid)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def get_or_create_user(uid: int):
 | 
			
		||||
        db_user = get_user(uid)
 | 
			
		||||
 | 
			
		||||
        if not db_user:
 | 
			
		||||
            db_user = new_user(uid)
 | 
			
		||||
 | 
			
		||||
        return db_user
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def update_user(uid: int, updt_cmd: dict):
 | 
			
		||||
        user = get_user(uid)
 | 
			
		||||
        if not user:
 | 
			
		||||
            raise ValueError('User not present on db')
 | 
			
		||||
 | 
			
		||||
        return tg_users.find_one_and_update(
 | 
			
		||||
            {'uid': uid}, updt_cmd)
 | 
			
		||||
 | 
			
		||||
    # bot handler
 | 
			
		||||
 | 
			
		||||
    def img_for_user_with_prompt(
 | 
			
		||||
        uid: int,
 | 
			
		||||
        prompt: str, step: int, size: tuple[int, int], guidance: int, seed: int
 | 
			
		||||
    ):
 | 
			
		||||
        name = uuid.uuid4()
 | 
			
		||||
 | 
			
		||||
        spawn(
 | 
			
		||||
            generate_image,
 | 
			
		||||
            args=(prompt, name, step, size, guidance, seed))
 | 
			
		||||
 | 
			
		||||
        logging.info(f'done generating. got {name}, sending...')
 | 
			
		||||
 | 
			
		||||
        if len(prompt) > 256:
 | 
			
		||||
            reply_txt = f'prompt: \"{prompt[:256]}...\"\n(full prompt too big to show on reply...)\n'
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            reply_txt = f'prompt: \"{prompt}\"\n'
 | 
			
		||||
 | 
			
		||||
        reply_txt +=  f'seed: {seed}\n'
 | 
			
		||||
        reply_txt +=  f'iterations: {step}\n'
 | 
			
		||||
        reply_txt +=  f'size: {size}\n'
 | 
			
		||||
        reply_txt +=  f'guidance: {guidance}\n'
 | 
			
		||||
        reply_txt +=  f'stable-diff v1.5 uncensored\n'
 | 
			
		||||
        reply_txt +=  f'euler ancestral discrete'
 | 
			
		||||
 | 
			
		||||
        return reply_txt, name
 | 
			
		||||
 | 
			
		||||
    @bot.message_handler(commands=['help'])
 | 
			
		||||
    def send_help(message):
 | 
			
		||||
        bot.reply_to(message, HELP_TEXT)
 | 
			
		||||
 | 
			
		||||
    @bot.message_handler(commands=['cool'])
 | 
			
		||||
    def send_help(message):
 | 
			
		||||
        bot.reply_to(message, '\n'.join(COOL_WORDS))
 | 
			
		||||
 | 
			
		||||
    @bot.message_handler(commands=['txt2img'])
 | 
			
		||||
    def send_txt2img(message):
 | 
			
		||||
        # check msg comes from testing group
 | 
			
		||||
        chat = message.chat
 | 
			
		||||
        if chat.type != 'group' and chat.id != GROUP_ID:
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        prompt = ' '.join(message.text.split(' ')[1:])
 | 
			
		||||
 | 
			
		||||
        if len(prompt) == 0:
 | 
			
		||||
            bot.reply_to(message, 'empty text prompt ignored.')
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        user = message.from_user
 | 
			
		||||
        db_user = get_or_create_user(user.id)
 | 
			
		||||
 | 
			
		||||
        logging.info(f"{user.first_name} ({user.id}) on chat {chat.id} txt2img: {prompt}")
 | 
			
		||||
 | 
			
		||||
        user_conf = db_user['config']
 | 
			
		||||
 | 
			
		||||
        step = user_conf['step']
 | 
			
		||||
        size = user_conf['size']
 | 
			
		||||
        seed = user_conf['seed'] if user_conf['seed'] else random.randint(0, 999999999)
 | 
			
		||||
        guidance = user_conf['guidance']
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            reply_txt, name = img_for_user_with_prompt(
 | 
			
		||||
                user.id, prompt, step, size, guidance, seed)
 | 
			
		||||
 | 
			
		||||
            update_user(
 | 
			
		||||
                user.id,
 | 
			
		||||
                {'$set': {
 | 
			
		||||
                    'generated': db_user['generated'] + 1,
 | 
			
		||||
                    'last_prompt': prompt
 | 
			
		||||
                    }})
 | 
			
		||||
 | 
			
		||||
            bot.send_photo(
 | 
			
		||||
                chat.id,
 | 
			
		||||
                caption=f'sent by: {user.first_name}\n' + reply_txt,
 | 
			
		||||
                photo=InputFile(f'/outputs/{name}.png'))
 | 
			
		||||
 | 
			
		||||
        except BaseException as e:
 | 
			
		||||
            logging.error(e)
 | 
			
		||||
            bot.reply_to(message, 'that command caused an error :(\nchange settings and try again (?')
 | 
			
		||||
 | 
			
		||||
    @bot.message_handler(commands=['redo'])
 | 
			
		||||
    def redo_txt2img(message):
 | 
			
		||||
        # check msg comes from testing group
 | 
			
		||||
        chat = message.chat
 | 
			
		||||
        if chat.type != 'group' and chat.id != GROUP_ID:
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        user = message.from_user
 | 
			
		||||
        db_user = get_or_create_user(user.id)
 | 
			
		||||
 | 
			
		||||
        prompt = db_user['last_prompt']
 | 
			
		||||
 | 
			
		||||
        if not prompt:
 | 
			
		||||
            bot.reply_to(message, 'do a /txt2img command first silly!')
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        user_conf = db_user['config']
 | 
			
		||||
 | 
			
		||||
        step = user_conf['step']
 | 
			
		||||
        size = user_conf['size']
 | 
			
		||||
        seed = user_conf['seed'] if user_conf['seed'] else random.randint(0, 999999999)
 | 
			
		||||
        guidance = user_conf['guidance']
 | 
			
		||||
 | 
			
		||||
        logging.info(f"{user.first_name} ({user.id}) on chat {chat.id} redo: {prompt}")
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            reply_txt, name = img_for_user_with_prompt(
 | 
			
		||||
                user.id, prompt, step, size, guidance, seed)
 | 
			
		||||
 | 
			
		||||
            update_user(
 | 
			
		||||
                user.id,
 | 
			
		||||
                {'$set': {
 | 
			
		||||
                    'generated': db_user['generated'] + 1,
 | 
			
		||||
                    }})
 | 
			
		||||
 | 
			
		||||
            bot.send_photo(
 | 
			
		||||
                chat.id,
 | 
			
		||||
                caption=f'sent by: {user.first_name}\n' + reply_txt,
 | 
			
		||||
                photo=InputFile(f'/outputs/{name}.png'))
 | 
			
		||||
 | 
			
		||||
        except BaseException as e:
 | 
			
		||||
            logging.error(e)
 | 
			
		||||
            bot.reply_to(message, 'that command caused an error :(\nchange settings and try again (?')
 | 
			
		||||
 | 
			
		||||
    @bot.message_handler(commands=['config'])
 | 
			
		||||
    def set_config(message):
 | 
			
		||||
        params = message.text.split(' ')
 | 
			
		||||
 | 
			
		||||
        if len(params) < 3:
 | 
			
		||||
            bot.reply_to(message, 'wrong msg format')
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            user = message.from_user
 | 
			
		||||
            chat = message.chat
 | 
			
		||||
            db_user = get_user(user.id)
 | 
			
		||||
 | 
			
		||||
            if not db_user:
 | 
			
		||||
                db_user = new_user(user.id)
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                attr = params[1]
 | 
			
		||||
 | 
			
		||||
                if attr == 'step':
 | 
			
		||||
                    val = int(params[2])
 | 
			
		||||
                    val = max(min(val, MAX_STEP), MIN_STEP)
 | 
			
		||||
                    res = update_user(user.id, {'$set': {'config.step': val}})
 | 
			
		||||
 | 
			
		||||
                elif attr  == 'size':
 | 
			
		||||
                    max_w, max_h = MAX_SIZE
 | 
			
		||||
                    w = max(min(int(params[2]), max_w), 16)
 | 
			
		||||
                    h = max(min(int(params[3]), max_h), 16)
 | 
			
		||||
 | 
			
		||||
                    val = (w, h)
 | 
			
		||||
 | 
			
		||||
                    if (w % 8 != 0) or (h % 8 != 0):
 | 
			
		||||
                        bot.reply_to(message, 'size must be divisible by 8!')
 | 
			
		||||
                        return
 | 
			
		||||
 | 
			
		||||
                    res = update_user(user.id, {'$set': {'config.size': val}})
 | 
			
		||||
 | 
			
		||||
                elif attr == 'seed':
 | 
			
		||||
                    val = params[2]
 | 
			
		||||
                    if val == 'auto':
 | 
			
		||||
                        val = None
 | 
			
		||||
                    else:
 | 
			
		||||
                        val = int(params[2])
 | 
			
		||||
 | 
			
		||||
                    res = update_user(user.id, {'$set': {'config.seed': val}})
 | 
			
		||||
 | 
			
		||||
                elif attr == 'guidance':
 | 
			
		||||
                    val = float(params[2])
 | 
			
		||||
                    val = max(min(val, MAX_GUIDANCE), 0)
 | 
			
		||||
                    res = update_user(user.id, {'$set': {'config.guidance': val}})
 | 
			
		||||
 | 
			
		||||
                bot.reply_to(message, f"config updated! {attr} to {val}")
 | 
			
		||||
 | 
			
		||||
            except ValueError:
 | 
			
		||||
                bot.reply_to(message, f"\"{val}\" is not a number silly")
 | 
			
		||||
 | 
			
		||||
    @bot.message_handler(commands=['stats'])
 | 
			
		||||
    def user_stats(message):
 | 
			
		||||
        user = message.from_user
 | 
			
		||||
        db_user = get_user(user.id)
 | 
			
		||||
 | 
			
		||||
        if not db_user:
 | 
			
		||||
            db_user = new_user(user.id)
 | 
			
		||||
 | 
			
		||||
        joined_date_str = datetime.fromisoformat(db_user['joined']).strftime('%B the %dth %Y, %H:%M:%S')
 | 
			
		||||
 | 
			
		||||
        user_stats_str = f'generated: {db_user["generated"]}\n'
 | 
			
		||||
        user_stats_str += f'joined: {joined_date_str}\n'
 | 
			
		||||
        user_stats_str += f'credits: {db_user["credits"]}\n'
 | 
			
		||||
 | 
			
		||||
        bot.reply_to(
 | 
			
		||||
            message, user_stats_str)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    login(token=os.environ['HF_TOKEN'])
 | 
			
		||||
    bot.infinity_polling()
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,30 @@
 | 
			
		|||
import os
 | 
			
		||||
import sys
 | 
			
		||||
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from diffusers import StableDiffusionPipeline
 | 
			
		||||
 | 
			
		||||
from huggingface_hub import login
 | 
			
		||||
 | 
			
		||||
assert torch.cuda.is_available()
 | 
			
		||||
 | 
			
		||||
login(token=os.environ['HF_TOKEN'])
 | 
			
		||||
 | 
			
		||||
pipe = StableDiffusionPipeline.from_pretrained(
 | 
			
		||||
    "runwayml/stable-diffusion-v1-5",
 | 
			
		||||
    torch_dtype=torch.float16,
 | 
			
		||||
    revision="fp16"
 | 
			
		||||
)
 | 
			
		||||
pipe = pipe.to("cuda")
 | 
			
		||||
 | 
			
		||||
prompt = sys.argv[1]
 | 
			
		||||
image = pipe(
 | 
			
		||||
    prompt,
 | 
			
		||||
    width=640,
 | 
			
		||||
    height=640,
 | 
			
		||||
    guidance_scale=7.5, num_inference_steps=120
 | 
			
		||||
).images[0]
 | 
			
		||||
 | 
			
		||||
image.save("/outputs/img.png")
 | 
			
		||||
		Loading…
	
		Reference in New Issue