mirror of https://github.com/skygpu/skynet.git
498 lines
14 KiB
Python
498 lines
14 KiB
Python
#!/usr/bin/python
|
|
|
|
import io
|
|
import zlib
|
|
import random
|
|
import logging
|
|
import asyncio
|
|
import traceback
|
|
|
|
from hashlib import sha256
|
|
from datetime import datetime
|
|
|
|
import asks
|
|
import docker
|
|
|
|
from PIL import Image
|
|
from leap.cleos import CLEOS, default_nodeos_image
|
|
from leap.sugar import get_container, collect_stdout
|
|
from leap.hyperion import HyperionAPI
|
|
from trio_asyncio import aio_as_trio
|
|
from telebot.types import (
|
|
InputFile, InputMediaPhoto, InlineKeyboardButton, InlineKeyboardMarkup
|
|
)
|
|
|
|
from telebot.types import CallbackQuery
|
|
from telebot.async_telebot import AsyncTeleBot, ExceptionHandler
|
|
from telebot.formatting import hlink
|
|
|
|
from ..db import open_new_database, open_database_connection
|
|
from ..constants import *
|
|
|
|
from . import *
|
|
|
|
|
|
class SKYExceptionHandler(ExceptionHandler):
|
|
|
|
def handle(exception):
|
|
traceback.print_exc()
|
|
|
|
|
|
def build_redo_menu():
|
|
btn_redo = InlineKeyboardButton("Redo", callback_data=json.dumps({'method': 'redo'}))
|
|
inline_keyboard = InlineKeyboardMarkup()
|
|
inline_keyboard.add(btn_redo)
|
|
return inline_keyboard
|
|
|
|
|
|
def prepare_metainfo_caption(tguser, meta: dict) -> str:
|
|
prompt = meta["prompt"]
|
|
if len(prompt) > 256:
|
|
prompt = prompt[:256]
|
|
|
|
if tguser.username:
|
|
user = f'@{tguser.username}'
|
|
else:
|
|
user = f'{tguser.first_name} id: {tguser.id}'
|
|
|
|
meta_str = f'<u>by {user}</u>\n'
|
|
|
|
meta_str += f'<code>prompt:</code> {prompt}\n'
|
|
meta_str += f'<code>seed: {meta["seed"]}</code>\n'
|
|
meta_str += f'<code>step: {meta["step"]}</code>\n'
|
|
meta_str += f'<code>guidance: {meta["guidance"]}</code>\n'
|
|
if meta['strength']:
|
|
meta_str += f'<code>strength: {meta["strength"]}</code>\n'
|
|
meta_str += f'<code>algo: {meta["algo"]}</code>\n'
|
|
if meta['upscaler']:
|
|
meta_str += f'<code>upscaler: {meta["upscaler"]}</code>\n'
|
|
|
|
meta_str += f'<b><u>Made with Skynet {VERSION}</u></b>\n'
|
|
meta_str += f'<b>JOIN THE SWARM: @skynetgpu</b>'
|
|
return meta_str
|
|
|
|
|
|
def generate_reply_caption(
|
|
tguser, # telegram user
|
|
params: dict,
|
|
ipfs_hash: str,
|
|
tx_hash: str
|
|
):
|
|
ipfs_link = hlink(
|
|
'Get your image on IPFS',
|
|
f'http://test1.us.telos.net:8080/ipfs/{ipfs_hash}/image.png'
|
|
)
|
|
explorer_link = hlink(
|
|
'SKYNET Transaction Explorer',
|
|
f'http://test1.us.telos.net:42001/v2/explore/transaction/{tx_hash}'
|
|
)
|
|
|
|
meta_info = prepare_metainfo_caption(tguser, params)
|
|
|
|
final_msg = '\n'.join([
|
|
'Worker finished your task!',
|
|
ipfs_link,
|
|
explorer_link,
|
|
f'PARAMETER INFO:\n{meta_info}'
|
|
])
|
|
|
|
final_msg = '\n'.join([
|
|
f'<b>{ipfs_link}</b>',
|
|
f'<i>{explorer_link}</i>',
|
|
f'{meta_info}'
|
|
])
|
|
|
|
logging.info(final_msg)
|
|
|
|
return final_msg
|
|
|
|
|
|
async def get_global_config(cleos):
|
|
return (await cleos.aget_table(
|
|
'telos.gpu', 'telos.gpu', 'config'))[0]
|
|
|
|
async def get_user_nonce(cleos, user: str):
|
|
return (await cleos.aget_table(
|
|
'telos.gpu', 'telos.gpu', 'users',
|
|
index_position=1,
|
|
key_type='name',
|
|
lower_bound=user,
|
|
upper_bound=user
|
|
))[0]['nonce']
|
|
|
|
async def work_request(
|
|
bot, cleos, hyperion,
|
|
message, user, chat,
|
|
account: str,
|
|
permission: str,
|
|
params: dict,
|
|
file_id: str | None = None,
|
|
file_path: str | None = None
|
|
):
|
|
if params['seed'] == None:
|
|
params['seed'] = random.randint(0, 9e18)
|
|
|
|
body = json.dumps({
|
|
'method': 'diffuse',
|
|
'params': params
|
|
})
|
|
request_time = datetime.now().isoformat()
|
|
|
|
if file_id:
|
|
image_raw = await bot.download_file(file_path)
|
|
image = Image.open(io.BytesIO(image_raw))
|
|
w, h = image.size
|
|
logging.info(f'user sent img of size {image.size}')
|
|
|
|
if w > 512 or h > 512:
|
|
image.thumbnail((512, 512))
|
|
logging.warning(f'resized it to {image.size}')
|
|
|
|
binary = image_raw.hex()
|
|
|
|
ec, out = cleos.push_action(
|
|
'telos.gpu', 'enqueue', [account, body, binary, '20.0000 GPU'], f'{account}@{permission}'
|
|
)
|
|
out = collect_stdout(out)
|
|
if ec != 0:
|
|
await bot.reply_to(message, out)
|
|
return
|
|
|
|
nonce = await get_user_nonce(cleos, account)
|
|
request_hash = sha256(
|
|
(str(nonce) + body).encode('utf-8')).hexdigest().upper()
|
|
|
|
request_id = int(out)
|
|
logging.info(f'{request_id} enqueued.')
|
|
|
|
config = await get_global_config(cleos)
|
|
|
|
tx_hash = None
|
|
ipfs_hash = None
|
|
for i in range(60):
|
|
submits = await hyperion.aget_actions(
|
|
account=account,
|
|
filter='telos.gpu:submit',
|
|
sort='desc',
|
|
after=request_time
|
|
)
|
|
actions = [
|
|
action
|
|
for action in submits['actions']
|
|
if action[
|
|
'act']['data']['request_hash'] == request_hash
|
|
]
|
|
if len(actions) > 0:
|
|
tx_hash = actions[0]['trx_id']
|
|
ipfs_hash = actions[0]['act']['data']['ipfs_hash']
|
|
break
|
|
|
|
await asyncio.sleep(1)
|
|
|
|
if not ipfs_hash:
|
|
await bot.reply_to(message, 'timeout processing request')
|
|
return
|
|
|
|
# attempt to get the image and send it
|
|
ipfs_link = f'http://test1.us.telos.net:8080/ipfs/{ipfs_hash}/image.png'
|
|
logging.info(f'attempting to get image at {ipfs_link}')
|
|
resp = None
|
|
for i in range(10):
|
|
try:
|
|
resp = await asks.get(ipfs_link, timeout=2)
|
|
|
|
except asks.errors.RequestTimeout:
|
|
logging.warning('timeout...')
|
|
...
|
|
|
|
logging.info(f'status_code: {resp.status_code}')
|
|
|
|
caption = generate_reply_caption(
|
|
user, params, ipfs_hash, tx_hash)
|
|
|
|
if resp.status_code != 200:
|
|
await bot.reply_to(
|
|
message,
|
|
caption,
|
|
reply_markup=build_redo_menu(),
|
|
parse_mode='HTML'
|
|
)
|
|
|
|
else:
|
|
if file_id: # img2img
|
|
await bot.send_media_group(
|
|
chat.id,
|
|
media=[
|
|
InputMediaPhoto(file_id),
|
|
InputMediaPhoto(
|
|
resp.raw,
|
|
caption=caption
|
|
)
|
|
],
|
|
reply_markup=build_redo_menu(),
|
|
parse_mode='HTML'
|
|
)
|
|
|
|
else: # txt2img
|
|
await bot.send_photo(
|
|
chat.id,
|
|
caption=caption,
|
|
photo=resp.raw,
|
|
reply_markup=build_redo_menu(),
|
|
parse_mode='HTML'
|
|
)
|
|
|
|
|
|
async def run_skynet_telegram(
|
|
tg_token: str,
|
|
account: str,
|
|
permission: str,
|
|
node_url: str,
|
|
hyperion_url: str,
|
|
db_host: str,
|
|
db_user: str,
|
|
db_pass: str,
|
|
key: str = None
|
|
):
|
|
dclient = docker.from_env()
|
|
vtestnet = get_container(
|
|
dclient,
|
|
default_nodeos_image(),
|
|
force_unique=True,
|
|
detach=True,
|
|
network='host',
|
|
remove=True)
|
|
|
|
cleos = CLEOS(dclient, vtestnet, url=node_url, remote=node_url)
|
|
hyperion = HyperionAPI(hyperion_url)
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
if key:
|
|
cleos.setup_wallet(key)
|
|
|
|
bot = AsyncTeleBot(tg_token, exception_handler=SKYExceptionHandler)
|
|
logging.info(f'tg_token: {tg_token}')
|
|
|
|
async with open_database_connection(
|
|
db_user, db_pass, db_host
|
|
) as db_call:
|
|
|
|
@bot.message_handler(commands=['help'])
|
|
async def send_help(message):
|
|
splt_msg = message.text.split(' ')
|
|
|
|
if len(splt_msg) == 1:
|
|
await bot.reply_to(message, HELP_TEXT)
|
|
|
|
else:
|
|
param = splt_msg[1]
|
|
if param in HELP_TOPICS:
|
|
await bot.reply_to(message, HELP_TOPICS[param])
|
|
|
|
else:
|
|
await bot.reply_to(message, HELP_UNKWNOWN_PARAM)
|
|
|
|
@bot.message_handler(commands=['cool'])
|
|
async def send_cool_words(message):
|
|
await bot.reply_to(message, '\n'.join(COOL_WORDS))
|
|
|
|
@bot.message_handler(commands=['txt2img'])
|
|
async def send_txt2img(message):
|
|
user = message.from_user
|
|
chat = message.chat
|
|
reply_id = None
|
|
if chat.type == 'group' and chat.id == GROUP_ID:
|
|
reply_id = message.message_id
|
|
|
|
prompt = ' '.join(message.text.split(' ')[1:])
|
|
|
|
if len(prompt) == 0:
|
|
await bot.reply_to(message, 'Empty text prompt ignored.')
|
|
return
|
|
|
|
logging.info(f'mid: {message.id}')
|
|
|
|
user_row = await db_call('get_or_create_user', user.id)
|
|
user_config = {**user_row}
|
|
del user_config['id']
|
|
|
|
params = {
|
|
'prompt': prompt,
|
|
**user_config
|
|
}
|
|
|
|
await db_call('update_user_stats', user.id, last_prompt=prompt)
|
|
|
|
await work_request(
|
|
bot, cleos, hyperion,
|
|
message, user, chat,
|
|
account, permission, params
|
|
)
|
|
|
|
@bot.message_handler(func=lambda message: True, content_types=['photo'])
|
|
async def send_img2img(message):
|
|
user = message.from_user
|
|
chat = message.chat
|
|
reply_id = None
|
|
if chat.type == 'group' and chat.id == GROUP_ID:
|
|
reply_id = message.message_id
|
|
|
|
if not message.caption.startswith('/img2img'):
|
|
await bot.reply_to(
|
|
message,
|
|
'For image to image you need to add /img2img to the beggining of your caption'
|
|
)
|
|
return
|
|
|
|
prompt = ' '.join(message.caption.split(' ')[1:])
|
|
|
|
if len(prompt) == 0:
|
|
await bot.reply_to(message, 'Empty text prompt ignored.')
|
|
return
|
|
|
|
file_id = message.photo[-1].file_id
|
|
file_path = (await bot.get_file(file_id)).file_path
|
|
|
|
logging.info(f'mid: {message.id}')
|
|
|
|
user_row = await db_call('get_or_create_user', user.id)
|
|
user_config = {**user_row}
|
|
del user_config['id']
|
|
|
|
params = {
|
|
'prompt': prompt,
|
|
**user_config
|
|
}
|
|
|
|
await db_call('update_user_stats', user.id, last_prompt=prompt)
|
|
|
|
await work_request(
|
|
bot, cleos, hyperion,
|
|
message, user, chat,
|
|
account, permission, params,
|
|
file_id=file_id, file_path=file_path
|
|
)
|
|
|
|
|
|
@bot.message_handler(commands=['img2img'])
|
|
async def img2img_missing_image(message):
|
|
await bot.reply_to(
|
|
message,
|
|
'seems you tried to do an img2img command without sending image'
|
|
)
|
|
|
|
async def _redo(message_or_query):
|
|
if isinstance(message_or_query, CallbackQuery):
|
|
query = message_or_query
|
|
message = query.message
|
|
user = query.from_user
|
|
chat = query.message.chat
|
|
|
|
else:
|
|
message = message_or_query
|
|
user = message.from_user
|
|
chat = message.chat
|
|
|
|
reply_id = None
|
|
if chat.type == 'group' and chat.id == GROUP_ID:
|
|
reply_id = message.message_id
|
|
|
|
prompt = await db_call('get_last_prompt_of', user.id)
|
|
|
|
if not prompt:
|
|
await bot.reply_to(
|
|
message,
|
|
'no last prompt found, do a txt2img cmd first!'
|
|
)
|
|
return
|
|
|
|
|
|
user_row = await db_call('get_or_create_user', user.id)
|
|
user_config = {**user_row}
|
|
del user_config['id']
|
|
|
|
params = {
|
|
'prompt': prompt,
|
|
**user_config
|
|
}
|
|
|
|
await work_request(
|
|
bot, cleos, hyperion,
|
|
message, user, chat,
|
|
account, permission, params
|
|
)
|
|
|
|
@bot.message_handler(commands=['redo'])
|
|
async def redo(message):
|
|
await _redo(message)
|
|
|
|
@bot.message_handler(commands=['config'])
|
|
async def set_config(message):
|
|
user = message.from_user.id
|
|
try:
|
|
attr, val, reply_txt = validate_user_config_request(
|
|
message.text)
|
|
|
|
logging.info(f'user config update: {attr} to {val}')
|
|
await db_call('update_user_config', user, attr, val)
|
|
logging.info('done')
|
|
|
|
except BaseException as e:
|
|
reply_txt = str(e)
|
|
|
|
finally:
|
|
await bot.reply_to(message, reply_txt)
|
|
|
|
@bot.message_handler(commands=['stats'])
|
|
async def user_stats(message):
|
|
user = message.from_user.id
|
|
|
|
generated, joined, role = await db_call('get_user_stats', user)
|
|
|
|
stats_str = f'generated: {generated}\n'
|
|
stats_str += f'joined: {joined}\n'
|
|
stats_str += f'role: {role}\n'
|
|
|
|
await bot.reply_to(
|
|
message, stats_str)
|
|
|
|
@bot.message_handler(commands=['donate'])
|
|
async def donation_info(message):
|
|
await bot.reply_to(
|
|
message, DONATION_INFO)
|
|
|
|
@bot.message_handler(commands=['say'])
|
|
async def say(message):
|
|
chat = message.chat
|
|
user = message.from_user
|
|
|
|
if (chat.type == 'group') or (user.id != 383385940):
|
|
return
|
|
|
|
await bot.send_message(GROUP_ID, message.text[4:])
|
|
|
|
|
|
@bot.message_handler(func=lambda message: True)
|
|
async def echo_message(message):
|
|
if message.text[0] == '/':
|
|
await bot.reply_to(message, UNKNOWN_CMD_TEXT)
|
|
|
|
@bot.callback_query_handler(func=lambda call: True)
|
|
async def callback_query(call):
|
|
msg = json.loads(call.data)
|
|
logging.info(call.data)
|
|
method = msg.get('method')
|
|
match method:
|
|
case 'redo':
|
|
await _redo(call)
|
|
|
|
try:
|
|
await bot.infinity_polling()
|
|
|
|
except KeyboardInterrupt:
|
|
...
|
|
|
|
finally:
|
|
vtestnet.stop()
|