mirror of https://github.com/skygpu/skynet.git
Finish telegram frontend integration
Correctly update user stats after a txt2img request Swap tg_id to BIGINT on database Refactor help topic system to be more generic Fix skynet brain cli entry point Fix multi request rpc session, we werent creating a context on the req side Fix redo method DGPU now returns image metadata Improve error messages brain to frontend Add version as constant Add telegram dep to requirementspull/2/head
parent
896b0f684b
commit
cb92aed51c
|
@ -5,3 +5,4 @@ aiohttp
|
||||||
msgspec
|
msgspec
|
||||||
pyOpenSSL
|
pyOpenSSL
|
||||||
trio_asyncio
|
trio_asyncio
|
||||||
|
pyTelegramBotAPI
|
||||||
|
|
4
setup.py
4
setup.py
|
@ -1,8 +1,10 @@
|
||||||
from setuptools import setup, find_packages
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
from skynet.constants import VERSION
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='skynet',
|
name='skynet',
|
||||||
version='0.1.0a6',
|
version=VERSION,
|
||||||
description='Decentralized compute platform',
|
description='Decentralized compute platform',
|
||||||
author='Guillermo Rodriguez',
|
author='Guillermo Rodriguez',
|
||||||
author_email='guillermo@telos.net',
|
author_email='guillermo@telos.net',
|
||||||
|
|
|
@ -4,6 +4,7 @@ import json
|
||||||
import uuid
|
import uuid
|
||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
|
import traceback
|
||||||
|
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -90,10 +91,10 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||||
logging.info(f'pre next_worker: {next_worker}')
|
logging.info(f'pre next_worker: {next_worker}')
|
||||||
|
|
||||||
if next_worker == None:
|
if next_worker == None:
|
||||||
raise SkynetDGPUOffline
|
raise SkynetDGPUOffline('No workers connected, try again later')
|
||||||
|
|
||||||
if are_all_workers_busy():
|
if are_all_workers_busy():
|
||||||
raise SkynetDGPUOverloaded
|
raise SkynetDGPUOverloaded('All workers are busy at the moment')
|
||||||
|
|
||||||
|
|
||||||
nid = list(nodes.keys())[next_worker]
|
nid = list(nodes.keys())[next_worker]
|
||||||
|
@ -175,6 +176,8 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||||
with trio.move_on_after(30):
|
with trio.move_on_after(30):
|
||||||
await img_event.wait()
|
await img_event.wait()
|
||||||
|
|
||||||
|
logging.info(f'img event: {ack_event.is_set()}')
|
||||||
|
|
||||||
if not img_event.is_set():
|
if not img_event.is_set():
|
||||||
disconnect_node(nid)
|
disconnect_node(nid)
|
||||||
raise SkynetDGPUComputeError('30 seconds timeout while processing request')
|
raise SkynetDGPUComputeError('30 seconds timeout while processing request')
|
||||||
|
@ -187,7 +190,7 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||||
if 'error' in img_resp.params:
|
if 'error' in img_resp.params:
|
||||||
raise SkynetDGPUComputeError(img_resp.params['error'])
|
raise SkynetDGPUComputeError(img_resp.params['error'])
|
||||||
|
|
||||||
return rid, img_resp.params['img']
|
return rid, img_resp.params['img'], img_resp.params['meta']
|
||||||
|
|
||||||
async def handle_user_request(rpc_ctx, req):
|
async def handle_user_request(rpc_ctx, req):
|
||||||
try:
|
try:
|
||||||
|
@ -202,39 +205,54 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||||
user_config = {**(await get_user_config(conn, user))}
|
user_config = {**(await get_user_config(conn, user))}
|
||||||
del user_config['id']
|
del user_config['id']
|
||||||
prompt = req.params['prompt']
|
prompt = req.params['prompt']
|
||||||
user_config= {
|
|
||||||
key : req.params.get(key, val)
|
|
||||||
for key, val in user_config.items()
|
|
||||||
}
|
|
||||||
req = ImageGenRequest(
|
req = ImageGenRequest(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
**user_config
|
**user_config
|
||||||
)
|
)
|
||||||
rid, img = await dgpu_stream_one_img(req)
|
rid, img, meta = await dgpu_stream_one_img(req)
|
||||||
|
logging.info(f'done streaming {rid}')
|
||||||
result = {
|
result = {
|
||||||
'id': rid,
|
'id': rid,
|
||||||
'img': img
|
'img': img,
|
||||||
|
'meta': meta
|
||||||
}
|
}
|
||||||
|
|
||||||
|
await update_user_stats(conn, user, last_prompt=prompt)
|
||||||
|
logging.info('updated user stats.')
|
||||||
|
|
||||||
case 'redo':
|
case 'redo':
|
||||||
logging.info('redo')
|
logging.info('redo')
|
||||||
user_config = await get_user_config(conn, user)
|
user_config = {**(await get_user_config(conn, user))}
|
||||||
|
del user_config['id']
|
||||||
prompt = await get_last_prompt_of(conn, user)
|
prompt = await get_last_prompt_of(conn, user)
|
||||||
|
|
||||||
|
if prompt:
|
||||||
req = ImageGenRequest(
|
req = ImageGenRequest(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
**user_config
|
**user_config
|
||||||
)
|
)
|
||||||
rid, img = await dgpu_stream_one_img(req)
|
rid, img, meta = await dgpu_stream_one_img(req)
|
||||||
result = {
|
result = {
|
||||||
'id': rid,
|
'id': rid,
|
||||||
'img': img
|
'img': img,
|
||||||
|
'meta': meta
|
||||||
|
}
|
||||||
|
await update_user_stats(conn, user)
|
||||||
|
logging.info('updated user stats.')
|
||||||
|
|
||||||
|
else:
|
||||||
|
result = {
|
||||||
|
'error': 'skynet_no_last_prompt',
|
||||||
|
'message': 'No prompt to redo, do txt2img first'
|
||||||
}
|
}
|
||||||
|
|
||||||
case 'config':
|
case 'config':
|
||||||
logging.info('config')
|
logging.info('config')
|
||||||
if req.params['attr'] in CONFIG_ATTRS:
|
if req.params['attr'] in CONFIG_ATTRS:
|
||||||
|
logging.info(f'update: {req.params}')
|
||||||
await update_user_config(
|
await update_user_config(
|
||||||
conn, user, req.params['attr'], req.params['val'])
|
conn, user, req.params['attr'], req.params['val'])
|
||||||
|
logging.info('done')
|
||||||
|
|
||||||
case 'stats':
|
case 'stats':
|
||||||
logging.info('stats')
|
logging.info('stats')
|
||||||
|
@ -255,9 +273,10 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||||
'message': str(e)
|
'message': str(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
except SkynetDGPUOverloaded:
|
except SkynetDGPUOverloaded as e:
|
||||||
result = {
|
result = {
|
||||||
'error': 'skynet_dgpu_overloaded',
|
'error': 'skynet_dgpu_overloaded',
|
||||||
|
'message': str(e),
|
||||||
'nodes': len(nodes)
|
'nodes': len(nodes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -266,14 +285,23 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||||
'error': 'skynet_dgpu_compute_error',
|
'error': 'skynet_dgpu_compute_error',
|
||||||
'message': str(e)
|
'message': str(e)
|
||||||
}
|
}
|
||||||
|
except BaseException as e:
|
||||||
|
traceback.print_exception(type(e), e, e.__traceback__)
|
||||||
|
result = {
|
||||||
|
'error': 'skynet_internal_error',
|
||||||
|
'message': str(e)
|
||||||
|
}
|
||||||
|
|
||||||
resp = SkynetRPCResponse(result=result)
|
resp = SkynetRPCResponse(result=result)
|
||||||
|
|
||||||
if security:
|
if security:
|
||||||
resp.sign(tls_key, 'skynet')
|
resp.sign(tls_key, 'skynet')
|
||||||
|
|
||||||
|
logging.info('sending response')
|
||||||
await rpc_ctx.asend(
|
await rpc_ctx.asend(
|
||||||
json.dumps(resp.to_dict()).encode())
|
json.dumps(resp.to_dict()).encode())
|
||||||
|
rpc_ctx.close()
|
||||||
|
logging.info('done')
|
||||||
|
|
||||||
async def request_service(n):
|
async def request_service(n):
|
||||||
nonlocal next_worker
|
nonlocal next_worker
|
||||||
|
@ -329,6 +357,8 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||||
await ctx.asend(
|
await ctx.asend(
|
||||||
json.dumps(resp.to_dict()).encode())
|
json.dumps(resp.to_dict()).encode())
|
||||||
|
|
||||||
|
ctx.close()
|
||||||
|
|
||||||
|
|
||||||
async with trio.open_nursery() as n:
|
async with trio.open_nursery() as n:
|
||||||
n.start_soon(dgpu_image_streamer)
|
n.start_soon(dgpu_image_streamer)
|
||||||
|
|
|
@ -8,10 +8,12 @@ from functools import partial
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
import click
|
import click
|
||||||
|
import trio_asyncio
|
||||||
|
|
||||||
from . import utils
|
from . import utils
|
||||||
from .dgpu import open_dgpu_node
|
from .dgpu import open_dgpu_node
|
||||||
from .brain import run_skynet
|
from .brain import run_skynet
|
||||||
|
from .constants import ALGOS
|
||||||
|
|
||||||
from .frontend.telegram import run_skynet_telegram
|
from .frontend.telegram import run_skynet_telegram
|
||||||
|
|
||||||
|
@ -61,16 +63,16 @@ def run(*args, **kwargs):
|
||||||
@click.option(
|
@click.option(
|
||||||
'--host', '-h', default='localhost:5432')
|
'--host', '-h', default='localhost:5432')
|
||||||
@click.option(
|
@click.option(
|
||||||
'--pass', '-p', default='password')
|
'--passwd', '-p', default='password')
|
||||||
def skynet(
|
def brain(
|
||||||
loglevel: str,
|
loglevel: str,
|
||||||
host: str,
|
host: str,
|
||||||
passw: str
|
passwd: str
|
||||||
):
|
):
|
||||||
async def _run_skynet():
|
async def _run_skynet():
|
||||||
async with run_skynet(
|
async with run_skynet(
|
||||||
db_host=host,
|
db_host=host,
|
||||||
db_pass=passw
|
db_pass=passwd
|
||||||
):
|
):
|
||||||
await trio.sleep_forever()
|
await trio.sleep_forever()
|
||||||
|
|
||||||
|
@ -86,13 +88,13 @@ def skynet(
|
||||||
@click.option(
|
@click.option(
|
||||||
'--cert', '-c', default='whitelist/dgpu')
|
'--cert', '-c', default='whitelist/dgpu')
|
||||||
@click.option(
|
@click.option(
|
||||||
'--algos', '-a', default=None)
|
'--algos', '-a', default=json.dumps(['midj']))
|
||||||
def dgpu(
|
def dgpu(
|
||||||
loglevel: str,
|
loglevel: str,
|
||||||
uid: str,
|
uid: str,
|
||||||
key: str,
|
key: str,
|
||||||
cert: str,
|
cert: str,
|
||||||
algos: Optional[str]
|
algos: str
|
||||||
):
|
):
|
||||||
trio.run(
|
trio.run(
|
||||||
partial(
|
partial(
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
#!/usr/bin/python
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
VERSION = '0.1a6'
|
||||||
|
|
||||||
DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda'
|
DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda'
|
||||||
|
|
||||||
DB_HOST = 'ancap.tech:34508'
|
DB_HOST = 'localhost:5432'
|
||||||
DB_USER = 'skynet'
|
DB_USER = 'skynet'
|
||||||
DB_PASS = 'password'
|
DB_PASS = 'password'
|
||||||
DB_NAME = 'skynet'
|
DB_NAME = 'skynet'
|
||||||
|
@ -21,7 +23,7 @@ ALGOS = {
|
||||||
|
|
||||||
N = '\n'
|
N = '\n'
|
||||||
HELP_TEXT = f'''
|
HELP_TEXT = f'''
|
||||||
test art bot v0.1a4
|
test art bot v{VERSION}
|
||||||
|
|
||||||
commands work on a user per user basis!
|
commands work on a user per user basis!
|
||||||
config is individual to each user!
|
config is individual to each user!
|
||||||
|
@ -47,7 +49,7 @@ config is individual to each user!
|
||||||
/config guidance NUMBER - prompt text importance
|
/config guidance NUMBER - prompt text importance
|
||||||
'''
|
'''
|
||||||
|
|
||||||
UNKNOWN_CMD_TEXT = 'unknown command! try sending \"/help\"'
|
UNKNOWN_CMD_TEXT = 'Unknown command! Try sending \"/help\"'
|
||||||
|
|
||||||
DONATION_INFO = '0xf95335682DF281FFaB7E104EB87B69625d9622B6\ngoal: 25/650usd'
|
DONATION_INFO = '0xf95335682DF281FFaB7E104EB87B69625d9622B6\ngoal: 25/650usd'
|
||||||
|
|
||||||
|
@ -74,22 +76,24 @@ COOL_WORDS = [
|
||||||
'michelangelo'
|
'michelangelo'
|
||||||
]
|
]
|
||||||
|
|
||||||
HELP_STEP = '''
|
HELP_TOPICS = {
|
||||||
diffusion models are iterative processes – a repeated cycle that starts with a\
|
'step': '''
|
||||||
|
Diffusion models are iterative processes – a repeated cycle that starts with a\
|
||||||
random noise generated from text input. With each step, some noise is removed\
|
random noise generated from text input. With each step, some noise is removed\
|
||||||
, resulting in a higher-quality image over time. The repetition stops when the\
|
, resulting in a higher-quality image over time. The repetition stops when the\
|
||||||
desired number of steps completes.
|
desired number of steps completes.
|
||||||
|
|
||||||
around 25 sampling steps are usually enough to achieve high-quality images. Us\
|
Around 25 sampling steps are usually enough to achieve high-quality images. Us\
|
||||||
ing more may produce a slightly different picture, but not necessarily better \
|
ing more may produce a slightly different picture, but not necessarily better \
|
||||||
quality.
|
quality.
|
||||||
'''
|
''',
|
||||||
|
|
||||||
HELP_GUIDANCE = '''
|
'guidance': '''
|
||||||
the guidance scale is a parameter that controls how much the image generation\
|
The guidance scale is a parameter that controls how much the image generation\
|
||||||
process follows the text prompt. The higher the value, the more image sticks\
|
process follows the text prompt. The higher the value, the more image sticks\
|
||||||
to a given text input.
|
to a given text input.
|
||||||
'''
|
'''
|
||||||
|
}
|
||||||
|
|
||||||
HELP_UNKWNOWN_PARAM = 'don\'t have any info on that.'
|
HELP_UNKWNOWN_PARAM = 'don\'t have any info on that.'
|
||||||
|
|
||||||
|
|
29
skynet/db.py
29
skynet/db.py
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from contextlib import asynccontextmanager as acm
|
from contextlib import asynccontextmanager as acm
|
||||||
|
|
||||||
|
@ -19,7 +20,7 @@ CREATE SCHEMA IF NOT EXISTS skynet;
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS skynet.user(
|
CREATE TABLE IF NOT EXISTS skynet.user(
|
||||||
id SERIAL PRIMARY KEY NOT NULL,
|
id SERIAL PRIMARY KEY NOT NULL,
|
||||||
tg_id INT,
|
tg_id BIGINT,
|
||||||
wp_id VARCHAR(128),
|
wp_id VARCHAR(128),
|
||||||
mx_id VARCHAR(128),
|
mx_id VARCHAR(128),
|
||||||
ig_id VARCHAR(128),
|
ig_id VARCHAR(128),
|
||||||
|
@ -47,7 +48,7 @@ CREATE TABLE IF NOT EXISTS skynet.user_config(
|
||||||
step INT NOT NULL,
|
step INT NOT NULL,
|
||||||
width INT NOT NULL,
|
width INT NOT NULL,
|
||||||
height INT NOT NULL,
|
height INT NOT NULL,
|
||||||
seed INT,
|
seed BIGINT,
|
||||||
guidance INT NOT NULL,
|
guidance INT NOT NULL,
|
||||||
upscaler VARCHAR(128)
|
upscaler VARCHAR(128)
|
||||||
);
|
);
|
||||||
|
@ -124,7 +125,7 @@ async def get_user_config(conn, user: int):
|
||||||
|
|
||||||
|
|
||||||
async def get_last_prompt_of(conn, user: int):
|
async def get_last_prompt_of(conn, user: int):
|
||||||
stms = await conn.prepare(
|
stmt = await conn.prepare(
|
||||||
'SELECT last_prompt FROM skynet.user WHERE id = $1')
|
'SELECT last_prompt FROM skynet.user WHERE id = $1')
|
||||||
return await stmt.fetchval(user)
|
return await stmt.fetchval(user)
|
||||||
|
|
||||||
|
@ -198,7 +199,12 @@ async def get_or_create_user(conn, uid: str):
|
||||||
return user
|
return user
|
||||||
|
|
||||||
async def update_user(conn, user: int, attr: str, val):
|
async def update_user(conn, user: int, attr: str, val):
|
||||||
...
|
stmt = await conn.prepare(f'''
|
||||||
|
UPDATE skynet.user
|
||||||
|
SET {attr} = $2
|
||||||
|
WHERE id = $1
|
||||||
|
''')
|
||||||
|
await stmt.fetch(user, val)
|
||||||
|
|
||||||
async def update_user_config(conn, user: int, attr: str, val):
|
async def update_user_config(conn, user: int, attr: str, val):
|
||||||
stmt = await conn.prepare(f'''
|
stmt = await conn.prepare(f'''
|
||||||
|
@ -218,3 +224,18 @@ async def get_user_stats(conn, user: int):
|
||||||
assert len(records) == 1
|
assert len(records) == 1
|
||||||
record = records[0]
|
record = records[0]
|
||||||
return record
|
return record
|
||||||
|
|
||||||
|
async def update_user_stats(
|
||||||
|
conn,
|
||||||
|
user: int,
|
||||||
|
last_prompt: Optional[str] = None
|
||||||
|
):
|
||||||
|
stmt = await conn.prepare('''
|
||||||
|
UPDATE skynet.user
|
||||||
|
SET generated = generated + 1
|
||||||
|
WHERE id = $1
|
||||||
|
''')
|
||||||
|
await stmt.fetch(user)
|
||||||
|
|
||||||
|
if last_prompt:
|
||||||
|
await update_user(conn, user, 'last_prompt', last_prompt)
|
||||||
|
|
|
@ -109,7 +109,6 @@ async def open_dgpu_node(
|
||||||
'generated': 0
|
'generated': 0
|
||||||
}
|
}
|
||||||
|
|
||||||
seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64)
|
|
||||||
try:
|
try:
|
||||||
image = models[ireq.algo]['pipe'](
|
image = models[ireq.algo]['pipe'](
|
||||||
ireq.prompt,
|
ireq.prompt,
|
||||||
|
@ -117,7 +116,7 @@ async def open_dgpu_node(
|
||||||
height=ireq.height,
|
height=ireq.height,
|
||||||
guidance_scale=ireq.guidance,
|
guidance_scale=ireq.guidance,
|
||||||
num_inference_steps=ireq.step,
|
num_inference_steps=ireq.step,
|
||||||
generator=torch.Generator("cuda").manual_seed(seed)
|
generator=torch.Generator("cuda").manual_seed(ireq.seed)
|
||||||
).images[0]
|
).images[0]
|
||||||
return image.tobytes()
|
return image.tobytes()
|
||||||
|
|
||||||
|
@ -207,12 +206,18 @@ async def open_dgpu_node(
|
||||||
logging.info(f'sent ack, processing {req.rid}...')
|
logging.info(f'sent ack, processing {req.rid}...')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
img = await gpu_compute_one(
|
img_req = ImageGenRequest(**req.params)
|
||||||
ImageGenRequest(**req.params))
|
if not img_req.seed:
|
||||||
|
img_req.seed = random.randint(0, 2 ** 64)
|
||||||
|
|
||||||
|
img = await gpu_compute_one(img_req)
|
||||||
img_resp = DGPUBusResponse(
|
img_resp = DGPUBusResponse(
|
||||||
rid=req.rid,
|
rid=req.rid,
|
||||||
nid=req.nid,
|
nid=req.nid,
|
||||||
params={'img': base64.b64encode(img).hex()}
|
params={
|
||||||
|
'img': base64.b64encode(img).hex(),
|
||||||
|
'meta': img_req.to_dict()
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
except DGPUComputeError as e:
|
except DGPUComputeError as e:
|
||||||
|
|
|
@ -90,13 +90,14 @@ async def open_skynet_rpc(
|
||||||
if security:
|
if security:
|
||||||
req.sign(tls_key, cert_name)
|
req.sign(tls_key, cert_name)
|
||||||
|
|
||||||
await sock.asend(
|
ctx = sock.new_context()
|
||||||
|
await ctx.asend(
|
||||||
json.dumps(
|
json.dumps(
|
||||||
req.to_dict()).encode())
|
req.to_dict()).encode())
|
||||||
|
|
||||||
resp = SkynetRPCResponse(
|
resp = SkynetRPCResponse(
|
||||||
**json.loads(
|
**json.loads((await ctx.arecv()).decode()))
|
||||||
(await sock.arecv_msg()).bytes.decode()))
|
ctx.close()
|
||||||
|
|
||||||
if security:
|
if security:
|
||||||
resp.verify(skynet_cert)
|
resp.verify(skynet_cert)
|
||||||
|
|
|
@ -1,14 +1,19 @@
|
||||||
#!/usr/bin/python
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
import io
|
||||||
|
import base64
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import pynng
|
import pynng
|
||||||
|
|
||||||
from telebot.async_telebot import AsyncTeleBot
|
from PIL import Image
|
||||||
from trio_asyncio import aio_as_trio
|
from trio_asyncio import aio_as_trio
|
||||||
|
|
||||||
|
from telebot.types import InputFile
|
||||||
|
from telebot.async_telebot import AsyncTeleBot
|
||||||
|
|
||||||
from ..constants import *
|
from ..constants import *
|
||||||
|
|
||||||
from . import *
|
from . import *
|
||||||
|
@ -17,6 +22,17 @@ from . import *
|
||||||
PREFIX = 'tg'
|
PREFIX = 'tg'
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_metainfo_caption(meta: dict) -> str:
|
||||||
|
meta_str = f'prompt: \"{meta["prompt"]}\"\n'
|
||||||
|
meta_str += f'seed: {meta["seed"]}\n'
|
||||||
|
meta_str += f'step: {meta["step"]}\n'
|
||||||
|
meta_str += f'guidance: {meta["guidance"]}\n'
|
||||||
|
meta_str += f'algo: \"{meta["algo"]}\"\n'
|
||||||
|
meta_str += f'sampler: k_euler_ancestral\n'
|
||||||
|
meta_str += f'skynet v{VERSION}'
|
||||||
|
return meta_str
|
||||||
|
|
||||||
|
|
||||||
async def run_skynet_telegram(
|
async def run_skynet_telegram(
|
||||||
tg_token: str,
|
tg_token: str,
|
||||||
key_name: str = 'telegram-frontend',
|
key_name: str = 'telegram-frontend',
|
||||||
|
@ -26,44 +42,96 @@ async def run_skynet_telegram(
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
bot = AsyncTeleBot(tg_token)
|
bot = AsyncTeleBot(tg_token)
|
||||||
|
|
||||||
with open_skynet_rpc(
|
async with open_skynet_rpc(
|
||||||
'skynet-telegram-0',
|
'skynet-telegram-0',
|
||||||
security=True,
|
security=True,
|
||||||
cert_name=cert,
|
cert_name=cert_name,
|
||||||
key_name=key
|
key_name=key_name
|
||||||
) as rpc_call:
|
) as rpc_call:
|
||||||
|
|
||||||
async def _rpc_call(
|
async def _rpc_call(
|
||||||
uid: int,
|
uid: int,
|
||||||
method: str,
|
method: str,
|
||||||
params: dict
|
params: dict = {}
|
||||||
):
|
):
|
||||||
return await rpc_call(
|
return await rpc_call(
|
||||||
method, params, uid=f'{PREFIX}+{uid}')
|
method, params, uid=f'{PREFIX}+{uid}')
|
||||||
|
|
||||||
@bot.message_handler(commands=['help'])
|
@bot.message_handler(commands=['help'])
|
||||||
async def send_help(message):
|
async def send_help(message):
|
||||||
|
splt_msg = message.text.split(' ')
|
||||||
|
|
||||||
|
if len(splt_msg) == 1:
|
||||||
await bot.reply_to(message, HELP_TEXT)
|
await bot.reply_to(message, HELP_TEXT)
|
||||||
|
|
||||||
|
else:
|
||||||
|
param = splt_msg[1]
|
||||||
|
if param in HELP_TOPICS:
|
||||||
|
await bot.reply_to(message, HELP_TOPICS[param])
|
||||||
|
|
||||||
|
else:
|
||||||
|
await bot.reply_to(message, HELP_UNKWNOWN_PARAM)
|
||||||
|
|
||||||
@bot.message_handler(commands=['cool'])
|
@bot.message_handler(commands=['cool'])
|
||||||
async def send_cool_words(message):
|
async def send_cool_words(message):
|
||||||
await bot.reply_to(message, '\n'.join(COOL_WORDS))
|
await bot.reply_to(message, '\n'.join(COOL_WORDS))
|
||||||
|
|
||||||
@bot.message_handler(commands=['txt2img'])
|
@bot.message_handler(commands=['txt2img'])
|
||||||
async def send_txt2img(message):
|
async def send_txt2img(message):
|
||||||
|
prompt = ' '.join(message.text.split(' ')[1:])
|
||||||
|
|
||||||
|
if len(prompt) == 0:
|
||||||
|
await bot.reply_to(message, 'Empty text prompt ignored.')
|
||||||
|
return
|
||||||
|
|
||||||
|
logging.info(f'mid: {message.id}')
|
||||||
resp = await _rpc_call(
|
resp = await _rpc_call(
|
||||||
message.from_user.id,
|
message.from_user.id,
|
||||||
'txt2img',
|
'txt2img',
|
||||||
{}
|
{'prompt': prompt}
|
||||||
)
|
)
|
||||||
|
logging.info(f'resp to {message.id} arrived')
|
||||||
|
|
||||||
|
resp_txt = ''
|
||||||
|
if 'error' in resp.result:
|
||||||
|
resp_txt = resp.result['message']
|
||||||
|
|
||||||
|
else:
|
||||||
|
logging.info(resp.result['id'])
|
||||||
|
img_raw = base64.b64decode(bytes.fromhex(resp.result['img']))
|
||||||
|
img = Image.frombytes('RGB', (512, 512), img_raw)
|
||||||
|
|
||||||
|
await bot.send_photo(
|
||||||
|
message.chat.id,
|
||||||
|
caption=prepare_metainfo_caption(resp.result['meta']),
|
||||||
|
photo=img,
|
||||||
|
reply_to_message_id=message.id
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
await bot.reply_to(message, resp_txt)
|
||||||
|
|
||||||
@bot.message_handler(commands=['redo'])
|
@bot.message_handler(commands=['redo'])
|
||||||
async def redo_txt2img(message):
|
async def redo_txt2img(message):
|
||||||
resp = await _rpc_call(
|
resp = await _rpc_call(message.from_user.id, 'redo')
|
||||||
message.from_user.id,
|
|
||||||
'redo',
|
resp_txt = ''
|
||||||
{}
|
if 'error' in resp.result:
|
||||||
|
resp_txt = resp.result['message']
|
||||||
|
|
||||||
|
else:
|
||||||
|
img_raw = base64.b64decode(bytes.fromhex(resp.result['img']))
|
||||||
|
img = Image.frombytes('RGB', (512, 512), img_raw)
|
||||||
|
|
||||||
|
await bot.send_photo(
|
||||||
|
message.chat.id,
|
||||||
|
caption=prepare_metainfo_caption(resp.result['meta']),
|
||||||
|
photo=img,
|
||||||
|
reply_to_message_id=message.id
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
await bot.reply_to(message, resp_txt)
|
||||||
|
|
||||||
@bot.message_handler(commands=['config'])
|
@bot.message_handler(commands=['config'])
|
||||||
async def set_config(message):
|
async def set_config(message):
|
||||||
|
|
Loading…
Reference in New Issue