add initial buttons, help and txt2img

pull/11/head
Konstantine Tsafatinos 2023-07-20 01:16:22 -04:00
parent ff0114d341
commit 8625b5747b
4 changed files with 105 additions and 5 deletions

View File

@ -57,7 +57,7 @@ class SkynetDiscordFrontend:
self.remote_ipfs_node = remote_ipfs_node self.remote_ipfs_node = remote_ipfs_node
self.key = key self.key = key
self.bot = DiscordBot() self.bot = DiscordBot(self)
self.cleos = CLEOS(None, None, url=node_url, remote=node_url) self.cleos = CLEOS(None, None, url=node_url, remote=node_url)
self.hyperion = HyperionAPI(hyperion_url) self.hyperion = HyperionAPI(hyperion_url)

View File

@ -4,6 +4,7 @@ import discord
# from dotenv import load_dotenv # from dotenv import load_dotenv
# from pathlib import Path # from pathlib import Path
from discord.ext import commands from discord.ext import commands
from .ui import SkynetView
# # Auth # # Auth
@ -18,7 +19,8 @@ from discord.ext import commands
# Actual Discord bot. # Actual Discord bot.
class DiscordBot(commands.Bot): class DiscordBot(commands.Bot):
def __init__(self, *args, **kwargs): def __init__(self, bot, *args, **kwargs):
self.bot = bot
intents = discord.Intents( intents = discord.Intents(
messages=True, messages=True,
guilds=True, guilds=True,
@ -39,7 +41,7 @@ class DiscordBot(commands.Bot):
for guild in self.guilds: for guild in self.guilds:
for channel in guild.channels: for channel in guild.channels:
if channel.name == "skynet": if channel.name == "skynet":
await channel.send('Skynet bot online') await channel.send('Skynet bot online', view=SkynetView(self.bot))
print("\n==============") print("\n==============")
print("Logged in as") print("Logged in as")
@ -48,7 +50,12 @@ class DiscordBot(commands.Bot):
print("==============") print("==============")
async def on_message(self, message): async def on_message(self, message):
if message.channel.name != 'skynet': if isinstance(message.channel, discord.DMChannel):
return
elif message.channel.name != 'skynet':
return
elif message.author != self.user:
await message.channel.send('', view=SkynetView(self.bot))
return return
await self.process_commands(message) await self.process_commands(message)

View File

@ -22,7 +22,6 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
ipfs_node = frontend.ipfs_node ipfs_node = frontend.ipfs_node
@bot.command(name='config', help='Responds with the configuration') @bot.command(name='config', help='Responds with the configuration')
async def set_config(ctx): async def set_config(ctx):
@ -138,6 +137,7 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
binary_data=binary binary_data=binary
) )
# TODO: DELETE BELOW # TODO: DELETE BELOW
# user = 'testworker3' # user = 'testworker3'
# status_msg = 'status' # status_msg = 'status'

View File

@ -0,0 +1,93 @@
import discord
import logging
from skynet.constants import *
class SkynetView(discord.ui.View):
def __init__(self, bot):
self.bot = bot
super().__init__(timeout=None)
self.add_item(Txt2ImgButton('Txt2Img', discord.ButtonStyle.green, self.bot))
self.add_item(HelpButton('Help', discord.ButtonStyle.grey))
class Txt2ImgButton(discord.ui.Button):
def __init__(self, label:str, style:discord.ButtonStyle, bot):
self.bot = bot
super().__init__(label=label, style = style)
async def callback(self, interaction):
db_call = self.bot.db_call
work_request = self.bot.work_request
msg = await grab('Text Prompt:', interaction)
# grab user from msg
user = msg.author
user_row = await db_call('get_or_create_user', user.id)
# init new msg
init_msg = 'started processing txt2img request...'
status_msg = await msg.reply(init_msg)
await db_call(
'new_user_request', user.id, msg.id, status_msg.id, status=init_msg)
prompt = msg.content
if len(prompt) == 0:
await status_msg.edit(content=
'Empty text prompt ignored.'
)
await db_call('update_user_request', status_msg.id, 'Empty text prompt ignored.')
return
logging.info(f'mid: {msg.id}')
user_config = {**user_row}
del user_config['id']
params = {
'prompt': prompt,
**user_config
}
await db_call(
'update_user_stats', user.id, 'txt2img', last_prompt=prompt)
ec = await work_request(user.name, status_msg, 'txt2img', params, msg)
if ec == 0:
await db_call('increment_generated', user.id)
class HelpButton(discord.ui.Button):
def __init__(self, label:str, style:discord.ButtonStyle):
super().__init__(label=label, style = style)
async def callback(self, interaction):
msg = await grab('What would you like help with? (a for all)', interaction)
param = msg.content
if param == 'a':
await msg.reply(content=HELP_TEXT)
else:
if param in HELP_TOPICS:
await msg.reply(content=HELP_TOPICS[param])
else:
await msg.reply(content=HELP_UNKWNOWN_PARAM)
async def grab(prompt, interaction):
def vet(m):
return m.author == interaction.user and m.channel == interaction.channel
await interaction.response.send_message(prompt, ephemeral=True)
message = await interaction.client.wait_for('message', check=vet)
return message