diff --git a/skynet/frontend/discord/__init__.py b/skynet/frontend/discord/__init__.py index 9fee423..f020734 100644 --- a/skynet/frontend/discord/__init__.py +++ b/skynet/frontend/discord/__init__.py @@ -57,7 +57,7 @@ class SkynetDiscordFrontend: self.remote_ipfs_node = remote_ipfs_node self.key = key - self.bot = DiscordBot() + self.bot = DiscordBot(self) self.cleos = CLEOS(None, None, url=node_url, remote=node_url) self.hyperion = HyperionAPI(hyperion_url) diff --git a/skynet/frontend/discord/bot.py b/skynet/frontend/discord/bot.py index c82d924..97cff8e 100644 --- a/skynet/frontend/discord/bot.py +++ b/skynet/frontend/discord/bot.py @@ -4,6 +4,7 @@ import discord # from dotenv import load_dotenv # from pathlib import Path from discord.ext import commands +from .ui import SkynetView # # Auth @@ -18,7 +19,8 @@ from discord.ext import commands # Actual Discord bot. class DiscordBot(commands.Bot): - def __init__(self, *args, **kwargs): + def __init__(self, bot, *args, **kwargs): + self.bot = bot intents = discord.Intents( messages=True, guilds=True, @@ -39,7 +41,7 @@ class DiscordBot(commands.Bot): for guild in self.guilds: for channel in guild.channels: if channel.name == "skynet": - await channel.send('Skynet bot online') + await channel.send('Skynet bot online', view=SkynetView(self.bot)) print("\n==============") print("Logged in as") @@ -48,7 +50,12 @@ class DiscordBot(commands.Bot): print("==============") async def on_message(self, message): - if message.channel.name != 'skynet': + if isinstance(message.channel, discord.DMChannel): + return + elif message.channel.name != 'skynet': + return + elif message.author != self.user: + await message.channel.send('', view=SkynetView(self.bot)) return await self.process_commands(message) diff --git a/skynet/frontend/discord/handlers.py b/skynet/frontend/discord/handlers.py index 6262b28..bc0cf56 100644 --- a/skynet/frontend/discord/handlers.py +++ b/skynet/frontend/discord/handlers.py @@ -22,7 +22,6 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'): ipfs_node = frontend.ipfs_node - @bot.command(name='config', help='Responds with the configuration') async def set_config(ctx): @@ -138,6 +137,7 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'): binary_data=binary ) + # TODO: DELETE BELOW # user = 'testworker3' # status_msg = 'status' diff --git a/skynet/frontend/discord/ui.py b/skynet/frontend/discord/ui.py new file mode 100644 index 0000000..6704bbe --- /dev/null +++ b/skynet/frontend/discord/ui.py @@ -0,0 +1,93 @@ +import discord +import logging +from skynet.constants import * + + +class SkynetView(discord.ui.View): + + def __init__(self, bot): + self.bot = bot + super().__init__(timeout=None) + self.add_item(Txt2ImgButton('Txt2Img', discord.ButtonStyle.green, self.bot)) + self.add_item(HelpButton('Help', discord.ButtonStyle.grey)) + + +class Txt2ImgButton(discord.ui.Button): + + def __init__(self, label:str, style:discord.ButtonStyle, bot): + self.bot = bot + super().__init__(label=label, style = style) + + async def callback(self, interaction): + db_call = self.bot.db_call + work_request = self.bot.work_request + msg = await grab('Text Prompt:', interaction) + # grab user from msg + user = msg.author + user_row = await db_call('get_or_create_user', user.id) + + # init new msg + init_msg = 'started processing txt2img request...' + status_msg = await msg.reply(init_msg) + await db_call( + 'new_user_request', user.id, msg.id, status_msg.id, status=init_msg) + + prompt = msg.content + + if len(prompt) == 0: + await status_msg.edit(content= + 'Empty text prompt ignored.' + ) + await db_call('update_user_request', status_msg.id, 'Empty text prompt ignored.') + return + + logging.info(f'mid: {msg.id}') + + user_config = {**user_row} + del user_config['id'] + + params = { + 'prompt': prompt, + **user_config + } + + await db_call( + 'update_user_stats', user.id, 'txt2img', last_prompt=prompt) + + ec = await work_request(user.name, status_msg, 'txt2img', params, msg) + + if ec == 0: + await db_call('increment_generated', user.id) + + +class HelpButton(discord.ui.Button): + + def __init__(self, label:str, style:discord.ButtonStyle): + super().__init__(label=label, style = style) + + async def callback(self, interaction): + msg = await grab('What would you like help with? (a for all)', interaction) + + param = msg.content + + if param == 'a': + await msg.reply(content=HELP_TEXT) + + else: + if param in HELP_TOPICS: + await msg.reply(content=HELP_TOPICS[param]) + + else: + await msg.reply(content=HELP_UNKWNOWN_PARAM) + + + +async def grab(prompt, interaction): + def vet(m): + return m.author == interaction.user and m.channel == interaction.channel + + await interaction.response.send_message(prompt, ephemeral=True) + message = await interaction.client.wait_for('message', check=vet) + return message + +