mirror of https://github.com/skygpu/skynet.git
368 lines
11 KiB
Python
368 lines
11 KiB
Python
#!/usr/bin/python
|
|
|
|
import io
|
|
import json
|
|
import logging
|
|
|
|
from datetime import datetime, timedelta
|
|
|
|
from PIL import Image
|
|
from telebot.types import CallbackQuery, Message
|
|
|
|
from skynet.frontend import validate_user_config_request, perform_auto_conf
|
|
from skynet.constants import *
|
|
|
|
|
|
def create_handler_context(frontend: 'SkynetTelegramFrontend'):
|
|
|
|
bot = frontend.bot
|
|
cleos = frontend.cleos
|
|
db_call = frontend.db_call
|
|
work_request = frontend.work_request
|
|
|
|
ipfs_node = frontend.ipfs_node
|
|
|
|
# generic / simple handlers
|
|
|
|
@bot.message_handler(commands=['help'])
|
|
async def send_help(message):
|
|
splt_msg = message.text.split(' ')
|
|
|
|
if len(splt_msg) == 1:
|
|
await bot.reply_to(message, HELP_TEXT)
|
|
|
|
else:
|
|
param = splt_msg[1]
|
|
if param in HELP_TOPICS:
|
|
await bot.reply_to(message, HELP_TOPICS[param])
|
|
|
|
else:
|
|
await bot.reply_to(message, HELP_UNKWNOWN_PARAM)
|
|
|
|
@bot.message_handler(commands=['cool'])
|
|
async def send_cool_words(message):
|
|
await bot.reply_to(message, '\n'.join(COOL_WORDS))
|
|
|
|
@bot.message_handler(commands=['queue'])
|
|
async def queue(message):
|
|
an_hour_ago = datetime.now() - timedelta(hours=1)
|
|
queue = await cleos.aget_table(
|
|
'gpu.scd', 'gpu.scd', 'queue',
|
|
index_position=2,
|
|
key_type='i64',
|
|
sort='desc',
|
|
lower_bound=int(an_hour_ago.timestamp())
|
|
)
|
|
await bot.reply_to(
|
|
message, f'Total requests on skynet queue: {len(queue)}')
|
|
|
|
|
|
@bot.message_handler(commands=['config'])
|
|
async def set_config(message):
|
|
user = message.from_user.id
|
|
try:
|
|
attr, val, reply_txt = validate_user_config_request(
|
|
message.text)
|
|
|
|
logging.info(f'user config update: {attr} to {val}')
|
|
await db_call('update_user_config', user, attr, val)
|
|
logging.info('done')
|
|
|
|
except BaseException as e:
|
|
reply_txt = str(e)
|
|
|
|
finally:
|
|
await bot.reply_to(message, reply_txt)
|
|
|
|
@bot.message_handler(commands=['stats'])
|
|
async def user_stats(message):
|
|
user = message.from_user.id
|
|
|
|
await db_call('get_or_create_user', user)
|
|
generated, joined, role = await db_call('get_user_stats', user)
|
|
|
|
stats_str = f'generated: {generated}\n'
|
|
stats_str += f'joined: {joined}\n'
|
|
stats_str += f'role: {role}\n'
|
|
|
|
await bot.reply_to(
|
|
message, stats_str)
|
|
|
|
@bot.message_handler(commands=['donate'])
|
|
async def donation_info(message):
|
|
await bot.reply_to(
|
|
message, DONATION_INFO)
|
|
|
|
@bot.message_handler(commands=['say'])
|
|
async def say(message):
|
|
chat = message.chat
|
|
user = message.from_user
|
|
|
|
if (chat.type == 'group') or (user.id != 383385940):
|
|
return
|
|
|
|
await bot.send_message(GROUP_ID, message.text[4:])
|
|
|
|
|
|
# generic txt2img handler
|
|
|
|
async def _generic_txt2img(message_or_query):
|
|
if isinstance(message_or_query, CallbackQuery):
|
|
query = message_or_query
|
|
message = query.message
|
|
user = query.from_user
|
|
chat = query.message.chat
|
|
|
|
else:
|
|
message = message_or_query
|
|
user = message.from_user
|
|
chat = message.chat
|
|
|
|
if chat.type == 'private':
|
|
return
|
|
|
|
reply_id = None
|
|
if chat.type == 'group' and chat.id == GROUP_ID:
|
|
reply_id = message.message_id
|
|
|
|
user_row = await db_call('get_or_create_user', user.id)
|
|
|
|
# init new msg
|
|
init_msg = 'started processing txt2img request...'
|
|
status_msg = await bot.reply_to(message, init_msg)
|
|
await db_call(
|
|
'new_user_request', user.id, message.id, status_msg.id, status=init_msg)
|
|
|
|
prompt = ' '.join(message.text.split(' ')[1:])
|
|
|
|
if len(prompt) == 0:
|
|
await bot.edit_message_text(
|
|
'Empty text prompt ignored.',
|
|
chat_id=status_msg.chat.id,
|
|
message_id=status_msg.id
|
|
)
|
|
await db_call('update_user_request', status_msg.id, 'Empty text prompt ignored.')
|
|
return
|
|
|
|
logging.info(f'mid: {message.id}')
|
|
|
|
user_config = {**user_row}
|
|
del user_config['id']
|
|
|
|
if user_config['autoconf']:
|
|
user_config = perform_auto_conf(user_config)
|
|
|
|
params = {
|
|
'prompt': prompt,
|
|
**user_config
|
|
}
|
|
|
|
await db_call(
|
|
'update_user_stats', user.id, 'txt2img', last_prompt=prompt)
|
|
|
|
success = await work_request(user, status_msg, 'txt2img', params)
|
|
|
|
if success:
|
|
await db_call('increment_generated', user.id)
|
|
|
|
|
|
# generic img2img handler
|
|
|
|
async def _generic_img2img(message_or_query):
|
|
if isinstance(message_or_query, CallbackQuery):
|
|
query = message_or_query
|
|
message = query.message
|
|
user = query.from_user
|
|
chat = query.message.chat
|
|
|
|
else:
|
|
message = message_or_query
|
|
user = message.from_user
|
|
chat = message.chat
|
|
|
|
if chat.type == 'private':
|
|
return
|
|
|
|
reply_id = None
|
|
if chat.type == 'group' and chat.id == GROUP_ID:
|
|
reply_id = message.message_id
|
|
|
|
user_row = await db_call('get_or_create_user', user.id)
|
|
|
|
# init new msg
|
|
init_msg = 'started processing txt2img request...'
|
|
status_msg = await bot.reply_to(message, init_msg)
|
|
await db_call(
|
|
'new_user_request', user.id, message.id, status_msg.id, status=init_msg)
|
|
|
|
if not message.caption.startswith('/img2img'):
|
|
await bot.reply_to(
|
|
message,
|
|
'For image to image you need to add /img2img to the beggining of your caption'
|
|
)
|
|
return
|
|
|
|
prompt = ' '.join(message.caption.split(' ')[1:])
|
|
|
|
if len(prompt) == 0:
|
|
await bot.reply_to(message, 'Empty text prompt ignored.')
|
|
return
|
|
|
|
file_id = message.photo[-1].file_id
|
|
file_path = (await bot.get_file(file_id)).file_path
|
|
image_raw = await bot.download_file(file_path)
|
|
|
|
user_config = {**user_row}
|
|
del user_config['id']
|
|
if user_config['autoconf']:
|
|
user_config = perform_auto_conf(user_config)
|
|
|
|
with Image.open(io.BytesIO(image_raw)) as image:
|
|
w, h = image.size
|
|
|
|
if w > user_config['width'] or h > user_config['height']:
|
|
logging.warning(f'user sent img of size {image.size}')
|
|
image.thumbnail(
|
|
(user_config['width'], user_config['height']))
|
|
logging.warning(f'resized it to {image.size}')
|
|
|
|
image_loc = 'ipfs-staging/image.png'
|
|
image.save(image_loc, format='PNG')
|
|
|
|
ipfs_info = await ipfs_node.add(image_loc)
|
|
ipfs_hash = ipfs_info['Hash']
|
|
await ipfs_node.pin(ipfs_hash)
|
|
|
|
logging.info(f'published input image {ipfs_hash} on ipfs')
|
|
|
|
logging.info(f'mid: {message.id}')
|
|
|
|
params = {
|
|
'prompt': prompt,
|
|
**user_config
|
|
}
|
|
|
|
await db_call(
|
|
'update_user_stats',
|
|
user.id,
|
|
'img2img',
|
|
last_file=file_id,
|
|
last_prompt=prompt,
|
|
last_binary=ipfs_hash
|
|
)
|
|
|
|
success = await work_request(
|
|
user, status_msg, 'img2img', params,
|
|
file_id=file_id,
|
|
inputs=ipfs_hash
|
|
)
|
|
|
|
if success:
|
|
await db_call('increment_generated', user.id)
|
|
|
|
|
|
# generic redo handler
|
|
|
|
async def _redo(message_or_query):
|
|
is_query = False
|
|
if isinstance(message_or_query, CallbackQuery):
|
|
is_query = True
|
|
query = message_or_query
|
|
message = query.message
|
|
user = query.from_user
|
|
chat = query.message.chat
|
|
|
|
elif isinstance(message_or_query, Message):
|
|
message = message_or_query
|
|
user = message.from_user
|
|
chat = message.chat
|
|
|
|
if chat.type == 'private':
|
|
return
|
|
|
|
init_msg = 'started processing redo request...'
|
|
if is_query:
|
|
status_msg = await bot.send_message(chat.id, init_msg)
|
|
|
|
else:
|
|
status_msg = await bot.reply_to(message, init_msg)
|
|
|
|
method = await db_call('get_last_method_of', user.id)
|
|
prompt = await db_call('get_last_prompt_of', user.id)
|
|
|
|
file_id = None
|
|
binary = ''
|
|
if method == 'img2img':
|
|
file_id = await db_call('get_last_file_of', user.id)
|
|
binary = await db_call('get_last_binary_of', user.id)
|
|
|
|
if not prompt:
|
|
await bot.reply_to(
|
|
message,
|
|
'no last prompt found, do a txt2img cmd first!'
|
|
)
|
|
return
|
|
|
|
|
|
user_row = await db_call('get_or_create_user', user.id)
|
|
await db_call(
|
|
'new_user_request', user.id, message.id, status_msg.id, status=init_msg)
|
|
user_config = {**user_row}
|
|
del user_config['id']
|
|
if user_config['autoconf']:
|
|
user_config = perform_auto_conf(user_config)
|
|
|
|
params = {
|
|
'prompt': prompt,
|
|
**user_config
|
|
}
|
|
|
|
success = await work_request(
|
|
user, status_msg, 'redo', params,
|
|
file_id=file_id,
|
|
inputs=binary
|
|
)
|
|
|
|
if success:
|
|
await db_call('increment_generated', user.id)
|
|
|
|
|
|
# "proxy" handlers just request routers
|
|
|
|
@bot.message_handler(commands=['txt2img'])
|
|
async def send_txt2img(message):
|
|
await _generic_txt2img(message)
|
|
|
|
@bot.message_handler(func=lambda message: True, content_types=[
|
|
'photo', 'document'])
|
|
async def send_img2img(message):
|
|
await _generic_img2img(message)
|
|
|
|
@bot.message_handler(commands=['img2img'])
|
|
async def img2img_missing_image(message):
|
|
await bot.reply_to(
|
|
message,
|
|
'seems you tried to do an img2img command without sending image'
|
|
)
|
|
|
|
@bot.message_handler(commands=['redo'])
|
|
async def redo(message):
|
|
await _redo(message)
|
|
|
|
@bot.callback_query_handler(func=lambda call: True)
|
|
async def callback_query(call):
|
|
msg = json.loads(call.data)
|
|
logging.info(call.data)
|
|
method = msg.get('method')
|
|
match method:
|
|
case 'redo':
|
|
await _redo(call)
|
|
|
|
|
|
# catch all handler for things we dont support
|
|
|
|
@bot.message_handler(func=lambda message: True)
|
|
async def echo_message(message):
|
|
if message.text[0] == '/':
|
|
await bot.reply_to(message, UNKNOWN_CMD_TEXT)
|