mirror of https://github.com/skygpu/skynet.git
Add autoconfiguration feature for telegram frontend
parent
d3b5d56187
commit
aa1d52dba0
|
@ -165,6 +165,15 @@ async def open_database_connection(
|
|||
else:
|
||||
await conn.execute(DB_INIT_SQL)
|
||||
|
||||
col_check = await conn.fetch(f'''
|
||||
select column_name
|
||||
from information_schema.columns
|
||||
where table_name = 'user_config' and column_name = 'autoconf';
|
||||
''')
|
||||
|
||||
if not col_check:
|
||||
await conn.execute('alter table skynet.user_config add column autoconf boolean;')
|
||||
|
||||
async def _db_call(method: str, *args, **kwargs):
|
||||
method = getattr(db, method)
|
||||
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import random
|
||||
|
||||
from ..constants import *
|
||||
|
||||
|
||||
|
@ -15,10 +17,14 @@ class ConfigUnknownAlgorithm(BaseException):
|
|||
class ConfigUnknownUpscaler(BaseException):
|
||||
...
|
||||
|
||||
class ConfigUnknownAutoConfSetting(BaseException):
|
||||
...
|
||||
|
||||
class ConfigSizeDivisionByEight(BaseException):
|
||||
...
|
||||
|
||||
|
||||
|
||||
def validate_user_config_request(req: str):
|
||||
params = req.split(' ')
|
||||
|
||||
|
@ -78,6 +84,18 @@ def validate_user_config_request(req: str):
|
|||
raise ConfigUnknownUpscaler(
|
||||
f'\"{val}\" is not a valid upscaler')
|
||||
|
||||
case 'autoconf':
|
||||
val = params[2]
|
||||
if val == 'on':
|
||||
val = True
|
||||
|
||||
elif val == 'off':
|
||||
val = False
|
||||
|
||||
else:
|
||||
raise ConfigUnknownAutoConfSetting(
|
||||
f'\"{val}\" not a valid setting for autoconf')
|
||||
|
||||
case _:
|
||||
raise ConfigUnknownAttribute(
|
||||
f'\"{attr}\" not a configurable parameter')
|
||||
|
@ -92,3 +110,22 @@ def validate_user_config_request(req: str):
|
|||
except ValueError:
|
||||
raise ValueError(f'\"{val}\" is not a number silly')
|
||||
|
||||
|
||||
def perform_auto_conf(config: dict) -> dict:
|
||||
model = config['model']
|
||||
prefered_size_w = 512
|
||||
prefered_size_h = 512
|
||||
|
||||
if 'xl' in model:
|
||||
prefered_size_w = 1024
|
||||
prefered_size_h = 1024
|
||||
|
||||
else:
|
||||
prefered_size_w = 512
|
||||
prefered_size_h = 512
|
||||
|
||||
config['step'] = random.randint(20, 35)
|
||||
config['width'] = prefered_size_w
|
||||
config['height'] = prefered_size_h
|
||||
|
||||
return config
|
||||
|
|
|
@ -9,7 +9,7 @@ from datetime import datetime, timedelta
|
|||
from PIL import Image
|
||||
from telebot.types import CallbackQuery, Message
|
||||
|
||||
from skynet.frontend import validate_user_config_request
|
||||
from skynet.frontend import validate_user_config_request, perform_auto_conf
|
||||
from skynet.constants import *
|
||||
|
||||
|
||||
|
@ -149,6 +149,10 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
|
|||
user_config = {**user_row}
|
||||
del user_config['id']
|
||||
|
||||
breakpoint()
|
||||
if user_config['autoconf']:
|
||||
user_config = perform_auto_conf(user_config)
|
||||
|
||||
params = {
|
||||
'prompt': prompt,
|
||||
**user_config
|
||||
|
@ -209,12 +213,18 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
|
|||
file_path = (await bot.get_file(file_id)).file_path
|
||||
image_raw = await bot.download_file(file_path)
|
||||
|
||||
user_config = {**user_row}
|
||||
del user_config['id']
|
||||
if user_config['autoconf']:
|
||||
user_config = perform_auto_conf(user_config)
|
||||
|
||||
with Image.open(io.BytesIO(image_raw)) as image:
|
||||
w, h = image.size
|
||||
|
||||
if w > 512 or h > 512:
|
||||
if w > user_config['width'] or h > user_config['height']:
|
||||
logging.warning(f'user sent img of size {image.size}')
|
||||
image.thumbnail((512, 512))
|
||||
image.thumbnail(
|
||||
(user_config['width'], user_config['height']))
|
||||
logging.warning(f'resized it to {image.size}')
|
||||
|
||||
image_loc = 'ipfs-staging/image.png'
|
||||
|
@ -228,9 +238,6 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
|
|||
|
||||
logging.info(f'mid: {message.id}')
|
||||
|
||||
user_config = {**user_row}
|
||||
del user_config['id']
|
||||
|
||||
params = {
|
||||
'prompt': prompt,
|
||||
**user_config
|
||||
|
@ -303,6 +310,8 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
|
|||
'new_user_request', user.id, message.id, status_msg.id, status=init_msg)
|
||||
user_config = {**user_row}
|
||||
del user_config['id']
|
||||
if user_config['autoconf']:
|
||||
user_config = perform_auto_conf(user_config)
|
||||
|
||||
params = {
|
||||
'prompt': prompt,
|
||||
|
|
Loading…
Reference in New Issue