add and finalize buttons

pull/11/head
Konstantine Tsafatinos 2023-07-20 20:54:59 -04:00
parent 2440fe32db
commit 53ed74e9a3
4 changed files with 111 additions and 20 deletions

View File

@ -29,6 +29,7 @@ from .bot import DiscordBot
from .utils import * from .utils import *
from .handlers import create_handler_context from .handlers import create_handler_context
from .ui import SkynetView
class SkynetDiscordFrontend: class SkynetDiscordFrontend:
@ -260,7 +261,7 @@ class SkynetDiscordFrontend:
# attempt to get the image and send it # attempt to get the image and send it
ipfs_link = f'https://ipfs.{DEFAULT_DOMAIN}/ipfs/{ipfs_hash}/image.png' ipfs_link = f'https://ipfs.{DEFAULT_DOMAIN}/ipfs/{ipfs_hash}/image.png'
resp = await get_ipfs_file(ipfs_link) resp = await get_ipfs_file(ipfs_link)
caption, embed = generate_reply_caption( caption, embed = generate_reply_caption(
user, params, tx_hash, worker, reward) user, params, tx_hash, worker, reward)
@ -305,4 +306,4 @@ class SkynetDiscordFrontend:
embed.set_image(url=ipfs_link) embed.set_image(url=ipfs_link)
embed.add_field(name='Parameters:', value=caption) embed.add_field(name='Parameters:', value=caption)
await send(embed=embed) await send(embed=embed, view=SkynetView(self))

View File

@ -57,8 +57,8 @@ class DiscordBot(commands.Bot):
elif message.author == self.user: elif message.author == self.user:
return return
await self.process_commands(message) await self.process_commands(message)
await asyncio.sleep(3) # await asyncio.sleep(3)
await message.channel.send('', view=SkynetView(self.bot)) # await message.channel.send('', view=SkynetView(self.bot))
async def on_command_error(self, ctx, error): async def on_command_error(self, ctx, error):
if isinstance(error, commands.MissingRequiredArgument): if isinstance(error, commands.MissingRequiredArgument):

View File

@ -11,6 +11,7 @@ from PIL import Image
from skynet.frontend import validate_user_config_request from skynet.frontend import validate_user_config_request
from skynet.constants import * from skynet.constants import *
from .ui import SkynetView
def create_handler_context(frontend: 'SkynetDiscordFrontend'): def create_handler_context(frontend: 'SkynetDiscordFrontend'):
@ -38,26 +39,26 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
reply_txt = str(e) reply_txt = str(e)
finally: finally:
await ctx.reply(content=reply_txt) await ctx.reply(content=reply_txt, view=SkynetView(frontend))
@bot.command(name='helper', help='Responds with a help') @bot.command(name='helper', help='Responds with a help')
async def helper(ctx): async def helper(ctx):
splt_msg = ctx.message.content.split(' ') splt_msg = ctx.message.content.split(' ')
if len(splt_msg) == 1: if len(splt_msg) == 1:
await ctx.reply(content=HELP_TEXT) await ctx.reply(content=HELP_TEXT, view=SkynetView(frontend))
else: else:
param = splt_msg[1] param = splt_msg[1]
if param in HELP_TOPICS: if param in HELP_TOPICS:
await ctx.reply(content=HELP_TOPICS[param]) await ctx.reply(content=HELP_TOPICS[param], view=SkynetView(frontend))
else: else:
await ctx.reply(content=HELP_UNKWNOWN_PARAM) await ctx.reply(content=HELP_UNKWNOWN_PARAM, view=SkynetView(frontend))
@bot.command(name='cool', help='Display a list of cool prompt words') @bot.command(name='cool', help='Display a list of cool prompt words')
async def send_cool_words(ctx): async def send_cool_words(ctx):
await ctx.reply(content='\n'.join(CLEAN_COOL_WORDS)) await ctx.reply(content='\n'.join(CLEAN_COOL_WORDS), view=SkynetView(frontend))
@bot.command(name='txt2img', help='Responds with an image') @bot.command(name='txt2img', help='Responds with an image')
async def send_txt2img(ctx): async def send_txt2img(ctx):
@ -94,7 +95,7 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
await db_call( await db_call(
'update_user_stats', user.id, 'txt2img', last_prompt=prompt) 'update_user_stats', user.id, 'txt2img', last_prompt=prompt)
ec = await work_request(user.name, status_msg, 'txt2img', params, ctx) ec = await work_request(user, status_msg, 'txt2img', params, ctx)
if ec == 0: if ec == 0:
await db_call('increment_generated', user.id) await db_call('increment_generated', user.id)

View File

@ -1,6 +1,7 @@
import discord import discord
import logging import logging
from skynet.constants import * from skynet.constants import *
from skynet.frontend import validate_user_config_request
class SkynetView(discord.ui.View): class SkynetView(discord.ui.View):
@ -8,20 +9,23 @@ class SkynetView(discord.ui.View):
def __init__(self, bot): def __init__(self, bot):
self.bot = bot self.bot = bot
super().__init__(timeout=None) super().__init__(timeout=None)
self.add_item(Txt2ImgButton('Txt2Img', discord.ButtonStyle.green, self.bot)) self.add_item(RedoButton('redo', discord.ButtonStyle.green, self.bot))
self.add_item(HelpButton('Help', discord.ButtonStyle.grey)) self.add_item(Txt2ImgButton('txt2img', discord.ButtonStyle.green, self.bot))
self.add_item(ConfigButton('config', discord.ButtonStyle.grey, self.bot))
self.add_item(HelpButton('help', discord.ButtonStyle.grey, self.bot))
self.add_item(CoolButton('cool', discord.ButtonStyle.gray, self.bot))
class Txt2ImgButton(discord.ui.Button): class Txt2ImgButton(discord.ui.Button):
def __init__(self, label:str, style:discord.ButtonStyle, bot): def __init__(self, label: str, style: discord.ButtonStyle, bot):
self.bot = bot self.bot = bot
super().__init__(label=label, style = style) super().__init__(label=label, style=style)
async def callback(self, interaction): async def callback(self, interaction):
db_call = self.bot.db_call db_call = self.bot.db_call
work_request = self.bot.work_request work_request = self.bot.work_request
msg = await grab('Text Prompt:', interaction) msg = await grab('Enter your prompt:', interaction)
# grab user from msg # grab user from msg
user = msg.author user = msg.author
user_row = await db_call('get_or_create_user', user.id) user_row = await db_call('get_or_create_user', user.id)
@ -60,10 +64,95 @@ class Txt2ImgButton(discord.ui.Button):
await db_call('increment_generated', user.id) await db_call('increment_generated', user.id)
class RedoButton(discord.ui.Button):
def __init__(self, label: str, style: discord.ButtonStyle, bot):
self.bot = bot
super().__init__(label=label, style=style)
async def callback(self, interaction):
db_call = self.bot.db_call
work_request = self.bot.work_request
init_msg = 'started processing redo request...'
await interaction.response.send_message(init_msg)
status_msg = await interaction.original_response()
user = interaction.user
method = await db_call('get_last_method_of', user.id)
prompt = await db_call('get_last_prompt_of', user.id)
file_id = None
binary = ''
if method == 'img2img':
file_id = await db_call('get_last_file_of', user.id)
binary = await db_call('get_last_binary_of', user.id)
if not prompt:
await interaction.response.edit_message(
'no last prompt found, do a txt2img cmd first!'
)
return
user_row = await db_call('get_or_create_user', user.id)
await db_call(
'new_user_request', user.id, interaction.id, status_msg.id, status=init_msg)
user_config = {**user_row}
del user_config['id']
params = {
'prompt': prompt,
**user_config
}
await work_request(
user, status_msg, 'redo', params, interaction,
file_id=file_id,
binary_data=binary
)
class ConfigButton(discord.ui.Button):
def __init__(self, label: str, style: discord.ButtonStyle, bot):
self.bot = bot
super().__init__(label=label, style=style)
async def callback(self, interaction):
db_call = self.bot.db_call
msg = await grab('What params do you want to change? (format: <param> <value>)', interaction)
user = interaction.user
try:
attr, val, reply_txt = validate_user_config_request(
'/config ' + msg.content)
logging.info(f'user config update: {attr} to {val}')
await db_call('update_user_config', user.id, attr, val)
logging.info('done')
except BaseException as e:
reply_txt = str(e)
finally:
await msg.reply(content=reply_txt, view=SkynetView(self.bot))
class CoolButton(discord.ui.Button):
def __init__(self, label: str, style: discord.ButtonStyle, bot):
self.bot = bot
super().__init__(label=label, style=style)
async def callback(self, interaction):
await interaction.response.send_message(
content='\n'.join(CLEAN_COOL_WORDS),
view=SkynetView(self.bot))
class HelpButton(discord.ui.Button): class HelpButton(discord.ui.Button):
def __init__(self, label:str, style:discord.ButtonStyle): def __init__(self, label: str, style: discord.ButtonStyle, bot):
super().__init__(label=label, style = style) self.bot = bot
super().__init__(label=label, style=style)
async def callback(self, interaction): async def callback(self, interaction):
msg = await grab('What would you like help with? (a for all)', interaction) msg = await grab('What would you like help with? (a for all)', interaction)
@ -71,14 +160,14 @@ class HelpButton(discord.ui.Button):
param = msg.content param = msg.content
if param == 'a': if param == 'a':
await msg.reply(content=HELP_TEXT) await msg.reply(content=HELP_TEXT, view=SkynetView(self.bot))
else: else:
if param in HELP_TOPICS: if param in HELP_TOPICS:
await msg.reply(content=HELP_TOPICS[param]) await msg.reply(content=HELP_TOPICS[param], view=SkynetView(self.bot))
else: else:
await msg.reply(content=HELP_UNKWNOWN_PARAM) await msg.reply(content=HELP_UNKWNOWN_PARAM, view=SkynetView(self.bot))