From aa1d52dba0f5c7a813b9e194c786ecb099c107e8 Mon Sep 17 00:00:00 2001 From: Guillermo Rodriguez Date: Sun, 8 Oct 2023 09:14:01 -0300 Subject: [PATCH] Add autoconfiguration feature for telegram frontend --- skynet/db/functions.py | 9 +++++++ skynet/frontend/__init__.py | 37 ++++++++++++++++++++++++++++ skynet/frontend/telegram/handlers.py | 21 +++++++++++----- 3 files changed, 61 insertions(+), 6 deletions(-) diff --git a/skynet/db/functions.py b/skynet/db/functions.py index f52703e..bc759d3 100644 --- a/skynet/db/functions.py +++ b/skynet/db/functions.py @@ -165,6 +165,15 @@ async def open_database_connection( else: await conn.execute(DB_INIT_SQL) + col_check = await conn.fetch(f''' + select column_name + from information_schema.columns + where table_name = 'user_config' and column_name = 'autoconf'; + ''') + + if not col_check: + await conn.execute('alter table skynet.user_config add column autoconf boolean;') + async def _db_call(method: str, *args, **kwargs): method = getattr(db, method) diff --git a/skynet/frontend/__init__.py b/skynet/frontend/__init__.py index 42145a3..30d6fa1 100644 --- a/skynet/frontend/__init__.py +++ b/skynet/frontend/__init__.py @@ -1,5 +1,7 @@ #!/usr/bin/python +import random + from ..constants import * @@ -15,10 +17,14 @@ class ConfigUnknownAlgorithm(BaseException): class ConfigUnknownUpscaler(BaseException): ... +class ConfigUnknownAutoConfSetting(BaseException): + ... + class ConfigSizeDivisionByEight(BaseException): ... + def validate_user_config_request(req: str): params = req.split(' ') @@ -78,6 +84,18 @@ def validate_user_config_request(req: str): raise ConfigUnknownUpscaler( f'\"{val}\" is not a valid upscaler') + case 'autoconf': + val = params[2] + if val == 'on': + val = True + + elif val == 'off': + val = False + + else: + raise ConfigUnknownAutoConfSetting( + f'\"{val}\" not a valid setting for autoconf') + case _: raise ConfigUnknownAttribute( f'\"{attr}\" not a configurable parameter') @@ -92,3 +110,22 @@ def validate_user_config_request(req: str): except ValueError: raise ValueError(f'\"{val}\" is not a number silly') + +def perform_auto_conf(config: dict) -> dict: + model = config['model'] + prefered_size_w = 512 + prefered_size_h = 512 + + if 'xl' in model: + prefered_size_w = 1024 + prefered_size_h = 1024 + + else: + prefered_size_w = 512 + prefered_size_h = 512 + + config['step'] = random.randint(20, 35) + config['width'] = prefered_size_w + config['height'] = prefered_size_h + + return config diff --git a/skynet/frontend/telegram/handlers.py b/skynet/frontend/telegram/handlers.py index da99941..b2019a7 100644 --- a/skynet/frontend/telegram/handlers.py +++ b/skynet/frontend/telegram/handlers.py @@ -9,7 +9,7 @@ from datetime import datetime, timedelta from PIL import Image from telebot.types import CallbackQuery, Message -from skynet.frontend import validate_user_config_request +from skynet.frontend import validate_user_config_request, perform_auto_conf from skynet.constants import * @@ -149,6 +149,10 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'): user_config = {**user_row} del user_config['id'] + breakpoint() + if user_config['autoconf']: + user_config = perform_auto_conf(user_config) + params = { 'prompt': prompt, **user_config @@ -209,12 +213,18 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'): file_path = (await bot.get_file(file_id)).file_path image_raw = await bot.download_file(file_path) + user_config = {**user_row} + del user_config['id'] + if user_config['autoconf']: + user_config = perform_auto_conf(user_config) + with Image.open(io.BytesIO(image_raw)) as image: w, h = image.size - if w > 512 or h > 512: + if w > user_config['width'] or h > user_config['height']: logging.warning(f'user sent img of size {image.size}') - image.thumbnail((512, 512)) + image.thumbnail( + (user_config['width'], user_config['height'])) logging.warning(f'resized it to {image.size}') image_loc = 'ipfs-staging/image.png' @@ -228,9 +238,6 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'): logging.info(f'mid: {message.id}') - user_config = {**user_row} - del user_config['id'] - params = { 'prompt': prompt, **user_config @@ -303,6 +310,8 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'): 'new_user_request', user.id, message.id, status_msg.id, status=init_msg) user_config = {**user_row} del user_config['id'] + if user_config['autoconf']: + user_config = perform_auto_conf(user_config) params = { 'prompt': prompt,