mirror of https://github.com/skygpu/skynet.git
add and finalize buttons
parent
2440fe32db
commit
53ed74e9a3
|
@ -29,6 +29,7 @@ from .bot import DiscordBot
|
||||||
|
|
||||||
from .utils import *
|
from .utils import *
|
||||||
from .handlers import create_handler_context
|
from .handlers import create_handler_context
|
||||||
|
from .ui import SkynetView
|
||||||
|
|
||||||
|
|
||||||
class SkynetDiscordFrontend:
|
class SkynetDiscordFrontend:
|
||||||
|
@ -305,4 +306,4 @@ class SkynetDiscordFrontend:
|
||||||
|
|
||||||
embed.set_image(url=ipfs_link)
|
embed.set_image(url=ipfs_link)
|
||||||
embed.add_field(name='Parameters:', value=caption)
|
embed.add_field(name='Parameters:', value=caption)
|
||||||
await send(embed=embed)
|
await send(embed=embed, view=SkynetView(self))
|
||||||
|
|
|
@ -57,8 +57,8 @@ class DiscordBot(commands.Bot):
|
||||||
elif message.author == self.user:
|
elif message.author == self.user:
|
||||||
return
|
return
|
||||||
await self.process_commands(message)
|
await self.process_commands(message)
|
||||||
await asyncio.sleep(3)
|
# await asyncio.sleep(3)
|
||||||
await message.channel.send('', view=SkynetView(self.bot))
|
# await message.channel.send('', view=SkynetView(self.bot))
|
||||||
|
|
||||||
async def on_command_error(self, ctx, error):
|
async def on_command_error(self, ctx, error):
|
||||||
if isinstance(error, commands.MissingRequiredArgument):
|
if isinstance(error, commands.MissingRequiredArgument):
|
||||||
|
|
|
@ -11,6 +11,7 @@ from PIL import Image
|
||||||
|
|
||||||
from skynet.frontend import validate_user_config_request
|
from skynet.frontend import validate_user_config_request
|
||||||
from skynet.constants import *
|
from skynet.constants import *
|
||||||
|
from .ui import SkynetView
|
||||||
|
|
||||||
|
|
||||||
def create_handler_context(frontend: 'SkynetDiscordFrontend'):
|
def create_handler_context(frontend: 'SkynetDiscordFrontend'):
|
||||||
|
@ -38,26 +39,26 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
|
||||||
reply_txt = str(e)
|
reply_txt = str(e)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
await ctx.reply(content=reply_txt)
|
await ctx.reply(content=reply_txt, view=SkynetView(frontend))
|
||||||
|
|
||||||
@bot.command(name='helper', help='Responds with a help')
|
@bot.command(name='helper', help='Responds with a help')
|
||||||
async def helper(ctx):
|
async def helper(ctx):
|
||||||
splt_msg = ctx.message.content.split(' ')
|
splt_msg = ctx.message.content.split(' ')
|
||||||
|
|
||||||
if len(splt_msg) == 1:
|
if len(splt_msg) == 1:
|
||||||
await ctx.reply(content=HELP_TEXT)
|
await ctx.reply(content=HELP_TEXT, view=SkynetView(frontend))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
param = splt_msg[1]
|
param = splt_msg[1]
|
||||||
if param in HELP_TOPICS:
|
if param in HELP_TOPICS:
|
||||||
await ctx.reply(content=HELP_TOPICS[param])
|
await ctx.reply(content=HELP_TOPICS[param], view=SkynetView(frontend))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
await ctx.reply(content=HELP_UNKWNOWN_PARAM)
|
await ctx.reply(content=HELP_UNKWNOWN_PARAM, view=SkynetView(frontend))
|
||||||
|
|
||||||
@bot.command(name='cool', help='Display a list of cool prompt words')
|
@bot.command(name='cool', help='Display a list of cool prompt words')
|
||||||
async def send_cool_words(ctx):
|
async def send_cool_words(ctx):
|
||||||
await ctx.reply(content='\n'.join(CLEAN_COOL_WORDS))
|
await ctx.reply(content='\n'.join(CLEAN_COOL_WORDS), view=SkynetView(frontend))
|
||||||
|
|
||||||
@bot.command(name='txt2img', help='Responds with an image')
|
@bot.command(name='txt2img', help='Responds with an image')
|
||||||
async def send_txt2img(ctx):
|
async def send_txt2img(ctx):
|
||||||
|
@ -94,7 +95,7 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
|
||||||
await db_call(
|
await db_call(
|
||||||
'update_user_stats', user.id, 'txt2img', last_prompt=prompt)
|
'update_user_stats', user.id, 'txt2img', last_prompt=prompt)
|
||||||
|
|
||||||
ec = await work_request(user.name, status_msg, 'txt2img', params, ctx)
|
ec = await work_request(user, status_msg, 'txt2img', params, ctx)
|
||||||
|
|
||||||
if ec == 0:
|
if ec == 0:
|
||||||
await db_call('increment_generated', user.id)
|
await db_call('increment_generated', user.id)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import discord
|
import discord
|
||||||
import logging
|
import logging
|
||||||
from skynet.constants import *
|
from skynet.constants import *
|
||||||
|
from skynet.frontend import validate_user_config_request
|
||||||
|
|
||||||
|
|
||||||
class SkynetView(discord.ui.View):
|
class SkynetView(discord.ui.View):
|
||||||
|
@ -8,8 +9,11 @@ class SkynetView(discord.ui.View):
|
||||||
def __init__(self, bot):
|
def __init__(self, bot):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
super().__init__(timeout=None)
|
super().__init__(timeout=None)
|
||||||
self.add_item(Txt2ImgButton('Txt2Img', discord.ButtonStyle.green, self.bot))
|
self.add_item(RedoButton('redo', discord.ButtonStyle.green, self.bot))
|
||||||
self.add_item(HelpButton('Help', discord.ButtonStyle.grey))
|
self.add_item(Txt2ImgButton('txt2img', discord.ButtonStyle.green, self.bot))
|
||||||
|
self.add_item(ConfigButton('config', discord.ButtonStyle.grey, self.bot))
|
||||||
|
self.add_item(HelpButton('help', discord.ButtonStyle.grey, self.bot))
|
||||||
|
self.add_item(CoolButton('cool', discord.ButtonStyle.gray, self.bot))
|
||||||
|
|
||||||
|
|
||||||
class Txt2ImgButton(discord.ui.Button):
|
class Txt2ImgButton(discord.ui.Button):
|
||||||
|
@ -21,7 +25,7 @@ class Txt2ImgButton(discord.ui.Button):
|
||||||
async def callback(self, interaction):
|
async def callback(self, interaction):
|
||||||
db_call = self.bot.db_call
|
db_call = self.bot.db_call
|
||||||
work_request = self.bot.work_request
|
work_request = self.bot.work_request
|
||||||
msg = await grab('Text Prompt:', interaction)
|
msg = await grab('Enter your prompt:', interaction)
|
||||||
# grab user from msg
|
# grab user from msg
|
||||||
user = msg.author
|
user = msg.author
|
||||||
user_row = await db_call('get_or_create_user', user.id)
|
user_row = await db_call('get_or_create_user', user.id)
|
||||||
|
@ -60,9 +64,94 @@ class Txt2ImgButton(discord.ui.Button):
|
||||||
await db_call('increment_generated', user.id)
|
await db_call('increment_generated', user.id)
|
||||||
|
|
||||||
|
|
||||||
|
class RedoButton(discord.ui.Button):
|
||||||
|
|
||||||
|
def __init__(self, label: str, style: discord.ButtonStyle, bot):
|
||||||
|
self.bot = bot
|
||||||
|
super().__init__(label=label, style=style)
|
||||||
|
|
||||||
|
async def callback(self, interaction):
|
||||||
|
db_call = self.bot.db_call
|
||||||
|
work_request = self.bot.work_request
|
||||||
|
init_msg = 'started processing redo request...'
|
||||||
|
await interaction.response.send_message(init_msg)
|
||||||
|
status_msg = await interaction.original_response()
|
||||||
|
user = interaction.user
|
||||||
|
|
||||||
|
method = await db_call('get_last_method_of', user.id)
|
||||||
|
prompt = await db_call('get_last_prompt_of', user.id)
|
||||||
|
|
||||||
|
file_id = None
|
||||||
|
binary = ''
|
||||||
|
if method == 'img2img':
|
||||||
|
file_id = await db_call('get_last_file_of', user.id)
|
||||||
|
binary = await db_call('get_last_binary_of', user.id)
|
||||||
|
|
||||||
|
if not prompt:
|
||||||
|
await interaction.response.edit_message(
|
||||||
|
'no last prompt found, do a txt2img cmd first!'
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
user_row = await db_call('get_or_create_user', user.id)
|
||||||
|
await db_call(
|
||||||
|
'new_user_request', user.id, interaction.id, status_msg.id, status=init_msg)
|
||||||
|
user_config = {**user_row}
|
||||||
|
del user_config['id']
|
||||||
|
|
||||||
|
params = {
|
||||||
|
'prompt': prompt,
|
||||||
|
**user_config
|
||||||
|
}
|
||||||
|
await work_request(
|
||||||
|
user, status_msg, 'redo', params, interaction,
|
||||||
|
file_id=file_id,
|
||||||
|
binary_data=binary
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigButton(discord.ui.Button):
|
||||||
|
|
||||||
|
def __init__(self, label: str, style: discord.ButtonStyle, bot):
|
||||||
|
self.bot = bot
|
||||||
|
super().__init__(label=label, style=style)
|
||||||
|
|
||||||
|
async def callback(self, interaction):
|
||||||
|
db_call = self.bot.db_call
|
||||||
|
msg = await grab('What params do you want to change? (format: <param> <value>)', interaction)
|
||||||
|
|
||||||
|
user = interaction.user
|
||||||
|
try:
|
||||||
|
attr, val, reply_txt = validate_user_config_request(
|
||||||
|
'/config ' + msg.content)
|
||||||
|
|
||||||
|
logging.info(f'user config update: {attr} to {val}')
|
||||||
|
await db_call('update_user_config', user.id, attr, val)
|
||||||
|
logging.info('done')
|
||||||
|
|
||||||
|
except BaseException as e:
|
||||||
|
reply_txt = str(e)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await msg.reply(content=reply_txt, view=SkynetView(self.bot))
|
||||||
|
|
||||||
|
|
||||||
|
class CoolButton(discord.ui.Button):
|
||||||
|
|
||||||
|
def __init__(self, label: str, style: discord.ButtonStyle, bot):
|
||||||
|
self.bot = bot
|
||||||
|
super().__init__(label=label, style=style)
|
||||||
|
|
||||||
|
async def callback(self, interaction):
|
||||||
|
await interaction.response.send_message(
|
||||||
|
content='\n'.join(CLEAN_COOL_WORDS),
|
||||||
|
view=SkynetView(self.bot))
|
||||||
|
|
||||||
|
|
||||||
class HelpButton(discord.ui.Button):
|
class HelpButton(discord.ui.Button):
|
||||||
|
|
||||||
def __init__(self, label:str, style:discord.ButtonStyle):
|
def __init__(self, label: str, style: discord.ButtonStyle, bot):
|
||||||
|
self.bot = bot
|
||||||
super().__init__(label=label, style=style)
|
super().__init__(label=label, style=style)
|
||||||
|
|
||||||
async def callback(self, interaction):
|
async def callback(self, interaction):
|
||||||
|
@ -71,14 +160,14 @@ class HelpButton(discord.ui.Button):
|
||||||
param = msg.content
|
param = msg.content
|
||||||
|
|
||||||
if param == 'a':
|
if param == 'a':
|
||||||
await msg.reply(content=HELP_TEXT)
|
await msg.reply(content=HELP_TEXT, view=SkynetView(self.bot))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if param in HELP_TOPICS:
|
if param in HELP_TOPICS:
|
||||||
await msg.reply(content=HELP_TOPICS[param])
|
await msg.reply(content=HELP_TOPICS[param], view=SkynetView(self.bot))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
await msg.reply(content=HELP_UNKWNOWN_PARAM)
|
await msg.reply(content=HELP_UNKWNOWN_PARAM, view=SkynetView(self.bot))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue