From 9afb192251409b76c6046dec59c1b3f1b5fbf17a Mon Sep 17 00:00:00 2001 From: Guillermo Rodriguez Date: Sun, 11 Dec 2022 11:02:55 -0300 Subject: [PATCH] Started making roboust testing fixtures to init fresh db and skynet Add simple dgpu worker connection test Make db connection handler manage schema and table init logic Keep tweaking dgpu main handler attemtping to fix subactor hangs Change frontend open rpc logic to return a wrapped rpc_call fn referencing the new socket Decupled user config request validation from telegram module Fix next_worker logic, now takes in account multiple tasks per dgpu Add dgpu_workers and dgpu_next calls Fixed readme, moved db init code into db module --- README.md | 49 +---------------- requirements.test.txt | 3 ++ skynet_bot/brain.py | 64 ++++++++++++++-------- skynet_bot/constants.py | 3 ++ skynet_bot/db.py | 57 ++++++++++++++++++-- skynet_bot/dgpu.py | 95 ++++++++++++++++----------------- skynet_bot/frontend/__init__.py | 79 ++++++++++++++++++++++++--- skynet_bot/frontend/telegram.py | 88 ++++++------------------------ tests/conftest.py | 90 +++++++++++++++++++++++++++++++ tests/test_skynet.py | 61 +++++++++++++++++++++ tests/test_telegram.py | 22 -------- 11 files changed, 387 insertions(+), 224 deletions(-) create mode 100644 tests/conftest.py create mode 100644 tests/test_skynet.py delete mode 100644 tests/test_telegram.py diff --git a/README.md b/README.md index a4e8a19..483bdbd 100644 --- a/README.md +++ b/README.md @@ -1,47 +1,2 @@ -create db in postgres: - -```sql -CREATE USER skynet WITH PASSWORD 'password'; -CREATE DATABASE skynet_art_bot; -GRANT ALL PRIVILEGES ON DATABASE skynet_art_bot TO skynet; - -CREATE SCHEMA IF NOT EXISTS skynet; - -CREATE TABLE IF NOT EXISTS skynet.user( - id SERIAL PRIMARY KEY NOT NULL, - tg_id INT, - wp_id VARCHAR(128), - mx_id VARCHAR(128), - ig_id VARCHAR(128), - generated INT NOT NULL, - joined DATE NOT NULL, - last_prompt TEXT, - role VARCHAR(128) NOT NULL -); -ALTER TABLE skynet.user - ADD CONSTRAINT tg_unique - UNIQUE (tg_id); -ALTER TABLE skynet.user - ADD CONSTRAINT wp_unique - UNIQUE (wp_id); -ALTER TABLE skynet.user - ADD CONSTRAINT mx_unique - UNIQUE (mx_id); -ALTER TABLE skynet.user - ADD CONSTRAINT ig_unique - UNIQUE (ig_id); - -CREATE TABLE IF NOT EXISTS skynet.user_config( - id SERIAL NOT NULL, - algo VARCHAR(128) NOT NULL, - step INT NOT NULL, - width INT NOT NULL, - height INT NOT NULL, - seed INT, - guidance INT NOT NULL, - upscaler VARCHAR(128) -); -ALTER TABLE skynet.user_config - ADD FOREIGN KEY(id) - REFERENCES skynet.user(id); -``` +# skynet +### decentralized compute platform diff --git a/requirements.test.txt b/requirements.test.txt index 5f0802d..48af1de 100644 --- a/requirements.test.txt +++ b/requirements.test.txt @@ -1,2 +1,5 @@ pytest +psycopg2 pytest-trio + +git+https://github.com/tgoodlet/pytest-dockerctl.git@master#egg=pytest-dockerctl diff --git a/skynet_bot/brain.py b/skynet_bot/brain.py index 12563e2..37f00dc 100644 --- a/skynet_bot/brain.py +++ b/skynet_bot/brain.py @@ -30,30 +30,37 @@ async def rpc_service(sock, dgpu_bus, db_pool): wip_reqs = {} fin_reqs = {} - def are_all_workers_busy(): - for nid, info in nodes.items(): - if info['task'] == None: + def is_worker_busy(nid: int): + for task in nodes[nid]['tasks']: + if task != None: return False return True - next_worker = 0 + def are_all_workers_busy(): + for nid in nodes.keys(): + if not is_worker_busy(nid): + return False + + return True + + next_worker: Optional[int] = None def get_next_worker(): nonlocal next_worker - if len(nodes) == 0: + if not next_worker: raise SkynetDGPUOffline if are_all_workers_busy(): raise SkynetDGPUOverloaded - next_worker += 1 + while is_worker_busy(next_worker): + next_worker += 1 - if next_worker >= len(nodes): - next_worker = 0 + if next_worker >= len(nodes): + next_worker = 0 - nid = list(nodes.keys())[next_worker] - return nid + return next_worker async def dgpu_image_streamer(): nonlocal wip_reqs, fin_reqs @@ -74,7 +81,8 @@ async def rpc_service(sock, dgpu_bus, db_pool): event = trio.Event() wip_reqs[rid] = event - nodes[nid]['task'] = rid + tid = nodes[nid]['tasks'].index(None) + nodes[nid]['tasks'][tid] = rid dgpu_req = DGPUBusRequest( rid=rid, @@ -89,7 +97,7 @@ async def rpc_service(sock, dgpu_bus, db_pool): await event.wait() - nodes[nid]['task'] = None + nodes[nid]['tasks'][tid] = None img = fin_reqs[rid] del fin_reqs[rid] @@ -167,10 +175,9 @@ async def rpc_service(sock, dgpu_bus, db_pool): except BaseException as e: logging.error(e) - raise e - # result = { - # 'error': 'skynet_internal_error' - # } + result = { + 'error': 'skynet_internal_error' + } await rpc_ctx.asend( json.dumps( @@ -187,21 +194,36 @@ async def rpc_service(sock, dgpu_bus, db_pool): logging.info(req) + result = {} + if req.method == 'dgpu_online': nodes[req.uid] = { - 'task': None + 'tasks': [None for _ in range(req.params['max_tasks'])], + 'max_tasks': req.params['max_tasks'] } logging.info(f'dgpu online: {req.uid}') + if not next_worker: + next_worker = 0 elif req.method == 'dgpu_offline': - i = nodes.values().index(req.uid) + i = list(nodes.keys()).index(req.uid) del nodes[req.uid] if i < next_worker: next_worker -= 1 + + if len(nodes) == 0: + next_worker = None + logging.info(f'dgpu offline: {req.uid}') + elif req.method == 'dgpu_workers': + result = len(nodes) + + elif req.method == 'dgpu_next': + result = next_worker + else: n.start_soon( handle_user_request, ctx, req) @@ -210,12 +232,12 @@ async def rpc_service(sock, dgpu_bus, db_pool): await ctx.asend( json.dumps( SkynetRPCResponse( - result={'ok': {}}).to_dict()).encode()) + result={'ok': result}).to_dict()).encode()) async def run_skynet( - db_user: str, - db_pass: str, + db_user: str = DB_USER, + db_pass: str = DB_PASS, db_host: str = DB_HOST, rpc_address: str = DEFAULT_RPC_ADDR, dgpu_address: str = DEFAULT_DGPU_ADDR, diff --git a/skynet_bot/constants.py b/skynet_bot/constants.py index a7b21ae..0c0c03b 100644 --- a/skynet_bot/constants.py +++ b/skynet_bot/constants.py @@ -3,6 +3,9 @@ API_TOKEN = '5880619053:AAFge2UfObw1kCn9Kb7AAyqaIHs_HgM0Fx0' DB_HOST = 'ancap.tech:34508' +DB_USER = 'skynet' +DB_PASS = 'password' +DB_NAME = 'skynet' ALGOS = { 'stable': 'runwayml/stable-diffusion-v1-5', diff --git a/skynet_bot/db.py b/skynet_bot/db.py index d5c94d7..9998e77 100644 --- a/skynet_bot/db.py +++ b/skynet_bot/db.py @@ -11,6 +11,49 @@ import triopg from .constants import * +DB_INIT_SQL = ''' +CREATE SCHEMA IF NOT EXISTS skynet; + +CREATE TABLE IF NOT EXISTS skynet.user( + id SERIAL PRIMARY KEY NOT NULL, + tg_id INT, + wp_id VARCHAR(128), + mx_id VARCHAR(128), + ig_id VARCHAR(128), + generated INT NOT NULL, + joined DATE NOT NULL, + last_prompt TEXT, + role VARCHAR(128) NOT NULL +); +ALTER TABLE skynet.user + ADD CONSTRAINT tg_unique + UNIQUE (tg_id); +ALTER TABLE skynet.user + ADD CONSTRAINT wp_unique + UNIQUE (wp_id); +ALTER TABLE skynet.user + ADD CONSTRAINT mx_unique + UNIQUE (mx_id); +ALTER TABLE skynet.user + ADD CONSTRAINT ig_unique + UNIQUE (ig_id); + +CREATE TABLE IF NOT EXISTS skynet.user_config( + id SERIAL NOT NULL, + algo VARCHAR(128) NOT NULL, + step INT NOT NULL, + width INT NOT NULL, + height INT NOT NULL, + seed INT, + guidance INT NOT NULL, + upscaler VARCHAR(128) +); +ALTER TABLE skynet.user_config + ADD FOREIGN KEY(id) + REFERENCES skynet.user(id); +''' + + def try_decode_uid(uid: str): try: proto, uid = uid.split('+') @@ -24,14 +67,18 @@ def try_decode_uid(uid: str): @acm async def open_database_connection( - db_user: str, - db_pass: str, + db_user: str = DB_USER, + db_pass: str = DB_PASS, db_host: str = DB_HOST, + db_name: str = DB_NAME ): async with triopg.create_pool( - dsn=f'postgres://{db_user}:{db_pass}@{db_host}/skynet_art_bot' - ) as conn: - yield conn + dsn=f'postgres://{db_user}:{db_pass}@{db_host}/{db_name}' + ) as pool_conn: + async with pool_conn.acquire() as conn: + await conn.execute(DB_INIT_SQL) + + yield pool_conn async def get_user(conn, uid: str): diff --git a/skynet_bot/dgpu.py b/skynet_bot/dgpu.py index c454717..8c019e4 100644 --- a/skynet_bot/dgpu.py +++ b/skynet_bot/dgpu.py @@ -65,58 +65,53 @@ async def open_dgpu_node( return img - with ( - pynng.Req0(dial=rpc_address) as rpc_sock, - pynng.Bus0(dial=dgpu_address) as dgpu_sock - ): - async def _rpc_call(*args, **kwargs): - return await rpc_call(rpc_sock, *args, **kwargs) + async with open_skynet_rpc() as rpc_call: + with pynng.Bus0(dial=dgpu_address) as dgpu_sock: + async def _process_dgpu_req(req: DGPUBusRequest): + img = await gpu_compute_one( + ImageGenRequest(**req.params)) + await dgpu_sock.asend( + bytes.fromhex(req.rid) + img) - async def _process_dgpu_req(req: DGPUBusRequest): - img = await gpu_compute_one( - ImageGenRequest(**req.params)) - await dgpu_sock.asend( - bytes.fromhex(req.rid) + img) + res = await rpc_call( + name.hex, 'dgpu_online', {'max_tasks': dgpu_max_tasks}) + logging.info(res) + assert 'ok' in res.result - res = await _rpc_call( - name.hex, 'dgpu_online', {'max_tasks': dgpu_max_tasks}) - logging.info(res) - assert 'ok' in res.result - - async with ( - tractor.open_actor_cluster( - modules=['skynet_bot.gpu'], - count=dgpu_max_tasks, - names=[i for i in range(dgpu_max_tasks)] - ) as portal_map, - trio.open_nursery() as n - ): - logging.info(f'starting {dgpu_max_tasks} gpu workers') - async with tractor.gather_contexts(( - portal.open_context( - open_gpu_worker, algo, 1.0 / dgpu_max_tasks) - for portal in portal_map.values() - )) as contexts: - contexts = {i: ctx for i, ctx in enumerate(contexts)} - for i, ctx in contexts.items(): - n.start_soon( - gpu_streamer, ctx, i) - try: - while True: - msg = await dgpu_sock.arecv() - req = DGPUBusRequest( - **json.loads(msg.decode())) - - if req.nid != name.hex: - continue - - logging.info(f'dgpu: {name}, req: {req}') + async with ( + tractor.open_actor_cluster( + modules=['skynet_bot.gpu'], + count=dgpu_max_tasks, + names=[i for i in range(dgpu_max_tasks)] + ) as portal_map, + trio.open_nursery() as n + ): + logging.info(f'starting {dgpu_max_tasks} gpu workers') + async with tractor.gather_contexts(( + portal.open_context( + open_gpu_worker, algo, 1.0 / dgpu_max_tasks) + for portal in portal_map.values() + )) as contexts: + contexts = {i: ctx for i, ctx in enumerate(contexts)} + for i, ctx in contexts.items(): n.start_soon( - _process_dgpu_req, req) + gpu_streamer, ctx, i) + try: + while True: + msg = await dgpu_sock.arecv() + req = DGPUBusRequest( + **json.loads(msg.decode())) - except KeyboardInterrupt: - ... + if req.nid != name.hex: + continue - res = await _rpc_call(name.hex, 'dgpu_offline') - logging.info(res) - assert 'ok' in res.result + logging.info(f'dgpu: {name}, req: {req}') + n.start_soon( + _process_dgpu_req, req) + + except KeyboardInterrupt: + ... + + res = await rpc_call(name.hex, 'dgpu_offline') + logging.info(res) + assert 'ok' in res.result diff --git a/skynet_bot/frontend/__init__.py b/skynet_bot/frontend/__init__.py index 7211eb5..4e728f9 100644 --- a/skynet_bot/frontend/__init__.py +++ b/skynet_bot/frontend/__init__.py @@ -3,14 +3,17 @@ import json from typing import Union -from contextlib import contextmanager as cm +from contextlib import asynccontextmanager as acm import pynng from ..types import SkynetRPCRequest, SkynetRPCResponse -from ..constants import DEFAULT_RPC_ADDR +from ..constants import * +class ConfigRequestFormatError(BaseException): + ... + class ConfigUnknownAttribute(BaseException): ... @@ -44,7 +47,71 @@ async def rpc_call( (await sock.arecv_msg()).bytes.decode())) -@cm -def open_skynet_rpc(rpc_address: str = DEFAULT_RPC_ADDR): - with pynng.Req0(dial=rpc_address) as rpc_sock: - yield rpc_sock +@acm +async def open_skynet_rpc(rpc_address: str = DEFAULT_RPC_ADDR): + with pynng.Req0(dial=rpc_address) as sock: + async def _rpc_call(*args, **kwargs): + return await rpc_call(sock, *args, **kwargs) + + yield _rpc_call + + +def validate_user_config_request(req: str): + params = req.split(' ') + + if len(params) < 3: + raise ConfigRequestFormatError('config request format incorrect') + + else: + try: + attr = params[1] + + if attr == 'algo': + val = params[2] + if val not in ALGOS: + raise ConfigUnknownAlgorithm(f'no algo named {val}') + + elif attr == 'step': + val = int(params[2]) + val = max(min(val, MAX_STEP), MIN_STEP) + + elif attr == 'width': + val = max(min(int(params[2]), MAX_WIDTH), 16) + if val % 8 != 0: + raise ConfigSizeDivisionByEight( + 'size must be divisible by 8!') + + elif attr == 'height': + val = max(min(int(params[2]), MAX_HEIGHT), 16) + if val % 8 != 0: + raise ConfigSizeDivisionByEight( + 'size must be divisible by 8!') + + elif attr == 'seed': + val = params[2] + if val == 'auto': + val = None + else: + val = int(params[2]) + + elif attr == 'guidance': + val = float(params[2]) + val = max(min(val, MAX_GUIDANCE), 0) + + elif attr == 'upscaler': + val = params[2] + if val == 'off': + val = None + elif val != 'x4': + raise ConfigUnknownUpscaler( + f'\"{val}\" is not a valid upscaler') + + else: + raise ConfigUnknownAttribute( + f'\"{attr}\" not a configurable parameter') + + return attr, val, f'config updated! {attr} to {val}' + + except ValueError: + raise ValueError(f'\"{val}\" is not a number silly') + diff --git a/skynet_bot/frontend/telegram.py b/skynet_bot/frontend/telegram.py index 8affa29..cc217f8 100644 --- a/skynet_bot/frontend/telegram.py +++ b/skynet_bot/frontend/telegram.py @@ -3,6 +3,7 @@ import logging from datetime import datetime +from functools import partial import pynng @@ -17,20 +18,21 @@ from . import * PREFIX = 'tg' -async def run_skynet_telegram(tg_token: str): +async def run_skynet_telegram( + tg_token: str +): logging.basicConfig(level=logging.INFO) bot = AsyncTeleBot(tg_token) - with open_skynet_rpc() as rpc_sock: + with open_skynet_rpc() as rpc_call: async def _rpc_call( uid: int, method: str, params: dict ): - return await rpc_call( - rpc_sock, f'{PREFIX}+{uid}', method, params) + return await rpc_call(f'{PREFIX}+{uid}', method, params) @bot.message_handler(commands=['help']) async def send_help(message): @@ -58,79 +60,19 @@ async def run_skynet_telegram(tg_token: str): @bot.message_handler(commands=['config']) async def set_config(message): - params = message.text.split(' ') - rpc_params = {} + try: + attr, val, reply_txt = validate_user_config_request( + message.text) - if len(params) < 3: - bot.reply_to(message, 'wrong msg format') + resp = await _rpc_call( + message.from_user.id, + 'config', {'attr': attr, 'val': val}) - else: - - try: - attr = params[1] - - if attr == 'algo': - val = params[2] - if val not in ALGOS: - raise ConfigUnknownAlgorithm - - elif attr == 'step': - val = int(params[2]) - val = max(min(val, MAX_STEP), MIN_STEP) - - elif attr == 'width': - val = max(min(int(params[2]), MAX_WIDTH), 16) - if val % 8 != 0: - raise ConfigSizeDivisionByEight - - elif attr == 'height': - val = max(min(int(params[2]), MAX_HEIGHT), 16) - if val % 8 != 0: - raise ConfigSizeDivisionByEight - - elif attr == 'seed': - val = params[2] - if val == 'auto': - val = None - else: - val = int(params[2]) - - elif attr == 'guidance': - val = float(params[2]) - val = max(min(val, MAX_GUIDANCE), 0) - - elif attr == 'upscaler': - val = params[2] - if val == 'off': - val = None - elif val != 'x4': - raise ConfigUnknownUpscaler - - else: - raise ConfigUnknownAttribute - - resp = await _rpc_call( - message.from_user.id, - 'config', {'attr': attr, 'val': val}) - - reply_txt = f'config updated! {attr} to {val}' - - except ConfigUnknownAlgorithm: - reply_txt = f'no algo named {val}' - - except ConfigUnknownAttribute: - reply_txt = f'\"{attr}\" not a configurable parameter' - - except ConfigUnknownUpscaler: - reply_txt = f'\"{val}\" is not a valid upscaler' - - except ConfigSizeDivisionByEight: - reply_txt = 'size must be divisible by 8!' - - except ValueError: - reply_txt = f'\"{val}\" is not a number silly' + except BaseException as e: + reply_text = e.message + finally: await bot.reply_to(message, reply_txt) @bot.message_handler(commands=['stats']) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..88d7f76 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,90 @@ +#!/usr/bin/python + +import time +import random +import string +import logging + +from functools import partial + +import trio +import pytest +import psycopg2 +import trio_asyncio + +from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT + +from skynet_bot.constants import * +from skynet_bot.brain import run_skynet + + +@pytest.fixture(scope='session') +def postgres_db(dockerctl): + rpassword = ''.join( + random.choice(string.ascii_lowercase) + for i in range(12)) + password = ''.join( + random.choice(string.ascii_lowercase) + for i in range(12)) + + with dockerctl.run( + 'postgres', + command='postgres', + ports={'5432/tcp': None}, + environment={ + 'POSTGRES_PASSWORD': rpassword + } + ) as containers: + container = containers[0] + # ip = container.attrs['NetworkSettings']['IPAddress'] + port = container.ports['5432/tcp'][0]['HostPort'] + host = f'localhost:{port}' + + for log in container.logs(stream=True): + log = log.decode().rstrip() + logging.info(log) + if ('database system is ready to accept connections' in log or + 'database system is shut down' in log): + break + + # why print the system is ready to accept connections when its not + # postgres? wtf + time.sleep(1) + logging.info('creating skynet db...') + + conn = psycopg2.connect( + user='postgres', + password=rpassword, + host='localhost', + port=port + ) + logging.info('connected...') + conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) + with conn.cursor() as cursor: + cursor.execute( + f'CREATE USER {DB_USER} WITH PASSWORD \'{password}\'') + cursor.execute( + f'CREATE DATABASE {DB_NAME}') + cursor.execute( + f'GRANT ALL PRIVILEGES ON DATABASE {DB_NAME} TO {DB_USER}') + + logging.info('done.') + yield container, password, host + + +@pytest.fixture +async def skynet_running(postgres_db): + db_container, db_pass, db_host = postgres_db + async with ( + trio_asyncio.open_loop(), + trio.open_nursery() as n + ): + await n.start( + partial(run_skynet, + db_pass=db_pass, + db_host=db_host)) + + yield + n.cancel_scope.cancel() + + diff --git a/tests/test_skynet.py b/tests/test_skynet.py new file mode 100644 index 0000000..f1520f8 --- /dev/null +++ b/tests/test_skynet.py @@ -0,0 +1,61 @@ +#!/usr/bin/python + +import logging + +import trio +import trio_asyncio + +from skynet_bot.types import * +from skynet_bot.brain import run_skynet +from skynet_bot.frontend import open_skynet_rpc + + +async def test_skynet_dgpu_connection_simple(skynet_running): + async with open_skynet_rpc() as rpc_call: + # check 0 nodes are connected + res = await rpc_call('dgpu-0', 'dgpu_workers') + logging.info(res) + assert 'ok' in res.result + assert res.result['ok'] == 0 + + # check next worker is None + res = await rpc_call('dgpu-0', 'dgpu_next') + logging.info(res) + assert 'ok' in res.result + assert res.result['ok'] == None + + # connect 1 dgpu + res = await rpc_call( + 'dgpu-0', 'dgpu_online', {'max_tasks': 3}) + logging.info(res) + assert 'ok' in res.result + + # check 1 node is connected + res = await rpc_call('dgpu-0', 'dgpu_workers') + logging.info(res) + assert 'ok' in res.result + assert res.result['ok'] == 1 + + # check next worker is 0 + res = await rpc_call('dgpu-0', 'dgpu_next') + logging.info(res) + assert 'ok' in res.result + assert res.result['ok'] == 0 + + # disconnect 1 dgpu + res = await rpc_call( + 'dgpu-0', 'dgpu_offline') + logging.info(res) + assert 'ok' in res.result + + # check 0 nodes are connected + res = await rpc_call('dgpu-0', 'dgpu_workers') + logging.info(res) + assert 'ok' in res.result + assert res.result['ok'] == 0 + + # check next worker is None + res = await rpc_call('dgpu-0', 'dgpu_next') + logging.info(res) + assert 'ok' in res.result + assert res.result['ok'] == None diff --git a/tests/test_telegram.py b/tests/test_telegram.py deleted file mode 100644 index fe99566..0000000 --- a/tests/test_telegram.py +++ /dev/null @@ -1,22 +0,0 @@ -#!/usr/bin/python - -import trio -import trio_asyncio - -from skynet_bot.brain import run_skynet -from skynet_bot.frontend import open_skynet_rpc -from skynet_bot.frontend.telegram import run_skynet_telegram - - -def test_run_tg_bot(): - async def main(): - async with trio.open_nursery() as n: - await n.start( - run_skynet, - 'skynet', '3GbZd6UojbD8V7zWpeFn', 'ancap.tech:34508') - n.start_soon( - run_skynet_telegram, - '5853245787:AAFEmv3EjJ_qJ8d_vmOpi6o6HFHUf8a0uCQ') - - - trio_asyncio.run(main)