mirror of https://github.com/skygpu/skynet.git
First decupled architecture, still working on integrating tractor gpu workers
parent
91e0693e65
commit
74d2426793
|
@ -0,0 +1,3 @@
|
|||
hf_home
|
||||
inputs
|
||||
outputs
|
|
@ -1,2 +1,5 @@
|
|||
.python-version
|
||||
hf_home
|
||||
outputs
|
||||
**/__pycache__
|
||||
*.egg-info
|
||||
|
|
31
Dockerfile
31
Dockerfile
|
@ -1,31 +0,0 @@
|
|||
from pytorch/pytorch:latest
|
||||
|
||||
env DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
run apt-get update && apt-get install -y git wget
|
||||
|
||||
run conda install xformers -c xformers/label/dev
|
||||
|
||||
run pip install --upgrade \
|
||||
diffusers[torch] \
|
||||
accelerate \
|
||||
transformers \
|
||||
huggingface_hub \
|
||||
pyTelegramBotAPI \
|
||||
pymongo \
|
||||
scipy \
|
||||
pdbpp
|
||||
|
||||
env NVIDIA_VISIBLE_DEVICES=all
|
||||
|
||||
run mkdir /scripts
|
||||
run mkdir /outputs
|
||||
run mkdir /inputs
|
||||
|
||||
env HF_HOME /hf_home
|
||||
|
||||
run mkdir /hf_home
|
||||
|
||||
workdir /scripts
|
||||
|
||||
env PYTORCH_CUDA_ALLOC_CONF max_split_size_mb:128
|
|
@ -0,0 +1,13 @@
|
|||
from python:3.10.0
|
||||
|
||||
env DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
workdir /skynet
|
||||
|
||||
copy requirements.* ./
|
||||
|
||||
run pip install \
|
||||
-r requirements.txt \
|
||||
-r requirements.test.txt
|
||||
|
||||
workdir /scripts
|
|
@ -0,0 +1,23 @@
|
|||
from nvidia/cuda:11.7.0-devel-ubuntu20.04
|
||||
from python:3.10.0
|
||||
|
||||
env DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
workdir /skynet
|
||||
|
||||
copy requirements.* .
|
||||
|
||||
run pip install -U pip ninja
|
||||
run pip install -r requirements.cuda.0.txt
|
||||
run pip install -v -r requirements.cuda.1.txt
|
||||
|
||||
run pip install \
|
||||
-r requirements.txt \
|
||||
-r requirements.test.txt
|
||||
|
||||
env NVIDIA_VISIBLE_DEVICES=all
|
||||
env HF_HOME /hf_home
|
||||
|
||||
env PYTORCH_CUDA_ALLOC_CONF max_split_size_mb:128
|
||||
|
||||
workdir /scripts
|
|
@ -0,0 +1,11 @@
|
|||
A menos que sea especificamente indicado en el cabezal del archivo, se reservan
|
||||
todos los derechos sobre este codigo por parte de:
|
||||
|
||||
Guillermo Rodriguez, guillermor@fing.edu.uy
|
||||
|
||||
ENGLISH LICENSE:
|
||||
|
||||
Unless specifically indicated in the file header, all rights to this code are
|
||||
reserved by:
|
||||
|
||||
Guillermo Rodriguez, guillermor@.edu.uy
|
|
@ -0,0 +1,47 @@
|
|||
create db in postgres:
|
||||
|
||||
```sql
|
||||
CREATE USER skynet WITH PASSWORD 'password';
|
||||
CREATE DATABASE skynet_art_bot;
|
||||
GRANT ALL PRIVILEGES ON DATABASE skynet_art_bot TO skynet;
|
||||
|
||||
CREATE SCHEMA IF NOT EXISTS skynet;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS skynet.user(
|
||||
id SERIAL PRIMARY KEY NOT NULL,
|
||||
tg_id INT,
|
||||
wp_id VARCHAR(128),
|
||||
mx_id VARCHAR(128),
|
||||
ig_id VARCHAR(128),
|
||||
generated INT NOT NULL,
|
||||
joined DATE NOT NULL,
|
||||
last_prompt TEXT,
|
||||
role VARCHAR(128) NOT NULL
|
||||
);
|
||||
ALTER TABLE skynet.user
|
||||
ADD CONSTRAINT tg_unique
|
||||
UNIQUE (tg_id);
|
||||
ALTER TABLE skynet.user
|
||||
ADD CONSTRAINT wp_unique
|
||||
UNIQUE (wp_id);
|
||||
ALTER TABLE skynet.user
|
||||
ADD CONSTRAINT mx_unique
|
||||
UNIQUE (mx_id);
|
||||
ALTER TABLE skynet.user
|
||||
ADD CONSTRAINT ig_unique
|
||||
UNIQUE (ig_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS skynet.user_config(
|
||||
id SERIAL NOT NULL,
|
||||
algo VARCHAR(128) NOT NULL,
|
||||
step INT NOT NULL,
|
||||
width INT NOT NULL,
|
||||
height INT NOT NULL,
|
||||
seed INT,
|
||||
guidance INT NOT NULL,
|
||||
upscaler VARCHAR(128)
|
||||
);
|
||||
ALTER TABLE skynet.user_config
|
||||
ADD FOREIGN KEY(id)
|
||||
REFERENCES skynet.user(id);
|
||||
```
|
|
@ -0,0 +1,7 @@
|
|||
docker build \
|
||||
-t skynet:runtime-cuda \
|
||||
-f Dockerfile.runtime-cuda .
|
||||
|
||||
docker build \
|
||||
-t skynet:runtime \
|
||||
-f Dockerfile.runtime .
|
|
@ -0,0 +1,2 @@
|
|||
[pytest]
|
||||
trio_mode = true
|
|
@ -0,0 +1,8 @@
|
|||
pdbpp
|
||||
scipy
|
||||
accelerate
|
||||
transformers
|
||||
huggingface_hub
|
||||
diffusers[torch]
|
||||
torch==1.13.0+cu117
|
||||
--extra-index-url https://download.pytorch.org/whl/cu117
|
|
@ -0,0 +1 @@
|
|||
git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
|
|
@ -0,0 +1,2 @@
|
|||
pytest
|
||||
pytest-trio
|
|
@ -0,0 +1,8 @@
|
|||
trio
|
||||
pynng
|
||||
triopg
|
||||
aiohttp
|
||||
msgspec
|
||||
trio_asyncio
|
||||
|
||||
git+https://github.com/goodboy/tractor.git@master#egg=tractor
|
14
run-bot.sh
14
run-bot.sh
|
@ -1,14 +0,0 @@
|
|||
mkdir -p outputs
|
||||
mkdir -p hf_home
|
||||
|
||||
docker run \
|
||||
-it \
|
||||
--rm \
|
||||
--gpus=all \
|
||||
--env HF_TOKEN='' \
|
||||
--env DB_USER='skynet' \
|
||||
--env DB_PASS='nnf01nmf091d0i' \
|
||||
--mount type=bind,source="$(pwd)"/outputs,target=/outputs \
|
||||
--mount type=bind,source="$(pwd)"/hf_home,target=/hf_home \
|
||||
--mount type=bind,source="$(pwd)"/scripts,target=/scripts \
|
||||
skynet:dif python telegram-bot-dev.py
|
|
@ -1,9 +0,0 @@
|
|||
docker run
|
||||
-d \
|
||||
--rm \
|
||||
-p 27017:27017 \
|
||||
--name mongodb-skynet \
|
||||
--mount type=bind,source="$(pwd)"/mongodb,target=/data/db \
|
||||
-e MONGO_INITDB_ROOT_USERNAME="" \
|
||||
-e MONGO_INITDB_ROOT_PASSWORD="" \
|
||||
mongo
|
|
@ -1,537 +0,0 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import os
|
||||
|
||||
import logging
|
||||
import random
|
||||
|
||||
from torch.multiprocessing import spawn
|
||||
|
||||
import telebot
|
||||
from telebot.types import InputFile
|
||||
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.multiprocessing.spawn import ProcessRaisedException
|
||||
from diffusers import (
|
||||
StableDiffusionPipeline,
|
||||
EulerAncestralDiscreteScheduler
|
||||
)
|
||||
|
||||
from huggingface_hub import login
|
||||
from datetime import datetime
|
||||
|
||||
from pymongo import MongoClient
|
||||
|
||||
from typing import Tuple, Optional
|
||||
|
||||
db_user = os.environ['DB_USER']
|
||||
db_pass = os.environ['DB_PASS']
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
MEM_FRACTION = .33
|
||||
|
||||
ALGOS = {
|
||||
'stable': 'runwayml/stable-diffusion-v1-5',
|
||||
'midj': 'prompthero/openjourney',
|
||||
'hdanime': 'Linaqruf/anything-v3.0',
|
||||
'waifu': 'hakurei/waifu-diffusion',
|
||||
'ghibli': 'nitrosocke/Ghibli-Diffusion',
|
||||
'van-gogh': 'dallinmackay/Van-Gogh-diffusion',
|
||||
'pokemon': 'lambdalabs/sd-pokemon-diffusers',
|
||||
'ink': 'Envvi/Inkpunk-Diffusion',
|
||||
'robot': 'nousr/robo-diffusion'
|
||||
}
|
||||
|
||||
N = '\n'
|
||||
HELP_TEXT = f'''
|
||||
test art bot v0.1a4
|
||||
|
||||
commands work on a user per user basis!
|
||||
config is individual to each user!
|
||||
|
||||
/txt2img TEXT - request an image based on a prompt
|
||||
|
||||
/redo - redo last prompt
|
||||
|
||||
/cool - list of cool words to use
|
||||
/stats - user statistics
|
||||
/donate - see donation info
|
||||
|
||||
/config algo NAME - select AI to use one of:
|
||||
|
||||
{N.join(ALGOS.keys())}
|
||||
|
||||
/config step NUMBER - set amount of iterations
|
||||
/config seed NUMBER - set the seed, deterministic results!
|
||||
/config size WIDTH HEIGHT - set size in pixels
|
||||
/config guidance NUMBER - prompt text importance
|
||||
'''
|
||||
|
||||
UNKNOWN_CMD_TEXT = 'unknown command! try sending \"/help\"'
|
||||
|
||||
DONATION_INFO = '0xf95335682DF281FFaB7E104EB87B69625d9622B6\ngoal: 25/650usd'
|
||||
|
||||
COOL_WORDS = [
|
||||
'cyberpunk',
|
||||
'soviet propaganda poster',
|
||||
'rastafari',
|
||||
'cannabis',
|
||||
'art deco',
|
||||
'H R Giger Necronom IV',
|
||||
'dimethyltryptamine',
|
||||
'lysergic',
|
||||
'slut',
|
||||
'psilocybin',
|
||||
'trippy',
|
||||
'lucy in the sky with diamonds',
|
||||
'fractal',
|
||||
'da vinci',
|
||||
'pencil illustration',
|
||||
'blueprint',
|
||||
'internal diagram',
|
||||
'baroque',
|
||||
'the last judgment',
|
||||
'michelangelo'
|
||||
]
|
||||
|
||||
GROUP_ID = -1001541979235
|
||||
|
||||
MP_ENABLED_ROLES = ['god']
|
||||
|
||||
MIN_STEP = 1
|
||||
MAX_STEP = 100
|
||||
MAX_SIZE = (512, 656)
|
||||
MAX_GUIDANCE = 20
|
||||
|
||||
DEFAULT_SIZE = (512, 512)
|
||||
DEFAULT_GUIDANCE = 7.5
|
||||
DEFAULT_STEP = 75
|
||||
DEFAULT_CREDITS = 10
|
||||
DEFAULT_ALGO = 'stable'
|
||||
DEFAULT_ROLE = 'pleb'
|
||||
DEFAULT_UPSCALER = None
|
||||
|
||||
rr_total = 1
|
||||
rr_id = 0
|
||||
request_counter = 0
|
||||
|
||||
def its_my_turn():
|
||||
global request_counter, rr_total, rr_id
|
||||
my_turn = request_counter % rr_total == rr_id
|
||||
logging.info(f'new request {request_counter}, turn: {my_turn} rr_total: {rr_total}, rr_id {rr_id}')
|
||||
request_counter += 1
|
||||
return my_turn
|
||||
|
||||
def round_robined(func):
|
||||
def rr_wrapper(*args, **kwargs):
|
||||
if not its_my_turn():
|
||||
return
|
||||
|
||||
func(*args, **kwargs)
|
||||
|
||||
return rr_wrapper
|
||||
|
||||
|
||||
def generate_image(
|
||||
i: int,
|
||||
prompt: str,
|
||||
name: str,
|
||||
step: int,
|
||||
size: Tuple[int, int],
|
||||
guidance: int,
|
||||
seed: int,
|
||||
algo: str,
|
||||
upscaler: Optional[str]
|
||||
):
|
||||
assert torch.cuda.is_available()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.set_per_process_memory_fraction(MEM_FRACTION)
|
||||
with torch.no_grad():
|
||||
if algo == 'stable':
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
'runwayml/stable-diffusion-v1-5',
|
||||
torch_dtype=torch.float16,
|
||||
revision="fp16",
|
||||
safety_checker=None
|
||||
)
|
||||
|
||||
else:
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
ALGOS[algo],
|
||||
torch_dtype=torch.float16,
|
||||
safety_checker=None
|
||||
)
|
||||
|
||||
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||
pipe = pipe.to("cuda")
|
||||
w, h = size
|
||||
print(f'generating image... of size {w, h}')
|
||||
image = pipe(
|
||||
prompt,
|
||||
width=w,
|
||||
height=h,
|
||||
guidance_scale=guidance, num_inference_steps=step,
|
||||
generator=torch.Generator("cuda").manual_seed(seed)
|
||||
).images[0]
|
||||
|
||||
if upscaler == 'x4':
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
'stabilityai/stable-diffusion-x4-upscaler',
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
image = pipe(prompt=prompt, image=image).images[0]
|
||||
|
||||
|
||||
image.save(f'/outputs/{name}.png')
|
||||
print('saved')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
API_TOKEN = '5880619053:AAFge2UfObw1kCn9Kb7AAyqaIHs_HgM0Fx0'
|
||||
|
||||
bot = telebot.TeleBot(API_TOKEN)
|
||||
db_client = MongoClient(
|
||||
host=['ancap.tech:64000'],
|
||||
username=db_user,
|
||||
password=db_pass)
|
||||
|
||||
tgdb = db_client.get_database('telegram')
|
||||
|
||||
collections = tgdb.list_collection_names()
|
||||
|
||||
if 'users' in collections:
|
||||
tg_users = tgdb.get_collection('users')
|
||||
# tg_users.delete_many({})
|
||||
|
||||
else:
|
||||
tg_users = tgdb.create_collection('users')
|
||||
|
||||
# db functions
|
||||
|
||||
def get_user(uid: int):
|
||||
return tg_users.find_one({'uid': uid})
|
||||
|
||||
|
||||
def new_user(uid: int):
|
||||
if get_user(uid):
|
||||
raise ValueError('User already present on db')
|
||||
|
||||
res = tg_users.insert_one({
|
||||
'generated': 0,
|
||||
'uid': uid,
|
||||
'credits': DEFAULT_CREDITS,
|
||||
'joined': datetime.utcnow().isoformat(),
|
||||
'last_prompt': None,
|
||||
'role': DEFAULT_ROLE,
|
||||
'config': {
|
||||
'algo': DEFAULT_ALGO,
|
||||
'step': DEFAULT_STEP,
|
||||
'size': DEFAULT_SIZE,
|
||||
'seed': None,
|
||||
'guidance': DEFAULT_GUIDANCE,
|
||||
'upscaler': DEFAULT_UPSCALER
|
||||
}
|
||||
})
|
||||
|
||||
assert res.acknowledged
|
||||
|
||||
return get_user(uid)
|
||||
|
||||
def migrate_user(db_user):
|
||||
# new: user roles
|
||||
if 'role' not in db_user:
|
||||
res = tg_users.find_one_and_update(
|
||||
{'uid': db_user['uid']}, {'$set': {'role': DEFAULT_ROLE}})
|
||||
|
||||
# new: algo selection
|
||||
if 'algo' not in db_user['config']:
|
||||
res = tg_users.find_one_and_update(
|
||||
{'uid': db_user['uid']}, {'$set': {'config.algo': DEFAULT_ALGO}})
|
||||
|
||||
# new: upscaler selection
|
||||
if 'upscaler' not in db_user['config']:
|
||||
res = tg_users.find_one_and_update(
|
||||
{'uid': db_user['uid']}, {'$set': {'config.upscaler': DEFAULT_UPSCALER}})
|
||||
|
||||
return get_user(db_user['uid'])
|
||||
|
||||
def get_or_create_user(uid: int):
|
||||
db_user = get_user(uid)
|
||||
|
||||
if not db_user:
|
||||
db_user = new_user(uid)
|
||||
|
||||
logging.info(f'req from: {uid}')
|
||||
|
||||
return migrate_user(db_user)
|
||||
|
||||
def update_user(uid: int, updt_cmd: dict):
|
||||
user = get_user(uid)
|
||||
if not user:
|
||||
raise ValueError('User not present on db')
|
||||
|
||||
return tg_users.find_one_and_update(
|
||||
{'uid': uid}, updt_cmd)
|
||||
|
||||
|
||||
# bot handler
|
||||
def img_for_user_with_prompt(
|
||||
uid: int,
|
||||
prompt: str, step: int, size: Tuple[int, int], guidance: int, seed: int,
|
||||
algo: str, upscaler: Optional[str]
|
||||
):
|
||||
name = uuid.uuid4()
|
||||
|
||||
spawn(
|
||||
generate_image,
|
||||
args=(prompt, name, step, size, guidance, seed, algo, upscaler))
|
||||
|
||||
logging.info(f'done generating. got {name}, sending...')
|
||||
|
||||
if len(prompt) > 256:
|
||||
reply_txt = f'prompt: \"{prompt[:256]}...\"\n(full prompt too big to show on reply...)\n'
|
||||
|
||||
else:
|
||||
reply_txt = f'prompt: \"{prompt}\"\n'
|
||||
|
||||
reply_txt += f'seed: {seed}\n'
|
||||
reply_txt += f'iterations: {step}\n'
|
||||
reply_txt += f'size: {size}\n'
|
||||
reply_txt += f'guidance: {guidance}\n'
|
||||
reply_txt += f'algo: {ALGOS[algo]}\n'
|
||||
reply_txt += f'euler ancestral discrete'
|
||||
|
||||
return reply_txt, name
|
||||
|
||||
@bot.message_handler(commands=['help'])
|
||||
@round_robined
|
||||
def send_help(message):
|
||||
bot.reply_to(message, HELP_TEXT)
|
||||
|
||||
@bot.message_handler(commands=['cool'])
|
||||
@round_robined
|
||||
def send_cool_words(message):
|
||||
bot.reply_to(message, '\n'.join(COOL_WORDS))
|
||||
|
||||
@bot.message_handler(commands=['txt2img'])
|
||||
@round_robined
|
||||
def send_txt2img(message):
|
||||
chat = message.chat
|
||||
user = message.from_user
|
||||
db_user = get_or_create_user(user.id)
|
||||
|
||||
if ((chat.type != 'group' and chat.id != GROUP_ID) and
|
||||
(db_user['role'] not in MP_ENABLED_ROLES)):
|
||||
return
|
||||
|
||||
prompt = ' '.join(message.text.split(' ')[1:])
|
||||
|
||||
if len(prompt) == 0:
|
||||
bot.reply_to(message, 'empty text prompt ignored.')
|
||||
return
|
||||
|
||||
logging.info(f"{user.first_name} ({user.id}) on chat {chat.id} txt2img: {prompt}")
|
||||
|
||||
user_conf = db_user['config']
|
||||
|
||||
algo = user_conf['algo']
|
||||
step = user_conf['step']
|
||||
size = user_conf['size']
|
||||
seed = user_conf['seed'] if user_conf['seed'] else random.randint(0, 999999999)
|
||||
guidance = user_conf['guidance']
|
||||
upscaler = user_conf['upscaler']
|
||||
|
||||
try:
|
||||
reply_txt, name = img_for_user_with_prompt(
|
||||
user.id, prompt, step, size, guidance, seed, algo, upscaler)
|
||||
|
||||
update_user(
|
||||
user.id,
|
||||
{'$set': {
|
||||
'generated': db_user['generated'] + 1,
|
||||
'last_prompt': prompt
|
||||
}})
|
||||
|
||||
bot.send_photo(
|
||||
chat.id,
|
||||
caption=f'sent by: {user.first_name}\n' + reply_txt,
|
||||
photo=InputFile(f'/outputs/{name}.png'))
|
||||
|
||||
except BaseException as e:
|
||||
logging.error(e)
|
||||
bot.reply_to(message, 'that command caused an error :(\nchange settings and try again (?')
|
||||
|
||||
@bot.message_handler(commands=['redo'])
|
||||
@round_robined
|
||||
def redo_txt2img(message):
|
||||
# check msg comes from testing group
|
||||
chat = message.chat
|
||||
user = message.from_user
|
||||
db_user = get_or_create_user(user.id)
|
||||
|
||||
if ((chat.type != 'group' and chat.id != GROUP_ID) and
|
||||
(db_user['role'] not in MP_ENABLED_ROLES)):
|
||||
return
|
||||
|
||||
prompt = db_user['last_prompt']
|
||||
|
||||
if not prompt:
|
||||
bot.reply_to(message, 'do a /txt2img command first silly!')
|
||||
return
|
||||
|
||||
user_conf = db_user['config']
|
||||
|
||||
algo = user_conf['algo']
|
||||
step = user_conf['step']
|
||||
size = user_conf['size']
|
||||
seed = user_conf['seed'] if user_conf['seed'] else random.randint(0, 999999999)
|
||||
guidance = user_conf['guidance']
|
||||
upscaler = user_conf['upscaler']
|
||||
|
||||
logging.info(f"{user.first_name} ({user.id}) on chat {chat.id} redo: {prompt}")
|
||||
|
||||
try:
|
||||
reply_txt, name = img_for_user_with_prompt(
|
||||
user.id, prompt, step, size, guidance, seed, algo, upscaler)
|
||||
|
||||
update_user(
|
||||
user.id,
|
||||
{'$set': {
|
||||
'generated': db_user['generated'] + 1,
|
||||
}})
|
||||
|
||||
bot.send_photo(
|
||||
chat.id,
|
||||
caption=f'sent by: {user.first_name}\n' + reply_txt,
|
||||
photo=InputFile(f'/outputs/{name}.png'))
|
||||
|
||||
except BaseException as e:
|
||||
logging.error(e)
|
||||
bot.reply_to(message, 'that command caused an error :(\nchange settings and try again (?')
|
||||
|
||||
@bot.message_handler(commands=['config'])
|
||||
@round_robined
|
||||
def set_config(message):
|
||||
logging.info(f'config req on chat: {message.chat.id}')
|
||||
|
||||
params = message.text.split(' ')
|
||||
|
||||
if len(params) < 3:
|
||||
bot.reply_to(message, 'wrong msg format')
|
||||
|
||||
else:
|
||||
user = message.from_user
|
||||
chat = message.chat
|
||||
db_user = get_or_create_user(user.id)
|
||||
|
||||
try:
|
||||
attr = params[1]
|
||||
|
||||
if attr == 'algo':
|
||||
val = params[2]
|
||||
assert val in ALGOS
|
||||
res = update_user(user.id, {'$set': {'config.algo': val}})
|
||||
|
||||
elif attr == 'step':
|
||||
val = int(params[2])
|
||||
val = max(min(val, MAX_STEP), MIN_STEP)
|
||||
res = update_user(user.id, {'$set': {'config.step': val}})
|
||||
|
||||
elif attr == 'size':
|
||||
max_w, max_h = MAX_SIZE
|
||||
w = max(min(int(params[2]), max_w), 16)
|
||||
h = max(min(int(params[3]), max_h), 16)
|
||||
|
||||
val = (w, h)
|
||||
|
||||
if (w % 8 != 0) or (h % 8 != 0):
|
||||
bot.reply_to(message, 'size must be divisible by 8!')
|
||||
return
|
||||
|
||||
res = update_user(user.id, {'$set': {'config.size': val}})
|
||||
|
||||
elif attr == 'seed':
|
||||
val = params[2]
|
||||
if val == 'auto':
|
||||
val = None
|
||||
else:
|
||||
val = int(params[2])
|
||||
|
||||
res = update_user(user.id, {'$set': {'config.seed': val}})
|
||||
|
||||
elif attr == 'guidance':
|
||||
val = float(params[2])
|
||||
val = max(min(val, MAX_GUIDANCE), 0)
|
||||
res = update_user(user.id, {'$set': {'config.guidance': val}})
|
||||
|
||||
elif attr == 'upscaler':
|
||||
val = params[2]
|
||||
if val == 'off':
|
||||
val = None
|
||||
|
||||
res = update_user(user.id, {'$set': {'config.upscaler': val}})
|
||||
|
||||
else:
|
||||
bot.reply_to(message, f'\"{attr}\" not a parameter')
|
||||
|
||||
bot.reply_to(message, f'config updated! {attr} to {val}')
|
||||
|
||||
except ValueError:
|
||||
bot.reply_to(message, f'\"{val}\" is not a number silly')
|
||||
|
||||
except AssertionError:
|
||||
bot.reply_to(message, f'no algo named {val}')
|
||||
|
||||
@bot.message_handler(commands=['stats'])
|
||||
@round_robined
|
||||
def user_stats(message):
|
||||
user = message.from_user
|
||||
db_user = get_or_create_user(user.id)
|
||||
migrate_user(db_user)
|
||||
|
||||
joined_date_str = datetime.fromisoformat(db_user['joined']).strftime('%B the %dth %Y, %H:%M:%S')
|
||||
|
||||
user_stats_str = f'generated: {db_user["generated"]}\n'
|
||||
user_stats_str += f'joined: {joined_date_str}\n'
|
||||
user_stats_str += f'credits: {db_user["credits"]}\n'
|
||||
user_stats_str += f'role: {db_user["role"]}\n'
|
||||
|
||||
bot.reply_to(
|
||||
message, user_stats_str)
|
||||
|
||||
@bot.message_handler(commands=['donate'])
|
||||
@round_robined
|
||||
def donation_info(message):
|
||||
bot.reply_to(
|
||||
message, DONATION_INFO)
|
||||
|
||||
@bot.message_handler(commands=['say'])
|
||||
@round_robined
|
||||
def say(message):
|
||||
chat = message.chat
|
||||
user = message.from_user
|
||||
db_user = get_or_create_user(user.id)
|
||||
|
||||
if (chat.type == 'group') or (db_user['role'] not in MP_ENABLED_ROLES):
|
||||
return
|
||||
|
||||
bot.send_message(GROUP_ID, message.text[4:])
|
||||
|
||||
@bot.message_handler(func=lambda message: True)
|
||||
@round_robined
|
||||
def echo_message(message):
|
||||
if message.text[0] == '/':
|
||||
bot.reply_to(message, UNKNOWN_CMD_TEXT)
|
||||
|
||||
|
||||
login(token=os.environ['HF_TOKEN'])
|
||||
|
||||
bot.infinity_polling()
|
|
@ -0,0 +1,11 @@
|
|||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name='skynet-bot',
|
||||
version='0.1.0a6',
|
||||
description='Decentralized compute platform',
|
||||
author='Guillermo Rodriguez',
|
||||
author_email='guillermo@telos.net',
|
||||
packages=find_packages(),
|
||||
install_requires=[]
|
||||
)
|
|
@ -0,0 +1,2 @@
|
|||
#!/usr/bin/python
|
||||
|
|
@ -0,0 +1,246 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import json
|
||||
import uuid
|
||||
import base64
|
||||
import logging
|
||||
|
||||
from uuid import UUID
|
||||
from functools import partial
|
||||
from collections import OrderedDict
|
||||
|
||||
import trio
|
||||
import pynng
|
||||
import trio_asyncio
|
||||
|
||||
from .db import *
|
||||
from .types import *
|
||||
from .constants import *
|
||||
|
||||
|
||||
class SkynetDGPUOffline(BaseException):
|
||||
...
|
||||
|
||||
class SkynetDGPUOverloaded(BaseException):
|
||||
...
|
||||
|
||||
|
||||
async def rpc_service(sock, dgpu_bus, db_pool):
|
||||
nodes = OrderedDict()
|
||||
wip_reqs = {}
|
||||
fin_reqs = {}
|
||||
|
||||
def are_all_workers_busy():
|
||||
for nid, info in nodes.items():
|
||||
if info['task'] == None:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
next_worker = 0
|
||||
def get_next_worker():
|
||||
nonlocal next_worker
|
||||
|
||||
if len(nodes) == 0:
|
||||
raise SkynetDGPUOffline
|
||||
|
||||
if are_all_workers_busy():
|
||||
raise SkynetDGPUOverloaded
|
||||
|
||||
next_worker += 1
|
||||
|
||||
if next_worker >= len(nodes):
|
||||
next_worker = 0
|
||||
|
||||
nid = list(nodes.keys())[next_worker]
|
||||
return nid
|
||||
|
||||
async def dgpu_image_streamer():
|
||||
nonlocal wip_reqs, fin_reqs
|
||||
while True:
|
||||
msg = await dgpu_bus.arecv_msg()
|
||||
rid = UUID(bytes=msg.bytes[:16]).hex
|
||||
img = msg.bytes[16:].hex()
|
||||
fin_reqs[rid] = img
|
||||
event = wip_reqs[rid]
|
||||
event.set()
|
||||
del wip_reqs[rid]
|
||||
|
||||
async def dgpu_stream_one_img(req: ImageGenRequest):
|
||||
nonlocal wip_reqs, fin_reqs, next_worker
|
||||
nid = get_next_worker()
|
||||
logging.info(f'dgpu_stream_one_img {next_worker} {nid}')
|
||||
rid = uuid.uuid4().hex
|
||||
event = trio.Event()
|
||||
wip_reqs[rid] = event
|
||||
|
||||
nodes[nid]['task'] = rid
|
||||
|
||||
dgpu_req = DGPUBusRequest(
|
||||
rid=rid,
|
||||
nid=nid,
|
||||
task='diffuse',
|
||||
params=req.to_dict())
|
||||
|
||||
logging.info(f'dgpu_bus req: {dgpu_req}')
|
||||
|
||||
await dgpu_bus.asend(
|
||||
json.dumps(dgpu_req.to_dict()).encode())
|
||||
|
||||
await event.wait()
|
||||
|
||||
nodes[nid]['task'] = None
|
||||
|
||||
img = fin_reqs[rid]
|
||||
del fin_reqs[rid]
|
||||
|
||||
logging.info(f'done streaming {img}')
|
||||
|
||||
return rid, img
|
||||
|
||||
async def handle_user_request(rpc_ctx, req):
|
||||
try:
|
||||
async with db_pool.acquire() as conn:
|
||||
user = await get_or_create_user(conn, req.uid)
|
||||
|
||||
result = {}
|
||||
|
||||
match req.method:
|
||||
case 'txt2img':
|
||||
logging.info('txt2img')
|
||||
user_config = {**(await get_user_config(conn, user))}
|
||||
del user_config['id']
|
||||
prompt = req.params['prompt']
|
||||
req = ImageGenRequest(
|
||||
prompt=prompt,
|
||||
**user_config
|
||||
)
|
||||
rid, img = await dgpu_stream_one_img(req)
|
||||
result = {
|
||||
'id': rid,
|
||||
'img': img
|
||||
}
|
||||
|
||||
case 'redo':
|
||||
logging.info('redo')
|
||||
user_config = await get_user_config(conn, user)
|
||||
prompt = await get_last_prompt_of(conn, user)
|
||||
req = ImageGenRequest(
|
||||
prompt=prompt,
|
||||
**user_config
|
||||
)
|
||||
rid, img = await dgpu_stream_one_img(req)
|
||||
result = {
|
||||
'id': rid,
|
||||
'img': img
|
||||
}
|
||||
|
||||
case 'config':
|
||||
logging.info('config')
|
||||
if req.params['attr'] in CONFIG_ATTRS:
|
||||
await update_user_config(
|
||||
conn, user, req.params['attr'], req.params['val'])
|
||||
|
||||
case 'stats':
|
||||
logging.info('stats')
|
||||
generated, joined, role = await get_user_stats(conn, user)
|
||||
|
||||
result = {
|
||||
'generated': generated,
|
||||
'joined': joined.strftime(DATE_FORMAT),
|
||||
'role': role
|
||||
}
|
||||
|
||||
case _:
|
||||
logging.warn('unknown method')
|
||||
|
||||
except SkynetDGPUOffline:
|
||||
result = {
|
||||
'error': 'skynet_dgpu_offline'
|
||||
}
|
||||
|
||||
except SkynetDGPUOverloaded:
|
||||
result = {
|
||||
'error': 'skynet_dgpu_overloaded',
|
||||
'nodes': len(nodes)
|
||||
}
|
||||
|
||||
except BaseException as e:
|
||||
logging.error(e)
|
||||
raise e
|
||||
# result = {
|
||||
# 'error': 'skynet_internal_error'
|
||||
# }
|
||||
|
||||
await rpc_ctx.asend(
|
||||
json.dumps(
|
||||
SkynetRPCResponse(result=result).to_dict()).encode())
|
||||
|
||||
|
||||
async with trio.open_nursery() as n:
|
||||
n.start_soon(dgpu_image_streamer)
|
||||
while True:
|
||||
ctx = sock.new_context()
|
||||
msg = await ctx.arecv_msg()
|
||||
content = msg.bytes.decode()
|
||||
req = SkynetRPCRequest(**json.loads(content))
|
||||
|
||||
logging.info(req)
|
||||
|
||||
if req.method == 'dgpu_online':
|
||||
nodes[req.uid] = {
|
||||
'task': None
|
||||
}
|
||||
logging.info(f'dgpu online: {req.uid}')
|
||||
|
||||
|
||||
elif req.method == 'dgpu_offline':
|
||||
i = nodes.values().index(req.uid)
|
||||
del nodes[req.uid]
|
||||
|
||||
if i < next_worker:
|
||||
next_worker -= 1
|
||||
logging.info(f'dgpu offline: {req.uid}')
|
||||
|
||||
else:
|
||||
n.start_soon(
|
||||
handle_user_request, ctx, req)
|
||||
continue
|
||||
|
||||
await ctx.asend(
|
||||
json.dumps(
|
||||
SkynetRPCResponse(
|
||||
result={'ok': {}}).to_dict()).encode())
|
||||
|
||||
|
||||
async def run_skynet(
|
||||
db_user: str,
|
||||
db_pass: str,
|
||||
db_host: str = DB_HOST,
|
||||
rpc_address: str = DEFAULT_RPC_ADDR,
|
||||
dgpu_address: str = DEFAULT_DGPU_ADDR,
|
||||
task_status = trio.TASK_STATUS_IGNORED
|
||||
):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.info('skynet is starting')
|
||||
|
||||
async with (
|
||||
trio.open_nursery() as n,
|
||||
open_database_connection(
|
||||
db_user, db_pass, db_host) as db_pool
|
||||
):
|
||||
logging.info('connected to db.')
|
||||
with (
|
||||
pynng.Rep0(listen=rpc_address) as rpc_sock,
|
||||
pynng.Bus0(listen=dgpu_address) as dgpu_bus
|
||||
):
|
||||
n.start_soon(
|
||||
rpc_service, rpc_sock, dgpu_bus, db_pool)
|
||||
task_status.started()
|
||||
|
||||
try:
|
||||
await trio.sleep_forever()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
...
|
||||
|
|
@ -0,0 +1,129 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
API_TOKEN = '5880619053:AAFge2UfObw1kCn9Kb7AAyqaIHs_HgM0Fx0'
|
||||
|
||||
DB_HOST = 'ancap.tech:34508'
|
||||
|
||||
ALGOS = {
|
||||
'stable': 'runwayml/stable-diffusion-v1-5',
|
||||
'midj': 'prompthero/openjourney',
|
||||
'hdanime': 'Linaqruf/anything-v3.0',
|
||||
'waifu': 'hakurei/waifu-diffusion',
|
||||
'ghibli': 'nitrosocke/Ghibli-Diffusion',
|
||||
'van-gogh': 'dallinmackay/Van-Gogh-diffusion',
|
||||
'pokemon': 'lambdalabs/sd-pokemon-diffusers',
|
||||
'ink': 'Envvi/Inkpunk-Diffusion',
|
||||
'robot': 'nousr/robo-diffusion'
|
||||
}
|
||||
|
||||
N = '\n'
|
||||
HELP_TEXT = f'''
|
||||
test art bot v0.1a4
|
||||
|
||||
commands work on a user per user basis!
|
||||
config is individual to each user!
|
||||
|
||||
/txt2img TEXT - request an image based on a prompt
|
||||
|
||||
/redo - re ont
|
||||
|
||||
/help step - get info on step config option
|
||||
/help guidance - get info on guidance config option
|
||||
|
||||
/cool - list of cool words to use
|
||||
/stats - user statistics
|
||||
/donate - see donation info
|
||||
|
||||
/config algo NAME - select AI to use one of:
|
||||
|
||||
{N.join(ALGOS.keys())}
|
||||
|
||||
/config step NUMBER - set amount of iterations
|
||||
/config seed NUMBER - set the seed, deterministic results!
|
||||
/config size WIDTH HEIGHT - set size in pixels
|
||||
/config guidance NUMBER - prompt text importance
|
||||
'''
|
||||
|
||||
UNKNOWN_CMD_TEXT = 'unknown command! try sending \"/help\"'
|
||||
|
||||
DONATION_INFO = '0xf95335682DF281FFaB7E104EB87B69625d9622B6\ngoal: 25/650usd'
|
||||
|
||||
COOL_WORDS = [
|
||||
'cyberpunk',
|
||||
'soviet propaganda poster',
|
||||
'rastafari',
|
||||
'cannabis',
|
||||
'art deco',
|
||||
'H R Giger Necronom IV',
|
||||
'dimethyltryptamine',
|
||||
'lysergic',
|
||||
'slut',
|
||||
'psilocybin',
|
||||
'trippy',
|
||||
'lucy in the sky with diamonds',
|
||||
'fractal',
|
||||
'da vinci',
|
||||
'pencil illustration',
|
||||
'blueprint',
|
||||
'internal diagram',
|
||||
'baroque',
|
||||
'the last judgment',
|
||||
'michelangelo'
|
||||
]
|
||||
|
||||
HELP_STEP = '''
|
||||
diffusion models are iterative processes – a repeated cycle that starts with a\
|
||||
random noise generated from text input. With each step, some noise is removed\
|
||||
, resulting in a higher-quality image over time. The repetition stops when the\
|
||||
desired number of steps completes.
|
||||
|
||||
around 25 sampling steps are usually enough to achieve high-quality images. Us\
|
||||
ing more may produce a slightly different picture, but not necessarily better \
|
||||
quality.
|
||||
'''
|
||||
|
||||
HELP_GUIDANCE = '''
|
||||
the guidance scale is a parameter that controls how much the image generation\
|
||||
process follows the text prompt. The higher the value, the more image sticks\
|
||||
to a given text input.
|
||||
'''
|
||||
|
||||
HELP_UNKWNOWN_PARAM = 'don\'t have any info on that.'
|
||||
|
||||
GROUP_ID = -1001541979235
|
||||
|
||||
MP_ENABLED_ROLES = ['god']
|
||||
|
||||
MIN_STEP = 1
|
||||
MAX_STEP = 100
|
||||
MAX_WIDTH = 512
|
||||
MAX_HEIGHT = 656
|
||||
MAX_GUIDANCE = 20
|
||||
|
||||
DEFAULT_SEED = None
|
||||
DEFAULT_WIDTH = 512
|
||||
DEFAULT_HEIGHT = 512
|
||||
DEFAULT_GUIDANCE = 7.5
|
||||
DEFAULT_STEP = 35
|
||||
DEFAULT_CREDITS = 10
|
||||
DEFAULT_ALGO = 'midj'
|
||||
DEFAULT_ROLE = 'pleb'
|
||||
DEFAULT_UPSCALER = None
|
||||
|
||||
DEFAULT_RPC_ADDR = 'tcp://127.0.0.1:41000'
|
||||
|
||||
DEFAULT_DGPU_ADDR = 'tcp://127.0.0.1:41069'
|
||||
DEFAULT_DGPU_MAX_TASKS = 3
|
||||
DEFAULT_INITAL_ALGOS = ['midj', 'stable', 'ink']
|
||||
|
||||
DATE_FORMAT = '%B the %dth %Y, %H:%M:%S'
|
||||
|
||||
CONFIG_ATTRS = [
|
||||
'algo',
|
||||
'step',
|
||||
'width',
|
||||
'height',
|
||||
'seed',
|
||||
'guidance',
|
||||
'upscaler'
|
||||
]
|
|
@ -0,0 +1,146 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import logging
|
||||
|
||||
from datetime import datetime
|
||||
from contextlib import asynccontextmanager as acm
|
||||
|
||||
import trio
|
||||
import triopg
|
||||
|
||||
from .constants import *
|
||||
|
||||
|
||||
def try_decode_uid(uid: str):
|
||||
try:
|
||||
proto, uid = uid.split('+')
|
||||
uid = int(uid)
|
||||
return proto, uid
|
||||
|
||||
except ValueError:
|
||||
logging.warning(f'got non numeric uid?: {uid}')
|
||||
return None, None
|
||||
|
||||
|
||||
@acm
|
||||
async def open_database_connection(
|
||||
db_user: str,
|
||||
db_pass: str,
|
||||
db_host: str = DB_HOST,
|
||||
):
|
||||
async with triopg.create_pool(
|
||||
dsn=f'postgres://{db_user}:{db_pass}@{db_host}/skynet_art_bot'
|
||||
) as conn:
|
||||
yield conn
|
||||
|
||||
|
||||
async def get_user(conn, uid: str):
|
||||
if isinstance(uid, str):
|
||||
proto, uid = try_decode_uid(uid)
|
||||
|
||||
match proto:
|
||||
case 'tg':
|
||||
stmt = await conn.prepare(
|
||||
'SELECT * FROM skynet.user WHERE tg_id = $1')
|
||||
user = await stmt.fetchval(uid)
|
||||
|
||||
case _:
|
||||
user = None
|
||||
|
||||
return user
|
||||
|
||||
else: # asumme is our uid
|
||||
stmt = await conn.prepare(
|
||||
'SELECT * FROM skynet.user WHERE id = $1')
|
||||
return await stmt.fetchval(uid)
|
||||
|
||||
|
||||
async def get_user_config(conn, user: int):
|
||||
stmt = await conn.prepare(
|
||||
'SELECT * FROM skynet.user_config WHERE id = $1')
|
||||
return (await stmt.fetch(user))[0]
|
||||
|
||||
|
||||
async def get_last_prompt_of(conn, user: int):
|
||||
stms = await conn.prepare(
|
||||
'SELECT last_prompt FROM skynet.user WHERE id = $1')
|
||||
return await stmt.fetchval(user)
|
||||
|
||||
|
||||
async def new_user(conn, uid: str):
|
||||
if await get_user(conn, uid):
|
||||
raise ValueError('User already present on db')
|
||||
|
||||
logging.info(f'new user! {uid}')
|
||||
|
||||
tg_id = None
|
||||
date = datetime.utcnow()
|
||||
|
||||
proto, pid = try_decode_uid(uid)
|
||||
|
||||
match proto:
|
||||
case 'tg':
|
||||
tg_id = pid
|
||||
|
||||
async with conn.transaction():
|
||||
stmt = await conn.prepare('''
|
||||
INSERT INTO skynet.user(
|
||||
tg_id, generated, joined, last_prompt, role)
|
||||
|
||||
VALUES($1, $2, $3, $4, $5)
|
||||
''')
|
||||
await stmt.fetch(
|
||||
tg_id, 0, date, None, DEFAULT_ROLE
|
||||
)
|
||||
|
||||
new_uid = await get_user(conn, uid)
|
||||
|
||||
stmt = await conn.prepare('''
|
||||
INSERT INTO skynet.user_config(
|
||||
id, algo, step, width, height, seed, guidance, upscaler)
|
||||
|
||||
VALUES($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
''')
|
||||
user = await stmt.fetch(
|
||||
new_uid,
|
||||
DEFAULT_ALGO,
|
||||
DEFAULT_STEP,
|
||||
DEFAULT_WIDTH,
|
||||
DEFAULT_HEIGHT,
|
||||
DEFAULT_SEED,
|
||||
DEFAULT_GUIDANCE,
|
||||
DEFAULT_UPSCALER
|
||||
)
|
||||
|
||||
return new_uid
|
||||
|
||||
|
||||
async def get_or_create_user(conn, uid: str):
|
||||
user = await get_user(conn, uid)
|
||||
|
||||
if not user:
|
||||
user = await new_user(conn, uid)
|
||||
|
||||
return user
|
||||
|
||||
async def update_user(conn, user: int, attr: str, val):
|
||||
...
|
||||
|
||||
async def update_user_config(conn, user: int, attr: str, val):
|
||||
stmt = await conn.prepare(f'''
|
||||
UPDATE skynet.user_config
|
||||
SET {attr} = $2
|
||||
WHERE id = $1
|
||||
''')
|
||||
await stmt.fetch(user, val)
|
||||
|
||||
|
||||
async def get_user_stats(conn, user: int):
|
||||
stmt = await conn.prepare('''
|
||||
SELECT generated,joined,role FROM skynet.user
|
||||
WHERE id = $1
|
||||
''')
|
||||
records = await stmt.fetch(user)
|
||||
assert len(records) == 1
|
||||
record = records[0]
|
||||
return record
|
|
@ -0,0 +1,121 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import trio
|
||||
import json
|
||||
import uuid
|
||||
import logging
|
||||
|
||||
import pynng
|
||||
import tractor
|
||||
|
||||
from . import gpu
|
||||
from .gpu import open_gpu_worker
|
||||
from .types import *
|
||||
from .constants import *
|
||||
from .frontend import rpc_call
|
||||
|
||||
|
||||
async def open_dgpu_node(
|
||||
rpc_address: str = DEFAULT_RPC_ADDR,
|
||||
dgpu_address: str = DEFAULT_DGPU_ADDR,
|
||||
dgpu_max_tasks: int = DEFAULT_DGPU_MAX_TASKS,
|
||||
initial_algos: str = DEFAULT_INITAL_ALGOS
|
||||
):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
name = uuid.uuid4()
|
||||
workers = initial_algos.copy()
|
||||
tasks = [None for _ in range(dgpu_max_tasks)]
|
||||
|
||||
portal_map: dict[int, tractor.Portal]
|
||||
contexts: dict[int, tractor.Context]
|
||||
|
||||
def get_next_worker(need_algo: str):
|
||||
nonlocal workers, tasks
|
||||
for task, algo in zip(workers, tasks):
|
||||
if need_algo == algo and not task:
|
||||
return workers.index(need_algo)
|
||||
|
||||
return tasks.index(None)
|
||||
|
||||
async def gpu_streamer(
|
||||
ctx: tractor.Context,
|
||||
nid: int
|
||||
):
|
||||
nonlocal tasks
|
||||
async with ctx.open_stream() as stream:
|
||||
async for img in stream:
|
||||
tasks[nid]['res'] = img
|
||||
tasks[nid]['event'].set()
|
||||
|
||||
async def gpu_compute_one(ireq: ImageGenRequest):
|
||||
wid = get_next_worker(ireq.algo)
|
||||
event = trio.Event()
|
||||
|
||||
workers[wid] = ireq.algo
|
||||
tasks[wid] = {
|
||||
'res': None, 'event': event}
|
||||
|
||||
await contexts[i].send(ireq)
|
||||
|
||||
await event.wait()
|
||||
|
||||
img = tasks[wid]['res']
|
||||
tasks[wid] = None
|
||||
return img
|
||||
|
||||
|
||||
with (
|
||||
pynng.Req0(dial=rpc_address) as rpc_sock,
|
||||
pynng.Bus0(dial=dgpu_address) as dgpu_sock
|
||||
):
|
||||
async def _rpc_call(*args, **kwargs):
|
||||
return await rpc_call(rpc_sock, *args, **kwargs)
|
||||
|
||||
async def _process_dgpu_req(req: DGPUBusRequest):
|
||||
img = await gpu_compute_one(
|
||||
ImageGenRequest(**req.params))
|
||||
await dgpu_sock.asend(
|
||||
bytes.fromhex(req.rid) + img)
|
||||
|
||||
res = await _rpc_call(
|
||||
name.hex, 'dgpu_online', {'max_tasks': dgpu_max_tasks})
|
||||
logging.info(res)
|
||||
assert 'ok' in res.result
|
||||
|
||||
async with (
|
||||
tractor.open_actor_cluster(
|
||||
modules=['skynet_bot.gpu'],
|
||||
count=dgpu_max_tasks,
|
||||
names=[i for i in range(dgpu_max_tasks)]
|
||||
) as portal_map,
|
||||
trio.open_nursery() as n
|
||||
):
|
||||
logging.info(f'starting {dgpu_max_tasks} gpu workers')
|
||||
async with tractor.gather_contexts((
|
||||
ctx.open_context(
|
||||
open_gpu_worker, algo, 1.0 / dgpu_max_tasks)
|
||||
)) as contexts:
|
||||
contexts = {i: ctx for i, ctx in enumerate(contexts)}
|
||||
for i, ctx in contexts.items():
|
||||
n.start_soon(
|
||||
gpu_streamer, ctx, i)
|
||||
try:
|
||||
while True:
|
||||
msg = await dgpu_sock.arecv()
|
||||
req = DGPUBusRequest(
|
||||
**json.loads(msg.decode()))
|
||||
|
||||
if req.nid != name.hex:
|
||||
continue
|
||||
|
||||
logging.info(f'dgpu: {name}, req: {req}')
|
||||
n.start_soon(
|
||||
_process_dgpu_req, req)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
...
|
||||
|
||||
res = await _rpc_call(name.hex, 'dgpu_offline')
|
||||
logging.info(res)
|
||||
assert 'ok' in res.result
|
|
@ -0,0 +1,50 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import json
|
||||
|
||||
from typing import Union
|
||||
from contextlib import contextmanager as cm
|
||||
|
||||
import pynng
|
||||
|
||||
from ..types import SkynetRPCRequest, SkynetRPCResponse
|
||||
from ..constants import DEFAULT_RPC_ADDR
|
||||
|
||||
|
||||
class ConfigUnknownAttribute(BaseException):
|
||||
...
|
||||
|
||||
class ConfigUnknownAlgorithm(BaseException):
|
||||
...
|
||||
|
||||
class ConfigUnknownUpscaler(BaseException):
|
||||
...
|
||||
|
||||
class ConfigSizeDivisionByEight(BaseException):
|
||||
...
|
||||
|
||||
|
||||
async def rpc_call(
|
||||
sock,
|
||||
uid: Union[int, str],
|
||||
method: str,
|
||||
params: dict = {}
|
||||
):
|
||||
req = SkynetRPCRequest(
|
||||
uid=uid,
|
||||
method=method,
|
||||
params=params
|
||||
)
|
||||
await sock.asend(
|
||||
json.dumps(
|
||||
req.to_dict()).encode())
|
||||
|
||||
return SkynetRPCResponse(
|
||||
**json.loads(
|
||||
(await sock.arecv_msg()).bytes.decode()))
|
||||
|
||||
|
||||
@cm
|
||||
def open_skynet_rpc(rpc_address: str = DEFAULT_RPC_ADDR):
|
||||
with pynng.Req0(dial=rpc_address) as rpc_sock:
|
||||
yield rpc_sock
|
|
@ -0,0 +1,164 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import logging
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import pynng
|
||||
|
||||
from telebot.async_telebot import AsyncTeleBot
|
||||
from trio_asyncio import aio_as_trio
|
||||
|
||||
from ..constants import *
|
||||
|
||||
from . import *
|
||||
|
||||
|
||||
PREFIX = 'tg'
|
||||
|
||||
|
||||
async def run_skynet_telegram(tg_token: str):
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
bot = AsyncTeleBot(tg_token)
|
||||
|
||||
with open_skynet_rpc() as rpc_sock:
|
||||
|
||||
async def _rpc_call(
|
||||
uid: int,
|
||||
method: str,
|
||||
params: dict
|
||||
):
|
||||
return await rpc_call(
|
||||
rpc_sock, f'{PREFIX}+{uid}', method, params)
|
||||
|
||||
@bot.message_handler(commands=['help'])
|
||||
async def send_help(message):
|
||||
await bot.reply_to(message, HELP_TEXT)
|
||||
|
||||
@bot.message_handler(commands=['cool'])
|
||||
async def send_cool_words(message):
|
||||
await bot.reply_to(message, '\n'.join(COOL_WORDS))
|
||||
|
||||
@bot.message_handler(commands=['txt2img'])
|
||||
async def send_txt2img(message):
|
||||
resp = await _rpc_call(
|
||||
message.from_user.id,
|
||||
'txt2img',
|
||||
{}
|
||||
)
|
||||
|
||||
@bot.message_handler(commands=['redo'])
|
||||
async def redo_txt2img(message):
|
||||
resp = await _rpc_call(
|
||||
message.from_user.id,
|
||||
'redo',
|
||||
{}
|
||||
)
|
||||
|
||||
@bot.message_handler(commands=['config'])
|
||||
async def set_config(message):
|
||||
params = message.text.split(' ')
|
||||
|
||||
rpc_params = {}
|
||||
|
||||
if len(params) < 3:
|
||||
bot.reply_to(message, 'wrong msg format')
|
||||
|
||||
else:
|
||||
|
||||
try:
|
||||
attr = params[1]
|
||||
|
||||
if attr == 'algo':
|
||||
val = params[2]
|
||||
if val not in ALGOS:
|
||||
raise ConfigUnknownAlgorithm
|
||||
|
||||
elif attr == 'step':
|
||||
val = int(params[2])
|
||||
val = max(min(val, MAX_STEP), MIN_STEP)
|
||||
|
||||
elif attr == 'width':
|
||||
val = max(min(int(params[2]), MAX_WIDTH), 16)
|
||||
if val % 8 != 0:
|
||||
raise ConfigSizeDivisionByEight
|
||||
|
||||
elif attr == 'height':
|
||||
val = max(min(int(params[2]), MAX_HEIGHT), 16)
|
||||
if val % 8 != 0:
|
||||
raise ConfigSizeDivisionByEight
|
||||
|
||||
elif attr == 'seed':
|
||||
val = params[2]
|
||||
if val == 'auto':
|
||||
val = None
|
||||
else:
|
||||
val = int(params[2])
|
||||
|
||||
elif attr == 'guidance':
|
||||
val = float(params[2])
|
||||
val = max(min(val, MAX_GUIDANCE), 0)
|
||||
|
||||
elif attr == 'upscaler':
|
||||
val = params[2]
|
||||
if val == 'off':
|
||||
val = None
|
||||
elif val != 'x4':
|
||||
raise ConfigUnknownUpscaler
|
||||
|
||||
else:
|
||||
raise ConfigUnknownAttribute
|
||||
|
||||
resp = await _rpc_call(
|
||||
message.from_user.id,
|
||||
'config', {'attr': attr, 'val': val})
|
||||
|
||||
reply_txt = f'config updated! {attr} to {val}'
|
||||
|
||||
except ConfigUnknownAlgorithm:
|
||||
reply_txt = f'no algo named {val}'
|
||||
|
||||
except ConfigUnknownAttribute:
|
||||
reply_txt = f'\"{attr}\" not a configurable parameter'
|
||||
|
||||
except ConfigUnknownUpscaler:
|
||||
reply_txt = f'\"{val}\" is not a valid upscaler'
|
||||
|
||||
except ConfigSizeDivisionByEight:
|
||||
reply_txt = 'size must be divisible by 8!'
|
||||
|
||||
except ValueError:
|
||||
reply_txt = f'\"{val}\" is not a number silly'
|
||||
|
||||
await bot.reply_to(message, reply_txt)
|
||||
|
||||
@bot.message_handler(commands=['stats'])
|
||||
async def user_stats(message):
|
||||
resp = await _rpc_call(
|
||||
message.from_user.id,
|
||||
'stats',
|
||||
{}
|
||||
)
|
||||
stats = resp.result
|
||||
|
||||
stats_str = f'generated: {stats["generated"]}\n'
|
||||
stats_str += f'joined: {stats["joined"]}\n'
|
||||
stats_str += f'role: {stats["role"]}\n'
|
||||
|
||||
await bot.reply_to(
|
||||
message, stats_str)
|
||||
|
||||
@bot.message_handler(commands=['donate'])
|
||||
async def donation_info(message):
|
||||
await bot.reply_to(
|
||||
message, DONATION_INFO)
|
||||
|
||||
|
||||
@bot.message_handler(func=lambda message: True)
|
||||
async def echo_message(message):
|
||||
if message.text[0] == '/':
|
||||
await bot.reply_to(message, UNKNOWN_CMD_TEXT)
|
||||
|
||||
|
||||
await aio_as_trio(bot.infinity_polling())
|
|
@ -0,0 +1,75 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import io
|
||||
import random
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import tractor
|
||||
|
||||
from diffusers import (
|
||||
StableDiffusionPipeline,
|
||||
EulerAncestralDiscreteScheduler
|
||||
)
|
||||
|
||||
from .types import ImageGenRequest
|
||||
from .constants import ALGOS
|
||||
|
||||
|
||||
def pipeline_for(algo: str, mem_fraction: float):
|
||||
assert torch.cuda.is_available()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
params = {
|
||||
'torch_dtype': torch.float16,
|
||||
'safety_checker': None
|
||||
}
|
||||
|
||||
if algo == 'stable':
|
||||
params['revision'] = 'fp16'
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
ALGOS[algo], **params)
|
||||
|
||||
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
||||
pipe.scheduler.config)
|
||||
|
||||
return pipe.to("cuda")
|
||||
|
||||
@tractor.context
|
||||
async def open_gpu_worker(
|
||||
ctx: tractor.Context,
|
||||
start_algo: str,
|
||||
mem_fraction: float
|
||||
):
|
||||
current_algo = start_algo
|
||||
with torch.no_grad():
|
||||
pipe = pipeline_for(current_algo, mem_fraction)
|
||||
await ctx.started()
|
||||
|
||||
async with ctx.open_stream() as bus:
|
||||
async for ireq in bus:
|
||||
if ireq.algo != current_algo:
|
||||
current_algo = ireq.algo
|
||||
pipe = pipeline_for(current_algo, mem_fraction)
|
||||
|
||||
seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64)
|
||||
image = pipe(
|
||||
ireq.prompt,
|
||||
width=ireq.width,
|
||||
height=ireq.height,
|
||||
guidance_scale=ireq.guidance,
|
||||
num_inference_steps=ireq.step,
|
||||
generator=torch.Generator("cuda").manual_seed(seed)
|
||||
).images[0]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# convert PIL.Image to BytesIO
|
||||
img_bytes = io.BytesIO()
|
||||
image.save(img_bytes, format='PNG')
|
||||
await bus.send(img_bytes.getvalue())
|
||||
|
|
@ -0,0 +1,109 @@
|
|||
# piker: trading gear for hackers
|
||||
# Copyright (C) Guillermo Rodriguez (in stewardship for piker0)
|
||||
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
"""
|
||||
Built-in (extension) types.
|
||||
"""
|
||||
import sys
|
||||
from typing import Optional, Union
|
||||
from pprint import pformat
|
||||
|
||||
import msgspec
|
||||
|
||||
|
||||
class Struct(
|
||||
msgspec.Struct,
|
||||
|
||||
# https://jcristharif.com/msgspec/structs.html#tagged-unions
|
||||
# tag='pikerstruct',
|
||||
# tag=True,
|
||||
):
|
||||
'''
|
||||
A "human friendlier" (aka repl buddy) struct subtype.
|
||||
'''
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
f: getattr(self, f)
|
||||
for f in self.__struct_fields__
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
# only turn on pprint when we detect a python REPL
|
||||
# at runtime B)
|
||||
if (
|
||||
hasattr(sys, 'ps1')
|
||||
# TODO: check if we're in pdb
|
||||
):
|
||||
return self.pformat()
|
||||
|
||||
return super().__repr__()
|
||||
|
||||
def pformat(self) -> str:
|
||||
return f'Struct({pformat(self.to_dict())})'
|
||||
|
||||
def copy(
|
||||
self,
|
||||
update: Optional[dict] = None,
|
||||
|
||||
) -> msgspec.Struct:
|
||||
'''
|
||||
Validate-typecast all self defined fields, return a copy of us
|
||||
with all such fields.
|
||||
This is kinda like the default behaviour in `pydantic.BaseModel`.
|
||||
'''
|
||||
if update:
|
||||
for k, v in update.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
# roundtrip serialize to validate
|
||||
return msgspec.msgpack.Decoder(
|
||||
type=type(self)
|
||||
).decode(
|
||||
msgspec.msgpack.Encoder().encode(self)
|
||||
)
|
||||
|
||||
def typecast(
|
||||
self,
|
||||
# fields: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
for fname, ftype in self.__annotations__.items():
|
||||
setattr(self, fname, ftype(getattr(self, fname)))
|
||||
|
||||
# proto
|
||||
|
||||
class SkynetRPCRequest(Struct):
|
||||
uid: Union[str, int] # user unique id
|
||||
method: str # rpc method name
|
||||
params: dict # variable params
|
||||
|
||||
class SkynetRPCResponse(Struct):
|
||||
result: dict
|
||||
|
||||
class ImageGenRequest(Struct):
|
||||
prompt: str
|
||||
step: int
|
||||
width: int
|
||||
height: int
|
||||
guidance: int
|
||||
seed: Optional[int]
|
||||
algo: str
|
||||
upscaler: Optional[str]
|
||||
|
||||
class DGPUBusRequest(Struct):
|
||||
rid: str # req id
|
||||
nid: str # node id
|
||||
task: str
|
||||
params: dict
|
|
@ -0,0 +1,8 @@
|
|||
docker run \
|
||||
-it \
|
||||
--rm \
|
||||
--mount type=bind,source="$(pwd)",target=/skynet \
|
||||
skynet:runtime-cuda \
|
||||
bash -c \
|
||||
"cd /skynet && pip install -e . && \
|
||||
pytest tests/test_dgpu.py --log-cli-level=info"
|
|
@ -0,0 +1,55 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
|
||||
import trio
|
||||
import pynng
|
||||
import trio_asyncio
|
||||
|
||||
from skynet_bot.dgpu import open_dgpu_node
|
||||
from skynet_bot.types import *
|
||||
from skynet_bot.brain import run_skynet
|
||||
from skynet_bot.constants import *
|
||||
from skynet_bot.frontend import open_skynet_rpc, rpc_call
|
||||
|
||||
|
||||
def test_dgpu_simple():
|
||||
async def main():
|
||||
async with trio.open_nursery() as n:
|
||||
await n.start(
|
||||
run_skynet,
|
||||
'skynet', '3GbZd6UojbD8V7zWpeFn', 'ancap.tech:34508')
|
||||
|
||||
await trio.sleep(2)
|
||||
|
||||
for i in range(3):
|
||||
n.start_soon(open_dgpu_node)
|
||||
|
||||
await trio.sleep(1)
|
||||
start = time.time()
|
||||
async def request_img():
|
||||
with pynng.Req0(dial=DEFAULT_RPC_ADDR) as rpc_sock:
|
||||
res = await rpc_call(
|
||||
rpc_sock, 'tg+1', 'txt2img', {
|
||||
'prompt': 'test',
|
||||
'step': 28,
|
||||
'width': 512, 'height': 512,
|
||||
'guidance': 7.5,
|
||||
'seed': None,
|
||||
'algo': 'stable',
|
||||
'upscaler': None
|
||||
})
|
||||
|
||||
logging.info(res)
|
||||
|
||||
async with trio.open_nursery() as inner_n:
|
||||
for i in range(3):
|
||||
inner_n.start_soon(request_img)
|
||||
|
||||
logging.info(f'time elapsed: {time.time() - start}')
|
||||
n.cancel_scope.cancel()
|
||||
|
||||
|
||||
trio_asyncio.run(main)
|
|
@ -0,0 +1,22 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import trio
|
||||
import trio_asyncio
|
||||
|
||||
from skynet_bot.brain import run_skynet
|
||||
from skynet_bot.frontend import open_skynet_rpc
|
||||
from skynet_bot.frontend.telegram import run_skynet_telegram
|
||||
|
||||
|
||||
def test_run_tg_bot():
|
||||
async def main():
|
||||
async with trio.open_nursery() as n:
|
||||
await n.start(
|
||||
run_skynet,
|
||||
'skynet', '3GbZd6UojbD8V7zWpeFn', 'ancap.tech:34508')
|
||||
n.start_soon(
|
||||
run_skynet_telegram,
|
||||
'5853245787:AAFEmv3EjJ_qJ8d_vmOpi6o6HFHUf8a0uCQ')
|
||||
|
||||
|
||||
trio_asyncio.run(main)
|
Loading…
Reference in New Issue