mirror of https://github.com/skygpu/skynet.git
				
				
				
			
		
			
				
	
	
		
			318 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			318 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
#!/usr/bin/python
 | 
						|
 | 
						|
from json import JSONDecodeError
 | 
						|
import random
 | 
						|
import logging
 | 
						|
import asyncio
 | 
						|
 | 
						|
from decimal import Decimal
 | 
						|
from hashlib import sha256
 | 
						|
from datetime import datetime
 | 
						|
from contextlib import ExitStack, AsyncExitStack
 | 
						|
from contextlib import asynccontextmanager as acm
 | 
						|
 | 
						|
from leap.cleos import CLEOS
 | 
						|
from leap.sugar import Name, asset_from_str, collect_stdout
 | 
						|
from leap.hyperion import HyperionAPI
 | 
						|
# from telebot.types import InputMediaPhoto
 | 
						|
 | 
						|
import discord
 | 
						|
import requests
 | 
						|
import io
 | 
						|
from PIL import Image, UnidentifiedImageError
 | 
						|
 | 
						|
from skynet.db import open_database_connection
 | 
						|
from skynet.ipfs import get_ipfs_file, AsyncIPFSHTTP
 | 
						|
from skynet.constants import *
 | 
						|
 | 
						|
from . import *
 | 
						|
from .bot import DiscordBot
 | 
						|
 | 
						|
from .utils import *
 | 
						|
from .handlers import create_handler_context
 | 
						|
from .ui import SkynetView
 | 
						|
 | 
						|
 | 
						|
class SkynetDiscordFrontend:
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        # token: str,
 | 
						|
        account: str,
 | 
						|
        permission: str,
 | 
						|
        node_url: str,
 | 
						|
        hyperion_url: str,
 | 
						|
        db_host: str,
 | 
						|
        db_user: str,
 | 
						|
        db_pass: str,
 | 
						|
        ipfs_url: str,
 | 
						|
        remote_ipfs_node: str,
 | 
						|
        key: str,
 | 
						|
        explorer_domain: str,
 | 
						|
        ipfs_domain: str
 | 
						|
    ):
 | 
						|
        # self.token = token
 | 
						|
        self.account = account
 | 
						|
        self.permission = permission
 | 
						|
        self.node_url = node_url
 | 
						|
        self.hyperion_url = hyperion_url
 | 
						|
        self.db_host = db_host
 | 
						|
        self.db_user = db_user
 | 
						|
        self.db_pass = db_pass
 | 
						|
        self.ipfs_url = ipfs_url
 | 
						|
        self.remote_ipfs_node = remote_ipfs_node
 | 
						|
        self.key = key
 | 
						|
        self.explorer_domain = explorer_domain
 | 
						|
        self.ipfs_domain = ipfs_domain
 | 
						|
 | 
						|
        self.bot = DiscordBot(self)
 | 
						|
        self.cleos = CLEOS(None, None, url=node_url, remote=node_url)
 | 
						|
        self.hyperion = HyperionAPI(hyperion_url)
 | 
						|
        self.ipfs_node = AsyncIPFSHTTP(ipfs_url)
 | 
						|
 | 
						|
        self._exit_stack = ExitStack()
 | 
						|
        self._async_exit_stack = AsyncExitStack()
 | 
						|
 | 
						|
    async def start(self):
 | 
						|
        if self.remote_ipfs_node:
 | 
						|
            await self.ipfs_node.connect(self.remote_ipfs_node)
 | 
						|
 | 
						|
        self.db_call = await self._async_exit_stack.enter_async_context(
 | 
						|
            open_database_connection(
 | 
						|
                self.db_user, self.db_pass, self.db_host))
 | 
						|
 | 
						|
        create_handler_context(self)
 | 
						|
 | 
						|
    async def stop(self):
 | 
						|
        await self._async_exit_stack.aclose()
 | 
						|
        self._exit_stack.close()
 | 
						|
 | 
						|
    @acm
 | 
						|
    async def open(self):
 | 
						|
        await self.start()
 | 
						|
        yield self
 | 
						|
        await self.stop()
 | 
						|
 | 
						|
    # maybe do this?
 | 
						|
    # async def update_status_message(
 | 
						|
    #     self, status_msg, new_text: str, **kwargs
 | 
						|
    # ):
 | 
						|
    #     await self.db_call(
 | 
						|
    #         'update_user_request_by_sid', status_msg.id, new_text)
 | 
						|
    #     return await self.bot.edit_message_text(
 | 
						|
    #         new_text,
 | 
						|
    #         chat_id=status_msg.chat.id,
 | 
						|
    #         message_id=status_msg.id,
 | 
						|
    #         **kwargs
 | 
						|
    #     )
 | 
						|
 | 
						|
    # async def append_status_message(
 | 
						|
    #     self, status_msg, add_text: str, **kwargs
 | 
						|
    # ):
 | 
						|
    #     request = await self.db_call('get_user_request_by_sid', status_msg.id)
 | 
						|
    #     await self.update_status_message(
 | 
						|
    #         status_msg,
 | 
						|
    #         request['status'] + add_text,
 | 
						|
    #         **kwargs
 | 
						|
    #     )
 | 
						|
 | 
						|
    async def work_request(
 | 
						|
        self,
 | 
						|
        user,
 | 
						|
        status_msg,
 | 
						|
        method: str,
 | 
						|
        params: dict,
 | 
						|
        ctx: discord.ext.commands.context.Context | discord.Message,
 | 
						|
        file_id: str | None = None,
 | 
						|
        binary_data: str = ''
 | 
						|
    ) -> bool:
 | 
						|
        send = ctx.channel.send
 | 
						|
 | 
						|
        if params['seed'] == None:
 | 
						|
            params['seed'] = random.randint(0, 0xFFFFFFFF)
 | 
						|
 | 
						|
        sanitized_params = {}
 | 
						|
        for key, val in params.items():
 | 
						|
            if isinstance(val, Decimal):
 | 
						|
                val = str(val)
 | 
						|
 | 
						|
            sanitized_params[key] = val
 | 
						|
 | 
						|
        body = json.dumps({
 | 
						|
            'method': 'diffuse',
 | 
						|
            'params': sanitized_params
 | 
						|
        })
 | 
						|
        request_time = datetime.now().isoformat()
 | 
						|
 | 
						|
        await status_msg.delete()
 | 
						|
        msg_text = f'processing a \'{method}\' request by {user.name}\n[{timestamp_pretty()}] *broadcasting transaction to chain...* '
 | 
						|
        embed = discord.Embed(
 | 
						|
            title='live updates',
 | 
						|
            description=msg_text,
 | 
						|
            color=discord.Color.blue())
 | 
						|
 | 
						|
        message = await send(embed=embed)
 | 
						|
 | 
						|
        reward = '20.0000 GPU'
 | 
						|
        res = await self.cleos.a_push_action(
 | 
						|
            'gpu.scd',
 | 
						|
            'enqueue',
 | 
						|
            {
 | 
						|
                'user': Name(self.account),
 | 
						|
                'request_body': body,
 | 
						|
                'binary_data': binary_data,
 | 
						|
                'reward': asset_from_str(reward),
 | 
						|
                'min_verification': 1
 | 
						|
            },
 | 
						|
            self.account, self.key, permission=self.permission
 | 
						|
        )
 | 
						|
 | 
						|
        if 'code' in res or 'statusCode' in res:
 | 
						|
            logging.error(json.dumps(res, indent=4))
 | 
						|
            await self.bot.channel.send(
 | 
						|
                status_msg,
 | 
						|
                'skynet has suffered an internal error trying to fill this request')
 | 
						|
            return False
 | 
						|
 | 
						|
        enqueue_tx_id = res['transaction_id']
 | 
						|
        enqueue_tx_link = f'[**Your request on Skynet Explorer**](https://{self.explorer_domain}/v2/explore/transaction/{enqueue_tx_id})'
 | 
						|
 | 
						|
        msg_text += f'**broadcasted!** \n{enqueue_tx_link}\n[{timestamp_pretty()}] *workers are processing request...* '
 | 
						|
        embed = discord.Embed(
 | 
						|
            title='live updates',
 | 
						|
            description=msg_text,
 | 
						|
            color=discord.Color.blue())
 | 
						|
 | 
						|
        await message.edit(embed=embed)
 | 
						|
 | 
						|
        out = collect_stdout(res)
 | 
						|
 | 
						|
        request_id, nonce = out.split(':')
 | 
						|
 | 
						|
        request_hash = sha256(
 | 
						|
            (nonce + body + binary_data).encode('utf-8')).hexdigest().upper()
 | 
						|
 | 
						|
        request_id = int(request_id)
 | 
						|
 | 
						|
        logging.info(f'{request_id} enqueued.')
 | 
						|
 | 
						|
        tx_hash = None
 | 
						|
        ipfs_hash = None
 | 
						|
        for i in range(60):
 | 
						|
            try:
 | 
						|
                submits = await self.hyperion.aget_actions(
 | 
						|
                    account=self.account,
 | 
						|
                    filter='gpu.scd:submit',
 | 
						|
                    sort='desc',
 | 
						|
                    after=request_time
 | 
						|
                )
 | 
						|
                actions = [
 | 
						|
                    action
 | 
						|
                    for action in submits['actions']
 | 
						|
                    if action[
 | 
						|
                        'act']['data']['request_hash'] == request_hash
 | 
						|
                ]
 | 
						|
                if len(actions) > 0:
 | 
						|
                    tx_hash = actions[0]['trx_id']
 | 
						|
                    data = actions[0]['act']['data']
 | 
						|
                    ipfs_hash = data['ipfs_hash']
 | 
						|
                    worker = data['worker']
 | 
						|
                    logging.info('Found matching submit!')
 | 
						|
                    break
 | 
						|
 | 
						|
            except JSONDecodeError:
 | 
						|
                logging.error(f'network error while getting actions, retry..')
 | 
						|
 | 
						|
            await asyncio.sleep(1)
 | 
						|
 | 
						|
        if not ipfs_hash:
 | 
						|
 | 
						|
            timeout_text = f'\n[{timestamp_pretty()}] **timeout processing request**'
 | 
						|
            embed = discord.Embed(
 | 
						|
                title='live updates',
 | 
						|
                description=timeout_text,
 | 
						|
                color=discord.Color.blue())
 | 
						|
 | 
						|
            await message.edit(embed=embed)
 | 
						|
            return False
 | 
						|
 | 
						|
        tx_link = f'[**Your result on Skynet Explorer**](https://{self.explorer_domain}/v2/explore/transaction/{tx_hash})'
 | 
						|
 | 
						|
        msg_text += f'**request processed!**\n{tx_link}\n[{timestamp_pretty()}] *trying to download image...*\n '
 | 
						|
        embed = discord.Embed(
 | 
						|
            title='live updates',
 | 
						|
            description=msg_text,
 | 
						|
            color=discord.Color.blue())
 | 
						|
 | 
						|
        await message.edit(embed=embed)
 | 
						|
 | 
						|
        # attempt to get the image and send it
 | 
						|
        results = {}
 | 
						|
        ipfs_link = f'https://{self.ipfs_domain}/ipfs/{ipfs_hash}'
 | 
						|
        ipfs_link_legacy = ipfs_link + '/image.png'
 | 
						|
 | 
						|
        async def get_and_set_results(link: str):
 | 
						|
            res = await get_ipfs_file(link)
 | 
						|
            logging.info(f'got response from {link}')
 | 
						|
            if not res or res.status_code != 200:
 | 
						|
                logging.warning(f'couldn\'t get ipfs binary data at {link}!')
 | 
						|
 | 
						|
            else:
 | 
						|
                try:
 | 
						|
                    with Image.open(io.BytesIO(res.raw)) as image:
 | 
						|
                        tmp_buf = io.BytesIO()
 | 
						|
                        image.save(tmp_buf, format='PNG')
 | 
						|
                        png_img = tmp_buf.getvalue()
 | 
						|
                        results[link] = png_img
 | 
						|
 | 
						|
                except UnidentifiedImageError:
 | 
						|
                    logging.warning(
 | 
						|
                        f'couldn\'t get ipfs binary data at {link}!')
 | 
						|
 | 
						|
        tasks = [
 | 
						|
            get_and_set_results(ipfs_link),
 | 
						|
            get_and_set_results(ipfs_link_legacy)
 | 
						|
        ]
 | 
						|
        await asyncio.gather(*tasks)
 | 
						|
 | 
						|
        png_img = None
 | 
						|
        if ipfs_link_legacy in results:
 | 
						|
            png_img = results[ipfs_link_legacy]
 | 
						|
 | 
						|
        if ipfs_link in results:
 | 
						|
            png_img = results[ipfs_link]
 | 
						|
 | 
						|
        if not png_img:
 | 
						|
            logging.error(f'couldn\'t get ipfs hosted image at {ipfs_link}!')
 | 
						|
            embed.add_field(
 | 
						|
                name='Error', value=f'couldn\'t get ipfs hosted image [**here**]({ipfs_link})!')
 | 
						|
            await message.edit(embed=embed, view=SkynetView(self))
 | 
						|
            return True
 | 
						|
 | 
						|
        # reword this function, may not need caption
 | 
						|
        caption, embed = generate_reply_caption(
 | 
						|
            user, params, tx_hash, worker, reward, self.explorer_domain)
 | 
						|
 | 
						|
        logging.info(f'success! sending generated image')
 | 
						|
        await message.delete()
 | 
						|
        if file_id:  # img2img
 | 
						|
            embed.set_image(url=ipfs_link)
 | 
						|
            orig_url = f'https://{self.ipfs_domain}/ipfs/' + binary_data
 | 
						|
            res = requests.get(orig_url, stream=True)
 | 
						|
            if res.status_code == 200:
 | 
						|
                with io.BytesIO(res.content) as img:
 | 
						|
                    file = discord.File(img, filename='image.png')
 | 
						|
                    embed.set_thumbnail(url='attachment://image.png')
 | 
						|
                    await send(embed=embed, view=SkynetView(self), file=file)
 | 
						|
            # orig_url = f'https://{self.ipfs_domain}/ipfs/' \
 | 
						|
            #         + binary_data + '/image.png'
 | 
						|
            # embed.set_thumbnail(
 | 
						|
            #     url=orig_url)
 | 
						|
            else:
 | 
						|
                await send(embed=embed, view=SkynetView(self))
 | 
						|
        else:  # txt2img
 | 
						|
            embed.set_image(url=ipfs_link)
 | 
						|
            await send(embed=embed, view=SkynetView(self))
 | 
						|
 | 
						|
        return True
 |