Started making roboust testing fixtures to init fresh db and skynet

Add simple dgpu worker connection test
Make db connection handler manage schema and table init logic
Keep tweaking dgpu main handler attemtping to fix subactor hangs
Change frontend open rpc logic to return a wrapped rpc_call fn referencing the new socket
Decupled user config request validation from telegram module
Fix next_worker logic, now takes in account multiple tasks per dgpu
Add dgpu_workers and dgpu_next calls
Fixed readme, moved db init code into db module
pull/2/head
Guillermo Rodriguez 2022-12-11 11:02:55 -03:00
parent c3852314a7
commit 9afb192251
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
11 changed files with 387 additions and 224 deletions

View File

@ -1,47 +1,2 @@
create db in postgres: # skynet
### decentralized compute platform
```sql
CREATE USER skynet WITH PASSWORD 'password';
CREATE DATABASE skynet_art_bot;
GRANT ALL PRIVILEGES ON DATABASE skynet_art_bot TO skynet;
CREATE SCHEMA IF NOT EXISTS skynet;
CREATE TABLE IF NOT EXISTS skynet.user(
id SERIAL PRIMARY KEY NOT NULL,
tg_id INT,
wp_id VARCHAR(128),
mx_id VARCHAR(128),
ig_id VARCHAR(128),
generated INT NOT NULL,
joined DATE NOT NULL,
last_prompt TEXT,
role VARCHAR(128) NOT NULL
);
ALTER TABLE skynet.user
ADD CONSTRAINT tg_unique
UNIQUE (tg_id);
ALTER TABLE skynet.user
ADD CONSTRAINT wp_unique
UNIQUE (wp_id);
ALTER TABLE skynet.user
ADD CONSTRAINT mx_unique
UNIQUE (mx_id);
ALTER TABLE skynet.user
ADD CONSTRAINT ig_unique
UNIQUE (ig_id);
CREATE TABLE IF NOT EXISTS skynet.user_config(
id SERIAL NOT NULL,
algo VARCHAR(128) NOT NULL,
step INT NOT NULL,
width INT NOT NULL,
height INT NOT NULL,
seed INT,
guidance INT NOT NULL,
upscaler VARCHAR(128)
);
ALTER TABLE skynet.user_config
ADD FOREIGN KEY(id)
REFERENCES skynet.user(id);
```

View File

@ -1,2 +1,5 @@
pytest pytest
psycopg2
pytest-trio pytest-trio
git+https://github.com/tgoodlet/pytest-dockerctl.git@master#egg=pytest-dockerctl

View File

@ -30,30 +30,37 @@ async def rpc_service(sock, dgpu_bus, db_pool):
wip_reqs = {} wip_reqs = {}
fin_reqs = {} fin_reqs = {}
def are_all_workers_busy(): def is_worker_busy(nid: int):
for nid, info in nodes.items(): for task in nodes[nid]['tasks']:
if info['task'] == None: if task != None:
return False return False
return True return True
next_worker = 0 def are_all_workers_busy():
for nid in nodes.keys():
if not is_worker_busy(nid):
return False
return True
next_worker: Optional[int] = None
def get_next_worker(): def get_next_worker():
nonlocal next_worker nonlocal next_worker
if len(nodes) == 0: if not next_worker:
raise SkynetDGPUOffline raise SkynetDGPUOffline
if are_all_workers_busy(): if are_all_workers_busy():
raise SkynetDGPUOverloaded raise SkynetDGPUOverloaded
while is_worker_busy(next_worker):
next_worker += 1 next_worker += 1
if next_worker >= len(nodes): if next_worker >= len(nodes):
next_worker = 0 next_worker = 0
nid = list(nodes.keys())[next_worker] return next_worker
return nid
async def dgpu_image_streamer(): async def dgpu_image_streamer():
nonlocal wip_reqs, fin_reqs nonlocal wip_reqs, fin_reqs
@ -74,7 +81,8 @@ async def rpc_service(sock, dgpu_bus, db_pool):
event = trio.Event() event = trio.Event()
wip_reqs[rid] = event wip_reqs[rid] = event
nodes[nid]['task'] = rid tid = nodes[nid]['tasks'].index(None)
nodes[nid]['tasks'][tid] = rid
dgpu_req = DGPUBusRequest( dgpu_req = DGPUBusRequest(
rid=rid, rid=rid,
@ -89,7 +97,7 @@ async def rpc_service(sock, dgpu_bus, db_pool):
await event.wait() await event.wait()
nodes[nid]['task'] = None nodes[nid]['tasks'][tid] = None
img = fin_reqs[rid] img = fin_reqs[rid]
del fin_reqs[rid] del fin_reqs[rid]
@ -167,10 +175,9 @@ async def rpc_service(sock, dgpu_bus, db_pool):
except BaseException as e: except BaseException as e:
logging.error(e) logging.error(e)
raise e result = {
# result = { 'error': 'skynet_internal_error'
# 'error': 'skynet_internal_error' }
# }
await rpc_ctx.asend( await rpc_ctx.asend(
json.dumps( json.dumps(
@ -187,21 +194,36 @@ async def rpc_service(sock, dgpu_bus, db_pool):
logging.info(req) logging.info(req)
result = {}
if req.method == 'dgpu_online': if req.method == 'dgpu_online':
nodes[req.uid] = { nodes[req.uid] = {
'task': None 'tasks': [None for _ in range(req.params['max_tasks'])],
'max_tasks': req.params['max_tasks']
} }
logging.info(f'dgpu online: {req.uid}') logging.info(f'dgpu online: {req.uid}')
if not next_worker:
next_worker = 0
elif req.method == 'dgpu_offline': elif req.method == 'dgpu_offline':
i = nodes.values().index(req.uid) i = list(nodes.keys()).index(req.uid)
del nodes[req.uid] del nodes[req.uid]
if i < next_worker: if i < next_worker:
next_worker -= 1 next_worker -= 1
if len(nodes) == 0:
next_worker = None
logging.info(f'dgpu offline: {req.uid}') logging.info(f'dgpu offline: {req.uid}')
elif req.method == 'dgpu_workers':
result = len(nodes)
elif req.method == 'dgpu_next':
result = next_worker
else: else:
n.start_soon( n.start_soon(
handle_user_request, ctx, req) handle_user_request, ctx, req)
@ -210,12 +232,12 @@ async def rpc_service(sock, dgpu_bus, db_pool):
await ctx.asend( await ctx.asend(
json.dumps( json.dumps(
SkynetRPCResponse( SkynetRPCResponse(
result={'ok': {}}).to_dict()).encode()) result={'ok': result}).to_dict()).encode())
async def run_skynet( async def run_skynet(
db_user: str, db_user: str = DB_USER,
db_pass: str, db_pass: str = DB_PASS,
db_host: str = DB_HOST, db_host: str = DB_HOST,
rpc_address: str = DEFAULT_RPC_ADDR, rpc_address: str = DEFAULT_RPC_ADDR,
dgpu_address: str = DEFAULT_DGPU_ADDR, dgpu_address: str = DEFAULT_DGPU_ADDR,

View File

@ -3,6 +3,9 @@
API_TOKEN = '5880619053:AAFge2UfObw1kCn9Kb7AAyqaIHs_HgM0Fx0' API_TOKEN = '5880619053:AAFge2UfObw1kCn9Kb7AAyqaIHs_HgM0Fx0'
DB_HOST = 'ancap.tech:34508' DB_HOST = 'ancap.tech:34508'
DB_USER = 'skynet'
DB_PASS = 'password'
DB_NAME = 'skynet'
ALGOS = { ALGOS = {
'stable': 'runwayml/stable-diffusion-v1-5', 'stable': 'runwayml/stable-diffusion-v1-5',

View File

@ -11,6 +11,49 @@ import triopg
from .constants import * from .constants import *
DB_INIT_SQL = '''
CREATE SCHEMA IF NOT EXISTS skynet;
CREATE TABLE IF NOT EXISTS skynet.user(
id SERIAL PRIMARY KEY NOT NULL,
tg_id INT,
wp_id VARCHAR(128),
mx_id VARCHAR(128),
ig_id VARCHAR(128),
generated INT NOT NULL,
joined DATE NOT NULL,
last_prompt TEXT,
role VARCHAR(128) NOT NULL
);
ALTER TABLE skynet.user
ADD CONSTRAINT tg_unique
UNIQUE (tg_id);
ALTER TABLE skynet.user
ADD CONSTRAINT wp_unique
UNIQUE (wp_id);
ALTER TABLE skynet.user
ADD CONSTRAINT mx_unique
UNIQUE (mx_id);
ALTER TABLE skynet.user
ADD CONSTRAINT ig_unique
UNIQUE (ig_id);
CREATE TABLE IF NOT EXISTS skynet.user_config(
id SERIAL NOT NULL,
algo VARCHAR(128) NOT NULL,
step INT NOT NULL,
width INT NOT NULL,
height INT NOT NULL,
seed INT,
guidance INT NOT NULL,
upscaler VARCHAR(128)
);
ALTER TABLE skynet.user_config
ADD FOREIGN KEY(id)
REFERENCES skynet.user(id);
'''
def try_decode_uid(uid: str): def try_decode_uid(uid: str):
try: try:
proto, uid = uid.split('+') proto, uid = uid.split('+')
@ -24,14 +67,18 @@ def try_decode_uid(uid: str):
@acm @acm
async def open_database_connection( async def open_database_connection(
db_user: str, db_user: str = DB_USER,
db_pass: str, db_pass: str = DB_PASS,
db_host: str = DB_HOST, db_host: str = DB_HOST,
db_name: str = DB_NAME
): ):
async with triopg.create_pool( async with triopg.create_pool(
dsn=f'postgres://{db_user}:{db_pass}@{db_host}/skynet_art_bot' dsn=f'postgres://{db_user}:{db_pass}@{db_host}/{db_name}'
) as conn: ) as pool_conn:
yield conn async with pool_conn.acquire() as conn:
await conn.execute(DB_INIT_SQL)
yield pool_conn
async def get_user(conn, uid: str): async def get_user(conn, uid: str):

View File

@ -65,20 +65,15 @@ async def open_dgpu_node(
return img return img
with ( async with open_skynet_rpc() as rpc_call:
pynng.Req0(dial=rpc_address) as rpc_sock, with pynng.Bus0(dial=dgpu_address) as dgpu_sock:
pynng.Bus0(dial=dgpu_address) as dgpu_sock
):
async def _rpc_call(*args, **kwargs):
return await rpc_call(rpc_sock, *args, **kwargs)
async def _process_dgpu_req(req: DGPUBusRequest): async def _process_dgpu_req(req: DGPUBusRequest):
img = await gpu_compute_one( img = await gpu_compute_one(
ImageGenRequest(**req.params)) ImageGenRequest(**req.params))
await dgpu_sock.asend( await dgpu_sock.asend(
bytes.fromhex(req.rid) + img) bytes.fromhex(req.rid) + img)
res = await _rpc_call( res = await rpc_call(
name.hex, 'dgpu_online', {'max_tasks': dgpu_max_tasks}) name.hex, 'dgpu_online', {'max_tasks': dgpu_max_tasks})
logging.info(res) logging.info(res)
assert 'ok' in res.result assert 'ok' in res.result
@ -117,6 +112,6 @@ async def open_dgpu_node(
except KeyboardInterrupt: except KeyboardInterrupt:
... ...
res = await _rpc_call(name.hex, 'dgpu_offline') res = await rpc_call(name.hex, 'dgpu_offline')
logging.info(res) logging.info(res)
assert 'ok' in res.result assert 'ok' in res.result

View File

@ -3,14 +3,17 @@
import json import json
from typing import Union from typing import Union
from contextlib import contextmanager as cm from contextlib import asynccontextmanager as acm
import pynng import pynng
from ..types import SkynetRPCRequest, SkynetRPCResponse from ..types import SkynetRPCRequest, SkynetRPCResponse
from ..constants import DEFAULT_RPC_ADDR from ..constants import *
class ConfigRequestFormatError(BaseException):
...
class ConfigUnknownAttribute(BaseException): class ConfigUnknownAttribute(BaseException):
... ...
@ -44,7 +47,71 @@ async def rpc_call(
(await sock.arecv_msg()).bytes.decode())) (await sock.arecv_msg()).bytes.decode()))
@cm @acm
def open_skynet_rpc(rpc_address: str = DEFAULT_RPC_ADDR): async def open_skynet_rpc(rpc_address: str = DEFAULT_RPC_ADDR):
with pynng.Req0(dial=rpc_address) as rpc_sock: with pynng.Req0(dial=rpc_address) as sock:
yield rpc_sock async def _rpc_call(*args, **kwargs):
return await rpc_call(sock, *args, **kwargs)
yield _rpc_call
def validate_user_config_request(req: str):
params = req.split(' ')
if len(params) < 3:
raise ConfigRequestFormatError('config request format incorrect')
else:
try:
attr = params[1]
if attr == 'algo':
val = params[2]
if val not in ALGOS:
raise ConfigUnknownAlgorithm(f'no algo named {val}')
elif attr == 'step':
val = int(params[2])
val = max(min(val, MAX_STEP), MIN_STEP)
elif attr == 'width':
val = max(min(int(params[2]), MAX_WIDTH), 16)
if val % 8 != 0:
raise ConfigSizeDivisionByEight(
'size must be divisible by 8!')
elif attr == 'height':
val = max(min(int(params[2]), MAX_HEIGHT), 16)
if val % 8 != 0:
raise ConfigSizeDivisionByEight(
'size must be divisible by 8!')
elif attr == 'seed':
val = params[2]
if val == 'auto':
val = None
else:
val = int(params[2])
elif attr == 'guidance':
val = float(params[2])
val = max(min(val, MAX_GUIDANCE), 0)
elif attr == 'upscaler':
val = params[2]
if val == 'off':
val = None
elif val != 'x4':
raise ConfigUnknownUpscaler(
f'\"{val}\" is not a valid upscaler')
else:
raise ConfigUnknownAttribute(
f'\"{attr}\" not a configurable parameter')
return attr, val, f'config updated! {attr} to {val}'
except ValueError:
raise ValueError(f'\"{val}\" is not a number silly')

View File

@ -3,6 +3,7 @@
import logging import logging
from datetime import datetime from datetime import datetime
from functools import partial
import pynng import pynng
@ -17,20 +18,21 @@ from . import *
PREFIX = 'tg' PREFIX = 'tg'
async def run_skynet_telegram(tg_token: str): async def run_skynet_telegram(
tg_token: str
):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
bot = AsyncTeleBot(tg_token) bot = AsyncTeleBot(tg_token)
with open_skynet_rpc() as rpc_sock: with open_skynet_rpc() as rpc_call:
async def _rpc_call( async def _rpc_call(
uid: int, uid: int,
method: str, method: str,
params: dict params: dict
): ):
return await rpc_call( return await rpc_call(f'{PREFIX}+{uid}', method, params)
rpc_sock, f'{PREFIX}+{uid}', method, params)
@bot.message_handler(commands=['help']) @bot.message_handler(commands=['help'])
async def send_help(message): async def send_help(message):
@ -58,79 +60,19 @@ async def run_skynet_telegram(tg_token: str):
@bot.message_handler(commands=['config']) @bot.message_handler(commands=['config'])
async def set_config(message): async def set_config(message):
params = message.text.split(' ')
rpc_params = {} rpc_params = {}
if len(params) < 3:
bot.reply_to(message, 'wrong msg format')
else:
try: try:
attr = params[1] attr, val, reply_txt = validate_user_config_request(
message.text)
if attr == 'algo':
val = params[2]
if val not in ALGOS:
raise ConfigUnknownAlgorithm
elif attr == 'step':
val = int(params[2])
val = max(min(val, MAX_STEP), MIN_STEP)
elif attr == 'width':
val = max(min(int(params[2]), MAX_WIDTH), 16)
if val % 8 != 0:
raise ConfigSizeDivisionByEight
elif attr == 'height':
val = max(min(int(params[2]), MAX_HEIGHT), 16)
if val % 8 != 0:
raise ConfigSizeDivisionByEight
elif attr == 'seed':
val = params[2]
if val == 'auto':
val = None
else:
val = int(params[2])
elif attr == 'guidance':
val = float(params[2])
val = max(min(val, MAX_GUIDANCE), 0)
elif attr == 'upscaler':
val = params[2]
if val == 'off':
val = None
elif val != 'x4':
raise ConfigUnknownUpscaler
else:
raise ConfigUnknownAttribute
resp = await _rpc_call( resp = await _rpc_call(
message.from_user.id, message.from_user.id,
'config', {'attr': attr, 'val': val}) 'config', {'attr': attr, 'val': val})
reply_txt = f'config updated! {attr} to {val}' except BaseException as e:
reply_text = e.message
except ConfigUnknownAlgorithm:
reply_txt = f'no algo named {val}'
except ConfigUnknownAttribute:
reply_txt = f'\"{attr}\" not a configurable parameter'
except ConfigUnknownUpscaler:
reply_txt = f'\"{val}\" is not a valid upscaler'
except ConfigSizeDivisionByEight:
reply_txt = 'size must be divisible by 8!'
except ValueError:
reply_txt = f'\"{val}\" is not a number silly'
finally:
await bot.reply_to(message, reply_txt) await bot.reply_to(message, reply_txt)
@bot.message_handler(commands=['stats']) @bot.message_handler(commands=['stats'])

90
tests/conftest.py 100644
View File

@ -0,0 +1,90 @@
#!/usr/bin/python
import time
import random
import string
import logging
from functools import partial
import trio
import pytest
import psycopg2
import trio_asyncio
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
from skynet_bot.constants import *
from skynet_bot.brain import run_skynet
@pytest.fixture(scope='session')
def postgres_db(dockerctl):
rpassword = ''.join(
random.choice(string.ascii_lowercase)
for i in range(12))
password = ''.join(
random.choice(string.ascii_lowercase)
for i in range(12))
with dockerctl.run(
'postgres',
command='postgres',
ports={'5432/tcp': None},
environment={
'POSTGRES_PASSWORD': rpassword
}
) as containers:
container = containers[0]
# ip = container.attrs['NetworkSettings']['IPAddress']
port = container.ports['5432/tcp'][0]['HostPort']
host = f'localhost:{port}'
for log in container.logs(stream=True):
log = log.decode().rstrip()
logging.info(log)
if ('database system is ready to accept connections' in log or
'database system is shut down' in log):
break
# why print the system is ready to accept connections when its not
# postgres? wtf
time.sleep(1)
logging.info('creating skynet db...')
conn = psycopg2.connect(
user='postgres',
password=rpassword,
host='localhost',
port=port
)
logging.info('connected...')
conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
with conn.cursor() as cursor:
cursor.execute(
f'CREATE USER {DB_USER} WITH PASSWORD \'{password}\'')
cursor.execute(
f'CREATE DATABASE {DB_NAME}')
cursor.execute(
f'GRANT ALL PRIVILEGES ON DATABASE {DB_NAME} TO {DB_USER}')
logging.info('done.')
yield container, password, host
@pytest.fixture
async def skynet_running(postgres_db):
db_container, db_pass, db_host = postgres_db
async with (
trio_asyncio.open_loop(),
trio.open_nursery() as n
):
await n.start(
partial(run_skynet,
db_pass=db_pass,
db_host=db_host))
yield
n.cancel_scope.cancel()

View File

@ -0,0 +1,61 @@
#!/usr/bin/python
import logging
import trio
import trio_asyncio
from skynet_bot.types import *
from skynet_bot.brain import run_skynet
from skynet_bot.frontend import open_skynet_rpc
async def test_skynet_dgpu_connection_simple(skynet_running):
async with open_skynet_rpc() as rpc_call:
# check 0 nodes are connected
res = await rpc_call('dgpu-0', 'dgpu_workers')
logging.info(res)
assert 'ok' in res.result
assert res.result['ok'] == 0
# check next worker is None
res = await rpc_call('dgpu-0', 'dgpu_next')
logging.info(res)
assert 'ok' in res.result
assert res.result['ok'] == None
# connect 1 dgpu
res = await rpc_call(
'dgpu-0', 'dgpu_online', {'max_tasks': 3})
logging.info(res)
assert 'ok' in res.result
# check 1 node is connected
res = await rpc_call('dgpu-0', 'dgpu_workers')
logging.info(res)
assert 'ok' in res.result
assert res.result['ok'] == 1
# check next worker is 0
res = await rpc_call('dgpu-0', 'dgpu_next')
logging.info(res)
assert 'ok' in res.result
assert res.result['ok'] == 0
# disconnect 1 dgpu
res = await rpc_call(
'dgpu-0', 'dgpu_offline')
logging.info(res)
assert 'ok' in res.result
# check 0 nodes are connected
res = await rpc_call('dgpu-0', 'dgpu_workers')
logging.info(res)
assert 'ok' in res.result
assert res.result['ok'] == 0
# check next worker is None
res = await rpc_call('dgpu-0', 'dgpu_next')
logging.info(res)
assert 'ok' in res.result
assert res.result['ok'] == None

View File

@ -1,22 +0,0 @@
#!/usr/bin/python
import trio
import trio_asyncio
from skynet_bot.brain import run_skynet
from skynet_bot.frontend import open_skynet_rpc
from skynet_bot.frontend.telegram import run_skynet_telegram
def test_run_tg_bot():
async def main():
async with trio.open_nursery() as n:
await n.start(
run_skynet,
'skynet', '3GbZd6UojbD8V7zWpeFn', 'ancap.tech:34508')
n.start_soon(
run_skynet_telegram,
'5853245787:AAFEmv3EjJ_qJ8d_vmOpi6o6HFHUf8a0uCQ')
trio_asyncio.run(main)