mirror of https://github.com/skygpu/skynet.git
				
				
				
			Frontend DB fixes and starting to add img2img
							parent
							
								
									5e017ffac0
								
							
						
					
					
						commit
						e63d395d5c
					
				| 
						 | 
					@ -329,7 +329,7 @@ def dgpu(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@run.command()
 | 
					@run.command()
 | 
				
			||||||
@click.option('--loglevel', '-l', default='warning', help='logging level')
 | 
					@click.option('--loglevel', '-l', default='INFO', help='logging level')
 | 
				
			||||||
@click.option(
 | 
					@click.option(
 | 
				
			||||||
    '--account', '-a', default='telegram')
 | 
					    '--account', '-a', default='telegram')
 | 
				
			||||||
@click.option(
 | 
					@click.option(
 | 
				
			||||||
| 
						 | 
					@ -357,6 +357,7 @@ def telegram(
 | 
				
			||||||
    db_user: str,
 | 
					    db_user: str,
 | 
				
			||||||
    db_pass: str
 | 
					    db_pass: str
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
 | 
					    logging.basicConfig(level=loglevel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    key, account, permission = load_account_info(
 | 
					    key, account, permission = load_account_info(
 | 
				
			||||||
        key, account, permission)
 | 
					        key, account, permission)
 | 
				
			||||||
| 
						 | 
					@ -422,3 +423,6 @@ def pinner(loglevel, container):
 | 
				
			||||||
                ipfs_node.pin(cid)
 | 
					                ipfs_node.pin(cid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            cleanup_pinned(now)
 | 
					            cleanup_pinned(now)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    except KeyboardInterrupt:
 | 
				
			||||||
 | 
					        ...
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -102,7 +102,7 @@ MAX_WIDTH = 512
 | 
				
			||||||
MAX_HEIGHT = 656
 | 
					MAX_HEIGHT = 656
 | 
				
			||||||
MAX_GUIDANCE = 20
 | 
					MAX_GUIDANCE = 20
 | 
				
			||||||
 | 
					
 | 
				
			||||||
DEFAULT_SEED = 0
 | 
					DEFAULT_SEED = None
 | 
				
			||||||
DEFAULT_WIDTH = 512
 | 
					DEFAULT_WIDTH = 512
 | 
				
			||||||
DEFAULT_HEIGHT = 512
 | 
					DEFAULT_HEIGHT = 512
 | 
				
			||||||
DEFAULT_GUIDANCE = 7.5
 | 
					DEFAULT_GUIDANCE = 7.5
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -26,27 +26,11 @@ CREATE SCHEMA IF NOT EXISTS skynet;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
CREATE TABLE IF NOT EXISTS skynet.user(
 | 
					CREATE TABLE IF NOT EXISTS skynet.user(
 | 
				
			||||||
   id SERIAL PRIMARY KEY NOT NULL,
 | 
					   id SERIAL PRIMARY KEY NOT NULL,
 | 
				
			||||||
   tg_id BIGINT,
 | 
					 | 
				
			||||||
   wp_id VARCHAR(128),
 | 
					 | 
				
			||||||
   mx_id VARCHAR(128),
 | 
					 | 
				
			||||||
   ig_id VARCHAR(128),
 | 
					 | 
				
			||||||
   generated INT NOT NULL,
 | 
					   generated INT NOT NULL,
 | 
				
			||||||
   joined DATE NOT NULL,
 | 
					   joined TIMESTAMP NOT NULL,
 | 
				
			||||||
   last_prompt TEXT,
 | 
					   last_prompt TEXT,
 | 
				
			||||||
   role VARCHAR(128) NOT NULL
 | 
					   role VARCHAR(128) NOT NULL
 | 
				
			||||||
);
 | 
					);
 | 
				
			||||||
ALTER TABLE skynet.user
 | 
					 | 
				
			||||||
    ADD CONSTRAINT tg_unique
 | 
					 | 
				
			||||||
    UNIQUE (tg_id);
 | 
					 | 
				
			||||||
ALTER TABLE skynet.user
 | 
					 | 
				
			||||||
    ADD CONSTRAINT wp_unique
 | 
					 | 
				
			||||||
    UNIQUE (wp_id);
 | 
					 | 
				
			||||||
ALTER TABLE skynet.user
 | 
					 | 
				
			||||||
    ADD CONSTRAINT mx_unique
 | 
					 | 
				
			||||||
    UNIQUE (mx_id);
 | 
					 | 
				
			||||||
ALTER TABLE skynet.user
 | 
					 | 
				
			||||||
    ADD CONSTRAINT ig_unique
 | 
					 | 
				
			||||||
    UNIQUE (ig_id);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
CREATE TABLE IF NOT EXISTS skynet.user_config(
 | 
					CREATE TABLE IF NOT EXISTS skynet.user_config(
 | 
				
			||||||
    id SERIAL NOT NULL,
 | 
					    id SERIAL NOT NULL,
 | 
				
			||||||
| 
						 | 
					@ -54,7 +38,7 @@ CREATE TABLE IF NOT EXISTS skynet.user_config(
 | 
				
			||||||
    step INT NOT NULL,
 | 
					    step INT NOT NULL,
 | 
				
			||||||
    width INT NOT NULL,
 | 
					    width INT NOT NULL,
 | 
				
			||||||
    height INT NOT NULL,
 | 
					    height INT NOT NULL,
 | 
				
			||||||
    seed BIGINT NOT NULL,
 | 
					    seed BIGINT,
 | 
				
			||||||
    guidance REAL NOT NULL,
 | 
					    guidance REAL NOT NULL,
 | 
				
			||||||
    strength REAL NOT NULL,
 | 
					    strength REAL NOT NULL,
 | 
				
			||||||
    upscaler VARCHAR(128)
 | 
					    upscaler VARCHAR(128)
 | 
				
			||||||
| 
						 | 
					@ -177,16 +161,19 @@ async def open_database_connection(
 | 
				
			||||||
    yield _db_call
 | 
					    yield _db_call
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def get_user(conn, uid: int):
 | 
					 | 
				
			||||||
    stmt = await conn.prepare(
 | 
					 | 
				
			||||||
        'SELECT * FROM skynet.user WHERE id = $1')
 | 
					 | 
				
			||||||
    return await stmt.fetchval(uid)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
async def get_user_config(conn, user: int):
 | 
					async def get_user_config(conn, user: int):
 | 
				
			||||||
    stmt = await conn.prepare(
 | 
					    stmt = await conn.prepare(
 | 
				
			||||||
        'SELECT * FROM skynet.user_config WHERE id = $1')
 | 
					        'SELECT * FROM skynet.user_config WHERE id = $1')
 | 
				
			||||||
    return (await stmt.fetch(user))[0]
 | 
					    conf = await stmt.fetch(user)
 | 
				
			||||||
 | 
					    if len(conf) == 1:
 | 
				
			||||||
 | 
					        return conf[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def get_user(conn, uid: int):
 | 
				
			||||||
 | 
					    return await get_user_config(conn, uid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def get_last_prompt_of(conn, user: int):
 | 
					async def get_last_prompt_of(conn, user: int):
 | 
				
			||||||
| 
						 | 
					@ -208,7 +195,6 @@ async def new_user(conn, uid: int):
 | 
				
			||||||
                id, generated, joined, last_prompt, role)
 | 
					                id, generated, joined, last_prompt, role)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            VALUES($1, $2, $3, $4, $5)
 | 
					            VALUES($1, $2, $3, $4, $5)
 | 
				
			||||||
            ON CONFLICT DO NOTHING
 | 
					 | 
				
			||||||
        ''')
 | 
					        ''')
 | 
				
			||||||
        await stmt.fetch(
 | 
					        await stmt.fetch(
 | 
				
			||||||
            uid, 0, date, None, DEFAULT_ROLE
 | 
					            uid, 0, date, None, DEFAULT_ROLE
 | 
				
			||||||
| 
						 | 
					@ -216,18 +202,16 @@ async def new_user(conn, uid: int):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        stmt = await conn.prepare('''
 | 
					        stmt = await conn.prepare('''
 | 
				
			||||||
            INSERT INTO skynet.user_config(
 | 
					            INSERT INTO skynet.user_config(
 | 
				
			||||||
                id, algo, step, width, height, seed, guidance, strength, upscaler)
 | 
					                id, algo, step, width, height, guidance, strength, upscaler)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            VALUES($1, $2, $3, $4, $5, $6, $7, $8, $9)
 | 
					            VALUES($1, $2, $3, $4, $5, $6, $7, $8)
 | 
				
			||||||
            ON CONFLICT DO NOTHING
 | 
					 | 
				
			||||||
        ''')
 | 
					        ''')
 | 
				
			||||||
        user = await stmt.fetch(
 | 
					        resp = await stmt.fetch(
 | 
				
			||||||
            new_uid,
 | 
					            uid,
 | 
				
			||||||
            DEFAULT_ALGO,
 | 
					            DEFAULT_ALGO,
 | 
				
			||||||
            DEFAULT_STEP,
 | 
					            DEFAULT_STEP,
 | 
				
			||||||
            DEFAULT_WIDTH,
 | 
					            DEFAULT_WIDTH,
 | 
				
			||||||
            DEFAULT_HEIGHT,
 | 
					            DEFAULT_HEIGHT,
 | 
				
			||||||
            DEFAULT_SEED,
 | 
					 | 
				
			||||||
            DEFAULT_GUIDANCE,
 | 
					            DEFAULT_GUIDANCE,
 | 
				
			||||||
            DEFAULT_STRENGTH,
 | 
					            DEFAULT_STRENGTH,
 | 
				
			||||||
            DEFAULT_UPSCALER
 | 
					            DEFAULT_UPSCALER
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -84,7 +84,12 @@ def validate_user_config_request(req: str):
 | 
				
			||||||
                    raise ConfigUnknownAttribute(
 | 
					                    raise ConfigUnknownAttribute(
 | 
				
			||||||
                        f'\"{attr}\" not a configurable parameter')
 | 
					                        f'\"{attr}\" not a configurable parameter')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            return attr, val, f'config updated! {attr} to {val}'
 | 
					            display_val = val
 | 
				
			||||||
 | 
					            if attr == 'seed':
 | 
				
			||||||
 | 
					                if not val:
 | 
				
			||||||
 | 
					                    display_val = 'Random'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            return attr, val, f'config updated! {attr} to {display_val}'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        except ValueError:
 | 
					        except ValueError:
 | 
				
			||||||
            raise ValueError(f'\"{val}\" is not a number silly')
 | 
					            raise ValueError(f'\"{val}\" is not a number silly')
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2,12 +2,15 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import io
 | 
					import io
 | 
				
			||||||
import zlib
 | 
					import zlib
 | 
				
			||||||
 | 
					import random
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
import asyncio
 | 
					import asyncio
 | 
				
			||||||
 | 
					import traceback
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from hashlib import sha256
 | 
					from hashlib import sha256
 | 
				
			||||||
from datetime import datetime
 | 
					from datetime import datetime
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import asks
 | 
				
			||||||
import docker
 | 
					import docker
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from PIL import Image
 | 
					from PIL import Image
 | 
				
			||||||
| 
						 | 
					@ -18,7 +21,9 @@ from trio_asyncio import aio_as_trio
 | 
				
			||||||
from telebot.types import (
 | 
					from telebot.types import (
 | 
				
			||||||
    InputFile, InputMediaPhoto, InlineKeyboardButton, InlineKeyboardMarkup
 | 
					    InputFile, InputMediaPhoto, InlineKeyboardButton, InlineKeyboardMarkup
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from telebot.async_telebot import AsyncTeleBot
 | 
					
 | 
				
			||||||
 | 
					from telebot.types import CallbackQuery
 | 
				
			||||||
 | 
					from telebot.async_telebot import AsyncTeleBot, ExceptionHandler
 | 
				
			||||||
from telebot.formatting import hlink
 | 
					from telebot.formatting import hlink
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..db import open_new_database, open_database_connection
 | 
					from ..db import open_new_database, open_database_connection
 | 
				
			||||||
| 
						 | 
					@ -27,7 +32,11 @@ from ..constants import *
 | 
				
			||||||
from . import *
 | 
					from . import *
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
PREFIX = 'tg'
 | 
					class SKYExceptionHandler(ExceptionHandler):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def handle(exception):
 | 
				
			||||||
 | 
					        traceback.print_exc()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def build_redo_menu():
 | 
					def build_redo_menu():
 | 
				
			||||||
    btn_redo = InlineKeyboardButton("Redo", callback_data=json.dumps({'method': 'redo'}))
 | 
					    btn_redo = InlineKeyboardButton("Redo", callback_data=json.dumps({'method': 'redo'}))
 | 
				
			||||||
| 
						 | 
					@ -113,20 +122,36 @@ async def get_user_nonce(cleos, user: str):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def work_request(
 | 
					async def work_request(
 | 
				
			||||||
    bot, cleos, hyperion,
 | 
					    bot, cleos, hyperion,
 | 
				
			||||||
    message,
 | 
					    message, user, chat,
 | 
				
			||||||
    account: str,
 | 
					    account: str,
 | 
				
			||||||
    permission: str,
 | 
					    permission: str,
 | 
				
			||||||
    params: dict
 | 
					    params: dict,
 | 
				
			||||||
 | 
					    file_id: str | None = None,
 | 
				
			||||||
 | 
					    file_path: str | None = None
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
 | 
					    if params['seed'] == None:
 | 
				
			||||||
 | 
					        params['seed'] = random.randint(0, 9e18)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    body = json.dumps({
 | 
					    body = json.dumps({
 | 
				
			||||||
        'method': 'diffuse',
 | 
					        'method': 'diffuse',
 | 
				
			||||||
        'params': params
 | 
					        'params': params
 | 
				
			||||||
    })
 | 
					    })
 | 
				
			||||||
    user = message.from_user
 | 
					 | 
				
			||||||
    chat = message.chat
 | 
					 | 
				
			||||||
    request_time = datetime.now().isoformat()
 | 
					    request_time = datetime.now().isoformat()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if file_id:
 | 
				
			||||||
 | 
					        image_raw = await bot.download_file(file_path)
 | 
				
			||||||
 | 
					        image = Image.open(io.BytesIO(image_raw))
 | 
				
			||||||
 | 
					        w, h = image.size
 | 
				
			||||||
 | 
					        logging.info(f'user sent img of size {image.size}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if w > 512 or h > 512:
 | 
				
			||||||
 | 
					            image.thumbnail((512, 512))
 | 
				
			||||||
 | 
					            logging.warning(f'resized it to {image.size}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        binary = image_raw.hex()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    ec, out = cleos.push_action(
 | 
					    ec, out = cleos.push_action(
 | 
				
			||||||
        'telos.gpu', 'enqueue', [account, body, '', '20.0000 GPU'], f'{account}@{permission}'
 | 
					        'telos.gpu', 'enqueue', [account, body, binary, '20.0000 GPU'], f'{account}@{permission}'
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    out = collect_stdout(out)
 | 
					    out = collect_stdout(out)
 | 
				
			||||||
    if ec != 0:
 | 
					    if ec != 0:
 | 
				
			||||||
| 
						 | 
					@ -168,10 +193,51 @@ async def work_request(
 | 
				
			||||||
        await bot.reply_to(message, 'timeout processing request')
 | 
					        await bot.reply_to(message, 'timeout processing request')
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # attempt to get the image and send it
 | 
				
			||||||
 | 
					    ipfs_link = f'http://test1.us.telos.net:8080/ipfs/{ipfs_hash}/image.png'
 | 
				
			||||||
 | 
					    logging.info(f'attempting to get image at {ipfs_link}')
 | 
				
			||||||
 | 
					    resp = None
 | 
				
			||||||
 | 
					    for i in range(10):
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            resp = await asks.get(ipfs_link, timeout=2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        except asks.errors.RequestTimeout:
 | 
				
			||||||
 | 
					            logging.warning('timeout...')
 | 
				
			||||||
 | 
					            ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    logging.info(f'status_code: {resp.status_code}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    caption = generate_reply_caption(
 | 
				
			||||||
 | 
					        user, params, ipfs_hash, tx_hash)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if resp.status_code != 200:
 | 
				
			||||||
        await bot.reply_to(
 | 
					        await bot.reply_to(
 | 
				
			||||||
            message,
 | 
					            message,
 | 
				
			||||||
        generate_reply_caption(
 | 
					            caption,
 | 
				
			||||||
            user, params, ipfs_hash, tx_hash),
 | 
					            reply_markup=build_redo_menu(),
 | 
				
			||||||
 | 
					            parse_mode='HTML'
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        if file_id:  # img2img
 | 
				
			||||||
 | 
					            await bot.send_media_group(
 | 
				
			||||||
 | 
					                chat.id,
 | 
				
			||||||
 | 
					                media=[
 | 
				
			||||||
 | 
					                    InputMediaPhoto(file_id),
 | 
				
			||||||
 | 
					                    InputMediaPhoto(
 | 
				
			||||||
 | 
					                        resp.raw,
 | 
				
			||||||
 | 
					                        caption=caption
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                ],
 | 
				
			||||||
 | 
					                reply_markup=build_redo_menu(),
 | 
				
			||||||
 | 
					                parse_mode='HTML'
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        else:  # txt2img
 | 
				
			||||||
 | 
					            await bot.send_photo(
 | 
				
			||||||
 | 
					                chat.id,
 | 
				
			||||||
 | 
					                caption=caption,
 | 
				
			||||||
 | 
					                photo=resp.raw,
 | 
				
			||||||
                reply_markup=build_redo_menu(),
 | 
					                reply_markup=build_redo_menu(),
 | 
				
			||||||
                parse_mode='HTML'
 | 
					                parse_mode='HTML'
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
| 
						 | 
					@ -205,7 +271,7 @@ async def run_skynet_telegram(
 | 
				
			||||||
    if key:
 | 
					    if key:
 | 
				
			||||||
        cleos.setup_wallet(key)
 | 
					        cleos.setup_wallet(key)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    bot = AsyncTeleBot(tg_token)
 | 
					    bot = AsyncTeleBot(tg_token, exception_handler=SKYExceptionHandler)
 | 
				
			||||||
    logging.info(f'tg_token: {tg_token}')
 | 
					    logging.info(f'tg_token: {tg_token}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async with open_database_connection(
 | 
					    async with open_database_connection(
 | 
				
			||||||
| 
						 | 
					@ -233,7 +299,7 @@ async def run_skynet_telegram(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @bot.message_handler(commands=['txt2img'])
 | 
					        @bot.message_handler(commands=['txt2img'])
 | 
				
			||||||
        async def send_txt2img(message):
 | 
					        async def send_txt2img(message):
 | 
				
			||||||
            user = message.from_user.id
 | 
					            user = message.from_user
 | 
				
			||||||
            chat = message.chat
 | 
					            chat = message.chat
 | 
				
			||||||
            reply_id = None
 | 
					            reply_id = None
 | 
				
			||||||
            if chat.type == 'group' and chat.id == GROUP_ID:
 | 
					            if chat.type == 'group' and chat.id == GROUP_ID:
 | 
				
			||||||
| 
						 | 
					@ -247,8 +313,8 @@ async def run_skynet_telegram(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            logging.info(f'mid: {message.id}')
 | 
					            logging.info(f'mid: {message.id}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            await db_call('get_or_create_user', user)
 | 
					            user_row = await db_call('get_or_create_user', user.id)
 | 
				
			||||||
            user_config = {**(await db_call('get_user_config', user))}
 | 
					            user_config = {**user_row}
 | 
				
			||||||
            del user_config['id']
 | 
					            del user_config['id']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            params = {
 | 
					            params = {
 | 
				
			||||||
| 
						 | 
					@ -256,22 +322,22 @@ async def run_skynet_telegram(
 | 
				
			||||||
                **user_config
 | 
					                **user_config
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            await db_call('update_user_stats', user, last_prompt=prompt)
 | 
					            await db_call('update_user_stats', user.id, last_prompt=prompt)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            await work_request(
 | 
					            await work_request(
 | 
				
			||||||
                bot, cleos, hyperion,
 | 
					                bot, cleos, hyperion,
 | 
				
			||||||
                message, account, permission, params)
 | 
					                message, user, chat,
 | 
				
			||||||
 | 
					                account, permission, params
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @bot.message_handler(func=lambda message: True, content_types=['photo'])
 | 
					        @bot.message_handler(func=lambda message: True, content_types=['photo'])
 | 
				
			||||||
        async def send_img2img(message):
 | 
					        async def send_img2img(message):
 | 
				
			||||||
            user = message.from_user.id
 | 
					            user = message.from_user
 | 
				
			||||||
            chat = message.chat
 | 
					            chat = message.chat
 | 
				
			||||||
            reply_id = None
 | 
					            reply_id = None
 | 
				
			||||||
            if chat.type == 'group' and chat.id == GROUP_ID:
 | 
					            if chat.type == 'group' and chat.id == GROUP_ID:
 | 
				
			||||||
                reply_id = message.message_id
 | 
					                reply_id = message.message_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            user_id = f'tg+{message.from_user.id}'
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if not message.caption.startswith('/img2img'):
 | 
					            if not message.caption.startswith('/img2img'):
 | 
				
			||||||
                await bot.reply_to(
 | 
					                await bot.reply_to(
 | 
				
			||||||
                    message,
 | 
					                    message,
 | 
				
			||||||
| 
						 | 
					@ -287,62 +353,26 @@ async def run_skynet_telegram(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            file_id = message.photo[-1].file_id
 | 
					            file_id = message.photo[-1].file_id
 | 
				
			||||||
            file_path = (await bot.get_file(file_id)).file_path
 | 
					            file_path = (await bot.get_file(file_id)).file_path
 | 
				
			||||||
            file_raw = await bot.download_file(file_path)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            logging.info(f'mid: {message.id}')
 | 
					            logging.info(f'mid: {message.id}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            user = await db_call('get_or_create_user', user_id)
 | 
					            user_row = await db_call('get_or_create_user', user.id)
 | 
				
			||||||
            user_config = {**(await db_call('get_user_config', user))}
 | 
					            user_config = {**user_row}
 | 
				
			||||||
            del user_config['id']
 | 
					            del user_config['id']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            req = json.dumps({
 | 
					            params = {
 | 
				
			||||||
                'method': 'diffuse',
 | 
					 | 
				
			||||||
                'params': {
 | 
					 | 
				
			||||||
                'prompt': prompt,
 | 
					                'prompt': prompt,
 | 
				
			||||||
                **user_config
 | 
					                **user_config
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            })
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            ec, out = cleos.push_action(
 | 
					            await db_call('update_user_stats', user.id, last_prompt=prompt)
 | 
				
			||||||
                'telos.gpu', 'enqueue', [account, req, file_raw.hex()], f'{account}@{permission}'
 | 
					
 | 
				
			||||||
 | 
					            await work_request(
 | 
				
			||||||
 | 
					                bot, cleos, hyperion,
 | 
				
			||||||
 | 
					                message, user, chat,
 | 
				
			||||||
 | 
					                account, permission, params,
 | 
				
			||||||
 | 
					                file_id=file_id, file_path=file_path
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            if ec != 0:
 | 
					 | 
				
			||||||
                await bot.reply_to(message, out)
 | 
					 | 
				
			||||||
                return
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            request_id = int(out)
 | 
					 | 
				
			||||||
            logging.info(f'{request_id} enqueued.')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            ipfs_hash = None
 | 
					 | 
				
			||||||
            sha_hash = None
 | 
					 | 
				
			||||||
            for i in range(60):
 | 
					 | 
				
			||||||
                result = cleos.get_table(
 | 
					 | 
				
			||||||
                    'telos.gpu', 'telos.gpu', 'results',
 | 
					 | 
				
			||||||
                    index_position=2,
 | 
					 | 
				
			||||||
                    key_type='i64',
 | 
					 | 
				
			||||||
                    lower_bound=request_id,
 | 
					 | 
				
			||||||
                    upper_bound=request_id
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
                if len(results) > 0:
 | 
					 | 
				
			||||||
                    ipfs_hash = result[0]['ipfs_hash']
 | 
					 | 
				
			||||||
                    sha_hash = result[0]['result_hash']
 | 
					 | 
				
			||||||
                    break
 | 
					 | 
				
			||||||
                else:
 | 
					 | 
				
			||||||
                    await asyncio.sleep(1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if not ipfs_hash:
 | 
					 | 
				
			||||||
                await bot.reply_to(message, 'timeout processing request')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            ipfs_link = f'https://ipfs.io/ipfs/{ipfs_hash}/image.png'
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            await bot.reply_to(
 | 
					 | 
				
			||||||
                message,
 | 
					 | 
				
			||||||
                ipfs_link + '\n' +
 | 
					 | 
				
			||||||
                prepare_metainfo_caption(user, result['meta']['meta']),
 | 
					 | 
				
			||||||
                reply_to_message_id=reply_id,
 | 
					 | 
				
			||||||
                reply_markup=build_redo_menu()
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            return
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @bot.message_handler(commands=['img2img'])
 | 
					        @bot.message_handler(commands=['img2img'])
 | 
				
			||||||
| 
						 | 
					@ -352,17 +382,23 @@ async def run_skynet_telegram(
 | 
				
			||||||
                'seems you tried to do an img2img command without sending image'
 | 
					                'seems you tried to do an img2img command without sending image'
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @bot.message_handler(commands=['redo'])
 | 
					        async def _redo(message_or_query):
 | 
				
			||||||
        async def redo(message):
 | 
					            if isinstance(message_or_query, CallbackQuery):
 | 
				
			||||||
            user = message.from_user.id
 | 
					                query = message_or_query
 | 
				
			||||||
 | 
					                message = query.message
 | 
				
			||||||
 | 
					                user = query.from_user
 | 
				
			||||||
 | 
					                chat = query.message.chat
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                message = message_or_query
 | 
				
			||||||
 | 
					                user = message.from_user
 | 
				
			||||||
                chat = message.chat
 | 
					                chat = message.chat
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            reply_id = None
 | 
					            reply_id = None
 | 
				
			||||||
            if chat.type == 'group' and chat.id == GROUP_ID:
 | 
					            if chat.type == 'group' and chat.id == GROUP_ID:
 | 
				
			||||||
                reply_id = message.message_id
 | 
					                reply_id = message.message_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            user_config = {**(await db_call('get_user_config', user))}
 | 
					            prompt = await db_call('get_last_prompt_of', user.id)
 | 
				
			||||||
            del user_config['id']
 | 
					 | 
				
			||||||
            prompt = await db_call('get_last_prompt_of', user)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if not prompt:
 | 
					            if not prompt:
 | 
				
			||||||
                await bot.reply_to(
 | 
					                await bot.reply_to(
 | 
				
			||||||
| 
						 | 
					@ -371,6 +407,11 @@ async def run_skynet_telegram(
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
                return
 | 
					                return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            user_row = await db_call('get_or_create_user', user.id)
 | 
				
			||||||
 | 
					            user_config = {**user_row}
 | 
				
			||||||
 | 
					            del user_config['id']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            params = {
 | 
					            params = {
 | 
				
			||||||
                'prompt': prompt,
 | 
					                'prompt': prompt,
 | 
				
			||||||
                **user_config
 | 
					                **user_config
 | 
				
			||||||
| 
						 | 
					@ -378,7 +419,13 @@ async def run_skynet_telegram(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            await work_request(
 | 
					            await work_request(
 | 
				
			||||||
                bot, cleos, hyperion,
 | 
					                bot, cleos, hyperion,
 | 
				
			||||||
                message, account, permission, params)
 | 
					                message, user, chat,
 | 
				
			||||||
 | 
					                account, permission, params
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        @bot.message_handler(commands=['redo'])
 | 
				
			||||||
 | 
					        async def redo(message):
 | 
				
			||||||
 | 
					            await _redo(message)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @bot.message_handler(commands=['config'])
 | 
					        @bot.message_handler(commands=['config'])
 | 
				
			||||||
        async def set_config(message):
 | 
					        async def set_config(message):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue