From 426018720882fe6e52922aee2d079d9f5c057503 Mon Sep 17 00:00:00 2001 From: Konstantine Tsafatinos Date: Fri, 21 Jul 2023 16:57:54 -0400 Subject: [PATCH] add img2img support, add stats and donate button, finalize UI, add live updates --- skynet/constants.py | 1 + skynet/db/functions.py | 3 +- skynet/frontend/discord/__init__.py | 125 +++++++-------------- skynet/frontend/discord/handlers.py | 139 +++++++++++++++++++++--- skynet/frontend/discord/ui.py | 161 +++++++++++++++++++++++++--- skynet/frontend/discord/utils.py | 42 +++++--- 6 files changed, 344 insertions(+), 127 deletions(-) diff --git a/skynet/constants.py b/skynet/constants.py index 0ddfa65..4dc1c48 100755 --- a/skynet/constants.py +++ b/skynet/constants.py @@ -36,6 +36,7 @@ commands work on a user per user basis! config is individual to each user! /txt2img TEXT - request an image based on a prompt +/img2img TEXT - request an image base on an image and a promtp /redo - redo last command (only works for txt2img for now!) diff --git a/skynet/db/functions.py b/skynet/db/functions.py index d98b099..f52703e 100644 --- a/skynet/db/functions.py +++ b/skynet/db/functions.py @@ -96,7 +96,8 @@ def open_new_database(cleanup=True): 'POSTGRES_PASSWORD': rpassword }, detach=True, - remove=True + # could remove this if we ant the dockers to be persistent. + # remove=True ) try: diff --git a/skynet/frontend/discord/__init__.py b/skynet/frontend/discord/__init__.py index 656f84f..7ab0f17 100644 --- a/skynet/frontend/discord/__init__.py +++ b/skynet/frontend/discord/__init__.py @@ -89,6 +89,7 @@ class SkynetDiscordFrontend: yield self await self.stop() + # maybe do this? # async def update_status_message( # self, status_msg, new_text: str, **kwargs # ): @@ -139,17 +140,14 @@ class SkynetDiscordFrontend: }) request_time = datetime.now().isoformat() - # maybe get rid of this - # await self.update_status_message( - # status_msg, - # f'processing a \'{method}\' request by {tg_user_pretty(user)}\n' - # f'[{timestamp_pretty()}] broadcasting transaction to chain...', - # parse_mode='HTML' - # ) - # message = await ctx.send( - # f'processing a \'{method}\' request by {user}\n \ - # [{timestamp_pretty()}] *broadcasting transaction to chain...*' - # ) + await status_msg.delete() + msg_text = f'processing a \'{method}\' request by {user.name}\n[{timestamp_pretty()}] *broadcasting transaction to chain...* ' + embed = discord.Embed( + title='live updates', + description=msg_text, + color=discord.Color.blue()) + + message = await send(embed=embed) reward = '20.0000 GPU' res = await self.cleos.a_push_action( @@ -164,7 +162,6 @@ class SkynetDiscordFrontend: }, self.account, self.key, permission=self.permission ) - # print(res) if 'code' in res or 'statusCode' in res: logging.error(json.dumps(res, indent=4)) @@ -174,23 +171,15 @@ class SkynetDiscordFrontend: return enqueue_tx_id = res['transaction_id'] - enqueue_tx_link = hlink( - 'Your request on Skynet Explorer', - f'https://explorer.{DEFAULT_DOMAIN}/v2/explore/transaction/{enqueue_tx_id}' - ) + enqueue_tx_link = f'[**Your request on Skynet Explorer**](https://explorer.{DEFAULT_DOMAIN}/v2/explore/transaction/{enqueue_tx_id})' - # await self.append_status_message( - # status_msg, - # f' broadcasted!\n' - # f'{enqueue_tx_link}\n' - # f'[{timestamp_pretty()}] workers are processing request...', - # parse_mode='HTML' - # ) - # await message.edit(content= - # f'**broadcasted!**\n \ - # **{enqueue_tx_link}**\n \ - # [{timestamp_pretty()}] *workers are processing request...*' - # ) + msg_text += f'**broadcasted!** \n{enqueue_tx_link}\n[{timestamp_pretty()}] *workers are processing request...* ' + embed = discord.Embed( + title='live updates', + description=msg_text, + color=discord.Color.blue()) + + await message.edit(embed=embed) out = collect_stdout(res) @@ -233,77 +222,45 @@ class SkynetDiscordFrontend: await asyncio.sleep(1) if not ipfs_hash: - # await self.update_status_message( - # status_msg, - # f'\n[{timestamp_pretty()}] timeout processing request', - # parse_mode='HTML' - # ) + + msg_text += f'\n[{timestamp_pretty()}] **timeout processing request**' + embed = discord.Embed( + title='live updates', + description=msg_text, + color=discord.Color.blue()) + + await message.edit(embed=embed) return - tx_link = hlink( - 'Your result on Skynet Explorer', - f'https://explorer.{DEFAULT_DOMAIN}/v2/explore/transaction/{tx_hash}' - ) + tx_link = f'[**Your result on Skynet Explorer**](https://explorer.{DEFAULT_DOMAIN}/v2/explore/transaction/{tx_hash})' - # await self.append_status_message( - # status_msg, - # f' request processed!\n' - # f'{tx_link}\n' - # f'[{timestamp_pretty()}] trying to download image...\n', - # parse_mode='HTML' - # ) - # await message.edit(content= - # f'**request processed!**\n \ - # **{tx_link}**\n \ - # [{timestamp_pretty()}] *trying to download image...*\n' - # ) + msg_text += f'**request processed!**\n{tx_link}\n[{timestamp_pretty()}] *trying to download image...*\n ' + embed = discord.Embed( + title='live updates', + description=msg_text, + color=discord.Color.blue()) + + await message.edit(embed=embed) # attempt to get the image and send it ipfs_link = f'https://ipfs.{DEFAULT_DOMAIN}/ipfs/{ipfs_hash}/image.png' resp = await get_ipfs_file(ipfs_link) + # reword this function, may not need caption caption, embed = generate_reply_caption( user, params, tx_hash, worker, reward) if not resp or resp.status_code != 200: logging.error(f'couldn\'t get ipfs hosted image at {ipfs_link}!') - # await self.update_status_message( - # status_msg, - # caption, - # reply_markup=build_redo_menu(), - # parse_mode='HTML' - # ) - # + await message.edit(embed=embed, view=SkynetView(self)) else: logging.info(f'success! sending generated image') - # image = io.BytesIO(resp.raw) - # embed.set_image(url=ipfs_link) - # embed.add_field(name='params', value=caption) - # await self.bot.delete_message( - # chat_id=status_msg.chat.id, message_id=status_msg.id) + await message.delete() if file_id: # img2img - pass - # await self.bot.send_media_group( - # status_msg.chat.id, - # media=[ - # InputMediaPhoto(file_id), - # InputMediaPhoto( - # resp.raw, - # caption=caption, - # parse_mode='HTML' - # ) - # ], - # ) - # - else: # txt2img - # await self.bot.send_photo( - # status_msg.chat.id, - # caption=caption, - # photo=resp.raw, - # reply_markup=build_redo_menu(), - # parse_mode='HTML' - # ) - + embed.set_thumbnail( + url='https://ipfs.skygpu.net/ipfs/' + binary_data + '/image.png') + embed.set_image(url=ipfs_link) + await send(embed=embed, view=SkynetView(self)) + else: # txt2img embed.set_image(url=ipfs_link) - embed.add_field(name='Parameters:', value=caption) await send(embed=embed, view=SkynetView(self)) diff --git a/skynet/frontend/discord/handlers.py b/skynet/frontend/discord/handlers.py index dbf19f4..c6f2735 100644 --- a/skynet/frontend/discord/handlers.py +++ b/skynet/frontend/discord/handlers.py @@ -41,24 +41,44 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'): finally: await ctx.reply(content=reply_txt, view=SkynetView(frontend)) - @bot.command(name='helper', help='Responds with a help') - async def helper(ctx): + bot.remove_command('help') + @bot.command(name='help', help='Responds with a help') + async def help(ctx): splt_msg = ctx.message.content.split(' ') if len(splt_msg) == 1: - await ctx.reply(content=HELP_TEXT, view=SkynetView(frontend)) + await ctx.send(content=f'```{HELP_TEXT}```', view=SkynetView(frontend)) else: param = splt_msg[1] if param in HELP_TOPICS: - await ctx.reply(content=HELP_TOPICS[param], view=SkynetView(frontend)) + await ctx.send(content=f'```{HELP_TOPICS[param]}```', view=SkynetView(frontend)) else: - await ctx.reply(content=HELP_UNKWNOWN_PARAM, view=SkynetView(frontend)) + await ctx.send(content=f'```{HELP_UNKWNOWN_PARAM}```', view=SkynetView(frontend)) @bot.command(name='cool', help='Display a list of cool prompt words') async def send_cool_words(ctx): - await ctx.reply(content='\n'.join(CLEAN_COOL_WORDS), view=SkynetView(frontend)) + clean_cool_word = '\n'.join(CLEAN_COOL_WORDS) + await ctx.send(content=f'```{clean_cool_word}```', view=SkynetView(frontend)) + + @bot.command(name='stats', help='See user statistics' ) + async def user_stats(ctx): + user = ctx.author + + await db_call('get_or_create_user', user.id) + generated, joined, role = await db_call('get_user_stats', user.id) + + stats_str = f'```generated: {generated}\n' + stats_str += f'joined: {joined}\n' + stats_str += f'role: {role}\n```' + + await ctx.reply(stats_str, view=SkynetView(frontend)) + + @bot.command(name='donate', help='See donate info') + async def donation_info(ctx): + await ctx.reply( + f'```\n{DONATION_INFO}```', view=SkynetView(frontend)) @bot.command(name='txt2img', help='Responds with an image') async def send_txt2img(ctx): @@ -69,7 +89,7 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'): # init new msg init_msg = 'started processing txt2img request...' - status_msg = await ctx.reply(init_msg) + status_msg = await ctx.send(init_msg) await db_call( 'new_user_request', user.id, ctx.message.id, status_msg.id, status=init_msg) @@ -97,13 +117,13 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'): ec = await work_request(user, status_msg, 'txt2img', params, ctx) - if ec == 0: + if ec == None: await db_call('increment_generated', user.id) @bot.command(name='redo', help='Redo last request') async def redo(ctx): init_msg = 'started processing redo request...' - status_msg = await ctx.reply(init_msg) + status_msg = await ctx.send(init_msg) user = ctx.author method = await db_call('get_last_method_of', user.id) @@ -116,8 +136,9 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'): binary = await db_call('get_last_binary_of', user.id) if not prompt: - await ctx.reply( - 'no last prompt found, do a txt2img cmd first!' + await status_msg.edit( + content='no last prompt found, do a txt2img cmd first!', + view=SkynetView(frontend) ) return @@ -132,12 +153,106 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'): **user_config } - await work_request( + ec = await work_request( user, status_msg, 'redo', params, ctx, file_id=file_id, binary_data=binary ) + if ec == None: + await db_call('increment_generated', user.id) + + @bot.command(name='img2img', help='Responds with an image') + async def send_img2img(ctx): + # if isinstance(message_or_query, CallbackQuery): + # query = message_or_query + # message = query.message + # user = query.from_user + # chat = query.message.chat + # + # else: + # message = message_or_query + # user = message.from_user + # chat = message.chat + + # reply_id = None + # if chat.type == 'group' and chat.id == GROUP_ID: + # reply_id = message.message_id + # + user = ctx.author + user_row = await db_call('get_or_create_user', user.id) + + # init new msg + init_msg = 'started processing img2img request...' + status_msg = await ctx.send(init_msg) + await db_call( + 'new_user_request', user.id, ctx.message.id, status_msg.id, status=init_msg) + + if not ctx.message.content.startswith('/img2img'): + await ctx.reply( + 'For image to image you need to add /img2img to the beggining of your caption' + ) + return + + prompt = ' '.join(ctx.message.content.split(' ')[1:]) + + if len(prompt) == 0: + await ctx.reply('Empty text prompt ignored.') + return + + # file_id = message.photo[-1].file_id + # file_path = (await bot.get_file(file_id)).file_path + # image_raw = await bot.download_file(file_path) + # + + file = ctx.message.attachments[-1] + file_id = str(file.id) + # file bytes + image_raw = await file.read() + with Image.open(io.BytesIO(image_raw)) as image: + w, h = image.size + + if w > 512 or h > 512: + logging.warning(f'user sent img of size {image.size}') + image.thumbnail((512, 512)) + logging.warning(f'resized it to {image.size}') + + image.save(f'ipfs-docker-staging/image.png', format='PNG') + + ipfs_hash = ipfs_node.add('image.png') + ipfs_node.pin(ipfs_hash) + + logging.info(f'published input image {ipfs_hash} on ipfs') + + logging.info(f'mid: {ctx.message.id}') + + user_config = {**user_row} + del user_config['id'] + + params = { + 'prompt': prompt, + **user_config + } + + await db_call( + 'update_user_stats', + user.id, + 'img2img', + last_file=file_id, + last_prompt=prompt, + last_binary=ipfs_hash + ) + + ec = await work_request( + user, status_msg, 'img2img', params, ctx, + file_id=file_id, + binary_data=ipfs_hash + ) + + if ec == None: + await db_call('increment_generated', user.id) + + # TODO: DELETE BELOW # user = 'testworker3' diff --git a/skynet/frontend/discord/ui.py b/skynet/frontend/discord/ui.py index 0f9452e..95d71b1 100644 --- a/skynet/frontend/discord/ui.py +++ b/skynet/frontend/discord/ui.py @@ -1,4 +1,6 @@ +import io import discord +from PIL import Image import logging from skynet.constants import * from skynet.frontend import validate_user_config_request @@ -9,11 +11,14 @@ class SkynetView(discord.ui.View): def __init__(self, bot): self.bot = bot super().__init__(timeout=None) - self.add_item(RedoButton('redo', discord.ButtonStyle.green, self.bot)) - self.add_item(Txt2ImgButton('txt2img', discord.ButtonStyle.green, self.bot)) - self.add_item(ConfigButton('config', discord.ButtonStyle.grey, self.bot)) - self.add_item(HelpButton('help', discord.ButtonStyle.grey, self.bot)) - self.add_item(CoolButton('cool', discord.ButtonStyle.gray, self.bot)) + self.add_item(RedoButton('redo', discord.ButtonStyle.primary, self.bot)) + self.add_item(Txt2ImgButton('txt2img', discord.ButtonStyle.primary, self.bot)) + self.add_item(Img2ImgButton('img2img', discord.ButtonStyle.primary, self.bot)) + self.add_item(StatsButton('stats', discord.ButtonStyle.secondary, self.bot)) + self.add_item(DonateButton('donate', discord.ButtonStyle.secondary, self.bot)) + self.add_item(ConfigButton('config', discord.ButtonStyle.secondary, self.bot)) + self.add_item(HelpButton('help', discord.ButtonStyle.secondary, self.bot)) + self.add_item(CoolButton('cool', discord.ButtonStyle.secondary, self.bot)) class Txt2ImgButton(discord.ui.Button): @@ -32,7 +37,7 @@ class Txt2ImgButton(discord.ui.Button): # init new msg init_msg = 'started processing txt2img request...' - status_msg = await msg.reply(init_msg) + status_msg = await msg.channel.send(init_msg) await db_call( 'new_user_request', user.id, msg.id, status_msg.id, status=init_msg) @@ -60,7 +65,93 @@ class Txt2ImgButton(discord.ui.Button): ec = await work_request(user, status_msg, 'txt2img', params, msg) - if ec == 0: + if ec == None: + await db_call('increment_generated', user.id) + + +class Img2ImgButton(discord.ui.Button): + + def __init__(self, label: str, style: discord.ButtonStyle, bot): + self.bot = bot + super().__init__(label=label, style=style) + + async def callback(self, interaction): + db_call = self.bot.db_call + work_request = self.bot.work_request + ipfs_node = self.bot.ipfs_node + msg = await grab('Attach an Image. Enter your prompt:', interaction) + + user = msg.author + user_row = await db_call('get_or_create_user', user.id) + + # init new msg + init_msg = 'started processing img2img request...' + status_msg = await msg.channel.send(init_msg) + await db_call( + 'new_user_request', user.id, msg.id, status_msg.id, status=init_msg) + + # if not msg.content.startswith('/img2img'): + # await msg.reply( + # 'For image to image you need to add /img2img to the beggining of your caption' + # ) + # return + + prompt = msg.content + + if len(prompt) == 0: + await msg.reply('Empty text prompt ignored.') + return + + # file_id = message.photo[-1].file_id + # file_path = (await bot.get_file(file_id)).file_path + # image_raw = await bot.download_file(file_path) + # + + file = msg.attachments[-1] + file_id = str(file.id) + # file bytes + image_raw = await file.read() + with Image.open(io.BytesIO(image_raw)) as image: + w, h = image.size + + if w > 512 or h > 512: + logging.warning(f'user sent img of size {image.size}') + image.thumbnail((512, 512)) + logging.warning(f'resized it to {image.size}') + + image.save(f'ipfs-docker-staging/image.png', format='PNG') + + ipfs_hash = ipfs_node.add('image.png') + ipfs_node.pin(ipfs_hash) + + logging.info(f'published input image {ipfs_hash} on ipfs') + + logging.info(f'mid: {msg.id}') + + user_config = {**user_row} + del user_config['id'] + + params = { + 'prompt': prompt, + **user_config + } + + await db_call( + 'update_user_stats', + user.id, + 'img2img', + last_file=file_id, + last_prompt=prompt, + last_binary=ipfs_hash + ) + + ec = await work_request( + user, status_msg, 'img2img', params, msg, + file_id=file_id, + binary_data=ipfs_hash + ) + + if ec == None: await db_call('increment_generated', user.id) @@ -88,8 +179,9 @@ class RedoButton(discord.ui.Button): binary = await db_call('get_last_binary_of', user.id) if not prompt: - await interaction.response.edit_message( - 'no last prompt found, do a txt2img cmd first!' + await status_msg.edit( + content='no last prompt found, do a txt2img cmd first!', + view=SkynetView(self.bot) ) return @@ -103,12 +195,15 @@ class RedoButton(discord.ui.Button): 'prompt': prompt, **user_config } - await work_request( + ec = await work_request( user, status_msg, 'redo', params, interaction, file_id=file_id, binary_data=binary ) + if ec == None: + await db_call('increment_generated', user.id) + class ConfigButton(discord.ui.Button): @@ -136,7 +231,29 @@ class ConfigButton(discord.ui.Button): await msg.reply(content=reply_txt, view=SkynetView(self.bot)) -class CoolButton(discord.ui.Button): +class StatsButton(discord.ui.Button): + + def __init__(self, label: str, style: discord.ButtonStyle, bot): + self.bot = bot + super().__init__(label=label, style=style) + + async def callback(self, interaction): + db_call = self.bot.db_call + + user = interaction.user + + await db_call('get_or_create_user', user.id) + generated, joined, role = await db_call('get_user_stats', user.id) + + stats_str = f'```generated: {generated}\n' + stats_str += f'joined: {joined}\n' + stats_str += f'role: {role}\n```' + + await interaction.response.send_message( + content=stats_str, view=SkynetView(self.bot)) + + +class DonateButton(discord.ui.Button): def __init__(self, label: str, style: discord.ButtonStyle, bot): self.bot = bot @@ -144,7 +261,20 @@ class CoolButton(discord.ui.Button): async def callback(self, interaction): await interaction.response.send_message( - content='\n'.join(CLEAN_COOL_WORDS), + content=f'```\n{DONATION_INFO}```', + view=SkynetView(self.bot)) + + +class CoolButton(discord.ui.Button): + + def __init__(self, label: str, style: discord.ButtonStyle, bot): + self.bot = bot + super().__init__(label=label, style=style) + + async def callback(self, interaction): + clean_cool_word = '\n'.join(CLEAN_COOL_WORDS) + await interaction.response.send_message( + content=f'```{clean_cool_word}```', view=SkynetView(self.bot)) @@ -160,15 +290,14 @@ class HelpButton(discord.ui.Button): param = msg.content if param == 'a': - await msg.reply(content=HELP_TEXT, view=SkynetView(self.bot)) + await msg.reply(content=f'```{HELP_TEXT}```', view=SkynetView(self.bot)) else: if param in HELP_TOPICS: - await msg.reply(content=HELP_TOPICS[param], view=SkynetView(self.bot)) + await msg.reply(content=f'```{HELP_TOPICS[param]}```', view=SkynetView(self.bot)) else: - await msg.reply(content=HELP_UNKWNOWN_PARAM, view=SkynetView(self.bot)) - + await msg.reply(content=f'```{HELP_UNKWNOWN_PARAM}```', view=SkynetView(self.bot)) async def grab(prompt, interaction): diff --git a/skynet/frontend/discord/utils.py b/skynet/frontend/discord/utils.py index 81724b9..1fd6618 100644 --- a/skynet/frontend/discord/utils.py +++ b/skynet/frontend/discord/utils.py @@ -38,27 +38,41 @@ def build_redo_menu(): return inline_keyboard -def prepare_metainfo_caption(user, worker: str, reward: str, meta: dict) -> str: +def prepare_metainfo_caption(user, worker: str, reward: str, meta: dict, embed) -> str: prompt = meta["prompt"] if len(prompt) > 256: prompt = prompt[:256] + + gen_str = f'generated by {user.name}\n' + gen_str += f'performed by {worker}\n' + gen_str += f'reward: {reward}\n' - meta_str = f'__by {user.name}__\n' - meta_str += f'*performed by {worker}*\n' - meta_str += f'__**reward: {reward}**__\n' + embed.add_field( + name='General Info', value=f'```{gen_str}```', inline=False) + # meta_str = f'__by {user.name}__\n' + # meta_str += f'*performed by {worker}*\n' + # meta_str += f'__**reward: {reward}**__\n' + embed.add_field(name='Prompt', value=f'```{prompt}\n```', inline=False) + + # meta_str = f'`prompt:` {prompt}\n' + + meta_str = f'seed: {meta["seed"]}\n' + meta_str += f'step: {meta["step"]}\n' + meta_str += f'guidance: {meta["guidance"]}\n' - meta_str += f'`prompt:` {prompt}\n' - meta_str += f'`seed: {meta["seed"]}`\n' - meta_str += f'`step: {meta["step"]}`\n' - meta_str += f'`guidance: {meta["guidance"]}`\n' if meta['strength']: - meta_str += f'`strength: {meta["strength"]}`\n' - meta_str += f'`algo: {meta["model"]}`\n' + meta_str += f'strength: {meta["strength"]}\n' + meta_str += f'algo: {meta["model"]}\n' if meta['upscaler']: - meta_str += f'`upscaler: {meta["upscaler"]}`\n' + meta_str += f'upscaler: {meta["upscaler"]}\n' + + embed.add_field(name='Parameters', value=f'```{meta_str}```', inline=False) + + foot_str = f'Made with Skynet v{VERSION}\n' + foot_str += f'JOIN THE SWARM: @skynetgpu' + + embed.set_footer(text=foot_str) - meta_str += f'__**Made with Skynet v{VERSION}**__\n' - meta_str += f'**JOIN THE SWARM: @skynetgpu**' return meta_str @@ -74,7 +88,7 @@ def generate_reply_caption( url=f'https://explorer.{DEFAULT_DOMAIN}/v2/explore/transaction/{tx_hash}', color=discord.Color.blue()) - meta_info = prepare_metainfo_caption(user, worker, reward, params) + meta_info = prepare_metainfo_caption(user, worker, reward, params, explorer_link) # why do we have this? final_msg = '\n'.join([