mirror of https://github.com/skygpu/skynet.git
				
				
				
			add and finalize buttons
							parent
							
								
									2440fe32db
								
							
						
					
					
						commit
						53ed74e9a3
					
				| 
						 | 
				
			
			@ -29,6 +29,7 @@ from .bot import DiscordBot
 | 
			
		|||
 | 
			
		||||
from .utils import *
 | 
			
		||||
from .handlers import create_handler_context
 | 
			
		||||
from .ui import SkynetView
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SkynetDiscordFrontend:
 | 
			
		||||
| 
						 | 
				
			
			@ -260,7 +261,7 @@ class SkynetDiscordFrontend:
 | 
			
		|||
        # attempt to get the image and send it
 | 
			
		||||
        ipfs_link = f'https://ipfs.{DEFAULT_DOMAIN}/ipfs/{ipfs_hash}/image.png'
 | 
			
		||||
        resp = await get_ipfs_file(ipfs_link)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        caption, embed = generate_reply_caption(
 | 
			
		||||
            user, params, tx_hash, worker, reward)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -305,4 +306,4 @@ class SkynetDiscordFrontend:
 | 
			
		|||
 | 
			
		||||
                embed.set_image(url=ipfs_link)
 | 
			
		||||
                embed.add_field(name='Parameters:', value=caption)
 | 
			
		||||
                await send(embed=embed)
 | 
			
		||||
                await send(embed=embed, view=SkynetView(self))
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -57,8 +57,8 @@ class DiscordBot(commands.Bot):
 | 
			
		|||
        elif message.author == self.user:
 | 
			
		||||
            return
 | 
			
		||||
        await self.process_commands(message)
 | 
			
		||||
        await asyncio.sleep(3)
 | 
			
		||||
        await message.channel.send('', view=SkynetView(self.bot))
 | 
			
		||||
        # await asyncio.sleep(3)
 | 
			
		||||
        # await message.channel.send('', view=SkynetView(self.bot))
 | 
			
		||||
 | 
			
		||||
    async def on_command_error(self, ctx, error):
 | 
			
		||||
        if isinstance(error, commands.MissingRequiredArgument):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -11,6 +11,7 @@ from PIL import Image
 | 
			
		|||
 | 
			
		||||
from skynet.frontend import validate_user_config_request
 | 
			
		||||
from skynet.constants import *
 | 
			
		||||
from .ui import SkynetView
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_handler_context(frontend: 'SkynetDiscordFrontend'):
 | 
			
		||||
| 
						 | 
				
			
			@ -38,26 +39,26 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
 | 
			
		|||
            reply_txt = str(e)
 | 
			
		||||
 | 
			
		||||
        finally:
 | 
			
		||||
            await ctx.reply(content=reply_txt)
 | 
			
		||||
            await ctx.reply(content=reply_txt, view=SkynetView(frontend))
 | 
			
		||||
 | 
			
		||||
    @bot.command(name='helper', help='Responds with a help')
 | 
			
		||||
    async def helper(ctx):
 | 
			
		||||
        splt_msg = ctx.message.content.split(' ')
 | 
			
		||||
 | 
			
		||||
        if len(splt_msg) == 1:
 | 
			
		||||
            await ctx.reply(content=HELP_TEXT)
 | 
			
		||||
            await ctx.reply(content=HELP_TEXT, view=SkynetView(frontend))
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            param = splt_msg[1]
 | 
			
		||||
            if param in HELP_TOPICS:
 | 
			
		||||
                await ctx.reply(content=HELP_TOPICS[param])
 | 
			
		||||
                await ctx.reply(content=HELP_TOPICS[param], view=SkynetView(frontend))
 | 
			
		||||
 | 
			
		||||
            else:
 | 
			
		||||
                await ctx.reply(content=HELP_UNKWNOWN_PARAM)
 | 
			
		||||
                await ctx.reply(content=HELP_UNKWNOWN_PARAM, view=SkynetView(frontend))
 | 
			
		||||
 | 
			
		||||
    @bot.command(name='cool', help='Display a list of cool prompt words')
 | 
			
		||||
    async def send_cool_words(ctx):
 | 
			
		||||
        await ctx.reply(content='\n'.join(CLEAN_COOL_WORDS))
 | 
			
		||||
        await ctx.reply(content='\n'.join(CLEAN_COOL_WORDS), view=SkynetView(frontend))
 | 
			
		||||
 | 
			
		||||
    @bot.command(name='txt2img', help='Responds with an image')
 | 
			
		||||
    async def send_txt2img(ctx):
 | 
			
		||||
| 
						 | 
				
			
			@ -94,7 +95,7 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
 | 
			
		|||
        await db_call(
 | 
			
		||||
            'update_user_stats', user.id, 'txt2img', last_prompt=prompt)
 | 
			
		||||
 | 
			
		||||
        ec = await work_request(user.name, status_msg, 'txt2img', params, ctx)
 | 
			
		||||
        ec = await work_request(user, status_msg, 'txt2img', params, ctx)
 | 
			
		||||
 | 
			
		||||
        if ec == 0:
 | 
			
		||||
            await db_call('increment_generated', user.id)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,6 +1,7 @@
 | 
			
		|||
import discord
 | 
			
		||||
import logging
 | 
			
		||||
from skynet.constants import *
 | 
			
		||||
from skynet.frontend import validate_user_config_request
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SkynetView(discord.ui.View):
 | 
			
		||||
| 
						 | 
				
			
			@ -8,20 +9,23 @@ class SkynetView(discord.ui.View):
 | 
			
		|||
    def __init__(self, bot):
 | 
			
		||||
        self.bot = bot
 | 
			
		||||
        super().__init__(timeout=None)
 | 
			
		||||
        self.add_item(Txt2ImgButton('Txt2Img', discord.ButtonStyle.green, self.bot))
 | 
			
		||||
        self.add_item(HelpButton('Help', discord.ButtonStyle.grey))
 | 
			
		||||
        self.add_item(RedoButton('redo', discord.ButtonStyle.green, self.bot))
 | 
			
		||||
        self.add_item(Txt2ImgButton('txt2img', discord.ButtonStyle.green, self.bot))
 | 
			
		||||
        self.add_item(ConfigButton('config', discord.ButtonStyle.grey, self.bot))
 | 
			
		||||
        self.add_item(HelpButton('help', discord.ButtonStyle.grey, self.bot))
 | 
			
		||||
        self.add_item(CoolButton('cool', discord.ButtonStyle.gray, self.bot))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Txt2ImgButton(discord.ui.Button):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, label:str, style:discord.ButtonStyle, bot):
 | 
			
		||||
    def __init__(self, label: str, style: discord.ButtonStyle, bot):
 | 
			
		||||
        self.bot = bot
 | 
			
		||||
        super().__init__(label=label, style = style)
 | 
			
		||||
        super().__init__(label=label, style=style)
 | 
			
		||||
 | 
			
		||||
    async def callback(self, interaction):
 | 
			
		||||
        db_call = self.bot.db_call
 | 
			
		||||
        work_request = self.bot.work_request
 | 
			
		||||
        msg = await grab('Text Prompt:', interaction)
 | 
			
		||||
        msg = await grab('Enter your prompt:', interaction)
 | 
			
		||||
        # grab user from msg
 | 
			
		||||
        user = msg.author
 | 
			
		||||
        user_row = await db_call('get_or_create_user', user.id)
 | 
			
		||||
| 
						 | 
				
			
			@ -60,10 +64,95 @@ class Txt2ImgButton(discord.ui.Button):
 | 
			
		|||
            await db_call('increment_generated', user.id)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RedoButton(discord.ui.Button):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, label: str, style: discord.ButtonStyle, bot):
 | 
			
		||||
        self.bot = bot
 | 
			
		||||
        super().__init__(label=label, style=style)
 | 
			
		||||
 | 
			
		||||
    async def callback(self, interaction):
 | 
			
		||||
        db_call = self.bot.db_call
 | 
			
		||||
        work_request = self.bot.work_request
 | 
			
		||||
        init_msg = 'started processing redo request...'
 | 
			
		||||
        await interaction.response.send_message(init_msg)
 | 
			
		||||
        status_msg = await interaction.original_response()
 | 
			
		||||
        user = interaction.user
 | 
			
		||||
 | 
			
		||||
        method = await db_call('get_last_method_of', user.id)
 | 
			
		||||
        prompt = await db_call('get_last_prompt_of', user.id)
 | 
			
		||||
 | 
			
		||||
        file_id = None
 | 
			
		||||
        binary = ''
 | 
			
		||||
        if method == 'img2img':
 | 
			
		||||
            file_id = await db_call('get_last_file_of', user.id)
 | 
			
		||||
            binary = await db_call('get_last_binary_of', user.id)
 | 
			
		||||
 | 
			
		||||
        if not prompt:
 | 
			
		||||
            await interaction.response.edit_message(
 | 
			
		||||
                'no last prompt found, do a txt2img cmd first!'
 | 
			
		||||
            )
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        user_row = await db_call('get_or_create_user', user.id)
 | 
			
		||||
        await db_call(
 | 
			
		||||
            'new_user_request', user.id, interaction.id, status_msg.id, status=init_msg)
 | 
			
		||||
        user_config = {**user_row}
 | 
			
		||||
        del user_config['id']
 | 
			
		||||
 | 
			
		||||
        params = {
 | 
			
		||||
            'prompt': prompt,
 | 
			
		||||
            **user_config
 | 
			
		||||
        }
 | 
			
		||||
        await work_request(
 | 
			
		||||
            user, status_msg, 'redo', params, interaction,
 | 
			
		||||
            file_id=file_id,
 | 
			
		||||
            binary_data=binary
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ConfigButton(discord.ui.Button):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, label: str, style: discord.ButtonStyle, bot):
 | 
			
		||||
        self.bot = bot
 | 
			
		||||
        super().__init__(label=label, style=style)
 | 
			
		||||
 | 
			
		||||
    async def callback(self, interaction):
 | 
			
		||||
        db_call = self.bot.db_call
 | 
			
		||||
        msg = await grab('What params do you want to change? (format: <param> <value>)', interaction)
 | 
			
		||||
 | 
			
		||||
        user = interaction.user
 | 
			
		||||
        try:
 | 
			
		||||
            attr, val, reply_txt = validate_user_config_request(
 | 
			
		||||
                '/config ' + msg.content)
 | 
			
		||||
 | 
			
		||||
            logging.info(f'user config update: {attr} to {val}')
 | 
			
		||||
            await db_call('update_user_config', user.id, attr, val)
 | 
			
		||||
            logging.info('done')
 | 
			
		||||
 | 
			
		||||
        except BaseException as e:
 | 
			
		||||
            reply_txt = str(e)
 | 
			
		||||
 | 
			
		||||
        finally:
 | 
			
		||||
            await msg.reply(content=reply_txt, view=SkynetView(self.bot))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CoolButton(discord.ui.Button):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, label: str, style: discord.ButtonStyle, bot):
 | 
			
		||||
        self.bot = bot
 | 
			
		||||
        super().__init__(label=label, style=style)
 | 
			
		||||
 | 
			
		||||
    async def callback(self, interaction):
 | 
			
		||||
        await interaction.response.send_message(
 | 
			
		||||
            content='\n'.join(CLEAN_COOL_WORDS),
 | 
			
		||||
            view=SkynetView(self.bot))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HelpButton(discord.ui.Button):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, label:str, style:discord.ButtonStyle):
 | 
			
		||||
        super().__init__(label=label, style = style)
 | 
			
		||||
    def __init__(self, label: str, style: discord.ButtonStyle, bot):
 | 
			
		||||
        self.bot = bot
 | 
			
		||||
        super().__init__(label=label, style=style)
 | 
			
		||||
 | 
			
		||||
    async def callback(self, interaction):
 | 
			
		||||
        msg = await grab('What would you like help with? (a for all)', interaction)
 | 
			
		||||
| 
						 | 
				
			
			@ -71,14 +160,14 @@ class HelpButton(discord.ui.Button):
 | 
			
		|||
        param = msg.content
 | 
			
		||||
 | 
			
		||||
        if param == 'a':
 | 
			
		||||
            await msg.reply(content=HELP_TEXT)
 | 
			
		||||
            await msg.reply(content=HELP_TEXT, view=SkynetView(self.bot))
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            if param in HELP_TOPICS:
 | 
			
		||||
                await msg.reply(content=HELP_TOPICS[param])
 | 
			
		||||
                await msg.reply(content=HELP_TOPICS[param], view=SkynetView(self.bot))
 | 
			
		||||
 | 
			
		||||
            else:
 | 
			
		||||
                await msg.reply(content=HELP_UNKWNOWN_PARAM)
 | 
			
		||||
                await msg.reply(content=HELP_UNKWNOWN_PARAM, view=SkynetView(self.bot))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue