diff --git a/skynet/frontend/discord/__init__.py b/skynet/frontend/discord/__init__.py
index 29809e8..50878d3 100644
--- a/skynet/frontend/discord/__init__.py
+++ b/skynet/frontend/discord/__init__.py
@@ -29,6 +29,7 @@ from .bot import DiscordBot
from .utils import *
from .handlers import create_handler_context
+from .ui import SkynetView
class SkynetDiscordFrontend:
@@ -260,7 +261,7 @@ class SkynetDiscordFrontend:
# attempt to get the image and send it
ipfs_link = f'https://ipfs.{DEFAULT_DOMAIN}/ipfs/{ipfs_hash}/image.png'
resp = await get_ipfs_file(ipfs_link)
-
+
caption, embed = generate_reply_caption(
user, params, tx_hash, worker, reward)
@@ -305,4 +306,4 @@ class SkynetDiscordFrontend:
embed.set_image(url=ipfs_link)
embed.add_field(name='Parameters:', value=caption)
- await send(embed=embed)
+ await send(embed=embed, view=SkynetView(self))
diff --git a/skynet/frontend/discord/bot.py b/skynet/frontend/discord/bot.py
index cde4144..ac8744f 100644
--- a/skynet/frontend/discord/bot.py
+++ b/skynet/frontend/discord/bot.py
@@ -57,8 +57,8 @@ class DiscordBot(commands.Bot):
elif message.author == self.user:
return
await self.process_commands(message)
- await asyncio.sleep(3)
- await message.channel.send('', view=SkynetView(self.bot))
+ # await asyncio.sleep(3)
+ # await message.channel.send('', view=SkynetView(self.bot))
async def on_command_error(self, ctx, error):
if isinstance(error, commands.MissingRequiredArgument):
diff --git a/skynet/frontend/discord/handlers.py b/skynet/frontend/discord/handlers.py
index bc0cf56..dbf19f4 100644
--- a/skynet/frontend/discord/handlers.py
+++ b/skynet/frontend/discord/handlers.py
@@ -11,6 +11,7 @@ from PIL import Image
from skynet.frontend import validate_user_config_request
from skynet.constants import *
+from .ui import SkynetView
def create_handler_context(frontend: 'SkynetDiscordFrontend'):
@@ -38,26 +39,26 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
reply_txt = str(e)
finally:
- await ctx.reply(content=reply_txt)
+ await ctx.reply(content=reply_txt, view=SkynetView(frontend))
@bot.command(name='helper', help='Responds with a help')
async def helper(ctx):
splt_msg = ctx.message.content.split(' ')
if len(splt_msg) == 1:
- await ctx.reply(content=HELP_TEXT)
+ await ctx.reply(content=HELP_TEXT, view=SkynetView(frontend))
else:
param = splt_msg[1]
if param in HELP_TOPICS:
- await ctx.reply(content=HELP_TOPICS[param])
+ await ctx.reply(content=HELP_TOPICS[param], view=SkynetView(frontend))
else:
- await ctx.reply(content=HELP_UNKWNOWN_PARAM)
+ await ctx.reply(content=HELP_UNKWNOWN_PARAM, view=SkynetView(frontend))
@bot.command(name='cool', help='Display a list of cool prompt words')
async def send_cool_words(ctx):
- await ctx.reply(content='\n'.join(CLEAN_COOL_WORDS))
+ await ctx.reply(content='\n'.join(CLEAN_COOL_WORDS), view=SkynetView(frontend))
@bot.command(name='txt2img', help='Responds with an image')
async def send_txt2img(ctx):
@@ -94,7 +95,7 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
await db_call(
'update_user_stats', user.id, 'txt2img', last_prompt=prompt)
- ec = await work_request(user.name, status_msg, 'txt2img', params, ctx)
+ ec = await work_request(user, status_msg, 'txt2img', params, ctx)
if ec == 0:
await db_call('increment_generated', user.id)
diff --git a/skynet/frontend/discord/ui.py b/skynet/frontend/discord/ui.py
index cb54d5c..0f9452e 100644
--- a/skynet/frontend/discord/ui.py
+++ b/skynet/frontend/discord/ui.py
@@ -1,6 +1,7 @@
import discord
import logging
from skynet.constants import *
+from skynet.frontend import validate_user_config_request
class SkynetView(discord.ui.View):
@@ -8,20 +9,23 @@ class SkynetView(discord.ui.View):
def __init__(self, bot):
self.bot = bot
super().__init__(timeout=None)
- self.add_item(Txt2ImgButton('Txt2Img', discord.ButtonStyle.green, self.bot))
- self.add_item(HelpButton('Help', discord.ButtonStyle.grey))
+ self.add_item(RedoButton('redo', discord.ButtonStyle.green, self.bot))
+ self.add_item(Txt2ImgButton('txt2img', discord.ButtonStyle.green, self.bot))
+ self.add_item(ConfigButton('config', discord.ButtonStyle.grey, self.bot))
+ self.add_item(HelpButton('help', discord.ButtonStyle.grey, self.bot))
+ self.add_item(CoolButton('cool', discord.ButtonStyle.gray, self.bot))
class Txt2ImgButton(discord.ui.Button):
- def __init__(self, label:str, style:discord.ButtonStyle, bot):
+ def __init__(self, label: str, style: discord.ButtonStyle, bot):
self.bot = bot
- super().__init__(label=label, style = style)
+ super().__init__(label=label, style=style)
async def callback(self, interaction):
db_call = self.bot.db_call
work_request = self.bot.work_request
- msg = await grab('Text Prompt:', interaction)
+ msg = await grab('Enter your prompt:', interaction)
# grab user from msg
user = msg.author
user_row = await db_call('get_or_create_user', user.id)
@@ -60,10 +64,95 @@ class Txt2ImgButton(discord.ui.Button):
await db_call('increment_generated', user.id)
+class RedoButton(discord.ui.Button):
+
+ def __init__(self, label: str, style: discord.ButtonStyle, bot):
+ self.bot = bot
+ super().__init__(label=label, style=style)
+
+ async def callback(self, interaction):
+ db_call = self.bot.db_call
+ work_request = self.bot.work_request
+ init_msg = 'started processing redo request...'
+ await interaction.response.send_message(init_msg)
+ status_msg = await interaction.original_response()
+ user = interaction.user
+
+ method = await db_call('get_last_method_of', user.id)
+ prompt = await db_call('get_last_prompt_of', user.id)
+
+ file_id = None
+ binary = ''
+ if method == 'img2img':
+ file_id = await db_call('get_last_file_of', user.id)
+ binary = await db_call('get_last_binary_of', user.id)
+
+ if not prompt:
+ await interaction.response.edit_message(
+ 'no last prompt found, do a txt2img cmd first!'
+ )
+ return
+
+ user_row = await db_call('get_or_create_user', user.id)
+ await db_call(
+ 'new_user_request', user.id, interaction.id, status_msg.id, status=init_msg)
+ user_config = {**user_row}
+ del user_config['id']
+
+ params = {
+ 'prompt': prompt,
+ **user_config
+ }
+ await work_request(
+ user, status_msg, 'redo', params, interaction,
+ file_id=file_id,
+ binary_data=binary
+ )
+
+
+class ConfigButton(discord.ui.Button):
+
+ def __init__(self, label: str, style: discord.ButtonStyle, bot):
+ self.bot = bot
+ super().__init__(label=label, style=style)
+
+ async def callback(self, interaction):
+ db_call = self.bot.db_call
+ msg = await grab('What params do you want to change? (format: )', interaction)
+
+ user = interaction.user
+ try:
+ attr, val, reply_txt = validate_user_config_request(
+ '/config ' + msg.content)
+
+ logging.info(f'user config update: {attr} to {val}')
+ await db_call('update_user_config', user.id, attr, val)
+ logging.info('done')
+
+ except BaseException as e:
+ reply_txt = str(e)
+
+ finally:
+ await msg.reply(content=reply_txt, view=SkynetView(self.bot))
+
+
+class CoolButton(discord.ui.Button):
+
+ def __init__(self, label: str, style: discord.ButtonStyle, bot):
+ self.bot = bot
+ super().__init__(label=label, style=style)
+
+ async def callback(self, interaction):
+ await interaction.response.send_message(
+ content='\n'.join(CLEAN_COOL_WORDS),
+ view=SkynetView(self.bot))
+
+
class HelpButton(discord.ui.Button):
- def __init__(self, label:str, style:discord.ButtonStyle):
- super().__init__(label=label, style = style)
+ def __init__(self, label: str, style: discord.ButtonStyle, bot):
+ self.bot = bot
+ super().__init__(label=label, style=style)
async def callback(self, interaction):
msg = await grab('What would you like help with? (a for all)', interaction)
@@ -71,14 +160,14 @@ class HelpButton(discord.ui.Button):
param = msg.content
if param == 'a':
- await msg.reply(content=HELP_TEXT)
+ await msg.reply(content=HELP_TEXT, view=SkynetView(self.bot))
else:
if param in HELP_TOPICS:
- await msg.reply(content=HELP_TOPICS[param])
+ await msg.reply(content=HELP_TOPICS[param], view=SkynetView(self.bot))
else:
- await msg.reply(content=HELP_UNKWNOWN_PARAM)
+ await msg.reply(content=HELP_UNKWNOWN_PARAM, view=SkynetView(self.bot))