mirror of https://github.com/skygpu/skynet.git
				
				
				
			First decupled architecture, still working on integrating tractor gpu workers
							parent
							
								
									f06a09b5bd
								
							
						
					
					
						commit
						66d997c039
					
				| 
						 | 
				
			
			@ -0,0 +1,3 @@
 | 
			
		|||
hf_home
 | 
			
		||||
inputs
 | 
			
		||||
outputs
 | 
			
		||||
| 
						 | 
				
			
			@ -1,2 +1,5 @@
 | 
			
		|||
.python-version
 | 
			
		||||
hf_home
 | 
			
		||||
outputs
 | 
			
		||||
**/__pycache__
 | 
			
		||||
*.egg-info
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										31
									
								
								Dockerfile
								
								
								
								
							
							
						
						
									
										31
									
								
								Dockerfile
								
								
								
								
							| 
						 | 
				
			
			@ -1,31 +0,0 @@
 | 
			
		|||
from pytorch/pytorch:latest
 | 
			
		||||
 | 
			
		||||
env DEBIAN_FRONTEND=noninteractive
 | 
			
		||||
 | 
			
		||||
run apt-get update && apt-get install -y git wget
 | 
			
		||||
 | 
			
		||||
run conda install xformers -c xformers/label/dev
 | 
			
		||||
 | 
			
		||||
run pip install --upgrade \
 | 
			
		||||
    diffusers[torch] \
 | 
			
		||||
    accelerate \
 | 
			
		||||
    transformers \
 | 
			
		||||
    huggingface_hub \
 | 
			
		||||
    pyTelegramBotAPI \
 | 
			
		||||
    pymongo \
 | 
			
		||||
    scipy \
 | 
			
		||||
    pdbpp
 | 
			
		||||
 | 
			
		||||
env NVIDIA_VISIBLE_DEVICES=all
 | 
			
		||||
 | 
			
		||||
run mkdir /scripts
 | 
			
		||||
run mkdir /outputs
 | 
			
		||||
run mkdir /inputs
 | 
			
		||||
 | 
			
		||||
env HF_HOME /hf_home
 | 
			
		||||
 | 
			
		||||
run mkdir /hf_home
 | 
			
		||||
 | 
			
		||||
workdir /scripts
 | 
			
		||||
 | 
			
		||||
env PYTORCH_CUDA_ALLOC_CONF max_split_size_mb:128
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,13 @@
 | 
			
		|||
from python:3.10.0
 | 
			
		||||
 | 
			
		||||
env DEBIAN_FRONTEND=noninteractive
 | 
			
		||||
 | 
			
		||||
workdir /skynet
 | 
			
		||||
 | 
			
		||||
copy requirements.* ./
 | 
			
		||||
 | 
			
		||||
run pip install \
 | 
			
		||||
    -r requirements.txt \
 | 
			
		||||
    -r requirements.test.txt
 | 
			
		||||
 | 
			
		||||
workdir /scripts
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,23 @@
 | 
			
		|||
from nvidia/cuda:11.7.0-devel-ubuntu20.04
 | 
			
		||||
from python:3.10.0
 | 
			
		||||
 | 
			
		||||
env DEBIAN_FRONTEND=noninteractive
 | 
			
		||||
 | 
			
		||||
workdir /skynet
 | 
			
		||||
 | 
			
		||||
copy requirements.* .
 | 
			
		||||
 | 
			
		||||
run pip install -U pip ninja
 | 
			
		||||
run pip install -r requirements.cuda.0.txt
 | 
			
		||||
run pip install -v -r requirements.cuda.1.txt
 | 
			
		||||
 | 
			
		||||
run pip install \
 | 
			
		||||
    -r requirements.txt \
 | 
			
		||||
    -r requirements.test.txt
 | 
			
		||||
 | 
			
		||||
env NVIDIA_VISIBLE_DEVICES=all
 | 
			
		||||
env HF_HOME /hf_home
 | 
			
		||||
 | 
			
		||||
env PYTORCH_CUDA_ALLOC_CONF max_split_size_mb:128
 | 
			
		||||
 | 
			
		||||
workdir /scripts
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,11 @@
 | 
			
		|||
A menos que sea especificamente indicado en el cabezal del archivo, se reservan
 | 
			
		||||
todos los derechos sobre este codigo por parte de:
 | 
			
		||||
 | 
			
		||||
Guillermo Rodriguez, guillermor@fing.edu.uy
 | 
			
		||||
 | 
			
		||||
ENGLISH LICENSE:
 | 
			
		||||
 | 
			
		||||
Unless specifically indicated in the file header, all rights to this code are
 | 
			
		||||
reserved by:
 | 
			
		||||
 | 
			
		||||
Guillermo Rodriguez, guillermor@.edu.uy
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,47 @@
 | 
			
		|||
create db in postgres:
 | 
			
		||||
 | 
			
		||||
```sql
 | 
			
		||||
CREATE USER skynet WITH PASSWORD 'password';
 | 
			
		||||
CREATE DATABASE skynet_art_bot;
 | 
			
		||||
GRANT ALL PRIVILEGES ON DATABASE skynet_art_bot TO skynet;
 | 
			
		||||
 | 
			
		||||
CREATE SCHEMA IF NOT EXISTS skynet;
 | 
			
		||||
 | 
			
		||||
CREATE TABLE IF NOT EXISTS skynet.user(
 | 
			
		||||
   id SERIAL PRIMARY KEY NOT NULL,
 | 
			
		||||
   tg_id INT,
 | 
			
		||||
   wp_id VARCHAR(128),
 | 
			
		||||
   mx_id VARCHAR(128),
 | 
			
		||||
   ig_id VARCHAR(128),
 | 
			
		||||
   generated INT NOT NULL,
 | 
			
		||||
   joined DATE NOT NULL,
 | 
			
		||||
   last_prompt TEXT,
 | 
			
		||||
   role VARCHAR(128) NOT NULL
 | 
			
		||||
);
 | 
			
		||||
ALTER TABLE skynet.user
 | 
			
		||||
    ADD CONSTRAINT tg_unique
 | 
			
		||||
    UNIQUE (tg_id);
 | 
			
		||||
ALTER TABLE skynet.user
 | 
			
		||||
    ADD CONSTRAINT wp_unique
 | 
			
		||||
    UNIQUE (wp_id);
 | 
			
		||||
ALTER TABLE skynet.user
 | 
			
		||||
    ADD CONSTRAINT mx_unique
 | 
			
		||||
    UNIQUE (mx_id);
 | 
			
		||||
ALTER TABLE skynet.user
 | 
			
		||||
    ADD CONSTRAINT ig_unique
 | 
			
		||||
    UNIQUE (ig_id);
 | 
			
		||||
 | 
			
		||||
CREATE TABLE IF NOT EXISTS skynet.user_config(
 | 
			
		||||
    id SERIAL NOT NULL,
 | 
			
		||||
    algo VARCHAR(128) NOT NULL,
 | 
			
		||||
    step INT NOT NULL,
 | 
			
		||||
    width INT NOT NULL,
 | 
			
		||||
    height INT NOT NULL,
 | 
			
		||||
    seed INT,
 | 
			
		||||
    guidance INT NOT NULL,
 | 
			
		||||
    upscaler VARCHAR(128)
 | 
			
		||||
);
 | 
			
		||||
ALTER TABLE skynet.user_config
 | 
			
		||||
    ADD FOREIGN KEY(id)
 | 
			
		||||
    REFERENCES skynet.user(id);
 | 
			
		||||
```
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,7 @@
 | 
			
		|||
docker build \
 | 
			
		||||
    -t skynet:runtime-cuda \
 | 
			
		||||
    -f Dockerfile.runtime-cuda .
 | 
			
		||||
 | 
			
		||||
docker build \
 | 
			
		||||
    -t skynet:runtime \
 | 
			
		||||
    -f Dockerfile.runtime .
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,2 @@
 | 
			
		|||
[pytest]
 | 
			
		||||
trio_mode = true
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,8 @@
 | 
			
		|||
pdbpp
 | 
			
		||||
scipy
 | 
			
		||||
accelerate
 | 
			
		||||
transformers
 | 
			
		||||
huggingface_hub
 | 
			
		||||
diffusers[torch]
 | 
			
		||||
torch==1.13.0+cu117
 | 
			
		||||
--extra-index-url https://download.pytorch.org/whl/cu117
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1 @@
 | 
			
		|||
git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,2 @@
 | 
			
		|||
pytest
 | 
			
		||||
pytest-trio
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,8 @@
 | 
			
		|||
trio
 | 
			
		||||
pynng
 | 
			
		||||
triopg
 | 
			
		||||
aiohttp
 | 
			
		||||
msgspec
 | 
			
		||||
trio_asyncio
 | 
			
		||||
 | 
			
		||||
git+https://github.com/goodboy/tractor.git@master#egg=tractor
 | 
			
		||||
							
								
								
									
										14
									
								
								run-bot.sh
								
								
								
								
							
							
						
						
									
										14
									
								
								run-bot.sh
								
								
								
								
							| 
						 | 
				
			
			@ -1,14 +0,0 @@
 | 
			
		|||
mkdir -p outputs
 | 
			
		||||
mkdir -p hf_home
 | 
			
		||||
 | 
			
		||||
docker run \
 | 
			
		||||
    -it \
 | 
			
		||||
    --rm \
 | 
			
		||||
    --gpus=all \
 | 
			
		||||
    --env HF_TOKEN='' \
 | 
			
		||||
    --env DB_USER='skynet' \
 | 
			
		||||
    --env DB_PASS='nnf01nmf091d0i' \
 | 
			
		||||
    --mount type=bind,source="$(pwd)"/outputs,target=/outputs \
 | 
			
		||||
    --mount type=bind,source="$(pwd)"/hf_home,target=/hf_home \
 | 
			
		||||
    --mount type=bind,source="$(pwd)"/scripts,target=/scripts \
 | 
			
		||||
    skynet:dif python telegram-bot-dev.py
 | 
			
		||||
| 
						 | 
				
			
			@ -1,9 +0,0 @@
 | 
			
		|||
docker run
 | 
			
		||||
    -d \
 | 
			
		||||
    --rm \
 | 
			
		||||
    -p 27017:27017 \
 | 
			
		||||
	--name mongodb-skynet \
 | 
			
		||||
    --mount type=bind,source="$(pwd)"/mongodb,target=/data/db \
 | 
			
		||||
	-e MONGO_INITDB_ROOT_USERNAME="" \
 | 
			
		||||
	-e MONGO_INITDB_ROOT_PASSWORD="" \
 | 
			
		||||
    mongo	
 | 
			
		||||
| 
						 | 
				
			
			@ -1,537 +0,0 @@
 | 
			
		|||
#!/usr/bin/python
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
import random
 | 
			
		||||
 | 
			
		||||
from torch.multiprocessing import spawn
 | 
			
		||||
 | 
			
		||||
import telebot
 | 
			
		||||
from telebot.types import InputFile
 | 
			
		||||
 | 
			
		||||
import sys
 | 
			
		||||
import uuid
 | 
			
		||||
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch.multiprocessing.spawn import ProcessRaisedException
 | 
			
		||||
from diffusers import (
 | 
			
		||||
    StableDiffusionPipeline,
 | 
			
		||||
    EulerAncestralDiscreteScheduler
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from huggingface_hub import login
 | 
			
		||||
from datetime import datetime
 | 
			
		||||
 | 
			
		||||
from pymongo import MongoClient
 | 
			
		||||
 | 
			
		||||
from typing import Tuple, Optional
 | 
			
		||||
 | 
			
		||||
db_user = os.environ['DB_USER']
 | 
			
		||||
db_pass = os.environ['DB_PASS']
 | 
			
		||||
 | 
			
		||||
logging.basicConfig(level=logging.INFO)
 | 
			
		||||
 | 
			
		||||
MEM_FRACTION = .33
 | 
			
		||||
 | 
			
		||||
ALGOS = {
 | 
			
		||||
    'stable': 'runwayml/stable-diffusion-v1-5',
 | 
			
		||||
    'midj': 'prompthero/openjourney',
 | 
			
		||||
    'hdanime': 'Linaqruf/anything-v3.0',
 | 
			
		||||
    'waifu': 'hakurei/waifu-diffusion',
 | 
			
		||||
    'ghibli': 'nitrosocke/Ghibli-Diffusion',
 | 
			
		||||
    'van-gogh': 'dallinmackay/Van-Gogh-diffusion',
 | 
			
		||||
    'pokemon': 'lambdalabs/sd-pokemon-diffusers',
 | 
			
		||||
    'ink': 'Envvi/Inkpunk-Diffusion',
 | 
			
		||||
    'robot': 'nousr/robo-diffusion'
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
N = '\n'
 | 
			
		||||
HELP_TEXT = f'''
 | 
			
		||||
test art bot v0.1a4
 | 
			
		||||
 | 
			
		||||
commands work on a user per user basis!
 | 
			
		||||
config is individual to each user!
 | 
			
		||||
 | 
			
		||||
/txt2img TEXT - request an image based on a prompt
 | 
			
		||||
 | 
			
		||||
/redo - redo last prompt
 | 
			
		||||
 | 
			
		||||
/cool - list of cool words to use
 | 
			
		||||
/stats - user statistics
 | 
			
		||||
/donate - see donation info
 | 
			
		||||
 | 
			
		||||
/config algo NAME - select AI to use one of:
 | 
			
		||||
 | 
			
		||||
{N.join(ALGOS.keys())}
 | 
			
		||||
 | 
			
		||||
/config step NUMBER - set amount of iterations
 | 
			
		||||
/config seed NUMBER - set the seed, deterministic results!
 | 
			
		||||
/config size WIDTH HEIGHT - set size in pixels
 | 
			
		||||
/config guidance NUMBER - prompt text importance
 | 
			
		||||
'''
 | 
			
		||||
 | 
			
		||||
UNKNOWN_CMD_TEXT = 'unknown command! try sending \"/help\"'
 | 
			
		||||
 | 
			
		||||
DONATION_INFO = '0xf95335682DF281FFaB7E104EB87B69625d9622B6\ngoal: 25/650usd'
 | 
			
		||||
 | 
			
		||||
COOL_WORDS = [
 | 
			
		||||
    'cyberpunk',
 | 
			
		||||
    'soviet propaganda poster',
 | 
			
		||||
    'rastafari',
 | 
			
		||||
    'cannabis',
 | 
			
		||||
    'art deco',
 | 
			
		||||
    'H R Giger Necronom IV',
 | 
			
		||||
    'dimethyltryptamine',
 | 
			
		||||
    'lysergic',
 | 
			
		||||
    'slut',
 | 
			
		||||
    'psilocybin',
 | 
			
		||||
    'trippy',
 | 
			
		||||
    'lucy in the sky with diamonds',
 | 
			
		||||
    'fractal',
 | 
			
		||||
    'da vinci',
 | 
			
		||||
    'pencil illustration',
 | 
			
		||||
    'blueprint',
 | 
			
		||||
    'internal diagram',
 | 
			
		||||
    'baroque',
 | 
			
		||||
    'the last judgment',
 | 
			
		||||
    'michelangelo'
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
GROUP_ID = -1001541979235
 | 
			
		||||
 | 
			
		||||
MP_ENABLED_ROLES = ['god']
 | 
			
		||||
 | 
			
		||||
MIN_STEP = 1
 | 
			
		||||
MAX_STEP = 100
 | 
			
		||||
MAX_SIZE = (512, 656)
 | 
			
		||||
MAX_GUIDANCE = 20
 | 
			
		||||
 | 
			
		||||
DEFAULT_SIZE = (512, 512)
 | 
			
		||||
DEFAULT_GUIDANCE = 7.5
 | 
			
		||||
DEFAULT_STEP = 75
 | 
			
		||||
DEFAULT_CREDITS = 10
 | 
			
		||||
DEFAULT_ALGO = 'stable'
 | 
			
		||||
DEFAULT_ROLE = 'pleb'
 | 
			
		||||
DEFAULT_UPSCALER = None
 | 
			
		||||
 | 
			
		||||
rr_total = 1
 | 
			
		||||
rr_id = 0
 | 
			
		||||
request_counter = 0
 | 
			
		||||
 | 
			
		||||
def its_my_turn():
 | 
			
		||||
    global request_counter, rr_total, rr_id
 | 
			
		||||
    my_turn = request_counter % rr_total == rr_id
 | 
			
		||||
    logging.info(f'new request {request_counter}, turn: {my_turn} rr_total: {rr_total}, rr_id {rr_id}')
 | 
			
		||||
    request_counter += 1
 | 
			
		||||
    return my_turn
 | 
			
		||||
 | 
			
		||||
def round_robined(func):
 | 
			
		||||
    def rr_wrapper(*args, **kwargs):
 | 
			
		||||
        if not its_my_turn():
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        func(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    return rr_wrapper
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_image(
 | 
			
		||||
    i: int,
 | 
			
		||||
    prompt: str,
 | 
			
		||||
    name: str,
 | 
			
		||||
    step: int,
 | 
			
		||||
    size: Tuple[int, int],
 | 
			
		||||
    guidance: int,
 | 
			
		||||
    seed: int,
 | 
			
		||||
    algo: str,
 | 
			
		||||
    upscaler: Optional[str]
 | 
			
		||||
):
 | 
			
		||||
    assert torch.cuda.is_available()
 | 
			
		||||
    torch.cuda.empty_cache()
 | 
			
		||||
    torch.cuda.set_per_process_memory_fraction(MEM_FRACTION)
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
        if algo == 'stable':
 | 
			
		||||
            pipe = StableDiffusionPipeline.from_pretrained(
 | 
			
		||||
                'runwayml/stable-diffusion-v1-5',
 | 
			
		||||
                torch_dtype=torch.float16,
 | 
			
		||||
                revision="fp16",
 | 
			
		||||
                safety_checker=None
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            pipe = StableDiffusionPipeline.from_pretrained(
 | 
			
		||||
                ALGOS[algo],
 | 
			
		||||
                torch_dtype=torch.float16,
 | 
			
		||||
                safety_checker=None
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
 | 
			
		||||
        pipe = pipe.to("cuda")
 | 
			
		||||
        w, h = size
 | 
			
		||||
        print(f'generating image... of size {w, h}')
 | 
			
		||||
        image = pipe(
 | 
			
		||||
            prompt,
 | 
			
		||||
            width=w,
 | 
			
		||||
            height=h,
 | 
			
		||||
            guidance_scale=guidance, num_inference_steps=step,
 | 
			
		||||
            generator=torch.Generator("cuda").manual_seed(seed)
 | 
			
		||||
        ).images[0]
 | 
			
		||||
 | 
			
		||||
        if upscaler == 'x4':
 | 
			
		||||
            pipe = StableDiffusionPipeline.from_pretrained(
 | 
			
		||||
                'stabilityai/stable-diffusion-x4-upscaler',
 | 
			
		||||
                revision="fp16",
 | 
			
		||||
                torch_dtype=torch.float16
 | 
			
		||||
            )
 | 
			
		||||
            image = pipe(prompt=prompt, image=image).images[0]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    image.save(f'/outputs/{name}.png')
 | 
			
		||||
    print('saved')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
 | 
			
		||||
    API_TOKEN = '5880619053:AAFge2UfObw1kCn9Kb7AAyqaIHs_HgM0Fx0'
 | 
			
		||||
 | 
			
		||||
    bot = telebot.TeleBot(API_TOKEN)
 | 
			
		||||
    db_client = MongoClient(
 | 
			
		||||
        host=['ancap.tech:64000'],
 | 
			
		||||
        username=db_user,
 | 
			
		||||
        password=db_pass)
 | 
			
		||||
 | 
			
		||||
    tgdb = db_client.get_database('telegram')
 | 
			
		||||
 | 
			
		||||
    collections = tgdb.list_collection_names()
 | 
			
		||||
 | 
			
		||||
    if 'users' in collections:
 | 
			
		||||
        tg_users = tgdb.get_collection('users')
 | 
			
		||||
        # tg_users.delete_many({})
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
        tg_users = tgdb.create_collection('users')
 | 
			
		||||
 | 
			
		||||
    # db functions
 | 
			
		||||
 | 
			
		||||
    def get_user(uid: int):
 | 
			
		||||
        return tg_users.find_one({'uid': uid})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def new_user(uid: int):
 | 
			
		||||
        if get_user(uid):
 | 
			
		||||
            raise ValueError('User already present on db')
 | 
			
		||||
 | 
			
		||||
        res = tg_users.insert_one({
 | 
			
		||||
            'generated': 0,
 | 
			
		||||
            'uid': uid,
 | 
			
		||||
            'credits': DEFAULT_CREDITS,
 | 
			
		||||
            'joined': datetime.utcnow().isoformat(),
 | 
			
		||||
            'last_prompt': None,
 | 
			
		||||
            'role': DEFAULT_ROLE,
 | 
			
		||||
            'config': {
 | 
			
		||||
                'algo': DEFAULT_ALGO,
 | 
			
		||||
                'step': DEFAULT_STEP,
 | 
			
		||||
                'size': DEFAULT_SIZE,
 | 
			
		||||
                'seed': None,
 | 
			
		||||
                'guidance': DEFAULT_GUIDANCE,
 | 
			
		||||
                'upscaler': DEFAULT_UPSCALER
 | 
			
		||||
            }
 | 
			
		||||
        })
 | 
			
		||||
 | 
			
		||||
        assert res.acknowledged
 | 
			
		||||
 | 
			
		||||
        return get_user(uid)
 | 
			
		||||
 | 
			
		||||
    def migrate_user(db_user):
 | 
			
		||||
        # new: user roles
 | 
			
		||||
        if 'role' not in db_user:
 | 
			
		||||
            res = tg_users.find_one_and_update(
 | 
			
		||||
                {'uid': db_user['uid']}, {'$set': {'role': DEFAULT_ROLE}})
 | 
			
		||||
 | 
			
		||||
        # new: algo selection
 | 
			
		||||
        if 'algo' not in db_user['config']:
 | 
			
		||||
            res = tg_users.find_one_and_update(
 | 
			
		||||
                {'uid': db_user['uid']}, {'$set': {'config.algo': DEFAULT_ALGO}})
 | 
			
		||||
 | 
			
		||||
        # new: upscaler selection
 | 
			
		||||
        if 'upscaler' not in db_user['config']:
 | 
			
		||||
            res = tg_users.find_one_and_update(
 | 
			
		||||
                {'uid': db_user['uid']}, {'$set': {'config.upscaler': DEFAULT_UPSCALER}})
 | 
			
		||||
 | 
			
		||||
        return get_user(db_user['uid'])
 | 
			
		||||
 | 
			
		||||
    def get_or_create_user(uid: int):
 | 
			
		||||
        db_user = get_user(uid)
 | 
			
		||||
 | 
			
		||||
        if not db_user:
 | 
			
		||||
            db_user = new_user(uid)
 | 
			
		||||
 | 
			
		||||
        logging.info(f'req from: {uid}')
 | 
			
		||||
 | 
			
		||||
        return migrate_user(db_user)
 | 
			
		||||
 | 
			
		||||
    def update_user(uid: int, updt_cmd: dict):
 | 
			
		||||
        user = get_user(uid)
 | 
			
		||||
        if not user:
 | 
			
		||||
            raise ValueError('User not present on db')
 | 
			
		||||
 | 
			
		||||
        return tg_users.find_one_and_update(
 | 
			
		||||
            {'uid': uid}, updt_cmd)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    # bot handler
 | 
			
		||||
    def img_for_user_with_prompt(
 | 
			
		||||
        uid: int,
 | 
			
		||||
        prompt: str, step: int, size: Tuple[int, int], guidance: int, seed: int,
 | 
			
		||||
        algo: str, upscaler: Optional[str]
 | 
			
		||||
    ):
 | 
			
		||||
        name = uuid.uuid4()
 | 
			
		||||
 | 
			
		||||
        spawn(
 | 
			
		||||
            generate_image,
 | 
			
		||||
            args=(prompt, name, step, size, guidance, seed, algo, upscaler))
 | 
			
		||||
 | 
			
		||||
        logging.info(f'done generating. got {name}, sending...')
 | 
			
		||||
 | 
			
		||||
        if len(prompt) > 256:
 | 
			
		||||
            reply_txt = f'prompt: \"{prompt[:256]}...\"\n(full prompt too big to show on reply...)\n'
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            reply_txt = f'prompt: \"{prompt}\"\n'
 | 
			
		||||
 | 
			
		||||
        reply_txt +=  f'seed: {seed}\n'
 | 
			
		||||
        reply_txt +=  f'iterations: {step}\n'
 | 
			
		||||
        reply_txt +=  f'size: {size}\n'
 | 
			
		||||
        reply_txt +=  f'guidance: {guidance}\n'
 | 
			
		||||
        reply_txt +=  f'algo: {ALGOS[algo]}\n'
 | 
			
		||||
        reply_txt +=  f'euler ancestral discrete'
 | 
			
		||||
 | 
			
		||||
        return reply_txt, name
 | 
			
		||||
 | 
			
		||||
    @bot.message_handler(commands=['help'])
 | 
			
		||||
    @round_robined
 | 
			
		||||
    def send_help(message):
 | 
			
		||||
        bot.reply_to(message, HELP_TEXT)
 | 
			
		||||
 | 
			
		||||
    @bot.message_handler(commands=['cool'])
 | 
			
		||||
    @round_robined
 | 
			
		||||
    def send_cool_words(message):
 | 
			
		||||
        bot.reply_to(message, '\n'.join(COOL_WORDS))
 | 
			
		||||
 | 
			
		||||
    @bot.message_handler(commands=['txt2img'])
 | 
			
		||||
    @round_robined
 | 
			
		||||
    def send_txt2img(message):
 | 
			
		||||
        chat = message.chat
 | 
			
		||||
        user = message.from_user
 | 
			
		||||
        db_user = get_or_create_user(user.id)
 | 
			
		||||
 | 
			
		||||
        if ((chat.type != 'group' and chat.id != GROUP_ID) and
 | 
			
		||||
                (db_user['role'] not in MP_ENABLED_ROLES)):
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        prompt = ' '.join(message.text.split(' ')[1:])
 | 
			
		||||
 | 
			
		||||
        if len(prompt) == 0:
 | 
			
		||||
            bot.reply_to(message, 'empty text prompt ignored.')
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        logging.info(f"{user.first_name} ({user.id}) on chat {chat.id} txt2img: {prompt}")
 | 
			
		||||
 | 
			
		||||
        user_conf = db_user['config']
 | 
			
		||||
 | 
			
		||||
        algo = user_conf['algo']
 | 
			
		||||
        step = user_conf['step']
 | 
			
		||||
        size = user_conf['size']
 | 
			
		||||
        seed = user_conf['seed'] if user_conf['seed'] else random.randint(0, 999999999)
 | 
			
		||||
        guidance = user_conf['guidance']
 | 
			
		||||
        upscaler = user_conf['upscaler']
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            reply_txt, name = img_for_user_with_prompt(
 | 
			
		||||
                user.id, prompt, step, size, guidance, seed, algo, upscaler)
 | 
			
		||||
 | 
			
		||||
            update_user(
 | 
			
		||||
                user.id,
 | 
			
		||||
                {'$set': {
 | 
			
		||||
                    'generated': db_user['generated'] + 1,
 | 
			
		||||
                    'last_prompt': prompt
 | 
			
		||||
                    }})
 | 
			
		||||
 | 
			
		||||
            bot.send_photo(
 | 
			
		||||
                chat.id,
 | 
			
		||||
                caption=f'sent by: {user.first_name}\n' + reply_txt,
 | 
			
		||||
                photo=InputFile(f'/outputs/{name}.png'))
 | 
			
		||||
 | 
			
		||||
        except BaseException as e:
 | 
			
		||||
            logging.error(e)
 | 
			
		||||
            bot.reply_to(message, 'that command caused an error :(\nchange settings and try again (?')
 | 
			
		||||
 | 
			
		||||
    @bot.message_handler(commands=['redo'])
 | 
			
		||||
    @round_robined
 | 
			
		||||
    def redo_txt2img(message):
 | 
			
		||||
        # check msg comes from testing group
 | 
			
		||||
        chat = message.chat
 | 
			
		||||
        user = message.from_user
 | 
			
		||||
        db_user = get_or_create_user(user.id)
 | 
			
		||||
 | 
			
		||||
        if ((chat.type != 'group' and chat.id != GROUP_ID) and
 | 
			
		||||
                (db_user['role'] not in MP_ENABLED_ROLES)):
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        prompt = db_user['last_prompt']
 | 
			
		||||
 | 
			
		||||
        if not prompt:
 | 
			
		||||
            bot.reply_to(message, 'do a /txt2img command first silly!')
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        user_conf = db_user['config']
 | 
			
		||||
 | 
			
		||||
        algo = user_conf['algo']
 | 
			
		||||
        step = user_conf['step']
 | 
			
		||||
        size = user_conf['size']
 | 
			
		||||
        seed = user_conf['seed'] if user_conf['seed'] else random.randint(0, 999999999)
 | 
			
		||||
        guidance = user_conf['guidance']
 | 
			
		||||
        upscaler = user_conf['upscaler']
 | 
			
		||||
 | 
			
		||||
        logging.info(f"{user.first_name} ({user.id}) on chat {chat.id} redo: {prompt}")
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            reply_txt, name = img_for_user_with_prompt(
 | 
			
		||||
                user.id, prompt, step, size, guidance, seed, algo, upscaler)
 | 
			
		||||
 | 
			
		||||
            update_user(
 | 
			
		||||
                user.id,
 | 
			
		||||
                {'$set': {
 | 
			
		||||
                    'generated': db_user['generated'] + 1,
 | 
			
		||||
                    }})
 | 
			
		||||
 | 
			
		||||
            bot.send_photo(
 | 
			
		||||
                chat.id,
 | 
			
		||||
                caption=f'sent by: {user.first_name}\n' + reply_txt,
 | 
			
		||||
                photo=InputFile(f'/outputs/{name}.png'))
 | 
			
		||||
 | 
			
		||||
        except BaseException as e:
 | 
			
		||||
            logging.error(e)
 | 
			
		||||
            bot.reply_to(message, 'that command caused an error :(\nchange settings and try again (?')
 | 
			
		||||
 | 
			
		||||
    @bot.message_handler(commands=['config'])
 | 
			
		||||
    @round_robined
 | 
			
		||||
    def set_config(message):
 | 
			
		||||
        logging.info(f'config req on chat: {message.chat.id}')
 | 
			
		||||
 | 
			
		||||
        params = message.text.split(' ')
 | 
			
		||||
 | 
			
		||||
        if len(params) < 3:
 | 
			
		||||
            bot.reply_to(message, 'wrong msg format')
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            user = message.from_user
 | 
			
		||||
            chat = message.chat
 | 
			
		||||
            db_user = get_or_create_user(user.id)
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                attr = params[1]
 | 
			
		||||
 | 
			
		||||
                if attr == 'algo':
 | 
			
		||||
                    val = params[2]
 | 
			
		||||
                    assert val in ALGOS
 | 
			
		||||
                    res = update_user(user.id, {'$set': {'config.algo': val}})
 | 
			
		||||
 | 
			
		||||
                elif attr == 'step':
 | 
			
		||||
                    val = int(params[2])
 | 
			
		||||
                    val = max(min(val, MAX_STEP), MIN_STEP)
 | 
			
		||||
                    res = update_user(user.id, {'$set': {'config.step': val}})
 | 
			
		||||
 | 
			
		||||
                elif attr  == 'size':
 | 
			
		||||
                    max_w, max_h = MAX_SIZE
 | 
			
		||||
                    w = max(min(int(params[2]), max_w), 16)
 | 
			
		||||
                    h = max(min(int(params[3]), max_h), 16)
 | 
			
		||||
 | 
			
		||||
                    val = (w, h)
 | 
			
		||||
 | 
			
		||||
                    if (w % 8 != 0) or (h % 8 != 0):
 | 
			
		||||
                        bot.reply_to(message, 'size must be divisible by 8!')
 | 
			
		||||
                        return
 | 
			
		||||
 | 
			
		||||
                    res = update_user(user.id, {'$set': {'config.size': val}})
 | 
			
		||||
 | 
			
		||||
                elif attr == 'seed':
 | 
			
		||||
                    val = params[2]
 | 
			
		||||
                    if val == 'auto':
 | 
			
		||||
                        val = None
 | 
			
		||||
                    else:
 | 
			
		||||
                        val = int(params[2])
 | 
			
		||||
 | 
			
		||||
                    res = update_user(user.id, {'$set': {'config.seed': val}})
 | 
			
		||||
 | 
			
		||||
                elif attr == 'guidance':
 | 
			
		||||
                    val = float(params[2])
 | 
			
		||||
                    val = max(min(val, MAX_GUIDANCE), 0)
 | 
			
		||||
                    res = update_user(user.id, {'$set': {'config.guidance': val}})
 | 
			
		||||
 | 
			
		||||
                elif attr == 'upscaler':
 | 
			
		||||
                    val = params[2]
 | 
			
		||||
                    if val == 'off':
 | 
			
		||||
                        val = None
 | 
			
		||||
 | 
			
		||||
                    res = update_user(user.id, {'$set': {'config.upscaler': val}})
 | 
			
		||||
 | 
			
		||||
                else:
 | 
			
		||||
                    bot.reply_to(message, f'\"{attr}\" not a parameter')
 | 
			
		||||
 | 
			
		||||
                bot.reply_to(message, f'config updated! {attr} to {val}')
 | 
			
		||||
 | 
			
		||||
            except ValueError:
 | 
			
		||||
                bot.reply_to(message, f'\"{val}\" is not a number silly')
 | 
			
		||||
 | 
			
		||||
            except AssertionError:
 | 
			
		||||
                bot.reply_to(message, f'no algo named {val}')
 | 
			
		||||
 | 
			
		||||
    @bot.message_handler(commands=['stats'])
 | 
			
		||||
    @round_robined
 | 
			
		||||
    def user_stats(message):
 | 
			
		||||
        user = message.from_user
 | 
			
		||||
        db_user = get_or_create_user(user.id)
 | 
			
		||||
        migrate_user(db_user)
 | 
			
		||||
 | 
			
		||||
        joined_date_str = datetime.fromisoformat(db_user['joined']).strftime('%B the %dth %Y, %H:%M:%S')
 | 
			
		||||
 | 
			
		||||
        user_stats_str = f'generated: {db_user["generated"]}\n'
 | 
			
		||||
        user_stats_str += f'joined: {joined_date_str}\n'
 | 
			
		||||
        user_stats_str += f'credits: {db_user["credits"]}\n'
 | 
			
		||||
        user_stats_str += f'role: {db_user["role"]}\n'
 | 
			
		||||
 | 
			
		||||
        bot.reply_to(
 | 
			
		||||
            message, user_stats_str)
 | 
			
		||||
 | 
			
		||||
    @bot.message_handler(commands=['donate'])
 | 
			
		||||
    @round_robined
 | 
			
		||||
    def donation_info(message):
 | 
			
		||||
        bot.reply_to(
 | 
			
		||||
            message, DONATION_INFO)
 | 
			
		||||
 | 
			
		||||
    @bot.message_handler(commands=['say'])
 | 
			
		||||
    @round_robined
 | 
			
		||||
    def say(message):
 | 
			
		||||
        chat = message.chat
 | 
			
		||||
        user = message.from_user
 | 
			
		||||
        db_user = get_or_create_user(user.id)
 | 
			
		||||
 | 
			
		||||
        if (chat.type == 'group') or (db_user['role'] not in MP_ENABLED_ROLES):
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        bot.send_message(GROUP_ID, message.text[4:])
 | 
			
		||||
 | 
			
		||||
    @bot.message_handler(func=lambda message: True)
 | 
			
		||||
    @round_robined
 | 
			
		||||
    def echo_message(message):
 | 
			
		||||
        if message.text[0] == '/':
 | 
			
		||||
            bot.reply_to(message, UNKNOWN_CMD_TEXT)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    login(token=os.environ['HF_TOKEN'])
 | 
			
		||||
 | 
			
		||||
    bot.infinity_polling()
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,11 @@
 | 
			
		|||
from setuptools import setup, find_packages
 | 
			
		||||
 | 
			
		||||
setup(
 | 
			
		||||
    name='skynet-bot',
 | 
			
		||||
    version='0.1.0a6',
 | 
			
		||||
    description='Decentralized compute platform',
 | 
			
		||||
    author='Guillermo Rodriguez',
 | 
			
		||||
    author_email='guillermo@telos.net',
 | 
			
		||||
    packages=find_packages(),
 | 
			
		||||
    install_requires=[]
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,2 @@
 | 
			
		|||
#!/usr/bin/python
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,246 @@
 | 
			
		|||
#!/usr/bin/python
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
import uuid
 | 
			
		||||
import base64
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
from uuid import UUID
 | 
			
		||||
from functools import partial
 | 
			
		||||
from collections import OrderedDict
 | 
			
		||||
 | 
			
		||||
import trio
 | 
			
		||||
import pynng
 | 
			
		||||
import trio_asyncio
 | 
			
		||||
 | 
			
		||||
from .db import *
 | 
			
		||||
from .types import *
 | 
			
		||||
from .constants import *
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SkynetDGPUOffline(BaseException):
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
class SkynetDGPUOverloaded(BaseException):
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def rpc_service(sock, dgpu_bus, db_pool):
 | 
			
		||||
    nodes = OrderedDict()
 | 
			
		||||
    wip_reqs = {}
 | 
			
		||||
    fin_reqs = {}
 | 
			
		||||
 | 
			
		||||
    def are_all_workers_busy():
 | 
			
		||||
        for nid, info in nodes.items():
 | 
			
		||||
            if info['task'] == None:
 | 
			
		||||
                return False
 | 
			
		||||
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    next_worker = 0
 | 
			
		||||
    def get_next_worker():
 | 
			
		||||
        nonlocal next_worker
 | 
			
		||||
 | 
			
		||||
        if len(nodes) == 0:
 | 
			
		||||
            raise SkynetDGPUOffline
 | 
			
		||||
 | 
			
		||||
        if are_all_workers_busy():
 | 
			
		||||
            raise SkynetDGPUOverloaded
 | 
			
		||||
 | 
			
		||||
        next_worker += 1
 | 
			
		||||
 | 
			
		||||
        if next_worker >= len(nodes):
 | 
			
		||||
            next_worker = 0
 | 
			
		||||
 | 
			
		||||
        nid = list(nodes.keys())[next_worker]
 | 
			
		||||
        return nid
 | 
			
		||||
 | 
			
		||||
    async def dgpu_image_streamer():
 | 
			
		||||
        nonlocal wip_reqs, fin_reqs
 | 
			
		||||
        while True:
 | 
			
		||||
            msg = await dgpu_bus.arecv_msg()
 | 
			
		||||
            rid = UUID(bytes=msg.bytes[:16]).hex
 | 
			
		||||
            img = msg.bytes[16:].hex()
 | 
			
		||||
            fin_reqs[rid] = img
 | 
			
		||||
            event = wip_reqs[rid]
 | 
			
		||||
            event.set()
 | 
			
		||||
            del wip_reqs[rid]
 | 
			
		||||
 | 
			
		||||
    async def dgpu_stream_one_img(req: ImageGenRequest):
 | 
			
		||||
        nonlocal wip_reqs, fin_reqs, next_worker
 | 
			
		||||
        nid = get_next_worker()
 | 
			
		||||
        logging.info(f'dgpu_stream_one_img {next_worker} {nid}')
 | 
			
		||||
        rid = uuid.uuid4().hex
 | 
			
		||||
        event = trio.Event()
 | 
			
		||||
        wip_reqs[rid] = event
 | 
			
		||||
 | 
			
		||||
        nodes[nid]['task'] = rid
 | 
			
		||||
 | 
			
		||||
        dgpu_req = DGPUBusRequest(
 | 
			
		||||
            rid=rid,
 | 
			
		||||
            nid=nid,
 | 
			
		||||
            task='diffuse',
 | 
			
		||||
            params=req.to_dict())
 | 
			
		||||
 | 
			
		||||
        logging.info(f'dgpu_bus req: {dgpu_req}')
 | 
			
		||||
 | 
			
		||||
        await dgpu_bus.asend(
 | 
			
		||||
            json.dumps(dgpu_req.to_dict()).encode())
 | 
			
		||||
 | 
			
		||||
        await event.wait()
 | 
			
		||||
 | 
			
		||||
        nodes[nid]['task'] = None
 | 
			
		||||
 | 
			
		||||
        img = fin_reqs[rid]
 | 
			
		||||
        del fin_reqs[rid]
 | 
			
		||||
 | 
			
		||||
        logging.info(f'done streaming {img}')
 | 
			
		||||
 | 
			
		||||
        return rid, img
 | 
			
		||||
 | 
			
		||||
    async def handle_user_request(rpc_ctx, req):
 | 
			
		||||
        try:
 | 
			
		||||
            async with db_pool.acquire() as conn:
 | 
			
		||||
                user = await get_or_create_user(conn, req.uid)
 | 
			
		||||
 | 
			
		||||
                result = {}
 | 
			
		||||
 | 
			
		||||
                match req.method:
 | 
			
		||||
                    case 'txt2img':
 | 
			
		||||
                        logging.info('txt2img')
 | 
			
		||||
                        user_config = {**(await get_user_config(conn, user))}
 | 
			
		||||
                        del user_config['id']
 | 
			
		||||
                        prompt = req.params['prompt']
 | 
			
		||||
                        req = ImageGenRequest(
 | 
			
		||||
                            prompt=prompt,
 | 
			
		||||
                            **user_config
 | 
			
		||||
                        )
 | 
			
		||||
                        rid, img = await dgpu_stream_one_img(req)
 | 
			
		||||
                        result = {
 | 
			
		||||
                            'id': rid,
 | 
			
		||||
                            'img': img
 | 
			
		||||
                        }
 | 
			
		||||
 | 
			
		||||
                    case 'redo':
 | 
			
		||||
                        logging.info('redo')
 | 
			
		||||
                        user_config = await get_user_config(conn, user)
 | 
			
		||||
                        prompt = await get_last_prompt_of(conn, user)
 | 
			
		||||
                        req = ImageGenRequest(
 | 
			
		||||
                            prompt=prompt,
 | 
			
		||||
                            **user_config
 | 
			
		||||
                        )
 | 
			
		||||
                        rid, img = await dgpu_stream_one_img(req)
 | 
			
		||||
                        result = {
 | 
			
		||||
                            'id': rid,
 | 
			
		||||
                            'img': img
 | 
			
		||||
                        }
 | 
			
		||||
 | 
			
		||||
                    case 'config':
 | 
			
		||||
                        logging.info('config')
 | 
			
		||||
                        if req.params['attr'] in CONFIG_ATTRS:
 | 
			
		||||
                            await update_user_config(
 | 
			
		||||
                                conn, user, req.params['attr'], req.params['val'])
 | 
			
		||||
 | 
			
		||||
                    case 'stats':
 | 
			
		||||
                        logging.info('stats')
 | 
			
		||||
                        generated, joined, role = await get_user_stats(conn, user)
 | 
			
		||||
 | 
			
		||||
                        result = {
 | 
			
		||||
                            'generated': generated,
 | 
			
		||||
                            'joined': joined.strftime(DATE_FORMAT),
 | 
			
		||||
                            'role': role
 | 
			
		||||
                        }
 | 
			
		||||
 | 
			
		||||
                    case _:
 | 
			
		||||
                        logging.warn('unknown method')
 | 
			
		||||
 | 
			
		||||
        except SkynetDGPUOffline:
 | 
			
		||||
            result = {
 | 
			
		||||
                'error': 'skynet_dgpu_offline'
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        except SkynetDGPUOverloaded:
 | 
			
		||||
            result = {
 | 
			
		||||
                'error': 'skynet_dgpu_overloaded',
 | 
			
		||||
                'nodes': len(nodes)
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        except BaseException as e:
 | 
			
		||||
            logging.error(e)
 | 
			
		||||
            raise e
 | 
			
		||||
            # result = {
 | 
			
		||||
            #     'error': 'skynet_internal_error'
 | 
			
		||||
            # }
 | 
			
		||||
 | 
			
		||||
        await rpc_ctx.asend(
 | 
			
		||||
            json.dumps(
 | 
			
		||||
                SkynetRPCResponse(result=result).to_dict()).encode())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    async with trio.open_nursery() as n:
 | 
			
		||||
        n.start_soon(dgpu_image_streamer)
 | 
			
		||||
        while True:
 | 
			
		||||
            ctx = sock.new_context()
 | 
			
		||||
            msg = await ctx.arecv_msg()
 | 
			
		||||
            content = msg.bytes.decode()
 | 
			
		||||
            req = SkynetRPCRequest(**json.loads(content))
 | 
			
		||||
 | 
			
		||||
            logging.info(req)
 | 
			
		||||
 | 
			
		||||
            if req.method == 'dgpu_online':
 | 
			
		||||
                nodes[req.uid] = {
 | 
			
		||||
                    'task': None
 | 
			
		||||
                }
 | 
			
		||||
                logging.info(f'dgpu online: {req.uid}')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
            elif req.method == 'dgpu_offline':
 | 
			
		||||
                i = nodes.values().index(req.uid)
 | 
			
		||||
                del nodes[req.uid]
 | 
			
		||||
 | 
			
		||||
                if i < next_worker:
 | 
			
		||||
                    next_worker -= 1
 | 
			
		||||
                logging.info(f'dgpu offline: {req.uid}')
 | 
			
		||||
 | 
			
		||||
            else:
 | 
			
		||||
                n.start_soon(
 | 
			
		||||
                    handle_user_request, ctx, req)
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            await ctx.asend(
 | 
			
		||||
                json.dumps(
 | 
			
		||||
                    SkynetRPCResponse(
 | 
			
		||||
                        result={'ok': {}}).to_dict()).encode())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def run_skynet(
 | 
			
		||||
    db_user: str,
 | 
			
		||||
    db_pass: str,
 | 
			
		||||
    db_host: str = DB_HOST,
 | 
			
		||||
    rpc_address: str = DEFAULT_RPC_ADDR,
 | 
			
		||||
    dgpu_address: str = DEFAULT_DGPU_ADDR,
 | 
			
		||||
    task_status = trio.TASK_STATUS_IGNORED
 | 
			
		||||
):
 | 
			
		||||
    logging.basicConfig(level=logging.INFO)
 | 
			
		||||
    logging.info('skynet is starting')
 | 
			
		||||
 | 
			
		||||
    async with (
 | 
			
		||||
        trio.open_nursery() as n,
 | 
			
		||||
        open_database_connection(
 | 
			
		||||
            db_user, db_pass, db_host) as db_pool
 | 
			
		||||
    ):
 | 
			
		||||
        logging.info('connected to db.')
 | 
			
		||||
        with (
 | 
			
		||||
            pynng.Rep0(listen=rpc_address) as rpc_sock,
 | 
			
		||||
            pynng.Bus0(listen=dgpu_address) as dgpu_bus
 | 
			
		||||
        ):
 | 
			
		||||
            n.start_soon(
 | 
			
		||||
                rpc_service, rpc_sock, dgpu_bus, db_pool)
 | 
			
		||||
            task_status.started()
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                await trio.sleep_forever()
 | 
			
		||||
 | 
			
		||||
            except KeyboardInterrupt:
 | 
			
		||||
                ...
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,129 @@
 | 
			
		|||
#!/usr/bin/python
 | 
			
		||||
 | 
			
		||||
API_TOKEN = '5880619053:AAFge2UfObw1kCn9Kb7AAyqaIHs_HgM0Fx0'
 | 
			
		||||
 | 
			
		||||
DB_HOST = 'ancap.tech:34508'
 | 
			
		||||
 | 
			
		||||
ALGOS = {
 | 
			
		||||
    'stable': 'runwayml/stable-diffusion-v1-5',
 | 
			
		||||
    'midj': 'prompthero/openjourney',
 | 
			
		||||
    'hdanime': 'Linaqruf/anything-v3.0',
 | 
			
		||||
    'waifu': 'hakurei/waifu-diffusion',
 | 
			
		||||
    'ghibli': 'nitrosocke/Ghibli-Diffusion',
 | 
			
		||||
    'van-gogh': 'dallinmackay/Van-Gogh-diffusion',
 | 
			
		||||
    'pokemon': 'lambdalabs/sd-pokemon-diffusers',
 | 
			
		||||
    'ink': 'Envvi/Inkpunk-Diffusion',
 | 
			
		||||
    'robot': 'nousr/robo-diffusion'
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
N = '\n'
 | 
			
		||||
HELP_TEXT = f'''
 | 
			
		||||
test art bot v0.1a4
 | 
			
		||||
 | 
			
		||||
commands work on a user per user basis!
 | 
			
		||||
config is individual to each user!
 | 
			
		||||
 | 
			
		||||
/txt2img TEXT - request an image based on a prompt
 | 
			
		||||
 | 
			
		||||
/redo - re ont
 | 
			
		||||
 | 
			
		||||
/help step - get info on step config option
 | 
			
		||||
/help guidance - get info on guidance config option
 | 
			
		||||
 | 
			
		||||
/cool - list of cool words to use
 | 
			
		||||
/stats - user statistics
 | 
			
		||||
/donate - see donation info
 | 
			
		||||
 | 
			
		||||
/config algo NAME - select AI to use one of:
 | 
			
		||||
 | 
			
		||||
{N.join(ALGOS.keys())}
 | 
			
		||||
 | 
			
		||||
/config step NUMBER - set amount of iterations
 | 
			
		||||
/config seed NUMBER - set the seed, deterministic results!
 | 
			
		||||
/config size WIDTH HEIGHT - set size in pixels
 | 
			
		||||
/config guidance NUMBER - prompt text importance
 | 
			
		||||
'''
 | 
			
		||||
 | 
			
		||||
UNKNOWN_CMD_TEXT = 'unknown command! try sending \"/help\"'
 | 
			
		||||
 | 
			
		||||
DONATION_INFO = '0xf95335682DF281FFaB7E104EB87B69625d9622B6\ngoal: 25/650usd'
 | 
			
		||||
 | 
			
		||||
COOL_WORDS = [
 | 
			
		||||
    'cyberpunk',
 | 
			
		||||
    'soviet propaganda poster',
 | 
			
		||||
    'rastafari',
 | 
			
		||||
    'cannabis',
 | 
			
		||||
    'art deco',
 | 
			
		||||
    'H R Giger Necronom IV',
 | 
			
		||||
    'dimethyltryptamine',
 | 
			
		||||
    'lysergic',
 | 
			
		||||
    'slut',
 | 
			
		||||
    'psilocybin',
 | 
			
		||||
    'trippy',
 | 
			
		||||
    'lucy in the sky with diamonds',
 | 
			
		||||
    'fractal',
 | 
			
		||||
    'da vinci',
 | 
			
		||||
    'pencil illustration',
 | 
			
		||||
    'blueprint',
 | 
			
		||||
    'internal diagram',
 | 
			
		||||
    'baroque',
 | 
			
		||||
    'the last judgment',
 | 
			
		||||
    'michelangelo'
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
HELP_STEP = '''
 | 
			
		||||
diffusion models are iterative processes – a repeated cycle that starts with a\
 | 
			
		||||
 random noise generated from text input. With each step, some noise is removed\
 | 
			
		||||
, resulting in a higher-quality image over time. The repetition stops when the\
 | 
			
		||||
 desired number of steps completes.
 | 
			
		||||
 | 
			
		||||
around 25 sampling steps are usually enough to achieve high-quality images. Us\
 | 
			
		||||
ing more may produce a slightly different picture, but not necessarily better \
 | 
			
		||||
quality.
 | 
			
		||||
'''
 | 
			
		||||
 | 
			
		||||
HELP_GUIDANCE = '''
 | 
			
		||||
the guidance scale is a parameter that controls how much the image generation\
 | 
			
		||||
 process follows the text prompt. The higher the value, the more image sticks\
 | 
			
		||||
 to a given text input.
 | 
			
		||||
'''
 | 
			
		||||
 | 
			
		||||
HELP_UNKWNOWN_PARAM = 'don\'t have any info on that.'
 | 
			
		||||
 | 
			
		||||
GROUP_ID = -1001541979235
 | 
			
		||||
 | 
			
		||||
MP_ENABLED_ROLES = ['god']
 | 
			
		||||
 | 
			
		||||
MIN_STEP = 1
 | 
			
		||||
MAX_STEP = 100
 | 
			
		||||
MAX_WIDTH = 512
 | 
			
		||||
MAX_HEIGHT = 656
 | 
			
		||||
MAX_GUIDANCE = 20
 | 
			
		||||
 | 
			
		||||
DEFAULT_SEED = None
 | 
			
		||||
DEFAULT_WIDTH = 512
 | 
			
		||||
DEFAULT_HEIGHT = 512
 | 
			
		||||
DEFAULT_GUIDANCE = 7.5
 | 
			
		||||
DEFAULT_STEP = 35
 | 
			
		||||
DEFAULT_CREDITS = 10
 | 
			
		||||
DEFAULT_ALGO = 'midj'
 | 
			
		||||
DEFAULT_ROLE = 'pleb'
 | 
			
		||||
DEFAULT_UPSCALER = None
 | 
			
		||||
 | 
			
		||||
DEFAULT_RPC_ADDR = 'tcp://127.0.0.1:41000'
 | 
			
		||||
 | 
			
		||||
DEFAULT_DGPU_ADDR = 'tcp://127.0.0.1:41069'
 | 
			
		||||
DEFAULT_DGPU_MAX_TASKS = 3
 | 
			
		||||
DEFAULT_INITAL_ALGOS = ['midj', 'stable', 'ink']
 | 
			
		||||
 | 
			
		||||
DATE_FORMAT = '%B the %dth %Y, %H:%M:%S'
 | 
			
		||||
 | 
			
		||||
CONFIG_ATTRS = [
 | 
			
		||||
    'algo',
 | 
			
		||||
    'step',
 | 
			
		||||
    'width',
 | 
			
		||||
    'height',
 | 
			
		||||
    'seed',
 | 
			
		||||
    'guidance',
 | 
			
		||||
    'upscaler'
 | 
			
		||||
]
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,146 @@
 | 
			
		|||
#!/usr/bin/python
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
from datetime import datetime
 | 
			
		||||
from contextlib import asynccontextmanager as acm
 | 
			
		||||
 | 
			
		||||
import trio
 | 
			
		||||
import triopg
 | 
			
		||||
 | 
			
		||||
from .constants import *
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def try_decode_uid(uid: str):
 | 
			
		||||
    try:
 | 
			
		||||
        proto, uid = uid.split('+')
 | 
			
		||||
        uid = int(uid)
 | 
			
		||||
        return proto, uid
 | 
			
		||||
 | 
			
		||||
    except ValueError:
 | 
			
		||||
        logging.warning(f'got non numeric uid?: {uid}')
 | 
			
		||||
        return None, None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@acm
 | 
			
		||||
async def open_database_connection(
 | 
			
		||||
    db_user: str,
 | 
			
		||||
    db_pass: str,
 | 
			
		||||
    db_host: str = DB_HOST,
 | 
			
		||||
):
 | 
			
		||||
    async with triopg.create_pool(
 | 
			
		||||
        dsn=f'postgres://{db_user}:{db_pass}@{db_host}/skynet_art_bot'
 | 
			
		||||
    ) as conn:
 | 
			
		||||
        yield conn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def get_user(conn, uid: str):
 | 
			
		||||
    if isinstance(uid, str):
 | 
			
		||||
        proto, uid = try_decode_uid(uid)
 | 
			
		||||
 | 
			
		||||
        match proto:
 | 
			
		||||
            case 'tg':
 | 
			
		||||
                stmt = await conn.prepare(
 | 
			
		||||
                    'SELECT * FROM skynet.user WHERE tg_id = $1')
 | 
			
		||||
                user = await stmt.fetchval(uid)
 | 
			
		||||
 | 
			
		||||
            case _:
 | 
			
		||||
                user = None
 | 
			
		||||
 | 
			
		||||
        return user
 | 
			
		||||
 | 
			
		||||
    else:  # asumme is our uid
 | 
			
		||||
        stmt = await conn.prepare(
 | 
			
		||||
            'SELECT * FROM skynet.user WHERE id = $1')
 | 
			
		||||
        return await stmt.fetchval(uid)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def get_user_config(conn, user: int):
 | 
			
		||||
    stmt = await conn.prepare(
 | 
			
		||||
        'SELECT * FROM skynet.user_config WHERE id = $1')
 | 
			
		||||
    return (await stmt.fetch(user))[0]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def get_last_prompt_of(conn, user: int):
 | 
			
		||||
    stms = await conn.prepare(
 | 
			
		||||
        'SELECT last_prompt FROM skynet.user WHERE id = $1')
 | 
			
		||||
    return await stmt.fetchval(user)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def new_user(conn, uid: str):
 | 
			
		||||
    if await get_user(conn, uid):
 | 
			
		||||
        raise ValueError('User already present on db')
 | 
			
		||||
 | 
			
		||||
    logging.info(f'new user! {uid}')
 | 
			
		||||
 | 
			
		||||
    tg_id = None
 | 
			
		||||
    date = datetime.utcnow()
 | 
			
		||||
 | 
			
		||||
    proto, pid = try_decode_uid(uid)
 | 
			
		||||
 | 
			
		||||
    match proto:
 | 
			
		||||
        case 'tg':
 | 
			
		||||
            tg_id = pid
 | 
			
		||||
 | 
			
		||||
    async with conn.transaction():
 | 
			
		||||
        stmt = await conn.prepare('''
 | 
			
		||||
            INSERT INTO skynet.user(
 | 
			
		||||
                tg_id, generated, joined, last_prompt, role)
 | 
			
		||||
 | 
			
		||||
            VALUES($1, $2, $3, $4, $5)
 | 
			
		||||
        ''')
 | 
			
		||||
        await stmt.fetch(
 | 
			
		||||
            tg_id, 0, date, None, DEFAULT_ROLE
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        new_uid = await get_user(conn, uid)
 | 
			
		||||
 | 
			
		||||
        stmt = await conn.prepare('''
 | 
			
		||||
            INSERT INTO skynet.user_config(
 | 
			
		||||
                id, algo, step, width, height, seed, guidance, upscaler)
 | 
			
		||||
 | 
			
		||||
            VALUES($1, $2, $3, $4, $5, $6, $7, $8)
 | 
			
		||||
        ''')
 | 
			
		||||
        user = await stmt.fetch(
 | 
			
		||||
            new_uid,
 | 
			
		||||
            DEFAULT_ALGO,
 | 
			
		||||
            DEFAULT_STEP,
 | 
			
		||||
            DEFAULT_WIDTH,
 | 
			
		||||
            DEFAULT_HEIGHT,
 | 
			
		||||
            DEFAULT_SEED,
 | 
			
		||||
            DEFAULT_GUIDANCE,
 | 
			
		||||
            DEFAULT_UPSCALER
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return new_uid
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def get_or_create_user(conn, uid: str):
 | 
			
		||||
    user = await get_user(conn, uid)
 | 
			
		||||
 | 
			
		||||
    if not user:
 | 
			
		||||
        user = await new_user(conn, uid)
 | 
			
		||||
 | 
			
		||||
    return user
 | 
			
		||||
 | 
			
		||||
async def update_user(conn, user: int, attr: str, val):
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
async def update_user_config(conn, user: int, attr: str, val):
 | 
			
		||||
    stmt = await conn.prepare(f'''
 | 
			
		||||
        UPDATE skynet.user_config
 | 
			
		||||
        SET {attr} = $2
 | 
			
		||||
        WHERE id = $1
 | 
			
		||||
    ''')
 | 
			
		||||
    await stmt.fetch(user, val)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def get_user_stats(conn, user: int):
 | 
			
		||||
    stmt = await conn.prepare('''
 | 
			
		||||
        SELECT generated,joined,role FROM skynet.user
 | 
			
		||||
        WHERE id = $1
 | 
			
		||||
    ''')
 | 
			
		||||
    records = await stmt.fetch(user)
 | 
			
		||||
    assert len(records) == 1
 | 
			
		||||
    record = records[0]
 | 
			
		||||
    return record
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,121 @@
 | 
			
		|||
#!/usr/bin/python
 | 
			
		||||
 | 
			
		||||
import trio
 | 
			
		||||
import json
 | 
			
		||||
import uuid
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
import pynng
 | 
			
		||||
import tractor
 | 
			
		||||
 | 
			
		||||
from . import gpu
 | 
			
		||||
from .gpu import open_gpu_worker
 | 
			
		||||
from .types import *
 | 
			
		||||
from .constants import *
 | 
			
		||||
from .frontend import rpc_call
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def open_dgpu_node(
 | 
			
		||||
    rpc_address: str = DEFAULT_RPC_ADDR,
 | 
			
		||||
    dgpu_address: str = DEFAULT_DGPU_ADDR,
 | 
			
		||||
    dgpu_max_tasks: int = DEFAULT_DGPU_MAX_TASKS,
 | 
			
		||||
    initial_algos: str = DEFAULT_INITAL_ALGOS
 | 
			
		||||
):
 | 
			
		||||
    logging.basicConfig(level=logging.INFO)
 | 
			
		||||
 | 
			
		||||
    name = uuid.uuid4()
 | 
			
		||||
    workers = initial_algos.copy()
 | 
			
		||||
    tasks = [None for _ in range(dgpu_max_tasks)]
 | 
			
		||||
 | 
			
		||||
    portal_map: dict[int, tractor.Portal]
 | 
			
		||||
    contexts: dict[int, tractor.Context]
 | 
			
		||||
 | 
			
		||||
    def get_next_worker(need_algo: str):
 | 
			
		||||
        nonlocal workers, tasks
 | 
			
		||||
        for task, algo in zip(workers, tasks):
 | 
			
		||||
            if need_algo == algo and not task:
 | 
			
		||||
                return workers.index(need_algo)
 | 
			
		||||
 | 
			
		||||
        return tasks.index(None)
 | 
			
		||||
 | 
			
		||||
    async def gpu_streamer(
 | 
			
		||||
        ctx: tractor.Context,
 | 
			
		||||
        nid: int
 | 
			
		||||
    ):
 | 
			
		||||
        nonlocal tasks
 | 
			
		||||
        async with ctx.open_stream() as stream:
 | 
			
		||||
            async for img in stream:
 | 
			
		||||
                tasks[nid]['res'] = img
 | 
			
		||||
                tasks[nid]['event'].set()
 | 
			
		||||
 | 
			
		||||
    async def gpu_compute_one(ireq: ImageGenRequest):
 | 
			
		||||
        wid = get_next_worker(ireq.algo)
 | 
			
		||||
        event = trio.Event()
 | 
			
		||||
 | 
			
		||||
        workers[wid] = ireq.algo
 | 
			
		||||
        tasks[wid] = {
 | 
			
		||||
            'res': None, 'event': event}
 | 
			
		||||
 | 
			
		||||
        await contexts[i].send(ireq)
 | 
			
		||||
 | 
			
		||||
        await event.wait()
 | 
			
		||||
 | 
			
		||||
        img = tasks[wid]['res']
 | 
			
		||||
        tasks[wid] = None
 | 
			
		||||
        return img
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    with (
 | 
			
		||||
        pynng.Req0(dial=rpc_address) as rpc_sock,
 | 
			
		||||
        pynng.Bus0(dial=dgpu_address) as dgpu_sock
 | 
			
		||||
    ):
 | 
			
		||||
        async def _rpc_call(*args, **kwargs):
 | 
			
		||||
            return await rpc_call(rpc_sock, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
        async def _process_dgpu_req(req: DGPUBusRequest):
 | 
			
		||||
            img = await gpu_compute_one(
 | 
			
		||||
                ImageGenRequest(**req.params))
 | 
			
		||||
            await dgpu_sock.asend(
 | 
			
		||||
                bytes.fromhex(req.rid) + img)
 | 
			
		||||
 | 
			
		||||
        res = await _rpc_call(
 | 
			
		||||
            name.hex, 'dgpu_online', {'max_tasks': dgpu_max_tasks})
 | 
			
		||||
        logging.info(res)
 | 
			
		||||
        assert 'ok' in res.result
 | 
			
		||||
 | 
			
		||||
        async with (
 | 
			
		||||
            tractor.open_actor_cluster(
 | 
			
		||||
                modules=['skynet_bot.gpu'],
 | 
			
		||||
                count=dgpu_max_tasks,
 | 
			
		||||
                names=[i for i in range(dgpu_max_tasks)]
 | 
			
		||||
                ) as portal_map,
 | 
			
		||||
            trio.open_nursery() as n
 | 
			
		||||
        ):
 | 
			
		||||
            logging.info(f'starting {dgpu_max_tasks} gpu workers')
 | 
			
		||||
            async with tractor.gather_contexts((
 | 
			
		||||
                ctx.open_context(
 | 
			
		||||
                    open_gpu_worker, algo, 1.0 / dgpu_max_tasks)
 | 
			
		||||
            )) as contexts:
 | 
			
		||||
                contexts = {i: ctx for i, ctx in enumerate(contexts)}
 | 
			
		||||
                for i, ctx in contexts.items():
 | 
			
		||||
                    n.start_soon(
 | 
			
		||||
                        gpu_streamer, ctx, i)
 | 
			
		||||
                try:
 | 
			
		||||
                    while True:
 | 
			
		||||
                        msg = await dgpu_sock.arecv()
 | 
			
		||||
                        req = DGPUBusRequest(
 | 
			
		||||
                            **json.loads(msg.decode()))
 | 
			
		||||
 | 
			
		||||
                        if req.nid != name.hex:
 | 
			
		||||
                            continue
 | 
			
		||||
 | 
			
		||||
                        logging.info(f'dgpu: {name}, req: {req}')
 | 
			
		||||
                        n.start_soon(
 | 
			
		||||
                            _process_dgpu_req, req)
 | 
			
		||||
 | 
			
		||||
                except KeyboardInterrupt:
 | 
			
		||||
                    ...
 | 
			
		||||
 | 
			
		||||
        res = await _rpc_call(name.hex, 'dgpu_offline')
 | 
			
		||||
        logging.info(res)
 | 
			
		||||
        assert 'ok' in res.result
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,50 @@
 | 
			
		|||
#!/usr/bin/python
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
 | 
			
		||||
from typing import Union
 | 
			
		||||
from contextlib import contextmanager as cm
 | 
			
		||||
 | 
			
		||||
import pynng
 | 
			
		||||
 | 
			
		||||
from ..types import SkynetRPCRequest, SkynetRPCResponse
 | 
			
		||||
from ..constants import DEFAULT_RPC_ADDR
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ConfigUnknownAttribute(BaseException):
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
class ConfigUnknownAlgorithm(BaseException):
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
class ConfigUnknownUpscaler(BaseException):
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
class ConfigSizeDivisionByEight(BaseException):
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def rpc_call(
 | 
			
		||||
    sock,
 | 
			
		||||
    uid: Union[int, str],
 | 
			
		||||
    method: str,
 | 
			
		||||
    params: dict = {}
 | 
			
		||||
):
 | 
			
		||||
    req = SkynetRPCRequest(
 | 
			
		||||
        uid=uid,
 | 
			
		||||
        method=method,
 | 
			
		||||
        params=params
 | 
			
		||||
    )
 | 
			
		||||
    await sock.asend(
 | 
			
		||||
        json.dumps(
 | 
			
		||||
            req.to_dict()).encode())
 | 
			
		||||
 | 
			
		||||
    return SkynetRPCResponse(
 | 
			
		||||
        **json.loads(
 | 
			
		||||
            (await sock.arecv_msg()).bytes.decode()))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@cm
 | 
			
		||||
def open_skynet_rpc(rpc_address: str = DEFAULT_RPC_ADDR):
 | 
			
		||||
    with pynng.Req0(dial=rpc_address) as rpc_sock:
 | 
			
		||||
        yield rpc_sock
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,164 @@
 | 
			
		|||
#!/usr/bin/python
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
from datetime import datetime
 | 
			
		||||
 | 
			
		||||
import pynng
 | 
			
		||||
 | 
			
		||||
from telebot.async_telebot import AsyncTeleBot
 | 
			
		||||
from trio_asyncio import aio_as_trio
 | 
			
		||||
 | 
			
		||||
from ..constants import *
 | 
			
		||||
 | 
			
		||||
from . import *
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
PREFIX = 'tg'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def run_skynet_telegram(tg_token: str):
 | 
			
		||||
 | 
			
		||||
    logging.basicConfig(level=logging.INFO)
 | 
			
		||||
    bot = AsyncTeleBot(tg_token)
 | 
			
		||||
 | 
			
		||||
    with open_skynet_rpc() as rpc_sock:
 | 
			
		||||
 | 
			
		||||
        async def _rpc_call(
 | 
			
		||||
            uid: int,
 | 
			
		||||
            method: str,
 | 
			
		||||
            params: dict
 | 
			
		||||
        ):
 | 
			
		||||
            return await rpc_call(
 | 
			
		||||
                rpc_sock, f'{PREFIX}+{uid}', method, params)
 | 
			
		||||
 | 
			
		||||
        @bot.message_handler(commands=['help'])
 | 
			
		||||
        async def send_help(message):
 | 
			
		||||
            await bot.reply_to(message, HELP_TEXT)
 | 
			
		||||
 | 
			
		||||
        @bot.message_handler(commands=['cool'])
 | 
			
		||||
        async def send_cool_words(message):
 | 
			
		||||
            await bot.reply_to(message, '\n'.join(COOL_WORDS))
 | 
			
		||||
 | 
			
		||||
        @bot.message_handler(commands=['txt2img'])
 | 
			
		||||
        async def send_txt2img(message):
 | 
			
		||||
            resp = await _rpc_call(
 | 
			
		||||
                message.from_user.id,
 | 
			
		||||
                'txt2img',
 | 
			
		||||
                {}
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        @bot.message_handler(commands=['redo'])
 | 
			
		||||
        async def redo_txt2img(message):
 | 
			
		||||
            resp = await _rpc_call(
 | 
			
		||||
                message.from_user.id,
 | 
			
		||||
                'redo',
 | 
			
		||||
                {}
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        @bot.message_handler(commands=['config'])
 | 
			
		||||
        async def set_config(message):
 | 
			
		||||
            params = message.text.split(' ')
 | 
			
		||||
 | 
			
		||||
            rpc_params = {}
 | 
			
		||||
 | 
			
		||||
            if len(params) < 3:
 | 
			
		||||
                bot.reply_to(message, 'wrong msg format')
 | 
			
		||||
 | 
			
		||||
            else:
 | 
			
		||||
 | 
			
		||||
                try:
 | 
			
		||||
                    attr = params[1]
 | 
			
		||||
 | 
			
		||||
                    if attr == 'algo':
 | 
			
		||||
                        val = params[2]
 | 
			
		||||
                        if val not in ALGOS:
 | 
			
		||||
                            raise ConfigUnknownAlgorithm
 | 
			
		||||
 | 
			
		||||
                    elif attr == 'step':
 | 
			
		||||
                        val = int(params[2])
 | 
			
		||||
                        val = max(min(val, MAX_STEP), MIN_STEP)
 | 
			
		||||
 | 
			
		||||
                    elif attr  == 'width':
 | 
			
		||||
                        val = max(min(int(params[2]), MAX_WIDTH), 16)
 | 
			
		||||
                        if val % 8 != 0:
 | 
			
		||||
                            raise ConfigSizeDivisionByEight
 | 
			
		||||
 | 
			
		||||
                    elif attr  == 'height':
 | 
			
		||||
                        val = max(min(int(params[2]), MAX_HEIGHT), 16)
 | 
			
		||||
                        if val % 8 != 0:
 | 
			
		||||
                            raise ConfigSizeDivisionByEight
 | 
			
		||||
 | 
			
		||||
                    elif attr == 'seed':
 | 
			
		||||
                        val = params[2]
 | 
			
		||||
                        if val == 'auto':
 | 
			
		||||
                            val = None
 | 
			
		||||
                        else:
 | 
			
		||||
                            val = int(params[2])
 | 
			
		||||
 | 
			
		||||
                    elif attr == 'guidance':
 | 
			
		||||
                        val = float(params[2])
 | 
			
		||||
                        val = max(min(val, MAX_GUIDANCE), 0)
 | 
			
		||||
 | 
			
		||||
                    elif attr == 'upscaler':
 | 
			
		||||
                        val = params[2]
 | 
			
		||||
                        if val == 'off':
 | 
			
		||||
                            val = None
 | 
			
		||||
                        elif val != 'x4':
 | 
			
		||||
                            raise ConfigUnknownUpscaler
 | 
			
		||||
 | 
			
		||||
                    else:
 | 
			
		||||
                        raise ConfigUnknownAttribute
 | 
			
		||||
 | 
			
		||||
                    resp = await _rpc_call(
 | 
			
		||||
                        message.from_user.id,
 | 
			
		||||
                        'config', {'attr': attr, 'val': val})
 | 
			
		||||
 | 
			
		||||
                    reply_txt = f'config updated! {attr} to {val}'
 | 
			
		||||
 | 
			
		||||
                except ConfigUnknownAlgorithm:
 | 
			
		||||
                    reply_txt = f'no algo named {val}'
 | 
			
		||||
 | 
			
		||||
                except ConfigUnknownAttribute:
 | 
			
		||||
                    reply_txt = f'\"{attr}\" not a configurable parameter'
 | 
			
		||||
 | 
			
		||||
                except ConfigUnknownUpscaler:
 | 
			
		||||
                    reply_txt = f'\"{val}\" is not a valid upscaler'
 | 
			
		||||
 | 
			
		||||
                except ConfigSizeDivisionByEight:
 | 
			
		||||
                    reply_txt = 'size must be divisible by 8!'
 | 
			
		||||
 | 
			
		||||
                except ValueError:
 | 
			
		||||
                    reply_txt = f'\"{val}\" is not a number silly'
 | 
			
		||||
 | 
			
		||||
                await bot.reply_to(message, reply_txt)
 | 
			
		||||
 | 
			
		||||
        @bot.message_handler(commands=['stats'])
 | 
			
		||||
        async def user_stats(message):
 | 
			
		||||
            resp = await _rpc_call(
 | 
			
		||||
                message.from_user.id,
 | 
			
		||||
                'stats',
 | 
			
		||||
                {}
 | 
			
		||||
            )
 | 
			
		||||
            stats = resp.result
 | 
			
		||||
 | 
			
		||||
            stats_str = f'generated: {stats["generated"]}\n'
 | 
			
		||||
            stats_str += f'joined: {stats["joined"]}\n'
 | 
			
		||||
            stats_str += f'role: {stats["role"]}\n'
 | 
			
		||||
 | 
			
		||||
            await bot.reply_to(
 | 
			
		||||
                message, stats_str)
 | 
			
		||||
 | 
			
		||||
        @bot.message_handler(commands=['donate'])
 | 
			
		||||
        async def donation_info(message):
 | 
			
		||||
            await bot.reply_to(
 | 
			
		||||
                message, DONATION_INFO)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        @bot.message_handler(func=lambda message: True)
 | 
			
		||||
        async def echo_message(message):
 | 
			
		||||
            if message.text[0] == '/':
 | 
			
		||||
                await bot.reply_to(message, UNKNOWN_CMD_TEXT)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        await aio_as_trio(bot.infinity_polling())
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,75 @@
 | 
			
		|||
#!/usr/bin/python
 | 
			
		||||
 | 
			
		||||
import io
 | 
			
		||||
import random
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import tractor
 | 
			
		||||
 | 
			
		||||
from diffusers import (
 | 
			
		||||
    StableDiffusionPipeline,
 | 
			
		||||
    EulerAncestralDiscreteScheduler
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from .types import ImageGenRequest
 | 
			
		||||
from .constants import ALGOS
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def pipeline_for(algo: str, mem_fraction: float):
 | 
			
		||||
    assert torch.cuda.is_available()
 | 
			
		||||
    torch.cuda.empty_cache()
 | 
			
		||||
    torch.cuda.set_per_process_memory_fraction(mem_fraction)
 | 
			
		||||
    torch.backends.cuda.matmul.allow_tf32 = True
 | 
			
		||||
    torch.backends.cudnn.allow_tf32 = True
 | 
			
		||||
 | 
			
		||||
    params = {
 | 
			
		||||
        'torch_dtype': torch.float16,
 | 
			
		||||
        'safety_checker': None
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if algo == 'stable':
 | 
			
		||||
        params['revision'] = 'fp16'
 | 
			
		||||
 | 
			
		||||
    pipe = StableDiffusionPipeline.from_pretrained(
 | 
			
		||||
        ALGOS[algo], **params)
 | 
			
		||||
 | 
			
		||||
    pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
 | 
			
		||||
        pipe.scheduler.config)
 | 
			
		||||
 | 
			
		||||
    return pipe.to("cuda")
 | 
			
		||||
 | 
			
		||||
@tractor.context
 | 
			
		||||
async def open_gpu_worker(
 | 
			
		||||
    ctx: tractor.Context,
 | 
			
		||||
    start_algo: str,
 | 
			
		||||
    mem_fraction: float
 | 
			
		||||
):
 | 
			
		||||
    current_algo = start_algo
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
        pipe = pipeline_for(current_algo, mem_fraction)
 | 
			
		||||
        await ctx.started()
 | 
			
		||||
 | 
			
		||||
        async with ctx.open_stream() as bus:
 | 
			
		||||
            async for ireq in bus:
 | 
			
		||||
                if ireq.algo != current_algo:
 | 
			
		||||
                    current_algo = ireq.algo
 | 
			
		||||
                    pipe = pipeline_for(current_algo, mem_fraction)
 | 
			
		||||
 | 
			
		||||
                seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64)
 | 
			
		||||
                image = pipe(
 | 
			
		||||
                    ireq.prompt,
 | 
			
		||||
                    width=ireq.width,
 | 
			
		||||
                    height=ireq.height,
 | 
			
		||||
                    guidance_scale=ireq.guidance,
 | 
			
		||||
                    num_inference_steps=ireq.step,
 | 
			
		||||
                    generator=torch.Generator("cuda").manual_seed(seed)
 | 
			
		||||
                ).images[0]
 | 
			
		||||
 | 
			
		||||
                torch.cuda.empty_cache()
 | 
			
		||||
 | 
			
		||||
                # convert PIL.Image to BytesIO
 | 
			
		||||
                img_bytes = io.BytesIO()
 | 
			
		||||
                image.save(img_bytes, format='PNG')
 | 
			
		||||
                await bus.send(img_bytes.getvalue())
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,109 @@
 | 
			
		|||
# piker: trading gear for hackers
 | 
			
		||||
# Copyright (C) Guillermo Rodriguez (in stewardship for piker0)
 | 
			
		||||
 | 
			
		||||
# This program is free software: you can redistribute it and/or modify
 | 
			
		||||
# it under the terms of the GNU Affero General Public License as published by
 | 
			
		||||
# the Free Software Foundation, either version 3 of the License, or
 | 
			
		||||
# (at your option) any later version.
 | 
			
		||||
 | 
			
		||||
# This program is distributed in the hope that it will be useful,
 | 
			
		||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 | 
			
		||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 | 
			
		||||
# GNU Affero General Public License for more details.
 | 
			
		||||
 | 
			
		||||
# You should have received a copy of the GNU Affero General Public License
 | 
			
		||||
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
Built-in (extension) types.
 | 
			
		||||
"""
 | 
			
		||||
import sys
 | 
			
		||||
from typing import Optional, Union
 | 
			
		||||
from pprint import pformat
 | 
			
		||||
 | 
			
		||||
import msgspec
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Struct(
 | 
			
		||||
    msgspec.Struct,
 | 
			
		||||
 | 
			
		||||
    # https://jcristharif.com/msgspec/structs.html#tagged-unions
 | 
			
		||||
    # tag='pikerstruct',
 | 
			
		||||
    # tag=True,
 | 
			
		||||
):
 | 
			
		||||
    '''
 | 
			
		||||
    A "human friendlier" (aka repl buddy) struct subtype.
 | 
			
		||||
    '''
 | 
			
		||||
    def to_dict(self) -> dict:
 | 
			
		||||
        return {
 | 
			
		||||
            f: getattr(self, f)
 | 
			
		||||
            for f in self.__struct_fields__
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        # only turn on pprint when we detect a python REPL
 | 
			
		||||
        # at runtime B)
 | 
			
		||||
        if (
 | 
			
		||||
            hasattr(sys, 'ps1')
 | 
			
		||||
            # TODO: check if we're in pdb
 | 
			
		||||
        ):
 | 
			
		||||
            return self.pformat()
 | 
			
		||||
 | 
			
		||||
        return super().__repr__()
 | 
			
		||||
 | 
			
		||||
    def pformat(self) -> str:
 | 
			
		||||
        return f'Struct({pformat(self.to_dict())})'
 | 
			
		||||
 | 
			
		||||
    def copy(
 | 
			
		||||
        self,
 | 
			
		||||
        update: Optional[dict] = None,
 | 
			
		||||
 | 
			
		||||
    ) -> msgspec.Struct:
 | 
			
		||||
        '''
 | 
			
		||||
        Validate-typecast all self defined fields, return a copy of us
 | 
			
		||||
        with all such fields.
 | 
			
		||||
        This is kinda like the default behaviour in `pydantic.BaseModel`.
 | 
			
		||||
        '''
 | 
			
		||||
        if update:
 | 
			
		||||
            for k, v in update.items():
 | 
			
		||||
                setattr(self, k, v)
 | 
			
		||||
 | 
			
		||||
        # roundtrip serialize to validate
 | 
			
		||||
        return msgspec.msgpack.Decoder(
 | 
			
		||||
            type=type(self)
 | 
			
		||||
        ).decode(
 | 
			
		||||
            msgspec.msgpack.Encoder().encode(self)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def typecast(
 | 
			
		||||
        self,
 | 
			
		||||
        # fields: Optional[list[str]] = None,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        for fname, ftype in self.__annotations__.items():
 | 
			
		||||
            setattr(self, fname, ftype(getattr(self, fname)))
 | 
			
		||||
 | 
			
		||||
# proto
 | 
			
		||||
 | 
			
		||||
class SkynetRPCRequest(Struct):
 | 
			
		||||
    uid: Union[str, int]  # user unique id
 | 
			
		||||
    method: str  # rpc method name
 | 
			
		||||
    params: dict  # variable params
 | 
			
		||||
 | 
			
		||||
class SkynetRPCResponse(Struct):
 | 
			
		||||
    result: dict
 | 
			
		||||
 | 
			
		||||
class ImageGenRequest(Struct):
 | 
			
		||||
    prompt: str
 | 
			
		||||
    step: int
 | 
			
		||||
    width: int
 | 
			
		||||
    height: int
 | 
			
		||||
    guidance: int
 | 
			
		||||
    seed: Optional[int]
 | 
			
		||||
    algo: str
 | 
			
		||||
    upscaler: Optional[str]
 | 
			
		||||
 | 
			
		||||
class DGPUBusRequest(Struct):
 | 
			
		||||
    rid: str  # req id
 | 
			
		||||
    nid: str  # node id
 | 
			
		||||
    task: str
 | 
			
		||||
    params: dict
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,8 @@
 | 
			
		|||
docker run \
 | 
			
		||||
    -it \
 | 
			
		||||
    --rm \
 | 
			
		||||
    --mount type=bind,source="$(pwd)",target=/skynet \
 | 
			
		||||
    skynet:runtime-cuda \
 | 
			
		||||
    bash -c \
 | 
			
		||||
        "cd /skynet && pip install -e . && \
 | 
			
		||||
        pytest tests/test_dgpu.py --log-cli-level=info"
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,55 @@
 | 
			
		|||
#!/usr/bin/python
 | 
			
		||||
 | 
			
		||||
import time
 | 
			
		||||
import json
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
import trio
 | 
			
		||||
import pynng
 | 
			
		||||
import trio_asyncio
 | 
			
		||||
 | 
			
		||||
from skynet_bot.dgpu import open_dgpu_node
 | 
			
		||||
from skynet_bot.types import *
 | 
			
		||||
from skynet_bot.brain import run_skynet
 | 
			
		||||
from skynet_bot.constants import *
 | 
			
		||||
from skynet_bot.frontend import open_skynet_rpc, rpc_call
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_dgpu_simple():
 | 
			
		||||
    async def main():
 | 
			
		||||
        async with trio.open_nursery() as n:
 | 
			
		||||
            await n.start(
 | 
			
		||||
                run_skynet,
 | 
			
		||||
                'skynet', '3GbZd6UojbD8V7zWpeFn', 'ancap.tech:34508')
 | 
			
		||||
 | 
			
		||||
            await trio.sleep(2)
 | 
			
		||||
 | 
			
		||||
            for i in range(3):
 | 
			
		||||
                n.start_soon(open_dgpu_node)
 | 
			
		||||
 | 
			
		||||
            await trio.sleep(1)
 | 
			
		||||
            start = time.time()
 | 
			
		||||
            async def request_img():
 | 
			
		||||
                with pynng.Req0(dial=DEFAULT_RPC_ADDR) as rpc_sock:
 | 
			
		||||
                    res = await rpc_call(
 | 
			
		||||
                        rpc_sock, 'tg+1', 'txt2img', {
 | 
			
		||||
                            'prompt': 'test',
 | 
			
		||||
                            'step': 28,
 | 
			
		||||
                            'width': 512, 'height': 512,
 | 
			
		||||
                            'guidance': 7.5,
 | 
			
		||||
                            'seed': None,
 | 
			
		||||
                            'algo': 'stable',
 | 
			
		||||
                            'upscaler': None
 | 
			
		||||
                        })
 | 
			
		||||
 | 
			
		||||
                    logging.info(res)
 | 
			
		||||
 | 
			
		||||
            async with trio.open_nursery() as inner_n:
 | 
			
		||||
                for i in range(3):
 | 
			
		||||
                    inner_n.start_soon(request_img)
 | 
			
		||||
 | 
			
		||||
            logging.info(f'time elapsed: {time.time() - start}')
 | 
			
		||||
            n.cancel_scope.cancel()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    trio_asyncio.run(main)
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,22 @@
 | 
			
		|||
#!/usr/bin/python
 | 
			
		||||
 | 
			
		||||
import trio
 | 
			
		||||
import trio_asyncio
 | 
			
		||||
 | 
			
		||||
from skynet_bot.brain import run_skynet
 | 
			
		||||
from skynet_bot.frontend import open_skynet_rpc
 | 
			
		||||
from skynet_bot.frontend.telegram import run_skynet_telegram
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_run_tg_bot():
 | 
			
		||||
    async def main():
 | 
			
		||||
        async with trio.open_nursery() as n:
 | 
			
		||||
            await n.start(
 | 
			
		||||
                run_skynet,
 | 
			
		||||
                'skynet', '3GbZd6UojbD8V7zWpeFn', 'ancap.tech:34508')
 | 
			
		||||
            n.start_soon(
 | 
			
		||||
                run_skynet_telegram,
 | 
			
		||||
                '5853245787:AAFEmv3EjJ_qJ8d_vmOpi6o6HFHUf8a0uCQ')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    trio_asyncio.run(main)
 | 
			
		||||
		Loading…
	
		Reference in New Issue