mirror of https://github.com/skygpu/skynet.git
				
				
				
			add img2img support, add stats and donate button, finalize UI, add live updates
							parent
							
								
									58c6a2070e
								
							
						
					
					
						commit
						4260187208
					
				| 
						 | 
				
			
			@ -36,6 +36,7 @@ commands work on a user per user basis!
 | 
			
		|||
config is individual to each user!
 | 
			
		||||
 | 
			
		||||
/txt2img TEXT - request an image based on a prompt
 | 
			
		||||
/img2img <attach_image> TEXT - request an image base on an image and a promtp
 | 
			
		||||
 | 
			
		||||
/redo - redo last command (only works for txt2img for now!)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -96,7 +96,8 @@ def open_new_database(cleanup=True):
 | 
			
		|||
            'POSTGRES_PASSWORD': rpassword
 | 
			
		||||
        },
 | 
			
		||||
        detach=True,
 | 
			
		||||
        remove=True
 | 
			
		||||
        # could remove this if we ant the dockers to be persistent.
 | 
			
		||||
        # remove=True
 | 
			
		||||
    )
 | 
			
		||||
    try:
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -89,6 +89,7 @@ class SkynetDiscordFrontend:
 | 
			
		|||
        yield self
 | 
			
		||||
        await self.stop()
 | 
			
		||||
 | 
			
		||||
    # maybe do this?
 | 
			
		||||
    # async def update_status_message(
 | 
			
		||||
    #     self, status_msg, new_text: str, **kwargs
 | 
			
		||||
    # ):
 | 
			
		||||
| 
						 | 
				
			
			@ -139,17 +140,14 @@ class SkynetDiscordFrontend:
 | 
			
		|||
        })
 | 
			
		||||
        request_time = datetime.now().isoformat()
 | 
			
		||||
 | 
			
		||||
        # maybe get rid of this
 | 
			
		||||
        # await self.update_status_message(
 | 
			
		||||
        #     status_msg,
 | 
			
		||||
        #     f'processing a \'{method}\' request by {tg_user_pretty(user)}\n'
 | 
			
		||||
        #     f'[{timestamp_pretty()}] <i>broadcasting transaction to chain...</i>',
 | 
			
		||||
        #     parse_mode='HTML'
 | 
			
		||||
        # )
 | 
			
		||||
        # message = await ctx.send(
 | 
			
		||||
        #     f'processing a \'{method}\' request by {user}\n \
 | 
			
		||||
        #     [{timestamp_pretty()}] *broadcasting transaction to chain...*'
 | 
			
		||||
        # )
 | 
			
		||||
        await status_msg.delete()
 | 
			
		||||
        msg_text = f'processing a \'{method}\' request by {user.name}\n[{timestamp_pretty()}] *broadcasting transaction to chain...* '
 | 
			
		||||
        embed = discord.Embed(
 | 
			
		||||
            title='live updates',
 | 
			
		||||
            description=msg_text,
 | 
			
		||||
            color=discord.Color.blue())
 | 
			
		||||
 | 
			
		||||
        message = await send(embed=embed)
 | 
			
		||||
 | 
			
		||||
        reward = '20.0000 GPU'
 | 
			
		||||
        res = await self.cleos.a_push_action(
 | 
			
		||||
| 
						 | 
				
			
			@ -164,7 +162,6 @@ class SkynetDiscordFrontend:
 | 
			
		|||
            },
 | 
			
		||||
            self.account, self.key, permission=self.permission
 | 
			
		||||
        )
 | 
			
		||||
        # print(res)
 | 
			
		||||
 | 
			
		||||
        if 'code' in res or 'statusCode' in res:
 | 
			
		||||
            logging.error(json.dumps(res, indent=4))
 | 
			
		||||
| 
						 | 
				
			
			@ -174,23 +171,15 @@ class SkynetDiscordFrontend:
 | 
			
		|||
            return
 | 
			
		||||
 | 
			
		||||
        enqueue_tx_id = res['transaction_id']
 | 
			
		||||
        enqueue_tx_link = hlink(
 | 
			
		||||
            'Your request on Skynet Explorer',
 | 
			
		||||
            f'https://explorer.{DEFAULT_DOMAIN}/v2/explore/transaction/{enqueue_tx_id}'
 | 
			
		||||
        )
 | 
			
		||||
        enqueue_tx_link = f'[**Your request on Skynet Explorer**](https://explorer.{DEFAULT_DOMAIN}/v2/explore/transaction/{enqueue_tx_id})'
 | 
			
		||||
 | 
			
		||||
        # await self.append_status_message(
 | 
			
		||||
        #     status_msg,
 | 
			
		||||
        #     f' <b>broadcasted!</b>\n'
 | 
			
		||||
        #     f'<b>{enqueue_tx_link}</b>\n'
 | 
			
		||||
        #     f'[{timestamp_pretty()}] <i>workers are processing request...</i>',
 | 
			
		||||
        #     parse_mode='HTML'
 | 
			
		||||
        # )
 | 
			
		||||
        # await message.edit(content=
 | 
			
		||||
        #     f'**broadcasted!**\n \
 | 
			
		||||
        #     **{enqueue_tx_link}**\n \
 | 
			
		||||
        #     [{timestamp_pretty()}] *workers are processing request...*'
 | 
			
		||||
        # )
 | 
			
		||||
        msg_text += f'**broadcasted!** \n{enqueue_tx_link}\n[{timestamp_pretty()}] *workers are processing request...* '
 | 
			
		||||
        embed = discord.Embed(
 | 
			
		||||
            title='live updates',
 | 
			
		||||
            description=msg_text,
 | 
			
		||||
            color=discord.Color.blue())
 | 
			
		||||
 | 
			
		||||
        await message.edit(embed=embed)
 | 
			
		||||
 | 
			
		||||
        out = collect_stdout(res)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -233,77 +222,45 @@ class SkynetDiscordFrontend:
 | 
			
		|||
            await asyncio.sleep(1)
 | 
			
		||||
 | 
			
		||||
        if not ipfs_hash:
 | 
			
		||||
            # await self.update_status_message(
 | 
			
		||||
            #     status_msg,
 | 
			
		||||
            #     f'\n[{timestamp_pretty()}] <b>timeout processing request</b>',
 | 
			
		||||
            #     parse_mode='HTML'
 | 
			
		||||
            # )
 | 
			
		||||
 | 
			
		||||
            msg_text += f'\n[{timestamp_pretty()}] **timeout processing request**'
 | 
			
		||||
            embed = discord.Embed(
 | 
			
		||||
                title='live updates',
 | 
			
		||||
                description=msg_text,
 | 
			
		||||
                color=discord.Color.blue())
 | 
			
		||||
 | 
			
		||||
            await message.edit(embed=embed)
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        tx_link = hlink(
 | 
			
		||||
            'Your result on Skynet Explorer',
 | 
			
		||||
            f'https://explorer.{DEFAULT_DOMAIN}/v2/explore/transaction/{tx_hash}'
 | 
			
		||||
        )
 | 
			
		||||
        tx_link = f'[**Your result on Skynet Explorer**](https://explorer.{DEFAULT_DOMAIN}/v2/explore/transaction/{tx_hash})'
 | 
			
		||||
 | 
			
		||||
        # await self.append_status_message(
 | 
			
		||||
        #     status_msg,
 | 
			
		||||
        #     f' <b>request processed!</b>\n'
 | 
			
		||||
        #     f'<b>{tx_link}</b>\n'
 | 
			
		||||
        #     f'[{timestamp_pretty()}] <i>trying to download image...</i>\n',
 | 
			
		||||
        #     parse_mode='HTML'
 | 
			
		||||
        # )
 | 
			
		||||
        # await message.edit(content=
 | 
			
		||||
        #     f'**request processed!**\n \
 | 
			
		||||
        #     **{tx_link}**\n \
 | 
			
		||||
        #     [{timestamp_pretty()}] *trying to download image...*\n'
 | 
			
		||||
        # )
 | 
			
		||||
        msg_text += f'**request processed!**\n{tx_link}\n[{timestamp_pretty()}] *trying to download image...*\n '
 | 
			
		||||
        embed = discord.Embed(
 | 
			
		||||
            title='live updates',
 | 
			
		||||
            description=msg_text,
 | 
			
		||||
            color=discord.Color.blue())
 | 
			
		||||
 | 
			
		||||
        await message.edit(embed=embed)
 | 
			
		||||
 | 
			
		||||
        # attempt to get the image and send it
 | 
			
		||||
        ipfs_link = f'https://ipfs.{DEFAULT_DOMAIN}/ipfs/{ipfs_hash}/image.png'
 | 
			
		||||
        resp = await get_ipfs_file(ipfs_link)
 | 
			
		||||
 | 
			
		||||
        # reword this function, may not need caption
 | 
			
		||||
        caption, embed = generate_reply_caption(
 | 
			
		||||
            user, params, tx_hash, worker, reward)
 | 
			
		||||
 | 
			
		||||
        if not resp or resp.status_code != 200:
 | 
			
		||||
            logging.error(f'couldn\'t get ipfs hosted image at {ipfs_link}!')
 | 
			
		||||
            # await self.update_status_message(
 | 
			
		||||
            #     status_msg,
 | 
			
		||||
            #     caption,
 | 
			
		||||
            #     reply_markup=build_redo_menu(),
 | 
			
		||||
            #     parse_mode='HTML'
 | 
			
		||||
            # )
 | 
			
		||||
            #
 | 
			
		||||
            await message.edit(embed=embed, view=SkynetView(self))
 | 
			
		||||
        else:
 | 
			
		||||
            logging.info(f'success! sending generated image')
 | 
			
		||||
            # image = io.BytesIO(resp.raw)
 | 
			
		||||
            # embed.set_image(url=ipfs_link)
 | 
			
		||||
            # embed.add_field(name='params', value=caption)
 | 
			
		||||
            # await self.bot.delete_message(
 | 
			
		||||
            #     chat_id=status_msg.chat.id, message_id=status_msg.id)
 | 
			
		||||
            await message.delete()
 | 
			
		||||
            if file_id:  # img2img
 | 
			
		||||
                pass
 | 
			
		||||
            #     await self.bot.send_media_group(
 | 
			
		||||
            #         status_msg.chat.id,
 | 
			
		||||
            #         media=[
 | 
			
		||||
            #             InputMediaPhoto(file_id),
 | 
			
		||||
            #             InputMediaPhoto(
 | 
			
		||||
            #                 resp.raw,
 | 
			
		||||
            #                 caption=caption,
 | 
			
		||||
            #                 parse_mode='HTML'
 | 
			
		||||
            #             )
 | 
			
		||||
            #         ],
 | 
			
		||||
            #     )
 | 
			
		||||
            #
 | 
			
		||||
            else:  # txt2img
 | 
			
		||||
                # await self.bot.send_photo(
 | 
			
		||||
                #     status_msg.chat.id,
 | 
			
		||||
                #     caption=caption,
 | 
			
		||||
                #     photo=resp.raw,
 | 
			
		||||
                #     reply_markup=build_redo_menu(),
 | 
			
		||||
                #     parse_mode='HTML'
 | 
			
		||||
                # )
 | 
			
		||||
 | 
			
		||||
                embed.set_thumbnail(
 | 
			
		||||
                    url='https://ipfs.skygpu.net/ipfs/' + binary_data + '/image.png')
 | 
			
		||||
                embed.set_image(url=ipfs_link)
 | 
			
		||||
                await send(embed=embed, view=SkynetView(self))
 | 
			
		||||
            else:  # txt2img
 | 
			
		||||
                embed.set_image(url=ipfs_link)
 | 
			
		||||
                embed.add_field(name='Parameters:', value=caption)
 | 
			
		||||
                await send(embed=embed, view=SkynetView(self))
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -41,24 +41,44 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
 | 
			
		|||
        finally:
 | 
			
		||||
            await ctx.reply(content=reply_txt, view=SkynetView(frontend))
 | 
			
		||||
 | 
			
		||||
    @bot.command(name='helper', help='Responds with a help')
 | 
			
		||||
    async def helper(ctx):
 | 
			
		||||
    bot.remove_command('help')
 | 
			
		||||
    @bot.command(name='help', help='Responds with a help')
 | 
			
		||||
    async def help(ctx):
 | 
			
		||||
        splt_msg = ctx.message.content.split(' ')
 | 
			
		||||
 | 
			
		||||
        if len(splt_msg) == 1:
 | 
			
		||||
            await ctx.reply(content=HELP_TEXT, view=SkynetView(frontend))
 | 
			
		||||
            await ctx.send(content=f'```{HELP_TEXT}```', view=SkynetView(frontend))
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            param = splt_msg[1]
 | 
			
		||||
            if param in HELP_TOPICS:
 | 
			
		||||
                await ctx.reply(content=HELP_TOPICS[param], view=SkynetView(frontend))
 | 
			
		||||
                await ctx.send(content=f'```{HELP_TOPICS[param]}```', view=SkynetView(frontend))
 | 
			
		||||
 | 
			
		||||
            else:
 | 
			
		||||
                await ctx.reply(content=HELP_UNKWNOWN_PARAM, view=SkynetView(frontend))
 | 
			
		||||
                await ctx.send(content=f'```{HELP_UNKWNOWN_PARAM}```', view=SkynetView(frontend))
 | 
			
		||||
 | 
			
		||||
    @bot.command(name='cool', help='Display a list of cool prompt words')
 | 
			
		||||
    async def send_cool_words(ctx):
 | 
			
		||||
        await ctx.reply(content='\n'.join(CLEAN_COOL_WORDS), view=SkynetView(frontend))
 | 
			
		||||
        clean_cool_word = '\n'.join(CLEAN_COOL_WORDS)
 | 
			
		||||
        await ctx.send(content=f'```{clean_cool_word}```', view=SkynetView(frontend))
 | 
			
		||||
 | 
			
		||||
    @bot.command(name='stats', help='See user statistics' )
 | 
			
		||||
    async def user_stats(ctx):
 | 
			
		||||
        user = ctx.author
 | 
			
		||||
 | 
			
		||||
        await db_call('get_or_create_user', user.id)
 | 
			
		||||
        generated, joined, role = await db_call('get_user_stats', user.id)
 | 
			
		||||
 | 
			
		||||
        stats_str = f'```generated: {generated}\n'
 | 
			
		||||
        stats_str += f'joined: {joined}\n'
 | 
			
		||||
        stats_str += f'role: {role}\n```'
 | 
			
		||||
 | 
			
		||||
        await ctx.reply(stats_str, view=SkynetView(frontend))
 | 
			
		||||
 | 
			
		||||
    @bot.command(name='donate', help='See donate info')
 | 
			
		||||
    async def donation_info(ctx):
 | 
			
		||||
        await ctx.reply(
 | 
			
		||||
            f'```\n{DONATION_INFO}```', view=SkynetView(frontend))
 | 
			
		||||
 | 
			
		||||
    @bot.command(name='txt2img', help='Responds with an image')
 | 
			
		||||
    async def send_txt2img(ctx):
 | 
			
		||||
| 
						 | 
				
			
			@ -69,7 +89,7 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
 | 
			
		|||
 | 
			
		||||
        # init new msg
 | 
			
		||||
        init_msg = 'started processing txt2img request...'
 | 
			
		||||
        status_msg = await ctx.reply(init_msg)
 | 
			
		||||
        status_msg = await ctx.send(init_msg)
 | 
			
		||||
        await db_call(
 | 
			
		||||
            'new_user_request', user.id, ctx.message.id, status_msg.id, status=init_msg)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -97,13 +117,13 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
 | 
			
		|||
 | 
			
		||||
        ec = await work_request(user, status_msg, 'txt2img', params, ctx)
 | 
			
		||||
 | 
			
		||||
        if ec == 0:
 | 
			
		||||
        if ec == None:
 | 
			
		||||
            await db_call('increment_generated', user.id)
 | 
			
		||||
 | 
			
		||||
    @bot.command(name='redo', help='Redo last request')
 | 
			
		||||
    async def redo(ctx):
 | 
			
		||||
        init_msg = 'started processing redo request...'
 | 
			
		||||
        status_msg = await ctx.reply(init_msg)
 | 
			
		||||
        status_msg = await ctx.send(init_msg)
 | 
			
		||||
        user = ctx.author
 | 
			
		||||
 | 
			
		||||
        method = await db_call('get_last_method_of', user.id)
 | 
			
		||||
| 
						 | 
				
			
			@ -116,8 +136,9 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
 | 
			
		|||
            binary = await db_call('get_last_binary_of', user.id)
 | 
			
		||||
 | 
			
		||||
        if not prompt:
 | 
			
		||||
            await ctx.reply(
 | 
			
		||||
                'no last prompt found, do a txt2img cmd first!'
 | 
			
		||||
            await status_msg.edit(
 | 
			
		||||
                content='no last prompt found, do a txt2img cmd first!',
 | 
			
		||||
                view=SkynetView(frontend)
 | 
			
		||||
            )
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -132,12 +153,106 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
 | 
			
		|||
            **user_config
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        await work_request(
 | 
			
		||||
        ec = await work_request(
 | 
			
		||||
            user, status_msg, 'redo', params, ctx,
 | 
			
		||||
            file_id=file_id,
 | 
			
		||||
            binary_data=binary
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if ec == None:
 | 
			
		||||
            await db_call('increment_generated', user.id)
 | 
			
		||||
 | 
			
		||||
    @bot.command(name='img2img', help='Responds with an image')
 | 
			
		||||
    async def send_img2img(ctx):
 | 
			
		||||
        # if isinstance(message_or_query, CallbackQuery):
 | 
			
		||||
        #     query = message_or_query
 | 
			
		||||
        #     message = query.message
 | 
			
		||||
        #     user = query.from_user
 | 
			
		||||
        #     chat = query.message.chat
 | 
			
		||||
        #
 | 
			
		||||
        # else:
 | 
			
		||||
        #     message = message_or_query
 | 
			
		||||
        #     user = message.from_user
 | 
			
		||||
        #     chat = message.chat
 | 
			
		||||
 | 
			
		||||
        # reply_id = None
 | 
			
		||||
        # if chat.type == 'group' and chat.id == GROUP_ID:
 | 
			
		||||
        #     reply_id = message.message_id
 | 
			
		||||
        #
 | 
			
		||||
        user = ctx.author
 | 
			
		||||
        user_row = await db_call('get_or_create_user', user.id)
 | 
			
		||||
 | 
			
		||||
        # init new msg
 | 
			
		||||
        init_msg = 'started processing img2img request...'
 | 
			
		||||
        status_msg = await ctx.send(init_msg)
 | 
			
		||||
        await db_call(
 | 
			
		||||
            'new_user_request', user.id, ctx.message.id, status_msg.id, status=init_msg)
 | 
			
		||||
 | 
			
		||||
        if not ctx.message.content.startswith('/img2img'):
 | 
			
		||||
            await ctx.reply(
 | 
			
		||||
                'For image to image you need to add /img2img to the beggining of your caption'
 | 
			
		||||
            )
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        prompt = ' '.join(ctx.message.content.split(' ')[1:])
 | 
			
		||||
 | 
			
		||||
        if len(prompt) == 0:
 | 
			
		||||
            await ctx.reply('Empty text prompt ignored.')
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        # file_id = message.photo[-1].file_id
 | 
			
		||||
        # file_path = (await bot.get_file(file_id)).file_path
 | 
			
		||||
        # image_raw = await bot.download_file(file_path)
 | 
			
		||||
        #
 | 
			
		||||
 | 
			
		||||
        file = ctx.message.attachments[-1]
 | 
			
		||||
        file_id = str(file.id)
 | 
			
		||||
        # file bytes
 | 
			
		||||
        image_raw = await file.read()
 | 
			
		||||
        with Image.open(io.BytesIO(image_raw)) as image:
 | 
			
		||||
            w, h = image.size
 | 
			
		||||
 | 
			
		||||
            if w > 512 or h > 512:
 | 
			
		||||
                logging.warning(f'user sent img of size {image.size}')
 | 
			
		||||
                image.thumbnail((512, 512))
 | 
			
		||||
                logging.warning(f'resized it to {image.size}')
 | 
			
		||||
 | 
			
		||||
            image.save(f'ipfs-docker-staging/image.png', format='PNG')
 | 
			
		||||
 | 
			
		||||
            ipfs_hash = ipfs_node.add('image.png')
 | 
			
		||||
            ipfs_node.pin(ipfs_hash)
 | 
			
		||||
 | 
			
		||||
            logging.info(f'published input image {ipfs_hash} on ipfs')
 | 
			
		||||
 | 
			
		||||
        logging.info(f'mid: {ctx.message.id}')
 | 
			
		||||
 | 
			
		||||
        user_config = {**user_row}
 | 
			
		||||
        del user_config['id']
 | 
			
		||||
 | 
			
		||||
        params = {
 | 
			
		||||
            'prompt': prompt,
 | 
			
		||||
            **user_config
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        await db_call(
 | 
			
		||||
            'update_user_stats',
 | 
			
		||||
            user.id,
 | 
			
		||||
            'img2img',
 | 
			
		||||
            last_file=file_id,
 | 
			
		||||
            last_prompt=prompt,
 | 
			
		||||
            last_binary=ipfs_hash
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        ec = await work_request(
 | 
			
		||||
            user, status_msg, 'img2img', params, ctx,
 | 
			
		||||
            file_id=file_id,
 | 
			
		||||
            binary_data=ipfs_hash
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if ec == None:
 | 
			
		||||
            await db_call('increment_generated', user.id)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        
 | 
			
		||||
        # TODO: DELETE BELOW
 | 
			
		||||
        # user = 'testworker3'
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,4 +1,6 @@
 | 
			
		|||
import io
 | 
			
		||||
import discord
 | 
			
		||||
from PIL import Image
 | 
			
		||||
import logging
 | 
			
		||||
from skynet.constants import *
 | 
			
		||||
from skynet.frontend import validate_user_config_request
 | 
			
		||||
| 
						 | 
				
			
			@ -9,11 +11,14 @@ class SkynetView(discord.ui.View):
 | 
			
		|||
    def __init__(self, bot):
 | 
			
		||||
        self.bot = bot
 | 
			
		||||
        super().__init__(timeout=None)
 | 
			
		||||
        self.add_item(RedoButton('redo', discord.ButtonStyle.green, self.bot))
 | 
			
		||||
        self.add_item(Txt2ImgButton('txt2img', discord.ButtonStyle.green, self.bot))
 | 
			
		||||
        self.add_item(ConfigButton('config', discord.ButtonStyle.grey, self.bot))
 | 
			
		||||
        self.add_item(HelpButton('help', discord.ButtonStyle.grey, self.bot))
 | 
			
		||||
        self.add_item(CoolButton('cool', discord.ButtonStyle.gray, self.bot))
 | 
			
		||||
        self.add_item(RedoButton('redo', discord.ButtonStyle.primary, self.bot))
 | 
			
		||||
        self.add_item(Txt2ImgButton('txt2img', discord.ButtonStyle.primary, self.bot))
 | 
			
		||||
        self.add_item(Img2ImgButton('img2img', discord.ButtonStyle.primary, self.bot))
 | 
			
		||||
        self.add_item(StatsButton('stats', discord.ButtonStyle.secondary, self.bot))
 | 
			
		||||
        self.add_item(DonateButton('donate', discord.ButtonStyle.secondary, self.bot))
 | 
			
		||||
        self.add_item(ConfigButton('config', discord.ButtonStyle.secondary, self.bot))
 | 
			
		||||
        self.add_item(HelpButton('help', discord.ButtonStyle.secondary, self.bot))
 | 
			
		||||
        self.add_item(CoolButton('cool', discord.ButtonStyle.secondary, self.bot))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Txt2ImgButton(discord.ui.Button):
 | 
			
		||||
| 
						 | 
				
			
			@ -32,7 +37,7 @@ class Txt2ImgButton(discord.ui.Button):
 | 
			
		|||
 | 
			
		||||
        # init new msg
 | 
			
		||||
        init_msg = 'started processing txt2img request...'
 | 
			
		||||
        status_msg = await msg.reply(init_msg)
 | 
			
		||||
        status_msg = await msg.channel.send(init_msg)
 | 
			
		||||
        await db_call(
 | 
			
		||||
            'new_user_request', user.id, msg.id, status_msg.id, status=init_msg)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -60,7 +65,93 @@ class Txt2ImgButton(discord.ui.Button):
 | 
			
		|||
 | 
			
		||||
        ec = await work_request(user, status_msg, 'txt2img', params, msg)
 | 
			
		||||
 | 
			
		||||
        if ec == 0:
 | 
			
		||||
        if ec == None:
 | 
			
		||||
            await db_call('increment_generated', user.id)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Img2ImgButton(discord.ui.Button):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, label: str, style: discord.ButtonStyle, bot):
 | 
			
		||||
        self.bot = bot
 | 
			
		||||
        super().__init__(label=label, style=style)
 | 
			
		||||
 | 
			
		||||
    async def callback(self, interaction):
 | 
			
		||||
        db_call = self.bot.db_call
 | 
			
		||||
        work_request = self.bot.work_request
 | 
			
		||||
        ipfs_node = self.bot.ipfs_node
 | 
			
		||||
        msg = await grab('Attach an Image. Enter your prompt:', interaction)
 | 
			
		||||
 | 
			
		||||
        user = msg.author
 | 
			
		||||
        user_row = await db_call('get_or_create_user', user.id)
 | 
			
		||||
 | 
			
		||||
        # init new msg
 | 
			
		||||
        init_msg = 'started processing img2img request...'
 | 
			
		||||
        status_msg = await msg.channel.send(init_msg)
 | 
			
		||||
        await db_call(
 | 
			
		||||
            'new_user_request', user.id, msg.id, status_msg.id, status=init_msg)
 | 
			
		||||
 | 
			
		||||
        # if not msg.content.startswith('/img2img'):
 | 
			
		||||
        #     await msg.reply(
 | 
			
		||||
        #         'For image to image you need to add /img2img to the beggining of your caption'
 | 
			
		||||
        #     )
 | 
			
		||||
        #     return
 | 
			
		||||
 | 
			
		||||
        prompt = msg.content
 | 
			
		||||
 | 
			
		||||
        if len(prompt) == 0:
 | 
			
		||||
            await msg.reply('Empty text prompt ignored.')
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        # file_id = message.photo[-1].file_id
 | 
			
		||||
        # file_path = (await bot.get_file(file_id)).file_path
 | 
			
		||||
        # image_raw = await bot.download_file(file_path)
 | 
			
		||||
        #
 | 
			
		||||
 | 
			
		||||
        file = msg.attachments[-1]
 | 
			
		||||
        file_id = str(file.id)
 | 
			
		||||
        # file bytes
 | 
			
		||||
        image_raw = await file.read()
 | 
			
		||||
        with Image.open(io.BytesIO(image_raw)) as image:
 | 
			
		||||
            w, h = image.size
 | 
			
		||||
 | 
			
		||||
            if w > 512 or h > 512:
 | 
			
		||||
                logging.warning(f'user sent img of size {image.size}')
 | 
			
		||||
                image.thumbnail((512, 512))
 | 
			
		||||
                logging.warning(f'resized it to {image.size}')
 | 
			
		||||
 | 
			
		||||
            image.save(f'ipfs-docker-staging/image.png', format='PNG')
 | 
			
		||||
 | 
			
		||||
            ipfs_hash = ipfs_node.add('image.png')
 | 
			
		||||
            ipfs_node.pin(ipfs_hash)
 | 
			
		||||
 | 
			
		||||
            logging.info(f'published input image {ipfs_hash} on ipfs')
 | 
			
		||||
 | 
			
		||||
        logging.info(f'mid: {msg.id}')
 | 
			
		||||
 | 
			
		||||
        user_config = {**user_row}
 | 
			
		||||
        del user_config['id']
 | 
			
		||||
 | 
			
		||||
        params = {
 | 
			
		||||
            'prompt': prompt,
 | 
			
		||||
            **user_config
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        await db_call(
 | 
			
		||||
            'update_user_stats',
 | 
			
		||||
            user.id,
 | 
			
		||||
            'img2img',
 | 
			
		||||
            last_file=file_id,
 | 
			
		||||
            last_prompt=prompt,
 | 
			
		||||
            last_binary=ipfs_hash
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        ec = await work_request(
 | 
			
		||||
            user, status_msg, 'img2img', params, msg,
 | 
			
		||||
            file_id=file_id,
 | 
			
		||||
            binary_data=ipfs_hash
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if ec == None:
 | 
			
		||||
            await db_call('increment_generated', user.id)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -88,8 +179,9 @@ class RedoButton(discord.ui.Button):
 | 
			
		|||
            binary = await db_call('get_last_binary_of', user.id)
 | 
			
		||||
 | 
			
		||||
        if not prompt:
 | 
			
		||||
            await interaction.response.edit_message(
 | 
			
		||||
                'no last prompt found, do a txt2img cmd first!'
 | 
			
		||||
            await status_msg.edit(
 | 
			
		||||
                content='no last prompt found, do a txt2img cmd first!',
 | 
			
		||||
                view=SkynetView(self.bot)
 | 
			
		||||
            )
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -103,12 +195,15 @@ class RedoButton(discord.ui.Button):
 | 
			
		|||
            'prompt': prompt,
 | 
			
		||||
            **user_config
 | 
			
		||||
        }
 | 
			
		||||
        await work_request(
 | 
			
		||||
        ec = await work_request(
 | 
			
		||||
            user, status_msg, 'redo', params, interaction,
 | 
			
		||||
            file_id=file_id,
 | 
			
		||||
            binary_data=binary
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if ec == None:
 | 
			
		||||
            await db_call('increment_generated', user.id)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ConfigButton(discord.ui.Button):
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -136,7 +231,29 @@ class ConfigButton(discord.ui.Button):
 | 
			
		|||
            await msg.reply(content=reply_txt, view=SkynetView(self.bot))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CoolButton(discord.ui.Button):
 | 
			
		||||
class StatsButton(discord.ui.Button):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, label: str, style: discord.ButtonStyle, bot):
 | 
			
		||||
        self.bot = bot
 | 
			
		||||
        super().__init__(label=label, style=style)
 | 
			
		||||
 | 
			
		||||
    async def callback(self, interaction):
 | 
			
		||||
        db_call = self.bot.db_call
 | 
			
		||||
 | 
			
		||||
        user = interaction.user
 | 
			
		||||
 | 
			
		||||
        await db_call('get_or_create_user', user.id)
 | 
			
		||||
        generated, joined, role = await db_call('get_user_stats', user.id)
 | 
			
		||||
 | 
			
		||||
        stats_str = f'```generated: {generated}\n'
 | 
			
		||||
        stats_str += f'joined: {joined}\n'
 | 
			
		||||
        stats_str += f'role: {role}\n```'
 | 
			
		||||
 | 
			
		||||
        await interaction.response.send_message(
 | 
			
		||||
            content=stats_str, view=SkynetView(self.bot))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DonateButton(discord.ui.Button):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, label: str, style: discord.ButtonStyle, bot):
 | 
			
		||||
        self.bot = bot
 | 
			
		||||
| 
						 | 
				
			
			@ -144,7 +261,20 @@ class CoolButton(discord.ui.Button):
 | 
			
		|||
 | 
			
		||||
    async def callback(self, interaction):
 | 
			
		||||
        await interaction.response.send_message(
 | 
			
		||||
            content='\n'.join(CLEAN_COOL_WORDS),
 | 
			
		||||
            content=f'```\n{DONATION_INFO}```',
 | 
			
		||||
            view=SkynetView(self.bot))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CoolButton(discord.ui.Button):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, label: str, style: discord.ButtonStyle, bot):
 | 
			
		||||
        self.bot = bot
 | 
			
		||||
        super().__init__(label=label, style=style)
 | 
			
		||||
 | 
			
		||||
    async def callback(self, interaction):
 | 
			
		||||
        clean_cool_word = '\n'.join(CLEAN_COOL_WORDS)
 | 
			
		||||
        await interaction.response.send_message(
 | 
			
		||||
            content=f'```{clean_cool_word}```',
 | 
			
		||||
            view=SkynetView(self.bot))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -160,15 +290,14 @@ class HelpButton(discord.ui.Button):
 | 
			
		|||
        param = msg.content
 | 
			
		||||
 | 
			
		||||
        if param == 'a':
 | 
			
		||||
            await msg.reply(content=HELP_TEXT, view=SkynetView(self.bot))
 | 
			
		||||
            await msg.reply(content=f'```{HELP_TEXT}```', view=SkynetView(self.bot))
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            if param in HELP_TOPICS:
 | 
			
		||||
                await msg.reply(content=HELP_TOPICS[param], view=SkynetView(self.bot))
 | 
			
		||||
                await msg.reply(content=f'```{HELP_TOPICS[param]}```', view=SkynetView(self.bot))
 | 
			
		||||
 | 
			
		||||
            else:
 | 
			
		||||
                await msg.reply(content=HELP_UNKWNOWN_PARAM, view=SkynetView(self.bot))
 | 
			
		||||
 | 
			
		||||
                await msg.reply(content=f'```{HELP_UNKWNOWN_PARAM}```', view=SkynetView(self.bot))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def grab(prompt, interaction):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -38,27 +38,41 @@ def build_redo_menu():
 | 
			
		|||
    return inline_keyboard
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def prepare_metainfo_caption(user, worker: str, reward: str, meta: dict) -> str:
 | 
			
		||||
def prepare_metainfo_caption(user, worker: str, reward: str, meta: dict, embed) -> str:
 | 
			
		||||
    prompt = meta["prompt"]
 | 
			
		||||
    if len(prompt) > 256:
 | 
			
		||||
        prompt = prompt[:256]
 | 
			
		||||
    
 | 
			
		||||
    gen_str = f'generated by {user.name}\n'
 | 
			
		||||
    gen_str += f'performed by {worker}\n'
 | 
			
		||||
    gen_str += f'reward: {reward}\n'
 | 
			
		||||
 | 
			
		||||
    meta_str = f'__by {user.name}__\n'
 | 
			
		||||
    meta_str += f'*performed by {worker}*\n'
 | 
			
		||||
    meta_str += f'__**reward: {reward}**__\n'
 | 
			
		||||
    embed.add_field(
 | 
			
		||||
        name='General Info', value=f'```{gen_str}```', inline=False)
 | 
			
		||||
    # meta_str = f'__by {user.name}__\n'
 | 
			
		||||
    # meta_str += f'*performed by {worker}*\n'
 | 
			
		||||
    # meta_str += f'__**reward: {reward}**__\n'
 | 
			
		||||
    embed.add_field(name='Prompt', value=f'```{prompt}\n```', inline=False)
 | 
			
		||||
 | 
			
		||||
    # meta_str = f'`prompt:` {prompt}\n'
 | 
			
		||||
 | 
			
		||||
    meta_str = f'seed: {meta["seed"]}\n'
 | 
			
		||||
    meta_str += f'step: {meta["step"]}\n'
 | 
			
		||||
    meta_str += f'guidance: {meta["guidance"]}\n'
 | 
			
		||||
 | 
			
		||||
    meta_str += f'`prompt:` {prompt}\n'
 | 
			
		||||
    meta_str += f'`seed: {meta["seed"]}`\n'
 | 
			
		||||
    meta_str += f'`step: {meta["step"]}`\n'
 | 
			
		||||
    meta_str += f'`guidance: {meta["guidance"]}`\n'
 | 
			
		||||
    if meta['strength']:
 | 
			
		||||
        meta_str += f'`strength: {meta["strength"]}`\n'
 | 
			
		||||
    meta_str += f'`algo: {meta["model"]}`\n'
 | 
			
		||||
        meta_str += f'strength: {meta["strength"]}\n'
 | 
			
		||||
    meta_str += f'algo: {meta["model"]}\n'
 | 
			
		||||
    if meta['upscaler']:
 | 
			
		||||
        meta_str += f'`upscaler: {meta["upscaler"]}`\n'
 | 
			
		||||
        meta_str += f'upscaler: {meta["upscaler"]}\n'
 | 
			
		||||
 | 
			
		||||
    embed.add_field(name='Parameters', value=f'```{meta_str}```', inline=False)
 | 
			
		||||
 | 
			
		||||
    foot_str = f'Made with Skynet v{VERSION}\n'
 | 
			
		||||
    foot_str += f'JOIN THE SWARM: @skynetgpu'
 | 
			
		||||
 | 
			
		||||
    embed.set_footer(text=foot_str)
 | 
			
		||||
 | 
			
		||||
    meta_str += f'__**Made with Skynet v{VERSION}**__\n'
 | 
			
		||||
    meta_str += f'**JOIN THE SWARM: @skynetgpu**'
 | 
			
		||||
    return meta_str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -74,7 +88,7 @@ def generate_reply_caption(
 | 
			
		|||
        url=f'https://explorer.{DEFAULT_DOMAIN}/v2/explore/transaction/{tx_hash}',
 | 
			
		||||
        color=discord.Color.blue())
 | 
			
		||||
 | 
			
		||||
    meta_info = prepare_metainfo_caption(user, worker, reward, params)
 | 
			
		||||
    meta_info = prepare_metainfo_caption(user, worker, reward, params, explorer_link)
 | 
			
		||||
 | 
			
		||||
    # why do we have this?
 | 
			
		||||
    final_msg = '\n'.join([
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue