Finish telegram frontend integration

Correctly update user stats after a txt2img request
Swap tg_id to BIGINT on database
Refactor help topic system to be more generic
Fix skynet brain cli entry point
Fix multi request rpc session, we werent creating a context on the req side
Fix redo method
DGPU now returns image metadata
Improve error messages brain to frontend
Add version as constant
Add telegram dep to requirements
pull/2/head
Guillermo Rodriguez 2022-12-21 11:53:50 -03:00
parent 896b0f684b
commit cb92aed51c
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
9 changed files with 194 additions and 60 deletions

View File

@ -5,3 +5,4 @@ aiohttp
msgspec
pyOpenSSL
trio_asyncio
pyTelegramBotAPI

View File

@ -1,8 +1,10 @@
from setuptools import setup, find_packages
from skynet.constants import VERSION
setup(
name='skynet',
version='0.1.0a6',
version=VERSION,
description='Decentralized compute platform',
author='Guillermo Rodriguez',
author_email='guillermo@telos.net',

View File

@ -4,6 +4,7 @@ import json
import uuid
import base64
import logging
import traceback
from uuid import UUID
from pathlib import Path
@ -90,10 +91,10 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
logging.info(f'pre next_worker: {next_worker}')
if next_worker == None:
raise SkynetDGPUOffline
raise SkynetDGPUOffline('No workers connected, try again later')
if are_all_workers_busy():
raise SkynetDGPUOverloaded
raise SkynetDGPUOverloaded('All workers are busy at the moment')
nid = list(nodes.keys())[next_worker]
@ -175,6 +176,8 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
with trio.move_on_after(30):
await img_event.wait()
logging.info(f'img event: {ack_event.is_set()}')
if not img_event.is_set():
disconnect_node(nid)
raise SkynetDGPUComputeError('30 seconds timeout while processing request')
@ -187,7 +190,7 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
if 'error' in img_resp.params:
raise SkynetDGPUComputeError(img_resp.params['error'])
return rid, img_resp.params['img']
return rid, img_resp.params['img'], img_resp.params['meta']
async def handle_user_request(rpc_ctx, req):
try:
@ -202,39 +205,54 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
user_config = {**(await get_user_config(conn, user))}
del user_config['id']
prompt = req.params['prompt']
user_config= {
key : req.params.get(key, val)
for key, val in user_config.items()
}
req = ImageGenRequest(
prompt=prompt,
**user_config
)
rid, img = await dgpu_stream_one_img(req)
rid, img, meta = await dgpu_stream_one_img(req)
logging.info(f'done streaming {rid}')
result = {
'id': rid,
'img': img
'img': img,
'meta': meta
}
await update_user_stats(conn, user, last_prompt=prompt)
logging.info('updated user stats.')
case 'redo':
logging.info('redo')
user_config = await get_user_config(conn, user)
user_config = {**(await get_user_config(conn, user))}
del user_config['id']
prompt = await get_last_prompt_of(conn, user)
if prompt:
req = ImageGenRequest(
prompt=prompt,
**user_config
)
rid, img = await dgpu_stream_one_img(req)
rid, img, meta = await dgpu_stream_one_img(req)
result = {
'id': rid,
'img': img
'img': img,
'meta': meta
}
await update_user_stats(conn, user)
logging.info('updated user stats.')
else:
result = {
'error': 'skynet_no_last_prompt',
'message': 'No prompt to redo, do txt2img first'
}
case 'config':
logging.info('config')
if req.params['attr'] in CONFIG_ATTRS:
logging.info(f'update: {req.params}')
await update_user_config(
conn, user, req.params['attr'], req.params['val'])
logging.info('done')
case 'stats':
logging.info('stats')
@ -255,9 +273,10 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
'message': str(e)
}
except SkynetDGPUOverloaded:
except SkynetDGPUOverloaded as e:
result = {
'error': 'skynet_dgpu_overloaded',
'message': str(e),
'nodes': len(nodes)
}
@ -266,14 +285,23 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
'error': 'skynet_dgpu_compute_error',
'message': str(e)
}
except BaseException as e:
traceback.print_exception(type(e), e, e.__traceback__)
result = {
'error': 'skynet_internal_error',
'message': str(e)
}
resp = SkynetRPCResponse(result=result)
if security:
resp.sign(tls_key, 'skynet')
logging.info('sending response')
await rpc_ctx.asend(
json.dumps(resp.to_dict()).encode())
rpc_ctx.close()
logging.info('done')
async def request_service(n):
nonlocal next_worker
@ -329,6 +357,8 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
await ctx.asend(
json.dumps(resp.to_dict()).encode())
ctx.close()
async with trio.open_nursery() as n:
n.start_soon(dgpu_image_streamer)

View File

@ -8,10 +8,12 @@ from functools import partial
import trio
import click
import trio_asyncio
from . import utils
from .dgpu import open_dgpu_node
from .brain import run_skynet
from .constants import ALGOS
from .frontend.telegram import run_skynet_telegram
@ -61,16 +63,16 @@ def run(*args, **kwargs):
@click.option(
'--host', '-h', default='localhost:5432')
@click.option(
'--pass', '-p', default='password')
def skynet(
'--passwd', '-p', default='password')
def brain(
loglevel: str,
host: str,
passw: str
passwd: str
):
async def _run_skynet():
async with run_skynet(
db_host=host,
db_pass=passw
db_pass=passwd
):
await trio.sleep_forever()
@ -86,13 +88,13 @@ def skynet(
@click.option(
'--cert', '-c', default='whitelist/dgpu')
@click.option(
'--algos', '-a', default=None)
'--algos', '-a', default=json.dumps(['midj']))
def dgpu(
loglevel: str,
uid: str,
key: str,
cert: str,
algos: Optional[str]
algos: str
):
trio.run(
partial(

View File

@ -1,8 +1,10 @@
#!/usr/bin/python
VERSION = '0.1a6'
DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda'
DB_HOST = 'ancap.tech:34508'
DB_HOST = 'localhost:5432'
DB_USER = 'skynet'
DB_PASS = 'password'
DB_NAME = 'skynet'
@ -21,7 +23,7 @@ ALGOS = {
N = '\n'
HELP_TEXT = f'''
test art bot v0.1a4
test art bot v{VERSION}
commands work on a user per user basis!
config is individual to each user!
@ -47,7 +49,7 @@ config is individual to each user!
/config guidance NUMBER - prompt text importance
'''
UNKNOWN_CMD_TEXT = 'unknown command! try sending \"/help\"'
UNKNOWN_CMD_TEXT = 'Unknown command! Try sending \"/help\"'
DONATION_INFO = '0xf95335682DF281FFaB7E104EB87B69625d9622B6\ngoal: 25/650usd'
@ -74,22 +76,24 @@ COOL_WORDS = [
'michelangelo'
]
HELP_STEP = '''
diffusion models are iterative processes a repeated cycle that starts with a\
HELP_TOPICS = {
'step': '''
Diffusion models are iterative processes a repeated cycle that starts with a\
random noise generated from text input. With each step, some noise is removed\
, resulting in a higher-quality image over time. The repetition stops when the\
desired number of steps completes.
around 25 sampling steps are usually enough to achieve high-quality images. Us\
Around 25 sampling steps are usually enough to achieve high-quality images. Us\
ing more may produce a slightly different picture, but not necessarily better \
quality.
'''
''',
HELP_GUIDANCE = '''
the guidance scale is a parameter that controls how much the image generation\
'guidance': '''
The guidance scale is a parameter that controls how much the image generation\
process follows the text prompt. The higher the value, the more image sticks\
to a given text input.
'''
}
HELP_UNKWNOWN_PARAM = 'don\'t have any info on that.'

View File

@ -2,6 +2,7 @@
import logging
from typing import Optional
from datetime import datetime
from contextlib import asynccontextmanager as acm
@ -19,7 +20,7 @@ CREATE SCHEMA IF NOT EXISTS skynet;
CREATE TABLE IF NOT EXISTS skynet.user(
id SERIAL PRIMARY KEY NOT NULL,
tg_id INT,
tg_id BIGINT,
wp_id VARCHAR(128),
mx_id VARCHAR(128),
ig_id VARCHAR(128),
@ -47,7 +48,7 @@ CREATE TABLE IF NOT EXISTS skynet.user_config(
step INT NOT NULL,
width INT NOT NULL,
height INT NOT NULL,
seed INT,
seed BIGINT,
guidance INT NOT NULL,
upscaler VARCHAR(128)
);
@ -124,7 +125,7 @@ async def get_user_config(conn, user: int):
async def get_last_prompt_of(conn, user: int):
stms = await conn.prepare(
stmt = await conn.prepare(
'SELECT last_prompt FROM skynet.user WHERE id = $1')
return await stmt.fetchval(user)
@ -198,7 +199,12 @@ async def get_or_create_user(conn, uid: str):
return user
async def update_user(conn, user: int, attr: str, val):
...
stmt = await conn.prepare(f'''
UPDATE skynet.user
SET {attr} = $2
WHERE id = $1
''')
await stmt.fetch(user, val)
async def update_user_config(conn, user: int, attr: str, val):
stmt = await conn.prepare(f'''
@ -218,3 +224,18 @@ async def get_user_stats(conn, user: int):
assert len(records) == 1
record = records[0]
return record
async def update_user_stats(
conn,
user: int,
last_prompt: Optional[str] = None
):
stmt = await conn.prepare('''
UPDATE skynet.user
SET generated = generated + 1
WHERE id = $1
''')
await stmt.fetch(user)
if last_prompt:
await update_user(conn, user, 'last_prompt', last_prompt)

View File

@ -109,7 +109,6 @@ async def open_dgpu_node(
'generated': 0
}
seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64)
try:
image = models[ireq.algo]['pipe'](
ireq.prompt,
@ -117,7 +116,7 @@ async def open_dgpu_node(
height=ireq.height,
guidance_scale=ireq.guidance,
num_inference_steps=ireq.step,
generator=torch.Generator("cuda").manual_seed(seed)
generator=torch.Generator("cuda").manual_seed(ireq.seed)
).images[0]
return image.tobytes()
@ -207,12 +206,18 @@ async def open_dgpu_node(
logging.info(f'sent ack, processing {req.rid}...')
try:
img = await gpu_compute_one(
ImageGenRequest(**req.params))
img_req = ImageGenRequest(**req.params)
if not img_req.seed:
img_req.seed = random.randint(0, 2 ** 64)
img = await gpu_compute_one(img_req)
img_resp = DGPUBusResponse(
rid=req.rid,
nid=req.nid,
params={'img': base64.b64encode(img).hex()}
params={
'img': base64.b64encode(img).hex(),
'meta': img_req.to_dict()
}
)
except DGPUComputeError as e:

View File

@ -90,13 +90,14 @@ async def open_skynet_rpc(
if security:
req.sign(tls_key, cert_name)
await sock.asend(
ctx = sock.new_context()
await ctx.asend(
json.dumps(
req.to_dict()).encode())
resp = SkynetRPCResponse(
**json.loads(
(await sock.arecv_msg()).bytes.decode()))
**json.loads((await ctx.arecv()).decode()))
ctx.close()
if security:
resp.verify(skynet_cert)

View File

@ -1,14 +1,19 @@
#!/usr/bin/python
import io
import base64
import logging
from datetime import datetime
import pynng
from telebot.async_telebot import AsyncTeleBot
from PIL import Image
from trio_asyncio import aio_as_trio
from telebot.types import InputFile
from telebot.async_telebot import AsyncTeleBot
from ..constants import *
from . import *
@ -17,6 +22,17 @@ from . import *
PREFIX = 'tg'
def prepare_metainfo_caption(meta: dict) -> str:
meta_str = f'prompt: \"{meta["prompt"]}\"\n'
meta_str += f'seed: {meta["seed"]}\n'
meta_str += f'step: {meta["step"]}\n'
meta_str += f'guidance: {meta["guidance"]}\n'
meta_str += f'algo: \"{meta["algo"]}\"\n'
meta_str += f'sampler: k_euler_ancestral\n'
meta_str += f'skynet v{VERSION}'
return meta_str
async def run_skynet_telegram(
tg_token: str,
key_name: str = 'telegram-frontend',
@ -26,44 +42,96 @@ async def run_skynet_telegram(
logging.basicConfig(level=logging.INFO)
bot = AsyncTeleBot(tg_token)
with open_skynet_rpc(
async with open_skynet_rpc(
'skynet-telegram-0',
security=True,
cert_name=cert,
key_name=key
cert_name=cert_name,
key_name=key_name
) as rpc_call:
async def _rpc_call(
uid: int,
method: str,
params: dict
params: dict = {}
):
return await rpc_call(
method, params, uid=f'{PREFIX}+{uid}')
@bot.message_handler(commands=['help'])
async def send_help(message):
splt_msg = message.text.split(' ')
if len(splt_msg) == 1:
await bot.reply_to(message, HELP_TEXT)
else:
param = splt_msg[1]
if param in HELP_TOPICS:
await bot.reply_to(message, HELP_TOPICS[param])
else:
await bot.reply_to(message, HELP_UNKWNOWN_PARAM)
@bot.message_handler(commands=['cool'])
async def send_cool_words(message):
await bot.reply_to(message, '\n'.join(COOL_WORDS))
@bot.message_handler(commands=['txt2img'])
async def send_txt2img(message):
prompt = ' '.join(message.text.split(' ')[1:])
if len(prompt) == 0:
await bot.reply_to(message, 'Empty text prompt ignored.')
return
logging.info(f'mid: {message.id}')
resp = await _rpc_call(
message.from_user.id,
'txt2img',
{}
{'prompt': prompt}
)
logging.info(f'resp to {message.id} arrived')
resp_txt = ''
if 'error' in resp.result:
resp_txt = resp.result['message']
else:
logging.info(resp.result['id'])
img_raw = base64.b64decode(bytes.fromhex(resp.result['img']))
img = Image.frombytes('RGB', (512, 512), img_raw)
await bot.send_photo(
message.chat.id,
caption=prepare_metainfo_caption(resp.result['meta']),
photo=img,
reply_to_message_id=message.id
)
return
await bot.reply_to(message, resp_txt)
@bot.message_handler(commands=['redo'])
async def redo_txt2img(message):
resp = await _rpc_call(
message.from_user.id,
'redo',
{}
resp = await _rpc_call(message.from_user.id, 'redo')
resp_txt = ''
if 'error' in resp.result:
resp_txt = resp.result['message']
else:
img_raw = base64.b64decode(bytes.fromhex(resp.result['img']))
img = Image.frombytes('RGB', (512, 512), img_raw)
await bot.send_photo(
message.chat.id,
caption=prepare_metainfo_caption(resp.result['meta']),
photo=img,
reply_to_message_id=message.id
)
return
await bot.reply_to(message, resp_txt)
@bot.message_handler(commands=['config'])
async def set_config(message):