mirror of https://github.com/skygpu/skynet.git
First decupled architecture, still working on integrating tractor gpu workers
parent
91e0693e65
commit
74d2426793
|
@ -0,0 +1,3 @@
|
||||||
|
hf_home
|
||||||
|
inputs
|
||||||
|
outputs
|
|
@ -1,2 +1,5 @@
|
||||||
|
.python-version
|
||||||
hf_home
|
hf_home
|
||||||
outputs
|
outputs
|
||||||
|
**/__pycache__
|
||||||
|
*.egg-info
|
||||||
|
|
31
Dockerfile
31
Dockerfile
|
@ -1,31 +0,0 @@
|
||||||
from pytorch/pytorch:latest
|
|
||||||
|
|
||||||
env DEBIAN_FRONTEND=noninteractive
|
|
||||||
|
|
||||||
run apt-get update && apt-get install -y git wget
|
|
||||||
|
|
||||||
run conda install xformers -c xformers/label/dev
|
|
||||||
|
|
||||||
run pip install --upgrade \
|
|
||||||
diffusers[torch] \
|
|
||||||
accelerate \
|
|
||||||
transformers \
|
|
||||||
huggingface_hub \
|
|
||||||
pyTelegramBotAPI \
|
|
||||||
pymongo \
|
|
||||||
scipy \
|
|
||||||
pdbpp
|
|
||||||
|
|
||||||
env NVIDIA_VISIBLE_DEVICES=all
|
|
||||||
|
|
||||||
run mkdir /scripts
|
|
||||||
run mkdir /outputs
|
|
||||||
run mkdir /inputs
|
|
||||||
|
|
||||||
env HF_HOME /hf_home
|
|
||||||
|
|
||||||
run mkdir /hf_home
|
|
||||||
|
|
||||||
workdir /scripts
|
|
||||||
|
|
||||||
env PYTORCH_CUDA_ALLOC_CONF max_split_size_mb:128
|
|
|
@ -0,0 +1,13 @@
|
||||||
|
from python:3.10.0
|
||||||
|
|
||||||
|
env DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
workdir /skynet
|
||||||
|
|
||||||
|
copy requirements.* ./
|
||||||
|
|
||||||
|
run pip install \
|
||||||
|
-r requirements.txt \
|
||||||
|
-r requirements.test.txt
|
||||||
|
|
||||||
|
workdir /scripts
|
|
@ -0,0 +1,23 @@
|
||||||
|
from nvidia/cuda:11.7.0-devel-ubuntu20.04
|
||||||
|
from python:3.10.0
|
||||||
|
|
||||||
|
env DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
workdir /skynet
|
||||||
|
|
||||||
|
copy requirements.* .
|
||||||
|
|
||||||
|
run pip install -U pip ninja
|
||||||
|
run pip install -r requirements.cuda.0.txt
|
||||||
|
run pip install -v -r requirements.cuda.1.txt
|
||||||
|
|
||||||
|
run pip install \
|
||||||
|
-r requirements.txt \
|
||||||
|
-r requirements.test.txt
|
||||||
|
|
||||||
|
env NVIDIA_VISIBLE_DEVICES=all
|
||||||
|
env HF_HOME /hf_home
|
||||||
|
|
||||||
|
env PYTORCH_CUDA_ALLOC_CONF max_split_size_mb:128
|
||||||
|
|
||||||
|
workdir /scripts
|
|
@ -0,0 +1,11 @@
|
||||||
|
A menos que sea especificamente indicado en el cabezal del archivo, se reservan
|
||||||
|
todos los derechos sobre este codigo por parte de:
|
||||||
|
|
||||||
|
Guillermo Rodriguez, guillermor@fing.edu.uy
|
||||||
|
|
||||||
|
ENGLISH LICENSE:
|
||||||
|
|
||||||
|
Unless specifically indicated in the file header, all rights to this code are
|
||||||
|
reserved by:
|
||||||
|
|
||||||
|
Guillermo Rodriguez, guillermor@.edu.uy
|
|
@ -0,0 +1,47 @@
|
||||||
|
create db in postgres:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
CREATE USER skynet WITH PASSWORD 'password';
|
||||||
|
CREATE DATABASE skynet_art_bot;
|
||||||
|
GRANT ALL PRIVILEGES ON DATABASE skynet_art_bot TO skynet;
|
||||||
|
|
||||||
|
CREATE SCHEMA IF NOT EXISTS skynet;
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS skynet.user(
|
||||||
|
id SERIAL PRIMARY KEY NOT NULL,
|
||||||
|
tg_id INT,
|
||||||
|
wp_id VARCHAR(128),
|
||||||
|
mx_id VARCHAR(128),
|
||||||
|
ig_id VARCHAR(128),
|
||||||
|
generated INT NOT NULL,
|
||||||
|
joined DATE NOT NULL,
|
||||||
|
last_prompt TEXT,
|
||||||
|
role VARCHAR(128) NOT NULL
|
||||||
|
);
|
||||||
|
ALTER TABLE skynet.user
|
||||||
|
ADD CONSTRAINT tg_unique
|
||||||
|
UNIQUE (tg_id);
|
||||||
|
ALTER TABLE skynet.user
|
||||||
|
ADD CONSTRAINT wp_unique
|
||||||
|
UNIQUE (wp_id);
|
||||||
|
ALTER TABLE skynet.user
|
||||||
|
ADD CONSTRAINT mx_unique
|
||||||
|
UNIQUE (mx_id);
|
||||||
|
ALTER TABLE skynet.user
|
||||||
|
ADD CONSTRAINT ig_unique
|
||||||
|
UNIQUE (ig_id);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS skynet.user_config(
|
||||||
|
id SERIAL NOT NULL,
|
||||||
|
algo VARCHAR(128) NOT NULL,
|
||||||
|
step INT NOT NULL,
|
||||||
|
width INT NOT NULL,
|
||||||
|
height INT NOT NULL,
|
||||||
|
seed INT,
|
||||||
|
guidance INT NOT NULL,
|
||||||
|
upscaler VARCHAR(128)
|
||||||
|
);
|
||||||
|
ALTER TABLE skynet.user_config
|
||||||
|
ADD FOREIGN KEY(id)
|
||||||
|
REFERENCES skynet.user(id);
|
||||||
|
```
|
|
@ -0,0 +1,7 @@
|
||||||
|
docker build \
|
||||||
|
-t skynet:runtime-cuda \
|
||||||
|
-f Dockerfile.runtime-cuda .
|
||||||
|
|
||||||
|
docker build \
|
||||||
|
-t skynet:runtime \
|
||||||
|
-f Dockerfile.runtime .
|
|
@ -0,0 +1,2 @@
|
||||||
|
[pytest]
|
||||||
|
trio_mode = true
|
|
@ -0,0 +1,8 @@
|
||||||
|
pdbpp
|
||||||
|
scipy
|
||||||
|
accelerate
|
||||||
|
transformers
|
||||||
|
huggingface_hub
|
||||||
|
diffusers[torch]
|
||||||
|
torch==1.13.0+cu117
|
||||||
|
--extra-index-url https://download.pytorch.org/whl/cu117
|
|
@ -0,0 +1 @@
|
||||||
|
git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
|
|
@ -0,0 +1,2 @@
|
||||||
|
pytest
|
||||||
|
pytest-trio
|
|
@ -0,0 +1,8 @@
|
||||||
|
trio
|
||||||
|
pynng
|
||||||
|
triopg
|
||||||
|
aiohttp
|
||||||
|
msgspec
|
||||||
|
trio_asyncio
|
||||||
|
|
||||||
|
git+https://github.com/goodboy/tractor.git@master#egg=tractor
|
14
run-bot.sh
14
run-bot.sh
|
@ -1,14 +0,0 @@
|
||||||
mkdir -p outputs
|
|
||||||
mkdir -p hf_home
|
|
||||||
|
|
||||||
docker run \
|
|
||||||
-it \
|
|
||||||
--rm \
|
|
||||||
--gpus=all \
|
|
||||||
--env HF_TOKEN='' \
|
|
||||||
--env DB_USER='skynet' \
|
|
||||||
--env DB_PASS='nnf01nmf091d0i' \
|
|
||||||
--mount type=bind,source="$(pwd)"/outputs,target=/outputs \
|
|
||||||
--mount type=bind,source="$(pwd)"/hf_home,target=/hf_home \
|
|
||||||
--mount type=bind,source="$(pwd)"/scripts,target=/scripts \
|
|
||||||
skynet:dif python telegram-bot-dev.py
|
|
|
@ -1,9 +0,0 @@
|
||||||
docker run
|
|
||||||
-d \
|
|
||||||
--rm \
|
|
||||||
-p 27017:27017 \
|
|
||||||
--name mongodb-skynet \
|
|
||||||
--mount type=bind,source="$(pwd)"/mongodb,target=/data/db \
|
|
||||||
-e MONGO_INITDB_ROOT_USERNAME="" \
|
|
||||||
-e MONGO_INITDB_ROOT_PASSWORD="" \
|
|
||||||
mongo
|
|
|
@ -1,537 +0,0 @@
|
||||||
#!/usr/bin/python
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import random
|
|
||||||
|
|
||||||
from torch.multiprocessing import spawn
|
|
||||||
|
|
||||||
import telebot
|
|
||||||
from telebot.types import InputFile
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.multiprocessing.spawn import ProcessRaisedException
|
|
||||||
from diffusers import (
|
|
||||||
StableDiffusionPipeline,
|
|
||||||
EulerAncestralDiscreteScheduler
|
|
||||||
)
|
|
||||||
|
|
||||||
from huggingface_hub import login
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from pymongo import MongoClient
|
|
||||||
|
|
||||||
from typing import Tuple, Optional
|
|
||||||
|
|
||||||
db_user = os.environ['DB_USER']
|
|
||||||
db_pass = os.environ['DB_PASS']
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
|
|
||||||
MEM_FRACTION = .33
|
|
||||||
|
|
||||||
ALGOS = {
|
|
||||||
'stable': 'runwayml/stable-diffusion-v1-5',
|
|
||||||
'midj': 'prompthero/openjourney',
|
|
||||||
'hdanime': 'Linaqruf/anything-v3.0',
|
|
||||||
'waifu': 'hakurei/waifu-diffusion',
|
|
||||||
'ghibli': 'nitrosocke/Ghibli-Diffusion',
|
|
||||||
'van-gogh': 'dallinmackay/Van-Gogh-diffusion',
|
|
||||||
'pokemon': 'lambdalabs/sd-pokemon-diffusers',
|
|
||||||
'ink': 'Envvi/Inkpunk-Diffusion',
|
|
||||||
'robot': 'nousr/robo-diffusion'
|
|
||||||
}
|
|
||||||
|
|
||||||
N = '\n'
|
|
||||||
HELP_TEXT = f'''
|
|
||||||
test art bot v0.1a4
|
|
||||||
|
|
||||||
commands work on a user per user basis!
|
|
||||||
config is individual to each user!
|
|
||||||
|
|
||||||
/txt2img TEXT - request an image based on a prompt
|
|
||||||
|
|
||||||
/redo - redo last prompt
|
|
||||||
|
|
||||||
/cool - list of cool words to use
|
|
||||||
/stats - user statistics
|
|
||||||
/donate - see donation info
|
|
||||||
|
|
||||||
/config algo NAME - select AI to use one of:
|
|
||||||
|
|
||||||
{N.join(ALGOS.keys())}
|
|
||||||
|
|
||||||
/config step NUMBER - set amount of iterations
|
|
||||||
/config seed NUMBER - set the seed, deterministic results!
|
|
||||||
/config size WIDTH HEIGHT - set size in pixels
|
|
||||||
/config guidance NUMBER - prompt text importance
|
|
||||||
'''
|
|
||||||
|
|
||||||
UNKNOWN_CMD_TEXT = 'unknown command! try sending \"/help\"'
|
|
||||||
|
|
||||||
DONATION_INFO = '0xf95335682DF281FFaB7E104EB87B69625d9622B6\ngoal: 25/650usd'
|
|
||||||
|
|
||||||
COOL_WORDS = [
|
|
||||||
'cyberpunk',
|
|
||||||
'soviet propaganda poster',
|
|
||||||
'rastafari',
|
|
||||||
'cannabis',
|
|
||||||
'art deco',
|
|
||||||
'H R Giger Necronom IV',
|
|
||||||
'dimethyltryptamine',
|
|
||||||
'lysergic',
|
|
||||||
'slut',
|
|
||||||
'psilocybin',
|
|
||||||
'trippy',
|
|
||||||
'lucy in the sky with diamonds',
|
|
||||||
'fractal',
|
|
||||||
'da vinci',
|
|
||||||
'pencil illustration',
|
|
||||||
'blueprint',
|
|
||||||
'internal diagram',
|
|
||||||
'baroque',
|
|
||||||
'the last judgment',
|
|
||||||
'michelangelo'
|
|
||||||
]
|
|
||||||
|
|
||||||
GROUP_ID = -1001541979235
|
|
||||||
|
|
||||||
MP_ENABLED_ROLES = ['god']
|
|
||||||
|
|
||||||
MIN_STEP = 1
|
|
||||||
MAX_STEP = 100
|
|
||||||
MAX_SIZE = (512, 656)
|
|
||||||
MAX_GUIDANCE = 20
|
|
||||||
|
|
||||||
DEFAULT_SIZE = (512, 512)
|
|
||||||
DEFAULT_GUIDANCE = 7.5
|
|
||||||
DEFAULT_STEP = 75
|
|
||||||
DEFAULT_CREDITS = 10
|
|
||||||
DEFAULT_ALGO = 'stable'
|
|
||||||
DEFAULT_ROLE = 'pleb'
|
|
||||||
DEFAULT_UPSCALER = None
|
|
||||||
|
|
||||||
rr_total = 1
|
|
||||||
rr_id = 0
|
|
||||||
request_counter = 0
|
|
||||||
|
|
||||||
def its_my_turn():
|
|
||||||
global request_counter, rr_total, rr_id
|
|
||||||
my_turn = request_counter % rr_total == rr_id
|
|
||||||
logging.info(f'new request {request_counter}, turn: {my_turn} rr_total: {rr_total}, rr_id {rr_id}')
|
|
||||||
request_counter += 1
|
|
||||||
return my_turn
|
|
||||||
|
|
||||||
def round_robined(func):
|
|
||||||
def rr_wrapper(*args, **kwargs):
|
|
||||||
if not its_my_turn():
|
|
||||||
return
|
|
||||||
|
|
||||||
func(*args, **kwargs)
|
|
||||||
|
|
||||||
return rr_wrapper
|
|
||||||
|
|
||||||
|
|
||||||
def generate_image(
|
|
||||||
i: int,
|
|
||||||
prompt: str,
|
|
||||||
name: str,
|
|
||||||
step: int,
|
|
||||||
size: Tuple[int, int],
|
|
||||||
guidance: int,
|
|
||||||
seed: int,
|
|
||||||
algo: str,
|
|
||||||
upscaler: Optional[str]
|
|
||||||
):
|
|
||||||
assert torch.cuda.is_available()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
torch.cuda.set_per_process_memory_fraction(MEM_FRACTION)
|
|
||||||
with torch.no_grad():
|
|
||||||
if algo == 'stable':
|
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(
|
|
||||||
'runwayml/stable-diffusion-v1-5',
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
revision="fp16",
|
|
||||||
safety_checker=None
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(
|
|
||||||
ALGOS[algo],
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
safety_checker=None
|
|
||||||
)
|
|
||||||
|
|
||||||
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
|
|
||||||
pipe = pipe.to("cuda")
|
|
||||||
w, h = size
|
|
||||||
print(f'generating image... of size {w, h}')
|
|
||||||
image = pipe(
|
|
||||||
prompt,
|
|
||||||
width=w,
|
|
||||||
height=h,
|
|
||||||
guidance_scale=guidance, num_inference_steps=step,
|
|
||||||
generator=torch.Generator("cuda").manual_seed(seed)
|
|
||||||
).images[0]
|
|
||||||
|
|
||||||
if upscaler == 'x4':
|
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(
|
|
||||||
'stabilityai/stable-diffusion-x4-upscaler',
|
|
||||||
revision="fp16",
|
|
||||||
torch_dtype=torch.float16
|
|
||||||
)
|
|
||||||
image = pipe(prompt=prompt, image=image).images[0]
|
|
||||||
|
|
||||||
|
|
||||||
image.save(f'/outputs/{name}.png')
|
|
||||||
print('saved')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
|
||||||
API_TOKEN = '5880619053:AAFge2UfObw1kCn9Kb7AAyqaIHs_HgM0Fx0'
|
|
||||||
|
|
||||||
bot = telebot.TeleBot(API_TOKEN)
|
|
||||||
db_client = MongoClient(
|
|
||||||
host=['ancap.tech:64000'],
|
|
||||||
username=db_user,
|
|
||||||
password=db_pass)
|
|
||||||
|
|
||||||
tgdb = db_client.get_database('telegram')
|
|
||||||
|
|
||||||
collections = tgdb.list_collection_names()
|
|
||||||
|
|
||||||
if 'users' in collections:
|
|
||||||
tg_users = tgdb.get_collection('users')
|
|
||||||
# tg_users.delete_many({})
|
|
||||||
|
|
||||||
else:
|
|
||||||
tg_users = tgdb.create_collection('users')
|
|
||||||
|
|
||||||
# db functions
|
|
||||||
|
|
||||||
def get_user(uid: int):
|
|
||||||
return tg_users.find_one({'uid': uid})
|
|
||||||
|
|
||||||
|
|
||||||
def new_user(uid: int):
|
|
||||||
if get_user(uid):
|
|
||||||
raise ValueError('User already present on db')
|
|
||||||
|
|
||||||
res = tg_users.insert_one({
|
|
||||||
'generated': 0,
|
|
||||||
'uid': uid,
|
|
||||||
'credits': DEFAULT_CREDITS,
|
|
||||||
'joined': datetime.utcnow().isoformat(),
|
|
||||||
'last_prompt': None,
|
|
||||||
'role': DEFAULT_ROLE,
|
|
||||||
'config': {
|
|
||||||
'algo': DEFAULT_ALGO,
|
|
||||||
'step': DEFAULT_STEP,
|
|
||||||
'size': DEFAULT_SIZE,
|
|
||||||
'seed': None,
|
|
||||||
'guidance': DEFAULT_GUIDANCE,
|
|
||||||
'upscaler': DEFAULT_UPSCALER
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
assert res.acknowledged
|
|
||||||
|
|
||||||
return get_user(uid)
|
|
||||||
|
|
||||||
def migrate_user(db_user):
|
|
||||||
# new: user roles
|
|
||||||
if 'role' not in db_user:
|
|
||||||
res = tg_users.find_one_and_update(
|
|
||||||
{'uid': db_user['uid']}, {'$set': {'role': DEFAULT_ROLE}})
|
|
||||||
|
|
||||||
# new: algo selection
|
|
||||||
if 'algo' not in db_user['config']:
|
|
||||||
res = tg_users.find_one_and_update(
|
|
||||||
{'uid': db_user['uid']}, {'$set': {'config.algo': DEFAULT_ALGO}})
|
|
||||||
|
|
||||||
# new: upscaler selection
|
|
||||||
if 'upscaler' not in db_user['config']:
|
|
||||||
res = tg_users.find_one_and_update(
|
|
||||||
{'uid': db_user['uid']}, {'$set': {'config.upscaler': DEFAULT_UPSCALER}})
|
|
||||||
|
|
||||||
return get_user(db_user['uid'])
|
|
||||||
|
|
||||||
def get_or_create_user(uid: int):
|
|
||||||
db_user = get_user(uid)
|
|
||||||
|
|
||||||
if not db_user:
|
|
||||||
db_user = new_user(uid)
|
|
||||||
|
|
||||||
logging.info(f'req from: {uid}')
|
|
||||||
|
|
||||||
return migrate_user(db_user)
|
|
||||||
|
|
||||||
def update_user(uid: int, updt_cmd: dict):
|
|
||||||
user = get_user(uid)
|
|
||||||
if not user:
|
|
||||||
raise ValueError('User not present on db')
|
|
||||||
|
|
||||||
return tg_users.find_one_and_update(
|
|
||||||
{'uid': uid}, updt_cmd)
|
|
||||||
|
|
||||||
|
|
||||||
# bot handler
|
|
||||||
def img_for_user_with_prompt(
|
|
||||||
uid: int,
|
|
||||||
prompt: str, step: int, size: Tuple[int, int], guidance: int, seed: int,
|
|
||||||
algo: str, upscaler: Optional[str]
|
|
||||||
):
|
|
||||||
name = uuid.uuid4()
|
|
||||||
|
|
||||||
spawn(
|
|
||||||
generate_image,
|
|
||||||
args=(prompt, name, step, size, guidance, seed, algo, upscaler))
|
|
||||||
|
|
||||||
logging.info(f'done generating. got {name}, sending...')
|
|
||||||
|
|
||||||
if len(prompt) > 256:
|
|
||||||
reply_txt = f'prompt: \"{prompt[:256]}...\"\n(full prompt too big to show on reply...)\n'
|
|
||||||
|
|
||||||
else:
|
|
||||||
reply_txt = f'prompt: \"{prompt}\"\n'
|
|
||||||
|
|
||||||
reply_txt += f'seed: {seed}\n'
|
|
||||||
reply_txt += f'iterations: {step}\n'
|
|
||||||
reply_txt += f'size: {size}\n'
|
|
||||||
reply_txt += f'guidance: {guidance}\n'
|
|
||||||
reply_txt += f'algo: {ALGOS[algo]}\n'
|
|
||||||
reply_txt += f'euler ancestral discrete'
|
|
||||||
|
|
||||||
return reply_txt, name
|
|
||||||
|
|
||||||
@bot.message_handler(commands=['help'])
|
|
||||||
@round_robined
|
|
||||||
def send_help(message):
|
|
||||||
bot.reply_to(message, HELP_TEXT)
|
|
||||||
|
|
||||||
@bot.message_handler(commands=['cool'])
|
|
||||||
@round_robined
|
|
||||||
def send_cool_words(message):
|
|
||||||
bot.reply_to(message, '\n'.join(COOL_WORDS))
|
|
||||||
|
|
||||||
@bot.message_handler(commands=['txt2img'])
|
|
||||||
@round_robined
|
|
||||||
def send_txt2img(message):
|
|
||||||
chat = message.chat
|
|
||||||
user = message.from_user
|
|
||||||
db_user = get_or_create_user(user.id)
|
|
||||||
|
|
||||||
if ((chat.type != 'group' and chat.id != GROUP_ID) and
|
|
||||||
(db_user['role'] not in MP_ENABLED_ROLES)):
|
|
||||||
return
|
|
||||||
|
|
||||||
prompt = ' '.join(message.text.split(' ')[1:])
|
|
||||||
|
|
||||||
if len(prompt) == 0:
|
|
||||||
bot.reply_to(message, 'empty text prompt ignored.')
|
|
||||||
return
|
|
||||||
|
|
||||||
logging.info(f"{user.first_name} ({user.id}) on chat {chat.id} txt2img: {prompt}")
|
|
||||||
|
|
||||||
user_conf = db_user['config']
|
|
||||||
|
|
||||||
algo = user_conf['algo']
|
|
||||||
step = user_conf['step']
|
|
||||||
size = user_conf['size']
|
|
||||||
seed = user_conf['seed'] if user_conf['seed'] else random.randint(0, 999999999)
|
|
||||||
guidance = user_conf['guidance']
|
|
||||||
upscaler = user_conf['upscaler']
|
|
||||||
|
|
||||||
try:
|
|
||||||
reply_txt, name = img_for_user_with_prompt(
|
|
||||||
user.id, prompt, step, size, guidance, seed, algo, upscaler)
|
|
||||||
|
|
||||||
update_user(
|
|
||||||
user.id,
|
|
||||||
{'$set': {
|
|
||||||
'generated': db_user['generated'] + 1,
|
|
||||||
'last_prompt': prompt
|
|
||||||
}})
|
|
||||||
|
|
||||||
bot.send_photo(
|
|
||||||
chat.id,
|
|
||||||
caption=f'sent by: {user.first_name}\n' + reply_txt,
|
|
||||||
photo=InputFile(f'/outputs/{name}.png'))
|
|
||||||
|
|
||||||
except BaseException as e:
|
|
||||||
logging.error(e)
|
|
||||||
bot.reply_to(message, 'that command caused an error :(\nchange settings and try again (?')
|
|
||||||
|
|
||||||
@bot.message_handler(commands=['redo'])
|
|
||||||
@round_robined
|
|
||||||
def redo_txt2img(message):
|
|
||||||
# check msg comes from testing group
|
|
||||||
chat = message.chat
|
|
||||||
user = message.from_user
|
|
||||||
db_user = get_or_create_user(user.id)
|
|
||||||
|
|
||||||
if ((chat.type != 'group' and chat.id != GROUP_ID) and
|
|
||||||
(db_user['role'] not in MP_ENABLED_ROLES)):
|
|
||||||
return
|
|
||||||
|
|
||||||
prompt = db_user['last_prompt']
|
|
||||||
|
|
||||||
if not prompt:
|
|
||||||
bot.reply_to(message, 'do a /txt2img command first silly!')
|
|
||||||
return
|
|
||||||
|
|
||||||
user_conf = db_user['config']
|
|
||||||
|
|
||||||
algo = user_conf['algo']
|
|
||||||
step = user_conf['step']
|
|
||||||
size = user_conf['size']
|
|
||||||
seed = user_conf['seed'] if user_conf['seed'] else random.randint(0, 999999999)
|
|
||||||
guidance = user_conf['guidance']
|
|
||||||
upscaler = user_conf['upscaler']
|
|
||||||
|
|
||||||
logging.info(f"{user.first_name} ({user.id}) on chat {chat.id} redo: {prompt}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
reply_txt, name = img_for_user_with_prompt(
|
|
||||||
user.id, prompt, step, size, guidance, seed, algo, upscaler)
|
|
||||||
|
|
||||||
update_user(
|
|
||||||
user.id,
|
|
||||||
{'$set': {
|
|
||||||
'generated': db_user['generated'] + 1,
|
|
||||||
}})
|
|
||||||
|
|
||||||
bot.send_photo(
|
|
||||||
chat.id,
|
|
||||||
caption=f'sent by: {user.first_name}\n' + reply_txt,
|
|
||||||
photo=InputFile(f'/outputs/{name}.png'))
|
|
||||||
|
|
||||||
except BaseException as e:
|
|
||||||
logging.error(e)
|
|
||||||
bot.reply_to(message, 'that command caused an error :(\nchange settings and try again (?')
|
|
||||||
|
|
||||||
@bot.message_handler(commands=['config'])
|
|
||||||
@round_robined
|
|
||||||
def set_config(message):
|
|
||||||
logging.info(f'config req on chat: {message.chat.id}')
|
|
||||||
|
|
||||||
params = message.text.split(' ')
|
|
||||||
|
|
||||||
if len(params) < 3:
|
|
||||||
bot.reply_to(message, 'wrong msg format')
|
|
||||||
|
|
||||||
else:
|
|
||||||
user = message.from_user
|
|
||||||
chat = message.chat
|
|
||||||
db_user = get_or_create_user(user.id)
|
|
||||||
|
|
||||||
try:
|
|
||||||
attr = params[1]
|
|
||||||
|
|
||||||
if attr == 'algo':
|
|
||||||
val = params[2]
|
|
||||||
assert val in ALGOS
|
|
||||||
res = update_user(user.id, {'$set': {'config.algo': val}})
|
|
||||||
|
|
||||||
elif attr == 'step':
|
|
||||||
val = int(params[2])
|
|
||||||
val = max(min(val, MAX_STEP), MIN_STEP)
|
|
||||||
res = update_user(user.id, {'$set': {'config.step': val}})
|
|
||||||
|
|
||||||
elif attr == 'size':
|
|
||||||
max_w, max_h = MAX_SIZE
|
|
||||||
w = max(min(int(params[2]), max_w), 16)
|
|
||||||
h = max(min(int(params[3]), max_h), 16)
|
|
||||||
|
|
||||||
val = (w, h)
|
|
||||||
|
|
||||||
if (w % 8 != 0) or (h % 8 != 0):
|
|
||||||
bot.reply_to(message, 'size must be divisible by 8!')
|
|
||||||
return
|
|
||||||
|
|
||||||
res = update_user(user.id, {'$set': {'config.size': val}})
|
|
||||||
|
|
||||||
elif attr == 'seed':
|
|
||||||
val = params[2]
|
|
||||||
if val == 'auto':
|
|
||||||
val = None
|
|
||||||
else:
|
|
||||||
val = int(params[2])
|
|
||||||
|
|
||||||
res = update_user(user.id, {'$set': {'config.seed': val}})
|
|
||||||
|
|
||||||
elif attr == 'guidance':
|
|
||||||
val = float(params[2])
|
|
||||||
val = max(min(val, MAX_GUIDANCE), 0)
|
|
||||||
res = update_user(user.id, {'$set': {'config.guidance': val}})
|
|
||||||
|
|
||||||
elif attr == 'upscaler':
|
|
||||||
val = params[2]
|
|
||||||
if val == 'off':
|
|
||||||
val = None
|
|
||||||
|
|
||||||
res = update_user(user.id, {'$set': {'config.upscaler': val}})
|
|
||||||
|
|
||||||
else:
|
|
||||||
bot.reply_to(message, f'\"{attr}\" not a parameter')
|
|
||||||
|
|
||||||
bot.reply_to(message, f'config updated! {attr} to {val}')
|
|
||||||
|
|
||||||
except ValueError:
|
|
||||||
bot.reply_to(message, f'\"{val}\" is not a number silly')
|
|
||||||
|
|
||||||
except AssertionError:
|
|
||||||
bot.reply_to(message, f'no algo named {val}')
|
|
||||||
|
|
||||||
@bot.message_handler(commands=['stats'])
|
|
||||||
@round_robined
|
|
||||||
def user_stats(message):
|
|
||||||
user = message.from_user
|
|
||||||
db_user = get_or_create_user(user.id)
|
|
||||||
migrate_user(db_user)
|
|
||||||
|
|
||||||
joined_date_str = datetime.fromisoformat(db_user['joined']).strftime('%B the %dth %Y, %H:%M:%S')
|
|
||||||
|
|
||||||
user_stats_str = f'generated: {db_user["generated"]}\n'
|
|
||||||
user_stats_str += f'joined: {joined_date_str}\n'
|
|
||||||
user_stats_str += f'credits: {db_user["credits"]}\n'
|
|
||||||
user_stats_str += f'role: {db_user["role"]}\n'
|
|
||||||
|
|
||||||
bot.reply_to(
|
|
||||||
message, user_stats_str)
|
|
||||||
|
|
||||||
@bot.message_handler(commands=['donate'])
|
|
||||||
@round_robined
|
|
||||||
def donation_info(message):
|
|
||||||
bot.reply_to(
|
|
||||||
message, DONATION_INFO)
|
|
||||||
|
|
||||||
@bot.message_handler(commands=['say'])
|
|
||||||
@round_robined
|
|
||||||
def say(message):
|
|
||||||
chat = message.chat
|
|
||||||
user = message.from_user
|
|
||||||
db_user = get_or_create_user(user.id)
|
|
||||||
|
|
||||||
if (chat.type == 'group') or (db_user['role'] not in MP_ENABLED_ROLES):
|
|
||||||
return
|
|
||||||
|
|
||||||
bot.send_message(GROUP_ID, message.text[4:])
|
|
||||||
|
|
||||||
@bot.message_handler(func=lambda message: True)
|
|
||||||
@round_robined
|
|
||||||
def echo_message(message):
|
|
||||||
if message.text[0] == '/':
|
|
||||||
bot.reply_to(message, UNKNOWN_CMD_TEXT)
|
|
||||||
|
|
||||||
|
|
||||||
login(token=os.environ['HF_TOKEN'])
|
|
||||||
|
|
||||||
bot.infinity_polling()
|
|
|
@ -0,0 +1,11 @@
|
||||||
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name='skynet-bot',
|
||||||
|
version='0.1.0a6',
|
||||||
|
description='Decentralized compute platform',
|
||||||
|
author='Guillermo Rodriguez',
|
||||||
|
author_email='guillermo@telos.net',
|
||||||
|
packages=find_packages(),
|
||||||
|
install_requires=[]
|
||||||
|
)
|
|
@ -0,0 +1,2 @@
|
||||||
|
#!/usr/bin/python
|
||||||
|
|
|
@ -0,0 +1,246 @@
|
||||||
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from uuid import UUID
|
||||||
|
from functools import partial
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import trio
|
||||||
|
import pynng
|
||||||
|
import trio_asyncio
|
||||||
|
|
||||||
|
from .db import *
|
||||||
|
from .types import *
|
||||||
|
from .constants import *
|
||||||
|
|
||||||
|
|
||||||
|
class SkynetDGPUOffline(BaseException):
|
||||||
|
...
|
||||||
|
|
||||||
|
class SkynetDGPUOverloaded(BaseException):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
async def rpc_service(sock, dgpu_bus, db_pool):
|
||||||
|
nodes = OrderedDict()
|
||||||
|
wip_reqs = {}
|
||||||
|
fin_reqs = {}
|
||||||
|
|
||||||
|
def are_all_workers_busy():
|
||||||
|
for nid, info in nodes.items():
|
||||||
|
if info['task'] == None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
next_worker = 0
|
||||||
|
def get_next_worker():
|
||||||
|
nonlocal next_worker
|
||||||
|
|
||||||
|
if len(nodes) == 0:
|
||||||
|
raise SkynetDGPUOffline
|
||||||
|
|
||||||
|
if are_all_workers_busy():
|
||||||
|
raise SkynetDGPUOverloaded
|
||||||
|
|
||||||
|
next_worker += 1
|
||||||
|
|
||||||
|
if next_worker >= len(nodes):
|
||||||
|
next_worker = 0
|
||||||
|
|
||||||
|
nid = list(nodes.keys())[next_worker]
|
||||||
|
return nid
|
||||||
|
|
||||||
|
async def dgpu_image_streamer():
|
||||||
|
nonlocal wip_reqs, fin_reqs
|
||||||
|
while True:
|
||||||
|
msg = await dgpu_bus.arecv_msg()
|
||||||
|
rid = UUID(bytes=msg.bytes[:16]).hex
|
||||||
|
img = msg.bytes[16:].hex()
|
||||||
|
fin_reqs[rid] = img
|
||||||
|
event = wip_reqs[rid]
|
||||||
|
event.set()
|
||||||
|
del wip_reqs[rid]
|
||||||
|
|
||||||
|
async def dgpu_stream_one_img(req: ImageGenRequest):
|
||||||
|
nonlocal wip_reqs, fin_reqs, next_worker
|
||||||
|
nid = get_next_worker()
|
||||||
|
logging.info(f'dgpu_stream_one_img {next_worker} {nid}')
|
||||||
|
rid = uuid.uuid4().hex
|
||||||
|
event = trio.Event()
|
||||||
|
wip_reqs[rid] = event
|
||||||
|
|
||||||
|
nodes[nid]['task'] = rid
|
||||||
|
|
||||||
|
dgpu_req = DGPUBusRequest(
|
||||||
|
rid=rid,
|
||||||
|
nid=nid,
|
||||||
|
task='diffuse',
|
||||||
|
params=req.to_dict())
|
||||||
|
|
||||||
|
logging.info(f'dgpu_bus req: {dgpu_req}')
|
||||||
|
|
||||||
|
await dgpu_bus.asend(
|
||||||
|
json.dumps(dgpu_req.to_dict()).encode())
|
||||||
|
|
||||||
|
await event.wait()
|
||||||
|
|
||||||
|
nodes[nid]['task'] = None
|
||||||
|
|
||||||
|
img = fin_reqs[rid]
|
||||||
|
del fin_reqs[rid]
|
||||||
|
|
||||||
|
logging.info(f'done streaming {img}')
|
||||||
|
|
||||||
|
return rid, img
|
||||||
|
|
||||||
|
async def handle_user_request(rpc_ctx, req):
|
||||||
|
try:
|
||||||
|
async with db_pool.acquire() as conn:
|
||||||
|
user = await get_or_create_user(conn, req.uid)
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
match req.method:
|
||||||
|
case 'txt2img':
|
||||||
|
logging.info('txt2img')
|
||||||
|
user_config = {**(await get_user_config(conn, user))}
|
||||||
|
del user_config['id']
|
||||||
|
prompt = req.params['prompt']
|
||||||
|
req = ImageGenRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
**user_config
|
||||||
|
)
|
||||||
|
rid, img = await dgpu_stream_one_img(req)
|
||||||
|
result = {
|
||||||
|
'id': rid,
|
||||||
|
'img': img
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'redo':
|
||||||
|
logging.info('redo')
|
||||||
|
user_config = await get_user_config(conn, user)
|
||||||
|
prompt = await get_last_prompt_of(conn, user)
|
||||||
|
req = ImageGenRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
**user_config
|
||||||
|
)
|
||||||
|
rid, img = await dgpu_stream_one_img(req)
|
||||||
|
result = {
|
||||||
|
'id': rid,
|
||||||
|
'img': img
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'config':
|
||||||
|
logging.info('config')
|
||||||
|
if req.params['attr'] in CONFIG_ATTRS:
|
||||||
|
await update_user_config(
|
||||||
|
conn, user, req.params['attr'], req.params['val'])
|
||||||
|
|
||||||
|
case 'stats':
|
||||||
|
logging.info('stats')
|
||||||
|
generated, joined, role = await get_user_stats(conn, user)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
'generated': generated,
|
||||||
|
'joined': joined.strftime(DATE_FORMAT),
|
||||||
|
'role': role
|
||||||
|
}
|
||||||
|
|
||||||
|
case _:
|
||||||
|
logging.warn('unknown method')
|
||||||
|
|
||||||
|
except SkynetDGPUOffline:
|
||||||
|
result = {
|
||||||
|
'error': 'skynet_dgpu_offline'
|
||||||
|
}
|
||||||
|
|
||||||
|
except SkynetDGPUOverloaded:
|
||||||
|
result = {
|
||||||
|
'error': 'skynet_dgpu_overloaded',
|
||||||
|
'nodes': len(nodes)
|
||||||
|
}
|
||||||
|
|
||||||
|
except BaseException as e:
|
||||||
|
logging.error(e)
|
||||||
|
raise e
|
||||||
|
# result = {
|
||||||
|
# 'error': 'skynet_internal_error'
|
||||||
|
# }
|
||||||
|
|
||||||
|
await rpc_ctx.asend(
|
||||||
|
json.dumps(
|
||||||
|
SkynetRPCResponse(result=result).to_dict()).encode())
|
||||||
|
|
||||||
|
|
||||||
|
async with trio.open_nursery() as n:
|
||||||
|
n.start_soon(dgpu_image_streamer)
|
||||||
|
while True:
|
||||||
|
ctx = sock.new_context()
|
||||||
|
msg = await ctx.arecv_msg()
|
||||||
|
content = msg.bytes.decode()
|
||||||
|
req = SkynetRPCRequest(**json.loads(content))
|
||||||
|
|
||||||
|
logging.info(req)
|
||||||
|
|
||||||
|
if req.method == 'dgpu_online':
|
||||||
|
nodes[req.uid] = {
|
||||||
|
'task': None
|
||||||
|
}
|
||||||
|
logging.info(f'dgpu online: {req.uid}')
|
||||||
|
|
||||||
|
|
||||||
|
elif req.method == 'dgpu_offline':
|
||||||
|
i = nodes.values().index(req.uid)
|
||||||
|
del nodes[req.uid]
|
||||||
|
|
||||||
|
if i < next_worker:
|
||||||
|
next_worker -= 1
|
||||||
|
logging.info(f'dgpu offline: {req.uid}')
|
||||||
|
|
||||||
|
else:
|
||||||
|
n.start_soon(
|
||||||
|
handle_user_request, ctx, req)
|
||||||
|
continue
|
||||||
|
|
||||||
|
await ctx.asend(
|
||||||
|
json.dumps(
|
||||||
|
SkynetRPCResponse(
|
||||||
|
result={'ok': {}}).to_dict()).encode())
|
||||||
|
|
||||||
|
|
||||||
|
async def run_skynet(
|
||||||
|
db_user: str,
|
||||||
|
db_pass: str,
|
||||||
|
db_host: str = DB_HOST,
|
||||||
|
rpc_address: str = DEFAULT_RPC_ADDR,
|
||||||
|
dgpu_address: str = DEFAULT_DGPU_ADDR,
|
||||||
|
task_status = trio.TASK_STATUS_IGNORED
|
||||||
|
):
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logging.info('skynet is starting')
|
||||||
|
|
||||||
|
async with (
|
||||||
|
trio.open_nursery() as n,
|
||||||
|
open_database_connection(
|
||||||
|
db_user, db_pass, db_host) as db_pool
|
||||||
|
):
|
||||||
|
logging.info('connected to db.')
|
||||||
|
with (
|
||||||
|
pynng.Rep0(listen=rpc_address) as rpc_sock,
|
||||||
|
pynng.Bus0(listen=dgpu_address) as dgpu_bus
|
||||||
|
):
|
||||||
|
n.start_soon(
|
||||||
|
rpc_service, rpc_sock, dgpu_bus, db_pool)
|
||||||
|
task_status.started()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await trio.sleep_forever()
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
...
|
||||||
|
|
|
@ -0,0 +1,129 @@
|
||||||
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
API_TOKEN = '5880619053:AAFge2UfObw1kCn9Kb7AAyqaIHs_HgM0Fx0'
|
||||||
|
|
||||||
|
DB_HOST = 'ancap.tech:34508'
|
||||||
|
|
||||||
|
ALGOS = {
|
||||||
|
'stable': 'runwayml/stable-diffusion-v1-5',
|
||||||
|
'midj': 'prompthero/openjourney',
|
||||||
|
'hdanime': 'Linaqruf/anything-v3.0',
|
||||||
|
'waifu': 'hakurei/waifu-diffusion',
|
||||||
|
'ghibli': 'nitrosocke/Ghibli-Diffusion',
|
||||||
|
'van-gogh': 'dallinmackay/Van-Gogh-diffusion',
|
||||||
|
'pokemon': 'lambdalabs/sd-pokemon-diffusers',
|
||||||
|
'ink': 'Envvi/Inkpunk-Diffusion',
|
||||||
|
'robot': 'nousr/robo-diffusion'
|
||||||
|
}
|
||||||
|
|
||||||
|
N = '\n'
|
||||||
|
HELP_TEXT = f'''
|
||||||
|
test art bot v0.1a4
|
||||||
|
|
||||||
|
commands work on a user per user basis!
|
||||||
|
config is individual to each user!
|
||||||
|
|
||||||
|
/txt2img TEXT - request an image based on a prompt
|
||||||
|
|
||||||
|
/redo - re ont
|
||||||
|
|
||||||
|
/help step - get info on step config option
|
||||||
|
/help guidance - get info on guidance config option
|
||||||
|
|
||||||
|
/cool - list of cool words to use
|
||||||
|
/stats - user statistics
|
||||||
|
/donate - see donation info
|
||||||
|
|
||||||
|
/config algo NAME - select AI to use one of:
|
||||||
|
|
||||||
|
{N.join(ALGOS.keys())}
|
||||||
|
|
||||||
|
/config step NUMBER - set amount of iterations
|
||||||
|
/config seed NUMBER - set the seed, deterministic results!
|
||||||
|
/config size WIDTH HEIGHT - set size in pixels
|
||||||
|
/config guidance NUMBER - prompt text importance
|
||||||
|
'''
|
||||||
|
|
||||||
|
UNKNOWN_CMD_TEXT = 'unknown command! try sending \"/help\"'
|
||||||
|
|
||||||
|
DONATION_INFO = '0xf95335682DF281FFaB7E104EB87B69625d9622B6\ngoal: 25/650usd'
|
||||||
|
|
||||||
|
COOL_WORDS = [
|
||||||
|
'cyberpunk',
|
||||||
|
'soviet propaganda poster',
|
||||||
|
'rastafari',
|
||||||
|
'cannabis',
|
||||||
|
'art deco',
|
||||||
|
'H R Giger Necronom IV',
|
||||||
|
'dimethyltryptamine',
|
||||||
|
'lysergic',
|
||||||
|
'slut',
|
||||||
|
'psilocybin',
|
||||||
|
'trippy',
|
||||||
|
'lucy in the sky with diamonds',
|
||||||
|
'fractal',
|
||||||
|
'da vinci',
|
||||||
|
'pencil illustration',
|
||||||
|
'blueprint',
|
||||||
|
'internal diagram',
|
||||||
|
'baroque',
|
||||||
|
'the last judgment',
|
||||||
|
'michelangelo'
|
||||||
|
]
|
||||||
|
|
||||||
|
HELP_STEP = '''
|
||||||
|
diffusion models are iterative processes – a repeated cycle that starts with a\
|
||||||
|
random noise generated from text input. With each step, some noise is removed\
|
||||||
|
, resulting in a higher-quality image over time. The repetition stops when the\
|
||||||
|
desired number of steps completes.
|
||||||
|
|
||||||
|
around 25 sampling steps are usually enough to achieve high-quality images. Us\
|
||||||
|
ing more may produce a slightly different picture, but not necessarily better \
|
||||||
|
quality.
|
||||||
|
'''
|
||||||
|
|
||||||
|
HELP_GUIDANCE = '''
|
||||||
|
the guidance scale is a parameter that controls how much the image generation\
|
||||||
|
process follows the text prompt. The higher the value, the more image sticks\
|
||||||
|
to a given text input.
|
||||||
|
'''
|
||||||
|
|
||||||
|
HELP_UNKWNOWN_PARAM = 'don\'t have any info on that.'
|
||||||
|
|
||||||
|
GROUP_ID = -1001541979235
|
||||||
|
|
||||||
|
MP_ENABLED_ROLES = ['god']
|
||||||
|
|
||||||
|
MIN_STEP = 1
|
||||||
|
MAX_STEP = 100
|
||||||
|
MAX_WIDTH = 512
|
||||||
|
MAX_HEIGHT = 656
|
||||||
|
MAX_GUIDANCE = 20
|
||||||
|
|
||||||
|
DEFAULT_SEED = None
|
||||||
|
DEFAULT_WIDTH = 512
|
||||||
|
DEFAULT_HEIGHT = 512
|
||||||
|
DEFAULT_GUIDANCE = 7.5
|
||||||
|
DEFAULT_STEP = 35
|
||||||
|
DEFAULT_CREDITS = 10
|
||||||
|
DEFAULT_ALGO = 'midj'
|
||||||
|
DEFAULT_ROLE = 'pleb'
|
||||||
|
DEFAULT_UPSCALER = None
|
||||||
|
|
||||||
|
DEFAULT_RPC_ADDR = 'tcp://127.0.0.1:41000'
|
||||||
|
|
||||||
|
DEFAULT_DGPU_ADDR = 'tcp://127.0.0.1:41069'
|
||||||
|
DEFAULT_DGPU_MAX_TASKS = 3
|
||||||
|
DEFAULT_INITAL_ALGOS = ['midj', 'stable', 'ink']
|
||||||
|
|
||||||
|
DATE_FORMAT = '%B the %dth %Y, %H:%M:%S'
|
||||||
|
|
||||||
|
CONFIG_ATTRS = [
|
||||||
|
'algo',
|
||||||
|
'step',
|
||||||
|
'width',
|
||||||
|
'height',
|
||||||
|
'seed',
|
||||||
|
'guidance',
|
||||||
|
'upscaler'
|
||||||
|
]
|
|
@ -0,0 +1,146 @@
|
||||||
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from contextlib import asynccontextmanager as acm
|
||||||
|
|
||||||
|
import trio
|
||||||
|
import triopg
|
||||||
|
|
||||||
|
from .constants import *
|
||||||
|
|
||||||
|
|
||||||
|
def try_decode_uid(uid: str):
|
||||||
|
try:
|
||||||
|
proto, uid = uid.split('+')
|
||||||
|
uid = int(uid)
|
||||||
|
return proto, uid
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
logging.warning(f'got non numeric uid?: {uid}')
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
@acm
|
||||||
|
async def open_database_connection(
|
||||||
|
db_user: str,
|
||||||
|
db_pass: str,
|
||||||
|
db_host: str = DB_HOST,
|
||||||
|
):
|
||||||
|
async with triopg.create_pool(
|
||||||
|
dsn=f'postgres://{db_user}:{db_pass}@{db_host}/skynet_art_bot'
|
||||||
|
) as conn:
|
||||||
|
yield conn
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user(conn, uid: str):
|
||||||
|
if isinstance(uid, str):
|
||||||
|
proto, uid = try_decode_uid(uid)
|
||||||
|
|
||||||
|
match proto:
|
||||||
|
case 'tg':
|
||||||
|
stmt = await conn.prepare(
|
||||||
|
'SELECT * FROM skynet.user WHERE tg_id = $1')
|
||||||
|
user = await stmt.fetchval(uid)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
user = None
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
else: # asumme is our uid
|
||||||
|
stmt = await conn.prepare(
|
||||||
|
'SELECT * FROM skynet.user WHERE id = $1')
|
||||||
|
return await stmt.fetchval(uid)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_config(conn, user: int):
|
||||||
|
stmt = await conn.prepare(
|
||||||
|
'SELECT * FROM skynet.user_config WHERE id = $1')
|
||||||
|
return (await stmt.fetch(user))[0]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_last_prompt_of(conn, user: int):
|
||||||
|
stms = await conn.prepare(
|
||||||
|
'SELECT last_prompt FROM skynet.user WHERE id = $1')
|
||||||
|
return await stmt.fetchval(user)
|
||||||
|
|
||||||
|
|
||||||
|
async def new_user(conn, uid: str):
|
||||||
|
if await get_user(conn, uid):
|
||||||
|
raise ValueError('User already present on db')
|
||||||
|
|
||||||
|
logging.info(f'new user! {uid}')
|
||||||
|
|
||||||
|
tg_id = None
|
||||||
|
date = datetime.utcnow()
|
||||||
|
|
||||||
|
proto, pid = try_decode_uid(uid)
|
||||||
|
|
||||||
|
match proto:
|
||||||
|
case 'tg':
|
||||||
|
tg_id = pid
|
||||||
|
|
||||||
|
async with conn.transaction():
|
||||||
|
stmt = await conn.prepare('''
|
||||||
|
INSERT INTO skynet.user(
|
||||||
|
tg_id, generated, joined, last_prompt, role)
|
||||||
|
|
||||||
|
VALUES($1, $2, $3, $4, $5)
|
||||||
|
''')
|
||||||
|
await stmt.fetch(
|
||||||
|
tg_id, 0, date, None, DEFAULT_ROLE
|
||||||
|
)
|
||||||
|
|
||||||
|
new_uid = await get_user(conn, uid)
|
||||||
|
|
||||||
|
stmt = await conn.prepare('''
|
||||||
|
INSERT INTO skynet.user_config(
|
||||||
|
id, algo, step, width, height, seed, guidance, upscaler)
|
||||||
|
|
||||||
|
VALUES($1, $2, $3, $4, $5, $6, $7, $8)
|
||||||
|
''')
|
||||||
|
user = await stmt.fetch(
|
||||||
|
new_uid,
|
||||||
|
DEFAULT_ALGO,
|
||||||
|
DEFAULT_STEP,
|
||||||
|
DEFAULT_WIDTH,
|
||||||
|
DEFAULT_HEIGHT,
|
||||||
|
DEFAULT_SEED,
|
||||||
|
DEFAULT_GUIDANCE,
|
||||||
|
DEFAULT_UPSCALER
|
||||||
|
)
|
||||||
|
|
||||||
|
return new_uid
|
||||||
|
|
||||||
|
|
||||||
|
async def get_or_create_user(conn, uid: str):
|
||||||
|
user = await get_user(conn, uid)
|
||||||
|
|
||||||
|
if not user:
|
||||||
|
user = await new_user(conn, uid)
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
async def update_user(conn, user: int, attr: str, val):
|
||||||
|
...
|
||||||
|
|
||||||
|
async def update_user_config(conn, user: int, attr: str, val):
|
||||||
|
stmt = await conn.prepare(f'''
|
||||||
|
UPDATE skynet.user_config
|
||||||
|
SET {attr} = $2
|
||||||
|
WHERE id = $1
|
||||||
|
''')
|
||||||
|
await stmt.fetch(user, val)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_stats(conn, user: int):
|
||||||
|
stmt = await conn.prepare('''
|
||||||
|
SELECT generated,joined,role FROM skynet.user
|
||||||
|
WHERE id = $1
|
||||||
|
''')
|
||||||
|
records = await stmt.fetch(user)
|
||||||
|
assert len(records) == 1
|
||||||
|
record = records[0]
|
||||||
|
return record
|
|
@ -0,0 +1,121 @@
|
||||||
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
import trio
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import pynng
|
||||||
|
import tractor
|
||||||
|
|
||||||
|
from . import gpu
|
||||||
|
from .gpu import open_gpu_worker
|
||||||
|
from .types import *
|
||||||
|
from .constants import *
|
||||||
|
from .frontend import rpc_call
|
||||||
|
|
||||||
|
|
||||||
|
async def open_dgpu_node(
|
||||||
|
rpc_address: str = DEFAULT_RPC_ADDR,
|
||||||
|
dgpu_address: str = DEFAULT_DGPU_ADDR,
|
||||||
|
dgpu_max_tasks: int = DEFAULT_DGPU_MAX_TASKS,
|
||||||
|
initial_algos: str = DEFAULT_INITAL_ALGOS
|
||||||
|
):
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
name = uuid.uuid4()
|
||||||
|
workers = initial_algos.copy()
|
||||||
|
tasks = [None for _ in range(dgpu_max_tasks)]
|
||||||
|
|
||||||
|
portal_map: dict[int, tractor.Portal]
|
||||||
|
contexts: dict[int, tractor.Context]
|
||||||
|
|
||||||
|
def get_next_worker(need_algo: str):
|
||||||
|
nonlocal workers, tasks
|
||||||
|
for task, algo in zip(workers, tasks):
|
||||||
|
if need_algo == algo and not task:
|
||||||
|
return workers.index(need_algo)
|
||||||
|
|
||||||
|
return tasks.index(None)
|
||||||
|
|
||||||
|
async def gpu_streamer(
|
||||||
|
ctx: tractor.Context,
|
||||||
|
nid: int
|
||||||
|
):
|
||||||
|
nonlocal tasks
|
||||||
|
async with ctx.open_stream() as stream:
|
||||||
|
async for img in stream:
|
||||||
|
tasks[nid]['res'] = img
|
||||||
|
tasks[nid]['event'].set()
|
||||||
|
|
||||||
|
async def gpu_compute_one(ireq: ImageGenRequest):
|
||||||
|
wid = get_next_worker(ireq.algo)
|
||||||
|
event = trio.Event()
|
||||||
|
|
||||||
|
workers[wid] = ireq.algo
|
||||||
|
tasks[wid] = {
|
||||||
|
'res': None, 'event': event}
|
||||||
|
|
||||||
|
await contexts[i].send(ireq)
|
||||||
|
|
||||||
|
await event.wait()
|
||||||
|
|
||||||
|
img = tasks[wid]['res']
|
||||||
|
tasks[wid] = None
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
with (
|
||||||
|
pynng.Req0(dial=rpc_address) as rpc_sock,
|
||||||
|
pynng.Bus0(dial=dgpu_address) as dgpu_sock
|
||||||
|
):
|
||||||
|
async def _rpc_call(*args, **kwargs):
|
||||||
|
return await rpc_call(rpc_sock, *args, **kwargs)
|
||||||
|
|
||||||
|
async def _process_dgpu_req(req: DGPUBusRequest):
|
||||||
|
img = await gpu_compute_one(
|
||||||
|
ImageGenRequest(**req.params))
|
||||||
|
await dgpu_sock.asend(
|
||||||
|
bytes.fromhex(req.rid) + img)
|
||||||
|
|
||||||
|
res = await _rpc_call(
|
||||||
|
name.hex, 'dgpu_online', {'max_tasks': dgpu_max_tasks})
|
||||||
|
logging.info(res)
|
||||||
|
assert 'ok' in res.result
|
||||||
|
|
||||||
|
async with (
|
||||||
|
tractor.open_actor_cluster(
|
||||||
|
modules=['skynet_bot.gpu'],
|
||||||
|
count=dgpu_max_tasks,
|
||||||
|
names=[i for i in range(dgpu_max_tasks)]
|
||||||
|
) as portal_map,
|
||||||
|
trio.open_nursery() as n
|
||||||
|
):
|
||||||
|
logging.info(f'starting {dgpu_max_tasks} gpu workers')
|
||||||
|
async with tractor.gather_contexts((
|
||||||
|
ctx.open_context(
|
||||||
|
open_gpu_worker, algo, 1.0 / dgpu_max_tasks)
|
||||||
|
)) as contexts:
|
||||||
|
contexts = {i: ctx for i, ctx in enumerate(contexts)}
|
||||||
|
for i, ctx in contexts.items():
|
||||||
|
n.start_soon(
|
||||||
|
gpu_streamer, ctx, i)
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
msg = await dgpu_sock.arecv()
|
||||||
|
req = DGPUBusRequest(
|
||||||
|
**json.loads(msg.decode()))
|
||||||
|
|
||||||
|
if req.nid != name.hex:
|
||||||
|
continue
|
||||||
|
|
||||||
|
logging.info(f'dgpu: {name}, req: {req}')
|
||||||
|
n.start_soon(
|
||||||
|
_process_dgpu_req, req)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
...
|
||||||
|
|
||||||
|
res = await _rpc_call(name.hex, 'dgpu_offline')
|
||||||
|
logging.info(res)
|
||||||
|
assert 'ok' in res.result
|
|
@ -0,0 +1,50 @@
|
||||||
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
from contextlib import contextmanager as cm
|
||||||
|
|
||||||
|
import pynng
|
||||||
|
|
||||||
|
from ..types import SkynetRPCRequest, SkynetRPCResponse
|
||||||
|
from ..constants import DEFAULT_RPC_ADDR
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigUnknownAttribute(BaseException):
|
||||||
|
...
|
||||||
|
|
||||||
|
class ConfigUnknownAlgorithm(BaseException):
|
||||||
|
...
|
||||||
|
|
||||||
|
class ConfigUnknownUpscaler(BaseException):
|
||||||
|
...
|
||||||
|
|
||||||
|
class ConfigSizeDivisionByEight(BaseException):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
async def rpc_call(
|
||||||
|
sock,
|
||||||
|
uid: Union[int, str],
|
||||||
|
method: str,
|
||||||
|
params: dict = {}
|
||||||
|
):
|
||||||
|
req = SkynetRPCRequest(
|
||||||
|
uid=uid,
|
||||||
|
method=method,
|
||||||
|
params=params
|
||||||
|
)
|
||||||
|
await sock.asend(
|
||||||
|
json.dumps(
|
||||||
|
req.to_dict()).encode())
|
||||||
|
|
||||||
|
return SkynetRPCResponse(
|
||||||
|
**json.loads(
|
||||||
|
(await sock.arecv_msg()).bytes.decode()))
|
||||||
|
|
||||||
|
|
||||||
|
@cm
|
||||||
|
def open_skynet_rpc(rpc_address: str = DEFAULT_RPC_ADDR):
|
||||||
|
with pynng.Req0(dial=rpc_address) as rpc_sock:
|
||||||
|
yield rpc_sock
|
|
@ -0,0 +1,164 @@
|
||||||
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import pynng
|
||||||
|
|
||||||
|
from telebot.async_telebot import AsyncTeleBot
|
||||||
|
from trio_asyncio import aio_as_trio
|
||||||
|
|
||||||
|
from ..constants import *
|
||||||
|
|
||||||
|
from . import *
|
||||||
|
|
||||||
|
|
||||||
|
PREFIX = 'tg'
|
||||||
|
|
||||||
|
|
||||||
|
async def run_skynet_telegram(tg_token: str):
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
bot = AsyncTeleBot(tg_token)
|
||||||
|
|
||||||
|
with open_skynet_rpc() as rpc_sock:
|
||||||
|
|
||||||
|
async def _rpc_call(
|
||||||
|
uid: int,
|
||||||
|
method: str,
|
||||||
|
params: dict
|
||||||
|
):
|
||||||
|
return await rpc_call(
|
||||||
|
rpc_sock, f'{PREFIX}+{uid}', method, params)
|
||||||
|
|
||||||
|
@bot.message_handler(commands=['help'])
|
||||||
|
async def send_help(message):
|
||||||
|
await bot.reply_to(message, HELP_TEXT)
|
||||||
|
|
||||||
|
@bot.message_handler(commands=['cool'])
|
||||||
|
async def send_cool_words(message):
|
||||||
|
await bot.reply_to(message, '\n'.join(COOL_WORDS))
|
||||||
|
|
||||||
|
@bot.message_handler(commands=['txt2img'])
|
||||||
|
async def send_txt2img(message):
|
||||||
|
resp = await _rpc_call(
|
||||||
|
message.from_user.id,
|
||||||
|
'txt2img',
|
||||||
|
{}
|
||||||
|
)
|
||||||
|
|
||||||
|
@bot.message_handler(commands=['redo'])
|
||||||
|
async def redo_txt2img(message):
|
||||||
|
resp = await _rpc_call(
|
||||||
|
message.from_user.id,
|
||||||
|
'redo',
|
||||||
|
{}
|
||||||
|
)
|
||||||
|
|
||||||
|
@bot.message_handler(commands=['config'])
|
||||||
|
async def set_config(message):
|
||||||
|
params = message.text.split(' ')
|
||||||
|
|
||||||
|
rpc_params = {}
|
||||||
|
|
||||||
|
if len(params) < 3:
|
||||||
|
bot.reply_to(message, 'wrong msg format')
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
try:
|
||||||
|
attr = params[1]
|
||||||
|
|
||||||
|
if attr == 'algo':
|
||||||
|
val = params[2]
|
||||||
|
if val not in ALGOS:
|
||||||
|
raise ConfigUnknownAlgorithm
|
||||||
|
|
||||||
|
elif attr == 'step':
|
||||||
|
val = int(params[2])
|
||||||
|
val = max(min(val, MAX_STEP), MIN_STEP)
|
||||||
|
|
||||||
|
elif attr == 'width':
|
||||||
|
val = max(min(int(params[2]), MAX_WIDTH), 16)
|
||||||
|
if val % 8 != 0:
|
||||||
|
raise ConfigSizeDivisionByEight
|
||||||
|
|
||||||
|
elif attr == 'height':
|
||||||
|
val = max(min(int(params[2]), MAX_HEIGHT), 16)
|
||||||
|
if val % 8 != 0:
|
||||||
|
raise ConfigSizeDivisionByEight
|
||||||
|
|
||||||
|
elif attr == 'seed':
|
||||||
|
val = params[2]
|
||||||
|
if val == 'auto':
|
||||||
|
val = None
|
||||||
|
else:
|
||||||
|
val = int(params[2])
|
||||||
|
|
||||||
|
elif attr == 'guidance':
|
||||||
|
val = float(params[2])
|
||||||
|
val = max(min(val, MAX_GUIDANCE), 0)
|
||||||
|
|
||||||
|
elif attr == 'upscaler':
|
||||||
|
val = params[2]
|
||||||
|
if val == 'off':
|
||||||
|
val = None
|
||||||
|
elif val != 'x4':
|
||||||
|
raise ConfigUnknownUpscaler
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ConfigUnknownAttribute
|
||||||
|
|
||||||
|
resp = await _rpc_call(
|
||||||
|
message.from_user.id,
|
||||||
|
'config', {'attr': attr, 'val': val})
|
||||||
|
|
||||||
|
reply_txt = f'config updated! {attr} to {val}'
|
||||||
|
|
||||||
|
except ConfigUnknownAlgorithm:
|
||||||
|
reply_txt = f'no algo named {val}'
|
||||||
|
|
||||||
|
except ConfigUnknownAttribute:
|
||||||
|
reply_txt = f'\"{attr}\" not a configurable parameter'
|
||||||
|
|
||||||
|
except ConfigUnknownUpscaler:
|
||||||
|
reply_txt = f'\"{val}\" is not a valid upscaler'
|
||||||
|
|
||||||
|
except ConfigSizeDivisionByEight:
|
||||||
|
reply_txt = 'size must be divisible by 8!'
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
reply_txt = f'\"{val}\" is not a number silly'
|
||||||
|
|
||||||
|
await bot.reply_to(message, reply_txt)
|
||||||
|
|
||||||
|
@bot.message_handler(commands=['stats'])
|
||||||
|
async def user_stats(message):
|
||||||
|
resp = await _rpc_call(
|
||||||
|
message.from_user.id,
|
||||||
|
'stats',
|
||||||
|
{}
|
||||||
|
)
|
||||||
|
stats = resp.result
|
||||||
|
|
||||||
|
stats_str = f'generated: {stats["generated"]}\n'
|
||||||
|
stats_str += f'joined: {stats["joined"]}\n'
|
||||||
|
stats_str += f'role: {stats["role"]}\n'
|
||||||
|
|
||||||
|
await bot.reply_to(
|
||||||
|
message, stats_str)
|
||||||
|
|
||||||
|
@bot.message_handler(commands=['donate'])
|
||||||
|
async def donation_info(message):
|
||||||
|
await bot.reply_to(
|
||||||
|
message, DONATION_INFO)
|
||||||
|
|
||||||
|
|
||||||
|
@bot.message_handler(func=lambda message: True)
|
||||||
|
async def echo_message(message):
|
||||||
|
if message.text[0] == '/':
|
||||||
|
await bot.reply_to(message, UNKNOWN_CMD_TEXT)
|
||||||
|
|
||||||
|
|
||||||
|
await aio_as_trio(bot.infinity_polling())
|
|
@ -0,0 +1,75 @@
|
||||||
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
import io
|
||||||
|
import random
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import tractor
|
||||||
|
|
||||||
|
from diffusers import (
|
||||||
|
StableDiffusionPipeline,
|
||||||
|
EulerAncestralDiscreteScheduler
|
||||||
|
)
|
||||||
|
|
||||||
|
from .types import ImageGenRequest
|
||||||
|
from .constants import ALGOS
|
||||||
|
|
||||||
|
|
||||||
|
def pipeline_for(algo: str, mem_fraction: float):
|
||||||
|
assert torch.cuda.is_available()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
params = {
|
||||||
|
'torch_dtype': torch.float16,
|
||||||
|
'safety_checker': None
|
||||||
|
}
|
||||||
|
|
||||||
|
if algo == 'stable':
|
||||||
|
params['revision'] = 'fp16'
|
||||||
|
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
|
ALGOS[algo], **params)
|
||||||
|
|
||||||
|
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
||||||
|
pipe.scheduler.config)
|
||||||
|
|
||||||
|
return pipe.to("cuda")
|
||||||
|
|
||||||
|
@tractor.context
|
||||||
|
async def open_gpu_worker(
|
||||||
|
ctx: tractor.Context,
|
||||||
|
start_algo: str,
|
||||||
|
mem_fraction: float
|
||||||
|
):
|
||||||
|
current_algo = start_algo
|
||||||
|
with torch.no_grad():
|
||||||
|
pipe = pipeline_for(current_algo, mem_fraction)
|
||||||
|
await ctx.started()
|
||||||
|
|
||||||
|
async with ctx.open_stream() as bus:
|
||||||
|
async for ireq in bus:
|
||||||
|
if ireq.algo != current_algo:
|
||||||
|
current_algo = ireq.algo
|
||||||
|
pipe = pipeline_for(current_algo, mem_fraction)
|
||||||
|
|
||||||
|
seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64)
|
||||||
|
image = pipe(
|
||||||
|
ireq.prompt,
|
||||||
|
width=ireq.width,
|
||||||
|
height=ireq.height,
|
||||||
|
guidance_scale=ireq.guidance,
|
||||||
|
num_inference_steps=ireq.step,
|
||||||
|
generator=torch.Generator("cuda").manual_seed(seed)
|
||||||
|
).images[0]
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# convert PIL.Image to BytesIO
|
||||||
|
img_bytes = io.BytesIO()
|
||||||
|
image.save(img_bytes, format='PNG')
|
||||||
|
await bus.send(img_bytes.getvalue())
|
||||||
|
|
|
@ -0,0 +1,109 @@
|
||||||
|
# piker: trading gear for hackers
|
||||||
|
# Copyright (C) Guillermo Rodriguez (in stewardship for piker0)
|
||||||
|
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
|
# the Free Software Foundation, either version 3 of the License, or
|
||||||
|
# (at your option) any later version.
|
||||||
|
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Built-in (extension) types.
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
from typing import Optional, Union
|
||||||
|
from pprint import pformat
|
||||||
|
|
||||||
|
import msgspec
|
||||||
|
|
||||||
|
|
||||||
|
class Struct(
|
||||||
|
msgspec.Struct,
|
||||||
|
|
||||||
|
# https://jcristharif.com/msgspec/structs.html#tagged-unions
|
||||||
|
# tag='pikerstruct',
|
||||||
|
# tag=True,
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
A "human friendlier" (aka repl buddy) struct subtype.
|
||||||
|
'''
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
f: getattr(self, f)
|
||||||
|
for f in self.__struct_fields__
|
||||||
|
}
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
# only turn on pprint when we detect a python REPL
|
||||||
|
# at runtime B)
|
||||||
|
if (
|
||||||
|
hasattr(sys, 'ps1')
|
||||||
|
# TODO: check if we're in pdb
|
||||||
|
):
|
||||||
|
return self.pformat()
|
||||||
|
|
||||||
|
return super().__repr__()
|
||||||
|
|
||||||
|
def pformat(self) -> str:
|
||||||
|
return f'Struct({pformat(self.to_dict())})'
|
||||||
|
|
||||||
|
def copy(
|
||||||
|
self,
|
||||||
|
update: Optional[dict] = None,
|
||||||
|
|
||||||
|
) -> msgspec.Struct:
|
||||||
|
'''
|
||||||
|
Validate-typecast all self defined fields, return a copy of us
|
||||||
|
with all such fields.
|
||||||
|
This is kinda like the default behaviour in `pydantic.BaseModel`.
|
||||||
|
'''
|
||||||
|
if update:
|
||||||
|
for k, v in update.items():
|
||||||
|
setattr(self, k, v)
|
||||||
|
|
||||||
|
# roundtrip serialize to validate
|
||||||
|
return msgspec.msgpack.Decoder(
|
||||||
|
type=type(self)
|
||||||
|
).decode(
|
||||||
|
msgspec.msgpack.Encoder().encode(self)
|
||||||
|
)
|
||||||
|
|
||||||
|
def typecast(
|
||||||
|
self,
|
||||||
|
# fields: Optional[list[str]] = None,
|
||||||
|
) -> None:
|
||||||
|
for fname, ftype in self.__annotations__.items():
|
||||||
|
setattr(self, fname, ftype(getattr(self, fname)))
|
||||||
|
|
||||||
|
# proto
|
||||||
|
|
||||||
|
class SkynetRPCRequest(Struct):
|
||||||
|
uid: Union[str, int] # user unique id
|
||||||
|
method: str # rpc method name
|
||||||
|
params: dict # variable params
|
||||||
|
|
||||||
|
class SkynetRPCResponse(Struct):
|
||||||
|
result: dict
|
||||||
|
|
||||||
|
class ImageGenRequest(Struct):
|
||||||
|
prompt: str
|
||||||
|
step: int
|
||||||
|
width: int
|
||||||
|
height: int
|
||||||
|
guidance: int
|
||||||
|
seed: Optional[int]
|
||||||
|
algo: str
|
||||||
|
upscaler: Optional[str]
|
||||||
|
|
||||||
|
class DGPUBusRequest(Struct):
|
||||||
|
rid: str # req id
|
||||||
|
nid: str # node id
|
||||||
|
task: str
|
||||||
|
params: dict
|
|
@ -0,0 +1,8 @@
|
||||||
|
docker run \
|
||||||
|
-it \
|
||||||
|
--rm \
|
||||||
|
--mount type=bind,source="$(pwd)",target=/skynet \
|
||||||
|
skynet:runtime-cuda \
|
||||||
|
bash -c \
|
||||||
|
"cd /skynet && pip install -e . && \
|
||||||
|
pytest tests/test_dgpu.py --log-cli-level=info"
|
|
@ -0,0 +1,55 @@
|
||||||
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import trio
|
||||||
|
import pynng
|
||||||
|
import trio_asyncio
|
||||||
|
|
||||||
|
from skynet_bot.dgpu import open_dgpu_node
|
||||||
|
from skynet_bot.types import *
|
||||||
|
from skynet_bot.brain import run_skynet
|
||||||
|
from skynet_bot.constants import *
|
||||||
|
from skynet_bot.frontend import open_skynet_rpc, rpc_call
|
||||||
|
|
||||||
|
|
||||||
|
def test_dgpu_simple():
|
||||||
|
async def main():
|
||||||
|
async with trio.open_nursery() as n:
|
||||||
|
await n.start(
|
||||||
|
run_skynet,
|
||||||
|
'skynet', '3GbZd6UojbD8V7zWpeFn', 'ancap.tech:34508')
|
||||||
|
|
||||||
|
await trio.sleep(2)
|
||||||
|
|
||||||
|
for i in range(3):
|
||||||
|
n.start_soon(open_dgpu_node)
|
||||||
|
|
||||||
|
await trio.sleep(1)
|
||||||
|
start = time.time()
|
||||||
|
async def request_img():
|
||||||
|
with pynng.Req0(dial=DEFAULT_RPC_ADDR) as rpc_sock:
|
||||||
|
res = await rpc_call(
|
||||||
|
rpc_sock, 'tg+1', 'txt2img', {
|
||||||
|
'prompt': 'test',
|
||||||
|
'step': 28,
|
||||||
|
'width': 512, 'height': 512,
|
||||||
|
'guidance': 7.5,
|
||||||
|
'seed': None,
|
||||||
|
'algo': 'stable',
|
||||||
|
'upscaler': None
|
||||||
|
})
|
||||||
|
|
||||||
|
logging.info(res)
|
||||||
|
|
||||||
|
async with trio.open_nursery() as inner_n:
|
||||||
|
for i in range(3):
|
||||||
|
inner_n.start_soon(request_img)
|
||||||
|
|
||||||
|
logging.info(f'time elapsed: {time.time() - start}')
|
||||||
|
n.cancel_scope.cancel()
|
||||||
|
|
||||||
|
|
||||||
|
trio_asyncio.run(main)
|
|
@ -0,0 +1,22 @@
|
||||||
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
import trio
|
||||||
|
import trio_asyncio
|
||||||
|
|
||||||
|
from skynet_bot.brain import run_skynet
|
||||||
|
from skynet_bot.frontend import open_skynet_rpc
|
||||||
|
from skynet_bot.frontend.telegram import run_skynet_telegram
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_tg_bot():
|
||||||
|
async def main():
|
||||||
|
async with trio.open_nursery() as n:
|
||||||
|
await n.start(
|
||||||
|
run_skynet,
|
||||||
|
'skynet', '3GbZd6UojbD8V7zWpeFn', 'ancap.tech:34508')
|
||||||
|
n.start_soon(
|
||||||
|
run_skynet_telegram,
|
||||||
|
'5853245787:AAFEmv3EjJ_qJ8d_vmOpi6o6HFHUf8a0uCQ')
|
||||||
|
|
||||||
|
|
||||||
|
trio_asyncio.run(main)
|
Loading…
Reference in New Issue