diff --git a/skynet/frontend/discord/__init__.py b/skynet/frontend/discord/__init__.py index 29809e8..50878d3 100644 --- a/skynet/frontend/discord/__init__.py +++ b/skynet/frontend/discord/__init__.py @@ -29,6 +29,7 @@ from .bot import DiscordBot from .utils import * from .handlers import create_handler_context +from .ui import SkynetView class SkynetDiscordFrontend: @@ -260,7 +261,7 @@ class SkynetDiscordFrontend: # attempt to get the image and send it ipfs_link = f'https://ipfs.{DEFAULT_DOMAIN}/ipfs/{ipfs_hash}/image.png' resp = await get_ipfs_file(ipfs_link) - + caption, embed = generate_reply_caption( user, params, tx_hash, worker, reward) @@ -305,4 +306,4 @@ class SkynetDiscordFrontend: embed.set_image(url=ipfs_link) embed.add_field(name='Parameters:', value=caption) - await send(embed=embed) + await send(embed=embed, view=SkynetView(self)) diff --git a/skynet/frontend/discord/bot.py b/skynet/frontend/discord/bot.py index cde4144..ac8744f 100644 --- a/skynet/frontend/discord/bot.py +++ b/skynet/frontend/discord/bot.py @@ -57,8 +57,8 @@ class DiscordBot(commands.Bot): elif message.author == self.user: return await self.process_commands(message) - await asyncio.sleep(3) - await message.channel.send('', view=SkynetView(self.bot)) + # await asyncio.sleep(3) + # await message.channel.send('', view=SkynetView(self.bot)) async def on_command_error(self, ctx, error): if isinstance(error, commands.MissingRequiredArgument): diff --git a/skynet/frontend/discord/handlers.py b/skynet/frontend/discord/handlers.py index bc0cf56..dbf19f4 100644 --- a/skynet/frontend/discord/handlers.py +++ b/skynet/frontend/discord/handlers.py @@ -11,6 +11,7 @@ from PIL import Image from skynet.frontend import validate_user_config_request from skynet.constants import * +from .ui import SkynetView def create_handler_context(frontend: 'SkynetDiscordFrontend'): @@ -38,26 +39,26 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'): reply_txt = str(e) finally: - await ctx.reply(content=reply_txt) + await ctx.reply(content=reply_txt, view=SkynetView(frontend)) @bot.command(name='helper', help='Responds with a help') async def helper(ctx): splt_msg = ctx.message.content.split(' ') if len(splt_msg) == 1: - await ctx.reply(content=HELP_TEXT) + await ctx.reply(content=HELP_TEXT, view=SkynetView(frontend)) else: param = splt_msg[1] if param in HELP_TOPICS: - await ctx.reply(content=HELP_TOPICS[param]) + await ctx.reply(content=HELP_TOPICS[param], view=SkynetView(frontend)) else: - await ctx.reply(content=HELP_UNKWNOWN_PARAM) + await ctx.reply(content=HELP_UNKWNOWN_PARAM, view=SkynetView(frontend)) @bot.command(name='cool', help='Display a list of cool prompt words') async def send_cool_words(ctx): - await ctx.reply(content='\n'.join(CLEAN_COOL_WORDS)) + await ctx.reply(content='\n'.join(CLEAN_COOL_WORDS), view=SkynetView(frontend)) @bot.command(name='txt2img', help='Responds with an image') async def send_txt2img(ctx): @@ -94,7 +95,7 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'): await db_call( 'update_user_stats', user.id, 'txt2img', last_prompt=prompt) - ec = await work_request(user.name, status_msg, 'txt2img', params, ctx) + ec = await work_request(user, status_msg, 'txt2img', params, ctx) if ec == 0: await db_call('increment_generated', user.id) diff --git a/skynet/frontend/discord/ui.py b/skynet/frontend/discord/ui.py index cb54d5c..0f9452e 100644 --- a/skynet/frontend/discord/ui.py +++ b/skynet/frontend/discord/ui.py @@ -1,6 +1,7 @@ import discord import logging from skynet.constants import * +from skynet.frontend import validate_user_config_request class SkynetView(discord.ui.View): @@ -8,20 +9,23 @@ class SkynetView(discord.ui.View): def __init__(self, bot): self.bot = bot super().__init__(timeout=None) - self.add_item(Txt2ImgButton('Txt2Img', discord.ButtonStyle.green, self.bot)) - self.add_item(HelpButton('Help', discord.ButtonStyle.grey)) + self.add_item(RedoButton('redo', discord.ButtonStyle.green, self.bot)) + self.add_item(Txt2ImgButton('txt2img', discord.ButtonStyle.green, self.bot)) + self.add_item(ConfigButton('config', discord.ButtonStyle.grey, self.bot)) + self.add_item(HelpButton('help', discord.ButtonStyle.grey, self.bot)) + self.add_item(CoolButton('cool', discord.ButtonStyle.gray, self.bot)) class Txt2ImgButton(discord.ui.Button): - def __init__(self, label:str, style:discord.ButtonStyle, bot): + def __init__(self, label: str, style: discord.ButtonStyle, bot): self.bot = bot - super().__init__(label=label, style = style) + super().__init__(label=label, style=style) async def callback(self, interaction): db_call = self.bot.db_call work_request = self.bot.work_request - msg = await grab('Text Prompt:', interaction) + msg = await grab('Enter your prompt:', interaction) # grab user from msg user = msg.author user_row = await db_call('get_or_create_user', user.id) @@ -60,10 +64,95 @@ class Txt2ImgButton(discord.ui.Button): await db_call('increment_generated', user.id) +class RedoButton(discord.ui.Button): + + def __init__(self, label: str, style: discord.ButtonStyle, bot): + self.bot = bot + super().__init__(label=label, style=style) + + async def callback(self, interaction): + db_call = self.bot.db_call + work_request = self.bot.work_request + init_msg = 'started processing redo request...' + await interaction.response.send_message(init_msg) + status_msg = await interaction.original_response() + user = interaction.user + + method = await db_call('get_last_method_of', user.id) + prompt = await db_call('get_last_prompt_of', user.id) + + file_id = None + binary = '' + if method == 'img2img': + file_id = await db_call('get_last_file_of', user.id) + binary = await db_call('get_last_binary_of', user.id) + + if not prompt: + await interaction.response.edit_message( + 'no last prompt found, do a txt2img cmd first!' + ) + return + + user_row = await db_call('get_or_create_user', user.id) + await db_call( + 'new_user_request', user.id, interaction.id, status_msg.id, status=init_msg) + user_config = {**user_row} + del user_config['id'] + + params = { + 'prompt': prompt, + **user_config + } + await work_request( + user, status_msg, 'redo', params, interaction, + file_id=file_id, + binary_data=binary + ) + + +class ConfigButton(discord.ui.Button): + + def __init__(self, label: str, style: discord.ButtonStyle, bot): + self.bot = bot + super().__init__(label=label, style=style) + + async def callback(self, interaction): + db_call = self.bot.db_call + msg = await grab('What params do you want to change? (format: )', interaction) + + user = interaction.user + try: + attr, val, reply_txt = validate_user_config_request( + '/config ' + msg.content) + + logging.info(f'user config update: {attr} to {val}') + await db_call('update_user_config', user.id, attr, val) + logging.info('done') + + except BaseException as e: + reply_txt = str(e) + + finally: + await msg.reply(content=reply_txt, view=SkynetView(self.bot)) + + +class CoolButton(discord.ui.Button): + + def __init__(self, label: str, style: discord.ButtonStyle, bot): + self.bot = bot + super().__init__(label=label, style=style) + + async def callback(self, interaction): + await interaction.response.send_message( + content='\n'.join(CLEAN_COOL_WORDS), + view=SkynetView(self.bot)) + + class HelpButton(discord.ui.Button): - def __init__(self, label:str, style:discord.ButtonStyle): - super().__init__(label=label, style = style) + def __init__(self, label: str, style: discord.ButtonStyle, bot): + self.bot = bot + super().__init__(label=label, style=style) async def callback(self, interaction): msg = await grab('What would you like help with? (a for all)', interaction) @@ -71,14 +160,14 @@ class HelpButton(discord.ui.Button): param = msg.content if param == 'a': - await msg.reply(content=HELP_TEXT) + await msg.reply(content=HELP_TEXT, view=SkynetView(self.bot)) else: if param in HELP_TOPICS: - await msg.reply(content=HELP_TOPICS[param]) + await msg.reply(content=HELP_TOPICS[param], view=SkynetView(self.bot)) else: - await msg.reply(content=HELP_UNKWNOWN_PARAM) + await msg.reply(content=HELP_UNKWNOWN_PARAM, view=SkynetView(self.bot))