diff --git a/skynet/cli.py b/skynet/cli.py
index 9c1e6f4..8a2d8ac 100644
--- a/skynet/cli.py
+++ b/skynet/cli.py
@@ -26,7 +26,7 @@ from .ipfs import open_ipfs_node
from .config import *
from .nodeos import open_cleos, open_nodeos
from .constants import *
-from .frontend.telegram import run_skynet_telegram
+from .frontend.telegram import SkynetTelegramFrontend
@click.group()
@@ -318,7 +318,7 @@ def nodeos():
@click.option(
'--ipfs-url', '-n', default=DEFAULT_IPFS_REMOTE)
@click.option(
- '--algos', '-A', default=json.dumps(['midj', 'ink']))
+ '--algos', '-A', default=json.dumps(['midj']))
def dgpu(
loglevel: str,
account: str,
@@ -383,8 +383,9 @@ def telegram(
key, account, permission)
_, _, tg_token, cfg = init_env_from_config()
- asyncio.run(
- run_skynet_telegram(
+
+ async def _async_main():
+ frontend = SkynetTelegramFrontend(
tg_token,
account,
permission,
@@ -393,7 +394,13 @@ def telegram(
db_host, db_user, db_pass,
remote_ipfs_node=ipfs_url,
key=key
- ))
+ )
+
+ async with frontend.open():
+ await frontend.bot.infinity_polling()
+
+
+ asyncio.run(_async_main())
class IPFSHTTP:
diff --git a/skynet/db/functions.py b/skynet/db/functions.py
index 3a8b760..8a56484 100644
--- a/skynet/db/functions.py
+++ b/skynet/db/functions.py
@@ -6,7 +6,6 @@ import string
import logging
import importlib
-from typing import Optional
from datetime import datetime
from contextlib import contextmanager as cm
from contextlib import asynccontextmanager as acm
@@ -15,7 +14,6 @@ import docker
import asyncpg
import psycopg2
-from asyncpg.exceptions import UndefinedColumnError
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
from ..constants import *
@@ -49,6 +47,17 @@ CREATE TABLE IF NOT EXISTS skynet.user_config(
ALTER TABLE skynet.user_config
ADD FOREIGN KEY(id)
REFERENCES skynet.user(id);
+
+CREATE TABLE IF NOT EXISTS skynet.user_requests(
+ id SERIAL NOT NULL,
+ user_id SERIAL NOT NULL,
+ sent TIMESTAMP NOT NULL,
+ status TEXT NOT NULL,
+ status_msg SERIAL PRIMARY KEY NOT NULL
+);
+ALTER TABLE skynet.user_requests
+ ADD FOREIGN KEY(user_id)
+ REFERENCES skynet.user(id);
'''
@@ -199,6 +208,53 @@ async def get_last_binary_of(conn, user: int):
return await stmt.fetchval(user)
+async def get_user_request(conn, mid: int):
+ stmt = await conn.prepare(
+ 'SELECT * FROM skynet.user_requests WHERE id = $1')
+ return await stmt.fetch(mid)
+
+async def get_user_request_by_sid(conn, sid: int):
+ stmt = await conn.prepare(
+ 'SELECT * FROM skynet.user_requests WHERE status_msg = $1')
+ return (await stmt.fetch(sid))[0]
+
+async def new_user_request(
+ conn, user: int, mid: int,
+ status_msg: int,
+ status: str = 'started processing request...'
+):
+ date = datetime.utcnow()
+ async with conn.transaction():
+ stmt = await conn.prepare('''
+ INSERT INTO skynet.user_requests(
+ id, user_id, sent, status, status_msg
+ )
+
+ VALUES($1, $2, $3, $4, $5)
+ ''')
+ await stmt.fetch(mid, user, date, status, status_msg)
+
+async def update_user_request(
+ conn, mid: int, status: str
+):
+ stmt = await conn.prepare(f'''
+ UPDATE skynet.user_requests
+ SET status = $2
+ WHERE id = $1
+ ''')
+ await stmt.fetch(mid, status)
+
+async def update_user_request_by_sid(
+ conn, sid: int, status: str
+):
+ stmt = await conn.prepare(f'''
+ UPDATE skynet.user_requests
+ SET status = $2
+ WHERE status_msg = $1
+ ''')
+ await stmt.fetch(sid, status)
+
+
async def new_user(conn, uid: int):
if await get_user(conn, uid):
raise ValueError('User already present on db')
diff --git a/skynet/dgpu.py b/skynet/dgpu.py
index e3b882d..763fdd1 100644
--- a/skynet/dgpu.py
+++ b/skynet/dgpu.py
@@ -4,8 +4,8 @@ import gc
import io
import json
import time
-import random
import logging
+import traceback
from PIL import Image
from typing import List, Optional
@@ -15,19 +15,13 @@ import trio
import asks
import torch
-from leap.cleos import CLEOS, default_nodeos_image
+from leap.cleos import CLEOS
from leap.sugar import *
-from diffusers import (
- StableDiffusionPipeline,
- StableDiffusionImg2ImgPipeline,
- EulerAncestralDiscreteScheduler
-)
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
-from diffusers.models import UNet2DConditionModel
-from .ipfs import IPFSDocker, open_ipfs_node, get_ipfs_file
+from .ipfs import open_ipfs_node, get_ipfs_file
from .utils import *
from .constants import *
@@ -125,7 +119,7 @@ async def open_dgpu_node(
logging.info(f'binext: {len(binext) if binext else 0} bytes')
if binext:
_params['image'] = image
- _params['strength'] = params['strength']
+ _params['strength'] = float(Decimal(params['strength']))
else:
_params['width'] = int(params['width'])
@@ -135,7 +129,7 @@ async def open_dgpu_node(
image = models[algo]['pipe'](
params['prompt'],
**_params,
- guidance_scale=params['guidance'],
+ guidance_scale=float(Decimal(params['guidance'])),
num_inference_steps=int(params['step']),
generator=torch.manual_seed(int(params['seed']))
).images[0]
@@ -246,7 +240,7 @@ async def open_dgpu_node(
async def maybe_withdraw_all():
logging.info('maybe_withdraw_all')
- balance = get_worker_balance()
+ balance = await get_worker_balance()
if not balance:
return
@@ -323,7 +317,7 @@ async def open_dgpu_node(
try:
while True:
if auto_withdraw:
- maybe_withdraw_all()
+ await maybe_withdraw_all()
queue = await get_work_requests_last_hour()
@@ -371,6 +365,7 @@ async def open_dgpu_node(
break
except BaseException as e:
+ traceback.print_exc()
await cancel_work(rid, str(e))
break
diff --git a/skynet/frontend/__init__.py b/skynet/frontend/__init__.py
index 290b6b3..6a2557e 100644
--- a/skynet/frontend/__init__.py
+++ b/skynet/frontend/__init__.py
@@ -1,11 +1,5 @@
#!/usr/bin/python
-import json
-
-from typing import Union, Optional
-from pathlib import Path
-from contextlib import contextmanager as cm
-
from ..constants import *
diff --git a/skynet/frontend/telegram.py b/skynet/frontend/telegram.py
deleted file mode 100644
index e6e8b76..0000000
--- a/skynet/frontend/telegram.py
+++ /dev/null
@@ -1,566 +0,0 @@
-#!/usr/bin/python
-
-import io
-import zlib
-import random
-import logging
-import asyncio
-import traceback
-
-from decimal import Decimal
-from hashlib import sha256
-from datetime import datetime, timedelta
-
-import asks
-import docker
-
-from PIL import Image
-from leap.cleos import CLEOS
-from leap.sugar import *
-from leap.hyperion import HyperionAPI
-from trio_asyncio import aio_as_trio
-from telebot.types import (
- InputFile, InputMediaPhoto, InlineKeyboardButton, InlineKeyboardMarkup
-)
-
-from telebot.types import CallbackQuery
-from telebot.async_telebot import AsyncTeleBot, ExceptionHandler
-from telebot.formatting import hlink
-
-from ..db import open_new_database, open_database_connection
-from ..ipfs import open_ipfs_node, get_ipfs_file
-from ..constants import *
-
-from . import *
-
-
-class SKYExceptionHandler(ExceptionHandler):
-
- def handle(exception):
- traceback.print_exc()
-
-
-def build_redo_menu():
- btn_redo = InlineKeyboardButton("Redo", callback_data=json.dumps({'method': 'redo'}))
- inline_keyboard = InlineKeyboardMarkup()
- inline_keyboard.add(btn_redo)
- return inline_keyboard
-
-
-def prepare_metainfo_caption(tguser, worker: str, reward: str, meta: dict) -> str:
- prompt = meta["prompt"]
- if len(prompt) > 256:
- prompt = prompt[:256]
-
- if tguser.username:
- user = f'@{tguser.username}'
- else:
- user = f'{tguser.first_name} id: {tguser.id}'
-
- meta_str = f'by {user}\n'
- meta_str += f'performed by {worker}\n'
- meta_str += f'reward: {reward}\n'
-
- meta_str += f'prompt:
{prompt}\n'
- meta_str += f'seed: {meta["seed"]}
\n'
- meta_str += f'step: {meta["step"]}
\n'
- meta_str += f'guidance: {meta["guidance"]}
\n'
- if meta['strength']:
- meta_str += f'strength: {meta["strength"]}
\n'
- meta_str += f'algo: {meta["algo"]}
\n'
- if meta['upscaler']:
- meta_str += f'upscaler: {meta["upscaler"]}
\n'
-
- meta_str += f'Made with Skynet v{VERSION}\n'
- meta_str += f'JOIN THE SWARM: @skynetgpu'
- return meta_str
-
-
-def generate_reply_caption(
- tguser, # telegram user
- params: dict,
- ipfs_hash: str,
- tx_hash: str,
- worker: str,
- reward: str
-):
- ipfs_link = hlink(
- 'Get your image on IPFS',
- f'https://ipfs.ancap.tech/ipfs/{ipfs_hash}/image.png'
- )
- explorer_link = hlink(
- 'SKYNET Transaction Explorer',
- f'https://skynet.ancap.tech/v2/explore/transaction/{tx_hash}'
- )
-
- meta_info = prepare_metainfo_caption(tguser, worker, reward, params)
-
- final_msg = '\n'.join([
- 'Worker finished your task!',
- ipfs_link,
- explorer_link,
- f'PARAMETER INFO:\n{meta_info}'
- ])
-
- final_msg = '\n'.join([
- f'{ipfs_link}',
- f'{explorer_link}',
- f'{meta_info}'
- ])
-
- logging.info(final_msg)
-
- return final_msg
-
-
-async def get_global_config(cleos):
- return (await cleos.aget_table(
- 'telos.gpu', 'telos.gpu', 'config'))[0]
-
-async def get_user_nonce(cleos, user: str):
- return (await cleos.aget_table(
- 'telos.gpu', 'telos.gpu', 'users',
- index_position=1,
- key_type='name',
- lower_bound=user,
- upper_bound=user
- ))[0]['nonce']
-
-async def work_request(
- bot, cleos, hyperion,
- message, user, chat,
- account: str,
- permission: str,
- private_key: str,
- params: dict,
- file_id: str | None = None,
- binary_data: str = ''
-):
- if params['seed'] == None:
- params['seed'] = random.randint(0, 0xFFFFFFFF)
-
- sanitized_params = {}
- for key, val in params.items():
- if isinstance(val, Decimal):
- val = int(val)
-
- sanitized_params[key] = val
-
- body = json.dumps({
- 'method': 'diffuse',
- 'params': sanitized_params
- })
- request_time = datetime.now().isoformat()
-
- reward = '20.0000 GPU'
- res = await cleos.a_push_action(
- 'telos.gpu',
- 'enqueue',
- {
- 'user': Name(account),
- 'request_body': body,
- 'binary_data': binary_data,
- 'reward': asset_from_str(reward)
- },
- account, private_key, permission=permission
- )
-
- if 'code' in res:
- await bot.reply_to(message, json.dumps(res, indent=4))
- return
-
- out = collect_stdout(res)
-
- request_id, nonce = out.split(':')
-
- request_hash = sha256(
- (nonce + body + binary_data).encode('utf-8')).hexdigest().upper()
-
- request_id = int(request_id)
- logging.info(f'{request_id} enqueued.')
-
- config = await get_global_config(cleos)
-
- tx_hash = None
- ipfs_hash = None
- for i in range(60):
- submits = await hyperion.aget_actions(
- account=account,
- filter='telos.gpu:submit',
- sort='desc',
- after=request_time
- )
- actions = [
- action
- for action in submits['actions']
- if action[
- 'act']['data']['request_hash'] == request_hash
- ]
- if len(actions) > 0:
- tx_hash = actions[0]['trx_id']
- data = actions[0]['act']['data']
- ipfs_hash = data['ipfs_hash']
- worker = data['worker']
- logging.info('Found matching submit!')
- break
-
- await asyncio.sleep(1)
-
- if not ipfs_hash:
- await bot.reply_to(message, 'timeout processing request')
- return
-
- # attempt to get the image and send it
- ipfs_link = f'https://ipfs.ancap.tech/ipfs/{ipfs_hash}/image.png'
- resp = await get_ipfs_file(ipfs_link)
-
- caption = generate_reply_caption(
- user, params, ipfs_hash, tx_hash, worker, reward)
-
- if not resp or resp.status_code != 200:
- logging.error(f'couldn\'t get ipfs hosted image at {ipfs_link}!')
- await bot.reply_to(
- message,
- caption,
- reply_markup=build_redo_menu(),
- parse_mode='HTML'
- )
-
- else:
- logging.info(f'succes! sending generated image')
- if file_id: # img2img
- await bot.send_media_group(
- chat.id,
- media=[
- InputMediaPhoto(file_id),
- InputMediaPhoto(
- resp.raw,
- caption=caption,
- parse_mode='HTML'
- )
- ],
- )
-
- else: # txt2img
- await bot.send_photo(
- chat.id,
- caption=caption,
- photo=resp.raw,
- reply_markup=build_redo_menu(),
- parse_mode='HTML'
- )
-
-
-async def run_skynet_telegram(
- tg_token: str,
- account: str,
- permission: str,
- node_url: str,
- hyperion_url: str,
- db_host: str,
- db_user: str,
- db_pass: str,
- remote_ipfs_node: str,
- key: str = None
-):
- cleos = CLEOS(None, None, url=node_url, remote=node_url)
- hyperion = HyperionAPI(hyperion_url)
-
- logging.basicConfig(level=logging.INFO)
-
- bot = AsyncTeleBot(tg_token, exception_handler=SKYExceptionHandler)
- logging.info(f'tg_token: {tg_token}')
-
- with open_ipfs_node() as ipfs_node:
- ipfs_node.connect(remote_ipfs_node)
- async with open_database_connection(
- db_user, db_pass, db_host
- ) as db_call:
-
- @bot.message_handler(commands=['help'])
- async def send_help(message):
- splt_msg = message.text.split(' ')
-
- if len(splt_msg) == 1:
- await bot.reply_to(message, HELP_TEXT)
-
- else:
- param = splt_msg[1]
- if param in HELP_TOPICS:
- await bot.reply_to(message, HELP_TOPICS[param])
-
- else:
- await bot.reply_to(message, HELP_UNKWNOWN_PARAM)
-
- @bot.message_handler(commands=['cool'])
- async def send_cool_words(message):
- await bot.reply_to(message, '\n'.join(COOL_WORDS))
-
- async def _generic_txt2img(message_or_query):
- if isinstance(message_or_query, CallbackQuery):
- query = message_or_query
- message = query.message
- user = query.from_user
- chat = query.message.chat
-
- else:
- message = message_or_query
- user = message.from_user
- chat = message.chat
-
- reply_id = None
- if chat.type == 'group' and chat.id == GROUP_ID:
- reply_id = message.message_id
-
- prompt = ' '.join(message.text.split(' ')[1:])
-
- if len(prompt) == 0:
- await bot.reply_to(message, 'Empty text prompt ignored.')
- return
-
- logging.info(f'mid: {message.id}')
-
- user_row = await db_call('get_or_create_user', user.id)
- user_config = {**user_row}
- del user_config['id']
-
- params = {
- 'prompt': prompt,
- **user_config
- }
-
- await db_call(
- 'update_user_stats', user.id, 'txt2img', last_prompt=prompt)
-
- ec = await work_request(
- bot, cleos, hyperion,
- message, user, chat,
- account, permission, key, params
- )
-
- if ec == 0:
- await db_call('increment_generated', user.id)
-
- async def _generic_img2img(message_or_query):
- if isinstance(message_or_query, CallbackQuery):
- query = message_or_query
- message = query.message
- user = query.from_user
- chat = query.message.chat
-
- else:
- message = message_or_query
- user = message.from_user
- chat = message.chat
-
- reply_id = None
- if chat.type == 'group' and chat.id == GROUP_ID:
- reply_id = message.message_id
-
- if not message.caption.startswith('/img2img'):
- await bot.reply_to(
- message,
- 'For image to image you need to add /img2img to the beggining of your caption'
- )
- return
-
- prompt = ' '.join(message.caption.split(' ')[1:])
-
- if len(prompt) == 0:
- await bot.reply_to(message, 'Empty text prompt ignored.')
- return
-
- file_id = message.photo[-1].file_id
- file_path = (await bot.get_file(file_id)).file_path
- image_raw = await bot.download_file(file_path)
-
- with Image.open(io.BytesIO(image_raw)) as image:
- w, h = image.size
-
- if w > 512 or h > 512:
- logging.warning(f'user sent img of size {image.size}')
- image.thumbnail((512, 512))
- logging.warning(f'resized it to {image.size}')
-
- image.save(f'ipfs-docker-staging/image.png', format='PNG')
-
- ipfs_hash = ipfs_node.add('image.png')
- ipfs_node.pin(ipfs_hash)
-
- logging.info(f'published input image {ipfs_hash} on ipfs')
-
- logging.info(f'mid: {message.id}')
-
- user_row = await db_call('get_or_create_user', user.id)
- user_config = {**user_row}
- del user_config['id']
-
- params = {
- 'prompt': prompt,
- **user_config
- }
-
- await db_call(
- 'update_user_stats',
- user.id,
- 'img2img',
- last_file=file_id,
- last_prompt=prompt,
- last_binary=ipfs_hash
- )
-
- ec = await work_request(
- bot, cleos, hyperion,
- message, user, chat,
- account, permission, key, params,
- file_id=file_id,
- binary_data=ipfs_hash
- )
-
- if ec == 0:
- await db_call('increment_generated', user.id)
-
- @bot.message_handler(commands=['txt2img'])
- async def send_txt2img(message):
- await _generic_txt2img(message)
-
- @bot.message_handler(func=lambda message: True, content_types=[
- 'photo', 'document'])
- async def send_img2img(message):
- await _generic_img2img(message)
-
- @bot.message_handler(commands=['img2img'])
- async def img2img_missing_image(message):
- await bot.reply_to(
- message,
- 'seems you tried to do an img2img command without sending image'
- )
-
- async def _redo(message_or_query):
- if isinstance(message_or_query, CallbackQuery):
- query = message_or_query
- message = query.message
- user = query.from_user
- chat = query.message.chat
-
- else:
- message = message_or_query
- user = message.from_user
- chat = message.chat
-
- method = await db_call('get_last_method_of', user.id)
- prompt = await db_call('get_last_prompt_of', user.id)
-
- file_id = None
- binary = ''
- if method == 'img2img':
- file_id = await db_call('get_last_file_of', user.id)
- binary = await db_call('get_last_binary_of', user.id)
-
- if not prompt:
- await bot.reply_to(
- message,
- 'no last prompt found, do a txt2img cmd first!'
- )
- return
-
-
- user_row = await db_call('get_or_create_user', user.id)
- user_config = {**user_row}
- del user_config['id']
-
- params = {
- 'prompt': prompt,
- **user_config
- }
-
- await work_request(
- bot, cleos, hyperion,
- message, user, chat,
- account, permission, key, params,
- file_id=file_id,
- binary_data=binary
- )
-
- @bot.message_handler(commands=['queue'])
- async def queue(message):
- an_hour_ago = datetime.now() - timedelta(hours=1)
- queue = await cleos.aget_table(
- 'telos.gpu', 'telos.gpu', 'queue',
- index_position=2,
- key_type='i64',
- sort='desc',
- lower_bound=int(an_hour_ago.timestamp())
- )
- await bot.reply_to(
- message, f'Total requests on skynet queue: {len(queue)}')
-
- @bot.message_handler(commands=['redo'])
- async def redo(message):
- await _redo(message)
-
- @bot.message_handler(commands=['config'])
- async def set_config(message):
- user = message.from_user.id
- try:
- attr, val, reply_txt = validate_user_config_request(
- message.text)
-
- logging.info(f'user config update: {attr} to {val}')
- await db_call('update_user_config', user, attr, val)
- logging.info('done')
-
- except BaseException as e:
- reply_txt = str(e)
-
- finally:
- await bot.reply_to(message, reply_txt)
-
- @bot.message_handler(commands=['stats'])
- async def user_stats(message):
- user = message.from_user.id
-
- await db_call('get_or_create_user', user)
- generated, joined, role = await db_call('get_user_stats', user)
-
- stats_str = f'generated: {generated}\n'
- stats_str += f'joined: {joined}\n'
- stats_str += f'role: {role}\n'
-
- await bot.reply_to(
- message, stats_str)
-
- @bot.message_handler(commands=['donate'])
- async def donation_info(message):
- await bot.reply_to(
- message, DONATION_INFO)
-
- @bot.message_handler(commands=['say'])
- async def say(message):
- chat = message.chat
- user = message.from_user
-
- if (chat.type == 'group') or (user.id != 383385940):
- return
-
- await bot.send_message(GROUP_ID, message.text[4:])
-
- @bot.message_handler(func=lambda message: True)
- async def echo_message(message):
- if message.text[0] == '/':
- await bot.reply_to(message, UNKNOWN_CMD_TEXT)
-
- @bot.callback_query_handler(func=lambda call: True)
- async def callback_query(call):
- msg = json.loads(call.data)
- logging.info(call.data)
- method = msg.get('method')
- match method:
- case 'redo':
- await _redo(call)
-
- try:
- await bot.infinity_polling()
-
- except KeyboardInterrupt:
- ...
diff --git a/skynet/frontend/telegram/__init__.py b/skynet/frontend/telegram/__init__.py
new file mode 100644
index 0000000..ecd66fd
--- /dev/null
+++ b/skynet/frontend/telegram/__init__.py
@@ -0,0 +1,274 @@
+#!/usr/bin/python
+
+import random
+import logging
+import asyncio
+
+from decimal import Decimal
+from hashlib import sha256
+from datetime import datetime
+from contextlib import ExitStack, AsyncExitStack
+from contextlib import asynccontextmanager as acm
+
+from leap.cleos import CLEOS
+from leap.sugar import Name, asset_from_str, collect_stdout
+from leap.hyperion import HyperionAPI
+from telebot.asyncio_helper import ApiTelegramException
+from telebot.types import InputMediaPhoto
+
+from telebot.types import CallbackQuery
+from telebot.async_telebot import AsyncTeleBot
+
+from skynet.db import open_new_database, open_database_connection
+from skynet.ipfs import open_ipfs_node, get_ipfs_file
+from skynet.constants import *
+
+from . import *
+
+from .utils import *
+from .handlers import create_handler_context
+
+
+class SkynetTelegramFrontend:
+
+ def __init__(
+ self,
+ token: str,
+ account: str,
+ permission: str,
+ node_url: str,
+ hyperion_url: str,
+ db_host: str,
+ db_user: str,
+ db_pass: str,
+ remote_ipfs_node: str,
+ key: str
+ ):
+ self.token = token
+ self.account = account
+ self.permission = permission
+ self.node_url = node_url
+ self.hyperion_url = hyperion_url
+ self.db_host = db_host
+ self.db_user = db_user
+ self.db_pass = db_pass
+ self.remote_ipfs_node = remote_ipfs_node
+ self.key = key
+
+ self.bot = AsyncTeleBot(token, exception_handler=SKYExceptionHandler)
+ self.cleos = CLEOS(None, None, url=node_url, remote=node_url)
+ self.hyperion = HyperionAPI(hyperion_url)
+
+ self._exit_stack = ExitStack()
+ self._async_exit_stack = AsyncExitStack()
+
+ async def start(self):
+ self.ipfs_node = self._exit_stack.enter_context(
+ open_ipfs_node())
+
+ self.ipfs_node.connect(self.remote_ipfs_node)
+ logging.info(
+ f'connected to remote ipfs node: {self.remote_ipfs_node}')
+
+ self.db_call = await self._async_exit_stack.enter_async_context(
+ open_database_connection(
+ self.db_user, self.db_pass, self.db_host))
+
+ create_handler_context(self)
+
+ async def stop(self):
+ await self._async_exit_stack.aclose()
+ self._exit_stack.close()
+
+ @acm
+ async def open(self):
+ await self.start()
+ yield self
+ await self.stop()
+
+ async def update_status_message(
+ self, status_msg, new_text: str, **kwargs
+ ):
+ await self.db_call(
+ 'update_user_request_by_sid', status_msg.id, new_text)
+ return await self.bot.edit_message_text(
+ new_text,
+ chat_id=status_msg.chat.id,
+ message_id=status_msg.id,
+ **kwargs
+ )
+
+ async def append_status_message(
+ self, status_msg, add_text: str, **kwargs
+ ):
+ request = await self.db_call('get_user_request_by_sid', status_msg.id)
+ await self.update_status_message(
+ status_msg,
+ request['status'] + add_text,
+ **kwargs
+ )
+
+ async def work_request(
+ self,
+ user,
+ status_msg,
+ method: str,
+ params: dict,
+ file_id: str | None = None,
+ binary_data: str = ''
+ ):
+ if params['seed'] == None:
+ params['seed'] = random.randint(0, 0xFFFFFFFF)
+
+ sanitized_params = {}
+ for key, val in params.items():
+ if isinstance(val, Decimal):
+ val = str(val)
+
+ sanitized_params[key] = val
+
+ body = json.dumps({
+ 'method': 'diffuse',
+ 'params': sanitized_params
+ })
+ request_time = datetime.now().isoformat()
+
+ await self.update_status_message(
+ status_msg,
+ f'processing a \'{method}\' request by {tg_user_pretty(user)}\n'
+ f'[{timestamp_pretty()}] broadcasting transaction to chain...',
+ parse_mode='HTML'
+ )
+
+ reward = '20.0000 GPU'
+ res = await self.cleos.a_push_action(
+ 'telos.gpu',
+ 'enqueue',
+ {
+ 'user': Name(self.account),
+ 'request_body': body,
+ 'binary_data': binary_data,
+ 'reward': asset_from_str(reward)
+ },
+ self.account, self.key, permission=self.permission
+ )
+
+ if 'code' in res or 'statusCode' in res:
+ logging.error(json.dumps(res, indent=4))
+ await self.update_status_message(
+ status_msg,
+ 'skynet has suffered an internal error trying to fill this request')
+ return
+
+ enqueue_tx_id = res['transaction_id']
+ enqueue_tx_link = hlink(
+ 'Your request on Skynet Explorer',
+ f'https://skynet.ancap.tech/v2/explore/transaction/{enqueue_tx_id}'
+ )
+
+ await self.append_status_message(
+ status_msg,
+ f' broadcasted!\n'
+ f'{enqueue_tx_link}\n'
+ f'[{timestamp_pretty()}] workers are processing request...',
+ parse_mode='HTML'
+ )
+
+ out = collect_stdout(res)
+
+ request_id, nonce = out.split(':')
+
+ request_hash = sha256(
+ (nonce + body + binary_data).encode('utf-8')).hexdigest().upper()
+
+ request_id = int(request_id)
+
+ logging.info(f'{request_id} enqueued.')
+
+ tx_hash = None
+ ipfs_hash = None
+ for i in range(60):
+ submits = await self.hyperion.aget_actions(
+ account=self.account,
+ filter='telos.gpu:submit',
+ sort='desc',
+ after=request_time
+ )
+ actions = [
+ action
+ for action in submits['actions']
+ if action[
+ 'act']['data']['request_hash'] == request_hash
+ ]
+ if len(actions) > 0:
+ tx_hash = actions[0]['trx_id']
+ data = actions[0]['act']['data']
+ ipfs_hash = data['ipfs_hash']
+ worker = data['worker']
+ logging.info('Found matching submit!')
+ break
+
+ await asyncio.sleep(1)
+
+ if not ipfs_hash:
+ await self.update_status_message(
+ status_msg,
+ '\n[{timestamp_pretty()}] timeout processing request',
+ parse_mode='HTML'
+ )
+ return
+
+ tx_link = hlink(
+ 'Your result on Skynet Explorer',
+ f'https://skynet.ancap.tech/v2/explore/transaction/{tx_hash}'
+ )
+
+ await self.append_status_message(
+ status_msg,
+ f' request processed!\n'
+ f'{tx_link}\n'
+ f'[{timestamp_pretty()}] trying to download image...\n',
+ parse_mode='HTML'
+ )
+
+ # attempt to get the image and send it
+ ipfs_link = f'https://ipfs.ancap.tech/ipfs/{ipfs_hash}/image.png'
+ resp = await get_ipfs_file(ipfs_link)
+
+ caption = generate_reply_caption(
+ user, params, ipfs_hash, tx_hash, worker, reward)
+
+ if not resp or resp.status_code != 200:
+ logging.error(f'couldn\'t get ipfs hosted image at {ipfs_link}!')
+ await self.update_status_message(
+ status_msg,
+ caption,
+ reply_markup=build_redo_menu(),
+ parse_mode='HTML'
+ )
+
+ else:
+ logging.info(f'success! sending generated image')
+ await self.bot.delete_message(
+ chat_id=status_msg.chat.id, message_id=status_msg.id)
+ if file_id: # img2img
+ await self.bot.send_media_group(
+ status_msg.chat.id,
+ media=[
+ InputMediaPhoto(file_id),
+ InputMediaPhoto(
+ resp.raw,
+ caption=caption,
+ parse_mode='HTML'
+ )
+ ],
+ )
+
+ else: # txt2img
+ await self.bot.send_photo(
+ status_msg.chat.id,
+ caption=caption,
+ photo=resp.raw,
+ reply_markup=build_redo_menu(),
+ parse_mode='HTML'
+ )
diff --git a/skynet/frontend/telegram/handlers.py b/skynet/frontend/telegram/handlers.py
new file mode 100644
index 0000000..7e77880
--- /dev/null
+++ b/skynet/frontend/telegram/handlers.py
@@ -0,0 +1,345 @@
+#!/usr/bin/python
+
+import io
+import json
+import logging
+
+from datetime import datetime, timedelta
+
+from PIL import Image
+from telebot.types import CallbackQuery, Message
+
+from skynet.frontend import validate_user_config_request
+from skynet.constants import *
+
+
+def create_handler_context(frontend: 'SkynetTelegramFrontend'):
+
+ bot = frontend.bot
+ cleos = frontend.cleos
+ db_call = frontend.db_call
+ work_request = frontend.work_request
+
+ ipfs_node = frontend.ipfs_node
+
+ # generic / simple handlers
+
+ @bot.message_handler(commands=['help'])
+ async def send_help(message):
+ splt_msg = message.text.split(' ')
+
+ if len(splt_msg) == 1:
+ await bot.reply_to(message, HELP_TEXT)
+
+ else:
+ param = splt_msg[1]
+ if param in HELP_TOPICS:
+ await bot.reply_to(message, HELP_TOPICS[param])
+
+ else:
+ await bot.reply_to(message, HELP_UNKWNOWN_PARAM)
+
+ @bot.message_handler(commands=['cool'])
+ async def send_cool_words(message):
+ await bot.reply_to(message, '\n'.join(COOL_WORDS))
+
+ @bot.message_handler(commands=['queue'])
+ async def queue(message):
+ an_hour_ago = datetime.now() - timedelta(hours=1)
+ queue = await cleos.aget_table(
+ 'telos.gpu', 'telos.gpu', 'queue',
+ index_position=2,
+ key_type='i64',
+ sort='desc',
+ lower_bound=int(an_hour_ago.timestamp())
+ )
+ await bot.reply_to(
+ message, f'Total requests on skynet queue: {len(queue)}')
+
+
+ @bot.message_handler(commands=['config'])
+ async def set_config(message):
+ user = message.from_user.id
+ try:
+ attr, val, reply_txt = validate_user_config_request(
+ message.text)
+
+ logging.info(f'user config update: {attr} to {val}')
+ await db_call('update_user_config', user, attr, val)
+ logging.info('done')
+
+ except BaseException as e:
+ reply_txt = str(e)
+
+ finally:
+ await bot.reply_to(message, reply_txt)
+
+ @bot.message_handler(commands=['stats'])
+ async def user_stats(message):
+ user = message.from_user.id
+
+ await db_call('get_or_create_user', user)
+ generated, joined, role = await db_call('get_user_stats', user)
+
+ stats_str = f'generated: {generated}\n'
+ stats_str += f'joined: {joined}\n'
+ stats_str += f'role: {role}\n'
+
+ await bot.reply_to(
+ message, stats_str)
+
+ @bot.message_handler(commands=['donate'])
+ async def donation_info(message):
+ await bot.reply_to(
+ message, DONATION_INFO)
+
+ @bot.message_handler(commands=['say'])
+ async def say(message):
+ chat = message.chat
+ user = message.from_user
+
+ if (chat.type == 'group') or (user.id != 383385940):
+ return
+
+ await bot.send_message(GROUP_ID, message.text[4:])
+
+
+ # generic txt2img handler
+
+ async def _generic_txt2img(message_or_query):
+ if isinstance(message_or_query, CallbackQuery):
+ query = message_or_query
+ message = query.message
+ user = query.from_user
+ chat = query.message.chat
+
+ else:
+ message = message_or_query
+ user = message.from_user
+ chat = message.chat
+
+ reply_id = None
+ if chat.type == 'group' and chat.id == GROUP_ID:
+ reply_id = message.message_id
+
+ user_row = await db_call('get_or_create_user', user.id)
+
+ # init new msg
+ init_msg = 'started processing txt2img request...'
+ status_msg = await bot.reply_to(message, init_msg)
+ await db_call(
+ 'new_user_request', user.id, message.id, status_msg.id, status=init_msg)
+
+ prompt = ' '.join(message.text.split(' ')[1:])
+
+ if len(prompt) == 0:
+ await bot.edit_message_text(
+ 'Empty text prompt ignored.',
+ chat_id=status_msg.chat.id,
+ message_id=status_msg.id
+ )
+ await db_call('update_user_request', status_msg.id, 'Empty text prompt ignored.')
+ return
+
+ logging.info(f'mid: {message.id}')
+
+ user_config = {**user_row}
+ del user_config['id']
+
+ params = {
+ 'prompt': prompt,
+ **user_config
+ }
+
+ await db_call(
+ 'update_user_stats', user.id, 'txt2img', last_prompt=prompt)
+
+ ec = await work_request(user, status_msg, 'txt2img', params)
+
+ if ec == 0:
+ await db_call('increment_generated', user.id)
+
+
+ # generic img2img handler
+
+ async def _generic_img2img(message_or_query):
+ if isinstance(message_or_query, CallbackQuery):
+ query = message_or_query
+ message = query.message
+ user = query.from_user
+ chat = query.message.chat
+
+ else:
+ message = message_or_query
+ user = message.from_user
+ chat = message.chat
+
+ reply_id = None
+ if chat.type == 'group' and chat.id == GROUP_ID:
+ reply_id = message.message_id
+
+ user_row = await db_call('get_or_create_user', user.id)
+
+ # init new msg
+ init_msg = 'started processing txt2img request...'
+ status_msg = await bot.reply_to(message, init_msg)
+ await db_call(
+ 'new_user_request', user.id, message.id, status_msg.id, status=init_msg)
+
+ if not message.caption.startswith('/img2img'):
+ await bot.reply_to(
+ message,
+ 'For image to image you need to add /img2img to the beggining of your caption'
+ )
+ return
+
+ prompt = ' '.join(message.caption.split(' ')[1:])
+
+ if len(prompt) == 0:
+ await bot.reply_to(message, 'Empty text prompt ignored.')
+ return
+
+ file_id = message.photo[-1].file_id
+ file_path = (await bot.get_file(file_id)).file_path
+ image_raw = await bot.download_file(file_path)
+
+ with Image.open(io.BytesIO(image_raw)) as image:
+ w, h = image.size
+
+ if w > 512 or h > 512:
+ logging.warning(f'user sent img of size {image.size}')
+ image.thumbnail((512, 512))
+ logging.warning(f'resized it to {image.size}')
+
+ image.save(f'ipfs-docker-staging/image.png', format='PNG')
+
+ ipfs_hash = ipfs_node.add('image.png')
+ ipfs_node.pin(ipfs_hash)
+
+ logging.info(f'published input image {ipfs_hash} on ipfs')
+
+ logging.info(f'mid: {message.id}')
+
+ user_config = {**user_row}
+ del user_config['id']
+
+ params = {
+ 'prompt': prompt,
+ **user_config
+ }
+
+ await db_call(
+ 'update_user_stats',
+ user.id,
+ 'img2img',
+ last_file=file_id,
+ last_prompt=prompt,
+ last_binary=ipfs_hash
+ )
+
+ ec = await work_request(
+ user, status_msg, 'img2img', params,
+ file_id=file_id,
+ binary_data=ipfs_hash
+ )
+
+ if ec == 0:
+ await db_call('increment_generated', user.id)
+
+
+ # generic redo handler
+
+ async def _redo(message_or_query):
+ is_query = False
+ if isinstance(message_or_query, CallbackQuery):
+ is_query = True
+ query = message_or_query
+ message = query.message
+ user = query.from_user
+ chat = query.message.chat
+
+ elif isinstance(message_or_query, Message):
+ message = message_or_query
+ user = message.from_user
+ chat = message.chat
+
+ init_msg = 'started processing redo request...'
+ if is_query:
+ status_msg = await bot.send_message(chat.id, init_msg)
+
+ else:
+ status_msg = await bot.reply_to(message, init_msg)
+
+ method = await db_call('get_last_method_of', user.id)
+ prompt = await db_call('get_last_prompt_of', user.id)
+
+ file_id = None
+ binary = ''
+ if method == 'img2img':
+ file_id = await db_call('get_last_file_of', user.id)
+ binary = await db_call('get_last_binary_of', user.id)
+
+ if not prompt:
+ await bot.reply_to(
+ message,
+ 'no last prompt found, do a txt2img cmd first!'
+ )
+ return
+
+
+ user_row = await db_call('get_or_create_user', user.id)
+ await db_call(
+ 'new_user_request', user.id, message.id, status_msg.id, status=init_msg)
+ user_config = {**user_row}
+ del user_config['id']
+
+ params = {
+ 'prompt': prompt,
+ **user_config
+ }
+
+ await work_request(
+ user, status_msg, 'redo', params,
+ file_id=file_id,
+ binary_data=binary
+ )
+
+
+ # "proxy" handlers just request routers
+
+ @bot.message_handler(commands=['txt2img'])
+ async def send_txt2img(message):
+ await _generic_txt2img(message)
+
+ @bot.message_handler(func=lambda message: True, content_types=[
+ 'photo', 'document'])
+ async def send_img2img(message):
+ await _generic_img2img(message)
+
+ @bot.message_handler(commands=['img2img'])
+ async def img2img_missing_image(message):
+ await bot.reply_to(
+ message,
+ 'seems you tried to do an img2img command without sending image'
+ )
+
+ @bot.message_handler(commands=['redo'])
+ async def redo(message):
+ await _redo(message)
+
+ @bot.callback_query_handler(func=lambda call: True)
+ async def callback_query(call):
+ msg = json.loads(call.data)
+ logging.info(call.data)
+ method = msg.get('method')
+ match method:
+ case 'redo':
+ await _redo(call)
+
+
+ # catch all handler for things we dont support
+
+ @bot.message_handler(func=lambda message: True)
+ async def echo_message(message):
+ if message.text[0] == '/':
+ await bot.reply_to(message, UNKNOWN_CMD_TEXT)
diff --git a/skynet/frontend/telegram/utils.py b/skynet/frontend/telegram/utils.py
new file mode 100644
index 0000000..682cd02
--- /dev/null
+++ b/skynet/frontend/telegram/utils.py
@@ -0,0 +1,113 @@
+#!/usr/bin/python
+
+import json
+import logging
+import traceback
+
+from datetime import datetime, timezone
+
+from telebot.types import InlineKeyboardButton, InlineKeyboardMarkup
+from telebot.async_telebot import ExceptionHandler
+from telebot.formatting import hlink
+
+from skynet.constants import *
+
+
+def timestamp_pretty():
+ return datetime.now(timezone.utc).strftime('%H:%M:%S')
+
+
+def tg_user_pretty(tguser):
+ if tguser.username:
+ return f'@{tguser.username}'
+ else:
+ return f'{tguser.first_name} id: {tguser.id}'
+
+
+class SKYExceptionHandler(ExceptionHandler):
+
+ def handle(exception):
+ traceback.print_exc()
+
+
+def build_redo_menu():
+ btn_redo = InlineKeyboardButton("Redo", callback_data=json.dumps({'method': 'redo'}))
+ inline_keyboard = InlineKeyboardMarkup()
+ inline_keyboard.add(btn_redo)
+ return inline_keyboard
+
+
+def prepare_metainfo_caption(tguser, worker: str, reward: str, meta: dict) -> str:
+ prompt = meta["prompt"]
+ if len(prompt) > 256:
+ prompt = prompt[:256]
+
+
+ meta_str = f'by {tg_user_pretty(tguser)}\n'
+ meta_str += f'performed by {worker}\n'
+ meta_str += f'reward: {reward}\n'
+
+ meta_str += f'prompt:
{prompt}\n'
+ meta_str += f'seed: {meta["seed"]}
\n'
+ meta_str += f'step: {meta["step"]}
\n'
+ meta_str += f'guidance: {meta["guidance"]}
\n'
+ if meta['strength']:
+ meta_str += f'strength: {meta["strength"]}
\n'
+ meta_str += f'algo: {meta["algo"]}
\n'
+ if meta['upscaler']:
+ meta_str += f'upscaler: {meta["upscaler"]}
\n'
+
+ meta_str += f'Made with Skynet v{VERSION}\n'
+ meta_str += f'JOIN THE SWARM: @skynetgpu'
+ return meta_str
+
+
+def generate_reply_caption(
+ tguser, # telegram user
+ params: dict,
+ ipfs_hash: str,
+ tx_hash: str,
+ worker: str,
+ reward: str
+):
+ ipfs_link = hlink(
+ 'Get your image on IPFS',
+ f'https://ipfs.ancap.tech/ipfs/{ipfs_hash}/image.png'
+ )
+ explorer_link = hlink(
+ 'SKYNET Transaction Explorer',
+ f'https://skynet.ancap.tech/v2/explore/transaction/{tx_hash}'
+ )
+
+ meta_info = prepare_metainfo_caption(tguser, worker, reward, params)
+
+ final_msg = '\n'.join([
+ 'Worker finished your task!',
+ ipfs_link,
+ explorer_link,
+ f'PARAMETER INFO:\n{meta_info}'
+ ])
+
+ final_msg = '\n'.join([
+ f'{ipfs_link}',
+ f'{explorer_link}',
+ f'{meta_info}'
+ ])
+
+ logging.info(final_msg)
+
+ return final_msg
+
+
+async def get_global_config(cleos):
+ return (await cleos.aget_table(
+ 'telos.gpu', 'telos.gpu', 'config'))[0]
+
+async def get_user_nonce(cleos, user: str):
+ return (await cleos.aget_table(
+ 'telos.gpu', 'telos.gpu', 'users',
+ index_position=1,
+ key_type='name',
+ lower_bound=user,
+ upper_bound=user
+ ))[0]['nonce']