mirror of https://github.com/skygpu/skynet.git
Frontend DB fixes and starting to add img2img
parent
5e017ffac0
commit
e63d395d5c
|
@ -329,7 +329,7 @@ def dgpu(
|
|||
|
||||
|
||||
@run.command()
|
||||
@click.option('--loglevel', '-l', default='warning', help='logging level')
|
||||
@click.option('--loglevel', '-l', default='INFO', help='logging level')
|
||||
@click.option(
|
||||
'--account', '-a', default='telegram')
|
||||
@click.option(
|
||||
|
@ -357,6 +357,7 @@ def telegram(
|
|||
db_user: str,
|
||||
db_pass: str
|
||||
):
|
||||
logging.basicConfig(level=loglevel)
|
||||
|
||||
key, account, permission = load_account_info(
|
||||
key, account, permission)
|
||||
|
@ -422,3 +423,6 @@ def pinner(loglevel, container):
|
|||
ipfs_node.pin(cid)
|
||||
|
||||
cleanup_pinned(now)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
...
|
||||
|
|
|
@ -102,7 +102,7 @@ MAX_WIDTH = 512
|
|||
MAX_HEIGHT = 656
|
||||
MAX_GUIDANCE = 20
|
||||
|
||||
DEFAULT_SEED = 0
|
||||
DEFAULT_SEED = None
|
||||
DEFAULT_WIDTH = 512
|
||||
DEFAULT_HEIGHT = 512
|
||||
DEFAULT_GUIDANCE = 7.5
|
||||
|
|
|
@ -26,27 +26,11 @@ CREATE SCHEMA IF NOT EXISTS skynet;
|
|||
|
||||
CREATE TABLE IF NOT EXISTS skynet.user(
|
||||
id SERIAL PRIMARY KEY NOT NULL,
|
||||
tg_id BIGINT,
|
||||
wp_id VARCHAR(128),
|
||||
mx_id VARCHAR(128),
|
||||
ig_id VARCHAR(128),
|
||||
generated INT NOT NULL,
|
||||
joined DATE NOT NULL,
|
||||
joined TIMESTAMP NOT NULL,
|
||||
last_prompt TEXT,
|
||||
role VARCHAR(128) NOT NULL
|
||||
);
|
||||
ALTER TABLE skynet.user
|
||||
ADD CONSTRAINT tg_unique
|
||||
UNIQUE (tg_id);
|
||||
ALTER TABLE skynet.user
|
||||
ADD CONSTRAINT wp_unique
|
||||
UNIQUE (wp_id);
|
||||
ALTER TABLE skynet.user
|
||||
ADD CONSTRAINT mx_unique
|
||||
UNIQUE (mx_id);
|
||||
ALTER TABLE skynet.user
|
||||
ADD CONSTRAINT ig_unique
|
||||
UNIQUE (ig_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS skynet.user_config(
|
||||
id SERIAL NOT NULL,
|
||||
|
@ -54,7 +38,7 @@ CREATE TABLE IF NOT EXISTS skynet.user_config(
|
|||
step INT NOT NULL,
|
||||
width INT NOT NULL,
|
||||
height INT NOT NULL,
|
||||
seed BIGINT NOT NULL,
|
||||
seed BIGINT,
|
||||
guidance REAL NOT NULL,
|
||||
strength REAL NOT NULL,
|
||||
upscaler VARCHAR(128)
|
||||
|
@ -177,16 +161,19 @@ async def open_database_connection(
|
|||
yield _db_call
|
||||
|
||||
|
||||
async def get_user(conn, uid: int):
|
||||
stmt = await conn.prepare(
|
||||
'SELECT * FROM skynet.user WHERE id = $1')
|
||||
return await stmt.fetchval(uid)
|
||||
|
||||
|
||||
async def get_user_config(conn, user: int):
|
||||
stmt = await conn.prepare(
|
||||
'SELECT * FROM skynet.user_config WHERE id = $1')
|
||||
return (await stmt.fetch(user))[0]
|
||||
conf = await stmt.fetch(user)
|
||||
if len(conf) == 1:
|
||||
return conf[0]
|
||||
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
async def get_user(conn, uid: int):
|
||||
return await get_user_config(conn, uid)
|
||||
|
||||
|
||||
async def get_last_prompt_of(conn, user: int):
|
||||
|
@ -208,7 +195,6 @@ async def new_user(conn, uid: int):
|
|||
id, generated, joined, last_prompt, role)
|
||||
|
||||
VALUES($1, $2, $3, $4, $5)
|
||||
ON CONFLICT DO NOTHING
|
||||
''')
|
||||
await stmt.fetch(
|
||||
uid, 0, date, None, DEFAULT_ROLE
|
||||
|
@ -216,18 +202,16 @@ async def new_user(conn, uid: int):
|
|||
|
||||
stmt = await conn.prepare('''
|
||||
INSERT INTO skynet.user_config(
|
||||
id, algo, step, width, height, seed, guidance, strength, upscaler)
|
||||
id, algo, step, width, height, guidance, strength, upscaler)
|
||||
|
||||
VALUES($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
ON CONFLICT DO NOTHING
|
||||
VALUES($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
''')
|
||||
user = await stmt.fetch(
|
||||
new_uid,
|
||||
resp = await stmt.fetch(
|
||||
uid,
|
||||
DEFAULT_ALGO,
|
||||
DEFAULT_STEP,
|
||||
DEFAULT_WIDTH,
|
||||
DEFAULT_HEIGHT,
|
||||
DEFAULT_SEED,
|
||||
DEFAULT_GUIDANCE,
|
||||
DEFAULT_STRENGTH,
|
||||
DEFAULT_UPSCALER
|
||||
|
|
|
@ -84,7 +84,12 @@ def validate_user_config_request(req: str):
|
|||
raise ConfigUnknownAttribute(
|
||||
f'\"{attr}\" not a configurable parameter')
|
||||
|
||||
return attr, val, f'config updated! {attr} to {val}'
|
||||
display_val = val
|
||||
if attr == 'seed':
|
||||
if not val:
|
||||
display_val = 'Random'
|
||||
|
||||
return attr, val, f'config updated! {attr} to {display_val}'
|
||||
|
||||
except ValueError:
|
||||
raise ValueError(f'\"{val}\" is not a number silly')
|
||||
|
|
|
@ -2,12 +2,15 @@
|
|||
|
||||
import io
|
||||
import zlib
|
||||
import random
|
||||
import logging
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
from hashlib import sha256
|
||||
from datetime import datetime
|
||||
|
||||
import asks
|
||||
import docker
|
||||
|
||||
from PIL import Image
|
||||
|
@ -18,7 +21,9 @@ from trio_asyncio import aio_as_trio
|
|||
from telebot.types import (
|
||||
InputFile, InputMediaPhoto, InlineKeyboardButton, InlineKeyboardMarkup
|
||||
)
|
||||
from telebot.async_telebot import AsyncTeleBot
|
||||
|
||||
from telebot.types import CallbackQuery
|
||||
from telebot.async_telebot import AsyncTeleBot, ExceptionHandler
|
||||
from telebot.formatting import hlink
|
||||
|
||||
from ..db import open_new_database, open_database_connection
|
||||
|
@ -27,7 +32,11 @@ from ..constants import *
|
|||
from . import *
|
||||
|
||||
|
||||
PREFIX = 'tg'
|
||||
class SKYExceptionHandler(ExceptionHandler):
|
||||
|
||||
def handle(exception):
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
def build_redo_menu():
|
||||
btn_redo = InlineKeyboardButton("Redo", callback_data=json.dumps({'method': 'redo'}))
|
||||
|
@ -113,20 +122,36 @@ async def get_user_nonce(cleos, user: str):
|
|||
|
||||
async def work_request(
|
||||
bot, cleos, hyperion,
|
||||
message,
|
||||
message, user, chat,
|
||||
account: str,
|
||||
permission: str,
|
||||
params: dict
|
||||
params: dict,
|
||||
file_id: str | None = None,
|
||||
file_path: str | None = None
|
||||
):
|
||||
if params['seed'] == None:
|
||||
params['seed'] = random.randint(0, 9e18)
|
||||
|
||||
body = json.dumps({
|
||||
'method': 'diffuse',
|
||||
'params': params
|
||||
})
|
||||
user = message.from_user
|
||||
chat = message.chat
|
||||
request_time = datetime.now().isoformat()
|
||||
|
||||
if file_id:
|
||||
image_raw = await bot.download_file(file_path)
|
||||
image = Image.open(io.BytesIO(image_raw))
|
||||
w, h = image.size
|
||||
logging.info(f'user sent img of size {image.size}')
|
||||
|
||||
if w > 512 or h > 512:
|
||||
image.thumbnail((512, 512))
|
||||
logging.warning(f'resized it to {image.size}')
|
||||
|
||||
binary = image_raw.hex()
|
||||
|
||||
ec, out = cleos.push_action(
|
||||
'telos.gpu', 'enqueue', [account, body, '', '20.0000 GPU'], f'{account}@{permission}'
|
||||
'telos.gpu', 'enqueue', [account, body, binary, '20.0000 GPU'], f'{account}@{permission}'
|
||||
)
|
||||
out = collect_stdout(out)
|
||||
if ec != 0:
|
||||
|
@ -168,10 +193,51 @@ async def work_request(
|
|||
await bot.reply_to(message, 'timeout processing request')
|
||||
return
|
||||
|
||||
# attempt to get the image and send it
|
||||
ipfs_link = f'http://test1.us.telos.net:8080/ipfs/{ipfs_hash}/image.png'
|
||||
logging.info(f'attempting to get image at {ipfs_link}')
|
||||
resp = None
|
||||
for i in range(10):
|
||||
try:
|
||||
resp = await asks.get(ipfs_link, timeout=2)
|
||||
|
||||
except asks.errors.RequestTimeout:
|
||||
logging.warning('timeout...')
|
||||
...
|
||||
|
||||
logging.info(f'status_code: {resp.status_code}')
|
||||
|
||||
caption = generate_reply_caption(
|
||||
user, params, ipfs_hash, tx_hash)
|
||||
|
||||
if resp.status_code != 200:
|
||||
await bot.reply_to(
|
||||
message,
|
||||
generate_reply_caption(
|
||||
user, params, ipfs_hash, tx_hash),
|
||||
caption,
|
||||
reply_markup=build_redo_menu(),
|
||||
parse_mode='HTML'
|
||||
)
|
||||
|
||||
else:
|
||||
if file_id: # img2img
|
||||
await bot.send_media_group(
|
||||
chat.id,
|
||||
media=[
|
||||
InputMediaPhoto(file_id),
|
||||
InputMediaPhoto(
|
||||
resp.raw,
|
||||
caption=caption
|
||||
)
|
||||
],
|
||||
reply_markup=build_redo_menu(),
|
||||
parse_mode='HTML'
|
||||
)
|
||||
|
||||
else: # txt2img
|
||||
await bot.send_photo(
|
||||
chat.id,
|
||||
caption=caption,
|
||||
photo=resp.raw,
|
||||
reply_markup=build_redo_menu(),
|
||||
parse_mode='HTML'
|
||||
)
|
||||
|
@ -205,7 +271,7 @@ async def run_skynet_telegram(
|
|||
if key:
|
||||
cleos.setup_wallet(key)
|
||||
|
||||
bot = AsyncTeleBot(tg_token)
|
||||
bot = AsyncTeleBot(tg_token, exception_handler=SKYExceptionHandler)
|
||||
logging.info(f'tg_token: {tg_token}')
|
||||
|
||||
async with open_database_connection(
|
||||
|
@ -233,7 +299,7 @@ async def run_skynet_telegram(
|
|||
|
||||
@bot.message_handler(commands=['txt2img'])
|
||||
async def send_txt2img(message):
|
||||
user = message.from_user.id
|
||||
user = message.from_user
|
||||
chat = message.chat
|
||||
reply_id = None
|
||||
if chat.type == 'group' and chat.id == GROUP_ID:
|
||||
|
@ -247,8 +313,8 @@ async def run_skynet_telegram(
|
|||
|
||||
logging.info(f'mid: {message.id}')
|
||||
|
||||
await db_call('get_or_create_user', user)
|
||||
user_config = {**(await db_call('get_user_config', user))}
|
||||
user_row = await db_call('get_or_create_user', user.id)
|
||||
user_config = {**user_row}
|
||||
del user_config['id']
|
||||
|
||||
params = {
|
||||
|
@ -256,22 +322,22 @@ async def run_skynet_telegram(
|
|||
**user_config
|
||||
}
|
||||
|
||||
await db_call('update_user_stats', user, last_prompt=prompt)
|
||||
await db_call('update_user_stats', user.id, last_prompt=prompt)
|
||||
|
||||
await work_request(
|
||||
bot, cleos, hyperion,
|
||||
message, account, permission, params)
|
||||
message, user, chat,
|
||||
account, permission, params
|
||||
)
|
||||
|
||||
@bot.message_handler(func=lambda message: True, content_types=['photo'])
|
||||
async def send_img2img(message):
|
||||
user = message.from_user.id
|
||||
user = message.from_user
|
||||
chat = message.chat
|
||||
reply_id = None
|
||||
if chat.type == 'group' and chat.id == GROUP_ID:
|
||||
reply_id = message.message_id
|
||||
|
||||
user_id = f'tg+{message.from_user.id}'
|
||||
|
||||
if not message.caption.startswith('/img2img'):
|
||||
await bot.reply_to(
|
||||
message,
|
||||
|
@ -287,62 +353,26 @@ async def run_skynet_telegram(
|
|||
|
||||
file_id = message.photo[-1].file_id
|
||||
file_path = (await bot.get_file(file_id)).file_path
|
||||
file_raw = await bot.download_file(file_path)
|
||||
|
||||
logging.info(f'mid: {message.id}')
|
||||
|
||||
user = await db_call('get_or_create_user', user_id)
|
||||
user_config = {**(await db_call('get_user_config', user))}
|
||||
user_row = await db_call('get_or_create_user', user.id)
|
||||
user_config = {**user_row}
|
||||
del user_config['id']
|
||||
|
||||
req = json.dumps({
|
||||
'method': 'diffuse',
|
||||
'params': {
|
||||
params = {
|
||||
'prompt': prompt,
|
||||
**user_config
|
||||
}
|
||||
})
|
||||
|
||||
ec, out = cleos.push_action(
|
||||
'telos.gpu', 'enqueue', [account, req, file_raw.hex()], f'{account}@{permission}'
|
||||
await db_call('update_user_stats', user.id, last_prompt=prompt)
|
||||
|
||||
await work_request(
|
||||
bot, cleos, hyperion,
|
||||
message, user, chat,
|
||||
account, permission, params,
|
||||
file_id=file_id, file_path=file_path
|
||||
)
|
||||
if ec != 0:
|
||||
await bot.reply_to(message, out)
|
||||
return
|
||||
|
||||
request_id = int(out)
|
||||
logging.info(f'{request_id} enqueued.')
|
||||
|
||||
ipfs_hash = None
|
||||
sha_hash = None
|
||||
for i in range(60):
|
||||
result = cleos.get_table(
|
||||
'telos.gpu', 'telos.gpu', 'results',
|
||||
index_position=2,
|
||||
key_type='i64',
|
||||
lower_bound=request_id,
|
||||
upper_bound=request_id
|
||||
)
|
||||
if len(results) > 0:
|
||||
ipfs_hash = result[0]['ipfs_hash']
|
||||
sha_hash = result[0]['result_hash']
|
||||
break
|
||||
else:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if not ipfs_hash:
|
||||
await bot.reply_to(message, 'timeout processing request')
|
||||
|
||||
ipfs_link = f'https://ipfs.io/ipfs/{ipfs_hash}/image.png'
|
||||
|
||||
await bot.reply_to(
|
||||
message,
|
||||
ipfs_link + '\n' +
|
||||
prepare_metainfo_caption(user, result['meta']['meta']),
|
||||
reply_to_message_id=reply_id,
|
||||
reply_markup=build_redo_menu()
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@bot.message_handler(commands=['img2img'])
|
||||
|
@ -352,17 +382,23 @@ async def run_skynet_telegram(
|
|||
'seems you tried to do an img2img command without sending image'
|
||||
)
|
||||
|
||||
@bot.message_handler(commands=['redo'])
|
||||
async def redo(message):
|
||||
user = message.from_user.id
|
||||
async def _redo(message_or_query):
|
||||
if isinstance(message_or_query, CallbackQuery):
|
||||
query = message_or_query
|
||||
message = query.message
|
||||
user = query.from_user
|
||||
chat = query.message.chat
|
||||
|
||||
else:
|
||||
message = message_or_query
|
||||
user = message.from_user
|
||||
chat = message.chat
|
||||
|
||||
reply_id = None
|
||||
if chat.type == 'group' and chat.id == GROUP_ID:
|
||||
reply_id = message.message_id
|
||||
|
||||
user_config = {**(await db_call('get_user_config', user))}
|
||||
del user_config['id']
|
||||
prompt = await db_call('get_last_prompt_of', user)
|
||||
prompt = await db_call('get_last_prompt_of', user.id)
|
||||
|
||||
if not prompt:
|
||||
await bot.reply_to(
|
||||
|
@ -371,6 +407,11 @@ async def run_skynet_telegram(
|
|||
)
|
||||
return
|
||||
|
||||
|
||||
user_row = await db_call('get_or_create_user', user.id)
|
||||
user_config = {**user_row}
|
||||
del user_config['id']
|
||||
|
||||
params = {
|
||||
'prompt': prompt,
|
||||
**user_config
|
||||
|
@ -378,7 +419,13 @@ async def run_skynet_telegram(
|
|||
|
||||
await work_request(
|
||||
bot, cleos, hyperion,
|
||||
message, account, permission, params)
|
||||
message, user, chat,
|
||||
account, permission, params
|
||||
)
|
||||
|
||||
@bot.message_handler(commands=['redo'])
|
||||
async def redo(message):
|
||||
await _redo(message)
|
||||
|
||||
@bot.message_handler(commands=['config'])
|
||||
async def set_config(message):
|
||||
|
|
Loading…
Reference in New Issue