Frontend DB fixes and starting to add img2img

add-txt2txt-models
Guillermo Rodriguez 2023-05-28 20:17:55 -03:00
parent 5e017ffac0
commit e63d395d5c
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
5 changed files with 155 additions and 115 deletions

View File

@ -329,7 +329,7 @@ def dgpu(
@run.command()
@click.option('--loglevel', '-l', default='warning', help='logging level')
@click.option('--loglevel', '-l', default='INFO', help='logging level')
@click.option(
'--account', '-a', default='telegram')
@click.option(
@ -357,6 +357,7 @@ def telegram(
db_user: str,
db_pass: str
):
logging.basicConfig(level=loglevel)
key, account, permission = load_account_info(
key, account, permission)
@ -422,3 +423,6 @@ def pinner(loglevel, container):
ipfs_node.pin(cid)
cleanup_pinned(now)
except KeyboardInterrupt:
...

View File

@ -102,7 +102,7 @@ MAX_WIDTH = 512
MAX_HEIGHT = 656
MAX_GUIDANCE = 20
DEFAULT_SEED = 0
DEFAULT_SEED = None
DEFAULT_WIDTH = 512
DEFAULT_HEIGHT = 512
DEFAULT_GUIDANCE = 7.5

View File

@ -26,27 +26,11 @@ CREATE SCHEMA IF NOT EXISTS skynet;
CREATE TABLE IF NOT EXISTS skynet.user(
id SERIAL PRIMARY KEY NOT NULL,
tg_id BIGINT,
wp_id VARCHAR(128),
mx_id VARCHAR(128),
ig_id VARCHAR(128),
generated INT NOT NULL,
joined DATE NOT NULL,
joined TIMESTAMP NOT NULL,
last_prompt TEXT,
role VARCHAR(128) NOT NULL
);
ALTER TABLE skynet.user
ADD CONSTRAINT tg_unique
UNIQUE (tg_id);
ALTER TABLE skynet.user
ADD CONSTRAINT wp_unique
UNIQUE (wp_id);
ALTER TABLE skynet.user
ADD CONSTRAINT mx_unique
UNIQUE (mx_id);
ALTER TABLE skynet.user
ADD CONSTRAINT ig_unique
UNIQUE (ig_id);
CREATE TABLE IF NOT EXISTS skynet.user_config(
id SERIAL NOT NULL,
@ -54,7 +38,7 @@ CREATE TABLE IF NOT EXISTS skynet.user_config(
step INT NOT NULL,
width INT NOT NULL,
height INT NOT NULL,
seed BIGINT NOT NULL,
seed BIGINT,
guidance REAL NOT NULL,
strength REAL NOT NULL,
upscaler VARCHAR(128)
@ -177,16 +161,19 @@ async def open_database_connection(
yield _db_call
async def get_user(conn, uid: int):
stmt = await conn.prepare(
'SELECT * FROM skynet.user WHERE id = $1')
return await stmt.fetchval(uid)
async def get_user_config(conn, user: int):
stmt = await conn.prepare(
'SELECT * FROM skynet.user_config WHERE id = $1')
return (await stmt.fetch(user))[0]
conf = await stmt.fetch(user)
if len(conf) == 1:
return conf[0]
else:
return None
async def get_user(conn, uid: int):
return await get_user_config(conn, uid)
async def get_last_prompt_of(conn, user: int):
@ -208,7 +195,6 @@ async def new_user(conn, uid: int):
id, generated, joined, last_prompt, role)
VALUES($1, $2, $3, $4, $5)
ON CONFLICT DO NOTHING
''')
await stmt.fetch(
uid, 0, date, None, DEFAULT_ROLE
@ -216,18 +202,16 @@ async def new_user(conn, uid: int):
stmt = await conn.prepare('''
INSERT INTO skynet.user_config(
id, algo, step, width, height, seed, guidance, strength, upscaler)
id, algo, step, width, height, guidance, strength, upscaler)
VALUES($1, $2, $3, $4, $5, $6, $7, $8, $9)
ON CONFLICT DO NOTHING
VALUES($1, $2, $3, $4, $5, $6, $7, $8)
''')
user = await stmt.fetch(
new_uid,
resp = await stmt.fetch(
uid,
DEFAULT_ALGO,
DEFAULT_STEP,
DEFAULT_WIDTH,
DEFAULT_HEIGHT,
DEFAULT_SEED,
DEFAULT_GUIDANCE,
DEFAULT_STRENGTH,
DEFAULT_UPSCALER

View File

@ -84,7 +84,12 @@ def validate_user_config_request(req: str):
raise ConfigUnknownAttribute(
f'\"{attr}\" not a configurable parameter')
return attr, val, f'config updated! {attr} to {val}'
display_val = val
if attr == 'seed':
if not val:
display_val = 'Random'
return attr, val, f'config updated! {attr} to {display_val}'
except ValueError:
raise ValueError(f'\"{val}\" is not a number silly')

View File

@ -2,12 +2,15 @@
import io
import zlib
import random
import logging
import asyncio
import traceback
from hashlib import sha256
from datetime import datetime
import asks
import docker
from PIL import Image
@ -18,7 +21,9 @@ from trio_asyncio import aio_as_trio
from telebot.types import (
InputFile, InputMediaPhoto, InlineKeyboardButton, InlineKeyboardMarkup
)
from telebot.async_telebot import AsyncTeleBot
from telebot.types import CallbackQuery
from telebot.async_telebot import AsyncTeleBot, ExceptionHandler
from telebot.formatting import hlink
from ..db import open_new_database, open_database_connection
@ -27,7 +32,11 @@ from ..constants import *
from . import *
PREFIX = 'tg'
class SKYExceptionHandler(ExceptionHandler):
def handle(exception):
traceback.print_exc()
def build_redo_menu():
btn_redo = InlineKeyboardButton("Redo", callback_data=json.dumps({'method': 'redo'}))
@ -113,20 +122,36 @@ async def get_user_nonce(cleos, user: str):
async def work_request(
bot, cleos, hyperion,
message,
message, user, chat,
account: str,
permission: str,
params: dict
params: dict,
file_id: str | None = None,
file_path: str | None = None
):
if params['seed'] == None:
params['seed'] = random.randint(0, 9e18)
body = json.dumps({
'method': 'diffuse',
'params': params
})
user = message.from_user
chat = message.chat
request_time = datetime.now().isoformat()
if file_id:
image_raw = await bot.download_file(file_path)
image = Image.open(io.BytesIO(image_raw))
w, h = image.size
logging.info(f'user sent img of size {image.size}')
if w > 512 or h > 512:
image.thumbnail((512, 512))
logging.warning(f'resized it to {image.size}')
binary = image_raw.hex()
ec, out = cleos.push_action(
'telos.gpu', 'enqueue', [account, body, '', '20.0000 GPU'], f'{account}@{permission}'
'telos.gpu', 'enqueue', [account, body, binary, '20.0000 GPU'], f'{account}@{permission}'
)
out = collect_stdout(out)
if ec != 0:
@ -168,13 +193,54 @@ async def work_request(
await bot.reply_to(message, 'timeout processing request')
return
await bot.reply_to(
message,
generate_reply_caption(
user, params, ipfs_hash, tx_hash),
reply_markup=build_redo_menu(),
parse_mode='HTML'
)
# attempt to get the image and send it
ipfs_link = f'http://test1.us.telos.net:8080/ipfs/{ipfs_hash}/image.png'
logging.info(f'attempting to get image at {ipfs_link}')
resp = None
for i in range(10):
try:
resp = await asks.get(ipfs_link, timeout=2)
except asks.errors.RequestTimeout:
logging.warning('timeout...')
...
logging.info(f'status_code: {resp.status_code}')
caption = generate_reply_caption(
user, params, ipfs_hash, tx_hash)
if resp.status_code != 200:
await bot.reply_to(
message,
caption,
reply_markup=build_redo_menu(),
parse_mode='HTML'
)
else:
if file_id: # img2img
await bot.send_media_group(
chat.id,
media=[
InputMediaPhoto(file_id),
InputMediaPhoto(
resp.raw,
caption=caption
)
],
reply_markup=build_redo_menu(),
parse_mode='HTML'
)
else: # txt2img
await bot.send_photo(
chat.id,
caption=caption,
photo=resp.raw,
reply_markup=build_redo_menu(),
parse_mode='HTML'
)
async def run_skynet_telegram(
@ -205,7 +271,7 @@ async def run_skynet_telegram(
if key:
cleos.setup_wallet(key)
bot = AsyncTeleBot(tg_token)
bot = AsyncTeleBot(tg_token, exception_handler=SKYExceptionHandler)
logging.info(f'tg_token: {tg_token}')
async with open_database_connection(
@ -233,7 +299,7 @@ async def run_skynet_telegram(
@bot.message_handler(commands=['txt2img'])
async def send_txt2img(message):
user = message.from_user.id
user = message.from_user
chat = message.chat
reply_id = None
if chat.type == 'group' and chat.id == GROUP_ID:
@ -247,8 +313,8 @@ async def run_skynet_telegram(
logging.info(f'mid: {message.id}')
await db_call('get_or_create_user', user)
user_config = {**(await db_call('get_user_config', user))}
user_row = await db_call('get_or_create_user', user.id)
user_config = {**user_row}
del user_config['id']
params = {
@ -256,22 +322,22 @@ async def run_skynet_telegram(
**user_config
}
await db_call('update_user_stats', user, last_prompt=prompt)
await db_call('update_user_stats', user.id, last_prompt=prompt)
await work_request(
bot, cleos, hyperion,
message, account, permission, params)
message, user, chat,
account, permission, params
)
@bot.message_handler(func=lambda message: True, content_types=['photo'])
async def send_img2img(message):
user = message.from_user.id
user = message.from_user
chat = message.chat
reply_id = None
if chat.type == 'group' and chat.id == GROUP_ID:
reply_id = message.message_id
user_id = f'tg+{message.from_user.id}'
if not message.caption.startswith('/img2img'):
await bot.reply_to(
message,
@ -287,62 +353,26 @@ async def run_skynet_telegram(
file_id = message.photo[-1].file_id
file_path = (await bot.get_file(file_id)).file_path
file_raw = await bot.download_file(file_path)
logging.info(f'mid: {message.id}')
user = await db_call('get_or_create_user', user_id)
user_config = {**(await db_call('get_user_config', user))}
user_row = await db_call('get_or_create_user', user.id)
user_config = {**user_row}
del user_config['id']
req = json.dumps({
'method': 'diffuse',
'params': {
'prompt': prompt,
**user_config
}
})
params = {
'prompt': prompt,
**user_config
}
ec, out = cleos.push_action(
'telos.gpu', 'enqueue', [account, req, file_raw.hex()], f'{account}@{permission}'
await db_call('update_user_stats', user.id, last_prompt=prompt)
await work_request(
bot, cleos, hyperion,
message, user, chat,
account, permission, params,
file_id=file_id, file_path=file_path
)
if ec != 0:
await bot.reply_to(message, out)
return
request_id = int(out)
logging.info(f'{request_id} enqueued.')
ipfs_hash = None
sha_hash = None
for i in range(60):
result = cleos.get_table(
'telos.gpu', 'telos.gpu', 'results',
index_position=2,
key_type='i64',
lower_bound=request_id,
upper_bound=request_id
)
if len(results) > 0:
ipfs_hash = result[0]['ipfs_hash']
sha_hash = result[0]['result_hash']
break
else:
await asyncio.sleep(1)
if not ipfs_hash:
await bot.reply_to(message, 'timeout processing request')
ipfs_link = f'https://ipfs.io/ipfs/{ipfs_hash}/image.png'
await bot.reply_to(
message,
ipfs_link + '\n' +
prepare_metainfo_caption(user, result['meta']['meta']),
reply_to_message_id=reply_id,
reply_markup=build_redo_menu()
)
return
@bot.message_handler(commands=['img2img'])
@ -352,17 +382,23 @@ async def run_skynet_telegram(
'seems you tried to do an img2img command without sending image'
)
@bot.message_handler(commands=['redo'])
async def redo(message):
user = message.from_user.id
chat = message.chat
async def _redo(message_or_query):
if isinstance(message_or_query, CallbackQuery):
query = message_or_query
message = query.message
user = query.from_user
chat = query.message.chat
else:
message = message_or_query
user = message.from_user
chat = message.chat
reply_id = None
if chat.type == 'group' and chat.id == GROUP_ID:
reply_id = message.message_id
user_config = {**(await db_call('get_user_config', user))}
del user_config['id']
prompt = await db_call('get_last_prompt_of', user)
prompt = await db_call('get_last_prompt_of', user.id)
if not prompt:
await bot.reply_to(
@ -371,6 +407,11 @@ async def run_skynet_telegram(
)
return
user_row = await db_call('get_or_create_user', user.id)
user_config = {**user_row}
del user_config['id']
params = {
'prompt': prompt,
**user_config
@ -378,7 +419,13 @@ async def run_skynet_telegram(
await work_request(
bot, cleos, hyperion,
message, account, permission, params)
message, user, chat,
account, permission, params
)
@bot.message_handler(commands=['redo'])
async def redo(message):
await _redo(message)
@bot.message_handler(commands=['config'])
async def set_config(message):