From 27fe05c3e7e85bfcb058f6350f4dc6ba75f78f64 Mon Sep 17 00:00:00 2001 From: Guillermo Rodriguez Date: Sat, 3 Jun 2023 20:17:56 -0300 Subject: [PATCH] Vast improvement to telegram frontedn --- skynet/cli.py | 17 +- skynet/db/functions.py | 60 ++- skynet/dgpu.py | 21 +- skynet/frontend/__init__.py | 6 - skynet/frontend/telegram.py | 566 --------------------------- skynet/frontend/telegram/__init__.py | 274 +++++++++++++ skynet/frontend/telegram/handlers.py | 345 ++++++++++++++++ skynet/frontend/telegram/utils.py | 113 ++++++ 8 files changed, 810 insertions(+), 592 deletions(-) delete mode 100644 skynet/frontend/telegram.py create mode 100644 skynet/frontend/telegram/__init__.py create mode 100644 skynet/frontend/telegram/handlers.py create mode 100644 skynet/frontend/telegram/utils.py diff --git a/skynet/cli.py b/skynet/cli.py index 9c1e6f4..8a2d8ac 100644 --- a/skynet/cli.py +++ b/skynet/cli.py @@ -26,7 +26,7 @@ from .ipfs import open_ipfs_node from .config import * from .nodeos import open_cleos, open_nodeos from .constants import * -from .frontend.telegram import run_skynet_telegram +from .frontend.telegram import SkynetTelegramFrontend @click.group() @@ -318,7 +318,7 @@ def nodeos(): @click.option( '--ipfs-url', '-n', default=DEFAULT_IPFS_REMOTE) @click.option( - '--algos', '-A', default=json.dumps(['midj', 'ink'])) + '--algos', '-A', default=json.dumps(['midj'])) def dgpu( loglevel: str, account: str, @@ -383,8 +383,9 @@ def telegram( key, account, permission) _, _, tg_token, cfg = init_env_from_config() - asyncio.run( - run_skynet_telegram( + + async def _async_main(): + frontend = SkynetTelegramFrontend( tg_token, account, permission, @@ -393,7 +394,13 @@ def telegram( db_host, db_user, db_pass, remote_ipfs_node=ipfs_url, key=key - )) + ) + + async with frontend.open(): + await frontend.bot.infinity_polling() + + + asyncio.run(_async_main()) class IPFSHTTP: diff --git a/skynet/db/functions.py b/skynet/db/functions.py index 3a8b760..8a56484 100644 --- a/skynet/db/functions.py +++ b/skynet/db/functions.py @@ -6,7 +6,6 @@ import string import logging import importlib -from typing import Optional from datetime import datetime from contextlib import contextmanager as cm from contextlib import asynccontextmanager as acm @@ -15,7 +14,6 @@ import docker import asyncpg import psycopg2 -from asyncpg.exceptions import UndefinedColumnError from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT from ..constants import * @@ -49,6 +47,17 @@ CREATE TABLE IF NOT EXISTS skynet.user_config( ALTER TABLE skynet.user_config ADD FOREIGN KEY(id) REFERENCES skynet.user(id); + +CREATE TABLE IF NOT EXISTS skynet.user_requests( + id SERIAL NOT NULL, + user_id SERIAL NOT NULL, + sent TIMESTAMP NOT NULL, + status TEXT NOT NULL, + status_msg SERIAL PRIMARY KEY NOT NULL +); +ALTER TABLE skynet.user_requests + ADD FOREIGN KEY(user_id) + REFERENCES skynet.user(id); ''' @@ -199,6 +208,53 @@ async def get_last_binary_of(conn, user: int): return await stmt.fetchval(user) +async def get_user_request(conn, mid: int): + stmt = await conn.prepare( + 'SELECT * FROM skynet.user_requests WHERE id = $1') + return await stmt.fetch(mid) + +async def get_user_request_by_sid(conn, sid: int): + stmt = await conn.prepare( + 'SELECT * FROM skynet.user_requests WHERE status_msg = $1') + return (await stmt.fetch(sid))[0] + +async def new_user_request( + conn, user: int, mid: int, + status_msg: int, + status: str = 'started processing request...' +): + date = datetime.utcnow() + async with conn.transaction(): + stmt = await conn.prepare(''' + INSERT INTO skynet.user_requests( + id, user_id, sent, status, status_msg + ) + + VALUES($1, $2, $3, $4, $5) + ''') + await stmt.fetch(mid, user, date, status, status_msg) + +async def update_user_request( + conn, mid: int, status: str +): + stmt = await conn.prepare(f''' + UPDATE skynet.user_requests + SET status = $2 + WHERE id = $1 + ''') + await stmt.fetch(mid, status) + +async def update_user_request_by_sid( + conn, sid: int, status: str +): + stmt = await conn.prepare(f''' + UPDATE skynet.user_requests + SET status = $2 + WHERE status_msg = $1 + ''') + await stmt.fetch(sid, status) + + async def new_user(conn, uid: int): if await get_user(conn, uid): raise ValueError('User already present on db') diff --git a/skynet/dgpu.py b/skynet/dgpu.py index e3b882d..763fdd1 100644 --- a/skynet/dgpu.py +++ b/skynet/dgpu.py @@ -4,8 +4,8 @@ import gc import io import json import time -import random import logging +import traceback from PIL import Image from typing import List, Optional @@ -15,19 +15,13 @@ import trio import asks import torch -from leap.cleos import CLEOS, default_nodeos_image +from leap.cleos import CLEOS from leap.sugar import * -from diffusers import ( - StableDiffusionPipeline, - StableDiffusionImg2ImgPipeline, - EulerAncestralDiscreteScheduler -) from realesrgan import RealESRGANer from basicsr.archs.rrdbnet_arch import RRDBNet -from diffusers.models import UNet2DConditionModel -from .ipfs import IPFSDocker, open_ipfs_node, get_ipfs_file +from .ipfs import open_ipfs_node, get_ipfs_file from .utils import * from .constants import * @@ -125,7 +119,7 @@ async def open_dgpu_node( logging.info(f'binext: {len(binext) if binext else 0} bytes') if binext: _params['image'] = image - _params['strength'] = params['strength'] + _params['strength'] = float(Decimal(params['strength'])) else: _params['width'] = int(params['width']) @@ -135,7 +129,7 @@ async def open_dgpu_node( image = models[algo]['pipe']( params['prompt'], **_params, - guidance_scale=params['guidance'], + guidance_scale=float(Decimal(params['guidance'])), num_inference_steps=int(params['step']), generator=torch.manual_seed(int(params['seed'])) ).images[0] @@ -246,7 +240,7 @@ async def open_dgpu_node( async def maybe_withdraw_all(): logging.info('maybe_withdraw_all') - balance = get_worker_balance() + balance = await get_worker_balance() if not balance: return @@ -323,7 +317,7 @@ async def open_dgpu_node( try: while True: if auto_withdraw: - maybe_withdraw_all() + await maybe_withdraw_all() queue = await get_work_requests_last_hour() @@ -371,6 +365,7 @@ async def open_dgpu_node( break except BaseException as e: + traceback.print_exc() await cancel_work(rid, str(e)) break diff --git a/skynet/frontend/__init__.py b/skynet/frontend/__init__.py index 290b6b3..6a2557e 100644 --- a/skynet/frontend/__init__.py +++ b/skynet/frontend/__init__.py @@ -1,11 +1,5 @@ #!/usr/bin/python -import json - -from typing import Union, Optional -from pathlib import Path -from contextlib import contextmanager as cm - from ..constants import * diff --git a/skynet/frontend/telegram.py b/skynet/frontend/telegram.py deleted file mode 100644 index e6e8b76..0000000 --- a/skynet/frontend/telegram.py +++ /dev/null @@ -1,566 +0,0 @@ -#!/usr/bin/python - -import io -import zlib -import random -import logging -import asyncio -import traceback - -from decimal import Decimal -from hashlib import sha256 -from datetime import datetime, timedelta - -import asks -import docker - -from PIL import Image -from leap.cleos import CLEOS -from leap.sugar import * -from leap.hyperion import HyperionAPI -from trio_asyncio import aio_as_trio -from telebot.types import ( - InputFile, InputMediaPhoto, InlineKeyboardButton, InlineKeyboardMarkup -) - -from telebot.types import CallbackQuery -from telebot.async_telebot import AsyncTeleBot, ExceptionHandler -from telebot.formatting import hlink - -from ..db import open_new_database, open_database_connection -from ..ipfs import open_ipfs_node, get_ipfs_file -from ..constants import * - -from . import * - - -class SKYExceptionHandler(ExceptionHandler): - - def handle(exception): - traceback.print_exc() - - -def build_redo_menu(): - btn_redo = InlineKeyboardButton("Redo", callback_data=json.dumps({'method': 'redo'})) - inline_keyboard = InlineKeyboardMarkup() - inline_keyboard.add(btn_redo) - return inline_keyboard - - -def prepare_metainfo_caption(tguser, worker: str, reward: str, meta: dict) -> str: - prompt = meta["prompt"] - if len(prompt) > 256: - prompt = prompt[:256] - - if tguser.username: - user = f'@{tguser.username}' - else: - user = f'{tguser.first_name} id: {tguser.id}' - - meta_str = f'by {user}\n' - meta_str += f'performed by {worker}\n' - meta_str += f'reward: {reward}\n' - - meta_str += f'prompt: {prompt}\n' - meta_str += f'seed: {meta["seed"]}\n' - meta_str += f'step: {meta["step"]}\n' - meta_str += f'guidance: {meta["guidance"]}\n' - if meta['strength']: - meta_str += f'strength: {meta["strength"]}\n' - meta_str += f'algo: {meta["algo"]}\n' - if meta['upscaler']: - meta_str += f'upscaler: {meta["upscaler"]}\n' - - meta_str += f'Made with Skynet v{VERSION}\n' - meta_str += f'JOIN THE SWARM: @skynetgpu' - return meta_str - - -def generate_reply_caption( - tguser, # telegram user - params: dict, - ipfs_hash: str, - tx_hash: str, - worker: str, - reward: str -): - ipfs_link = hlink( - 'Get your image on IPFS', - f'https://ipfs.ancap.tech/ipfs/{ipfs_hash}/image.png' - ) - explorer_link = hlink( - 'SKYNET Transaction Explorer', - f'https://skynet.ancap.tech/v2/explore/transaction/{tx_hash}' - ) - - meta_info = prepare_metainfo_caption(tguser, worker, reward, params) - - final_msg = '\n'.join([ - 'Worker finished your task!', - ipfs_link, - explorer_link, - f'PARAMETER INFO:\n{meta_info}' - ]) - - final_msg = '\n'.join([ - f'{ipfs_link}', - f'{explorer_link}', - f'{meta_info}' - ]) - - logging.info(final_msg) - - return final_msg - - -async def get_global_config(cleos): - return (await cleos.aget_table( - 'telos.gpu', 'telos.gpu', 'config'))[0] - -async def get_user_nonce(cleos, user: str): - return (await cleos.aget_table( - 'telos.gpu', 'telos.gpu', 'users', - index_position=1, - key_type='name', - lower_bound=user, - upper_bound=user - ))[0]['nonce'] - -async def work_request( - bot, cleos, hyperion, - message, user, chat, - account: str, - permission: str, - private_key: str, - params: dict, - file_id: str | None = None, - binary_data: str = '' -): - if params['seed'] == None: - params['seed'] = random.randint(0, 0xFFFFFFFF) - - sanitized_params = {} - for key, val in params.items(): - if isinstance(val, Decimal): - val = int(val) - - sanitized_params[key] = val - - body = json.dumps({ - 'method': 'diffuse', - 'params': sanitized_params - }) - request_time = datetime.now().isoformat() - - reward = '20.0000 GPU' - res = await cleos.a_push_action( - 'telos.gpu', - 'enqueue', - { - 'user': Name(account), - 'request_body': body, - 'binary_data': binary_data, - 'reward': asset_from_str(reward) - }, - account, private_key, permission=permission - ) - - if 'code' in res: - await bot.reply_to(message, json.dumps(res, indent=4)) - return - - out = collect_stdout(res) - - request_id, nonce = out.split(':') - - request_hash = sha256( - (nonce + body + binary_data).encode('utf-8')).hexdigest().upper() - - request_id = int(request_id) - logging.info(f'{request_id} enqueued.') - - config = await get_global_config(cleos) - - tx_hash = None - ipfs_hash = None - for i in range(60): - submits = await hyperion.aget_actions( - account=account, - filter='telos.gpu:submit', - sort='desc', - after=request_time - ) - actions = [ - action - for action in submits['actions'] - if action[ - 'act']['data']['request_hash'] == request_hash - ] - if len(actions) > 0: - tx_hash = actions[0]['trx_id'] - data = actions[0]['act']['data'] - ipfs_hash = data['ipfs_hash'] - worker = data['worker'] - logging.info('Found matching submit!') - break - - await asyncio.sleep(1) - - if not ipfs_hash: - await bot.reply_to(message, 'timeout processing request') - return - - # attempt to get the image and send it - ipfs_link = f'https://ipfs.ancap.tech/ipfs/{ipfs_hash}/image.png' - resp = await get_ipfs_file(ipfs_link) - - caption = generate_reply_caption( - user, params, ipfs_hash, tx_hash, worker, reward) - - if not resp or resp.status_code != 200: - logging.error(f'couldn\'t get ipfs hosted image at {ipfs_link}!') - await bot.reply_to( - message, - caption, - reply_markup=build_redo_menu(), - parse_mode='HTML' - ) - - else: - logging.info(f'succes! sending generated image') - if file_id: # img2img - await bot.send_media_group( - chat.id, - media=[ - InputMediaPhoto(file_id), - InputMediaPhoto( - resp.raw, - caption=caption, - parse_mode='HTML' - ) - ], - ) - - else: # txt2img - await bot.send_photo( - chat.id, - caption=caption, - photo=resp.raw, - reply_markup=build_redo_menu(), - parse_mode='HTML' - ) - - -async def run_skynet_telegram( - tg_token: str, - account: str, - permission: str, - node_url: str, - hyperion_url: str, - db_host: str, - db_user: str, - db_pass: str, - remote_ipfs_node: str, - key: str = None -): - cleos = CLEOS(None, None, url=node_url, remote=node_url) - hyperion = HyperionAPI(hyperion_url) - - logging.basicConfig(level=logging.INFO) - - bot = AsyncTeleBot(tg_token, exception_handler=SKYExceptionHandler) - logging.info(f'tg_token: {tg_token}') - - with open_ipfs_node() as ipfs_node: - ipfs_node.connect(remote_ipfs_node) - async with open_database_connection( - db_user, db_pass, db_host - ) as db_call: - - @bot.message_handler(commands=['help']) - async def send_help(message): - splt_msg = message.text.split(' ') - - if len(splt_msg) == 1: - await bot.reply_to(message, HELP_TEXT) - - else: - param = splt_msg[1] - if param in HELP_TOPICS: - await bot.reply_to(message, HELP_TOPICS[param]) - - else: - await bot.reply_to(message, HELP_UNKWNOWN_PARAM) - - @bot.message_handler(commands=['cool']) - async def send_cool_words(message): - await bot.reply_to(message, '\n'.join(COOL_WORDS)) - - async def _generic_txt2img(message_or_query): - if isinstance(message_or_query, CallbackQuery): - query = message_or_query - message = query.message - user = query.from_user - chat = query.message.chat - - else: - message = message_or_query - user = message.from_user - chat = message.chat - - reply_id = None - if chat.type == 'group' and chat.id == GROUP_ID: - reply_id = message.message_id - - prompt = ' '.join(message.text.split(' ')[1:]) - - if len(prompt) == 0: - await bot.reply_to(message, 'Empty text prompt ignored.') - return - - logging.info(f'mid: {message.id}') - - user_row = await db_call('get_or_create_user', user.id) - user_config = {**user_row} - del user_config['id'] - - params = { - 'prompt': prompt, - **user_config - } - - await db_call( - 'update_user_stats', user.id, 'txt2img', last_prompt=prompt) - - ec = await work_request( - bot, cleos, hyperion, - message, user, chat, - account, permission, key, params - ) - - if ec == 0: - await db_call('increment_generated', user.id) - - async def _generic_img2img(message_or_query): - if isinstance(message_or_query, CallbackQuery): - query = message_or_query - message = query.message - user = query.from_user - chat = query.message.chat - - else: - message = message_or_query - user = message.from_user - chat = message.chat - - reply_id = None - if chat.type == 'group' and chat.id == GROUP_ID: - reply_id = message.message_id - - if not message.caption.startswith('/img2img'): - await bot.reply_to( - message, - 'For image to image you need to add /img2img to the beggining of your caption' - ) - return - - prompt = ' '.join(message.caption.split(' ')[1:]) - - if len(prompt) == 0: - await bot.reply_to(message, 'Empty text prompt ignored.') - return - - file_id = message.photo[-1].file_id - file_path = (await bot.get_file(file_id)).file_path - image_raw = await bot.download_file(file_path) - - with Image.open(io.BytesIO(image_raw)) as image: - w, h = image.size - - if w > 512 or h > 512: - logging.warning(f'user sent img of size {image.size}') - image.thumbnail((512, 512)) - logging.warning(f'resized it to {image.size}') - - image.save(f'ipfs-docker-staging/image.png', format='PNG') - - ipfs_hash = ipfs_node.add('image.png') - ipfs_node.pin(ipfs_hash) - - logging.info(f'published input image {ipfs_hash} on ipfs') - - logging.info(f'mid: {message.id}') - - user_row = await db_call('get_or_create_user', user.id) - user_config = {**user_row} - del user_config['id'] - - params = { - 'prompt': prompt, - **user_config - } - - await db_call( - 'update_user_stats', - user.id, - 'img2img', - last_file=file_id, - last_prompt=prompt, - last_binary=ipfs_hash - ) - - ec = await work_request( - bot, cleos, hyperion, - message, user, chat, - account, permission, key, params, - file_id=file_id, - binary_data=ipfs_hash - ) - - if ec == 0: - await db_call('increment_generated', user.id) - - @bot.message_handler(commands=['txt2img']) - async def send_txt2img(message): - await _generic_txt2img(message) - - @bot.message_handler(func=lambda message: True, content_types=[ - 'photo', 'document']) - async def send_img2img(message): - await _generic_img2img(message) - - @bot.message_handler(commands=['img2img']) - async def img2img_missing_image(message): - await bot.reply_to( - message, - 'seems you tried to do an img2img command without sending image' - ) - - async def _redo(message_or_query): - if isinstance(message_or_query, CallbackQuery): - query = message_or_query - message = query.message - user = query.from_user - chat = query.message.chat - - else: - message = message_or_query - user = message.from_user - chat = message.chat - - method = await db_call('get_last_method_of', user.id) - prompt = await db_call('get_last_prompt_of', user.id) - - file_id = None - binary = '' - if method == 'img2img': - file_id = await db_call('get_last_file_of', user.id) - binary = await db_call('get_last_binary_of', user.id) - - if not prompt: - await bot.reply_to( - message, - 'no last prompt found, do a txt2img cmd first!' - ) - return - - - user_row = await db_call('get_or_create_user', user.id) - user_config = {**user_row} - del user_config['id'] - - params = { - 'prompt': prompt, - **user_config - } - - await work_request( - bot, cleos, hyperion, - message, user, chat, - account, permission, key, params, - file_id=file_id, - binary_data=binary - ) - - @bot.message_handler(commands=['queue']) - async def queue(message): - an_hour_ago = datetime.now() - timedelta(hours=1) - queue = await cleos.aget_table( - 'telos.gpu', 'telos.gpu', 'queue', - index_position=2, - key_type='i64', - sort='desc', - lower_bound=int(an_hour_ago.timestamp()) - ) - await bot.reply_to( - message, f'Total requests on skynet queue: {len(queue)}') - - @bot.message_handler(commands=['redo']) - async def redo(message): - await _redo(message) - - @bot.message_handler(commands=['config']) - async def set_config(message): - user = message.from_user.id - try: - attr, val, reply_txt = validate_user_config_request( - message.text) - - logging.info(f'user config update: {attr} to {val}') - await db_call('update_user_config', user, attr, val) - logging.info('done') - - except BaseException as e: - reply_txt = str(e) - - finally: - await bot.reply_to(message, reply_txt) - - @bot.message_handler(commands=['stats']) - async def user_stats(message): - user = message.from_user.id - - await db_call('get_or_create_user', user) - generated, joined, role = await db_call('get_user_stats', user) - - stats_str = f'generated: {generated}\n' - stats_str += f'joined: {joined}\n' - stats_str += f'role: {role}\n' - - await bot.reply_to( - message, stats_str) - - @bot.message_handler(commands=['donate']) - async def donation_info(message): - await bot.reply_to( - message, DONATION_INFO) - - @bot.message_handler(commands=['say']) - async def say(message): - chat = message.chat - user = message.from_user - - if (chat.type == 'group') or (user.id != 383385940): - return - - await bot.send_message(GROUP_ID, message.text[4:]) - - @bot.message_handler(func=lambda message: True) - async def echo_message(message): - if message.text[0] == '/': - await bot.reply_to(message, UNKNOWN_CMD_TEXT) - - @bot.callback_query_handler(func=lambda call: True) - async def callback_query(call): - msg = json.loads(call.data) - logging.info(call.data) - method = msg.get('method') - match method: - case 'redo': - await _redo(call) - - try: - await bot.infinity_polling() - - except KeyboardInterrupt: - ... diff --git a/skynet/frontend/telegram/__init__.py b/skynet/frontend/telegram/__init__.py new file mode 100644 index 0000000..ecd66fd --- /dev/null +++ b/skynet/frontend/telegram/__init__.py @@ -0,0 +1,274 @@ +#!/usr/bin/python + +import random +import logging +import asyncio + +from decimal import Decimal +from hashlib import sha256 +from datetime import datetime +from contextlib import ExitStack, AsyncExitStack +from contextlib import asynccontextmanager as acm + +from leap.cleos import CLEOS +from leap.sugar import Name, asset_from_str, collect_stdout +from leap.hyperion import HyperionAPI +from telebot.asyncio_helper import ApiTelegramException +from telebot.types import InputMediaPhoto + +from telebot.types import CallbackQuery +from telebot.async_telebot import AsyncTeleBot + +from skynet.db import open_new_database, open_database_connection +from skynet.ipfs import open_ipfs_node, get_ipfs_file +from skynet.constants import * + +from . import * + +from .utils import * +from .handlers import create_handler_context + + +class SkynetTelegramFrontend: + + def __init__( + self, + token: str, + account: str, + permission: str, + node_url: str, + hyperion_url: str, + db_host: str, + db_user: str, + db_pass: str, + remote_ipfs_node: str, + key: str + ): + self.token = token + self.account = account + self.permission = permission + self.node_url = node_url + self.hyperion_url = hyperion_url + self.db_host = db_host + self.db_user = db_user + self.db_pass = db_pass + self.remote_ipfs_node = remote_ipfs_node + self.key = key + + self.bot = AsyncTeleBot(token, exception_handler=SKYExceptionHandler) + self.cleos = CLEOS(None, None, url=node_url, remote=node_url) + self.hyperion = HyperionAPI(hyperion_url) + + self._exit_stack = ExitStack() + self._async_exit_stack = AsyncExitStack() + + async def start(self): + self.ipfs_node = self._exit_stack.enter_context( + open_ipfs_node()) + + self.ipfs_node.connect(self.remote_ipfs_node) + logging.info( + f'connected to remote ipfs node: {self.remote_ipfs_node}') + + self.db_call = await self._async_exit_stack.enter_async_context( + open_database_connection( + self.db_user, self.db_pass, self.db_host)) + + create_handler_context(self) + + async def stop(self): + await self._async_exit_stack.aclose() + self._exit_stack.close() + + @acm + async def open(self): + await self.start() + yield self + await self.stop() + + async def update_status_message( + self, status_msg, new_text: str, **kwargs + ): + await self.db_call( + 'update_user_request_by_sid', status_msg.id, new_text) + return await self.bot.edit_message_text( + new_text, + chat_id=status_msg.chat.id, + message_id=status_msg.id, + **kwargs + ) + + async def append_status_message( + self, status_msg, add_text: str, **kwargs + ): + request = await self.db_call('get_user_request_by_sid', status_msg.id) + await self.update_status_message( + status_msg, + request['status'] + add_text, + **kwargs + ) + + async def work_request( + self, + user, + status_msg, + method: str, + params: dict, + file_id: str | None = None, + binary_data: str = '' + ): + if params['seed'] == None: + params['seed'] = random.randint(0, 0xFFFFFFFF) + + sanitized_params = {} + for key, val in params.items(): + if isinstance(val, Decimal): + val = str(val) + + sanitized_params[key] = val + + body = json.dumps({ + 'method': 'diffuse', + 'params': sanitized_params + }) + request_time = datetime.now().isoformat() + + await self.update_status_message( + status_msg, + f'processing a \'{method}\' request by {tg_user_pretty(user)}\n' + f'[{timestamp_pretty()}] broadcasting transaction to chain...', + parse_mode='HTML' + ) + + reward = '20.0000 GPU' + res = await self.cleos.a_push_action( + 'telos.gpu', + 'enqueue', + { + 'user': Name(self.account), + 'request_body': body, + 'binary_data': binary_data, + 'reward': asset_from_str(reward) + }, + self.account, self.key, permission=self.permission + ) + + if 'code' in res or 'statusCode' in res: + logging.error(json.dumps(res, indent=4)) + await self.update_status_message( + status_msg, + 'skynet has suffered an internal error trying to fill this request') + return + + enqueue_tx_id = res['transaction_id'] + enqueue_tx_link = hlink( + 'Your request on Skynet Explorer', + f'https://skynet.ancap.tech/v2/explore/transaction/{enqueue_tx_id}' + ) + + await self.append_status_message( + status_msg, + f' broadcasted!\n' + f'{enqueue_tx_link}\n' + f'[{timestamp_pretty()}] workers are processing request...', + parse_mode='HTML' + ) + + out = collect_stdout(res) + + request_id, nonce = out.split(':') + + request_hash = sha256( + (nonce + body + binary_data).encode('utf-8')).hexdigest().upper() + + request_id = int(request_id) + + logging.info(f'{request_id} enqueued.') + + tx_hash = None + ipfs_hash = None + for i in range(60): + submits = await self.hyperion.aget_actions( + account=self.account, + filter='telos.gpu:submit', + sort='desc', + after=request_time + ) + actions = [ + action + for action in submits['actions'] + if action[ + 'act']['data']['request_hash'] == request_hash + ] + if len(actions) > 0: + tx_hash = actions[0]['trx_id'] + data = actions[0]['act']['data'] + ipfs_hash = data['ipfs_hash'] + worker = data['worker'] + logging.info('Found matching submit!') + break + + await asyncio.sleep(1) + + if not ipfs_hash: + await self.update_status_message( + status_msg, + '\n[{timestamp_pretty()}] timeout processing request', + parse_mode='HTML' + ) + return + + tx_link = hlink( + 'Your result on Skynet Explorer', + f'https://skynet.ancap.tech/v2/explore/transaction/{tx_hash}' + ) + + await self.append_status_message( + status_msg, + f' request processed!\n' + f'{tx_link}\n' + f'[{timestamp_pretty()}] trying to download image...\n', + parse_mode='HTML' + ) + + # attempt to get the image and send it + ipfs_link = f'https://ipfs.ancap.tech/ipfs/{ipfs_hash}/image.png' + resp = await get_ipfs_file(ipfs_link) + + caption = generate_reply_caption( + user, params, ipfs_hash, tx_hash, worker, reward) + + if not resp or resp.status_code != 200: + logging.error(f'couldn\'t get ipfs hosted image at {ipfs_link}!') + await self.update_status_message( + status_msg, + caption, + reply_markup=build_redo_menu(), + parse_mode='HTML' + ) + + else: + logging.info(f'success! sending generated image') + await self.bot.delete_message( + chat_id=status_msg.chat.id, message_id=status_msg.id) + if file_id: # img2img + await self.bot.send_media_group( + status_msg.chat.id, + media=[ + InputMediaPhoto(file_id), + InputMediaPhoto( + resp.raw, + caption=caption, + parse_mode='HTML' + ) + ], + ) + + else: # txt2img + await self.bot.send_photo( + status_msg.chat.id, + caption=caption, + photo=resp.raw, + reply_markup=build_redo_menu(), + parse_mode='HTML' + ) diff --git a/skynet/frontend/telegram/handlers.py b/skynet/frontend/telegram/handlers.py new file mode 100644 index 0000000..7e77880 --- /dev/null +++ b/skynet/frontend/telegram/handlers.py @@ -0,0 +1,345 @@ +#!/usr/bin/python + +import io +import json +import logging + +from datetime import datetime, timedelta + +from PIL import Image +from telebot.types import CallbackQuery, Message + +from skynet.frontend import validate_user_config_request +from skynet.constants import * + + +def create_handler_context(frontend: 'SkynetTelegramFrontend'): + + bot = frontend.bot + cleos = frontend.cleos + db_call = frontend.db_call + work_request = frontend.work_request + + ipfs_node = frontend.ipfs_node + + # generic / simple handlers + + @bot.message_handler(commands=['help']) + async def send_help(message): + splt_msg = message.text.split(' ') + + if len(splt_msg) == 1: + await bot.reply_to(message, HELP_TEXT) + + else: + param = splt_msg[1] + if param in HELP_TOPICS: + await bot.reply_to(message, HELP_TOPICS[param]) + + else: + await bot.reply_to(message, HELP_UNKWNOWN_PARAM) + + @bot.message_handler(commands=['cool']) + async def send_cool_words(message): + await bot.reply_to(message, '\n'.join(COOL_WORDS)) + + @bot.message_handler(commands=['queue']) + async def queue(message): + an_hour_ago = datetime.now() - timedelta(hours=1) + queue = await cleos.aget_table( + 'telos.gpu', 'telos.gpu', 'queue', + index_position=2, + key_type='i64', + sort='desc', + lower_bound=int(an_hour_ago.timestamp()) + ) + await bot.reply_to( + message, f'Total requests on skynet queue: {len(queue)}') + + + @bot.message_handler(commands=['config']) + async def set_config(message): + user = message.from_user.id + try: + attr, val, reply_txt = validate_user_config_request( + message.text) + + logging.info(f'user config update: {attr} to {val}') + await db_call('update_user_config', user, attr, val) + logging.info('done') + + except BaseException as e: + reply_txt = str(e) + + finally: + await bot.reply_to(message, reply_txt) + + @bot.message_handler(commands=['stats']) + async def user_stats(message): + user = message.from_user.id + + await db_call('get_or_create_user', user) + generated, joined, role = await db_call('get_user_stats', user) + + stats_str = f'generated: {generated}\n' + stats_str += f'joined: {joined}\n' + stats_str += f'role: {role}\n' + + await bot.reply_to( + message, stats_str) + + @bot.message_handler(commands=['donate']) + async def donation_info(message): + await bot.reply_to( + message, DONATION_INFO) + + @bot.message_handler(commands=['say']) + async def say(message): + chat = message.chat + user = message.from_user + + if (chat.type == 'group') or (user.id != 383385940): + return + + await bot.send_message(GROUP_ID, message.text[4:]) + + + # generic txt2img handler + + async def _generic_txt2img(message_or_query): + if isinstance(message_or_query, CallbackQuery): + query = message_or_query + message = query.message + user = query.from_user + chat = query.message.chat + + else: + message = message_or_query + user = message.from_user + chat = message.chat + + reply_id = None + if chat.type == 'group' and chat.id == GROUP_ID: + reply_id = message.message_id + + user_row = await db_call('get_or_create_user', user.id) + + # init new msg + init_msg = 'started processing txt2img request...' + status_msg = await bot.reply_to(message, init_msg) + await db_call( + 'new_user_request', user.id, message.id, status_msg.id, status=init_msg) + + prompt = ' '.join(message.text.split(' ')[1:]) + + if len(prompt) == 0: + await bot.edit_message_text( + 'Empty text prompt ignored.', + chat_id=status_msg.chat.id, + message_id=status_msg.id + ) + await db_call('update_user_request', status_msg.id, 'Empty text prompt ignored.') + return + + logging.info(f'mid: {message.id}') + + user_config = {**user_row} + del user_config['id'] + + params = { + 'prompt': prompt, + **user_config + } + + await db_call( + 'update_user_stats', user.id, 'txt2img', last_prompt=prompt) + + ec = await work_request(user, status_msg, 'txt2img', params) + + if ec == 0: + await db_call('increment_generated', user.id) + + + # generic img2img handler + + async def _generic_img2img(message_or_query): + if isinstance(message_or_query, CallbackQuery): + query = message_or_query + message = query.message + user = query.from_user + chat = query.message.chat + + else: + message = message_or_query + user = message.from_user + chat = message.chat + + reply_id = None + if chat.type == 'group' and chat.id == GROUP_ID: + reply_id = message.message_id + + user_row = await db_call('get_or_create_user', user.id) + + # init new msg + init_msg = 'started processing txt2img request...' + status_msg = await bot.reply_to(message, init_msg) + await db_call( + 'new_user_request', user.id, message.id, status_msg.id, status=init_msg) + + if not message.caption.startswith('/img2img'): + await bot.reply_to( + message, + 'For image to image you need to add /img2img to the beggining of your caption' + ) + return + + prompt = ' '.join(message.caption.split(' ')[1:]) + + if len(prompt) == 0: + await bot.reply_to(message, 'Empty text prompt ignored.') + return + + file_id = message.photo[-1].file_id + file_path = (await bot.get_file(file_id)).file_path + image_raw = await bot.download_file(file_path) + + with Image.open(io.BytesIO(image_raw)) as image: + w, h = image.size + + if w > 512 or h > 512: + logging.warning(f'user sent img of size {image.size}') + image.thumbnail((512, 512)) + logging.warning(f'resized it to {image.size}') + + image.save(f'ipfs-docker-staging/image.png', format='PNG') + + ipfs_hash = ipfs_node.add('image.png') + ipfs_node.pin(ipfs_hash) + + logging.info(f'published input image {ipfs_hash} on ipfs') + + logging.info(f'mid: {message.id}') + + user_config = {**user_row} + del user_config['id'] + + params = { + 'prompt': prompt, + **user_config + } + + await db_call( + 'update_user_stats', + user.id, + 'img2img', + last_file=file_id, + last_prompt=prompt, + last_binary=ipfs_hash + ) + + ec = await work_request( + user, status_msg, 'img2img', params, + file_id=file_id, + binary_data=ipfs_hash + ) + + if ec == 0: + await db_call('increment_generated', user.id) + + + # generic redo handler + + async def _redo(message_or_query): + is_query = False + if isinstance(message_or_query, CallbackQuery): + is_query = True + query = message_or_query + message = query.message + user = query.from_user + chat = query.message.chat + + elif isinstance(message_or_query, Message): + message = message_or_query + user = message.from_user + chat = message.chat + + init_msg = 'started processing redo request...' + if is_query: + status_msg = await bot.send_message(chat.id, init_msg) + + else: + status_msg = await bot.reply_to(message, init_msg) + + method = await db_call('get_last_method_of', user.id) + prompt = await db_call('get_last_prompt_of', user.id) + + file_id = None + binary = '' + if method == 'img2img': + file_id = await db_call('get_last_file_of', user.id) + binary = await db_call('get_last_binary_of', user.id) + + if not prompt: + await bot.reply_to( + message, + 'no last prompt found, do a txt2img cmd first!' + ) + return + + + user_row = await db_call('get_or_create_user', user.id) + await db_call( + 'new_user_request', user.id, message.id, status_msg.id, status=init_msg) + user_config = {**user_row} + del user_config['id'] + + params = { + 'prompt': prompt, + **user_config + } + + await work_request( + user, status_msg, 'redo', params, + file_id=file_id, + binary_data=binary + ) + + + # "proxy" handlers just request routers + + @bot.message_handler(commands=['txt2img']) + async def send_txt2img(message): + await _generic_txt2img(message) + + @bot.message_handler(func=lambda message: True, content_types=[ + 'photo', 'document']) + async def send_img2img(message): + await _generic_img2img(message) + + @bot.message_handler(commands=['img2img']) + async def img2img_missing_image(message): + await bot.reply_to( + message, + 'seems you tried to do an img2img command without sending image' + ) + + @bot.message_handler(commands=['redo']) + async def redo(message): + await _redo(message) + + @bot.callback_query_handler(func=lambda call: True) + async def callback_query(call): + msg = json.loads(call.data) + logging.info(call.data) + method = msg.get('method') + match method: + case 'redo': + await _redo(call) + + + # catch all handler for things we dont support + + @bot.message_handler(func=lambda message: True) + async def echo_message(message): + if message.text[0] == '/': + await bot.reply_to(message, UNKNOWN_CMD_TEXT) diff --git a/skynet/frontend/telegram/utils.py b/skynet/frontend/telegram/utils.py new file mode 100644 index 0000000..682cd02 --- /dev/null +++ b/skynet/frontend/telegram/utils.py @@ -0,0 +1,113 @@ +#!/usr/bin/python + +import json +import logging +import traceback + +from datetime import datetime, timezone + +from telebot.types import InlineKeyboardButton, InlineKeyboardMarkup +from telebot.async_telebot import ExceptionHandler +from telebot.formatting import hlink + +from skynet.constants import * + + +def timestamp_pretty(): + return datetime.now(timezone.utc).strftime('%H:%M:%S') + + +def tg_user_pretty(tguser): + if tguser.username: + return f'@{tguser.username}' + else: + return f'{tguser.first_name} id: {tguser.id}' + + +class SKYExceptionHandler(ExceptionHandler): + + def handle(exception): + traceback.print_exc() + + +def build_redo_menu(): + btn_redo = InlineKeyboardButton("Redo", callback_data=json.dumps({'method': 'redo'})) + inline_keyboard = InlineKeyboardMarkup() + inline_keyboard.add(btn_redo) + return inline_keyboard + + +def prepare_metainfo_caption(tguser, worker: str, reward: str, meta: dict) -> str: + prompt = meta["prompt"] + if len(prompt) > 256: + prompt = prompt[:256] + + + meta_str = f'by {tg_user_pretty(tguser)}\n' + meta_str += f'performed by {worker}\n' + meta_str += f'reward: {reward}\n' + + meta_str += f'prompt: {prompt}\n' + meta_str += f'seed: {meta["seed"]}\n' + meta_str += f'step: {meta["step"]}\n' + meta_str += f'guidance: {meta["guidance"]}\n' + if meta['strength']: + meta_str += f'strength: {meta["strength"]}\n' + meta_str += f'algo: {meta["algo"]}\n' + if meta['upscaler']: + meta_str += f'upscaler: {meta["upscaler"]}\n' + + meta_str += f'Made with Skynet v{VERSION}\n' + meta_str += f'JOIN THE SWARM: @skynetgpu' + return meta_str + + +def generate_reply_caption( + tguser, # telegram user + params: dict, + ipfs_hash: str, + tx_hash: str, + worker: str, + reward: str +): + ipfs_link = hlink( + 'Get your image on IPFS', + f'https://ipfs.ancap.tech/ipfs/{ipfs_hash}/image.png' + ) + explorer_link = hlink( + 'SKYNET Transaction Explorer', + f'https://skynet.ancap.tech/v2/explore/transaction/{tx_hash}' + ) + + meta_info = prepare_metainfo_caption(tguser, worker, reward, params) + + final_msg = '\n'.join([ + 'Worker finished your task!', + ipfs_link, + explorer_link, + f'PARAMETER INFO:\n{meta_info}' + ]) + + final_msg = '\n'.join([ + f'{ipfs_link}', + f'{explorer_link}', + f'{meta_info}' + ]) + + logging.info(final_msg) + + return final_msg + + +async def get_global_config(cleos): + return (await cleos.aget_table( + 'telos.gpu', 'telos.gpu', 'config'))[0] + +async def get_user_nonce(cleos, user: str): + return (await cleos.aget_table( + 'telos.gpu', 'telos.gpu', 'users', + index_position=1, + key_type='name', + lower_bound=user, + upper_bound=user + ))[0]['nonce']