Add autoconfiguration feature for telegram frontend

pull/26/head
Guillermo Rodriguez 2023-10-08 09:14:01 -03:00
parent d3b5d56187
commit aa1d52dba0
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
3 changed files with 61 additions and 6 deletions

View File

@ -165,6 +165,15 @@ async def open_database_connection(
else:
await conn.execute(DB_INIT_SQL)
col_check = await conn.fetch(f'''
select column_name
from information_schema.columns
where table_name = 'user_config' and column_name = 'autoconf';
''')
if not col_check:
await conn.execute('alter table skynet.user_config add column autoconf boolean;')
async def _db_call(method: str, *args, **kwargs):
method = getattr(db, method)

View File

@ -1,5 +1,7 @@
#!/usr/bin/python
import random
from ..constants import *
@ -15,10 +17,14 @@ class ConfigUnknownAlgorithm(BaseException):
class ConfigUnknownUpscaler(BaseException):
...
class ConfigUnknownAutoConfSetting(BaseException):
...
class ConfigSizeDivisionByEight(BaseException):
...
def validate_user_config_request(req: str):
params = req.split(' ')
@ -78,6 +84,18 @@ def validate_user_config_request(req: str):
raise ConfigUnknownUpscaler(
f'\"{val}\" is not a valid upscaler')
case 'autoconf':
val = params[2]
if val == 'on':
val = True
elif val == 'off':
val = False
else:
raise ConfigUnknownAutoConfSetting(
f'\"{val}\" not a valid setting for autoconf')
case _:
raise ConfigUnknownAttribute(
f'\"{attr}\" not a configurable parameter')
@ -92,3 +110,22 @@ def validate_user_config_request(req: str):
except ValueError:
raise ValueError(f'\"{val}\" is not a number silly')
def perform_auto_conf(config: dict) -> dict:
model = config['model']
prefered_size_w = 512
prefered_size_h = 512
if 'xl' in model:
prefered_size_w = 1024
prefered_size_h = 1024
else:
prefered_size_w = 512
prefered_size_h = 512
config['step'] = random.randint(20, 35)
config['width'] = prefered_size_w
config['height'] = prefered_size_h
return config

View File

@ -9,7 +9,7 @@ from datetime import datetime, timedelta
from PIL import Image
from telebot.types import CallbackQuery, Message
from skynet.frontend import validate_user_config_request
from skynet.frontend import validate_user_config_request, perform_auto_conf
from skynet.constants import *
@ -149,6 +149,10 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
user_config = {**user_row}
del user_config['id']
breakpoint()
if user_config['autoconf']:
user_config = perform_auto_conf(user_config)
params = {
'prompt': prompt,
**user_config
@ -209,12 +213,18 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
file_path = (await bot.get_file(file_id)).file_path
image_raw = await bot.download_file(file_path)
user_config = {**user_row}
del user_config['id']
if user_config['autoconf']:
user_config = perform_auto_conf(user_config)
with Image.open(io.BytesIO(image_raw)) as image:
w, h = image.size
if w > 512 or h > 512:
if w > user_config['width'] or h > user_config['height']:
logging.warning(f'user sent img of size {image.size}')
image.thumbnail((512, 512))
image.thumbnail(
(user_config['width'], user_config['height']))
logging.warning(f'resized it to {image.size}')
image_loc = 'ipfs-staging/image.png'
@ -228,9 +238,6 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
logging.info(f'mid: {message.id}')
user_config = {**user_row}
del user_config['id']
params = {
'prompt': prompt,
**user_config
@ -303,6 +310,8 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
'new_user_request', user.id, message.id, status_msg.id, status=init_msg)
user_config = {**user_row}
del user_config['id']
if user_config['autoconf']:
user_config = perform_auto_conf(user_config)
params = {
'prompt': prompt,