mirror of https://github.com/skygpu/skynet.git
add initial buttons, help and txt2img
parent
ff0114d341
commit
8625b5747b
|
@ -57,7 +57,7 @@ class SkynetDiscordFrontend:
|
||||||
self.remote_ipfs_node = remote_ipfs_node
|
self.remote_ipfs_node = remote_ipfs_node
|
||||||
self.key = key
|
self.key = key
|
||||||
|
|
||||||
self.bot = DiscordBot()
|
self.bot = DiscordBot(self)
|
||||||
self.cleos = CLEOS(None, None, url=node_url, remote=node_url)
|
self.cleos = CLEOS(None, None, url=node_url, remote=node_url)
|
||||||
self.hyperion = HyperionAPI(hyperion_url)
|
self.hyperion = HyperionAPI(hyperion_url)
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ import discord
|
||||||
# from dotenv import load_dotenv
|
# from dotenv import load_dotenv
|
||||||
# from pathlib import Path
|
# from pathlib import Path
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
from .ui import SkynetView
|
||||||
|
|
||||||
|
|
||||||
# # Auth
|
# # Auth
|
||||||
|
@ -18,7 +19,8 @@ from discord.ext import commands
|
||||||
# Actual Discord bot.
|
# Actual Discord bot.
|
||||||
class DiscordBot(commands.Bot):
|
class DiscordBot(commands.Bot):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, bot, *args, **kwargs):
|
||||||
|
self.bot = bot
|
||||||
intents = discord.Intents(
|
intents = discord.Intents(
|
||||||
messages=True,
|
messages=True,
|
||||||
guilds=True,
|
guilds=True,
|
||||||
|
@ -39,7 +41,7 @@ class DiscordBot(commands.Bot):
|
||||||
for guild in self.guilds:
|
for guild in self.guilds:
|
||||||
for channel in guild.channels:
|
for channel in guild.channels:
|
||||||
if channel.name == "skynet":
|
if channel.name == "skynet":
|
||||||
await channel.send('Skynet bot online')
|
await channel.send('Skynet bot online', view=SkynetView(self.bot))
|
||||||
|
|
||||||
print("\n==============")
|
print("\n==============")
|
||||||
print("Logged in as")
|
print("Logged in as")
|
||||||
|
@ -48,7 +50,12 @@ class DiscordBot(commands.Bot):
|
||||||
print("==============")
|
print("==============")
|
||||||
|
|
||||||
async def on_message(self, message):
|
async def on_message(self, message):
|
||||||
if message.channel.name != 'skynet':
|
if isinstance(message.channel, discord.DMChannel):
|
||||||
|
return
|
||||||
|
elif message.channel.name != 'skynet':
|
||||||
|
return
|
||||||
|
elif message.author != self.user:
|
||||||
|
await message.channel.send('', view=SkynetView(self.bot))
|
||||||
return
|
return
|
||||||
await self.process_commands(message)
|
await self.process_commands(message)
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,6 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
|
||||||
|
|
||||||
ipfs_node = frontend.ipfs_node
|
ipfs_node = frontend.ipfs_node
|
||||||
|
|
||||||
|
|
||||||
@bot.command(name='config', help='Responds with the configuration')
|
@bot.command(name='config', help='Responds with the configuration')
|
||||||
async def set_config(ctx):
|
async def set_config(ctx):
|
||||||
|
|
||||||
|
@ -138,6 +137,7 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
|
||||||
binary_data=binary
|
binary_data=binary
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: DELETE BELOW
|
# TODO: DELETE BELOW
|
||||||
# user = 'testworker3'
|
# user = 'testworker3'
|
||||||
# status_msg = 'status'
|
# status_msg = 'status'
|
||||||
|
|
|
@ -0,0 +1,93 @@
|
||||||
|
import discord
|
||||||
|
import logging
|
||||||
|
from skynet.constants import *
|
||||||
|
|
||||||
|
|
||||||
|
class SkynetView(discord.ui.View):
|
||||||
|
|
||||||
|
def __init__(self, bot):
|
||||||
|
self.bot = bot
|
||||||
|
super().__init__(timeout=None)
|
||||||
|
self.add_item(Txt2ImgButton('Txt2Img', discord.ButtonStyle.green, self.bot))
|
||||||
|
self.add_item(HelpButton('Help', discord.ButtonStyle.grey))
|
||||||
|
|
||||||
|
|
||||||
|
class Txt2ImgButton(discord.ui.Button):
|
||||||
|
|
||||||
|
def __init__(self, label:str, style:discord.ButtonStyle, bot):
|
||||||
|
self.bot = bot
|
||||||
|
super().__init__(label=label, style = style)
|
||||||
|
|
||||||
|
async def callback(self, interaction):
|
||||||
|
db_call = self.bot.db_call
|
||||||
|
work_request = self.bot.work_request
|
||||||
|
msg = await grab('Text Prompt:', interaction)
|
||||||
|
# grab user from msg
|
||||||
|
user = msg.author
|
||||||
|
user_row = await db_call('get_or_create_user', user.id)
|
||||||
|
|
||||||
|
# init new msg
|
||||||
|
init_msg = 'started processing txt2img request...'
|
||||||
|
status_msg = await msg.reply(init_msg)
|
||||||
|
await db_call(
|
||||||
|
'new_user_request', user.id, msg.id, status_msg.id, status=init_msg)
|
||||||
|
|
||||||
|
prompt = msg.content
|
||||||
|
|
||||||
|
if len(prompt) == 0:
|
||||||
|
await status_msg.edit(content=
|
||||||
|
'Empty text prompt ignored.'
|
||||||
|
)
|
||||||
|
await db_call('update_user_request', status_msg.id, 'Empty text prompt ignored.')
|
||||||
|
return
|
||||||
|
|
||||||
|
logging.info(f'mid: {msg.id}')
|
||||||
|
|
||||||
|
user_config = {**user_row}
|
||||||
|
del user_config['id']
|
||||||
|
|
||||||
|
params = {
|
||||||
|
'prompt': prompt,
|
||||||
|
**user_config
|
||||||
|
}
|
||||||
|
|
||||||
|
await db_call(
|
||||||
|
'update_user_stats', user.id, 'txt2img', last_prompt=prompt)
|
||||||
|
|
||||||
|
ec = await work_request(user.name, status_msg, 'txt2img', params, msg)
|
||||||
|
|
||||||
|
if ec == 0:
|
||||||
|
await db_call('increment_generated', user.id)
|
||||||
|
|
||||||
|
|
||||||
|
class HelpButton(discord.ui.Button):
|
||||||
|
|
||||||
|
def __init__(self, label:str, style:discord.ButtonStyle):
|
||||||
|
super().__init__(label=label, style = style)
|
||||||
|
|
||||||
|
async def callback(self, interaction):
|
||||||
|
msg = await grab('What would you like help with? (a for all)', interaction)
|
||||||
|
|
||||||
|
param = msg.content
|
||||||
|
|
||||||
|
if param == 'a':
|
||||||
|
await msg.reply(content=HELP_TEXT)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if param in HELP_TOPICS:
|
||||||
|
await msg.reply(content=HELP_TOPICS[param])
|
||||||
|
|
||||||
|
else:
|
||||||
|
await msg.reply(content=HELP_UNKWNOWN_PARAM)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def grab(prompt, interaction):
|
||||||
|
def vet(m):
|
||||||
|
return m.author == interaction.user and m.channel == interaction.channel
|
||||||
|
|
||||||
|
await interaction.response.send_message(prompt, ephemeral=True)
|
||||||
|
message = await interaction.client.wait_for('message', check=vet)
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue