Started making roboust testing fixtures to init fresh db and skynet

Add simple dgpu worker connection test
Make db connection handler manage schema and table init logic
Keep tweaking dgpu main handler attemtping to fix subactor hangs
Change frontend open rpc logic to return a wrapped rpc_call fn referencing the new socket
Decupled user config request validation from telegram module
Fix next_worker logic, now takes in account multiple tasks per dgpu
Add dgpu_workers and dgpu_next calls
Fixed readme, moved db init code into db module
pull/2/head
Guillermo Rodriguez 2022-12-11 11:02:55 -03:00
parent c3852314a7
commit 9afb192251
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
11 changed files with 387 additions and 224 deletions

View File

@ -1,47 +1,2 @@
create db in postgres:
```sql
CREATE USER skynet WITH PASSWORD 'password';
CREATE DATABASE skynet_art_bot;
GRANT ALL PRIVILEGES ON DATABASE skynet_art_bot TO skynet;
CREATE SCHEMA IF NOT EXISTS skynet;
CREATE TABLE IF NOT EXISTS skynet.user(
id SERIAL PRIMARY KEY NOT NULL,
tg_id INT,
wp_id VARCHAR(128),
mx_id VARCHAR(128),
ig_id VARCHAR(128),
generated INT NOT NULL,
joined DATE NOT NULL,
last_prompt TEXT,
role VARCHAR(128) NOT NULL
);
ALTER TABLE skynet.user
ADD CONSTRAINT tg_unique
UNIQUE (tg_id);
ALTER TABLE skynet.user
ADD CONSTRAINT wp_unique
UNIQUE (wp_id);
ALTER TABLE skynet.user
ADD CONSTRAINT mx_unique
UNIQUE (mx_id);
ALTER TABLE skynet.user
ADD CONSTRAINT ig_unique
UNIQUE (ig_id);
CREATE TABLE IF NOT EXISTS skynet.user_config(
id SERIAL NOT NULL,
algo VARCHAR(128) NOT NULL,
step INT NOT NULL,
width INT NOT NULL,
height INT NOT NULL,
seed INT,
guidance INT NOT NULL,
upscaler VARCHAR(128)
);
ALTER TABLE skynet.user_config
ADD FOREIGN KEY(id)
REFERENCES skynet.user(id);
```
# skynet
### decentralized compute platform

View File

@ -1,2 +1,5 @@
pytest
psycopg2
pytest-trio
git+https://github.com/tgoodlet/pytest-dockerctl.git@master#egg=pytest-dockerctl

View File

@ -30,30 +30,37 @@ async def rpc_service(sock, dgpu_bus, db_pool):
wip_reqs = {}
fin_reqs = {}
def are_all_workers_busy():
for nid, info in nodes.items():
if info['task'] == None:
def is_worker_busy(nid: int):
for task in nodes[nid]['tasks']:
if task != None:
return False
return True
next_worker = 0
def are_all_workers_busy():
for nid in nodes.keys():
if not is_worker_busy(nid):
return False
return True
next_worker: Optional[int] = None
def get_next_worker():
nonlocal next_worker
if len(nodes) == 0:
if not next_worker:
raise SkynetDGPUOffline
if are_all_workers_busy():
raise SkynetDGPUOverloaded
next_worker += 1
while is_worker_busy(next_worker):
next_worker += 1
if next_worker >= len(nodes):
next_worker = 0
if next_worker >= len(nodes):
next_worker = 0
nid = list(nodes.keys())[next_worker]
return nid
return next_worker
async def dgpu_image_streamer():
nonlocal wip_reqs, fin_reqs
@ -74,7 +81,8 @@ async def rpc_service(sock, dgpu_bus, db_pool):
event = trio.Event()
wip_reqs[rid] = event
nodes[nid]['task'] = rid
tid = nodes[nid]['tasks'].index(None)
nodes[nid]['tasks'][tid] = rid
dgpu_req = DGPUBusRequest(
rid=rid,
@ -89,7 +97,7 @@ async def rpc_service(sock, dgpu_bus, db_pool):
await event.wait()
nodes[nid]['task'] = None
nodes[nid]['tasks'][tid] = None
img = fin_reqs[rid]
del fin_reqs[rid]
@ -167,10 +175,9 @@ async def rpc_service(sock, dgpu_bus, db_pool):
except BaseException as e:
logging.error(e)
raise e
# result = {
# 'error': 'skynet_internal_error'
# }
result = {
'error': 'skynet_internal_error'
}
await rpc_ctx.asend(
json.dumps(
@ -187,21 +194,36 @@ async def rpc_service(sock, dgpu_bus, db_pool):
logging.info(req)
result = {}
if req.method == 'dgpu_online':
nodes[req.uid] = {
'task': None
'tasks': [None for _ in range(req.params['max_tasks'])],
'max_tasks': req.params['max_tasks']
}
logging.info(f'dgpu online: {req.uid}')
if not next_worker:
next_worker = 0
elif req.method == 'dgpu_offline':
i = nodes.values().index(req.uid)
i = list(nodes.keys()).index(req.uid)
del nodes[req.uid]
if i < next_worker:
next_worker -= 1
if len(nodes) == 0:
next_worker = None
logging.info(f'dgpu offline: {req.uid}')
elif req.method == 'dgpu_workers':
result = len(nodes)
elif req.method == 'dgpu_next':
result = next_worker
else:
n.start_soon(
handle_user_request, ctx, req)
@ -210,12 +232,12 @@ async def rpc_service(sock, dgpu_bus, db_pool):
await ctx.asend(
json.dumps(
SkynetRPCResponse(
result={'ok': {}}).to_dict()).encode())
result={'ok': result}).to_dict()).encode())
async def run_skynet(
db_user: str,
db_pass: str,
db_user: str = DB_USER,
db_pass: str = DB_PASS,
db_host: str = DB_HOST,
rpc_address: str = DEFAULT_RPC_ADDR,
dgpu_address: str = DEFAULT_DGPU_ADDR,

View File

@ -3,6 +3,9 @@
API_TOKEN = '5880619053:AAFge2UfObw1kCn9Kb7AAyqaIHs_HgM0Fx0'
DB_HOST = 'ancap.tech:34508'
DB_USER = 'skynet'
DB_PASS = 'password'
DB_NAME = 'skynet'
ALGOS = {
'stable': 'runwayml/stable-diffusion-v1-5',

View File

@ -11,6 +11,49 @@ import triopg
from .constants import *
DB_INIT_SQL = '''
CREATE SCHEMA IF NOT EXISTS skynet;
CREATE TABLE IF NOT EXISTS skynet.user(
id SERIAL PRIMARY KEY NOT NULL,
tg_id INT,
wp_id VARCHAR(128),
mx_id VARCHAR(128),
ig_id VARCHAR(128),
generated INT NOT NULL,
joined DATE NOT NULL,
last_prompt TEXT,
role VARCHAR(128) NOT NULL
);
ALTER TABLE skynet.user
ADD CONSTRAINT tg_unique
UNIQUE (tg_id);
ALTER TABLE skynet.user
ADD CONSTRAINT wp_unique
UNIQUE (wp_id);
ALTER TABLE skynet.user
ADD CONSTRAINT mx_unique
UNIQUE (mx_id);
ALTER TABLE skynet.user
ADD CONSTRAINT ig_unique
UNIQUE (ig_id);
CREATE TABLE IF NOT EXISTS skynet.user_config(
id SERIAL NOT NULL,
algo VARCHAR(128) NOT NULL,
step INT NOT NULL,
width INT NOT NULL,
height INT NOT NULL,
seed INT,
guidance INT NOT NULL,
upscaler VARCHAR(128)
);
ALTER TABLE skynet.user_config
ADD FOREIGN KEY(id)
REFERENCES skynet.user(id);
'''
def try_decode_uid(uid: str):
try:
proto, uid = uid.split('+')
@ -24,14 +67,18 @@ def try_decode_uid(uid: str):
@acm
async def open_database_connection(
db_user: str,
db_pass: str,
db_user: str = DB_USER,
db_pass: str = DB_PASS,
db_host: str = DB_HOST,
db_name: str = DB_NAME
):
async with triopg.create_pool(
dsn=f'postgres://{db_user}:{db_pass}@{db_host}/skynet_art_bot'
) as conn:
yield conn
dsn=f'postgres://{db_user}:{db_pass}@{db_host}/{db_name}'
) as pool_conn:
async with pool_conn.acquire() as conn:
await conn.execute(DB_INIT_SQL)
yield pool_conn
async def get_user(conn, uid: str):

View File

@ -65,58 +65,53 @@ async def open_dgpu_node(
return img
with (
pynng.Req0(dial=rpc_address) as rpc_sock,
pynng.Bus0(dial=dgpu_address) as dgpu_sock
):
async def _rpc_call(*args, **kwargs):
return await rpc_call(rpc_sock, *args, **kwargs)
async with open_skynet_rpc() as rpc_call:
with pynng.Bus0(dial=dgpu_address) as dgpu_sock:
async def _process_dgpu_req(req: DGPUBusRequest):
img = await gpu_compute_one(
ImageGenRequest(**req.params))
await dgpu_sock.asend(
bytes.fromhex(req.rid) + img)
async def _process_dgpu_req(req: DGPUBusRequest):
img = await gpu_compute_one(
ImageGenRequest(**req.params))
await dgpu_sock.asend(
bytes.fromhex(req.rid) + img)
res = await rpc_call(
name.hex, 'dgpu_online', {'max_tasks': dgpu_max_tasks})
logging.info(res)
assert 'ok' in res.result
res = await _rpc_call(
name.hex, 'dgpu_online', {'max_tasks': dgpu_max_tasks})
logging.info(res)
assert 'ok' in res.result
async with (
tractor.open_actor_cluster(
modules=['skynet_bot.gpu'],
count=dgpu_max_tasks,
names=[i for i in range(dgpu_max_tasks)]
) as portal_map,
trio.open_nursery() as n
):
logging.info(f'starting {dgpu_max_tasks} gpu workers')
async with tractor.gather_contexts((
portal.open_context(
open_gpu_worker, algo, 1.0 / dgpu_max_tasks)
for portal in portal_map.values()
)) as contexts:
contexts = {i: ctx for i, ctx in enumerate(contexts)}
for i, ctx in contexts.items():
n.start_soon(
gpu_streamer, ctx, i)
try:
while True:
msg = await dgpu_sock.arecv()
req = DGPUBusRequest(
**json.loads(msg.decode()))
if req.nid != name.hex:
continue
logging.info(f'dgpu: {name}, req: {req}')
async with (
tractor.open_actor_cluster(
modules=['skynet_bot.gpu'],
count=dgpu_max_tasks,
names=[i for i in range(dgpu_max_tasks)]
) as portal_map,
trio.open_nursery() as n
):
logging.info(f'starting {dgpu_max_tasks} gpu workers')
async with tractor.gather_contexts((
portal.open_context(
open_gpu_worker, algo, 1.0 / dgpu_max_tasks)
for portal in portal_map.values()
)) as contexts:
contexts = {i: ctx for i, ctx in enumerate(contexts)}
for i, ctx in contexts.items():
n.start_soon(
_process_dgpu_req, req)
gpu_streamer, ctx, i)
try:
while True:
msg = await dgpu_sock.arecv()
req = DGPUBusRequest(
**json.loads(msg.decode()))
except KeyboardInterrupt:
...
if req.nid != name.hex:
continue
res = await _rpc_call(name.hex, 'dgpu_offline')
logging.info(res)
assert 'ok' in res.result
logging.info(f'dgpu: {name}, req: {req}')
n.start_soon(
_process_dgpu_req, req)
except KeyboardInterrupt:
...
res = await rpc_call(name.hex, 'dgpu_offline')
logging.info(res)
assert 'ok' in res.result

View File

@ -3,14 +3,17 @@
import json
from typing import Union
from contextlib import contextmanager as cm
from contextlib import asynccontextmanager as acm
import pynng
from ..types import SkynetRPCRequest, SkynetRPCResponse
from ..constants import DEFAULT_RPC_ADDR
from ..constants import *
class ConfigRequestFormatError(BaseException):
...
class ConfigUnknownAttribute(BaseException):
...
@ -44,7 +47,71 @@ async def rpc_call(
(await sock.arecv_msg()).bytes.decode()))
@cm
def open_skynet_rpc(rpc_address: str = DEFAULT_RPC_ADDR):
with pynng.Req0(dial=rpc_address) as rpc_sock:
yield rpc_sock
@acm
async def open_skynet_rpc(rpc_address: str = DEFAULT_RPC_ADDR):
with pynng.Req0(dial=rpc_address) as sock:
async def _rpc_call(*args, **kwargs):
return await rpc_call(sock, *args, **kwargs)
yield _rpc_call
def validate_user_config_request(req: str):
params = req.split(' ')
if len(params) < 3:
raise ConfigRequestFormatError('config request format incorrect')
else:
try:
attr = params[1]
if attr == 'algo':
val = params[2]
if val not in ALGOS:
raise ConfigUnknownAlgorithm(f'no algo named {val}')
elif attr == 'step':
val = int(params[2])
val = max(min(val, MAX_STEP), MIN_STEP)
elif attr == 'width':
val = max(min(int(params[2]), MAX_WIDTH), 16)
if val % 8 != 0:
raise ConfigSizeDivisionByEight(
'size must be divisible by 8!')
elif attr == 'height':
val = max(min(int(params[2]), MAX_HEIGHT), 16)
if val % 8 != 0:
raise ConfigSizeDivisionByEight(
'size must be divisible by 8!')
elif attr == 'seed':
val = params[2]
if val == 'auto':
val = None
else:
val = int(params[2])
elif attr == 'guidance':
val = float(params[2])
val = max(min(val, MAX_GUIDANCE), 0)
elif attr == 'upscaler':
val = params[2]
if val == 'off':
val = None
elif val != 'x4':
raise ConfigUnknownUpscaler(
f'\"{val}\" is not a valid upscaler')
else:
raise ConfigUnknownAttribute(
f'\"{attr}\" not a configurable parameter')
return attr, val, f'config updated! {attr} to {val}'
except ValueError:
raise ValueError(f'\"{val}\" is not a number silly')

View File

@ -3,6 +3,7 @@
import logging
from datetime import datetime
from functools import partial
import pynng
@ -17,20 +18,21 @@ from . import *
PREFIX = 'tg'
async def run_skynet_telegram(tg_token: str):
async def run_skynet_telegram(
tg_token: str
):
logging.basicConfig(level=logging.INFO)
bot = AsyncTeleBot(tg_token)
with open_skynet_rpc() as rpc_sock:
with open_skynet_rpc() as rpc_call:
async def _rpc_call(
uid: int,
method: str,
params: dict
):
return await rpc_call(
rpc_sock, f'{PREFIX}+{uid}', method, params)
return await rpc_call(f'{PREFIX}+{uid}', method, params)
@bot.message_handler(commands=['help'])
async def send_help(message):
@ -58,79 +60,19 @@ async def run_skynet_telegram(tg_token: str):
@bot.message_handler(commands=['config'])
async def set_config(message):
params = message.text.split(' ')
rpc_params = {}
try:
attr, val, reply_txt = validate_user_config_request(
message.text)
if len(params) < 3:
bot.reply_to(message, 'wrong msg format')
resp = await _rpc_call(
message.from_user.id,
'config', {'attr': attr, 'val': val})
else:
try:
attr = params[1]
if attr == 'algo':
val = params[2]
if val not in ALGOS:
raise ConfigUnknownAlgorithm
elif attr == 'step':
val = int(params[2])
val = max(min(val, MAX_STEP), MIN_STEP)
elif attr == 'width':
val = max(min(int(params[2]), MAX_WIDTH), 16)
if val % 8 != 0:
raise ConfigSizeDivisionByEight
elif attr == 'height':
val = max(min(int(params[2]), MAX_HEIGHT), 16)
if val % 8 != 0:
raise ConfigSizeDivisionByEight
elif attr == 'seed':
val = params[2]
if val == 'auto':
val = None
else:
val = int(params[2])
elif attr == 'guidance':
val = float(params[2])
val = max(min(val, MAX_GUIDANCE), 0)
elif attr == 'upscaler':
val = params[2]
if val == 'off':
val = None
elif val != 'x4':
raise ConfigUnknownUpscaler
else:
raise ConfigUnknownAttribute
resp = await _rpc_call(
message.from_user.id,
'config', {'attr': attr, 'val': val})
reply_txt = f'config updated! {attr} to {val}'
except ConfigUnknownAlgorithm:
reply_txt = f'no algo named {val}'
except ConfigUnknownAttribute:
reply_txt = f'\"{attr}\" not a configurable parameter'
except ConfigUnknownUpscaler:
reply_txt = f'\"{val}\" is not a valid upscaler'
except ConfigSizeDivisionByEight:
reply_txt = 'size must be divisible by 8!'
except ValueError:
reply_txt = f'\"{val}\" is not a number silly'
except BaseException as e:
reply_text = e.message
finally:
await bot.reply_to(message, reply_txt)
@bot.message_handler(commands=['stats'])

90
tests/conftest.py 100644
View File

@ -0,0 +1,90 @@
#!/usr/bin/python
import time
import random
import string
import logging
from functools import partial
import trio
import pytest
import psycopg2
import trio_asyncio
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
from skynet_bot.constants import *
from skynet_bot.brain import run_skynet
@pytest.fixture(scope='session')
def postgres_db(dockerctl):
rpassword = ''.join(
random.choice(string.ascii_lowercase)
for i in range(12))
password = ''.join(
random.choice(string.ascii_lowercase)
for i in range(12))
with dockerctl.run(
'postgres',
command='postgres',
ports={'5432/tcp': None},
environment={
'POSTGRES_PASSWORD': rpassword
}
) as containers:
container = containers[0]
# ip = container.attrs['NetworkSettings']['IPAddress']
port = container.ports['5432/tcp'][0]['HostPort']
host = f'localhost:{port}'
for log in container.logs(stream=True):
log = log.decode().rstrip()
logging.info(log)
if ('database system is ready to accept connections' in log or
'database system is shut down' in log):
break
# why print the system is ready to accept connections when its not
# postgres? wtf
time.sleep(1)
logging.info('creating skynet db...')
conn = psycopg2.connect(
user='postgres',
password=rpassword,
host='localhost',
port=port
)
logging.info('connected...')
conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
with conn.cursor() as cursor:
cursor.execute(
f'CREATE USER {DB_USER} WITH PASSWORD \'{password}\'')
cursor.execute(
f'CREATE DATABASE {DB_NAME}')
cursor.execute(
f'GRANT ALL PRIVILEGES ON DATABASE {DB_NAME} TO {DB_USER}')
logging.info('done.')
yield container, password, host
@pytest.fixture
async def skynet_running(postgres_db):
db_container, db_pass, db_host = postgres_db
async with (
trio_asyncio.open_loop(),
trio.open_nursery() as n
):
await n.start(
partial(run_skynet,
db_pass=db_pass,
db_host=db_host))
yield
n.cancel_scope.cancel()

View File

@ -0,0 +1,61 @@
#!/usr/bin/python
import logging
import trio
import trio_asyncio
from skynet_bot.types import *
from skynet_bot.brain import run_skynet
from skynet_bot.frontend import open_skynet_rpc
async def test_skynet_dgpu_connection_simple(skynet_running):
async with open_skynet_rpc() as rpc_call:
# check 0 nodes are connected
res = await rpc_call('dgpu-0', 'dgpu_workers')
logging.info(res)
assert 'ok' in res.result
assert res.result['ok'] == 0
# check next worker is None
res = await rpc_call('dgpu-0', 'dgpu_next')
logging.info(res)
assert 'ok' in res.result
assert res.result['ok'] == None
# connect 1 dgpu
res = await rpc_call(
'dgpu-0', 'dgpu_online', {'max_tasks': 3})
logging.info(res)
assert 'ok' in res.result
# check 1 node is connected
res = await rpc_call('dgpu-0', 'dgpu_workers')
logging.info(res)
assert 'ok' in res.result
assert res.result['ok'] == 1
# check next worker is 0
res = await rpc_call('dgpu-0', 'dgpu_next')
logging.info(res)
assert 'ok' in res.result
assert res.result['ok'] == 0
# disconnect 1 dgpu
res = await rpc_call(
'dgpu-0', 'dgpu_offline')
logging.info(res)
assert 'ok' in res.result
# check 0 nodes are connected
res = await rpc_call('dgpu-0', 'dgpu_workers')
logging.info(res)
assert 'ok' in res.result
assert res.result['ok'] == 0
# check next worker is None
res = await rpc_call('dgpu-0', 'dgpu_next')
logging.info(res)
assert 'ok' in res.result
assert res.result['ok'] == None

View File

@ -1,22 +0,0 @@
#!/usr/bin/python
import trio
import trio_asyncio
from skynet_bot.brain import run_skynet
from skynet_bot.frontend import open_skynet_rpc
from skynet_bot.frontend.telegram import run_skynet_telegram
def test_run_tg_bot():
async def main():
async with trio.open_nursery() as n:
await n.start(
run_skynet,
'skynet', '3GbZd6UojbD8V7zWpeFn', 'ancap.tech:34508')
n.start_soon(
run_skynet_telegram,
'5853245787:AAFEmv3EjJ_qJ8d_vmOpi6o6HFHUf8a0uCQ')
trio_asyncio.run(main)