diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..d611665 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,3 @@ +hf_home +inputs +outputs diff --git a/.gitignore b/.gitignore index 56d3f90..c60d49a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,5 @@ +.python-version hf_home outputs +**/__pycache__ +*.egg-info diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index f9f5af7..0000000 --- a/Dockerfile +++ /dev/null @@ -1,31 +0,0 @@ -from pytorch/pytorch:latest - -env DEBIAN_FRONTEND=noninteractive - -run apt-get update && apt-get install -y git wget - -run conda install xformers -c xformers/label/dev - -run pip install --upgrade \ - diffusers[torch] \ - accelerate \ - transformers \ - huggingface_hub \ - pyTelegramBotAPI \ - pymongo \ - scipy \ - pdbpp - -env NVIDIA_VISIBLE_DEVICES=all - -run mkdir /scripts -run mkdir /outputs -run mkdir /inputs - -env HF_HOME /hf_home - -run mkdir /hf_home - -workdir /scripts - -env PYTORCH_CUDA_ALLOC_CONF max_split_size_mb:128 diff --git a/Dockerfile.runtime b/Dockerfile.runtime new file mode 100644 index 0000000..84b38b6 --- /dev/null +++ b/Dockerfile.runtime @@ -0,0 +1,13 @@ +from python:3.10.0 + +env DEBIAN_FRONTEND=noninteractive + +workdir /skynet + +copy requirements.* ./ + +run pip install \ + -r requirements.txt \ + -r requirements.test.txt + +workdir /scripts diff --git a/Dockerfile.runtime-cuda b/Dockerfile.runtime-cuda new file mode 100644 index 0000000..c39aefc --- /dev/null +++ b/Dockerfile.runtime-cuda @@ -0,0 +1,23 @@ +from nvidia/cuda:11.7.0-devel-ubuntu20.04 +from python:3.10.0 + +env DEBIAN_FRONTEND=noninteractive + +workdir /skynet + +copy requirements.* . + +run pip install -U pip ninja +run pip install -r requirements.cuda.0.txt +run pip install -v -r requirements.cuda.1.txt + +run pip install \ + -r requirements.txt \ + -r requirements.test.txt + +env NVIDIA_VISIBLE_DEVICES=all +env HF_HOME /hf_home + +env PYTORCH_CUDA_ALLOC_CONF max_split_size_mb:128 + +workdir /scripts diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..0fb76a9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,11 @@ +A menos que sea especificamente indicado en el cabezal del archivo, se reservan +todos los derechos sobre este codigo por parte de: + +Guillermo Rodriguez, guillermor@fing.edu.uy + +ENGLISH LICENSE: + +Unless specifically indicated in the file header, all rights to this code are +reserved by: + +Guillermo Rodriguez, guillermor@.edu.uy diff --git a/README.md b/README.md new file mode 100644 index 0000000..a4e8a19 --- /dev/null +++ b/README.md @@ -0,0 +1,47 @@ +create db in postgres: + +```sql +CREATE USER skynet WITH PASSWORD 'password'; +CREATE DATABASE skynet_art_bot; +GRANT ALL PRIVILEGES ON DATABASE skynet_art_bot TO skynet; + +CREATE SCHEMA IF NOT EXISTS skynet; + +CREATE TABLE IF NOT EXISTS skynet.user( + id SERIAL PRIMARY KEY NOT NULL, + tg_id INT, + wp_id VARCHAR(128), + mx_id VARCHAR(128), + ig_id VARCHAR(128), + generated INT NOT NULL, + joined DATE NOT NULL, + last_prompt TEXT, + role VARCHAR(128) NOT NULL +); +ALTER TABLE skynet.user + ADD CONSTRAINT tg_unique + UNIQUE (tg_id); +ALTER TABLE skynet.user + ADD CONSTRAINT wp_unique + UNIQUE (wp_id); +ALTER TABLE skynet.user + ADD CONSTRAINT mx_unique + UNIQUE (mx_id); +ALTER TABLE skynet.user + ADD CONSTRAINT ig_unique + UNIQUE (ig_id); + +CREATE TABLE IF NOT EXISTS skynet.user_config( + id SERIAL NOT NULL, + algo VARCHAR(128) NOT NULL, + step INT NOT NULL, + width INT NOT NULL, + height INT NOT NULL, + seed INT, + guidance INT NOT NULL, + upscaler VARCHAR(128) +); +ALTER TABLE skynet.user_config + ADD FOREIGN KEY(id) + REFERENCES skynet.user(id); +``` diff --git a/build_docker.sh b/build_docker.sh new file mode 100755 index 0000000..72e37f1 --- /dev/null +++ b/build_docker.sh @@ -0,0 +1,7 @@ +docker build \ + -t skynet:runtime-cuda \ + -f Dockerfile.runtime-cuda . + +docker build \ + -t skynet:runtime \ + -f Dockerfile.runtime . diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..5f4a13a --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +trio_mode = true diff --git a/requirements.cuda.0.txt b/requirements.cuda.0.txt new file mode 100644 index 0000000..e91ed18 --- /dev/null +++ b/requirements.cuda.0.txt @@ -0,0 +1,8 @@ +pdbpp +scipy +accelerate +transformers +huggingface_hub +diffusers[torch] +torch==1.13.0+cu117 +--extra-index-url https://download.pytorch.org/whl/cu117 diff --git a/requirements.cuda.1.txt b/requirements.cuda.1.txt new file mode 100644 index 0000000..b9f2703 --- /dev/null +++ b/requirements.cuda.1.txt @@ -0,0 +1 @@ +git+https://github.com/facebookresearch/xformers.git@main#egg=xformers diff --git a/requirements.test.txt b/requirements.test.txt new file mode 100644 index 0000000..5f0802d --- /dev/null +++ b/requirements.test.txt @@ -0,0 +1,2 @@ +pytest +pytest-trio diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..7831fd3 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +trio +pynng +triopg +aiohttp +msgspec +trio_asyncio + +git+https://github.com/goodboy/tractor.git@master#egg=tractor diff --git a/run-bot.sh b/run-bot.sh deleted file mode 100755 index c575c39..0000000 --- a/run-bot.sh +++ /dev/null @@ -1,14 +0,0 @@ -mkdir -p outputs -mkdir -p hf_home - -docker run \ - -it \ - --rm \ - --gpus=all \ - --env HF_TOKEN='' \ - --env DB_USER='skynet' \ - --env DB_PASS='nnf01nmf091d0i' \ - --mount type=bind,source="$(pwd)"/outputs,target=/outputs \ - --mount type=bind,source="$(pwd)"/hf_home,target=/hf_home \ - --mount type=bind,source="$(pwd)"/scripts,target=/scripts \ - skynet:dif python telegram-bot-dev.py diff --git a/run-mongo.sh b/run-mongo.sh deleted file mode 100755 index b3adc3d..0000000 --- a/run-mongo.sh +++ /dev/null @@ -1,9 +0,0 @@ -docker run - -d \ - --rm \ - -p 27017:27017 \ - --name mongodb-skynet \ - --mount type=bind,source="$(pwd)"/mongodb,target=/data/db \ - -e MONGO_INITDB_ROOT_USERNAME="" \ - -e MONGO_INITDB_ROOT_PASSWORD="" \ - mongo diff --git a/scripts/telegram-bot-dev.py b/scripts/telegram-bot-dev.py deleted file mode 100644 index e49cd6f..0000000 --- a/scripts/telegram-bot-dev.py +++ /dev/null @@ -1,537 +0,0 @@ -#!/usr/bin/python - -import os - -import logging -import random - -from torch.multiprocessing import spawn - -import telebot -from telebot.types import InputFile - -import sys -import uuid - -from pathlib import Path - -import torch -from torch.multiprocessing.spawn import ProcessRaisedException -from diffusers import ( - StableDiffusionPipeline, - EulerAncestralDiscreteScheduler -) - -from huggingface_hub import login -from datetime import datetime - -from pymongo import MongoClient - -from typing import Tuple, Optional - -db_user = os.environ['DB_USER'] -db_pass = os.environ['DB_PASS'] - -logging.basicConfig(level=logging.INFO) - -MEM_FRACTION = .33 - -ALGOS = { - 'stable': 'runwayml/stable-diffusion-v1-5', - 'midj': 'prompthero/openjourney', - 'hdanime': 'Linaqruf/anything-v3.0', - 'waifu': 'hakurei/waifu-diffusion', - 'ghibli': 'nitrosocke/Ghibli-Diffusion', - 'van-gogh': 'dallinmackay/Van-Gogh-diffusion', - 'pokemon': 'lambdalabs/sd-pokemon-diffusers', - 'ink': 'Envvi/Inkpunk-Diffusion', - 'robot': 'nousr/robo-diffusion' -} - -N = '\n' -HELP_TEXT = f''' -test art bot v0.1a4 - -commands work on a user per user basis! -config is individual to each user! - -/txt2img TEXT - request an image based on a prompt - -/redo - redo last prompt - -/cool - list of cool words to use -/stats - user statistics -/donate - see donation info - -/config algo NAME - select AI to use one of: - -{N.join(ALGOS.keys())} - -/config step NUMBER - set amount of iterations -/config seed NUMBER - set the seed, deterministic results! -/config size WIDTH HEIGHT - set size in pixels -/config guidance NUMBER - prompt text importance -''' - -UNKNOWN_CMD_TEXT = 'unknown command! try sending \"/help\"' - -DONATION_INFO = '0xf95335682DF281FFaB7E104EB87B69625d9622B6\ngoal: 25/650usd' - -COOL_WORDS = [ - 'cyberpunk', - 'soviet propaganda poster', - 'rastafari', - 'cannabis', - 'art deco', - 'H R Giger Necronom IV', - 'dimethyltryptamine', - 'lysergic', - 'slut', - 'psilocybin', - 'trippy', - 'lucy in the sky with diamonds', - 'fractal', - 'da vinci', - 'pencil illustration', - 'blueprint', - 'internal diagram', - 'baroque', - 'the last judgment', - 'michelangelo' -] - -GROUP_ID = -1001541979235 - -MP_ENABLED_ROLES = ['god'] - -MIN_STEP = 1 -MAX_STEP = 100 -MAX_SIZE = (512, 656) -MAX_GUIDANCE = 20 - -DEFAULT_SIZE = (512, 512) -DEFAULT_GUIDANCE = 7.5 -DEFAULT_STEP = 75 -DEFAULT_CREDITS = 10 -DEFAULT_ALGO = 'stable' -DEFAULT_ROLE = 'pleb' -DEFAULT_UPSCALER = None - -rr_total = 1 -rr_id = 0 -request_counter = 0 - -def its_my_turn(): - global request_counter, rr_total, rr_id - my_turn = request_counter % rr_total == rr_id - logging.info(f'new request {request_counter}, turn: {my_turn} rr_total: {rr_total}, rr_id {rr_id}') - request_counter += 1 - return my_turn - -def round_robined(func): - def rr_wrapper(*args, **kwargs): - if not its_my_turn(): - return - - func(*args, **kwargs) - - return rr_wrapper - - -def generate_image( - i: int, - prompt: str, - name: str, - step: int, - size: Tuple[int, int], - guidance: int, - seed: int, - algo: str, - upscaler: Optional[str] -): - assert torch.cuda.is_available() - torch.cuda.empty_cache() - torch.cuda.set_per_process_memory_fraction(MEM_FRACTION) - with torch.no_grad(): - if algo == 'stable': - pipe = StableDiffusionPipeline.from_pretrained( - 'runwayml/stable-diffusion-v1-5', - torch_dtype=torch.float16, - revision="fp16", - safety_checker=None - ) - - else: - pipe = StableDiffusionPipeline.from_pretrained( - ALGOS[algo], - torch_dtype=torch.float16, - safety_checker=None - ) - - pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) - pipe = pipe.to("cuda") - w, h = size - print(f'generating image... of size {w, h}') - image = pipe( - prompt, - width=w, - height=h, - guidance_scale=guidance, num_inference_steps=step, - generator=torch.Generator("cuda").manual_seed(seed) - ).images[0] - - if upscaler == 'x4': - pipe = StableDiffusionPipeline.from_pretrained( - 'stabilityai/stable-diffusion-x4-upscaler', - revision="fp16", - torch_dtype=torch.float16 - ) - image = pipe(prompt=prompt, image=image).images[0] - - - image.save(f'/outputs/{name}.png') - print('saved') - - -if __name__ == '__main__': - - API_TOKEN = '5880619053:AAFge2UfObw1kCn9Kb7AAyqaIHs_HgM0Fx0' - - bot = telebot.TeleBot(API_TOKEN) - db_client = MongoClient( - host=['ancap.tech:64000'], - username=db_user, - password=db_pass) - - tgdb = db_client.get_database('telegram') - - collections = tgdb.list_collection_names() - - if 'users' in collections: - tg_users = tgdb.get_collection('users') - # tg_users.delete_many({}) - - else: - tg_users = tgdb.create_collection('users') - - # db functions - - def get_user(uid: int): - return tg_users.find_one({'uid': uid}) - - - def new_user(uid: int): - if get_user(uid): - raise ValueError('User already present on db') - - res = tg_users.insert_one({ - 'generated': 0, - 'uid': uid, - 'credits': DEFAULT_CREDITS, - 'joined': datetime.utcnow().isoformat(), - 'last_prompt': None, - 'role': DEFAULT_ROLE, - 'config': { - 'algo': DEFAULT_ALGO, - 'step': DEFAULT_STEP, - 'size': DEFAULT_SIZE, - 'seed': None, - 'guidance': DEFAULT_GUIDANCE, - 'upscaler': DEFAULT_UPSCALER - } - }) - - assert res.acknowledged - - return get_user(uid) - - def migrate_user(db_user): - # new: user roles - if 'role' not in db_user: - res = tg_users.find_one_and_update( - {'uid': db_user['uid']}, {'$set': {'role': DEFAULT_ROLE}}) - - # new: algo selection - if 'algo' not in db_user['config']: - res = tg_users.find_one_and_update( - {'uid': db_user['uid']}, {'$set': {'config.algo': DEFAULT_ALGO}}) - - # new: upscaler selection - if 'upscaler' not in db_user['config']: - res = tg_users.find_one_and_update( - {'uid': db_user['uid']}, {'$set': {'config.upscaler': DEFAULT_UPSCALER}}) - - return get_user(db_user['uid']) - - def get_or_create_user(uid: int): - db_user = get_user(uid) - - if not db_user: - db_user = new_user(uid) - - logging.info(f'req from: {uid}') - - return migrate_user(db_user) - - def update_user(uid: int, updt_cmd: dict): - user = get_user(uid) - if not user: - raise ValueError('User not present on db') - - return tg_users.find_one_and_update( - {'uid': uid}, updt_cmd) - - - # bot handler - def img_for_user_with_prompt( - uid: int, - prompt: str, step: int, size: Tuple[int, int], guidance: int, seed: int, - algo: str, upscaler: Optional[str] - ): - name = uuid.uuid4() - - spawn( - generate_image, - args=(prompt, name, step, size, guidance, seed, algo, upscaler)) - - logging.info(f'done generating. got {name}, sending...') - - if len(prompt) > 256: - reply_txt = f'prompt: \"{prompt[:256]}...\"\n(full prompt too big to show on reply...)\n' - - else: - reply_txt = f'prompt: \"{prompt}\"\n' - - reply_txt += f'seed: {seed}\n' - reply_txt += f'iterations: {step}\n' - reply_txt += f'size: {size}\n' - reply_txt += f'guidance: {guidance}\n' - reply_txt += f'algo: {ALGOS[algo]}\n' - reply_txt += f'euler ancestral discrete' - - return reply_txt, name - - @bot.message_handler(commands=['help']) - @round_robined - def send_help(message): - bot.reply_to(message, HELP_TEXT) - - @bot.message_handler(commands=['cool']) - @round_robined - def send_cool_words(message): - bot.reply_to(message, '\n'.join(COOL_WORDS)) - - @bot.message_handler(commands=['txt2img']) - @round_robined - def send_txt2img(message): - chat = message.chat - user = message.from_user - db_user = get_or_create_user(user.id) - - if ((chat.type != 'group' and chat.id != GROUP_ID) and - (db_user['role'] not in MP_ENABLED_ROLES)): - return - - prompt = ' '.join(message.text.split(' ')[1:]) - - if len(prompt) == 0: - bot.reply_to(message, 'empty text prompt ignored.') - return - - logging.info(f"{user.first_name} ({user.id}) on chat {chat.id} txt2img: {prompt}") - - user_conf = db_user['config'] - - algo = user_conf['algo'] - step = user_conf['step'] - size = user_conf['size'] - seed = user_conf['seed'] if user_conf['seed'] else random.randint(0, 999999999) - guidance = user_conf['guidance'] - upscaler = user_conf['upscaler'] - - try: - reply_txt, name = img_for_user_with_prompt( - user.id, prompt, step, size, guidance, seed, algo, upscaler) - - update_user( - user.id, - {'$set': { - 'generated': db_user['generated'] + 1, - 'last_prompt': prompt - }}) - - bot.send_photo( - chat.id, - caption=f'sent by: {user.first_name}\n' + reply_txt, - photo=InputFile(f'/outputs/{name}.png')) - - except BaseException as e: - logging.error(e) - bot.reply_to(message, 'that command caused an error :(\nchange settings and try again (?') - - @bot.message_handler(commands=['redo']) - @round_robined - def redo_txt2img(message): - # check msg comes from testing group - chat = message.chat - user = message.from_user - db_user = get_or_create_user(user.id) - - if ((chat.type != 'group' and chat.id != GROUP_ID) and - (db_user['role'] not in MP_ENABLED_ROLES)): - return - - prompt = db_user['last_prompt'] - - if not prompt: - bot.reply_to(message, 'do a /txt2img command first silly!') - return - - user_conf = db_user['config'] - - algo = user_conf['algo'] - step = user_conf['step'] - size = user_conf['size'] - seed = user_conf['seed'] if user_conf['seed'] else random.randint(0, 999999999) - guidance = user_conf['guidance'] - upscaler = user_conf['upscaler'] - - logging.info(f"{user.first_name} ({user.id}) on chat {chat.id} redo: {prompt}") - - try: - reply_txt, name = img_for_user_with_prompt( - user.id, prompt, step, size, guidance, seed, algo, upscaler) - - update_user( - user.id, - {'$set': { - 'generated': db_user['generated'] + 1, - }}) - - bot.send_photo( - chat.id, - caption=f'sent by: {user.first_name}\n' + reply_txt, - photo=InputFile(f'/outputs/{name}.png')) - - except BaseException as e: - logging.error(e) - bot.reply_to(message, 'that command caused an error :(\nchange settings and try again (?') - - @bot.message_handler(commands=['config']) - @round_robined - def set_config(message): - logging.info(f'config req on chat: {message.chat.id}') - - params = message.text.split(' ') - - if len(params) < 3: - bot.reply_to(message, 'wrong msg format') - - else: - user = message.from_user - chat = message.chat - db_user = get_or_create_user(user.id) - - try: - attr = params[1] - - if attr == 'algo': - val = params[2] - assert val in ALGOS - res = update_user(user.id, {'$set': {'config.algo': val}}) - - elif attr == 'step': - val = int(params[2]) - val = max(min(val, MAX_STEP), MIN_STEP) - res = update_user(user.id, {'$set': {'config.step': val}}) - - elif attr == 'size': - max_w, max_h = MAX_SIZE - w = max(min(int(params[2]), max_w), 16) - h = max(min(int(params[3]), max_h), 16) - - val = (w, h) - - if (w % 8 != 0) or (h % 8 != 0): - bot.reply_to(message, 'size must be divisible by 8!') - return - - res = update_user(user.id, {'$set': {'config.size': val}}) - - elif attr == 'seed': - val = params[2] - if val == 'auto': - val = None - else: - val = int(params[2]) - - res = update_user(user.id, {'$set': {'config.seed': val}}) - - elif attr == 'guidance': - val = float(params[2]) - val = max(min(val, MAX_GUIDANCE), 0) - res = update_user(user.id, {'$set': {'config.guidance': val}}) - - elif attr == 'upscaler': - val = params[2] - if val == 'off': - val = None - - res = update_user(user.id, {'$set': {'config.upscaler': val}}) - - else: - bot.reply_to(message, f'\"{attr}\" not a parameter') - - bot.reply_to(message, f'config updated! {attr} to {val}') - - except ValueError: - bot.reply_to(message, f'\"{val}\" is not a number silly') - - except AssertionError: - bot.reply_to(message, f'no algo named {val}') - - @bot.message_handler(commands=['stats']) - @round_robined - def user_stats(message): - user = message.from_user - db_user = get_or_create_user(user.id) - migrate_user(db_user) - - joined_date_str = datetime.fromisoformat(db_user['joined']).strftime('%B the %dth %Y, %H:%M:%S') - - user_stats_str = f'generated: {db_user["generated"]}\n' - user_stats_str += f'joined: {joined_date_str}\n' - user_stats_str += f'credits: {db_user["credits"]}\n' - user_stats_str += f'role: {db_user["role"]}\n' - - bot.reply_to( - message, user_stats_str) - - @bot.message_handler(commands=['donate']) - @round_robined - def donation_info(message): - bot.reply_to( - message, DONATION_INFO) - - @bot.message_handler(commands=['say']) - @round_robined - def say(message): - chat = message.chat - user = message.from_user - db_user = get_or_create_user(user.id) - - if (chat.type == 'group') or (db_user['role'] not in MP_ENABLED_ROLES): - return - - bot.send_message(GROUP_ID, message.text[4:]) - - @bot.message_handler(func=lambda message: True) - @round_robined - def echo_message(message): - if message.text[0] == '/': - bot.reply_to(message, UNKNOWN_CMD_TEXT) - - - login(token=os.environ['HF_TOKEN']) - - bot.infinity_polling() diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..f48893c --- /dev/null +++ b/setup.py @@ -0,0 +1,11 @@ +from setuptools import setup, find_packages + +setup( + name='skynet-bot', + version='0.1.0a6', + description='Decentralized compute platform', + author='Guillermo Rodriguez', + author_email='guillermo@telos.net', + packages=find_packages(), + install_requires=[] +) diff --git a/skynet_bot/__init__.py b/skynet_bot/__init__.py new file mode 100644 index 0000000..8d5063a --- /dev/null +++ b/skynet_bot/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/python + diff --git a/skynet_bot/brain.py b/skynet_bot/brain.py new file mode 100644 index 0000000..12563e2 --- /dev/null +++ b/skynet_bot/brain.py @@ -0,0 +1,246 @@ +#!/usr/bin/python + +import json +import uuid +import base64 +import logging + +from uuid import UUID +from functools import partial +from collections import OrderedDict + +import trio +import pynng +import trio_asyncio + +from .db import * +from .types import * +from .constants import * + + +class SkynetDGPUOffline(BaseException): + ... + +class SkynetDGPUOverloaded(BaseException): + ... + + +async def rpc_service(sock, dgpu_bus, db_pool): + nodes = OrderedDict() + wip_reqs = {} + fin_reqs = {} + + def are_all_workers_busy(): + for nid, info in nodes.items(): + if info['task'] == None: + return False + + return True + + next_worker = 0 + def get_next_worker(): + nonlocal next_worker + + if len(nodes) == 0: + raise SkynetDGPUOffline + + if are_all_workers_busy(): + raise SkynetDGPUOverloaded + + next_worker += 1 + + if next_worker >= len(nodes): + next_worker = 0 + + nid = list(nodes.keys())[next_worker] + return nid + + async def dgpu_image_streamer(): + nonlocal wip_reqs, fin_reqs + while True: + msg = await dgpu_bus.arecv_msg() + rid = UUID(bytes=msg.bytes[:16]).hex + img = msg.bytes[16:].hex() + fin_reqs[rid] = img + event = wip_reqs[rid] + event.set() + del wip_reqs[rid] + + async def dgpu_stream_one_img(req: ImageGenRequest): + nonlocal wip_reqs, fin_reqs, next_worker + nid = get_next_worker() + logging.info(f'dgpu_stream_one_img {next_worker} {nid}') + rid = uuid.uuid4().hex + event = trio.Event() + wip_reqs[rid] = event + + nodes[nid]['task'] = rid + + dgpu_req = DGPUBusRequest( + rid=rid, + nid=nid, + task='diffuse', + params=req.to_dict()) + + logging.info(f'dgpu_bus req: {dgpu_req}') + + await dgpu_bus.asend( + json.dumps(dgpu_req.to_dict()).encode()) + + await event.wait() + + nodes[nid]['task'] = None + + img = fin_reqs[rid] + del fin_reqs[rid] + + logging.info(f'done streaming {img}') + + return rid, img + + async def handle_user_request(rpc_ctx, req): + try: + async with db_pool.acquire() as conn: + user = await get_or_create_user(conn, req.uid) + + result = {} + + match req.method: + case 'txt2img': + logging.info('txt2img') + user_config = {**(await get_user_config(conn, user))} + del user_config['id'] + prompt = req.params['prompt'] + req = ImageGenRequest( + prompt=prompt, + **user_config + ) + rid, img = await dgpu_stream_one_img(req) + result = { + 'id': rid, + 'img': img + } + + case 'redo': + logging.info('redo') + user_config = await get_user_config(conn, user) + prompt = await get_last_prompt_of(conn, user) + req = ImageGenRequest( + prompt=prompt, + **user_config + ) + rid, img = await dgpu_stream_one_img(req) + result = { + 'id': rid, + 'img': img + } + + case 'config': + logging.info('config') + if req.params['attr'] in CONFIG_ATTRS: + await update_user_config( + conn, user, req.params['attr'], req.params['val']) + + case 'stats': + logging.info('stats') + generated, joined, role = await get_user_stats(conn, user) + + result = { + 'generated': generated, + 'joined': joined.strftime(DATE_FORMAT), + 'role': role + } + + case _: + logging.warn('unknown method') + + except SkynetDGPUOffline: + result = { + 'error': 'skynet_dgpu_offline' + } + + except SkynetDGPUOverloaded: + result = { + 'error': 'skynet_dgpu_overloaded', + 'nodes': len(nodes) + } + + except BaseException as e: + logging.error(e) + raise e + # result = { + # 'error': 'skynet_internal_error' + # } + + await rpc_ctx.asend( + json.dumps( + SkynetRPCResponse(result=result).to_dict()).encode()) + + + async with trio.open_nursery() as n: + n.start_soon(dgpu_image_streamer) + while True: + ctx = sock.new_context() + msg = await ctx.arecv_msg() + content = msg.bytes.decode() + req = SkynetRPCRequest(**json.loads(content)) + + logging.info(req) + + if req.method == 'dgpu_online': + nodes[req.uid] = { + 'task': None + } + logging.info(f'dgpu online: {req.uid}') + + + elif req.method == 'dgpu_offline': + i = nodes.values().index(req.uid) + del nodes[req.uid] + + if i < next_worker: + next_worker -= 1 + logging.info(f'dgpu offline: {req.uid}') + + else: + n.start_soon( + handle_user_request, ctx, req) + continue + + await ctx.asend( + json.dumps( + SkynetRPCResponse( + result={'ok': {}}).to_dict()).encode()) + + +async def run_skynet( + db_user: str, + db_pass: str, + db_host: str = DB_HOST, + rpc_address: str = DEFAULT_RPC_ADDR, + dgpu_address: str = DEFAULT_DGPU_ADDR, + task_status = trio.TASK_STATUS_IGNORED +): + logging.basicConfig(level=logging.INFO) + logging.info('skynet is starting') + + async with ( + trio.open_nursery() as n, + open_database_connection( + db_user, db_pass, db_host) as db_pool + ): + logging.info('connected to db.') + with ( + pynng.Rep0(listen=rpc_address) as rpc_sock, + pynng.Bus0(listen=dgpu_address) as dgpu_bus + ): + n.start_soon( + rpc_service, rpc_sock, dgpu_bus, db_pool) + task_status.started() + + try: + await trio.sleep_forever() + + except KeyboardInterrupt: + ... + diff --git a/skynet_bot/constants.py b/skynet_bot/constants.py new file mode 100644 index 0000000..a7b21ae --- /dev/null +++ b/skynet_bot/constants.py @@ -0,0 +1,129 @@ +#!/usr/bin/python + +API_TOKEN = '5880619053:AAFge2UfObw1kCn9Kb7AAyqaIHs_HgM0Fx0' + +DB_HOST = 'ancap.tech:34508' + +ALGOS = { + 'stable': 'runwayml/stable-diffusion-v1-5', + 'midj': 'prompthero/openjourney', + 'hdanime': 'Linaqruf/anything-v3.0', + 'waifu': 'hakurei/waifu-diffusion', + 'ghibli': 'nitrosocke/Ghibli-Diffusion', + 'van-gogh': 'dallinmackay/Van-Gogh-diffusion', + 'pokemon': 'lambdalabs/sd-pokemon-diffusers', + 'ink': 'Envvi/Inkpunk-Diffusion', + 'robot': 'nousr/robo-diffusion' +} + +N = '\n' +HELP_TEXT = f''' +test art bot v0.1a4 + +commands work on a user per user basis! +config is individual to each user! + +/txt2img TEXT - request an image based on a prompt + +/redo - re ont + +/help step - get info on step config option +/help guidance - get info on guidance config option + +/cool - list of cool words to use +/stats - user statistics +/donate - see donation info + +/config algo NAME - select AI to use one of: + +{N.join(ALGOS.keys())} + +/config step NUMBER - set amount of iterations +/config seed NUMBER - set the seed, deterministic results! +/config size WIDTH HEIGHT - set size in pixels +/config guidance NUMBER - prompt text importance +''' + +UNKNOWN_CMD_TEXT = 'unknown command! try sending \"/help\"' + +DONATION_INFO = '0xf95335682DF281FFaB7E104EB87B69625d9622B6\ngoal: 25/650usd' + +COOL_WORDS = [ + 'cyberpunk', + 'soviet propaganda poster', + 'rastafari', + 'cannabis', + 'art deco', + 'H R Giger Necronom IV', + 'dimethyltryptamine', + 'lysergic', + 'slut', + 'psilocybin', + 'trippy', + 'lucy in the sky with diamonds', + 'fractal', + 'da vinci', + 'pencil illustration', + 'blueprint', + 'internal diagram', + 'baroque', + 'the last judgment', + 'michelangelo' +] + +HELP_STEP = ''' +diffusion models are iterative processes – a repeated cycle that starts with a\ + random noise generated from text input. With each step, some noise is removed\ +, resulting in a higher-quality image over time. The repetition stops when the\ + desired number of steps completes. + +around 25 sampling steps are usually enough to achieve high-quality images. Us\ +ing more may produce a slightly different picture, but not necessarily better \ +quality. +''' + +HELP_GUIDANCE = ''' +the guidance scale is a parameter that controls how much the image generation\ + process follows the text prompt. The higher the value, the more image sticks\ + to a given text input. +''' + +HELP_UNKWNOWN_PARAM = 'don\'t have any info on that.' + +GROUP_ID = -1001541979235 + +MP_ENABLED_ROLES = ['god'] + +MIN_STEP = 1 +MAX_STEP = 100 +MAX_WIDTH = 512 +MAX_HEIGHT = 656 +MAX_GUIDANCE = 20 + +DEFAULT_SEED = None +DEFAULT_WIDTH = 512 +DEFAULT_HEIGHT = 512 +DEFAULT_GUIDANCE = 7.5 +DEFAULT_STEP = 35 +DEFAULT_CREDITS = 10 +DEFAULT_ALGO = 'midj' +DEFAULT_ROLE = 'pleb' +DEFAULT_UPSCALER = None + +DEFAULT_RPC_ADDR = 'tcp://127.0.0.1:41000' + +DEFAULT_DGPU_ADDR = 'tcp://127.0.0.1:41069' +DEFAULT_DGPU_MAX_TASKS = 3 +DEFAULT_INITAL_ALGOS = ['midj', 'stable', 'ink'] + +DATE_FORMAT = '%B the %dth %Y, %H:%M:%S' + +CONFIG_ATTRS = [ + 'algo', + 'step', + 'width', + 'height', + 'seed', + 'guidance', + 'upscaler' +] diff --git a/skynet_bot/db.py b/skynet_bot/db.py new file mode 100644 index 0000000..d5c94d7 --- /dev/null +++ b/skynet_bot/db.py @@ -0,0 +1,146 @@ +#!/usr/bin/python + +import logging + +from datetime import datetime +from contextlib import asynccontextmanager as acm + +import trio +import triopg + +from .constants import * + + +def try_decode_uid(uid: str): + try: + proto, uid = uid.split('+') + uid = int(uid) + return proto, uid + + except ValueError: + logging.warning(f'got non numeric uid?: {uid}') + return None, None + + +@acm +async def open_database_connection( + db_user: str, + db_pass: str, + db_host: str = DB_HOST, +): + async with triopg.create_pool( + dsn=f'postgres://{db_user}:{db_pass}@{db_host}/skynet_art_bot' + ) as conn: + yield conn + + +async def get_user(conn, uid: str): + if isinstance(uid, str): + proto, uid = try_decode_uid(uid) + + match proto: + case 'tg': + stmt = await conn.prepare( + 'SELECT * FROM skynet.user WHERE tg_id = $1') + user = await stmt.fetchval(uid) + + case _: + user = None + + return user + + else: # asumme is our uid + stmt = await conn.prepare( + 'SELECT * FROM skynet.user WHERE id = $1') + return await stmt.fetchval(uid) + + +async def get_user_config(conn, user: int): + stmt = await conn.prepare( + 'SELECT * FROM skynet.user_config WHERE id = $1') + return (await stmt.fetch(user))[0] + + +async def get_last_prompt_of(conn, user: int): + stms = await conn.prepare( + 'SELECT last_prompt FROM skynet.user WHERE id = $1') + return await stmt.fetchval(user) + + +async def new_user(conn, uid: str): + if await get_user(conn, uid): + raise ValueError('User already present on db') + + logging.info(f'new user! {uid}') + + tg_id = None + date = datetime.utcnow() + + proto, pid = try_decode_uid(uid) + + match proto: + case 'tg': + tg_id = pid + + async with conn.transaction(): + stmt = await conn.prepare(''' + INSERT INTO skynet.user( + tg_id, generated, joined, last_prompt, role) + + VALUES($1, $2, $3, $4, $5) + ''') + await stmt.fetch( + tg_id, 0, date, None, DEFAULT_ROLE + ) + + new_uid = await get_user(conn, uid) + + stmt = await conn.prepare(''' + INSERT INTO skynet.user_config( + id, algo, step, width, height, seed, guidance, upscaler) + + VALUES($1, $2, $3, $4, $5, $6, $7, $8) + ''') + user = await stmt.fetch( + new_uid, + DEFAULT_ALGO, + DEFAULT_STEP, + DEFAULT_WIDTH, + DEFAULT_HEIGHT, + DEFAULT_SEED, + DEFAULT_GUIDANCE, + DEFAULT_UPSCALER + ) + + return new_uid + + +async def get_or_create_user(conn, uid: str): + user = await get_user(conn, uid) + + if not user: + user = await new_user(conn, uid) + + return user + +async def update_user(conn, user: int, attr: str, val): + ... + +async def update_user_config(conn, user: int, attr: str, val): + stmt = await conn.prepare(f''' + UPDATE skynet.user_config + SET {attr} = $2 + WHERE id = $1 + ''') + await stmt.fetch(user, val) + + +async def get_user_stats(conn, user: int): + stmt = await conn.prepare(''' + SELECT generated,joined,role FROM skynet.user + WHERE id = $1 + ''') + records = await stmt.fetch(user) + assert len(records) == 1 + record = records[0] + return record diff --git a/skynet_bot/dgpu.py b/skynet_bot/dgpu.py new file mode 100644 index 0000000..016aa53 --- /dev/null +++ b/skynet_bot/dgpu.py @@ -0,0 +1,121 @@ +#!/usr/bin/python + +import trio +import json +import uuid +import logging + +import pynng +import tractor + +from . import gpu +from .gpu import open_gpu_worker +from .types import * +from .constants import * +from .frontend import rpc_call + + +async def open_dgpu_node( + rpc_address: str = DEFAULT_RPC_ADDR, + dgpu_address: str = DEFAULT_DGPU_ADDR, + dgpu_max_tasks: int = DEFAULT_DGPU_MAX_TASKS, + initial_algos: str = DEFAULT_INITAL_ALGOS +): + logging.basicConfig(level=logging.INFO) + + name = uuid.uuid4() + workers = initial_algos.copy() + tasks = [None for _ in range(dgpu_max_tasks)] + + portal_map: dict[int, tractor.Portal] + contexts: dict[int, tractor.Context] + + def get_next_worker(need_algo: str): + nonlocal workers, tasks + for task, algo in zip(workers, tasks): + if need_algo == algo and not task: + return workers.index(need_algo) + + return tasks.index(None) + + async def gpu_streamer( + ctx: tractor.Context, + nid: int + ): + nonlocal tasks + async with ctx.open_stream() as stream: + async for img in stream: + tasks[nid]['res'] = img + tasks[nid]['event'].set() + + async def gpu_compute_one(ireq: ImageGenRequest): + wid = get_next_worker(ireq.algo) + event = trio.Event() + + workers[wid] = ireq.algo + tasks[wid] = { + 'res': None, 'event': event} + + await contexts[i].send(ireq) + + await event.wait() + + img = tasks[wid]['res'] + tasks[wid] = None + return img + + + with ( + pynng.Req0(dial=rpc_address) as rpc_sock, + pynng.Bus0(dial=dgpu_address) as dgpu_sock + ): + async def _rpc_call(*args, **kwargs): + return await rpc_call(rpc_sock, *args, **kwargs) + + async def _process_dgpu_req(req: DGPUBusRequest): + img = await gpu_compute_one( + ImageGenRequest(**req.params)) + await dgpu_sock.asend( + bytes.fromhex(req.rid) + img) + + res = await _rpc_call( + name.hex, 'dgpu_online', {'max_tasks': dgpu_max_tasks}) + logging.info(res) + assert 'ok' in res.result + + async with ( + tractor.open_actor_cluster( + modules=['skynet_bot.gpu'], + count=dgpu_max_tasks, + names=[i for i in range(dgpu_max_tasks)] + ) as portal_map, + trio.open_nursery() as n + ): + logging.info(f'starting {dgpu_max_tasks} gpu workers') + async with tractor.gather_contexts(( + ctx.open_context( + open_gpu_worker, algo, 1.0 / dgpu_max_tasks) + )) as contexts: + contexts = {i: ctx for i, ctx in enumerate(contexts)} + for i, ctx in contexts.items(): + n.start_soon( + gpu_streamer, ctx, i) + try: + while True: + msg = await dgpu_sock.arecv() + req = DGPUBusRequest( + **json.loads(msg.decode())) + + if req.nid != name.hex: + continue + + logging.info(f'dgpu: {name}, req: {req}') + n.start_soon( + _process_dgpu_req, req) + + except KeyboardInterrupt: + ... + + res = await _rpc_call(name.hex, 'dgpu_offline') + logging.info(res) + assert 'ok' in res.result diff --git a/skynet_bot/frontend/__init__.py b/skynet_bot/frontend/__init__.py new file mode 100644 index 0000000..7211eb5 --- /dev/null +++ b/skynet_bot/frontend/__init__.py @@ -0,0 +1,50 @@ +#!/usr/bin/python + +import json + +from typing import Union +from contextlib import contextmanager as cm + +import pynng + +from ..types import SkynetRPCRequest, SkynetRPCResponse +from ..constants import DEFAULT_RPC_ADDR + + +class ConfigUnknownAttribute(BaseException): + ... + +class ConfigUnknownAlgorithm(BaseException): + ... + +class ConfigUnknownUpscaler(BaseException): + ... + +class ConfigSizeDivisionByEight(BaseException): + ... + + +async def rpc_call( + sock, + uid: Union[int, str], + method: str, + params: dict = {} +): + req = SkynetRPCRequest( + uid=uid, + method=method, + params=params + ) + await sock.asend( + json.dumps( + req.to_dict()).encode()) + + return SkynetRPCResponse( + **json.loads( + (await sock.arecv_msg()).bytes.decode())) + + +@cm +def open_skynet_rpc(rpc_address: str = DEFAULT_RPC_ADDR): + with pynng.Req0(dial=rpc_address) as rpc_sock: + yield rpc_sock diff --git a/skynet_bot/frontend/telegram.py b/skynet_bot/frontend/telegram.py new file mode 100644 index 0000000..8affa29 --- /dev/null +++ b/skynet_bot/frontend/telegram.py @@ -0,0 +1,164 @@ +#!/usr/bin/python + +import logging + +from datetime import datetime + +import pynng + +from telebot.async_telebot import AsyncTeleBot +from trio_asyncio import aio_as_trio + +from ..constants import * + +from . import * + + +PREFIX = 'tg' + + +async def run_skynet_telegram(tg_token: str): + + logging.basicConfig(level=logging.INFO) + bot = AsyncTeleBot(tg_token) + + with open_skynet_rpc() as rpc_sock: + + async def _rpc_call( + uid: int, + method: str, + params: dict + ): + return await rpc_call( + rpc_sock, f'{PREFIX}+{uid}', method, params) + + @bot.message_handler(commands=['help']) + async def send_help(message): + await bot.reply_to(message, HELP_TEXT) + + @bot.message_handler(commands=['cool']) + async def send_cool_words(message): + await bot.reply_to(message, '\n'.join(COOL_WORDS)) + + @bot.message_handler(commands=['txt2img']) + async def send_txt2img(message): + resp = await _rpc_call( + message.from_user.id, + 'txt2img', + {} + ) + + @bot.message_handler(commands=['redo']) + async def redo_txt2img(message): + resp = await _rpc_call( + message.from_user.id, + 'redo', + {} + ) + + @bot.message_handler(commands=['config']) + async def set_config(message): + params = message.text.split(' ') + + rpc_params = {} + + if len(params) < 3: + bot.reply_to(message, 'wrong msg format') + + else: + + try: + attr = params[1] + + if attr == 'algo': + val = params[2] + if val not in ALGOS: + raise ConfigUnknownAlgorithm + + elif attr == 'step': + val = int(params[2]) + val = max(min(val, MAX_STEP), MIN_STEP) + + elif attr == 'width': + val = max(min(int(params[2]), MAX_WIDTH), 16) + if val % 8 != 0: + raise ConfigSizeDivisionByEight + + elif attr == 'height': + val = max(min(int(params[2]), MAX_HEIGHT), 16) + if val % 8 != 0: + raise ConfigSizeDivisionByEight + + elif attr == 'seed': + val = params[2] + if val == 'auto': + val = None + else: + val = int(params[2]) + + elif attr == 'guidance': + val = float(params[2]) + val = max(min(val, MAX_GUIDANCE), 0) + + elif attr == 'upscaler': + val = params[2] + if val == 'off': + val = None + elif val != 'x4': + raise ConfigUnknownUpscaler + + else: + raise ConfigUnknownAttribute + + resp = await _rpc_call( + message.from_user.id, + 'config', {'attr': attr, 'val': val}) + + reply_txt = f'config updated! {attr} to {val}' + + except ConfigUnknownAlgorithm: + reply_txt = f'no algo named {val}' + + except ConfigUnknownAttribute: + reply_txt = f'\"{attr}\" not a configurable parameter' + + except ConfigUnknownUpscaler: + reply_txt = f'\"{val}\" is not a valid upscaler' + + except ConfigSizeDivisionByEight: + reply_txt = 'size must be divisible by 8!' + + except ValueError: + reply_txt = f'\"{val}\" is not a number silly' + + await bot.reply_to(message, reply_txt) + + @bot.message_handler(commands=['stats']) + async def user_stats(message): + resp = await _rpc_call( + message.from_user.id, + 'stats', + {} + ) + stats = resp.result + + stats_str = f'generated: {stats["generated"]}\n' + stats_str += f'joined: {stats["joined"]}\n' + stats_str += f'role: {stats["role"]}\n' + + await bot.reply_to( + message, stats_str) + + @bot.message_handler(commands=['donate']) + async def donation_info(message): + await bot.reply_to( + message, DONATION_INFO) + + + @bot.message_handler(func=lambda message: True) + async def echo_message(message): + if message.text[0] == '/': + await bot.reply_to(message, UNKNOWN_CMD_TEXT) + + + await aio_as_trio(bot.infinity_polling()) diff --git a/skynet_bot/gpu.py b/skynet_bot/gpu.py new file mode 100644 index 0000000..b805bab --- /dev/null +++ b/skynet_bot/gpu.py @@ -0,0 +1,75 @@ +#!/usr/bin/python + +import io +import random +import logging + +import torch +import tractor + +from diffusers import ( + StableDiffusionPipeline, + EulerAncestralDiscreteScheduler +) + +from .types import ImageGenRequest +from .constants import ALGOS + + +def pipeline_for(algo: str, mem_fraction: float): + assert torch.cuda.is_available() + torch.cuda.empty_cache() + torch.cuda.set_per_process_memory_fraction(mem_fraction) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + params = { + 'torch_dtype': torch.float16, + 'safety_checker': None + } + + if algo == 'stable': + params['revision'] = 'fp16' + + pipe = StableDiffusionPipeline.from_pretrained( + ALGOS[algo], **params) + + pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( + pipe.scheduler.config) + + return pipe.to("cuda") + +@tractor.context +async def open_gpu_worker( + ctx: tractor.Context, + start_algo: str, + mem_fraction: float +): + current_algo = start_algo + with torch.no_grad(): + pipe = pipeline_for(current_algo, mem_fraction) + await ctx.started() + + async with ctx.open_stream() as bus: + async for ireq in bus: + if ireq.algo != current_algo: + current_algo = ireq.algo + pipe = pipeline_for(current_algo, mem_fraction) + + seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64) + image = pipe( + ireq.prompt, + width=ireq.width, + height=ireq.height, + guidance_scale=ireq.guidance, + num_inference_steps=ireq.step, + generator=torch.Generator("cuda").manual_seed(seed) + ).images[0] + + torch.cuda.empty_cache() + + # convert PIL.Image to BytesIO + img_bytes = io.BytesIO() + image.save(img_bytes, format='PNG') + await bus.send(img_bytes.getvalue()) + diff --git a/skynet_bot/types.py b/skynet_bot/types.py new file mode 100644 index 0000000..4332229 --- /dev/null +++ b/skynet_bot/types.py @@ -0,0 +1,109 @@ +# piker: trading gear for hackers +# Copyright (C) Guillermo Rodriguez (in stewardship for piker0) + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +""" +Built-in (extension) types. +""" +import sys +from typing import Optional, Union +from pprint import pformat + +import msgspec + + +class Struct( + msgspec.Struct, + + # https://jcristharif.com/msgspec/structs.html#tagged-unions + # tag='pikerstruct', + # tag=True, +): + ''' + A "human friendlier" (aka repl buddy) struct subtype. + ''' + def to_dict(self) -> dict: + return { + f: getattr(self, f) + for f in self.__struct_fields__ + } + + def __repr__(self): + # only turn on pprint when we detect a python REPL + # at runtime B) + if ( + hasattr(sys, 'ps1') + # TODO: check if we're in pdb + ): + return self.pformat() + + return super().__repr__() + + def pformat(self) -> str: + return f'Struct({pformat(self.to_dict())})' + + def copy( + self, + update: Optional[dict] = None, + + ) -> msgspec.Struct: + ''' + Validate-typecast all self defined fields, return a copy of us + with all such fields. + This is kinda like the default behaviour in `pydantic.BaseModel`. + ''' + if update: + for k, v in update.items(): + setattr(self, k, v) + + # roundtrip serialize to validate + return msgspec.msgpack.Decoder( + type=type(self) + ).decode( + msgspec.msgpack.Encoder().encode(self) + ) + + def typecast( + self, + # fields: Optional[list[str]] = None, + ) -> None: + for fname, ftype in self.__annotations__.items(): + setattr(self, fname, ftype(getattr(self, fname))) + +# proto + +class SkynetRPCRequest(Struct): + uid: Union[str, int] # user unique id + method: str # rpc method name + params: dict # variable params + +class SkynetRPCResponse(Struct): + result: dict + +class ImageGenRequest(Struct): + prompt: str + step: int + width: int + height: int + guidance: int + seed: Optional[int] + algo: str + upscaler: Optional[str] + +class DGPUBusRequest(Struct): + rid: str # req id + nid: str # node id + task: str + params: dict diff --git a/test.sh b/test.sh new file mode 100755 index 0000000..4cbef69 --- /dev/null +++ b/test.sh @@ -0,0 +1,8 @@ +docker run \ + -it \ + --rm \ + --mount type=bind,source="$(pwd)",target=/skynet \ + skynet:runtime-cuda \ + bash -c \ + "cd /skynet && pip install -e . && \ + pytest tests/test_dgpu.py --log-cli-level=info" diff --git a/tests/test_dgpu.py b/tests/test_dgpu.py new file mode 100644 index 0000000..421eb18 --- /dev/null +++ b/tests/test_dgpu.py @@ -0,0 +1,55 @@ +#!/usr/bin/python + +import time +import json +import logging + +import trio +import pynng +import trio_asyncio + +from skynet_bot.dgpu import open_dgpu_node +from skynet_bot.types import * +from skynet_bot.brain import run_skynet +from skynet_bot.constants import * +from skynet_bot.frontend import open_skynet_rpc, rpc_call + + +def test_dgpu_simple(): + async def main(): + async with trio.open_nursery() as n: + await n.start( + run_skynet, + 'skynet', '3GbZd6UojbD8V7zWpeFn', 'ancap.tech:34508') + + await trio.sleep(2) + + for i in range(3): + n.start_soon(open_dgpu_node) + + await trio.sleep(1) + start = time.time() + async def request_img(): + with pynng.Req0(dial=DEFAULT_RPC_ADDR) as rpc_sock: + res = await rpc_call( + rpc_sock, 'tg+1', 'txt2img', { + 'prompt': 'test', + 'step': 28, + 'width': 512, 'height': 512, + 'guidance': 7.5, + 'seed': None, + 'algo': 'stable', + 'upscaler': None + }) + + logging.info(res) + + async with trio.open_nursery() as inner_n: + for i in range(3): + inner_n.start_soon(request_img) + + logging.info(f'time elapsed: {time.time() - start}') + n.cancel_scope.cancel() + + + trio_asyncio.run(main) diff --git a/tests/test_telegram.py b/tests/test_telegram.py new file mode 100644 index 0000000..fe99566 --- /dev/null +++ b/tests/test_telegram.py @@ -0,0 +1,22 @@ +#!/usr/bin/python + +import trio +import trio_asyncio + +from skynet_bot.brain import run_skynet +from skynet_bot.frontend import open_skynet_rpc +from skynet_bot.frontend.telegram import run_skynet_telegram + + +def test_run_tg_bot(): + async def main(): + async with trio.open_nursery() as n: + await n.start( + run_skynet, + 'skynet', '3GbZd6UojbD8V7zWpeFn', 'ancap.tech:34508') + n.start_soon( + run_skynet_telegram, + '5853245787:AAFEmv3EjJ_qJ8d_vmOpi6o6HFHUf8a0uCQ') + + + trio_asyncio.run(main)