mirror of https://github.com/skygpu/skynet.git
Finish telegram frontend integration
Correctly update user stats after a txt2img request Swap tg_id to BIGINT on database Refactor help topic system to be more generic Fix skynet brain cli entry point Fix multi request rpc session, we werent creating a context on the req side Fix redo method DGPU now returns image metadata Improve error messages brain to frontend Add version as constant Add telegram dep to requirementspull/2/head
parent
896b0f684b
commit
cb92aed51c
|
@ -5,3 +5,4 @@ aiohttp
|
|||
msgspec
|
||||
pyOpenSSL
|
||||
trio_asyncio
|
||||
pyTelegramBotAPI
|
||||
|
|
4
setup.py
4
setup.py
|
@ -1,8 +1,10 @@
|
|||
from setuptools import setup, find_packages
|
||||
|
||||
from skynet.constants import VERSION
|
||||
|
||||
setup(
|
||||
name='skynet',
|
||||
version='0.1.0a6',
|
||||
version=VERSION,
|
||||
description='Decentralized compute platform',
|
||||
author='Guillermo Rodriguez',
|
||||
author_email='guillermo@telos.net',
|
||||
|
|
|
@ -4,6 +4,7 @@ import json
|
|||
import uuid
|
||||
import base64
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from uuid import UUID
|
||||
from pathlib import Path
|
||||
|
@ -90,10 +91,10 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
|||
logging.info(f'pre next_worker: {next_worker}')
|
||||
|
||||
if next_worker == None:
|
||||
raise SkynetDGPUOffline
|
||||
raise SkynetDGPUOffline('No workers connected, try again later')
|
||||
|
||||
if are_all_workers_busy():
|
||||
raise SkynetDGPUOverloaded
|
||||
raise SkynetDGPUOverloaded('All workers are busy at the moment')
|
||||
|
||||
|
||||
nid = list(nodes.keys())[next_worker]
|
||||
|
@ -175,6 +176,8 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
|||
with trio.move_on_after(30):
|
||||
await img_event.wait()
|
||||
|
||||
logging.info(f'img event: {ack_event.is_set()}')
|
||||
|
||||
if not img_event.is_set():
|
||||
disconnect_node(nid)
|
||||
raise SkynetDGPUComputeError('30 seconds timeout while processing request')
|
||||
|
@ -187,7 +190,7 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
|||
if 'error' in img_resp.params:
|
||||
raise SkynetDGPUComputeError(img_resp.params['error'])
|
||||
|
||||
return rid, img_resp.params['img']
|
||||
return rid, img_resp.params['img'], img_resp.params['meta']
|
||||
|
||||
async def handle_user_request(rpc_ctx, req):
|
||||
try:
|
||||
|
@ -202,39 +205,54 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
|||
user_config = {**(await get_user_config(conn, user))}
|
||||
del user_config['id']
|
||||
prompt = req.params['prompt']
|
||||
user_config= {
|
||||
key : req.params.get(key, val)
|
||||
for key, val in user_config.items()
|
||||
}
|
||||
req = ImageGenRequest(
|
||||
prompt=prompt,
|
||||
**user_config
|
||||
)
|
||||
rid, img = await dgpu_stream_one_img(req)
|
||||
rid, img, meta = await dgpu_stream_one_img(req)
|
||||
logging.info(f'done streaming {rid}')
|
||||
result = {
|
||||
'id': rid,
|
||||
'img': img
|
||||
'img': img,
|
||||
'meta': meta
|
||||
}
|
||||
|
||||
await update_user_stats(conn, user, last_prompt=prompt)
|
||||
logging.info('updated user stats.')
|
||||
|
||||
case 'redo':
|
||||
logging.info('redo')
|
||||
user_config = await get_user_config(conn, user)
|
||||
user_config = {**(await get_user_config(conn, user))}
|
||||
del user_config['id']
|
||||
prompt = await get_last_prompt_of(conn, user)
|
||||
req = ImageGenRequest(
|
||||
prompt=prompt,
|
||||
**user_config
|
||||
)
|
||||
rid, img = await dgpu_stream_one_img(req)
|
||||
result = {
|
||||
'id': rid,
|
||||
'img': img
|
||||
}
|
||||
|
||||
if prompt:
|
||||
req = ImageGenRequest(
|
||||
prompt=prompt,
|
||||
**user_config
|
||||
)
|
||||
rid, img, meta = await dgpu_stream_one_img(req)
|
||||
result = {
|
||||
'id': rid,
|
||||
'img': img,
|
||||
'meta': meta
|
||||
}
|
||||
await update_user_stats(conn, user)
|
||||
logging.info('updated user stats.')
|
||||
|
||||
else:
|
||||
result = {
|
||||
'error': 'skynet_no_last_prompt',
|
||||
'message': 'No prompt to redo, do txt2img first'
|
||||
}
|
||||
|
||||
case 'config':
|
||||
logging.info('config')
|
||||
if req.params['attr'] in CONFIG_ATTRS:
|
||||
logging.info(f'update: {req.params}')
|
||||
await update_user_config(
|
||||
conn, user, req.params['attr'], req.params['val'])
|
||||
logging.info('done')
|
||||
|
||||
case 'stats':
|
||||
logging.info('stats')
|
||||
|
@ -255,9 +273,10 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
|||
'message': str(e)
|
||||
}
|
||||
|
||||
except SkynetDGPUOverloaded:
|
||||
except SkynetDGPUOverloaded as e:
|
||||
result = {
|
||||
'error': 'skynet_dgpu_overloaded',
|
||||
'message': str(e),
|
||||
'nodes': len(nodes)
|
||||
}
|
||||
|
||||
|
@ -266,14 +285,23 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
|||
'error': 'skynet_dgpu_compute_error',
|
||||
'message': str(e)
|
||||
}
|
||||
except BaseException as e:
|
||||
traceback.print_exception(type(e), e, e.__traceback__)
|
||||
result = {
|
||||
'error': 'skynet_internal_error',
|
||||
'message': str(e)
|
||||
}
|
||||
|
||||
resp = SkynetRPCResponse(result=result)
|
||||
|
||||
if security:
|
||||
resp.sign(tls_key, 'skynet')
|
||||
|
||||
logging.info('sending response')
|
||||
await rpc_ctx.asend(
|
||||
json.dumps(resp.to_dict()).encode())
|
||||
rpc_ctx.close()
|
||||
logging.info('done')
|
||||
|
||||
async def request_service(n):
|
||||
nonlocal next_worker
|
||||
|
@ -329,6 +357,8 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
|||
await ctx.asend(
|
||||
json.dumps(resp.to_dict()).encode())
|
||||
|
||||
ctx.close()
|
||||
|
||||
|
||||
async with trio.open_nursery() as n:
|
||||
n.start_soon(dgpu_image_streamer)
|
||||
|
|
|
@ -8,10 +8,12 @@ from functools import partial
|
|||
|
||||
import trio
|
||||
import click
|
||||
import trio_asyncio
|
||||
|
||||
from . import utils
|
||||
from .dgpu import open_dgpu_node
|
||||
from .brain import run_skynet
|
||||
from .constants import ALGOS
|
||||
|
||||
from .frontend.telegram import run_skynet_telegram
|
||||
|
||||
|
@ -61,16 +63,16 @@ def run(*args, **kwargs):
|
|||
@click.option(
|
||||
'--host', '-h', default='localhost:5432')
|
||||
@click.option(
|
||||
'--pass', '-p', default='password')
|
||||
def skynet(
|
||||
'--passwd', '-p', default='password')
|
||||
def brain(
|
||||
loglevel: str,
|
||||
host: str,
|
||||
passw: str
|
||||
passwd: str
|
||||
):
|
||||
async def _run_skynet():
|
||||
async with run_skynet(
|
||||
db_host=host,
|
||||
db_pass=passw
|
||||
db_pass=passwd
|
||||
):
|
||||
await trio.sleep_forever()
|
||||
|
||||
|
@ -86,13 +88,13 @@ def skynet(
|
|||
@click.option(
|
||||
'--cert', '-c', default='whitelist/dgpu')
|
||||
@click.option(
|
||||
'--algos', '-a', default=None)
|
||||
'--algos', '-a', default=json.dumps(['midj']))
|
||||
def dgpu(
|
||||
loglevel: str,
|
||||
uid: str,
|
||||
key: str,
|
||||
cert: str,
|
||||
algos: Optional[str]
|
||||
algos: str
|
||||
):
|
||||
trio.run(
|
||||
partial(
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
VERSION = '0.1a6'
|
||||
|
||||
DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda'
|
||||
|
||||
DB_HOST = 'ancap.tech:34508'
|
||||
DB_HOST = 'localhost:5432'
|
||||
DB_USER = 'skynet'
|
||||
DB_PASS = 'password'
|
||||
DB_NAME = 'skynet'
|
||||
|
@ -21,7 +23,7 @@ ALGOS = {
|
|||
|
||||
N = '\n'
|
||||
HELP_TEXT = f'''
|
||||
test art bot v0.1a4
|
||||
test art bot v{VERSION}
|
||||
|
||||
commands work on a user per user basis!
|
||||
config is individual to each user!
|
||||
|
@ -47,7 +49,7 @@ config is individual to each user!
|
|||
/config guidance NUMBER - prompt text importance
|
||||
'''
|
||||
|
||||
UNKNOWN_CMD_TEXT = 'unknown command! try sending \"/help\"'
|
||||
UNKNOWN_CMD_TEXT = 'Unknown command! Try sending \"/help\"'
|
||||
|
||||
DONATION_INFO = '0xf95335682DF281FFaB7E104EB87B69625d9622B6\ngoal: 25/650usd'
|
||||
|
||||
|
@ -74,22 +76,24 @@ COOL_WORDS = [
|
|||
'michelangelo'
|
||||
]
|
||||
|
||||
HELP_STEP = '''
|
||||
diffusion models are iterative processes – a repeated cycle that starts with a\
|
||||
HELP_TOPICS = {
|
||||
'step': '''
|
||||
Diffusion models are iterative processes – a repeated cycle that starts with a\
|
||||
random noise generated from text input. With each step, some noise is removed\
|
||||
, resulting in a higher-quality image over time. The repetition stops when the\
|
||||
desired number of steps completes.
|
||||
|
||||
around 25 sampling steps are usually enough to achieve high-quality images. Us\
|
||||
Around 25 sampling steps are usually enough to achieve high-quality images. Us\
|
||||
ing more may produce a slightly different picture, but not necessarily better \
|
||||
quality.
|
||||
'''
|
||||
''',
|
||||
|
||||
HELP_GUIDANCE = '''
|
||||
the guidance scale is a parameter that controls how much the image generation\
|
||||
'guidance': '''
|
||||
The guidance scale is a parameter that controls how much the image generation\
|
||||
process follows the text prompt. The higher the value, the more image sticks\
|
||||
to a given text input.
|
||||
'''
|
||||
}
|
||||
|
||||
HELP_UNKWNOWN_PARAM = 'don\'t have any info on that.'
|
||||
|
||||
|
|
29
skynet/db.py
29
skynet/db.py
|
@ -2,6 +2,7 @@
|
|||
|
||||
import logging
|
||||
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from contextlib import asynccontextmanager as acm
|
||||
|
||||
|
@ -19,7 +20,7 @@ CREATE SCHEMA IF NOT EXISTS skynet;
|
|||
|
||||
CREATE TABLE IF NOT EXISTS skynet.user(
|
||||
id SERIAL PRIMARY KEY NOT NULL,
|
||||
tg_id INT,
|
||||
tg_id BIGINT,
|
||||
wp_id VARCHAR(128),
|
||||
mx_id VARCHAR(128),
|
||||
ig_id VARCHAR(128),
|
||||
|
@ -47,7 +48,7 @@ CREATE TABLE IF NOT EXISTS skynet.user_config(
|
|||
step INT NOT NULL,
|
||||
width INT NOT NULL,
|
||||
height INT NOT NULL,
|
||||
seed INT,
|
||||
seed BIGINT,
|
||||
guidance INT NOT NULL,
|
||||
upscaler VARCHAR(128)
|
||||
);
|
||||
|
@ -124,7 +125,7 @@ async def get_user_config(conn, user: int):
|
|||
|
||||
|
||||
async def get_last_prompt_of(conn, user: int):
|
||||
stms = await conn.prepare(
|
||||
stmt = await conn.prepare(
|
||||
'SELECT last_prompt FROM skynet.user WHERE id = $1')
|
||||
return await stmt.fetchval(user)
|
||||
|
||||
|
@ -198,7 +199,12 @@ async def get_or_create_user(conn, uid: str):
|
|||
return user
|
||||
|
||||
async def update_user(conn, user: int, attr: str, val):
|
||||
...
|
||||
stmt = await conn.prepare(f'''
|
||||
UPDATE skynet.user
|
||||
SET {attr} = $2
|
||||
WHERE id = $1
|
||||
''')
|
||||
await stmt.fetch(user, val)
|
||||
|
||||
async def update_user_config(conn, user: int, attr: str, val):
|
||||
stmt = await conn.prepare(f'''
|
||||
|
@ -218,3 +224,18 @@ async def get_user_stats(conn, user: int):
|
|||
assert len(records) == 1
|
||||
record = records[0]
|
||||
return record
|
||||
|
||||
async def update_user_stats(
|
||||
conn,
|
||||
user: int,
|
||||
last_prompt: Optional[str] = None
|
||||
):
|
||||
stmt = await conn.prepare('''
|
||||
UPDATE skynet.user
|
||||
SET generated = generated + 1
|
||||
WHERE id = $1
|
||||
''')
|
||||
await stmt.fetch(user)
|
||||
|
||||
if last_prompt:
|
||||
await update_user(conn, user, 'last_prompt', last_prompt)
|
||||
|
|
|
@ -109,7 +109,6 @@ async def open_dgpu_node(
|
|||
'generated': 0
|
||||
}
|
||||
|
||||
seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64)
|
||||
try:
|
||||
image = models[ireq.algo]['pipe'](
|
||||
ireq.prompt,
|
||||
|
@ -117,7 +116,7 @@ async def open_dgpu_node(
|
|||
height=ireq.height,
|
||||
guidance_scale=ireq.guidance,
|
||||
num_inference_steps=ireq.step,
|
||||
generator=torch.Generator("cuda").manual_seed(seed)
|
||||
generator=torch.Generator("cuda").manual_seed(ireq.seed)
|
||||
).images[0]
|
||||
return image.tobytes()
|
||||
|
||||
|
@ -207,12 +206,18 @@ async def open_dgpu_node(
|
|||
logging.info(f'sent ack, processing {req.rid}...')
|
||||
|
||||
try:
|
||||
img = await gpu_compute_one(
|
||||
ImageGenRequest(**req.params))
|
||||
img_req = ImageGenRequest(**req.params)
|
||||
if not img_req.seed:
|
||||
img_req.seed = random.randint(0, 2 ** 64)
|
||||
|
||||
img = await gpu_compute_one(img_req)
|
||||
img_resp = DGPUBusResponse(
|
||||
rid=req.rid,
|
||||
nid=req.nid,
|
||||
params={'img': base64.b64encode(img).hex()}
|
||||
params={
|
||||
'img': base64.b64encode(img).hex(),
|
||||
'meta': img_req.to_dict()
|
||||
}
|
||||
)
|
||||
|
||||
except DGPUComputeError as e:
|
||||
|
|
|
@ -90,13 +90,14 @@ async def open_skynet_rpc(
|
|||
if security:
|
||||
req.sign(tls_key, cert_name)
|
||||
|
||||
await sock.asend(
|
||||
ctx = sock.new_context()
|
||||
await ctx.asend(
|
||||
json.dumps(
|
||||
req.to_dict()).encode())
|
||||
|
||||
resp = SkynetRPCResponse(
|
||||
**json.loads(
|
||||
(await sock.arecv_msg()).bytes.decode()))
|
||||
**json.loads((await ctx.arecv()).decode()))
|
||||
ctx.close()
|
||||
|
||||
if security:
|
||||
resp.verify(skynet_cert)
|
||||
|
|
|
@ -1,14 +1,19 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import io
|
||||
import base64
|
||||
import logging
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import pynng
|
||||
|
||||
from telebot.async_telebot import AsyncTeleBot
|
||||
from PIL import Image
|
||||
from trio_asyncio import aio_as_trio
|
||||
|
||||
from telebot.types import InputFile
|
||||
from telebot.async_telebot import AsyncTeleBot
|
||||
|
||||
from ..constants import *
|
||||
|
||||
from . import *
|
||||
|
@ -17,6 +22,17 @@ from . import *
|
|||
PREFIX = 'tg'
|
||||
|
||||
|
||||
def prepare_metainfo_caption(meta: dict) -> str:
|
||||
meta_str = f'prompt: \"{meta["prompt"]}\"\n'
|
||||
meta_str += f'seed: {meta["seed"]}\n'
|
||||
meta_str += f'step: {meta["step"]}\n'
|
||||
meta_str += f'guidance: {meta["guidance"]}\n'
|
||||
meta_str += f'algo: \"{meta["algo"]}\"\n'
|
||||
meta_str += f'sampler: k_euler_ancestral\n'
|
||||
meta_str += f'skynet v{VERSION}'
|
||||
return meta_str
|
||||
|
||||
|
||||
async def run_skynet_telegram(
|
||||
tg_token: str,
|
||||
key_name: str = 'telegram-frontend',
|
||||
|
@ -26,24 +42,35 @@ async def run_skynet_telegram(
|
|||
logging.basicConfig(level=logging.INFO)
|
||||
bot = AsyncTeleBot(tg_token)
|
||||
|
||||
with open_skynet_rpc(
|
||||
async with open_skynet_rpc(
|
||||
'skynet-telegram-0',
|
||||
security=True,
|
||||
cert_name=cert,
|
||||
key_name=key
|
||||
cert_name=cert_name,
|
||||
key_name=key_name
|
||||
) as rpc_call:
|
||||
|
||||
async def _rpc_call(
|
||||
uid: int,
|
||||
method: str,
|
||||
params: dict
|
||||
params: dict = {}
|
||||
):
|
||||
return await rpc_call(
|
||||
method, params, uid=f'{PREFIX}+{uid}')
|
||||
|
||||
@bot.message_handler(commands=['help'])
|
||||
async def send_help(message):
|
||||
await bot.reply_to(message, HELP_TEXT)
|
||||
splt_msg = message.text.split(' ')
|
||||
|
||||
if len(splt_msg) == 1:
|
||||
await bot.reply_to(message, HELP_TEXT)
|
||||
|
||||
else:
|
||||
param = splt_msg[1]
|
||||
if param in HELP_TOPICS:
|
||||
await bot.reply_to(message, HELP_TOPICS[param])
|
||||
|
||||
else:
|
||||
await bot.reply_to(message, HELP_UNKWNOWN_PARAM)
|
||||
|
||||
@bot.message_handler(commands=['cool'])
|
||||
async def send_cool_words(message):
|
||||
|
@ -51,19 +78,60 @@ async def run_skynet_telegram(
|
|||
|
||||
@bot.message_handler(commands=['txt2img'])
|
||||
async def send_txt2img(message):
|
||||
prompt = ' '.join(message.text.split(' ')[1:])
|
||||
|
||||
if len(prompt) == 0:
|
||||
await bot.reply_to(message, 'Empty text prompt ignored.')
|
||||
return
|
||||
|
||||
logging.info(f'mid: {message.id}')
|
||||
resp = await _rpc_call(
|
||||
message.from_user.id,
|
||||
'txt2img',
|
||||
{}
|
||||
{'prompt': prompt}
|
||||
)
|
||||
logging.info(f'resp to {message.id} arrived')
|
||||
|
||||
resp_txt = ''
|
||||
if 'error' in resp.result:
|
||||
resp_txt = resp.result['message']
|
||||
|
||||
else:
|
||||
logging.info(resp.result['id'])
|
||||
img_raw = base64.b64decode(bytes.fromhex(resp.result['img']))
|
||||
img = Image.frombytes('RGB', (512, 512), img_raw)
|
||||
|
||||
await bot.send_photo(
|
||||
message.chat.id,
|
||||
caption=prepare_metainfo_caption(resp.result['meta']),
|
||||
photo=img,
|
||||
reply_to_message_id=message.id
|
||||
)
|
||||
return
|
||||
|
||||
await bot.reply_to(message, resp_txt)
|
||||
|
||||
@bot.message_handler(commands=['redo'])
|
||||
async def redo_txt2img(message):
|
||||
resp = await _rpc_call(
|
||||
message.from_user.id,
|
||||
'redo',
|
||||
{}
|
||||
)
|
||||
resp = await _rpc_call(message.from_user.id, 'redo')
|
||||
|
||||
resp_txt = ''
|
||||
if 'error' in resp.result:
|
||||
resp_txt = resp.result['message']
|
||||
|
||||
else:
|
||||
img_raw = base64.b64decode(bytes.fromhex(resp.result['img']))
|
||||
img = Image.frombytes('RGB', (512, 512), img_raw)
|
||||
|
||||
await bot.send_photo(
|
||||
message.chat.id,
|
||||
caption=prepare_metainfo_caption(resp.result['meta']),
|
||||
photo=img,
|
||||
reply_to_message_id=message.id
|
||||
)
|
||||
return
|
||||
|
||||
await bot.reply_to(message, resp_txt)
|
||||
|
||||
@bot.message_handler(commands=['config'])
|
||||
async def set_config(message):
|
||||
|
|
Loading…
Reference in New Issue