diff --git a/README.md b/README.md index a4e8a19..483bdbd 100644 --- a/README.md +++ b/README.md @@ -1,47 +1,2 @@ -create db in postgres: - -```sql -CREATE USER skynet WITH PASSWORD 'password'; -CREATE DATABASE skynet_art_bot; -GRANT ALL PRIVILEGES ON DATABASE skynet_art_bot TO skynet; - -CREATE SCHEMA IF NOT EXISTS skynet; - -CREATE TABLE IF NOT EXISTS skynet.user( - id SERIAL PRIMARY KEY NOT NULL, - tg_id INT, - wp_id VARCHAR(128), - mx_id VARCHAR(128), - ig_id VARCHAR(128), - generated INT NOT NULL, - joined DATE NOT NULL, - last_prompt TEXT, - role VARCHAR(128) NOT NULL -); -ALTER TABLE skynet.user - ADD CONSTRAINT tg_unique - UNIQUE (tg_id); -ALTER TABLE skynet.user - ADD CONSTRAINT wp_unique - UNIQUE (wp_id); -ALTER TABLE skynet.user - ADD CONSTRAINT mx_unique - UNIQUE (mx_id); -ALTER TABLE skynet.user - ADD CONSTRAINT ig_unique - UNIQUE (ig_id); - -CREATE TABLE IF NOT EXISTS skynet.user_config( - id SERIAL NOT NULL, - algo VARCHAR(128) NOT NULL, - step INT NOT NULL, - width INT NOT NULL, - height INT NOT NULL, - seed INT, - guidance INT NOT NULL, - upscaler VARCHAR(128) -); -ALTER TABLE skynet.user_config - ADD FOREIGN KEY(id) - REFERENCES skynet.user(id); -``` +# skynet +### decentralized compute platform diff --git a/requirements.test.txt b/requirements.test.txt index 5f0802d..48af1de 100644 --- a/requirements.test.txt +++ b/requirements.test.txt @@ -1,2 +1,5 @@ pytest +psycopg2 pytest-trio + +git+https://github.com/tgoodlet/pytest-dockerctl.git@master#egg=pytest-dockerctl diff --git a/skynet_bot/brain.py b/skynet_bot/brain.py index 12563e2..37f00dc 100644 --- a/skynet_bot/brain.py +++ b/skynet_bot/brain.py @@ -30,30 +30,37 @@ async def rpc_service(sock, dgpu_bus, db_pool): wip_reqs = {} fin_reqs = {} - def are_all_workers_busy(): - for nid, info in nodes.items(): - if info['task'] == None: + def is_worker_busy(nid: int): + for task in nodes[nid]['tasks']: + if task != None: return False return True - next_worker = 0 + def are_all_workers_busy(): + for nid in nodes.keys(): + if not is_worker_busy(nid): + return False + + return True + + next_worker: Optional[int] = None def get_next_worker(): nonlocal next_worker - if len(nodes) == 0: + if not next_worker: raise SkynetDGPUOffline if are_all_workers_busy(): raise SkynetDGPUOverloaded - next_worker += 1 + while is_worker_busy(next_worker): + next_worker += 1 - if next_worker >= len(nodes): - next_worker = 0 + if next_worker >= len(nodes): + next_worker = 0 - nid = list(nodes.keys())[next_worker] - return nid + return next_worker async def dgpu_image_streamer(): nonlocal wip_reqs, fin_reqs @@ -74,7 +81,8 @@ async def rpc_service(sock, dgpu_bus, db_pool): event = trio.Event() wip_reqs[rid] = event - nodes[nid]['task'] = rid + tid = nodes[nid]['tasks'].index(None) + nodes[nid]['tasks'][tid] = rid dgpu_req = DGPUBusRequest( rid=rid, @@ -89,7 +97,7 @@ async def rpc_service(sock, dgpu_bus, db_pool): await event.wait() - nodes[nid]['task'] = None + nodes[nid]['tasks'][tid] = None img = fin_reqs[rid] del fin_reqs[rid] @@ -167,10 +175,9 @@ async def rpc_service(sock, dgpu_bus, db_pool): except BaseException as e: logging.error(e) - raise e - # result = { - # 'error': 'skynet_internal_error' - # } + result = { + 'error': 'skynet_internal_error' + } await rpc_ctx.asend( json.dumps( @@ -187,21 +194,36 @@ async def rpc_service(sock, dgpu_bus, db_pool): logging.info(req) + result = {} + if req.method == 'dgpu_online': nodes[req.uid] = { - 'task': None + 'tasks': [None for _ in range(req.params['max_tasks'])], + 'max_tasks': req.params['max_tasks'] } logging.info(f'dgpu online: {req.uid}') + if not next_worker: + next_worker = 0 elif req.method == 'dgpu_offline': - i = nodes.values().index(req.uid) + i = list(nodes.keys()).index(req.uid) del nodes[req.uid] if i < next_worker: next_worker -= 1 + + if len(nodes) == 0: + next_worker = None + logging.info(f'dgpu offline: {req.uid}') + elif req.method == 'dgpu_workers': + result = len(nodes) + + elif req.method == 'dgpu_next': + result = next_worker + else: n.start_soon( handle_user_request, ctx, req) @@ -210,12 +232,12 @@ async def rpc_service(sock, dgpu_bus, db_pool): await ctx.asend( json.dumps( SkynetRPCResponse( - result={'ok': {}}).to_dict()).encode()) + result={'ok': result}).to_dict()).encode()) async def run_skynet( - db_user: str, - db_pass: str, + db_user: str = DB_USER, + db_pass: str = DB_PASS, db_host: str = DB_HOST, rpc_address: str = DEFAULT_RPC_ADDR, dgpu_address: str = DEFAULT_DGPU_ADDR, diff --git a/skynet_bot/constants.py b/skynet_bot/constants.py index a7b21ae..0c0c03b 100644 --- a/skynet_bot/constants.py +++ b/skynet_bot/constants.py @@ -3,6 +3,9 @@ API_TOKEN = '5880619053:AAFge2UfObw1kCn9Kb7AAyqaIHs_HgM0Fx0' DB_HOST = 'ancap.tech:34508' +DB_USER = 'skynet' +DB_PASS = 'password' +DB_NAME = 'skynet' ALGOS = { 'stable': 'runwayml/stable-diffusion-v1-5', diff --git a/skynet_bot/db.py b/skynet_bot/db.py index d5c94d7..9998e77 100644 --- a/skynet_bot/db.py +++ b/skynet_bot/db.py @@ -11,6 +11,49 @@ import triopg from .constants import * +DB_INIT_SQL = ''' +CREATE SCHEMA IF NOT EXISTS skynet; + +CREATE TABLE IF NOT EXISTS skynet.user( + id SERIAL PRIMARY KEY NOT NULL, + tg_id INT, + wp_id VARCHAR(128), + mx_id VARCHAR(128), + ig_id VARCHAR(128), + generated INT NOT NULL, + joined DATE NOT NULL, + last_prompt TEXT, + role VARCHAR(128) NOT NULL +); +ALTER TABLE skynet.user + ADD CONSTRAINT tg_unique + UNIQUE (tg_id); +ALTER TABLE skynet.user + ADD CONSTRAINT wp_unique + UNIQUE (wp_id); +ALTER TABLE skynet.user + ADD CONSTRAINT mx_unique + UNIQUE (mx_id); +ALTER TABLE skynet.user + ADD CONSTRAINT ig_unique + UNIQUE (ig_id); + +CREATE TABLE IF NOT EXISTS skynet.user_config( + id SERIAL NOT NULL, + algo VARCHAR(128) NOT NULL, + step INT NOT NULL, + width INT NOT NULL, + height INT NOT NULL, + seed INT, + guidance INT NOT NULL, + upscaler VARCHAR(128) +); +ALTER TABLE skynet.user_config + ADD FOREIGN KEY(id) + REFERENCES skynet.user(id); +''' + + def try_decode_uid(uid: str): try: proto, uid = uid.split('+') @@ -24,14 +67,18 @@ def try_decode_uid(uid: str): @acm async def open_database_connection( - db_user: str, - db_pass: str, + db_user: str = DB_USER, + db_pass: str = DB_PASS, db_host: str = DB_HOST, + db_name: str = DB_NAME ): async with triopg.create_pool( - dsn=f'postgres://{db_user}:{db_pass}@{db_host}/skynet_art_bot' - ) as conn: - yield conn + dsn=f'postgres://{db_user}:{db_pass}@{db_host}/{db_name}' + ) as pool_conn: + async with pool_conn.acquire() as conn: + await conn.execute(DB_INIT_SQL) + + yield pool_conn async def get_user(conn, uid: str): diff --git a/skynet_bot/dgpu.py b/skynet_bot/dgpu.py index c454717..8c019e4 100644 --- a/skynet_bot/dgpu.py +++ b/skynet_bot/dgpu.py @@ -65,58 +65,53 @@ async def open_dgpu_node( return img - with ( - pynng.Req0(dial=rpc_address) as rpc_sock, - pynng.Bus0(dial=dgpu_address) as dgpu_sock - ): - async def _rpc_call(*args, **kwargs): - return await rpc_call(rpc_sock, *args, **kwargs) + async with open_skynet_rpc() as rpc_call: + with pynng.Bus0(dial=dgpu_address) as dgpu_sock: + async def _process_dgpu_req(req: DGPUBusRequest): + img = await gpu_compute_one( + ImageGenRequest(**req.params)) + await dgpu_sock.asend( + bytes.fromhex(req.rid) + img) - async def _process_dgpu_req(req: DGPUBusRequest): - img = await gpu_compute_one( - ImageGenRequest(**req.params)) - await dgpu_sock.asend( - bytes.fromhex(req.rid) + img) + res = await rpc_call( + name.hex, 'dgpu_online', {'max_tasks': dgpu_max_tasks}) + logging.info(res) + assert 'ok' in res.result - res = await _rpc_call( - name.hex, 'dgpu_online', {'max_tasks': dgpu_max_tasks}) - logging.info(res) - assert 'ok' in res.result - - async with ( - tractor.open_actor_cluster( - modules=['skynet_bot.gpu'], - count=dgpu_max_tasks, - names=[i for i in range(dgpu_max_tasks)] - ) as portal_map, - trio.open_nursery() as n - ): - logging.info(f'starting {dgpu_max_tasks} gpu workers') - async with tractor.gather_contexts(( - portal.open_context( - open_gpu_worker, algo, 1.0 / dgpu_max_tasks) - for portal in portal_map.values() - )) as contexts: - contexts = {i: ctx for i, ctx in enumerate(contexts)} - for i, ctx in contexts.items(): - n.start_soon( - gpu_streamer, ctx, i) - try: - while True: - msg = await dgpu_sock.arecv() - req = DGPUBusRequest( - **json.loads(msg.decode())) - - if req.nid != name.hex: - continue - - logging.info(f'dgpu: {name}, req: {req}') + async with ( + tractor.open_actor_cluster( + modules=['skynet_bot.gpu'], + count=dgpu_max_tasks, + names=[i for i in range(dgpu_max_tasks)] + ) as portal_map, + trio.open_nursery() as n + ): + logging.info(f'starting {dgpu_max_tasks} gpu workers') + async with tractor.gather_contexts(( + portal.open_context( + open_gpu_worker, algo, 1.0 / dgpu_max_tasks) + for portal in portal_map.values() + )) as contexts: + contexts = {i: ctx for i, ctx in enumerate(contexts)} + for i, ctx in contexts.items(): n.start_soon( - _process_dgpu_req, req) + gpu_streamer, ctx, i) + try: + while True: + msg = await dgpu_sock.arecv() + req = DGPUBusRequest( + **json.loads(msg.decode())) - except KeyboardInterrupt: - ... + if req.nid != name.hex: + continue - res = await _rpc_call(name.hex, 'dgpu_offline') - logging.info(res) - assert 'ok' in res.result + logging.info(f'dgpu: {name}, req: {req}') + n.start_soon( + _process_dgpu_req, req) + + except KeyboardInterrupt: + ... + + res = await rpc_call(name.hex, 'dgpu_offline') + logging.info(res) + assert 'ok' in res.result diff --git a/skynet_bot/frontend/__init__.py b/skynet_bot/frontend/__init__.py index 7211eb5..4e728f9 100644 --- a/skynet_bot/frontend/__init__.py +++ b/skynet_bot/frontend/__init__.py @@ -3,14 +3,17 @@ import json from typing import Union -from contextlib import contextmanager as cm +from contextlib import asynccontextmanager as acm import pynng from ..types import SkynetRPCRequest, SkynetRPCResponse -from ..constants import DEFAULT_RPC_ADDR +from ..constants import * +class ConfigRequestFormatError(BaseException): + ... + class ConfigUnknownAttribute(BaseException): ... @@ -44,7 +47,71 @@ async def rpc_call( (await sock.arecv_msg()).bytes.decode())) -@cm -def open_skynet_rpc(rpc_address: str = DEFAULT_RPC_ADDR): - with pynng.Req0(dial=rpc_address) as rpc_sock: - yield rpc_sock +@acm +async def open_skynet_rpc(rpc_address: str = DEFAULT_RPC_ADDR): + with pynng.Req0(dial=rpc_address) as sock: + async def _rpc_call(*args, **kwargs): + return await rpc_call(sock, *args, **kwargs) + + yield _rpc_call + + +def validate_user_config_request(req: str): + params = req.split(' ') + + if len(params) < 3: + raise ConfigRequestFormatError('config request format incorrect') + + else: + try: + attr = params[1] + + if attr == 'algo': + val = params[2] + if val not in ALGOS: + raise ConfigUnknownAlgorithm(f'no algo named {val}') + + elif attr == 'step': + val = int(params[2]) + val = max(min(val, MAX_STEP), MIN_STEP) + + elif attr == 'width': + val = max(min(int(params[2]), MAX_WIDTH), 16) + if val % 8 != 0: + raise ConfigSizeDivisionByEight( + 'size must be divisible by 8!') + + elif attr == 'height': + val = max(min(int(params[2]), MAX_HEIGHT), 16) + if val % 8 != 0: + raise ConfigSizeDivisionByEight( + 'size must be divisible by 8!') + + elif attr == 'seed': + val = params[2] + if val == 'auto': + val = None + else: + val = int(params[2]) + + elif attr == 'guidance': + val = float(params[2]) + val = max(min(val, MAX_GUIDANCE), 0) + + elif attr == 'upscaler': + val = params[2] + if val == 'off': + val = None + elif val != 'x4': + raise ConfigUnknownUpscaler( + f'\"{val}\" is not a valid upscaler') + + else: + raise ConfigUnknownAttribute( + f'\"{attr}\" not a configurable parameter') + + return attr, val, f'config updated! {attr} to {val}' + + except ValueError: + raise ValueError(f'\"{val}\" is not a number silly') + diff --git a/skynet_bot/frontend/telegram.py b/skynet_bot/frontend/telegram.py index 8affa29..cc217f8 100644 --- a/skynet_bot/frontend/telegram.py +++ b/skynet_bot/frontend/telegram.py @@ -3,6 +3,7 @@ import logging from datetime import datetime +from functools import partial import pynng @@ -17,20 +18,21 @@ from . import * PREFIX = 'tg' -async def run_skynet_telegram(tg_token: str): +async def run_skynet_telegram( + tg_token: str +): logging.basicConfig(level=logging.INFO) bot = AsyncTeleBot(tg_token) - with open_skynet_rpc() as rpc_sock: + with open_skynet_rpc() as rpc_call: async def _rpc_call( uid: int, method: str, params: dict ): - return await rpc_call( - rpc_sock, f'{PREFIX}+{uid}', method, params) + return await rpc_call(f'{PREFIX}+{uid}', method, params) @bot.message_handler(commands=['help']) async def send_help(message): @@ -58,79 +60,19 @@ async def run_skynet_telegram(tg_token: str): @bot.message_handler(commands=['config']) async def set_config(message): - params = message.text.split(' ') - rpc_params = {} + try: + attr, val, reply_txt = validate_user_config_request( + message.text) - if len(params) < 3: - bot.reply_to(message, 'wrong msg format') + resp = await _rpc_call( + message.from_user.id, + 'config', {'attr': attr, 'val': val}) - else: - - try: - attr = params[1] - - if attr == 'algo': - val = params[2] - if val not in ALGOS: - raise ConfigUnknownAlgorithm - - elif attr == 'step': - val = int(params[2]) - val = max(min(val, MAX_STEP), MIN_STEP) - - elif attr == 'width': - val = max(min(int(params[2]), MAX_WIDTH), 16) - if val % 8 != 0: - raise ConfigSizeDivisionByEight - - elif attr == 'height': - val = max(min(int(params[2]), MAX_HEIGHT), 16) - if val % 8 != 0: - raise ConfigSizeDivisionByEight - - elif attr == 'seed': - val = params[2] - if val == 'auto': - val = None - else: - val = int(params[2]) - - elif attr == 'guidance': - val = float(params[2]) - val = max(min(val, MAX_GUIDANCE), 0) - - elif attr == 'upscaler': - val = params[2] - if val == 'off': - val = None - elif val != 'x4': - raise ConfigUnknownUpscaler - - else: - raise ConfigUnknownAttribute - - resp = await _rpc_call( - message.from_user.id, - 'config', {'attr': attr, 'val': val}) - - reply_txt = f'config updated! {attr} to {val}' - - except ConfigUnknownAlgorithm: - reply_txt = f'no algo named {val}' - - except ConfigUnknownAttribute: - reply_txt = f'\"{attr}\" not a configurable parameter' - - except ConfigUnknownUpscaler: - reply_txt = f'\"{val}\" is not a valid upscaler' - - except ConfigSizeDivisionByEight: - reply_txt = 'size must be divisible by 8!' - - except ValueError: - reply_txt = f'\"{val}\" is not a number silly' + except BaseException as e: + reply_text = e.message + finally: await bot.reply_to(message, reply_txt) @bot.message_handler(commands=['stats']) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..88d7f76 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,90 @@ +#!/usr/bin/python + +import time +import random +import string +import logging + +from functools import partial + +import trio +import pytest +import psycopg2 +import trio_asyncio + +from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT + +from skynet_bot.constants import * +from skynet_bot.brain import run_skynet + + +@pytest.fixture(scope='session') +def postgres_db(dockerctl): + rpassword = ''.join( + random.choice(string.ascii_lowercase) + for i in range(12)) + password = ''.join( + random.choice(string.ascii_lowercase) + for i in range(12)) + + with dockerctl.run( + 'postgres', + command='postgres', + ports={'5432/tcp': None}, + environment={ + 'POSTGRES_PASSWORD': rpassword + } + ) as containers: + container = containers[0] + # ip = container.attrs['NetworkSettings']['IPAddress'] + port = container.ports['5432/tcp'][0]['HostPort'] + host = f'localhost:{port}' + + for log in container.logs(stream=True): + log = log.decode().rstrip() + logging.info(log) + if ('database system is ready to accept connections' in log or + 'database system is shut down' in log): + break + + # why print the system is ready to accept connections when its not + # postgres? wtf + time.sleep(1) + logging.info('creating skynet db...') + + conn = psycopg2.connect( + user='postgres', + password=rpassword, + host='localhost', + port=port + ) + logging.info('connected...') + conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) + with conn.cursor() as cursor: + cursor.execute( + f'CREATE USER {DB_USER} WITH PASSWORD \'{password}\'') + cursor.execute( + f'CREATE DATABASE {DB_NAME}') + cursor.execute( + f'GRANT ALL PRIVILEGES ON DATABASE {DB_NAME} TO {DB_USER}') + + logging.info('done.') + yield container, password, host + + +@pytest.fixture +async def skynet_running(postgres_db): + db_container, db_pass, db_host = postgres_db + async with ( + trio_asyncio.open_loop(), + trio.open_nursery() as n + ): + await n.start( + partial(run_skynet, + db_pass=db_pass, + db_host=db_host)) + + yield + n.cancel_scope.cancel() + + diff --git a/tests/test_skynet.py b/tests/test_skynet.py new file mode 100644 index 0000000..f1520f8 --- /dev/null +++ b/tests/test_skynet.py @@ -0,0 +1,61 @@ +#!/usr/bin/python + +import logging + +import trio +import trio_asyncio + +from skynet_bot.types import * +from skynet_bot.brain import run_skynet +from skynet_bot.frontend import open_skynet_rpc + + +async def test_skynet_dgpu_connection_simple(skynet_running): + async with open_skynet_rpc() as rpc_call: + # check 0 nodes are connected + res = await rpc_call('dgpu-0', 'dgpu_workers') + logging.info(res) + assert 'ok' in res.result + assert res.result['ok'] == 0 + + # check next worker is None + res = await rpc_call('dgpu-0', 'dgpu_next') + logging.info(res) + assert 'ok' in res.result + assert res.result['ok'] == None + + # connect 1 dgpu + res = await rpc_call( + 'dgpu-0', 'dgpu_online', {'max_tasks': 3}) + logging.info(res) + assert 'ok' in res.result + + # check 1 node is connected + res = await rpc_call('dgpu-0', 'dgpu_workers') + logging.info(res) + assert 'ok' in res.result + assert res.result['ok'] == 1 + + # check next worker is 0 + res = await rpc_call('dgpu-0', 'dgpu_next') + logging.info(res) + assert 'ok' in res.result + assert res.result['ok'] == 0 + + # disconnect 1 dgpu + res = await rpc_call( + 'dgpu-0', 'dgpu_offline') + logging.info(res) + assert 'ok' in res.result + + # check 0 nodes are connected + res = await rpc_call('dgpu-0', 'dgpu_workers') + logging.info(res) + assert 'ok' in res.result + assert res.result['ok'] == 0 + + # check next worker is None + res = await rpc_call('dgpu-0', 'dgpu_next') + logging.info(res) + assert 'ok' in res.result + assert res.result['ok'] == None diff --git a/tests/test_telegram.py b/tests/test_telegram.py deleted file mode 100644 index fe99566..0000000 --- a/tests/test_telegram.py +++ /dev/null @@ -1,22 +0,0 @@ -#!/usr/bin/python - -import trio -import trio_asyncio - -from skynet_bot.brain import run_skynet -from skynet_bot.frontend import open_skynet_rpc -from skynet_bot.frontend.telegram import run_skynet_telegram - - -def test_run_tg_bot(): - async def main(): - async with trio.open_nursery() as n: - await n.start( - run_skynet, - 'skynet', '3GbZd6UojbD8V7zWpeFn', 'ancap.tech:34508') - n.start_soon( - run_skynet_telegram, - '5853245787:AAFEmv3EjJ_qJ8d_vmOpi6o6HFHUf8a0uCQ') - - - trio_asyncio.run(main)