Frontend DB fixes and starting to add img2img

add-txt2txt-models
Guillermo Rodriguez 2023-05-28 20:17:55 -03:00
parent 5e017ffac0
commit e63d395d5c
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
5 changed files with 155 additions and 115 deletions

View File

@ -329,7 +329,7 @@ def dgpu(
@run.command() @run.command()
@click.option('--loglevel', '-l', default='warning', help='logging level') @click.option('--loglevel', '-l', default='INFO', help='logging level')
@click.option( @click.option(
'--account', '-a', default='telegram') '--account', '-a', default='telegram')
@click.option( @click.option(
@ -357,6 +357,7 @@ def telegram(
db_user: str, db_user: str,
db_pass: str db_pass: str
): ):
logging.basicConfig(level=loglevel)
key, account, permission = load_account_info( key, account, permission = load_account_info(
key, account, permission) key, account, permission)
@ -422,3 +423,6 @@ def pinner(loglevel, container):
ipfs_node.pin(cid) ipfs_node.pin(cid)
cleanup_pinned(now) cleanup_pinned(now)
except KeyboardInterrupt:
...

View File

@ -102,7 +102,7 @@ MAX_WIDTH = 512
MAX_HEIGHT = 656 MAX_HEIGHT = 656
MAX_GUIDANCE = 20 MAX_GUIDANCE = 20
DEFAULT_SEED = 0 DEFAULT_SEED = None
DEFAULT_WIDTH = 512 DEFAULT_WIDTH = 512
DEFAULT_HEIGHT = 512 DEFAULT_HEIGHT = 512
DEFAULT_GUIDANCE = 7.5 DEFAULT_GUIDANCE = 7.5

View File

@ -26,27 +26,11 @@ CREATE SCHEMA IF NOT EXISTS skynet;
CREATE TABLE IF NOT EXISTS skynet.user( CREATE TABLE IF NOT EXISTS skynet.user(
id SERIAL PRIMARY KEY NOT NULL, id SERIAL PRIMARY KEY NOT NULL,
tg_id BIGINT,
wp_id VARCHAR(128),
mx_id VARCHAR(128),
ig_id VARCHAR(128),
generated INT NOT NULL, generated INT NOT NULL,
joined DATE NOT NULL, joined TIMESTAMP NOT NULL,
last_prompt TEXT, last_prompt TEXT,
role VARCHAR(128) NOT NULL role VARCHAR(128) NOT NULL
); );
ALTER TABLE skynet.user
ADD CONSTRAINT tg_unique
UNIQUE (tg_id);
ALTER TABLE skynet.user
ADD CONSTRAINT wp_unique
UNIQUE (wp_id);
ALTER TABLE skynet.user
ADD CONSTRAINT mx_unique
UNIQUE (mx_id);
ALTER TABLE skynet.user
ADD CONSTRAINT ig_unique
UNIQUE (ig_id);
CREATE TABLE IF NOT EXISTS skynet.user_config( CREATE TABLE IF NOT EXISTS skynet.user_config(
id SERIAL NOT NULL, id SERIAL NOT NULL,
@ -54,7 +38,7 @@ CREATE TABLE IF NOT EXISTS skynet.user_config(
step INT NOT NULL, step INT NOT NULL,
width INT NOT NULL, width INT NOT NULL,
height INT NOT NULL, height INT NOT NULL,
seed BIGINT NOT NULL, seed BIGINT,
guidance REAL NOT NULL, guidance REAL NOT NULL,
strength REAL NOT NULL, strength REAL NOT NULL,
upscaler VARCHAR(128) upscaler VARCHAR(128)
@ -177,16 +161,19 @@ async def open_database_connection(
yield _db_call yield _db_call
async def get_user(conn, uid: int):
stmt = await conn.prepare(
'SELECT * FROM skynet.user WHERE id = $1')
return await stmt.fetchval(uid)
async def get_user_config(conn, user: int): async def get_user_config(conn, user: int):
stmt = await conn.prepare( stmt = await conn.prepare(
'SELECT * FROM skynet.user_config WHERE id = $1') 'SELECT * FROM skynet.user_config WHERE id = $1')
return (await stmt.fetch(user))[0] conf = await stmt.fetch(user)
if len(conf) == 1:
return conf[0]
else:
return None
async def get_user(conn, uid: int):
return await get_user_config(conn, uid)
async def get_last_prompt_of(conn, user: int): async def get_last_prompt_of(conn, user: int):
@ -208,7 +195,6 @@ async def new_user(conn, uid: int):
id, generated, joined, last_prompt, role) id, generated, joined, last_prompt, role)
VALUES($1, $2, $3, $4, $5) VALUES($1, $2, $3, $4, $5)
ON CONFLICT DO NOTHING
''') ''')
await stmt.fetch( await stmt.fetch(
uid, 0, date, None, DEFAULT_ROLE uid, 0, date, None, DEFAULT_ROLE
@ -216,18 +202,16 @@ async def new_user(conn, uid: int):
stmt = await conn.prepare(''' stmt = await conn.prepare('''
INSERT INTO skynet.user_config( INSERT INTO skynet.user_config(
id, algo, step, width, height, seed, guidance, strength, upscaler) id, algo, step, width, height, guidance, strength, upscaler)
VALUES($1, $2, $3, $4, $5, $6, $7, $8, $9) VALUES($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT DO NOTHING
''') ''')
user = await stmt.fetch( resp = await stmt.fetch(
new_uid, uid,
DEFAULT_ALGO, DEFAULT_ALGO,
DEFAULT_STEP, DEFAULT_STEP,
DEFAULT_WIDTH, DEFAULT_WIDTH,
DEFAULT_HEIGHT, DEFAULT_HEIGHT,
DEFAULT_SEED,
DEFAULT_GUIDANCE, DEFAULT_GUIDANCE,
DEFAULT_STRENGTH, DEFAULT_STRENGTH,
DEFAULT_UPSCALER DEFAULT_UPSCALER

View File

@ -84,7 +84,12 @@ def validate_user_config_request(req: str):
raise ConfigUnknownAttribute( raise ConfigUnknownAttribute(
f'\"{attr}\" not a configurable parameter') f'\"{attr}\" not a configurable parameter')
return attr, val, f'config updated! {attr} to {val}' display_val = val
if attr == 'seed':
if not val:
display_val = 'Random'
return attr, val, f'config updated! {attr} to {display_val}'
except ValueError: except ValueError:
raise ValueError(f'\"{val}\" is not a number silly') raise ValueError(f'\"{val}\" is not a number silly')

View File

@ -2,12 +2,15 @@
import io import io
import zlib import zlib
import random
import logging import logging
import asyncio import asyncio
import traceback
from hashlib import sha256 from hashlib import sha256
from datetime import datetime from datetime import datetime
import asks
import docker import docker
from PIL import Image from PIL import Image
@ -18,7 +21,9 @@ from trio_asyncio import aio_as_trio
from telebot.types import ( from telebot.types import (
InputFile, InputMediaPhoto, InlineKeyboardButton, InlineKeyboardMarkup InputFile, InputMediaPhoto, InlineKeyboardButton, InlineKeyboardMarkup
) )
from telebot.async_telebot import AsyncTeleBot
from telebot.types import CallbackQuery
from telebot.async_telebot import AsyncTeleBot, ExceptionHandler
from telebot.formatting import hlink from telebot.formatting import hlink
from ..db import open_new_database, open_database_connection from ..db import open_new_database, open_database_connection
@ -27,7 +32,11 @@ from ..constants import *
from . import * from . import *
PREFIX = 'tg' class SKYExceptionHandler(ExceptionHandler):
def handle(exception):
traceback.print_exc()
def build_redo_menu(): def build_redo_menu():
btn_redo = InlineKeyboardButton("Redo", callback_data=json.dumps({'method': 'redo'})) btn_redo = InlineKeyboardButton("Redo", callback_data=json.dumps({'method': 'redo'}))
@ -113,20 +122,36 @@ async def get_user_nonce(cleos, user: str):
async def work_request( async def work_request(
bot, cleos, hyperion, bot, cleos, hyperion,
message, message, user, chat,
account: str, account: str,
permission: str, permission: str,
params: dict params: dict,
file_id: str | None = None,
file_path: str | None = None
): ):
if params['seed'] == None:
params['seed'] = random.randint(0, 9e18)
body = json.dumps({ body = json.dumps({
'method': 'diffuse', 'method': 'diffuse',
'params': params 'params': params
}) })
user = message.from_user
chat = message.chat
request_time = datetime.now().isoformat() request_time = datetime.now().isoformat()
if file_id:
image_raw = await bot.download_file(file_path)
image = Image.open(io.BytesIO(image_raw))
w, h = image.size
logging.info(f'user sent img of size {image.size}')
if w > 512 or h > 512:
image.thumbnail((512, 512))
logging.warning(f'resized it to {image.size}')
binary = image_raw.hex()
ec, out = cleos.push_action( ec, out = cleos.push_action(
'telos.gpu', 'enqueue', [account, body, '', '20.0000 GPU'], f'{account}@{permission}' 'telos.gpu', 'enqueue', [account, body, binary, '20.0000 GPU'], f'{account}@{permission}'
) )
out = collect_stdout(out) out = collect_stdout(out)
if ec != 0: if ec != 0:
@ -168,10 +193,51 @@ async def work_request(
await bot.reply_to(message, 'timeout processing request') await bot.reply_to(message, 'timeout processing request')
return return
# attempt to get the image and send it
ipfs_link = f'http://test1.us.telos.net:8080/ipfs/{ipfs_hash}/image.png'
logging.info(f'attempting to get image at {ipfs_link}')
resp = None
for i in range(10):
try:
resp = await asks.get(ipfs_link, timeout=2)
except asks.errors.RequestTimeout:
logging.warning('timeout...')
...
logging.info(f'status_code: {resp.status_code}')
caption = generate_reply_caption(
user, params, ipfs_hash, tx_hash)
if resp.status_code != 200:
await bot.reply_to( await bot.reply_to(
message, message,
generate_reply_caption( caption,
user, params, ipfs_hash, tx_hash), reply_markup=build_redo_menu(),
parse_mode='HTML'
)
else:
if file_id: # img2img
await bot.send_media_group(
chat.id,
media=[
InputMediaPhoto(file_id),
InputMediaPhoto(
resp.raw,
caption=caption
)
],
reply_markup=build_redo_menu(),
parse_mode='HTML'
)
else: # txt2img
await bot.send_photo(
chat.id,
caption=caption,
photo=resp.raw,
reply_markup=build_redo_menu(), reply_markup=build_redo_menu(),
parse_mode='HTML' parse_mode='HTML'
) )
@ -205,7 +271,7 @@ async def run_skynet_telegram(
if key: if key:
cleos.setup_wallet(key) cleos.setup_wallet(key)
bot = AsyncTeleBot(tg_token) bot = AsyncTeleBot(tg_token, exception_handler=SKYExceptionHandler)
logging.info(f'tg_token: {tg_token}') logging.info(f'tg_token: {tg_token}')
async with open_database_connection( async with open_database_connection(
@ -233,7 +299,7 @@ async def run_skynet_telegram(
@bot.message_handler(commands=['txt2img']) @bot.message_handler(commands=['txt2img'])
async def send_txt2img(message): async def send_txt2img(message):
user = message.from_user.id user = message.from_user
chat = message.chat chat = message.chat
reply_id = None reply_id = None
if chat.type == 'group' and chat.id == GROUP_ID: if chat.type == 'group' and chat.id == GROUP_ID:
@ -247,8 +313,8 @@ async def run_skynet_telegram(
logging.info(f'mid: {message.id}') logging.info(f'mid: {message.id}')
await db_call('get_or_create_user', user) user_row = await db_call('get_or_create_user', user.id)
user_config = {**(await db_call('get_user_config', user))} user_config = {**user_row}
del user_config['id'] del user_config['id']
params = { params = {
@ -256,22 +322,22 @@ async def run_skynet_telegram(
**user_config **user_config
} }
await db_call('update_user_stats', user, last_prompt=prompt) await db_call('update_user_stats', user.id, last_prompt=prompt)
await work_request( await work_request(
bot, cleos, hyperion, bot, cleos, hyperion,
message, account, permission, params) message, user, chat,
account, permission, params
)
@bot.message_handler(func=lambda message: True, content_types=['photo']) @bot.message_handler(func=lambda message: True, content_types=['photo'])
async def send_img2img(message): async def send_img2img(message):
user = message.from_user.id user = message.from_user
chat = message.chat chat = message.chat
reply_id = None reply_id = None
if chat.type == 'group' and chat.id == GROUP_ID: if chat.type == 'group' and chat.id == GROUP_ID:
reply_id = message.message_id reply_id = message.message_id
user_id = f'tg+{message.from_user.id}'
if not message.caption.startswith('/img2img'): if not message.caption.startswith('/img2img'):
await bot.reply_to( await bot.reply_to(
message, message,
@ -287,62 +353,26 @@ async def run_skynet_telegram(
file_id = message.photo[-1].file_id file_id = message.photo[-1].file_id
file_path = (await bot.get_file(file_id)).file_path file_path = (await bot.get_file(file_id)).file_path
file_raw = await bot.download_file(file_path)
logging.info(f'mid: {message.id}') logging.info(f'mid: {message.id}')
user = await db_call('get_or_create_user', user_id) user_row = await db_call('get_or_create_user', user.id)
user_config = {**(await db_call('get_user_config', user))} user_config = {**user_row}
del user_config['id'] del user_config['id']
req = json.dumps({ params = {
'method': 'diffuse',
'params': {
'prompt': prompt, 'prompt': prompt,
**user_config **user_config
} }
})
ec, out = cleos.push_action( await db_call('update_user_stats', user.id, last_prompt=prompt)
'telos.gpu', 'enqueue', [account, req, file_raw.hex()], f'{account}@{permission}'
await work_request(
bot, cleos, hyperion,
message, user, chat,
account, permission, params,
file_id=file_id, file_path=file_path
) )
if ec != 0:
await bot.reply_to(message, out)
return
request_id = int(out)
logging.info(f'{request_id} enqueued.')
ipfs_hash = None
sha_hash = None
for i in range(60):
result = cleos.get_table(
'telos.gpu', 'telos.gpu', 'results',
index_position=2,
key_type='i64',
lower_bound=request_id,
upper_bound=request_id
)
if len(results) > 0:
ipfs_hash = result[0]['ipfs_hash']
sha_hash = result[0]['result_hash']
break
else:
await asyncio.sleep(1)
if not ipfs_hash:
await bot.reply_to(message, 'timeout processing request')
ipfs_link = f'https://ipfs.io/ipfs/{ipfs_hash}/image.png'
await bot.reply_to(
message,
ipfs_link + '\n' +
prepare_metainfo_caption(user, result['meta']['meta']),
reply_to_message_id=reply_id,
reply_markup=build_redo_menu()
)
return
@bot.message_handler(commands=['img2img']) @bot.message_handler(commands=['img2img'])
@ -352,17 +382,23 @@ async def run_skynet_telegram(
'seems you tried to do an img2img command without sending image' 'seems you tried to do an img2img command without sending image'
) )
@bot.message_handler(commands=['redo']) async def _redo(message_or_query):
async def redo(message): if isinstance(message_or_query, CallbackQuery):
user = message.from_user.id query = message_or_query
message = query.message
user = query.from_user
chat = query.message.chat
else:
message = message_or_query
user = message.from_user
chat = message.chat chat = message.chat
reply_id = None reply_id = None
if chat.type == 'group' and chat.id == GROUP_ID: if chat.type == 'group' and chat.id == GROUP_ID:
reply_id = message.message_id reply_id = message.message_id
user_config = {**(await db_call('get_user_config', user))} prompt = await db_call('get_last_prompt_of', user.id)
del user_config['id']
prompt = await db_call('get_last_prompt_of', user)
if not prompt: if not prompt:
await bot.reply_to( await bot.reply_to(
@ -371,6 +407,11 @@ async def run_skynet_telegram(
) )
return return
user_row = await db_call('get_or_create_user', user.id)
user_config = {**user_row}
del user_config['id']
params = { params = {
'prompt': prompt, 'prompt': prompt,
**user_config **user_config
@ -378,7 +419,13 @@ async def run_skynet_telegram(
await work_request( await work_request(
bot, cleos, hyperion, bot, cleos, hyperion,
message, account, permission, params) message, user, chat,
account, permission, params
)
@bot.message_handler(commands=['redo'])
async def redo(message):
await _redo(message)
@bot.message_handler(commands=['config']) @bot.message_handler(commands=['config'])
async def set_config(message): async def set_config(message):