mirror of https://github.com/skygpu/skynet.git
add img2img support, add stats and donate button, finalize UI, add live updates
parent
58c6a2070e
commit
4260187208
|
@ -36,6 +36,7 @@ commands work on a user per user basis!
|
||||||
config is individual to each user!
|
config is individual to each user!
|
||||||
|
|
||||||
/txt2img TEXT - request an image based on a prompt
|
/txt2img TEXT - request an image based on a prompt
|
||||||
|
/img2img <attach_image> TEXT - request an image base on an image and a promtp
|
||||||
|
|
||||||
/redo - redo last command (only works for txt2img for now!)
|
/redo - redo last command (only works for txt2img for now!)
|
||||||
|
|
||||||
|
|
|
@ -96,7 +96,8 @@ def open_new_database(cleanup=True):
|
||||||
'POSTGRES_PASSWORD': rpassword
|
'POSTGRES_PASSWORD': rpassword
|
||||||
},
|
},
|
||||||
detach=True,
|
detach=True,
|
||||||
remove=True
|
# could remove this if we ant the dockers to be persistent.
|
||||||
|
# remove=True
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
|
|
@ -89,6 +89,7 @@ class SkynetDiscordFrontend:
|
||||||
yield self
|
yield self
|
||||||
await self.stop()
|
await self.stop()
|
||||||
|
|
||||||
|
# maybe do this?
|
||||||
# async def update_status_message(
|
# async def update_status_message(
|
||||||
# self, status_msg, new_text: str, **kwargs
|
# self, status_msg, new_text: str, **kwargs
|
||||||
# ):
|
# ):
|
||||||
|
@ -139,17 +140,14 @@ class SkynetDiscordFrontend:
|
||||||
})
|
})
|
||||||
request_time = datetime.now().isoformat()
|
request_time = datetime.now().isoformat()
|
||||||
|
|
||||||
# maybe get rid of this
|
await status_msg.delete()
|
||||||
# await self.update_status_message(
|
msg_text = f'processing a \'{method}\' request by {user.name}\n[{timestamp_pretty()}] *broadcasting transaction to chain...* '
|
||||||
# status_msg,
|
embed = discord.Embed(
|
||||||
# f'processing a \'{method}\' request by {tg_user_pretty(user)}\n'
|
title='live updates',
|
||||||
# f'[{timestamp_pretty()}] <i>broadcasting transaction to chain...</i>',
|
description=msg_text,
|
||||||
# parse_mode='HTML'
|
color=discord.Color.blue())
|
||||||
# )
|
|
||||||
# message = await ctx.send(
|
message = await send(embed=embed)
|
||||||
# f'processing a \'{method}\' request by {user}\n \
|
|
||||||
# [{timestamp_pretty()}] *broadcasting transaction to chain...*'
|
|
||||||
# )
|
|
||||||
|
|
||||||
reward = '20.0000 GPU'
|
reward = '20.0000 GPU'
|
||||||
res = await self.cleos.a_push_action(
|
res = await self.cleos.a_push_action(
|
||||||
|
@ -164,7 +162,6 @@ class SkynetDiscordFrontend:
|
||||||
},
|
},
|
||||||
self.account, self.key, permission=self.permission
|
self.account, self.key, permission=self.permission
|
||||||
)
|
)
|
||||||
# print(res)
|
|
||||||
|
|
||||||
if 'code' in res or 'statusCode' in res:
|
if 'code' in res or 'statusCode' in res:
|
||||||
logging.error(json.dumps(res, indent=4))
|
logging.error(json.dumps(res, indent=4))
|
||||||
|
@ -174,23 +171,15 @@ class SkynetDiscordFrontend:
|
||||||
return
|
return
|
||||||
|
|
||||||
enqueue_tx_id = res['transaction_id']
|
enqueue_tx_id = res['transaction_id']
|
||||||
enqueue_tx_link = hlink(
|
enqueue_tx_link = f'[**Your request on Skynet Explorer**](https://explorer.{DEFAULT_DOMAIN}/v2/explore/transaction/{enqueue_tx_id})'
|
||||||
'Your request on Skynet Explorer',
|
|
||||||
f'https://explorer.{DEFAULT_DOMAIN}/v2/explore/transaction/{enqueue_tx_id}'
|
|
||||||
)
|
|
||||||
|
|
||||||
# await self.append_status_message(
|
msg_text += f'**broadcasted!** \n{enqueue_tx_link}\n[{timestamp_pretty()}] *workers are processing request...* '
|
||||||
# status_msg,
|
embed = discord.Embed(
|
||||||
# f' <b>broadcasted!</b>\n'
|
title='live updates',
|
||||||
# f'<b>{enqueue_tx_link}</b>\n'
|
description=msg_text,
|
||||||
# f'[{timestamp_pretty()}] <i>workers are processing request...</i>',
|
color=discord.Color.blue())
|
||||||
# parse_mode='HTML'
|
|
||||||
# )
|
await message.edit(embed=embed)
|
||||||
# await message.edit(content=
|
|
||||||
# f'**broadcasted!**\n \
|
|
||||||
# **{enqueue_tx_link}**\n \
|
|
||||||
# [{timestamp_pretty()}] *workers are processing request...*'
|
|
||||||
# )
|
|
||||||
|
|
||||||
out = collect_stdout(res)
|
out = collect_stdout(res)
|
||||||
|
|
||||||
|
@ -233,77 +222,45 @@ class SkynetDiscordFrontend:
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
if not ipfs_hash:
|
if not ipfs_hash:
|
||||||
# await self.update_status_message(
|
|
||||||
# status_msg,
|
msg_text += f'\n[{timestamp_pretty()}] **timeout processing request**'
|
||||||
# f'\n[{timestamp_pretty()}] <b>timeout processing request</b>',
|
embed = discord.Embed(
|
||||||
# parse_mode='HTML'
|
title='live updates',
|
||||||
# )
|
description=msg_text,
|
||||||
|
color=discord.Color.blue())
|
||||||
|
|
||||||
|
await message.edit(embed=embed)
|
||||||
return
|
return
|
||||||
|
|
||||||
tx_link = hlink(
|
tx_link = f'[**Your result on Skynet Explorer**](https://explorer.{DEFAULT_DOMAIN}/v2/explore/transaction/{tx_hash})'
|
||||||
'Your result on Skynet Explorer',
|
|
||||||
f'https://explorer.{DEFAULT_DOMAIN}/v2/explore/transaction/{tx_hash}'
|
|
||||||
)
|
|
||||||
|
|
||||||
# await self.append_status_message(
|
msg_text += f'**request processed!**\n{tx_link}\n[{timestamp_pretty()}] *trying to download image...*\n '
|
||||||
# status_msg,
|
embed = discord.Embed(
|
||||||
# f' <b>request processed!</b>\n'
|
title='live updates',
|
||||||
# f'<b>{tx_link}</b>\n'
|
description=msg_text,
|
||||||
# f'[{timestamp_pretty()}] <i>trying to download image...</i>\n',
|
color=discord.Color.blue())
|
||||||
# parse_mode='HTML'
|
|
||||||
# )
|
await message.edit(embed=embed)
|
||||||
# await message.edit(content=
|
|
||||||
# f'**request processed!**\n \
|
|
||||||
# **{tx_link}**\n \
|
|
||||||
# [{timestamp_pretty()}] *trying to download image...*\n'
|
|
||||||
# )
|
|
||||||
|
|
||||||
# attempt to get the image and send it
|
# attempt to get the image and send it
|
||||||
ipfs_link = f'https://ipfs.{DEFAULT_DOMAIN}/ipfs/{ipfs_hash}/image.png'
|
ipfs_link = f'https://ipfs.{DEFAULT_DOMAIN}/ipfs/{ipfs_hash}/image.png'
|
||||||
resp = await get_ipfs_file(ipfs_link)
|
resp = await get_ipfs_file(ipfs_link)
|
||||||
|
|
||||||
|
# reword this function, may not need caption
|
||||||
caption, embed = generate_reply_caption(
|
caption, embed = generate_reply_caption(
|
||||||
user, params, tx_hash, worker, reward)
|
user, params, tx_hash, worker, reward)
|
||||||
|
|
||||||
if not resp or resp.status_code != 200:
|
if not resp or resp.status_code != 200:
|
||||||
logging.error(f'couldn\'t get ipfs hosted image at {ipfs_link}!')
|
logging.error(f'couldn\'t get ipfs hosted image at {ipfs_link}!')
|
||||||
# await self.update_status_message(
|
await message.edit(embed=embed, view=SkynetView(self))
|
||||||
# status_msg,
|
|
||||||
# caption,
|
|
||||||
# reply_markup=build_redo_menu(),
|
|
||||||
# parse_mode='HTML'
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
else:
|
else:
|
||||||
logging.info(f'success! sending generated image')
|
logging.info(f'success! sending generated image')
|
||||||
# image = io.BytesIO(resp.raw)
|
await message.delete()
|
||||||
# embed.set_image(url=ipfs_link)
|
|
||||||
# embed.add_field(name='params', value=caption)
|
|
||||||
# await self.bot.delete_message(
|
|
||||||
# chat_id=status_msg.chat.id, message_id=status_msg.id)
|
|
||||||
if file_id: # img2img
|
if file_id: # img2img
|
||||||
pass
|
embed.set_thumbnail(
|
||||||
# await self.bot.send_media_group(
|
url='https://ipfs.skygpu.net/ipfs/' + binary_data + '/image.png')
|
||||||
# status_msg.chat.id,
|
embed.set_image(url=ipfs_link)
|
||||||
# media=[
|
await send(embed=embed, view=SkynetView(self))
|
||||||
# InputMediaPhoto(file_id),
|
else: # txt2img
|
||||||
# InputMediaPhoto(
|
|
||||||
# resp.raw,
|
|
||||||
# caption=caption,
|
|
||||||
# parse_mode='HTML'
|
|
||||||
# )
|
|
||||||
# ],
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
else: # txt2img
|
|
||||||
# await self.bot.send_photo(
|
|
||||||
# status_msg.chat.id,
|
|
||||||
# caption=caption,
|
|
||||||
# photo=resp.raw,
|
|
||||||
# reply_markup=build_redo_menu(),
|
|
||||||
# parse_mode='HTML'
|
|
||||||
# )
|
|
||||||
|
|
||||||
embed.set_image(url=ipfs_link)
|
embed.set_image(url=ipfs_link)
|
||||||
embed.add_field(name='Parameters:', value=caption)
|
|
||||||
await send(embed=embed, view=SkynetView(self))
|
await send(embed=embed, view=SkynetView(self))
|
||||||
|
|
|
@ -41,24 +41,44 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
|
||||||
finally:
|
finally:
|
||||||
await ctx.reply(content=reply_txt, view=SkynetView(frontend))
|
await ctx.reply(content=reply_txt, view=SkynetView(frontend))
|
||||||
|
|
||||||
@bot.command(name='helper', help='Responds with a help')
|
bot.remove_command('help')
|
||||||
async def helper(ctx):
|
@bot.command(name='help', help='Responds with a help')
|
||||||
|
async def help(ctx):
|
||||||
splt_msg = ctx.message.content.split(' ')
|
splt_msg = ctx.message.content.split(' ')
|
||||||
|
|
||||||
if len(splt_msg) == 1:
|
if len(splt_msg) == 1:
|
||||||
await ctx.reply(content=HELP_TEXT, view=SkynetView(frontend))
|
await ctx.send(content=f'```{HELP_TEXT}```', view=SkynetView(frontend))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
param = splt_msg[1]
|
param = splt_msg[1]
|
||||||
if param in HELP_TOPICS:
|
if param in HELP_TOPICS:
|
||||||
await ctx.reply(content=HELP_TOPICS[param], view=SkynetView(frontend))
|
await ctx.send(content=f'```{HELP_TOPICS[param]}```', view=SkynetView(frontend))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
await ctx.reply(content=HELP_UNKWNOWN_PARAM, view=SkynetView(frontend))
|
await ctx.send(content=f'```{HELP_UNKWNOWN_PARAM}```', view=SkynetView(frontend))
|
||||||
|
|
||||||
@bot.command(name='cool', help='Display a list of cool prompt words')
|
@bot.command(name='cool', help='Display a list of cool prompt words')
|
||||||
async def send_cool_words(ctx):
|
async def send_cool_words(ctx):
|
||||||
await ctx.reply(content='\n'.join(CLEAN_COOL_WORDS), view=SkynetView(frontend))
|
clean_cool_word = '\n'.join(CLEAN_COOL_WORDS)
|
||||||
|
await ctx.send(content=f'```{clean_cool_word}```', view=SkynetView(frontend))
|
||||||
|
|
||||||
|
@bot.command(name='stats', help='See user statistics' )
|
||||||
|
async def user_stats(ctx):
|
||||||
|
user = ctx.author
|
||||||
|
|
||||||
|
await db_call('get_or_create_user', user.id)
|
||||||
|
generated, joined, role = await db_call('get_user_stats', user.id)
|
||||||
|
|
||||||
|
stats_str = f'```generated: {generated}\n'
|
||||||
|
stats_str += f'joined: {joined}\n'
|
||||||
|
stats_str += f'role: {role}\n```'
|
||||||
|
|
||||||
|
await ctx.reply(stats_str, view=SkynetView(frontend))
|
||||||
|
|
||||||
|
@bot.command(name='donate', help='See donate info')
|
||||||
|
async def donation_info(ctx):
|
||||||
|
await ctx.reply(
|
||||||
|
f'```\n{DONATION_INFO}```', view=SkynetView(frontend))
|
||||||
|
|
||||||
@bot.command(name='txt2img', help='Responds with an image')
|
@bot.command(name='txt2img', help='Responds with an image')
|
||||||
async def send_txt2img(ctx):
|
async def send_txt2img(ctx):
|
||||||
|
@ -69,7 +89,7 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
|
||||||
|
|
||||||
# init new msg
|
# init new msg
|
||||||
init_msg = 'started processing txt2img request...'
|
init_msg = 'started processing txt2img request...'
|
||||||
status_msg = await ctx.reply(init_msg)
|
status_msg = await ctx.send(init_msg)
|
||||||
await db_call(
|
await db_call(
|
||||||
'new_user_request', user.id, ctx.message.id, status_msg.id, status=init_msg)
|
'new_user_request', user.id, ctx.message.id, status_msg.id, status=init_msg)
|
||||||
|
|
||||||
|
@ -97,13 +117,13 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
|
||||||
|
|
||||||
ec = await work_request(user, status_msg, 'txt2img', params, ctx)
|
ec = await work_request(user, status_msg, 'txt2img', params, ctx)
|
||||||
|
|
||||||
if ec == 0:
|
if ec == None:
|
||||||
await db_call('increment_generated', user.id)
|
await db_call('increment_generated', user.id)
|
||||||
|
|
||||||
@bot.command(name='redo', help='Redo last request')
|
@bot.command(name='redo', help='Redo last request')
|
||||||
async def redo(ctx):
|
async def redo(ctx):
|
||||||
init_msg = 'started processing redo request...'
|
init_msg = 'started processing redo request...'
|
||||||
status_msg = await ctx.reply(init_msg)
|
status_msg = await ctx.send(init_msg)
|
||||||
user = ctx.author
|
user = ctx.author
|
||||||
|
|
||||||
method = await db_call('get_last_method_of', user.id)
|
method = await db_call('get_last_method_of', user.id)
|
||||||
|
@ -116,8 +136,9 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
|
||||||
binary = await db_call('get_last_binary_of', user.id)
|
binary = await db_call('get_last_binary_of', user.id)
|
||||||
|
|
||||||
if not prompt:
|
if not prompt:
|
||||||
await ctx.reply(
|
await status_msg.edit(
|
||||||
'no last prompt found, do a txt2img cmd first!'
|
content='no last prompt found, do a txt2img cmd first!',
|
||||||
|
view=SkynetView(frontend)
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -132,12 +153,106 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
|
||||||
**user_config
|
**user_config
|
||||||
}
|
}
|
||||||
|
|
||||||
await work_request(
|
ec = await work_request(
|
||||||
user, status_msg, 'redo', params, ctx,
|
user, status_msg, 'redo', params, ctx,
|
||||||
file_id=file_id,
|
file_id=file_id,
|
||||||
binary_data=binary
|
binary_data=binary
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if ec == None:
|
||||||
|
await db_call('increment_generated', user.id)
|
||||||
|
|
||||||
|
@bot.command(name='img2img', help='Responds with an image')
|
||||||
|
async def send_img2img(ctx):
|
||||||
|
# if isinstance(message_or_query, CallbackQuery):
|
||||||
|
# query = message_or_query
|
||||||
|
# message = query.message
|
||||||
|
# user = query.from_user
|
||||||
|
# chat = query.message.chat
|
||||||
|
#
|
||||||
|
# else:
|
||||||
|
# message = message_or_query
|
||||||
|
# user = message.from_user
|
||||||
|
# chat = message.chat
|
||||||
|
|
||||||
|
# reply_id = None
|
||||||
|
# if chat.type == 'group' and chat.id == GROUP_ID:
|
||||||
|
# reply_id = message.message_id
|
||||||
|
#
|
||||||
|
user = ctx.author
|
||||||
|
user_row = await db_call('get_or_create_user', user.id)
|
||||||
|
|
||||||
|
# init new msg
|
||||||
|
init_msg = 'started processing img2img request...'
|
||||||
|
status_msg = await ctx.send(init_msg)
|
||||||
|
await db_call(
|
||||||
|
'new_user_request', user.id, ctx.message.id, status_msg.id, status=init_msg)
|
||||||
|
|
||||||
|
if not ctx.message.content.startswith('/img2img'):
|
||||||
|
await ctx.reply(
|
||||||
|
'For image to image you need to add /img2img to the beggining of your caption'
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
prompt = ' '.join(ctx.message.content.split(' ')[1:])
|
||||||
|
|
||||||
|
if len(prompt) == 0:
|
||||||
|
await ctx.reply('Empty text prompt ignored.')
|
||||||
|
return
|
||||||
|
|
||||||
|
# file_id = message.photo[-1].file_id
|
||||||
|
# file_path = (await bot.get_file(file_id)).file_path
|
||||||
|
# image_raw = await bot.download_file(file_path)
|
||||||
|
#
|
||||||
|
|
||||||
|
file = ctx.message.attachments[-1]
|
||||||
|
file_id = str(file.id)
|
||||||
|
# file bytes
|
||||||
|
image_raw = await file.read()
|
||||||
|
with Image.open(io.BytesIO(image_raw)) as image:
|
||||||
|
w, h = image.size
|
||||||
|
|
||||||
|
if w > 512 or h > 512:
|
||||||
|
logging.warning(f'user sent img of size {image.size}')
|
||||||
|
image.thumbnail((512, 512))
|
||||||
|
logging.warning(f'resized it to {image.size}')
|
||||||
|
|
||||||
|
image.save(f'ipfs-docker-staging/image.png', format='PNG')
|
||||||
|
|
||||||
|
ipfs_hash = ipfs_node.add('image.png')
|
||||||
|
ipfs_node.pin(ipfs_hash)
|
||||||
|
|
||||||
|
logging.info(f'published input image {ipfs_hash} on ipfs')
|
||||||
|
|
||||||
|
logging.info(f'mid: {ctx.message.id}')
|
||||||
|
|
||||||
|
user_config = {**user_row}
|
||||||
|
del user_config['id']
|
||||||
|
|
||||||
|
params = {
|
||||||
|
'prompt': prompt,
|
||||||
|
**user_config
|
||||||
|
}
|
||||||
|
|
||||||
|
await db_call(
|
||||||
|
'update_user_stats',
|
||||||
|
user.id,
|
||||||
|
'img2img',
|
||||||
|
last_file=file_id,
|
||||||
|
last_prompt=prompt,
|
||||||
|
last_binary=ipfs_hash
|
||||||
|
)
|
||||||
|
|
||||||
|
ec = await work_request(
|
||||||
|
user, status_msg, 'img2img', params, ctx,
|
||||||
|
file_id=file_id,
|
||||||
|
binary_data=ipfs_hash
|
||||||
|
)
|
||||||
|
|
||||||
|
if ec == None:
|
||||||
|
await db_call('increment_generated', user.id)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: DELETE BELOW
|
# TODO: DELETE BELOW
|
||||||
# user = 'testworker3'
|
# user = 'testworker3'
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
|
import io
|
||||||
import discord
|
import discord
|
||||||
|
from PIL import Image
|
||||||
import logging
|
import logging
|
||||||
from skynet.constants import *
|
from skynet.constants import *
|
||||||
from skynet.frontend import validate_user_config_request
|
from skynet.frontend import validate_user_config_request
|
||||||
|
@ -9,11 +11,14 @@ class SkynetView(discord.ui.View):
|
||||||
def __init__(self, bot):
|
def __init__(self, bot):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
super().__init__(timeout=None)
|
super().__init__(timeout=None)
|
||||||
self.add_item(RedoButton('redo', discord.ButtonStyle.green, self.bot))
|
self.add_item(RedoButton('redo', discord.ButtonStyle.primary, self.bot))
|
||||||
self.add_item(Txt2ImgButton('txt2img', discord.ButtonStyle.green, self.bot))
|
self.add_item(Txt2ImgButton('txt2img', discord.ButtonStyle.primary, self.bot))
|
||||||
self.add_item(ConfigButton('config', discord.ButtonStyle.grey, self.bot))
|
self.add_item(Img2ImgButton('img2img', discord.ButtonStyle.primary, self.bot))
|
||||||
self.add_item(HelpButton('help', discord.ButtonStyle.grey, self.bot))
|
self.add_item(StatsButton('stats', discord.ButtonStyle.secondary, self.bot))
|
||||||
self.add_item(CoolButton('cool', discord.ButtonStyle.gray, self.bot))
|
self.add_item(DonateButton('donate', discord.ButtonStyle.secondary, self.bot))
|
||||||
|
self.add_item(ConfigButton('config', discord.ButtonStyle.secondary, self.bot))
|
||||||
|
self.add_item(HelpButton('help', discord.ButtonStyle.secondary, self.bot))
|
||||||
|
self.add_item(CoolButton('cool', discord.ButtonStyle.secondary, self.bot))
|
||||||
|
|
||||||
|
|
||||||
class Txt2ImgButton(discord.ui.Button):
|
class Txt2ImgButton(discord.ui.Button):
|
||||||
|
@ -32,7 +37,7 @@ class Txt2ImgButton(discord.ui.Button):
|
||||||
|
|
||||||
# init new msg
|
# init new msg
|
||||||
init_msg = 'started processing txt2img request...'
|
init_msg = 'started processing txt2img request...'
|
||||||
status_msg = await msg.reply(init_msg)
|
status_msg = await msg.channel.send(init_msg)
|
||||||
await db_call(
|
await db_call(
|
||||||
'new_user_request', user.id, msg.id, status_msg.id, status=init_msg)
|
'new_user_request', user.id, msg.id, status_msg.id, status=init_msg)
|
||||||
|
|
||||||
|
@ -60,7 +65,93 @@ class Txt2ImgButton(discord.ui.Button):
|
||||||
|
|
||||||
ec = await work_request(user, status_msg, 'txt2img', params, msg)
|
ec = await work_request(user, status_msg, 'txt2img', params, msg)
|
||||||
|
|
||||||
if ec == 0:
|
if ec == None:
|
||||||
|
await db_call('increment_generated', user.id)
|
||||||
|
|
||||||
|
|
||||||
|
class Img2ImgButton(discord.ui.Button):
|
||||||
|
|
||||||
|
def __init__(self, label: str, style: discord.ButtonStyle, bot):
|
||||||
|
self.bot = bot
|
||||||
|
super().__init__(label=label, style=style)
|
||||||
|
|
||||||
|
async def callback(self, interaction):
|
||||||
|
db_call = self.bot.db_call
|
||||||
|
work_request = self.bot.work_request
|
||||||
|
ipfs_node = self.bot.ipfs_node
|
||||||
|
msg = await grab('Attach an Image. Enter your prompt:', interaction)
|
||||||
|
|
||||||
|
user = msg.author
|
||||||
|
user_row = await db_call('get_or_create_user', user.id)
|
||||||
|
|
||||||
|
# init new msg
|
||||||
|
init_msg = 'started processing img2img request...'
|
||||||
|
status_msg = await msg.channel.send(init_msg)
|
||||||
|
await db_call(
|
||||||
|
'new_user_request', user.id, msg.id, status_msg.id, status=init_msg)
|
||||||
|
|
||||||
|
# if not msg.content.startswith('/img2img'):
|
||||||
|
# await msg.reply(
|
||||||
|
# 'For image to image you need to add /img2img to the beggining of your caption'
|
||||||
|
# )
|
||||||
|
# return
|
||||||
|
|
||||||
|
prompt = msg.content
|
||||||
|
|
||||||
|
if len(prompt) == 0:
|
||||||
|
await msg.reply('Empty text prompt ignored.')
|
||||||
|
return
|
||||||
|
|
||||||
|
# file_id = message.photo[-1].file_id
|
||||||
|
# file_path = (await bot.get_file(file_id)).file_path
|
||||||
|
# image_raw = await bot.download_file(file_path)
|
||||||
|
#
|
||||||
|
|
||||||
|
file = msg.attachments[-1]
|
||||||
|
file_id = str(file.id)
|
||||||
|
# file bytes
|
||||||
|
image_raw = await file.read()
|
||||||
|
with Image.open(io.BytesIO(image_raw)) as image:
|
||||||
|
w, h = image.size
|
||||||
|
|
||||||
|
if w > 512 or h > 512:
|
||||||
|
logging.warning(f'user sent img of size {image.size}')
|
||||||
|
image.thumbnail((512, 512))
|
||||||
|
logging.warning(f'resized it to {image.size}')
|
||||||
|
|
||||||
|
image.save(f'ipfs-docker-staging/image.png', format='PNG')
|
||||||
|
|
||||||
|
ipfs_hash = ipfs_node.add('image.png')
|
||||||
|
ipfs_node.pin(ipfs_hash)
|
||||||
|
|
||||||
|
logging.info(f'published input image {ipfs_hash} on ipfs')
|
||||||
|
|
||||||
|
logging.info(f'mid: {msg.id}')
|
||||||
|
|
||||||
|
user_config = {**user_row}
|
||||||
|
del user_config['id']
|
||||||
|
|
||||||
|
params = {
|
||||||
|
'prompt': prompt,
|
||||||
|
**user_config
|
||||||
|
}
|
||||||
|
|
||||||
|
await db_call(
|
||||||
|
'update_user_stats',
|
||||||
|
user.id,
|
||||||
|
'img2img',
|
||||||
|
last_file=file_id,
|
||||||
|
last_prompt=prompt,
|
||||||
|
last_binary=ipfs_hash
|
||||||
|
)
|
||||||
|
|
||||||
|
ec = await work_request(
|
||||||
|
user, status_msg, 'img2img', params, msg,
|
||||||
|
file_id=file_id,
|
||||||
|
binary_data=ipfs_hash
|
||||||
|
)
|
||||||
|
|
||||||
|
if ec == None:
|
||||||
await db_call('increment_generated', user.id)
|
await db_call('increment_generated', user.id)
|
||||||
|
|
||||||
|
|
||||||
|
@ -88,8 +179,9 @@ class RedoButton(discord.ui.Button):
|
||||||
binary = await db_call('get_last_binary_of', user.id)
|
binary = await db_call('get_last_binary_of', user.id)
|
||||||
|
|
||||||
if not prompt:
|
if not prompt:
|
||||||
await interaction.response.edit_message(
|
await status_msg.edit(
|
||||||
'no last prompt found, do a txt2img cmd first!'
|
content='no last prompt found, do a txt2img cmd first!',
|
||||||
|
view=SkynetView(self.bot)
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -103,12 +195,15 @@ class RedoButton(discord.ui.Button):
|
||||||
'prompt': prompt,
|
'prompt': prompt,
|
||||||
**user_config
|
**user_config
|
||||||
}
|
}
|
||||||
await work_request(
|
ec = await work_request(
|
||||||
user, status_msg, 'redo', params, interaction,
|
user, status_msg, 'redo', params, interaction,
|
||||||
file_id=file_id,
|
file_id=file_id,
|
||||||
binary_data=binary
|
binary_data=binary
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if ec == None:
|
||||||
|
await db_call('increment_generated', user.id)
|
||||||
|
|
||||||
|
|
||||||
class ConfigButton(discord.ui.Button):
|
class ConfigButton(discord.ui.Button):
|
||||||
|
|
||||||
|
@ -136,7 +231,29 @@ class ConfigButton(discord.ui.Button):
|
||||||
await msg.reply(content=reply_txt, view=SkynetView(self.bot))
|
await msg.reply(content=reply_txt, view=SkynetView(self.bot))
|
||||||
|
|
||||||
|
|
||||||
class CoolButton(discord.ui.Button):
|
class StatsButton(discord.ui.Button):
|
||||||
|
|
||||||
|
def __init__(self, label: str, style: discord.ButtonStyle, bot):
|
||||||
|
self.bot = bot
|
||||||
|
super().__init__(label=label, style=style)
|
||||||
|
|
||||||
|
async def callback(self, interaction):
|
||||||
|
db_call = self.bot.db_call
|
||||||
|
|
||||||
|
user = interaction.user
|
||||||
|
|
||||||
|
await db_call('get_or_create_user', user.id)
|
||||||
|
generated, joined, role = await db_call('get_user_stats', user.id)
|
||||||
|
|
||||||
|
stats_str = f'```generated: {generated}\n'
|
||||||
|
stats_str += f'joined: {joined}\n'
|
||||||
|
stats_str += f'role: {role}\n```'
|
||||||
|
|
||||||
|
await interaction.response.send_message(
|
||||||
|
content=stats_str, view=SkynetView(self.bot))
|
||||||
|
|
||||||
|
|
||||||
|
class DonateButton(discord.ui.Button):
|
||||||
|
|
||||||
def __init__(self, label: str, style: discord.ButtonStyle, bot):
|
def __init__(self, label: str, style: discord.ButtonStyle, bot):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
|
@ -144,7 +261,20 @@ class CoolButton(discord.ui.Button):
|
||||||
|
|
||||||
async def callback(self, interaction):
|
async def callback(self, interaction):
|
||||||
await interaction.response.send_message(
|
await interaction.response.send_message(
|
||||||
content='\n'.join(CLEAN_COOL_WORDS),
|
content=f'```\n{DONATION_INFO}```',
|
||||||
|
view=SkynetView(self.bot))
|
||||||
|
|
||||||
|
|
||||||
|
class CoolButton(discord.ui.Button):
|
||||||
|
|
||||||
|
def __init__(self, label: str, style: discord.ButtonStyle, bot):
|
||||||
|
self.bot = bot
|
||||||
|
super().__init__(label=label, style=style)
|
||||||
|
|
||||||
|
async def callback(self, interaction):
|
||||||
|
clean_cool_word = '\n'.join(CLEAN_COOL_WORDS)
|
||||||
|
await interaction.response.send_message(
|
||||||
|
content=f'```{clean_cool_word}```',
|
||||||
view=SkynetView(self.bot))
|
view=SkynetView(self.bot))
|
||||||
|
|
||||||
|
|
||||||
|
@ -160,15 +290,14 @@ class HelpButton(discord.ui.Button):
|
||||||
param = msg.content
|
param = msg.content
|
||||||
|
|
||||||
if param == 'a':
|
if param == 'a':
|
||||||
await msg.reply(content=HELP_TEXT, view=SkynetView(self.bot))
|
await msg.reply(content=f'```{HELP_TEXT}```', view=SkynetView(self.bot))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if param in HELP_TOPICS:
|
if param in HELP_TOPICS:
|
||||||
await msg.reply(content=HELP_TOPICS[param], view=SkynetView(self.bot))
|
await msg.reply(content=f'```{HELP_TOPICS[param]}```', view=SkynetView(self.bot))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
await msg.reply(content=HELP_UNKWNOWN_PARAM, view=SkynetView(self.bot))
|
await msg.reply(content=f'```{HELP_UNKWNOWN_PARAM}```', view=SkynetView(self.bot))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def grab(prompt, interaction):
|
async def grab(prompt, interaction):
|
||||||
|
|
|
@ -38,27 +38,41 @@ def build_redo_menu():
|
||||||
return inline_keyboard
|
return inline_keyboard
|
||||||
|
|
||||||
|
|
||||||
def prepare_metainfo_caption(user, worker: str, reward: str, meta: dict) -> str:
|
def prepare_metainfo_caption(user, worker: str, reward: str, meta: dict, embed) -> str:
|
||||||
prompt = meta["prompt"]
|
prompt = meta["prompt"]
|
||||||
if len(prompt) > 256:
|
if len(prompt) > 256:
|
||||||
prompt = prompt[:256]
|
prompt = prompt[:256]
|
||||||
|
|
||||||
meta_str = f'__by {user.name}__\n'
|
gen_str = f'generated by {user.name}\n'
|
||||||
meta_str += f'*performed by {worker}*\n'
|
gen_str += f'performed by {worker}\n'
|
||||||
meta_str += f'__**reward: {reward}**__\n'
|
gen_str += f'reward: {reward}\n'
|
||||||
|
|
||||||
|
embed.add_field(
|
||||||
|
name='General Info', value=f'```{gen_str}```', inline=False)
|
||||||
|
# meta_str = f'__by {user.name}__\n'
|
||||||
|
# meta_str += f'*performed by {worker}*\n'
|
||||||
|
# meta_str += f'__**reward: {reward}**__\n'
|
||||||
|
embed.add_field(name='Prompt', value=f'```{prompt}\n```', inline=False)
|
||||||
|
|
||||||
|
# meta_str = f'`prompt:` {prompt}\n'
|
||||||
|
|
||||||
|
meta_str = f'seed: {meta["seed"]}\n'
|
||||||
|
meta_str += f'step: {meta["step"]}\n'
|
||||||
|
meta_str += f'guidance: {meta["guidance"]}\n'
|
||||||
|
|
||||||
meta_str += f'`prompt:` {prompt}\n'
|
|
||||||
meta_str += f'`seed: {meta["seed"]}`\n'
|
|
||||||
meta_str += f'`step: {meta["step"]}`\n'
|
|
||||||
meta_str += f'`guidance: {meta["guidance"]}`\n'
|
|
||||||
if meta['strength']:
|
if meta['strength']:
|
||||||
meta_str += f'`strength: {meta["strength"]}`\n'
|
meta_str += f'strength: {meta["strength"]}\n'
|
||||||
meta_str += f'`algo: {meta["model"]}`\n'
|
meta_str += f'algo: {meta["model"]}\n'
|
||||||
if meta['upscaler']:
|
if meta['upscaler']:
|
||||||
meta_str += f'`upscaler: {meta["upscaler"]}`\n'
|
meta_str += f'upscaler: {meta["upscaler"]}\n'
|
||||||
|
|
||||||
|
embed.add_field(name='Parameters', value=f'```{meta_str}```', inline=False)
|
||||||
|
|
||||||
|
foot_str = f'Made with Skynet v{VERSION}\n'
|
||||||
|
foot_str += f'JOIN THE SWARM: @skynetgpu'
|
||||||
|
|
||||||
|
embed.set_footer(text=foot_str)
|
||||||
|
|
||||||
meta_str += f'__**Made with Skynet v{VERSION}**__\n'
|
|
||||||
meta_str += f'**JOIN THE SWARM: @skynetgpu**'
|
|
||||||
return meta_str
|
return meta_str
|
||||||
|
|
||||||
|
|
||||||
|
@ -74,7 +88,7 @@ def generate_reply_caption(
|
||||||
url=f'https://explorer.{DEFAULT_DOMAIN}/v2/explore/transaction/{tx_hash}',
|
url=f'https://explorer.{DEFAULT_DOMAIN}/v2/explore/transaction/{tx_hash}',
|
||||||
color=discord.Color.blue())
|
color=discord.Color.blue())
|
||||||
|
|
||||||
meta_info = prepare_metainfo_caption(user, worker, reward, params)
|
meta_info = prepare_metainfo_caption(user, worker, reward, params, explorer_link)
|
||||||
|
|
||||||
# why do we have this?
|
# why do we have this?
|
||||||
final_msg = '\n'.join([
|
final_msg = '\n'.join([
|
||||||
|
|
Loading…
Reference in New Issue