Finish telegram frontend integration

Correctly update user stats after a txt2img request
Swap tg_id to BIGINT on database
Refactor help topic system to be more generic
Fix skynet brain cli entry point
Fix multi request rpc session, we werent creating a context on the req side
Fix redo method
DGPU now returns image metadata
Improve error messages brain to frontend
Add version as constant
Add telegram dep to requirements
pull/2/head
Guillermo Rodriguez 2022-12-21 11:53:50 -03:00
parent 896b0f684b
commit cb92aed51c
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
9 changed files with 194 additions and 60 deletions

View File

@ -5,3 +5,4 @@ aiohttp
msgspec msgspec
pyOpenSSL pyOpenSSL
trio_asyncio trio_asyncio
pyTelegramBotAPI

View File

@ -1,8 +1,10 @@
from setuptools import setup, find_packages from setuptools import setup, find_packages
from skynet.constants import VERSION
setup( setup(
name='skynet', name='skynet',
version='0.1.0a6', version=VERSION,
description='Decentralized compute platform', description='Decentralized compute platform',
author='Guillermo Rodriguez', author='Guillermo Rodriguez',
author_email='guillermo@telos.net', author_email='guillermo@telos.net',

View File

@ -4,6 +4,7 @@ import json
import uuid import uuid
import base64 import base64
import logging import logging
import traceback
from uuid import UUID from uuid import UUID
from pathlib import Path from pathlib import Path
@ -90,10 +91,10 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
logging.info(f'pre next_worker: {next_worker}') logging.info(f'pre next_worker: {next_worker}')
if next_worker == None: if next_worker == None:
raise SkynetDGPUOffline raise SkynetDGPUOffline('No workers connected, try again later')
if are_all_workers_busy(): if are_all_workers_busy():
raise SkynetDGPUOverloaded raise SkynetDGPUOverloaded('All workers are busy at the moment')
nid = list(nodes.keys())[next_worker] nid = list(nodes.keys())[next_worker]
@ -175,6 +176,8 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
with trio.move_on_after(30): with trio.move_on_after(30):
await img_event.wait() await img_event.wait()
logging.info(f'img event: {ack_event.is_set()}')
if not img_event.is_set(): if not img_event.is_set():
disconnect_node(nid) disconnect_node(nid)
raise SkynetDGPUComputeError('30 seconds timeout while processing request') raise SkynetDGPUComputeError('30 seconds timeout while processing request')
@ -187,7 +190,7 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
if 'error' in img_resp.params: if 'error' in img_resp.params:
raise SkynetDGPUComputeError(img_resp.params['error']) raise SkynetDGPUComputeError(img_resp.params['error'])
return rid, img_resp.params['img'] return rid, img_resp.params['img'], img_resp.params['meta']
async def handle_user_request(rpc_ctx, req): async def handle_user_request(rpc_ctx, req):
try: try:
@ -202,39 +205,54 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
user_config = {**(await get_user_config(conn, user))} user_config = {**(await get_user_config(conn, user))}
del user_config['id'] del user_config['id']
prompt = req.params['prompt'] prompt = req.params['prompt']
user_config= {
key : req.params.get(key, val)
for key, val in user_config.items()
}
req = ImageGenRequest( req = ImageGenRequest(
prompt=prompt, prompt=prompt,
**user_config **user_config
) )
rid, img = await dgpu_stream_one_img(req) rid, img, meta = await dgpu_stream_one_img(req)
logging.info(f'done streaming {rid}')
result = { result = {
'id': rid, 'id': rid,
'img': img 'img': img,
'meta': meta
} }
await update_user_stats(conn, user, last_prompt=prompt)
logging.info('updated user stats.')
case 'redo': case 'redo':
logging.info('redo') logging.info('redo')
user_config = await get_user_config(conn, user) user_config = {**(await get_user_config(conn, user))}
del user_config['id']
prompt = await get_last_prompt_of(conn, user) prompt = await get_last_prompt_of(conn, user)
if prompt:
req = ImageGenRequest( req = ImageGenRequest(
prompt=prompt, prompt=prompt,
**user_config **user_config
) )
rid, img = await dgpu_stream_one_img(req) rid, img, meta = await dgpu_stream_one_img(req)
result = { result = {
'id': rid, 'id': rid,
'img': img 'img': img,
'meta': meta
}
await update_user_stats(conn, user)
logging.info('updated user stats.')
else:
result = {
'error': 'skynet_no_last_prompt',
'message': 'No prompt to redo, do txt2img first'
} }
case 'config': case 'config':
logging.info('config') logging.info('config')
if req.params['attr'] in CONFIG_ATTRS: if req.params['attr'] in CONFIG_ATTRS:
logging.info(f'update: {req.params}')
await update_user_config( await update_user_config(
conn, user, req.params['attr'], req.params['val']) conn, user, req.params['attr'], req.params['val'])
logging.info('done')
case 'stats': case 'stats':
logging.info('stats') logging.info('stats')
@ -255,9 +273,10 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
'message': str(e) 'message': str(e)
} }
except SkynetDGPUOverloaded: except SkynetDGPUOverloaded as e:
result = { result = {
'error': 'skynet_dgpu_overloaded', 'error': 'skynet_dgpu_overloaded',
'message': str(e),
'nodes': len(nodes) 'nodes': len(nodes)
} }
@ -266,14 +285,23 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
'error': 'skynet_dgpu_compute_error', 'error': 'skynet_dgpu_compute_error',
'message': str(e) 'message': str(e)
} }
except BaseException as e:
traceback.print_exception(type(e), e, e.__traceback__)
result = {
'error': 'skynet_internal_error',
'message': str(e)
}
resp = SkynetRPCResponse(result=result) resp = SkynetRPCResponse(result=result)
if security: if security:
resp.sign(tls_key, 'skynet') resp.sign(tls_key, 'skynet')
logging.info('sending response')
await rpc_ctx.asend( await rpc_ctx.asend(
json.dumps(resp.to_dict()).encode()) json.dumps(resp.to_dict()).encode())
rpc_ctx.close()
logging.info('done')
async def request_service(n): async def request_service(n):
nonlocal next_worker nonlocal next_worker
@ -329,6 +357,8 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
await ctx.asend( await ctx.asend(
json.dumps(resp.to_dict()).encode()) json.dumps(resp.to_dict()).encode())
ctx.close()
async with trio.open_nursery() as n: async with trio.open_nursery() as n:
n.start_soon(dgpu_image_streamer) n.start_soon(dgpu_image_streamer)

View File

@ -8,10 +8,12 @@ from functools import partial
import trio import trio
import click import click
import trio_asyncio
from . import utils from . import utils
from .dgpu import open_dgpu_node from .dgpu import open_dgpu_node
from .brain import run_skynet from .brain import run_skynet
from .constants import ALGOS
from .frontend.telegram import run_skynet_telegram from .frontend.telegram import run_skynet_telegram
@ -61,16 +63,16 @@ def run(*args, **kwargs):
@click.option( @click.option(
'--host', '-h', default='localhost:5432') '--host', '-h', default='localhost:5432')
@click.option( @click.option(
'--pass', '-p', default='password') '--passwd', '-p', default='password')
def skynet( def brain(
loglevel: str, loglevel: str,
host: str, host: str,
passw: str passwd: str
): ):
async def _run_skynet(): async def _run_skynet():
async with run_skynet( async with run_skynet(
db_host=host, db_host=host,
db_pass=passw db_pass=passwd
): ):
await trio.sleep_forever() await trio.sleep_forever()
@ -86,13 +88,13 @@ def skynet(
@click.option( @click.option(
'--cert', '-c', default='whitelist/dgpu') '--cert', '-c', default='whitelist/dgpu')
@click.option( @click.option(
'--algos', '-a', default=None) '--algos', '-a', default=json.dumps(['midj']))
def dgpu( def dgpu(
loglevel: str, loglevel: str,
uid: str, uid: str,
key: str, key: str,
cert: str, cert: str,
algos: Optional[str] algos: str
): ):
trio.run( trio.run(
partial( partial(

View File

@ -1,8 +1,10 @@
#!/usr/bin/python #!/usr/bin/python
VERSION = '0.1a6'
DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda' DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda'
DB_HOST = 'ancap.tech:34508' DB_HOST = 'localhost:5432'
DB_USER = 'skynet' DB_USER = 'skynet'
DB_PASS = 'password' DB_PASS = 'password'
DB_NAME = 'skynet' DB_NAME = 'skynet'
@ -21,7 +23,7 @@ ALGOS = {
N = '\n' N = '\n'
HELP_TEXT = f''' HELP_TEXT = f'''
test art bot v0.1a4 test art bot v{VERSION}
commands work on a user per user basis! commands work on a user per user basis!
config is individual to each user! config is individual to each user!
@ -47,7 +49,7 @@ config is individual to each user!
/config guidance NUMBER - prompt text importance /config guidance NUMBER - prompt text importance
''' '''
UNKNOWN_CMD_TEXT = 'unknown command! try sending \"/help\"' UNKNOWN_CMD_TEXT = 'Unknown command! Try sending \"/help\"'
DONATION_INFO = '0xf95335682DF281FFaB7E104EB87B69625d9622B6\ngoal: 25/650usd' DONATION_INFO = '0xf95335682DF281FFaB7E104EB87B69625d9622B6\ngoal: 25/650usd'
@ -74,22 +76,24 @@ COOL_WORDS = [
'michelangelo' 'michelangelo'
] ]
HELP_STEP = ''' HELP_TOPICS = {
diffusion models are iterative processes a repeated cycle that starts with a\ 'step': '''
Diffusion models are iterative processes a repeated cycle that starts with a\
random noise generated from text input. With each step, some noise is removed\ random noise generated from text input. With each step, some noise is removed\
, resulting in a higher-quality image over time. The repetition stops when the\ , resulting in a higher-quality image over time. The repetition stops when the\
desired number of steps completes. desired number of steps completes.
around 25 sampling steps are usually enough to achieve high-quality images. Us\ Around 25 sampling steps are usually enough to achieve high-quality images. Us\
ing more may produce a slightly different picture, but not necessarily better \ ing more may produce a slightly different picture, but not necessarily better \
quality. quality.
''' ''',
HELP_GUIDANCE = ''' 'guidance': '''
the guidance scale is a parameter that controls how much the image generation\ The guidance scale is a parameter that controls how much the image generation\
process follows the text prompt. The higher the value, the more image sticks\ process follows the text prompt. The higher the value, the more image sticks\
to a given text input. to a given text input.
''' '''
}
HELP_UNKWNOWN_PARAM = 'don\'t have any info on that.' HELP_UNKWNOWN_PARAM = 'don\'t have any info on that.'

View File

@ -2,6 +2,7 @@
import logging import logging
from typing import Optional
from datetime import datetime from datetime import datetime
from contextlib import asynccontextmanager as acm from contextlib import asynccontextmanager as acm
@ -19,7 +20,7 @@ CREATE SCHEMA IF NOT EXISTS skynet;
CREATE TABLE IF NOT EXISTS skynet.user( CREATE TABLE IF NOT EXISTS skynet.user(
id SERIAL PRIMARY KEY NOT NULL, id SERIAL PRIMARY KEY NOT NULL,
tg_id INT, tg_id BIGINT,
wp_id VARCHAR(128), wp_id VARCHAR(128),
mx_id VARCHAR(128), mx_id VARCHAR(128),
ig_id VARCHAR(128), ig_id VARCHAR(128),
@ -47,7 +48,7 @@ CREATE TABLE IF NOT EXISTS skynet.user_config(
step INT NOT NULL, step INT NOT NULL,
width INT NOT NULL, width INT NOT NULL,
height INT NOT NULL, height INT NOT NULL,
seed INT, seed BIGINT,
guidance INT NOT NULL, guidance INT NOT NULL,
upscaler VARCHAR(128) upscaler VARCHAR(128)
); );
@ -124,7 +125,7 @@ async def get_user_config(conn, user: int):
async def get_last_prompt_of(conn, user: int): async def get_last_prompt_of(conn, user: int):
stms = await conn.prepare( stmt = await conn.prepare(
'SELECT last_prompt FROM skynet.user WHERE id = $1') 'SELECT last_prompt FROM skynet.user WHERE id = $1')
return await stmt.fetchval(user) return await stmt.fetchval(user)
@ -198,7 +199,12 @@ async def get_or_create_user(conn, uid: str):
return user return user
async def update_user(conn, user: int, attr: str, val): async def update_user(conn, user: int, attr: str, val):
... stmt = await conn.prepare(f'''
UPDATE skynet.user
SET {attr} = $2
WHERE id = $1
''')
await stmt.fetch(user, val)
async def update_user_config(conn, user: int, attr: str, val): async def update_user_config(conn, user: int, attr: str, val):
stmt = await conn.prepare(f''' stmt = await conn.prepare(f'''
@ -218,3 +224,18 @@ async def get_user_stats(conn, user: int):
assert len(records) == 1 assert len(records) == 1
record = records[0] record = records[0]
return record return record
async def update_user_stats(
conn,
user: int,
last_prompt: Optional[str] = None
):
stmt = await conn.prepare('''
UPDATE skynet.user
SET generated = generated + 1
WHERE id = $1
''')
await stmt.fetch(user)
if last_prompt:
await update_user(conn, user, 'last_prompt', last_prompt)

View File

@ -109,7 +109,6 @@ async def open_dgpu_node(
'generated': 0 'generated': 0
} }
seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64)
try: try:
image = models[ireq.algo]['pipe']( image = models[ireq.algo]['pipe'](
ireq.prompt, ireq.prompt,
@ -117,7 +116,7 @@ async def open_dgpu_node(
height=ireq.height, height=ireq.height,
guidance_scale=ireq.guidance, guidance_scale=ireq.guidance,
num_inference_steps=ireq.step, num_inference_steps=ireq.step,
generator=torch.Generator("cuda").manual_seed(seed) generator=torch.Generator("cuda").manual_seed(ireq.seed)
).images[0] ).images[0]
return image.tobytes() return image.tobytes()
@ -207,12 +206,18 @@ async def open_dgpu_node(
logging.info(f'sent ack, processing {req.rid}...') logging.info(f'sent ack, processing {req.rid}...')
try: try:
img = await gpu_compute_one( img_req = ImageGenRequest(**req.params)
ImageGenRequest(**req.params)) if not img_req.seed:
img_req.seed = random.randint(0, 2 ** 64)
img = await gpu_compute_one(img_req)
img_resp = DGPUBusResponse( img_resp = DGPUBusResponse(
rid=req.rid, rid=req.rid,
nid=req.nid, nid=req.nid,
params={'img': base64.b64encode(img).hex()} params={
'img': base64.b64encode(img).hex(),
'meta': img_req.to_dict()
}
) )
except DGPUComputeError as e: except DGPUComputeError as e:

View File

@ -90,13 +90,14 @@ async def open_skynet_rpc(
if security: if security:
req.sign(tls_key, cert_name) req.sign(tls_key, cert_name)
await sock.asend( ctx = sock.new_context()
await ctx.asend(
json.dumps( json.dumps(
req.to_dict()).encode()) req.to_dict()).encode())
resp = SkynetRPCResponse( resp = SkynetRPCResponse(
**json.loads( **json.loads((await ctx.arecv()).decode()))
(await sock.arecv_msg()).bytes.decode())) ctx.close()
if security: if security:
resp.verify(skynet_cert) resp.verify(skynet_cert)

View File

@ -1,14 +1,19 @@
#!/usr/bin/python #!/usr/bin/python
import io
import base64
import logging import logging
from datetime import datetime from datetime import datetime
import pynng import pynng
from telebot.async_telebot import AsyncTeleBot from PIL import Image
from trio_asyncio import aio_as_trio from trio_asyncio import aio_as_trio
from telebot.types import InputFile
from telebot.async_telebot import AsyncTeleBot
from ..constants import * from ..constants import *
from . import * from . import *
@ -17,6 +22,17 @@ from . import *
PREFIX = 'tg' PREFIX = 'tg'
def prepare_metainfo_caption(meta: dict) -> str:
meta_str = f'prompt: \"{meta["prompt"]}\"\n'
meta_str += f'seed: {meta["seed"]}\n'
meta_str += f'step: {meta["step"]}\n'
meta_str += f'guidance: {meta["guidance"]}\n'
meta_str += f'algo: \"{meta["algo"]}\"\n'
meta_str += f'sampler: k_euler_ancestral\n'
meta_str += f'skynet v{VERSION}'
return meta_str
async def run_skynet_telegram( async def run_skynet_telegram(
tg_token: str, tg_token: str,
key_name: str = 'telegram-frontend', key_name: str = 'telegram-frontend',
@ -26,44 +42,96 @@ async def run_skynet_telegram(
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
bot = AsyncTeleBot(tg_token) bot = AsyncTeleBot(tg_token)
with open_skynet_rpc( async with open_skynet_rpc(
'skynet-telegram-0', 'skynet-telegram-0',
security=True, security=True,
cert_name=cert, cert_name=cert_name,
key_name=key key_name=key_name
) as rpc_call: ) as rpc_call:
async def _rpc_call( async def _rpc_call(
uid: int, uid: int,
method: str, method: str,
params: dict params: dict = {}
): ):
return await rpc_call( return await rpc_call(
method, params, uid=f'{PREFIX}+{uid}') method, params, uid=f'{PREFIX}+{uid}')
@bot.message_handler(commands=['help']) @bot.message_handler(commands=['help'])
async def send_help(message): async def send_help(message):
splt_msg = message.text.split(' ')
if len(splt_msg) == 1:
await bot.reply_to(message, HELP_TEXT) await bot.reply_to(message, HELP_TEXT)
else:
param = splt_msg[1]
if param in HELP_TOPICS:
await bot.reply_to(message, HELP_TOPICS[param])
else:
await bot.reply_to(message, HELP_UNKWNOWN_PARAM)
@bot.message_handler(commands=['cool']) @bot.message_handler(commands=['cool'])
async def send_cool_words(message): async def send_cool_words(message):
await bot.reply_to(message, '\n'.join(COOL_WORDS)) await bot.reply_to(message, '\n'.join(COOL_WORDS))
@bot.message_handler(commands=['txt2img']) @bot.message_handler(commands=['txt2img'])
async def send_txt2img(message): async def send_txt2img(message):
prompt = ' '.join(message.text.split(' ')[1:])
if len(prompt) == 0:
await bot.reply_to(message, 'Empty text prompt ignored.')
return
logging.info(f'mid: {message.id}')
resp = await _rpc_call( resp = await _rpc_call(
message.from_user.id, message.from_user.id,
'txt2img', 'txt2img',
{} {'prompt': prompt}
) )
logging.info(f'resp to {message.id} arrived')
resp_txt = ''
if 'error' in resp.result:
resp_txt = resp.result['message']
else:
logging.info(resp.result['id'])
img_raw = base64.b64decode(bytes.fromhex(resp.result['img']))
img = Image.frombytes('RGB', (512, 512), img_raw)
await bot.send_photo(
message.chat.id,
caption=prepare_metainfo_caption(resp.result['meta']),
photo=img,
reply_to_message_id=message.id
)
return
await bot.reply_to(message, resp_txt)
@bot.message_handler(commands=['redo']) @bot.message_handler(commands=['redo'])
async def redo_txt2img(message): async def redo_txt2img(message):
resp = await _rpc_call( resp = await _rpc_call(message.from_user.id, 'redo')
message.from_user.id,
'redo', resp_txt = ''
{} if 'error' in resp.result:
resp_txt = resp.result['message']
else:
img_raw = base64.b64decode(bytes.fromhex(resp.result['img']))
img = Image.frombytes('RGB', (512, 512), img_raw)
await bot.send_photo(
message.chat.id,
caption=prepare_metainfo_caption(resp.result['meta']),
photo=img,
reply_to_message_id=message.id
) )
return
await bot.reply_to(message, resp_txt)
@bot.message_handler(commands=['config']) @bot.message_handler(commands=['config'])
async def set_config(message): async def set_config(message):