From cb92aed51cd167506c77993638e201d88d8428a6 Mon Sep 17 00:00:00 2001 From: Guillermo Rodriguez Date: Wed, 21 Dec 2022 11:53:50 -0300 Subject: [PATCH] Finish telegram frontend integration Correctly update user stats after a txt2img request Swap tg_id to BIGINT on database Refactor help topic system to be more generic Fix skynet brain cli entry point Fix multi request rpc session, we werent creating a context on the req side Fix redo method DGPU now returns image metadata Improve error messages brain to frontend Add version as constant Add telegram dep to requirements --- requirements.txt | 1 + setup.py | 4 +- skynet/brain.py | 70 ++++++++++++++++++++-------- skynet/cli.py | 14 +++--- skynet/constants.py | 22 +++++---- skynet/db.py | 29 ++++++++++-- skynet/dgpu.py | 15 ++++-- skynet/frontend/__init__.py | 7 +-- skynet/frontend/telegram.py | 92 ++++++++++++++++++++++++++++++++----- 9 files changed, 194 insertions(+), 60 deletions(-) diff --git a/requirements.txt b/requirements.txt index b1034c9..650f6ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ aiohttp msgspec pyOpenSSL trio_asyncio +pyTelegramBotAPI diff --git a/setup.py b/setup.py index 4781c43..0a822be 100644 --- a/setup.py +++ b/setup.py @@ -1,8 +1,10 @@ from setuptools import setup, find_packages +from skynet.constants import VERSION + setup( name='skynet', - version='0.1.0a6', + version=VERSION, description='Decentralized compute platform', author='Guillermo Rodriguez', author_email='guillermo@telos.net', diff --git a/skynet/brain.py b/skynet/brain.py index 7bd3cf7..91ed253 100644 --- a/skynet/brain.py +++ b/skynet/brain.py @@ -4,6 +4,7 @@ import json import uuid import base64 import logging +import traceback from uuid import UUID from pathlib import Path @@ -90,10 +91,10 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): logging.info(f'pre next_worker: {next_worker}') if next_worker == None: - raise SkynetDGPUOffline + raise SkynetDGPUOffline('No workers connected, try again later') if are_all_workers_busy(): - raise SkynetDGPUOverloaded + raise SkynetDGPUOverloaded('All workers are busy at the moment') nid = list(nodes.keys())[next_worker] @@ -175,6 +176,8 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): with trio.move_on_after(30): await img_event.wait() + logging.info(f'img event: {ack_event.is_set()}') + if not img_event.is_set(): disconnect_node(nid) raise SkynetDGPUComputeError('30 seconds timeout while processing request') @@ -187,7 +190,7 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): if 'error' in img_resp.params: raise SkynetDGPUComputeError(img_resp.params['error']) - return rid, img_resp.params['img'] + return rid, img_resp.params['img'], img_resp.params['meta'] async def handle_user_request(rpc_ctx, req): try: @@ -202,39 +205,54 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): user_config = {**(await get_user_config(conn, user))} del user_config['id'] prompt = req.params['prompt'] - user_config= { - key : req.params.get(key, val) - for key, val in user_config.items() - } req = ImageGenRequest( prompt=prompt, **user_config ) - rid, img = await dgpu_stream_one_img(req) + rid, img, meta = await dgpu_stream_one_img(req) + logging.info(f'done streaming {rid}') result = { 'id': rid, - 'img': img + 'img': img, + 'meta': meta } + await update_user_stats(conn, user, last_prompt=prompt) + logging.info('updated user stats.') + case 'redo': logging.info('redo') - user_config = await get_user_config(conn, user) + user_config = {**(await get_user_config(conn, user))} + del user_config['id'] prompt = await get_last_prompt_of(conn, user) - req = ImageGenRequest( - prompt=prompt, - **user_config - ) - rid, img = await dgpu_stream_one_img(req) - result = { - 'id': rid, - 'img': img - } + + if prompt: + req = ImageGenRequest( + prompt=prompt, + **user_config + ) + rid, img, meta = await dgpu_stream_one_img(req) + result = { + 'id': rid, + 'img': img, + 'meta': meta + } + await update_user_stats(conn, user) + logging.info('updated user stats.') + + else: + result = { + 'error': 'skynet_no_last_prompt', + 'message': 'No prompt to redo, do txt2img first' + } case 'config': logging.info('config') if req.params['attr'] in CONFIG_ATTRS: + logging.info(f'update: {req.params}') await update_user_config( conn, user, req.params['attr'], req.params['val']) + logging.info('done') case 'stats': logging.info('stats') @@ -255,9 +273,10 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): 'message': str(e) } - except SkynetDGPUOverloaded: + except SkynetDGPUOverloaded as e: result = { 'error': 'skynet_dgpu_overloaded', + 'message': str(e), 'nodes': len(nodes) } @@ -266,14 +285,23 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): 'error': 'skynet_dgpu_compute_error', 'message': str(e) } + except BaseException as e: + traceback.print_exception(type(e), e, e.__traceback__) + result = { + 'error': 'skynet_internal_error', + 'message': str(e) + } resp = SkynetRPCResponse(result=result) if security: resp.sign(tls_key, 'skynet') + logging.info('sending response') await rpc_ctx.asend( json.dumps(resp.to_dict()).encode()) + rpc_ctx.close() + logging.info('done') async def request_service(n): nonlocal next_worker @@ -329,6 +357,8 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): await ctx.asend( json.dumps(resp.to_dict()).encode()) + ctx.close() + async with trio.open_nursery() as n: n.start_soon(dgpu_image_streamer) diff --git a/skynet/cli.py b/skynet/cli.py index 50360b1..cfb786e 100644 --- a/skynet/cli.py +++ b/skynet/cli.py @@ -8,10 +8,12 @@ from functools import partial import trio import click +import trio_asyncio from . import utils from .dgpu import open_dgpu_node from .brain import run_skynet +from .constants import ALGOS from .frontend.telegram import run_skynet_telegram @@ -61,16 +63,16 @@ def run(*args, **kwargs): @click.option( '--host', '-h', default='localhost:5432') @click.option( - '--pass', '-p', default='password') -def skynet( + '--passwd', '-p', default='password') +def brain( loglevel: str, host: str, - passw: str + passwd: str ): async def _run_skynet(): async with run_skynet( db_host=host, - db_pass=passw + db_pass=passwd ): await trio.sleep_forever() @@ -86,13 +88,13 @@ def skynet( @click.option( '--cert', '-c', default='whitelist/dgpu') @click.option( - '--algos', '-a', default=None) + '--algos', '-a', default=json.dumps(['midj'])) def dgpu( loglevel: str, uid: str, key: str, cert: str, - algos: Optional[str] + algos: str ): trio.run( partial( diff --git a/skynet/constants.py b/skynet/constants.py index 5e7d767..73f6fd8 100644 --- a/skynet/constants.py +++ b/skynet/constants.py @@ -1,8 +1,10 @@ #!/usr/bin/python +VERSION = '0.1a6' + DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda' -DB_HOST = 'ancap.tech:34508' +DB_HOST = 'localhost:5432' DB_USER = 'skynet' DB_PASS = 'password' DB_NAME = 'skynet' @@ -21,7 +23,7 @@ ALGOS = { N = '\n' HELP_TEXT = f''' -test art bot v0.1a4 +test art bot v{VERSION} commands work on a user per user basis! config is individual to each user! @@ -47,7 +49,7 @@ config is individual to each user! /config guidance NUMBER - prompt text importance ''' -UNKNOWN_CMD_TEXT = 'unknown command! try sending \"/help\"' +UNKNOWN_CMD_TEXT = 'Unknown command! Try sending \"/help\"' DONATION_INFO = '0xf95335682DF281FFaB7E104EB87B69625d9622B6\ngoal: 25/650usd' @@ -74,22 +76,24 @@ COOL_WORDS = [ 'michelangelo' ] -HELP_STEP = ''' -diffusion models are iterative processes – a repeated cycle that starts with a\ +HELP_TOPICS = { + 'step': ''' +Diffusion models are iterative processes – a repeated cycle that starts with a\ random noise generated from text input. With each step, some noise is removed\ , resulting in a higher-quality image over time. The repetition stops when the\ desired number of steps completes. -around 25 sampling steps are usually enough to achieve high-quality images. Us\ +Around 25 sampling steps are usually enough to achieve high-quality images. Us\ ing more may produce a slightly different picture, but not necessarily better \ quality. -''' +''', -HELP_GUIDANCE = ''' -the guidance scale is a parameter that controls how much the image generation\ +'guidance': ''' +The guidance scale is a parameter that controls how much the image generation\ process follows the text prompt. The higher the value, the more image sticks\ to a given text input. ''' +} HELP_UNKWNOWN_PARAM = 'don\'t have any info on that.' diff --git a/skynet/db.py b/skynet/db.py index 7745e56..f803397 100644 --- a/skynet/db.py +++ b/skynet/db.py @@ -2,6 +2,7 @@ import logging +from typing import Optional from datetime import datetime from contextlib import asynccontextmanager as acm @@ -19,7 +20,7 @@ CREATE SCHEMA IF NOT EXISTS skynet; CREATE TABLE IF NOT EXISTS skynet.user( id SERIAL PRIMARY KEY NOT NULL, - tg_id INT, + tg_id BIGINT, wp_id VARCHAR(128), mx_id VARCHAR(128), ig_id VARCHAR(128), @@ -47,7 +48,7 @@ CREATE TABLE IF NOT EXISTS skynet.user_config( step INT NOT NULL, width INT NOT NULL, height INT NOT NULL, - seed INT, + seed BIGINT, guidance INT NOT NULL, upscaler VARCHAR(128) ); @@ -124,7 +125,7 @@ async def get_user_config(conn, user: int): async def get_last_prompt_of(conn, user: int): - stms = await conn.prepare( + stmt = await conn.prepare( 'SELECT last_prompt FROM skynet.user WHERE id = $1') return await stmt.fetchval(user) @@ -198,7 +199,12 @@ async def get_or_create_user(conn, uid: str): return user async def update_user(conn, user: int, attr: str, val): - ... + stmt = await conn.prepare(f''' + UPDATE skynet.user + SET {attr} = $2 + WHERE id = $1 + ''') + await stmt.fetch(user, val) async def update_user_config(conn, user: int, attr: str, val): stmt = await conn.prepare(f''' @@ -218,3 +224,18 @@ async def get_user_stats(conn, user: int): assert len(records) == 1 record = records[0] return record + +async def update_user_stats( + conn, + user: int, + last_prompt: Optional[str] = None +): + stmt = await conn.prepare(''' + UPDATE skynet.user + SET generated = generated + 1 + WHERE id = $1 + ''') + await stmt.fetch(user) + + if last_prompt: + await update_user(conn, user, 'last_prompt', last_prompt) diff --git a/skynet/dgpu.py b/skynet/dgpu.py index 7b44d6c..4beb8f3 100644 --- a/skynet/dgpu.py +++ b/skynet/dgpu.py @@ -109,7 +109,6 @@ async def open_dgpu_node( 'generated': 0 } - seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64) try: image = models[ireq.algo]['pipe']( ireq.prompt, @@ -117,7 +116,7 @@ async def open_dgpu_node( height=ireq.height, guidance_scale=ireq.guidance, num_inference_steps=ireq.step, - generator=torch.Generator("cuda").manual_seed(seed) + generator=torch.Generator("cuda").manual_seed(ireq.seed) ).images[0] return image.tobytes() @@ -207,12 +206,18 @@ async def open_dgpu_node( logging.info(f'sent ack, processing {req.rid}...') try: - img = await gpu_compute_one( - ImageGenRequest(**req.params)) + img_req = ImageGenRequest(**req.params) + if not img_req.seed: + img_req.seed = random.randint(0, 2 ** 64) + + img = await gpu_compute_one(img_req) img_resp = DGPUBusResponse( rid=req.rid, nid=req.nid, - params={'img': base64.b64encode(img).hex()} + params={ + 'img': base64.b64encode(img).hex(), + 'meta': img_req.to_dict() + } ) except DGPUComputeError as e: diff --git a/skynet/frontend/__init__.py b/skynet/frontend/__init__.py index 4eaf918..ceb47eb 100644 --- a/skynet/frontend/__init__.py +++ b/skynet/frontend/__init__.py @@ -90,13 +90,14 @@ async def open_skynet_rpc( if security: req.sign(tls_key, cert_name) - await sock.asend( + ctx = sock.new_context() + await ctx.asend( json.dumps( req.to_dict()).encode()) resp = SkynetRPCResponse( - **json.loads( - (await sock.arecv_msg()).bytes.decode())) + **json.loads((await ctx.arecv()).decode())) + ctx.close() if security: resp.verify(skynet_cert) diff --git a/skynet/frontend/telegram.py b/skynet/frontend/telegram.py index 6f217b3..e3d073b 100644 --- a/skynet/frontend/telegram.py +++ b/skynet/frontend/telegram.py @@ -1,14 +1,19 @@ #!/usr/bin/python +import io +import base64 import logging from datetime import datetime import pynng -from telebot.async_telebot import AsyncTeleBot +from PIL import Image from trio_asyncio import aio_as_trio +from telebot.types import InputFile +from telebot.async_telebot import AsyncTeleBot + from ..constants import * from . import * @@ -17,6 +22,17 @@ from . import * PREFIX = 'tg' +def prepare_metainfo_caption(meta: dict) -> str: + meta_str = f'prompt: \"{meta["prompt"]}\"\n' + meta_str += f'seed: {meta["seed"]}\n' + meta_str += f'step: {meta["step"]}\n' + meta_str += f'guidance: {meta["guidance"]}\n' + meta_str += f'algo: \"{meta["algo"]}\"\n' + meta_str += f'sampler: k_euler_ancestral\n' + meta_str += f'skynet v{VERSION}' + return meta_str + + async def run_skynet_telegram( tg_token: str, key_name: str = 'telegram-frontend', @@ -26,24 +42,35 @@ async def run_skynet_telegram( logging.basicConfig(level=logging.INFO) bot = AsyncTeleBot(tg_token) - with open_skynet_rpc( + async with open_skynet_rpc( 'skynet-telegram-0', security=True, - cert_name=cert, - key_name=key + cert_name=cert_name, + key_name=key_name ) as rpc_call: async def _rpc_call( uid: int, method: str, - params: dict + params: dict = {} ): return await rpc_call( method, params, uid=f'{PREFIX}+{uid}') @bot.message_handler(commands=['help']) async def send_help(message): - await bot.reply_to(message, HELP_TEXT) + splt_msg = message.text.split(' ') + + if len(splt_msg) == 1: + await bot.reply_to(message, HELP_TEXT) + + else: + param = splt_msg[1] + if param in HELP_TOPICS: + await bot.reply_to(message, HELP_TOPICS[param]) + + else: + await bot.reply_to(message, HELP_UNKWNOWN_PARAM) @bot.message_handler(commands=['cool']) async def send_cool_words(message): @@ -51,19 +78,60 @@ async def run_skynet_telegram( @bot.message_handler(commands=['txt2img']) async def send_txt2img(message): + prompt = ' '.join(message.text.split(' ')[1:]) + + if len(prompt) == 0: + await bot.reply_to(message, 'Empty text prompt ignored.') + return + + logging.info(f'mid: {message.id}') resp = await _rpc_call( message.from_user.id, 'txt2img', - {} + {'prompt': prompt} ) + logging.info(f'resp to {message.id} arrived') + + resp_txt = '' + if 'error' in resp.result: + resp_txt = resp.result['message'] + + else: + logging.info(resp.result['id']) + img_raw = base64.b64decode(bytes.fromhex(resp.result['img'])) + img = Image.frombytes('RGB', (512, 512), img_raw) + + await bot.send_photo( + message.chat.id, + caption=prepare_metainfo_caption(resp.result['meta']), + photo=img, + reply_to_message_id=message.id + ) + return + + await bot.reply_to(message, resp_txt) @bot.message_handler(commands=['redo']) async def redo_txt2img(message): - resp = await _rpc_call( - message.from_user.id, - 'redo', - {} - ) + resp = await _rpc_call(message.from_user.id, 'redo') + + resp_txt = '' + if 'error' in resp.result: + resp_txt = resp.result['message'] + + else: + img_raw = base64.b64decode(bytes.fromhex(resp.result['img'])) + img = Image.frombytes('RGB', (512, 512), img_raw) + + await bot.send_photo( + message.chat.id, + caption=prepare_metainfo_caption(resp.result['meta']), + photo=img, + reply_to_message_id=message.id + ) + return + + await bot.reply_to(message, resp_txt) @bot.message_handler(commands=['config']) async def set_config(message):