mirror of https://github.com/skygpu/skynet.git
				
				
				
			add initial buttons, help and txt2img
							parent
							
								
									ff0114d341
								
							
						
					
					
						commit
						8625b5747b
					
				| 
						 | 
				
			
			@ -57,7 +57,7 @@ class SkynetDiscordFrontend:
 | 
			
		|||
        self.remote_ipfs_node = remote_ipfs_node
 | 
			
		||||
        self.key = key
 | 
			
		||||
 | 
			
		||||
        self.bot = DiscordBot()
 | 
			
		||||
        self.bot = DiscordBot(self)
 | 
			
		||||
        self.cleos = CLEOS(None, None, url=node_url, remote=node_url)
 | 
			
		||||
        self.hyperion = HyperionAPI(hyperion_url)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -4,6 +4,7 @@ import discord
 | 
			
		|||
# from dotenv import load_dotenv
 | 
			
		||||
# from pathlib import Path
 | 
			
		||||
from discord.ext import commands
 | 
			
		||||
from .ui import SkynetView
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# # Auth
 | 
			
		||||
| 
						 | 
				
			
			@ -18,7 +19,8 @@ from discord.ext import commands
 | 
			
		|||
# Actual Discord bot.
 | 
			
		||||
class DiscordBot(commands.Bot):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
    def __init__(self, bot, *args, **kwargs):
 | 
			
		||||
        self.bot = bot
 | 
			
		||||
        intents = discord.Intents(
 | 
			
		||||
            messages=True,
 | 
			
		||||
            guilds=True,
 | 
			
		||||
| 
						 | 
				
			
			@ -39,7 +41,7 @@ class DiscordBot(commands.Bot):
 | 
			
		|||
        for guild in self.guilds:
 | 
			
		||||
            for channel in guild.channels:
 | 
			
		||||
                if channel.name == "skynet":
 | 
			
		||||
                    await channel.send('Skynet bot online')
 | 
			
		||||
                    await channel.send('Skynet bot online', view=SkynetView(self.bot))
 | 
			
		||||
 | 
			
		||||
        print("\n==============")
 | 
			
		||||
        print("Logged in as")
 | 
			
		||||
| 
						 | 
				
			
			@ -48,7 +50,12 @@ class DiscordBot(commands.Bot):
 | 
			
		|||
        print("==============")
 | 
			
		||||
 | 
			
		||||
    async def on_message(self, message):
 | 
			
		||||
        if message.channel.name != 'skynet':
 | 
			
		||||
        if isinstance(message.channel, discord.DMChannel):
 | 
			
		||||
            return
 | 
			
		||||
        elif message.channel.name != 'skynet':
 | 
			
		||||
            return
 | 
			
		||||
        elif message.author != self.user:
 | 
			
		||||
            await message.channel.send('', view=SkynetView(self.bot))
 | 
			
		||||
            return
 | 
			
		||||
        await self.process_commands(message)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -22,7 +22,6 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
 | 
			
		|||
 | 
			
		||||
    ipfs_node = frontend.ipfs_node
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @bot.command(name='config', help='Responds with the configuration')
 | 
			
		||||
    async def set_config(ctx):
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -138,6 +137,7 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
 | 
			
		|||
            binary_data=binary
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        
 | 
			
		||||
        # TODO: DELETE BELOW
 | 
			
		||||
        # user = 'testworker3'
 | 
			
		||||
        # status_msg = 'status'
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,93 @@
 | 
			
		|||
import discord
 | 
			
		||||
import logging
 | 
			
		||||
from skynet.constants import *
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SkynetView(discord.ui.View):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, bot):
 | 
			
		||||
        self.bot = bot
 | 
			
		||||
        super().__init__(timeout=None)
 | 
			
		||||
        self.add_item(Txt2ImgButton('Txt2Img', discord.ButtonStyle.green, self.bot))
 | 
			
		||||
        self.add_item(HelpButton('Help', discord.ButtonStyle.grey))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Txt2ImgButton(discord.ui.Button):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, label:str, style:discord.ButtonStyle, bot):
 | 
			
		||||
        self.bot = bot
 | 
			
		||||
        super().__init__(label=label, style = style)
 | 
			
		||||
 | 
			
		||||
    async def callback(self, interaction):
 | 
			
		||||
        db_call = self.bot.db_call
 | 
			
		||||
        work_request = self.bot.work_request
 | 
			
		||||
        msg = await grab('Text Prompt:', interaction)
 | 
			
		||||
        # grab user from msg
 | 
			
		||||
        user = msg.author
 | 
			
		||||
        user_row = await db_call('get_or_create_user', user.id)
 | 
			
		||||
 | 
			
		||||
        # init new msg
 | 
			
		||||
        init_msg = 'started processing txt2img request...'
 | 
			
		||||
        status_msg = await msg.reply(init_msg)
 | 
			
		||||
        await db_call(
 | 
			
		||||
            'new_user_request', user.id, msg.id, status_msg.id, status=init_msg)
 | 
			
		||||
 | 
			
		||||
        prompt = msg.content
 | 
			
		||||
 | 
			
		||||
        if len(prompt) == 0:
 | 
			
		||||
            await status_msg.edit(content=
 | 
			
		||||
                'Empty text prompt ignored.'
 | 
			
		||||
            )
 | 
			
		||||
            await db_call('update_user_request', status_msg.id, 'Empty text prompt ignored.')
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        logging.info(f'mid: {msg.id}')
 | 
			
		||||
 | 
			
		||||
        user_config = {**user_row}
 | 
			
		||||
        del user_config['id']
 | 
			
		||||
 | 
			
		||||
        params = {
 | 
			
		||||
            'prompt': prompt,
 | 
			
		||||
            **user_config
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        await db_call(
 | 
			
		||||
            'update_user_stats', user.id, 'txt2img', last_prompt=prompt)
 | 
			
		||||
 | 
			
		||||
        ec = await work_request(user.name, status_msg, 'txt2img', params, msg)
 | 
			
		||||
 | 
			
		||||
        if ec == 0:
 | 
			
		||||
            await db_call('increment_generated', user.id)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HelpButton(discord.ui.Button):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, label:str, style:discord.ButtonStyle):
 | 
			
		||||
        super().__init__(label=label, style = style)
 | 
			
		||||
 | 
			
		||||
    async def callback(self, interaction):
 | 
			
		||||
        msg = await grab('What would you like help with? (a for all)', interaction)
 | 
			
		||||
 | 
			
		||||
        param = msg.content
 | 
			
		||||
 | 
			
		||||
        if param == 'a':
 | 
			
		||||
            await msg.reply(content=HELP_TEXT)
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            if param in HELP_TOPICS:
 | 
			
		||||
                await msg.reply(content=HELP_TOPICS[param])
 | 
			
		||||
 | 
			
		||||
            else:
 | 
			
		||||
                await msg.reply(content=HELP_UNKWNOWN_PARAM)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def grab(prompt, interaction):
 | 
			
		||||
    def vet(m):
 | 
			
		||||
        return m.author == interaction.user and m.channel == interaction.channel
 | 
			
		||||
 | 
			
		||||
    await interaction.response.send_message(prompt, ephemeral=True)
 | 
			
		||||
    message = await interaction.client.wait_for('message', check=vet)
 | 
			
		||||
    return message
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue