skynet/skynet/frontend/discord/handlers.py

604 lines
19 KiB
Python

#!/usr/bin/python
import io
import json
import logging
from datetime import datetime, timedelta
from PIL import Image
# from telebot.types import CallbackQuery, Message
from skynet.frontend import validate_user_config_request
from skynet.constants import *
from .ui import SkynetView
def create_handler_context(frontend: 'SkynetDiscordFrontend'):
bot = frontend.bot
cleos = frontend.cleos
db_call = frontend.db_call
work_request = frontend.work_request
ipfs_node = frontend.ipfs_node
@bot.command(name='config', help='Responds with the configuration')
async def set_config(ctx):
user = ctx.author
try:
attr, val, reply_txt = validate_user_config_request(
ctx.message.content)
logging.info(f'user config update: {attr} to {val}')
await db_call('update_user_config', user.id, attr, val)
logging.info('done')
except BaseException as e:
reply_txt = str(e)
finally:
await ctx.reply(content=reply_txt, view=SkynetView(frontend))
bot.remove_command('help')
@bot.command(name='help', help='Responds with a help')
async def help(ctx):
splt_msg = ctx.message.content.split(' ')
if len(splt_msg) == 1:
await ctx.send(content=f'```{HELP_TEXT}```', view=SkynetView(frontend))
else:
param = splt_msg[1]
if param in HELP_TOPICS:
await ctx.send(content=f'```{HELP_TOPICS[param]}```', view=SkynetView(frontend))
else:
await ctx.send(content=f'```{HELP_UNKWNOWN_PARAM}```', view=SkynetView(frontend))
@bot.command(name='cool', help='Display a list of cool prompt words')
async def send_cool_words(ctx):
clean_cool_word = '\n'.join(CLEAN_COOL_WORDS)
await ctx.send(content=f'```{clean_cool_word}```', view=SkynetView(frontend))
@bot.command(name='stats', help='See user statistics' )
async def user_stats(ctx):
user = ctx.author
await db_call('get_or_create_user', user.id)
generated, joined, role = await db_call('get_user_stats', user.id)
stats_str = f'```generated: {generated}\n'
stats_str += f'joined: {joined}\n'
stats_str += f'role: {role}\n```'
await ctx.reply(stats_str, view=SkynetView(frontend))
@bot.command(name='donate', help='See donate info')
async def donation_info(ctx):
await ctx.reply(
f'```\n{DONATION_INFO}```', view=SkynetView(frontend))
@bot.command(name='txt2img', help='Responds with an image')
async def send_txt2img(ctx):
# grab user from ctx
user = ctx.author
user_row = await db_call('get_or_create_user', user.id)
# init new msg
init_msg = 'started processing txt2img request...'
status_msg = await ctx.send(init_msg)
await db_call(
'new_user_request', user.id, ctx.message.id, status_msg.id, status=init_msg)
prompt = ' '.join(ctx.message.content.split(' ')[1:])
if len(prompt) == 0:
await status_msg.edit(content=
'Empty text prompt ignored.'
)
await db_call('update_user_request', status_msg.id, 'Empty text prompt ignored.')
return
logging.info(f'mid: {ctx.message.id}')
user_config = {**user_row}
del user_config['id']
params = {
'prompt': prompt,
**user_config
}
await db_call(
'update_user_stats', user.id, 'txt2img', last_prompt=prompt)
success = await work_request(user, status_msg, 'txt2img', params, ctx)
if success:
await db_call('increment_generated', user.id)
@bot.command(name='redo', help='Redo last request')
async def redo(ctx):
init_msg = 'started processing redo request...'
status_msg = await ctx.send(init_msg)
user = ctx.author
method = await db_call('get_last_method_of', user.id)
prompt = await db_call('get_last_prompt_of', user.id)
file_id = None
binary = ''
if method == 'img2img':
file_id = await db_call('get_last_file_of', user.id)
binary = await db_call('get_last_binary_of', user.id)
if not prompt:
await status_msg.edit(
content='no last prompt found, do a txt2img cmd first!',
view=SkynetView(frontend)
)
return
user_row = await db_call('get_or_create_user', user.id)
await db_call(
'new_user_request', user.id, ctx.message.id, status_msg.id, status=init_msg)
user_config = {**user_row}
del user_config['id']
params = {
'prompt': prompt,
**user_config
}
success = await work_request(
user, status_msg, 'redo', params, ctx,
file_id=file_id,
binary_data=binary
)
if success:
await db_call('increment_generated', user.id)
@bot.command(name='img2img', help='Responds with an image')
async def send_img2img(ctx):
# if isinstance(message_or_query, CallbackQuery):
# query = message_or_query
# message = query.message
# user = query.from_user
# chat = query.message.chat
#
# else:
# message = message_or_query
# user = message.from_user
# chat = message.chat
# reply_id = None
# if chat.type == 'group' and chat.id == GROUP_ID:
# reply_id = message.message_id
#
user = ctx.author
user_row = await db_call('get_or_create_user', user.id)
# init new msg
init_msg = 'started processing img2img request...'
status_msg = await ctx.send(init_msg)
await db_call(
'new_user_request', user.id, ctx.message.id, status_msg.id, status=init_msg)
if not ctx.message.content.startswith('/img2img'):
await ctx.reply(
'For image to image you need to add /img2img to the beggining of your caption'
)
return
prompt = ' '.join(ctx.message.content.split(' ')[1:])
if len(prompt) == 0:
await ctx.reply('Empty text prompt ignored.')
return
# file_id = message.photo[-1].file_id
# file_path = (await bot.get_file(file_id)).file_path
# image_raw = await bot.download_file(file_path)
#
file = ctx.message.attachments[-1]
file_id = str(file.id)
# file bytes
image_raw = await file.read()
with Image.open(io.BytesIO(image_raw)) as image:
w, h = image.size
if w > 512 or h > 512:
logging.warning(f'user sent img of size {image.size}')
image.thumbnail((512, 512))
logging.warning(f'resized it to {image.size}')
image_loc = 'ipfs-staging/image.png'
image.save(image_loc, format='PNG')
ipfs_info = await ipfs_node.add(image_loc)
ipfs_hash = ipfs_info['Hash']
await ipfs_node.pin(ipfs_hash)
logging.info(f'published input image {ipfs_hash} on ipfs')
logging.info(f'mid: {ctx.message.id}')
user_config = {**user_row}
del user_config['id']
params = {
'prompt': prompt,
**user_config
}
await db_call(
'update_user_stats',
user.id,
'img2img',
last_file=file_id,
last_prompt=prompt,
last_binary=ipfs_hash
)
success = await work_request(
user, status_msg, 'img2img', params, ctx,
file_id=file_id,
binary_data=ipfs_hash
)
if success:
await db_call('increment_generated', user.id)
# TODO: DELETE BELOW
# user = 'testworker3'
# status_msg = 'status'
# params = {
# 'prompt': arg,
# 'seed': None,
# 'step': 35,
# 'guidance': 7.5,
# 'strength': 0.5,
# 'width': 512,
# 'height': 512,
# 'upscaler': None,
# 'model': 'prompthero/openjourney',
# }
#
# ec = await work_request(user, status_msg, 'txt2img', params, ctx)
# print(ec)
# if ec == 0:
# await db_call('increment_generated', user.id)
# response = f"This is your prompt: {arg}"
# await ctx.send(response)
# generic / simple handlers
# @bot.message_handler(commands=['help'])
# async def send_help(message):
# splt_msg = message.text.split(' ')
#
# if len(splt_msg) == 1:
# await bot.reply_to(message, HELP_TEXT)
#
# else:
# param = splt_msg[1]
# if param in HELP_TOPICS:
# await bot.reply_to(message, HELP_TOPICS[param])
#
# else:
# await bot.reply_to(message, HELP_UNKWNOWN_PARAM)
#
# @bot.message_handler(commands=['cool'])
# async def send_cool_words(message):
# await bot.reply_to(message, '\n'.join(COOL_WORDS))
#
# @bot.message_handler(commands=['queue'])
# async def queue(message):
# an_hour_ago = datetime.now() - timedelta(hours=1)
# queue = await cleos.aget_table(
# 'gpu.scd', 'gpu.scd', 'queue',
# index_position=2,
# key_type='i64',
# sort='desc',
# lower_bound=int(an_hour_ago.timestamp())
# )
# await bot.reply_to(
# message, f'Total requests on skynet queue: {len(queue)}')
# @bot.message_handler(commands=['config'])
# async def set_config(message):
# user = message.from_user.id
# try:
# attr, val, reply_txt = validate_user_config_request(
# message.text)
#
# logging.info(f'user config update: {attr} to {val}')
# await db_call('update_user_config', user, attr, val)
# logging.info('done')
#
# except BaseException as e:
# reply_txt = str(e)
#
# finally:
# await bot.reply_to(message, reply_txt)
#
# @bot.message_handler(commands=['stats'])
# async def user_stats(message):
# user = message.from_user.id
#
# await db_call('get_or_create_user', user)
# generated, joined, role = await db_call('get_user_stats', user)
#
# stats_str = f'generated: {generated}\n'
# stats_str += f'joined: {joined}\n'
# stats_str += f'role: {role}\n'
#
# await bot.reply_to(
# message, stats_str)
#
# @bot.message_handler(commands=['donate'])
# async def donation_info(message):
# await bot.reply_to(
# message, DONATION_INFO)
#
# @bot.message_handler(commands=['say'])
# async def say(message):
# chat = message.chat
# user = message.from_user
#
# if (chat.type == 'group') or (user.id != 383385940):
# return
#
# await bot.send_message(GROUP_ID, message.text[4:])
# generic txt2img handler
# async def _generic_txt2img(message_or_query):
# if isinstance(message_or_query, CallbackQuery):
# query = message_or_query
# message = query.message
# user = query.from_user
# chat = query.message.chat
#
# else:
# message = message_or_query
# user = message.from_user
# chat = message.chat
#
# reply_id = None
# if chat.type == 'group' and chat.id == GROUP_ID:
# reply_id = message.message_id
#
# user_row = await db_call('get_or_create_user', user.id)
#
# # init new msg
# init_msg = 'started processing txt2img request...'
# status_msg = await bot.reply_to(message, init_msg)
# await db_call(
# 'new_user_request', user.id, message.id, status_msg.id, status=init_msg)
#
# prompt = ' '.join(message.text.split(' ')[1:])
#
# if len(prompt) == 0:
# await bot.edit_message_text(
# 'Empty text prompt ignored.',
# chat_id=status_msg.chat.id,
# message_id=status_msg.id
# )
# await db_call('update_user_request', status_msg.id, 'Empty text prompt ignored.')
# return
#
# logging.info(f'mid: {message.id}')
#
# user_config = {**user_row}
# del user_config['id']
#
# params = {
# 'prompt': prompt,
# **user_config
# }
#
# await db_call(
# 'update_user_stats', user.id, 'txt2img', last_prompt=prompt)
#
# ec = await work_request(user, status_msg, 'txt2img', params)
# if ec == 0:
# await db_call('increment_generated', user.id)
#
#
# # generic img2img handler
#
# async def _generic_img2img(message_or_query):
# if isinstance(message_or_query, CallbackQuery):
# query = message_or_query
# message = query.message
# user = query.from_user
# chat = query.message.chat
#
# else:
# message = message_or_query
# user = message.from_user
# chat = message.chat
#
# reply_id = None
# if chat.type == 'group' and chat.id == GROUP_ID:
# reply_id = message.message_id
#
# user_row = await db_call('get_or_create_user', user.id)
#
# # init new msg
# init_msg = 'started processing txt2img request...'
# status_msg = await bot.reply_to(message, init_msg)
# await db_call(
# 'new_user_request', user.id, message.id, status_msg.id, status=init_msg)
#
# if not message.caption.startswith('/img2img'):
# await bot.reply_to(
# message,
# 'For image to image you need to add /img2img to the beggining of your caption'
# )
# return
#
# prompt = ' '.join(message.caption.split(' ')[1:])
#
# if len(prompt) == 0:
# await bot.reply_to(message, 'Empty text prompt ignored.')
# return
#
# file_id = message.photo[-1].file_id
# file_path = (await bot.get_file(file_id)).file_path
# image_raw = await bot.download_file(file_path)
# with Image.open(io.BytesIO(image_raw)) as image:
# w, h = image.size
#
# if w > 512 or h > 512:
# logging.warning(f'user sent img of size {image.size}')
# image.thumbnail((512, 512))
# logging.warning(f'resized it to {image.size}')
#
# image.save(f'ipfs-docker-staging/image.png', format='PNG')
#
# ipfs_hash = ipfs_node.add('image.png')
# ipfs_node.pin(ipfs_hash)
#
# logging.info(f'published input image {ipfs_hash} on ipfs')
#
# logging.info(f'mid: {message.id}')
#
# user_config = {**user_row}
# del user_config['id']
#
# params = {
# 'prompt': prompt,
# **user_config
# }
#
# await db_call(
# 'update_user_stats',
# user.id,
# 'img2img',
# last_file=file_id,
# last_prompt=prompt,
# last_binary=ipfs_hash
# )
#
# ec = await work_request(
# user, status_msg, 'img2img', params,
# file_id=file_id,
# binary_data=ipfs_hash
# )
#
# if ec == 0:
# await db_call('increment_generated', user.id)
#
# generic redo handler
# async def _redo(message_or_query):
# is_query = False
# if isinstance(message_or_query, CallbackQuery):
# is_query = True
# query = message_or_query
# message = query.message
# user = query.from_user
# chat = query.message.chat
#
# elif isinstance(message_or_query, Message):
# message = message_or_query
# user = message.from_user
# chat = message.chat
#
# init_msg = 'started processing redo request...'
# if is_query:
# status_msg = await bot.send_message(chat.id, init_msg)
#
# else:
# status_msg = await bot.reply_to(message, init_msg)
#
# method = await db_call('get_last_method_of', user.id)
# prompt = await db_call('get_last_prompt_of', user.id)
#
# file_id = None
# binary = ''
# if method == 'img2img':
# file_id = await db_call('get_last_file_of', user.id)
# binary = await db_call('get_last_binary_of', user.id)
#
# if not prompt:
# await bot.reply_to(
# message,
# 'no last prompt found, do a txt2img cmd first!'
# )
# return
#
#
# user_row = await db_call('get_or_create_user', user.id)
# await db_call(
# 'new_user_request', user.id, message.id, status_msg.id, status=init_msg)
# user_config = {**user_row}
# del user_config['id']
#
# params = {
# 'prompt': prompt,
# **user_config
# }
#
# await work_request(
# user, status_msg, 'redo', params,
# file_id=file_id,
# binary_data=binary
# )
# "proxy" handlers just request routers
# @bot.message_handler(commands=['txt2img'])
# async def send_txt2img(message):
# await _generic_txt2img(message)
#
# @bot.message_handler(func=lambda message: True, content_types=[
# 'photo', 'document'])
# async def send_img2img(message):
# await _generic_img2img(message)
#
# @bot.message_handler(commands=['img2img'])
# async def img2img_missing_image(message):
# await bot.reply_to(
# message,
# 'seems you tried to do an img2img command without sending image'
# )
#
# @bot.message_handler(commands=['redo'])
# async def redo(message):
# await _redo(message)
#
# @bot.callback_query_handler(func=lambda call: True)
# async def callback_query(call):
# msg = json.loads(call.data)
# logging.info(call.data)
# method = msg.get('method')
# match method:
# case 'redo':
# await _redo(call)
# catch all handler for things we dont support
# @bot.message_handler(func=lambda message: True)
# async def echo_message(message):
# if message.text[0] == '/':
# await bot.reply_to(message, UNKNOWN_CMD_TEXT)