From e63d395d5c48c58498a35fda043b825d2cd8e794 Mon Sep 17 00:00:00 2001 From: Guillermo Rodriguez Date: Sun, 28 May 2023 20:17:55 -0300 Subject: [PATCH] Frontend DB fixes and starting to add img2img --- skynet/cli.py | 6 +- skynet/constants.py | 2 +- skynet/db/functions.py | 48 +++------ skynet/frontend/__init__.py | 7 +- skynet/frontend/telegram.py | 207 ++++++++++++++++++++++-------------- 5 files changed, 155 insertions(+), 115 deletions(-) diff --git a/skynet/cli.py b/skynet/cli.py index 9c30c15..6e4db55 100644 --- a/skynet/cli.py +++ b/skynet/cli.py @@ -329,7 +329,7 @@ def dgpu( @run.command() -@click.option('--loglevel', '-l', default='warning', help='logging level') +@click.option('--loglevel', '-l', default='INFO', help='logging level') @click.option( '--account', '-a', default='telegram') @click.option( @@ -357,6 +357,7 @@ def telegram( db_user: str, db_pass: str ): + logging.basicConfig(level=loglevel) key, account, permission = load_account_info( key, account, permission) @@ -422,3 +423,6 @@ def pinner(loglevel, container): ipfs_node.pin(cid) cleanup_pinned(now) + + except KeyboardInterrupt: + ... diff --git a/skynet/constants.py b/skynet/constants.py index b590bcb..486edd3 100644 --- a/skynet/constants.py +++ b/skynet/constants.py @@ -102,7 +102,7 @@ MAX_WIDTH = 512 MAX_HEIGHT = 656 MAX_GUIDANCE = 20 -DEFAULT_SEED = 0 +DEFAULT_SEED = None DEFAULT_WIDTH = 512 DEFAULT_HEIGHT = 512 DEFAULT_GUIDANCE = 7.5 diff --git a/skynet/db/functions.py b/skynet/db/functions.py index da35ae0..2b8e192 100644 --- a/skynet/db/functions.py +++ b/skynet/db/functions.py @@ -26,27 +26,11 @@ CREATE SCHEMA IF NOT EXISTS skynet; CREATE TABLE IF NOT EXISTS skynet.user( id SERIAL PRIMARY KEY NOT NULL, - tg_id BIGINT, - wp_id VARCHAR(128), - mx_id VARCHAR(128), - ig_id VARCHAR(128), generated INT NOT NULL, - joined DATE NOT NULL, + joined TIMESTAMP NOT NULL, last_prompt TEXT, role VARCHAR(128) NOT NULL ); -ALTER TABLE skynet.user - ADD CONSTRAINT tg_unique - UNIQUE (tg_id); -ALTER TABLE skynet.user - ADD CONSTRAINT wp_unique - UNIQUE (wp_id); -ALTER TABLE skynet.user - ADD CONSTRAINT mx_unique - UNIQUE (mx_id); -ALTER TABLE skynet.user - ADD CONSTRAINT ig_unique - UNIQUE (ig_id); CREATE TABLE IF NOT EXISTS skynet.user_config( id SERIAL NOT NULL, @@ -54,7 +38,7 @@ CREATE TABLE IF NOT EXISTS skynet.user_config( step INT NOT NULL, width INT NOT NULL, height INT NOT NULL, - seed BIGINT NOT NULL, + seed BIGINT, guidance REAL NOT NULL, strength REAL NOT NULL, upscaler VARCHAR(128) @@ -177,16 +161,19 @@ async def open_database_connection( yield _db_call -async def get_user(conn, uid: int): - stmt = await conn.prepare( - 'SELECT * FROM skynet.user WHERE id = $1') - return await stmt.fetchval(uid) - - async def get_user_config(conn, user: int): stmt = await conn.prepare( 'SELECT * FROM skynet.user_config WHERE id = $1') - return (await stmt.fetch(user))[0] + conf = await stmt.fetch(user) + if len(conf) == 1: + return conf[0] + + else: + return None + + +async def get_user(conn, uid: int): + return await get_user_config(conn, uid) async def get_last_prompt_of(conn, user: int): @@ -208,7 +195,6 @@ async def new_user(conn, uid: int): id, generated, joined, last_prompt, role) VALUES($1, $2, $3, $4, $5) - ON CONFLICT DO NOTHING ''') await stmt.fetch( uid, 0, date, None, DEFAULT_ROLE @@ -216,18 +202,16 @@ async def new_user(conn, uid: int): stmt = await conn.prepare(''' INSERT INTO skynet.user_config( - id, algo, step, width, height, seed, guidance, strength, upscaler) + id, algo, step, width, height, guidance, strength, upscaler) - VALUES($1, $2, $3, $4, $5, $6, $7, $8, $9) - ON CONFLICT DO NOTHING + VALUES($1, $2, $3, $4, $5, $6, $7, $8) ''') - user = await stmt.fetch( - new_uid, + resp = await stmt.fetch( + uid, DEFAULT_ALGO, DEFAULT_STEP, DEFAULT_WIDTH, DEFAULT_HEIGHT, - DEFAULT_SEED, DEFAULT_GUIDANCE, DEFAULT_STRENGTH, DEFAULT_UPSCALER diff --git a/skynet/frontend/__init__.py b/skynet/frontend/__init__.py index 19ea716..290b6b3 100644 --- a/skynet/frontend/__init__.py +++ b/skynet/frontend/__init__.py @@ -84,7 +84,12 @@ def validate_user_config_request(req: str): raise ConfigUnknownAttribute( f'\"{attr}\" not a configurable parameter') - return attr, val, f'config updated! {attr} to {val}' + display_val = val + if attr == 'seed': + if not val: + display_val = 'Random' + + return attr, val, f'config updated! {attr} to {display_val}' except ValueError: raise ValueError(f'\"{val}\" is not a number silly') diff --git a/skynet/frontend/telegram.py b/skynet/frontend/telegram.py index 0fcf67a..c20de06 100644 --- a/skynet/frontend/telegram.py +++ b/skynet/frontend/telegram.py @@ -2,12 +2,15 @@ import io import zlib +import random import logging import asyncio +import traceback from hashlib import sha256 from datetime import datetime +import asks import docker from PIL import Image @@ -18,7 +21,9 @@ from trio_asyncio import aio_as_trio from telebot.types import ( InputFile, InputMediaPhoto, InlineKeyboardButton, InlineKeyboardMarkup ) -from telebot.async_telebot import AsyncTeleBot + +from telebot.types import CallbackQuery +from telebot.async_telebot import AsyncTeleBot, ExceptionHandler from telebot.formatting import hlink from ..db import open_new_database, open_database_connection @@ -27,7 +32,11 @@ from ..constants import * from . import * -PREFIX = 'tg' +class SKYExceptionHandler(ExceptionHandler): + + def handle(exception): + traceback.print_exc() + def build_redo_menu(): btn_redo = InlineKeyboardButton("Redo", callback_data=json.dumps({'method': 'redo'})) @@ -113,20 +122,36 @@ async def get_user_nonce(cleos, user: str): async def work_request( bot, cleos, hyperion, - message, + message, user, chat, account: str, permission: str, - params: dict + params: dict, + file_id: str | None = None, + file_path: str | None = None ): + if params['seed'] == None: + params['seed'] = random.randint(0, 9e18) + body = json.dumps({ 'method': 'diffuse', 'params': params }) - user = message.from_user - chat = message.chat request_time = datetime.now().isoformat() + + if file_id: + image_raw = await bot.download_file(file_path) + image = Image.open(io.BytesIO(image_raw)) + w, h = image.size + logging.info(f'user sent img of size {image.size}') + + if w > 512 or h > 512: + image.thumbnail((512, 512)) + logging.warning(f'resized it to {image.size}') + + binary = image_raw.hex() + ec, out = cleos.push_action( - 'telos.gpu', 'enqueue', [account, body, '', '20.0000 GPU'], f'{account}@{permission}' + 'telos.gpu', 'enqueue', [account, body, binary, '20.0000 GPU'], f'{account}@{permission}' ) out = collect_stdout(out) if ec != 0: @@ -168,13 +193,54 @@ async def work_request( await bot.reply_to(message, 'timeout processing request') return - await bot.reply_to( - message, - generate_reply_caption( - user, params, ipfs_hash, tx_hash), - reply_markup=build_redo_menu(), - parse_mode='HTML' - ) + # attempt to get the image and send it + ipfs_link = f'http://test1.us.telos.net:8080/ipfs/{ipfs_hash}/image.png' + logging.info(f'attempting to get image at {ipfs_link}') + resp = None + for i in range(10): + try: + resp = await asks.get(ipfs_link, timeout=2) + + except asks.errors.RequestTimeout: + logging.warning('timeout...') + ... + + logging.info(f'status_code: {resp.status_code}') + + caption = generate_reply_caption( + user, params, ipfs_hash, tx_hash) + + if resp.status_code != 200: + await bot.reply_to( + message, + caption, + reply_markup=build_redo_menu(), + parse_mode='HTML' + ) + + else: + if file_id: # img2img + await bot.send_media_group( + chat.id, + media=[ + InputMediaPhoto(file_id), + InputMediaPhoto( + resp.raw, + caption=caption + ) + ], + reply_markup=build_redo_menu(), + parse_mode='HTML' + ) + + else: # txt2img + await bot.send_photo( + chat.id, + caption=caption, + photo=resp.raw, + reply_markup=build_redo_menu(), + parse_mode='HTML' + ) async def run_skynet_telegram( @@ -205,7 +271,7 @@ async def run_skynet_telegram( if key: cleos.setup_wallet(key) - bot = AsyncTeleBot(tg_token) + bot = AsyncTeleBot(tg_token, exception_handler=SKYExceptionHandler) logging.info(f'tg_token: {tg_token}') async with open_database_connection( @@ -233,7 +299,7 @@ async def run_skynet_telegram( @bot.message_handler(commands=['txt2img']) async def send_txt2img(message): - user = message.from_user.id + user = message.from_user chat = message.chat reply_id = None if chat.type == 'group' and chat.id == GROUP_ID: @@ -247,8 +313,8 @@ async def run_skynet_telegram( logging.info(f'mid: {message.id}') - await db_call('get_or_create_user', user) - user_config = {**(await db_call('get_user_config', user))} + user_row = await db_call('get_or_create_user', user.id) + user_config = {**user_row} del user_config['id'] params = { @@ -256,22 +322,22 @@ async def run_skynet_telegram( **user_config } - await db_call('update_user_stats', user, last_prompt=prompt) + await db_call('update_user_stats', user.id, last_prompt=prompt) await work_request( bot, cleos, hyperion, - message, account, permission, params) + message, user, chat, + account, permission, params + ) @bot.message_handler(func=lambda message: True, content_types=['photo']) async def send_img2img(message): - user = message.from_user.id + user = message.from_user chat = message.chat reply_id = None if chat.type == 'group' and chat.id == GROUP_ID: reply_id = message.message_id - user_id = f'tg+{message.from_user.id}' - if not message.caption.startswith('/img2img'): await bot.reply_to( message, @@ -287,62 +353,26 @@ async def run_skynet_telegram( file_id = message.photo[-1].file_id file_path = (await bot.get_file(file_id)).file_path - file_raw = await bot.download_file(file_path) logging.info(f'mid: {message.id}') - user = await db_call('get_or_create_user', user_id) - user_config = {**(await db_call('get_user_config', user))} + user_row = await db_call('get_or_create_user', user.id) + user_config = {**user_row} del user_config['id'] - req = json.dumps({ - 'method': 'diffuse', - 'params': { - 'prompt': prompt, - **user_config - } - }) + params = { + 'prompt': prompt, + **user_config + } - ec, out = cleos.push_action( - 'telos.gpu', 'enqueue', [account, req, file_raw.hex()], f'{account}@{permission}' + await db_call('update_user_stats', user.id, last_prompt=prompt) + + await work_request( + bot, cleos, hyperion, + message, user, chat, + account, permission, params, + file_id=file_id, file_path=file_path ) - if ec != 0: - await bot.reply_to(message, out) - return - - request_id = int(out) - logging.info(f'{request_id} enqueued.') - - ipfs_hash = None - sha_hash = None - for i in range(60): - result = cleos.get_table( - 'telos.gpu', 'telos.gpu', 'results', - index_position=2, - key_type='i64', - lower_bound=request_id, - upper_bound=request_id - ) - if len(results) > 0: - ipfs_hash = result[0]['ipfs_hash'] - sha_hash = result[0]['result_hash'] - break - else: - await asyncio.sleep(1) - - if not ipfs_hash: - await bot.reply_to(message, 'timeout processing request') - - ipfs_link = f'https://ipfs.io/ipfs/{ipfs_hash}/image.png' - - await bot.reply_to( - message, - ipfs_link + '\n' + - prepare_metainfo_caption(user, result['meta']['meta']), - reply_to_message_id=reply_id, - reply_markup=build_redo_menu() - ) - return @bot.message_handler(commands=['img2img']) @@ -352,17 +382,23 @@ async def run_skynet_telegram( 'seems you tried to do an img2img command without sending image' ) - @bot.message_handler(commands=['redo']) - async def redo(message): - user = message.from_user.id - chat = message.chat + async def _redo(message_or_query): + if isinstance(message_or_query, CallbackQuery): + query = message_or_query + message = query.message + user = query.from_user + chat = query.message.chat + + else: + message = message_or_query + user = message.from_user + chat = message.chat + reply_id = None if chat.type == 'group' and chat.id == GROUP_ID: reply_id = message.message_id - user_config = {**(await db_call('get_user_config', user))} - del user_config['id'] - prompt = await db_call('get_last_prompt_of', user) + prompt = await db_call('get_last_prompt_of', user.id) if not prompt: await bot.reply_to( @@ -371,6 +407,11 @@ async def run_skynet_telegram( ) return + + user_row = await db_call('get_or_create_user', user.id) + user_config = {**user_row} + del user_config['id'] + params = { 'prompt': prompt, **user_config @@ -378,7 +419,13 @@ async def run_skynet_telegram( await work_request( bot, cleos, hyperion, - message, account, permission, params) + message, user, chat, + account, permission, params + ) + + @bot.message_handler(commands=['redo']) + async def redo(message): + await _redo(message) @bot.message_handler(commands=['config']) async def set_config(message):