Add autoconfiguration feature for telegram frontend

pull/26/head
Guillermo Rodriguez 2023-10-08 09:14:01 -03:00
parent d3b5d56187
commit aa1d52dba0
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
3 changed files with 61 additions and 6 deletions

View File

@ -165,6 +165,15 @@ async def open_database_connection(
else: else:
await conn.execute(DB_INIT_SQL) await conn.execute(DB_INIT_SQL)
col_check = await conn.fetch(f'''
select column_name
from information_schema.columns
where table_name = 'user_config' and column_name = 'autoconf';
''')
if not col_check:
await conn.execute('alter table skynet.user_config add column autoconf boolean;')
async def _db_call(method: str, *args, **kwargs): async def _db_call(method: str, *args, **kwargs):
method = getattr(db, method) method = getattr(db, method)

View File

@ -1,5 +1,7 @@
#!/usr/bin/python #!/usr/bin/python
import random
from ..constants import * from ..constants import *
@ -15,10 +17,14 @@ class ConfigUnknownAlgorithm(BaseException):
class ConfigUnknownUpscaler(BaseException): class ConfigUnknownUpscaler(BaseException):
... ...
class ConfigUnknownAutoConfSetting(BaseException):
...
class ConfigSizeDivisionByEight(BaseException): class ConfigSizeDivisionByEight(BaseException):
... ...
def validate_user_config_request(req: str): def validate_user_config_request(req: str):
params = req.split(' ') params = req.split(' ')
@ -78,6 +84,18 @@ def validate_user_config_request(req: str):
raise ConfigUnknownUpscaler( raise ConfigUnknownUpscaler(
f'\"{val}\" is not a valid upscaler') f'\"{val}\" is not a valid upscaler')
case 'autoconf':
val = params[2]
if val == 'on':
val = True
elif val == 'off':
val = False
else:
raise ConfigUnknownAutoConfSetting(
f'\"{val}\" not a valid setting for autoconf')
case _: case _:
raise ConfigUnknownAttribute( raise ConfigUnknownAttribute(
f'\"{attr}\" not a configurable parameter') f'\"{attr}\" not a configurable parameter')
@ -92,3 +110,22 @@ def validate_user_config_request(req: str):
except ValueError: except ValueError:
raise ValueError(f'\"{val}\" is not a number silly') raise ValueError(f'\"{val}\" is not a number silly')
def perform_auto_conf(config: dict) -> dict:
model = config['model']
prefered_size_w = 512
prefered_size_h = 512
if 'xl' in model:
prefered_size_w = 1024
prefered_size_h = 1024
else:
prefered_size_w = 512
prefered_size_h = 512
config['step'] = random.randint(20, 35)
config['width'] = prefered_size_w
config['height'] = prefered_size_h
return config

View File

@ -9,7 +9,7 @@ from datetime import datetime, timedelta
from PIL import Image from PIL import Image
from telebot.types import CallbackQuery, Message from telebot.types import CallbackQuery, Message
from skynet.frontend import validate_user_config_request from skynet.frontend import validate_user_config_request, perform_auto_conf
from skynet.constants import * from skynet.constants import *
@ -149,6 +149,10 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
user_config = {**user_row} user_config = {**user_row}
del user_config['id'] del user_config['id']
breakpoint()
if user_config['autoconf']:
user_config = perform_auto_conf(user_config)
params = { params = {
'prompt': prompt, 'prompt': prompt,
**user_config **user_config
@ -209,12 +213,18 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
file_path = (await bot.get_file(file_id)).file_path file_path = (await bot.get_file(file_id)).file_path
image_raw = await bot.download_file(file_path) image_raw = await bot.download_file(file_path)
user_config = {**user_row}
del user_config['id']
if user_config['autoconf']:
user_config = perform_auto_conf(user_config)
with Image.open(io.BytesIO(image_raw)) as image: with Image.open(io.BytesIO(image_raw)) as image:
w, h = image.size w, h = image.size
if w > 512 or h > 512: if w > user_config['width'] or h > user_config['height']:
logging.warning(f'user sent img of size {image.size}') logging.warning(f'user sent img of size {image.size}')
image.thumbnail((512, 512)) image.thumbnail(
(user_config['width'], user_config['height']))
logging.warning(f'resized it to {image.size}') logging.warning(f'resized it to {image.size}')
image_loc = 'ipfs-staging/image.png' image_loc = 'ipfs-staging/image.png'
@ -228,9 +238,6 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
logging.info(f'mid: {message.id}') logging.info(f'mid: {message.id}')
user_config = {**user_row}
del user_config['id']
params = { params = {
'prompt': prompt, 'prompt': prompt,
**user_config **user_config
@ -303,6 +310,8 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
'new_user_request', user.id, message.id, status_msg.id, status=init_msg) 'new_user_request', user.id, message.id, status_msg.id, status=init_msg)
user_config = {**user_row} user_config = {**user_row}
del user_config['id'] del user_config['id']
if user_config['autoconf']:
user_config = perform_auto_conf(user_config)
params = { params = {
'prompt': prompt, 'prompt': prompt,