mirror of https://github.com/skygpu/skynet.git
Add autoconfiguration feature for telegram frontend
parent
d3b5d56187
commit
aa1d52dba0
|
@ -165,6 +165,15 @@ async def open_database_connection(
|
||||||
else:
|
else:
|
||||||
await conn.execute(DB_INIT_SQL)
|
await conn.execute(DB_INIT_SQL)
|
||||||
|
|
||||||
|
col_check = await conn.fetch(f'''
|
||||||
|
select column_name
|
||||||
|
from information_schema.columns
|
||||||
|
where table_name = 'user_config' and column_name = 'autoconf';
|
||||||
|
''')
|
||||||
|
|
||||||
|
if not col_check:
|
||||||
|
await conn.execute('alter table skynet.user_config add column autoconf boolean;')
|
||||||
|
|
||||||
async def _db_call(method: str, *args, **kwargs):
|
async def _db_call(method: str, *args, **kwargs):
|
||||||
method = getattr(db, method)
|
method = getattr(db, method)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
#!/usr/bin/python
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
from ..constants import *
|
from ..constants import *
|
||||||
|
|
||||||
|
|
||||||
|
@ -15,10 +17,14 @@ class ConfigUnknownAlgorithm(BaseException):
|
||||||
class ConfigUnknownUpscaler(BaseException):
|
class ConfigUnknownUpscaler(BaseException):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
class ConfigUnknownAutoConfSetting(BaseException):
|
||||||
|
...
|
||||||
|
|
||||||
class ConfigSizeDivisionByEight(BaseException):
|
class ConfigSizeDivisionByEight(BaseException):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def validate_user_config_request(req: str):
|
def validate_user_config_request(req: str):
|
||||||
params = req.split(' ')
|
params = req.split(' ')
|
||||||
|
|
||||||
|
@ -78,6 +84,18 @@ def validate_user_config_request(req: str):
|
||||||
raise ConfigUnknownUpscaler(
|
raise ConfigUnknownUpscaler(
|
||||||
f'\"{val}\" is not a valid upscaler')
|
f'\"{val}\" is not a valid upscaler')
|
||||||
|
|
||||||
|
case 'autoconf':
|
||||||
|
val = params[2]
|
||||||
|
if val == 'on':
|
||||||
|
val = True
|
||||||
|
|
||||||
|
elif val == 'off':
|
||||||
|
val = False
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ConfigUnknownAutoConfSetting(
|
||||||
|
f'\"{val}\" not a valid setting for autoconf')
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
raise ConfigUnknownAttribute(
|
raise ConfigUnknownAttribute(
|
||||||
f'\"{attr}\" not a configurable parameter')
|
f'\"{attr}\" not a configurable parameter')
|
||||||
|
@ -92,3 +110,22 @@ def validate_user_config_request(req: str):
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValueError(f'\"{val}\" is not a number silly')
|
raise ValueError(f'\"{val}\" is not a number silly')
|
||||||
|
|
||||||
|
|
||||||
|
def perform_auto_conf(config: dict) -> dict:
|
||||||
|
model = config['model']
|
||||||
|
prefered_size_w = 512
|
||||||
|
prefered_size_h = 512
|
||||||
|
|
||||||
|
if 'xl' in model:
|
||||||
|
prefered_size_w = 1024
|
||||||
|
prefered_size_h = 1024
|
||||||
|
|
||||||
|
else:
|
||||||
|
prefered_size_w = 512
|
||||||
|
prefered_size_h = 512
|
||||||
|
|
||||||
|
config['step'] = random.randint(20, 35)
|
||||||
|
config['width'] = prefered_size_w
|
||||||
|
config['height'] = prefered_size_h
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
|
@ -9,7 +9,7 @@ from datetime import datetime, timedelta
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from telebot.types import CallbackQuery, Message
|
from telebot.types import CallbackQuery, Message
|
||||||
|
|
||||||
from skynet.frontend import validate_user_config_request
|
from skynet.frontend import validate_user_config_request, perform_auto_conf
|
||||||
from skynet.constants import *
|
from skynet.constants import *
|
||||||
|
|
||||||
|
|
||||||
|
@ -149,6 +149,10 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
|
||||||
user_config = {**user_row}
|
user_config = {**user_row}
|
||||||
del user_config['id']
|
del user_config['id']
|
||||||
|
|
||||||
|
breakpoint()
|
||||||
|
if user_config['autoconf']:
|
||||||
|
user_config = perform_auto_conf(user_config)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'prompt': prompt,
|
'prompt': prompt,
|
||||||
**user_config
|
**user_config
|
||||||
|
@ -209,12 +213,18 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
|
||||||
file_path = (await bot.get_file(file_id)).file_path
|
file_path = (await bot.get_file(file_id)).file_path
|
||||||
image_raw = await bot.download_file(file_path)
|
image_raw = await bot.download_file(file_path)
|
||||||
|
|
||||||
|
user_config = {**user_row}
|
||||||
|
del user_config['id']
|
||||||
|
if user_config['autoconf']:
|
||||||
|
user_config = perform_auto_conf(user_config)
|
||||||
|
|
||||||
with Image.open(io.BytesIO(image_raw)) as image:
|
with Image.open(io.BytesIO(image_raw)) as image:
|
||||||
w, h = image.size
|
w, h = image.size
|
||||||
|
|
||||||
if w > 512 or h > 512:
|
if w > user_config['width'] or h > user_config['height']:
|
||||||
logging.warning(f'user sent img of size {image.size}')
|
logging.warning(f'user sent img of size {image.size}')
|
||||||
image.thumbnail((512, 512))
|
image.thumbnail(
|
||||||
|
(user_config['width'], user_config['height']))
|
||||||
logging.warning(f'resized it to {image.size}')
|
logging.warning(f'resized it to {image.size}')
|
||||||
|
|
||||||
image_loc = 'ipfs-staging/image.png'
|
image_loc = 'ipfs-staging/image.png'
|
||||||
|
@ -228,9 +238,6 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
|
||||||
|
|
||||||
logging.info(f'mid: {message.id}')
|
logging.info(f'mid: {message.id}')
|
||||||
|
|
||||||
user_config = {**user_row}
|
|
||||||
del user_config['id']
|
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'prompt': prompt,
|
'prompt': prompt,
|
||||||
**user_config
|
**user_config
|
||||||
|
@ -303,6 +310,8 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
|
||||||
'new_user_request', user.id, message.id, status_msg.id, status=init_msg)
|
'new_user_request', user.id, message.id, status_msg.id, status=init_msg)
|
||||||
user_config = {**user_row}
|
user_config = {**user_row}
|
||||||
del user_config['id']
|
del user_config['id']
|
||||||
|
if user_config['autoconf']:
|
||||||
|
user_config = perform_auto_conf(user_config)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'prompt': prompt,
|
'prompt': prompt,
|
||||||
|
|
Loading…
Reference in New Issue