diff --git a/.gitignore b/.gitignore index e264fa9..e98d124 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +skynet.ini .python-version hf_home outputs diff --git a/Dockerfile.runtime+cuda b/Dockerfile.runtime+cuda index 32d3c4a..27a5a66 100644 --- a/Dockerfile.runtime+cuda +++ b/Dockerfile.runtime+cuda @@ -32,3 +32,4 @@ env HF_HOME /hf_home copy scripts scripts copy tests tests +expose 40000-45000 diff --git a/skynet.ini.example b/skynet.ini.example new file mode 100644 index 0000000..7035920 --- /dev/null +++ b/skynet.ini.example @@ -0,0 +1,12 @@ +[skynet] +certs_dir = certs + +[skynet.dgpu] +hf_home = hf_home +hf_token = hf_XxXxXxXxXxXxXxXxXxXxXxXxXxXxXxXxXx + +[skynet.telegram] +token = XXXXXXXXXX:xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx + +[skynet.telegram-test] +token = XXXXXXXXXX:xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx diff --git a/skynet/brain.py b/skynet/brain.py index c442ba5..b121bd3 100644 --- a/skynet/brain.py +++ b/skynet/brain.py @@ -1,35 +1,24 @@ #!/usr/bin/python -import time -import json -import uuid -import zlib import logging -import traceback -from uuid import UUID -from pathlib import Path -from functools import partial from contextlib import asynccontextmanager as acm from collections import OrderedDict import trio -import pynng -import trio_asyncio -from pynng import TLSConfig -from OpenSSL.crypto import ( - load_privatekey, - load_certificate, - FILETYPE_PEM -) +from pynng import Context -from .db import * +from .utils import time_ms +from .network import * +from .protobuf import * from .constants import * -from .protobuf import * +class SkynetRPCBadRequest(BaseException): + ... + class SkynetDGPUOffline(BaseException): ... @@ -44,39 +33,71 @@ class SkynetShutdownRequested(BaseException): @acm -async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): +async def run_skynet( + rpc_address: str = DEFAULT_RPC_ADDR +): + logging.basicConfig(level=logging.INFO) + logging.info('skynet is starting') + nodes = OrderedDict() - wip_reqs = {} - fin_reqs = {} heartbeats = {} next_worker: Optional[int] = None - security = len(tls_whitelist) > 0 - def connect_node(uid): + def connect_node(req: SkynetRPCRequest): nonlocal next_worker - nodes[uid] = { - 'task': None - } - logging.info(f'dgpu online: {uid}') - if not next_worker: - next_worker = 0 + node_params = MessageToDict(req.params) + logging.info(f'got node params {node_params}') + + if 'dgpu_addr' not in node_params: + raise SkynetRPCBadRequest( + f'DGPU connection params don\'t include dgpu addr') + + session = SessionClient( + node_params['dgpu_addr'], + 'skynet', + cert_name='brain.cert', + key_name='brain.key', + ca_name=node_params['cert'] + ) + try: + session.connect() + + node = { + 'task': None, + 'session': session + } + node.update(node_params) + + nodes[req.uid] = node + logging.info(f'DGPU node online: {req.uid}') + + if not next_worker: + next_worker = 0 + + except pynng.exceptions.ConnectionRefused: + logging.warning(f'error while dialing dgpu node... dropping...') + raise SkynetDGPUOffline('Connection to dgpu node addr failed.') def disconnect_node(uid): nonlocal next_worker if uid not in nodes: + logging.warning(f'Attempt to disconnect unknown node {uid}') return + i = list(nodes.keys()).index(uid) + nodes[uid]['session'].disconnect() del nodes[uid] if i < next_worker: next_worker -= 1 + logging.warning(f'DGPU node offline: {uid}') + if len(nodes) == 0: - logging.info('nw: None') + logging.info('All nodes disconnected.') next_worker = None - logging.warning(f'dgpu offline: {uid}') def is_worker_busy(nid: str): return nodes[nid]['task'] != None @@ -90,8 +111,6 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): def get_next_worker(): nonlocal next_worker - logging.info('get next_worker called') - logging.info(f'pre next_worker: {next_worker}') if next_worker == None: raise SkynetDGPUOffline('No workers connected, try again later') @@ -113,392 +132,79 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): if next_worker >= len(nodes): next_worker = 0 - logging.info(f'post next_worker: {next_worker}') - return nid - async def dgpu_heartbeat_service(): - nonlocal heartbeats - while True: - await trio.sleep(60) - rid = uuid.uuid4().hex - beat_msg = DGPUBusMessage( - rid=rid, - nid='', - method='heartbeat' - ) - heartbeats.clear() - heartbeats[rid] = int(time.time() * 1000) - await dgpu_bus.asend(beat_msg.SerializeToString()) - logging.info('sent heartbeat') - - async def dgpu_bus_streamer(): - nonlocal wip_reqs, fin_reqs, heartbeats - while True: - raw_msg = await dgpu_bus.arecv() - logging.info(f'streamer got {len(raw_msg)} bytes.') - msg = DGPUBusMessage() - msg.ParseFromString(raw_msg) - - if security: - verify_protobuf_msg(msg, tls_whitelist[msg.auth.cert]) - - rid = msg.rid - - if msg.method == 'heartbeat': - sent_time = heartbeats[rid] - delta = msg.params['time'] - sent_time - logging.info(f'got heartbeat reply from {msg.nid}, ping: {delta}') - continue - - if rid not in wip_reqs: - continue - - if msg.method == 'binary-reply': - logging.info('bin reply, recv extra data') - raw_img = await dgpu_bus.arecv() - msg = (msg, raw_img) - - fin_reqs[rid] = msg - event = wip_reqs[rid] - event.set() - del wip_reqs[rid] - - async def dgpu_stream_one_img(req: DiffusionParameters, img_buf=None): - nonlocal wip_reqs, fin_reqs, next_worker - nid = get_next_worker() - idx = list(nodes.keys()).index(nid) - logging.info(f'dgpu_stream_one_img {idx}/{len(nodes)} {nid}') - rid = uuid.uuid4().hex - ack_event = trio.Event() - img_event = trio.Event() - wip_reqs[rid] = ack_event - - nodes[nid]['task'] = rid - - dgpu_req = DGPUBusMessage( - rid=rid, - nid=nid, - method='diffuse') - dgpu_req.params.update(req.to_dict()) - - if security: - dgpu_req.auth.cert = 'skynet' - dgpu_req.auth.sig = sign_protobuf_msg(dgpu_req, tls_key) - - msg = dgpu_req.SerializeToString() - if img_buf: - logging.info(f'sending img of size {len(img_buf)} as attachment') - logging.info(img_buf[:10]) - msg = f'BINEXT%$%$'.encode() + msg + b'%$%$' + img_buf - - await dgpu_bus.asend(msg) - - with trio.move_on_after(4): - await ack_event.wait() - - logging.info(f'ack event: {ack_event.is_set()}') - - if not ack_event.is_set(): - disconnect_node(nid) - raise SkynetDGPUOffline('dgpu failed to acknowledge request') - - ack_msg = fin_reqs[rid] - if 'ack' not in ack_msg.params: - disconnect_node(nid) - raise SkynetDGPUOffline('dgpu failed to acknowledge request') - - wip_reqs[rid] = img_event - with trio.move_on_after(30): - await img_event.wait() - - logging.info(f'img event: {ack_event.is_set()}') - - if not img_event.is_set(): - disconnect_node(nid) - raise SkynetDGPUComputeError('30 seconds timeout while processing request') - - nodes[nid]['task'] = None - - resp = fin_reqs[rid] - del fin_reqs[rid] - if isinstance(resp, tuple): - meta, img = resp - return rid, img, meta.params - - raise SkynetDGPUComputeError(MessageToDict(resp.params)) - - - async def handle_user_request(rpc_ctx, req): - try: - async with db_pool.acquire() as conn: - user = await get_or_create_user(conn, req.uid) - - result = {} - - match req.method: - case 'txt2img': - logging.info('txt2img') - user_config = {**(await get_user_config(conn, user))} - del user_config['id'] - user_config.update(MessageToDict(req.params)) - - req = DiffusionParameters(**user_config, image=False) - rid, img, meta = await dgpu_stream_one_img(req) - logging.info(f'done streaming {rid}') - result = { - 'id': rid, - 'img': img.hex(), - 'meta': meta - } - - await update_user_stats(conn, user, last_prompt=user_config['prompt']) - logging.info('updated user stats.') - - case 'img2img': - logging.info('img2img') - user_config = {**(await get_user_config(conn, user))} - del user_config['id'] - - params = MessageToDict(req.params) - img_buf = bytes.fromhex(params['img']) - del params['img'] - user_config.update(params) - - req = DiffusionParameters(**user_config, image=True) - - if not req.image: - raise AssertionError('Didn\'t enable image flag for img2img?') - - rid, img, meta = await dgpu_stream_one_img(req, img_buf=img_buf) - logging.info(f'done streaming {rid}') - result = { - 'id': rid, - 'img': img.hex(), - 'meta': meta - } - - await update_user_stats(conn, user, last_prompt=user_config['prompt']) - logging.info('updated user stats.') - - case 'redo': - logging.info('redo') - user_config = {**(await get_user_config(conn, user))} - del user_config['id'] - prompt = await get_last_prompt_of(conn, user) - - if prompt: - req = DiffusionParameters( - prompt=prompt, - **user_config, - image=False - ) - rid, img, meta = await dgpu_stream_one_img(req) - result = { - 'id': rid, - 'img': img.hex(), - 'meta': meta - } - await update_user_stats(conn, user) - logging.info('updated user stats.') - - else: - result = { - 'error': 'skynet_no_last_prompt', - 'message': 'No prompt to redo, do txt2img first' - } - - case 'config': - logging.info('config') - if req.params['attr'] in CONFIG_ATTRS: - logging.info(f'update: {req.params}') - await update_user_config( - conn, user, req.params['attr'], req.params['val']) - logging.info('done') - - else: - logging.warning(f'{req.params["attr"]} not in {CONFIG_ATTRS}') - - case 'stats': - logging.info('stats') - generated, joined, role = await get_user_stats(conn, user) - - result = { - 'generated': generated, - 'joined': joined.strftime(DATE_FORMAT), - 'role': role - } - - case _: - logging.warn('unknown method') - - except SkynetDGPUOffline as e: - result = { - 'error': 'skynet_dgpu_offline', - 'message': str(e) - } - - except SkynetDGPUOverloaded as e: - result = { - 'error': 'skynet_dgpu_overloaded', - 'message': str(e), - 'nodes': len(nodes) - } - - except SkynetDGPUComputeError as e: - result = { - 'error': 'skynet_dgpu_compute_error', - 'message': str(e) - } - except BaseException as e: - traceback.print_exception(type(e), e, e.__traceback__) - result = { - 'error': 'skynet_internal_error', - 'message': str(e) - } - + async def rpc_handler(req: SkynetRPCRequest, ctx: Context): + result = {'ok': {}} resp = SkynetRPCResponse() - resp.result.update(result) - - if security: - resp.auth.cert = 'skynet' - resp.auth.sig = sign_protobuf_msg(resp, tls_key) - - logging.info('sending response') - await rpc_ctx.asend(resp.SerializeToString()) - rpc_ctx.close() - logging.info('done') - - async def request_service(n): - nonlocal next_worker - while True: - ctx = sock.new_context() - req = SkynetRPCRequest() - req.ParseFromString(await ctx.arecv()) - - if security: - if req.auth.cert not in tls_whitelist: - logging.warning( - f'{req.cert} not in tls whitelist and security=True') - continue - - try: - verify_protobuf_msg(req, tls_whitelist[req.auth.cert]) - - except ValueError: - logging.warning( - f'{req.cert} sent an unauthenticated msg with security=True') - continue - - result = {} + try: match req.method: - case 'skynet_shutdown': - raise SkynetShutdownRequested - case 'dgpu_online': - connect_node(req.uid) + connect_node(req) + + case 'dgpu_call': + nid = get_next_worker() + idx = list(nodes.keys()).index(nid) + node = nodes[nid] + logging.info(f'dgpu_call {idx}/{len(nodes)} {nid} @ {node["dgpu_addr"]}') + dgpu_time = await node['session'].rpc('dgpu_time') + if 'ok' not in dgpu_time.result: + status = MessageToDict(dgpu_time.result) + logging.warning(json.dumps(status, indent=4)) + disconnect_node(nid) + raise SkynetDGPUComputeError(status['error']) + + dgpu_time = dgpu_time.result['ok'] + logging.info(f'ping to {nid}: {time_ms() - dgpu_time} ms') + + try: + dgpu_result = await node['session'].rpc( + timeout=45, # give this 45 sec to run cause its compute + binext=req.bin, + **req.params + ) + result = MessageToDict(dgpu_result.result) + + if dgpu_result.bin: + resp.bin = dgpu_result.bin + + except trio.TooSlowError: + result = {'error': 'timeout while processing request'} case 'dgpu_offline': disconnect_node(req.uid) case 'dgpu_workers': - result = len(nodes) + result = {'ok': len(nodes)} case 'dgpu_next': - result = next_worker + result = {'ok': next_worker} - case 'heartbeat': - logging.info('beat') - result = {'time': time.time()} + case 'skynet_shutdown': + raise SkynetShutdownRequested case _: - n.start_soon( - handle_user_request, ctx, req) - continue + logging.warning(f'Unknown method {req.method}') + result = {'error': 'unknown method'} - resp = SkynetRPCResponse() - resp.result.update({'ok': result}) + except BaseException as e: + result = {'error': str(e)} - if security: - resp.auth.cert = 'skynet' - resp.auth.sig = sign_protobuf_msg(resp, tls_key) + resp.result.update(result) - await ctx.asend(resp.SerializeToString()) + return resp - ctx.close() + rpc_server = SessionServer( + rpc_address, + rpc_handler, + cert_name='brain.cert', + key_name='brain.key' + ) - - async with trio.open_nursery() as n: - n.start_soon(dgpu_bus_streamer) - n.start_soon(dgpu_heartbeat_service) - n.start_soon(request_service, n) - logging.info('starting rpc service') + async with rpc_server.open(): + logging.info('rpc server is up') yield - logging.info('stopping rpc service') - n.cancel_scope.cancel() + logging.info('skynet is shuting down...') - -@acm -async def run_skynet( - db_user: str = DB_USER, - db_pass: str = DB_PASS, - db_host: str = DB_HOST, - rpc_address: str = DEFAULT_RPC_ADDR, - dgpu_address: str = DEFAULT_DGPU_ADDR, - security: bool = True -): - logging.basicConfig(level=logging.INFO) - logging.info('skynet is starting') - - tls_config = None - if security: - # load tls certs - certs_dir = Path(DEFAULT_CERTS_DIR).resolve() - - tls_key_data = (certs_dir / DEFAULT_CERT_SKYNET_PRIV).read_text() - tls_key = load_privatekey(FILETYPE_PEM, tls_key_data) - - tls_cert_data = (certs_dir / DEFAULT_CERT_SKYNET_PUB).read_text() - tls_cert = load_certificate(FILETYPE_PEM, tls_cert_data) - - tls_whitelist = {} - for cert_path in (certs_dir / 'whitelist').glob('*.cert'): - tls_whitelist[cert_path.stem] = load_certificate( - FILETYPE_PEM, cert_path.read_text()) - - cert_start = tls_cert_data.index('\n') + 1 - logging.info(f'tls_cert: {tls_cert_data[cert_start:cert_start+64]}...') - logging.info(f'tls_whitelist len: {len(tls_whitelist)}') - - rpc_address = 'tls+' + rpc_address - dgpu_address = 'tls+' + dgpu_address - tls_config = TLSConfig( - TLSConfig.MODE_SERVER, - own_key_string=tls_key_data, - own_cert_string=tls_cert_data) - - with ( - pynng.Rep0(recv_max_size=0) as rpc_sock, - pynng.Bus0(recv_max_size=0) as dgpu_bus - ): - async with open_database_connection( - db_user, db_pass, db_host) as db_pool: - - logging.info('connected to db.') - if security: - rpc_sock.tls_config = tls_config - dgpu_bus.tls_config = tls_config - - rpc_sock.listen(rpc_address) - dgpu_bus.listen(dgpu_address) - - try: - async with open_rpc_service( - rpc_sock, dgpu_bus, db_pool, tls_whitelist, tls_key): - yield - - except SkynetShutdownRequested: - ... - - logging.info('disconnected from db.') + logging.info('skynet down.') diff --git a/skynet/cli.py b/skynet/cli.py index 2573106..856835f 100644 --- a/skynet/cli.py +++ b/skynet/cli.py @@ -17,8 +17,8 @@ if torch_enabled: from .dgpu import open_dgpu_node from .brain import run_skynet +from .config import * from .constants import ALGOS, DEFAULT_RPC_ADDR, DEFAULT_DGPU_ADDR - from .frontend.telegram import run_skynet_telegram @@ -38,8 +38,8 @@ def skynet(*args, **kwargs): @click.option('--steps', '-s', default=26) @click.option('--seed', '-S', default=None) def txt2img(*args, **kwargs): - assert 'HF_TOKEN' in os.environ - utils.txt2img(os.environ['HF_TOKEN'], **kwargs) + _, hf_token, _, cfg = init_env_from_config() + utils.txt2img(hf_token, **kwargs) @click.command() @click.option('--model', '-m', default='midj') @@ -52,9 +52,9 @@ def txt2img(*args, **kwargs): @click.option('--steps', '-s', default=26) @click.option('--seed', '-S', default=None) def img2img(model, prompt, input, output, strength, guidance, steps, seed): - assert 'HF_TOKEN' in os.environ + _, hf_token, _, cfg = init_env_from_config() utils.img2img( - os.environ['HF_TOKEN'], + hf_token, model=model, prompt=prompt, img_path=input, @@ -85,29 +85,17 @@ def run(*args, **kwargs): @click.option('--loglevel', '-l', default='warning', help='Logging level') @click.option( '--host', '-H', default=DEFAULT_RPC_ADDR) -@click.option( - '--host-dgpu', '-D', default=DEFAULT_DGPU_ADDR) -@click.option( - '--db-host', '-h', default='localhost:5432') -@click.option( - '--db-pass', '-p', default='password') def brain( loglevel: str, - host: str, - host_dgpu: str, - db_host: str, - db_pass: str + host: str ): async def _run_skynet(): async with run_skynet( - db_host=db_host, - db_pass=db_pass, - rpc_address=host, - dgpu_address=host_dgpu + rpc_address=host ): await trio.sleep_forever() - trio_asyncio.run(_run_skynet) + trio.run(_run_skynet) @run.command() @@ -115,9 +103,9 @@ def brain( @click.option( '--uid', '-u', required=True) @click.option( - '--key', '-k', default='dgpu') + '--key', '-k', default='dgpu.key') @click.option( - '--cert', '-c', default='whitelist/dgpu') + '--cert', '-c', default='whitelist/dgpu.cert') @click.option( '--algos', '-a', default=json.dumps(['midj'])) @click.option( @@ -159,11 +147,11 @@ def telegram( cert: str, rpc: str ): - assert 'TG_TOKEN' in os.environ + _, _, tg_token, cfg = init_env_from_config() trio_asyncio.run( partial( run_skynet_telegram, - os.environ['TG_TOKEN'], + tg_token, key_name=key, cert_name=cert, rpc_address=rpc diff --git a/skynet/config.py b/skynet/config.py new file mode 100644 index 0000000..65158c6 --- /dev/null +++ b/skynet/config.py @@ -0,0 +1,36 @@ +#!/usr/bin/python + +from pathlib import Path +from configparser import ConfigParser + +from .constants import DEFAULT_CONFIG_PATH + + +def load_skynet_ini( + file_path=DEFAULT_CONFIG_PATH +): + config = ConfigParser() + config.read(file_path) + return config + + +def init_env_from_config( + file_path=DEFAULT_CONFIG_PATH +): + config = load_skynet_ini() + if 'HF_TOKEN' in os.environ: + hf_token = os.environ['HF_TOKEN'] + else: + hf_token = config['skynet']['dgpu']['hf_token'] + + if 'HF_HOME' in os.environ: + hf_home = os.environ['HF_HOME'] + else: + hf_home = config['skynet']['dgpu']['hf_home'] + + if 'TG_TOKEN' in os.environ: + tg_token = os.environ['TG_TOKEN'] + else: + tg_token = config['skynet']['telegram']['token'] + + return hf_home, hf_token, tg_token, config diff --git a/skynet/constants.py b/skynet/constants.py index 1478269..3d96a2c 100644 --- a/skynet/constants.py +++ b/skynet/constants.py @@ -1,14 +1,9 @@ #!/usr/bin/python -VERSION = '0.1a8' +VERSION = '0.1a9' DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda' -DB_HOST = 'localhost:5432' -DB_USER = 'skynet' -DB_PASS = 'password' -DB_NAME = 'skynet' - ALGOS = { 'midj': 'prompthero/openjourney', 'stable': 'runwayml/stable-diffusion-v1-5', @@ -118,6 +113,7 @@ DEFAULT_ALGO = 'midj' DEFAULT_ROLE = 'pleb' DEFAULT_UPSCALER = None +DEFAULT_CONFIG_PATH = 'skynet.ini' DEFAULT_CERTS_DIR = 'certs' DEFAULT_CERT_WHITELIST_DIR = 'whitelist' DEFAULT_CERT_SKYNET_PUB = 'brain.cert' diff --git a/skynet/db/__init__.py b/skynet/db/__init__.py new file mode 100644 index 0000000..fd45c9e --- /dev/null +++ b/skynet/db/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/python + +from .proxy import open_database_connection + +from .functions import open_new_database diff --git a/skynet/db.py b/skynet/db/functions.py similarity index 73% rename from skynet/db.py rename to skynet/db/functions.py index fbcf202..10863c2 100644 --- a/skynet/db.py +++ b/skynet/db/functions.py @@ -1,18 +1,21 @@ #!/usr/bin/python +import time +import random +import string import logging from typing import Optional from datetime import datetime -from contextlib import asynccontextmanager as acm +from contextlib import contextmanager as cm -import trio -import triopg -import trio_asyncio +import docker +import psycopg2 from asyncpg.exceptions import UndefinedColumnError +from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT -from .constants import * +from ..constants import * DB_INIT_SQL = ''' @@ -75,29 +78,67 @@ def try_decode_uid(uid: str): return None, None -@acm -async def open_database_connection( - db_user: str = DB_USER, - db_pass: str = DB_PASS, - db_host: str = DB_HOST, - db_name: str = DB_NAME -): - async with trio_asyncio.open_loop() as loop: - async with triopg.create_pool( - dsn=f'postgres://{db_user}:{db_pass}@{db_host}/{db_name}' - ) as pool_conn: - async with pool_conn.acquire() as conn: - res = await conn.execute(f''' - select distinct table_schema - from information_schema.tables - where table_schema = \'{db_name}\' - ''') - if '1' in res: - logging.info('schema already in db, skipping init') - else: - await conn.execute(DB_INIT_SQL) +@cm +def open_new_database(): + rpassword = ''.join( + random.choice(string.ascii_lowercase) + for i in range(12)) + password = ''.join( + random.choice(string.ascii_lowercase) + for i in range(12)) - yield pool_conn + dclient = docker.from_env() + + container = dclient.containers.run( + 'postgres', + name='skynet-test-postgres', + ports={'5432/tcp': None}, + environment={ + 'POSTGRES_PASSWORD': rpassword + }, + detach=True, + remove=True + ) + + for log in container.logs(stream=True): + log = log.decode().rstrip() + logging.info(log) + if ('database system is ready to accept connections' in log or + 'database system is shut down' in log): + break + + # ip = container.attrs['NetworkSettings']['IPAddress'] + container.reload() + port = container.ports['5432/tcp'][0]['HostPort'] + host = f'localhost:{port}' + + # why print the system is ready to accept connections when its not + # postgres? wtf + time.sleep(1) + logging.info('creating skynet db...') + + conn = psycopg2.connect( + user='postgres', + password=rpassword, + host='localhost', + port=port + ) + logging.info('connected...') + conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) + with conn.cursor() as cursor: + cursor.execute( + f'CREATE USER skynet WITH PASSWORD \'{password}\'') + cursor.execute( + f'CREATE DATABASE skynet') + cursor.execute( + f'GRANT ALL PRIVILEGES ON DATABASE skynet TO skynet') + + conn.close() + + logging.info('done.') + yield container, password, host + + container.stop() async def get_user(conn, uid: str): diff --git a/skynet/db/proxy.py b/skynet/db/proxy.py new file mode 100644 index 0000000..d2f86c1 --- /dev/null +++ b/skynet/db/proxy.py @@ -0,0 +1,123 @@ +#!/usr/bin/python + +import importlib + +from contextlib import asynccontextmanager as acm + +import trio +import tractor +import asyncpg +import asyncio +import trio_asyncio + + +_spawn_kwargs = { + 'infect_asyncio': True, +} + + +async def aio_db_proxy( + to_trio: trio.MemorySendChannel, + from_trio: asyncio.Queue, + db_user: str = 'skynet', + db_pass: str = 'password', + db_host: str = 'localhost:5432', + db_name: str = 'skynet' +) -> None: + db = importlib.import_module('skynet.db.functions') + + pool = await asyncpg.create_pool( + dsn=f'postgres://{db_user}:{db_pass}@{db_host}/{db_name}') + + async with pool_conn.acquire() as conn: + res = await conn.execute(f''' + select distinct table_schema + from information_schema.tables + where table_schema = \'{db_name}\' + ''') + if '1' in res: + logging.info('schema already in db, skipping init') + else: + await conn.execute(DB_INIT_SQL) + + # a first message must be sent **from** this ``asyncio`` + # task or the ``trio`` side will never unblock from + # ``tractor.to_asyncio.open_channel_from():`` + to_trio.send_nowait('start') + + # XXX: this uses an ``from_trio: asyncio.Queue`` currently but we + # should probably offer something better. + while True: + msg = await from_trio.get() + + method = getattr(db, msg.get('method')) + args = getattr(db, msg.get('args', [])) + kwargs = getattr(db, msg.get('kwargs', {})) + + async with pool_conn.acquire() as conn: + result = await method(conn, *args, **kwargs) + to_trio.send_nowait(result) + + +@tractor.context +async def trio_to_aio_db_proxy( + ctx: tractor.Context, + db_user: str = 'skynet', + db_pass: str = 'password', + db_host: str = 'localhost:5432', + db_name: str = 'skynet' +): + # this will block until the ``asyncio`` task sends a "first" + # message. + async with tractor.to_asyncio.open_channel_from( + aio_db_proxy, + db_user=db_user, + db_pass=db_pass, + db_host=db_host, + db_name=db_name + ) as (first, chan): + + assert first == 'start' + await ctx.started(first) + + async with ctx.open_stream() as stream: + + async for msg in stream: + await chan.send(msg) + + out = await chan.receive() + # echo back to parent actor-task + await stream.send(out) + + +@acm +async def open_database_connection( + db_user: str = 'skynet', + db_pass: str = 'password', + db_host: str = 'localhost:5432', + db_name: str = 'skynet' +): + async with tractor.open_nursery() as n: + p = await n.start_actor( + 'aio_db_proxy', + enable_modules=[__name__], + infect_asyncio=True, + ) + async with p.open_context( + trio_to_aio_db_proxy, + db_user=db_user, + db_pass=db_pass, + db_host=db_host, + db_name=db_name + ) as (ctx, first): + async with ctx.open_stream() as stream: + + async def _db_pc(method: str, *args, **kwargs): + await stream.send({ + 'method': method, + 'args': args, + 'kwargs': kwargs + }) + return await stream.receive() + + yield _db_pc diff --git a/skynet/dgpu.py b/skynet/dgpu.py index 752c8b8..79c6c49 100644 --- a/skynet/dgpu.py +++ b/skynet/dgpu.py @@ -2,29 +2,17 @@ import gc import io -import trio import json -import uuid -import time -import zlib import random import logging -import traceback from PIL import Image from typing import List, Optional -from pathlib import Path -from contextlib import ExitStack -import pynng +import trio import torch -from pynng import TLSConfig -from OpenSSL.crypto import ( - load_privatekey, - load_certificate, - FILETYPE_PEM -) +from pynng import Context from diffusers import ( StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, @@ -34,12 +22,9 @@ from realesrgan import RealESRGANer from basicsr.archs.rrdbnet_arch import RRDBNet from diffusers.models import UNet2DConditionModel -from .utils import ( - pipeline_for, - convert_from_cv2_to_image, convert_from_image_to_cv2 -) +from .utils import * +from .network import * from .protobuf import * -from .frontend import open_skynet_rpc from .constants import * @@ -64,65 +49,16 @@ class DGPUComputeError(BaseException): ... -class ReconnectingBus: - - def __init__(self, address: str, tls_config: Optional[TLSConfig]): - self.address = address - self.tls_config = tls_config - - self._stack = ExitStack() - self._sock = None - self._closed = True - - def connect(self): - self._sock = self._stack.enter_context( - pynng.Bus0(recv_max_size=0)) - self._sock.tls_config = self.tls_config - self._sock.dial(self.address) - self._closed = False - - async def arecv(self): - while True: - try: - return await self._sock.arecv() - - except pynng.exceptions.Closed: - if self._closed: - raise - - async def asend(self, msg): - while True: - try: - return await self._sock.asend(msg) - - except pynng.exceptions.Closed: - if self._closed: - raise - - def close(self): - self._stack.close() - self._stack = ExitStack() - self._closed = True - - def reconnect(self): - self.close() - self.connect() - - async def open_dgpu_node( cert_name: str, unique_id: str, key_name: Optional[str], rpc_address: str = DEFAULT_RPC_ADDR, dgpu_address: str = DEFAULT_DGPU_ADDR, - initial_algos: Optional[List[str]] = None, - security: bool = True + initial_algos: Optional[List[str]] = None ): - logging.basicConfig(level=logging.INFO) + logging.basicConfig(level=logging.DEBUG) logging.info(f'starting dgpu node!') - - name = uuid.uuid4() - logging.info(f'loading models...') upscaler = init_upscaler() @@ -141,241 +77,140 @@ async def open_dgpu_node( logging.info('memory summary:') logging.info('\n' + torch.cuda.memory_summary()) - async def gpu_compute_one(ireq: DiffusionParameters, image=None): - algo = ireq.algo + 'img' if image else ireq.algo - if algo not in models: - least_used = list(models.keys())[0] - for model in models: - if models[least_used]['generated'] > models[model]['generated']: - least_used = model + async def gpu_compute_one(method: str, params: dict, binext: Optional[bytes] = None): + match method: + case 'diffuse': + image = None + algo = params['algo'] + if binext: + algo += 'img' + image = Image.open(io.BytesIO(binext)) + w, h = image.size + logging.info(f'user sent img of size {image.size}') - del models[least_used] - gc.collect() + if w > 512 or h > 512: + image.thumbnail((512, 512)) + logging.info(f'resized it to {image.size}') - models[algo] = { - 'pipe': pipeline_for(ireq.algo, image=True if image else False), - 'generated': 0 - } + if algo not in models: + logging.info(f'{algo} not in loaded models, swapping...') + least_used = list(models.keys())[0] + for model in models: + if models[least_used]['generated'] > models[model]['generated']: + least_used = model - _params = {} - if ireq.image: - _params['image'] = image - _params['strength'] = ireq.strength + del models[least_used] + gc.collect() - else: - _params['width'] = int(ireq.width) - _params['height'] = int(ireq.height) + models[algo] = { + 'pipe': pipeline_for(params['algo'], image=True if binext else False), + 'generated': 0 + } + logging.info(f'swapping done.') - try: - image = models[algo]['pipe']( - ireq.prompt, - **_params, - guidance_scale=ireq.guidance, - num_inference_steps=int(ireq.step), - generator=torch.Generator("cuda").manual_seed(ireq.seed) - ).images[0] + _params = {} + logging.info(method) + logging.info(json.dumps(params, indent=4)) + logging.info(f'binext: {len(binext) if binext else 0} bytes') + if binext: + _params['image'] = image + _params['strength'] = params['strength'] - if ireq.upscaler == 'x4': - logging.info(f'size: {len(image.tobytes())}') - logging.info('performing upscale...') - input_img = image.convert('RGB') - up_img, _ = upscaler.enhance( - convert_from_image_to_cv2(input_img), outscale=4) + else: + _params['width'] = int(params['width']) + _params['height'] = int(params['height']) - image = convert_from_cv2_to_image(up_img) - logging.info('done') + try: + image = models[algo]['pipe']( + params['prompt'], + **_params, + guidance_scale=params['guidance'], + num_inference_steps=int(params['step']), + generator=torch.Generator("cuda").manual_seed( + int(params['seed']) if params['seed'] else random.randint(0, 2 ** 64) + ) + ).images[0] - img_byte_arr = io.BytesIO() - image.save(img_byte_arr, format='PNG') - raw_img = img_byte_arr.getvalue() - logging.info(f'final img size {len(raw_img)} bytes.') + if params['upscaler'] == 'x4': + logging.info(f'size: {len(image.tobytes())}') + logging.info('performing upscale...') + input_img = image.convert('RGB') + up_img, _ = upscaler.enhance( + convert_from_image_to_cv2(input_img), outscale=4) - return raw_img + image = convert_from_cv2_to_image(up_img) + logging.info('done') - except BaseException as e: - logging.error(e) - raise DGPUComputeError(str(e)) + img_byte_arr = io.BytesIO() + image.save(img_byte_arr, format='PNG') + raw_img = img_byte_arr.getvalue() + logging.info(f'final img size {len(raw_img)} bytes.') - finally: - torch.cuda.empty_cache() + return raw_img + + except BaseException as e: + logging.error(e) + raise DGPUComputeError(str(e)) + + finally: + torch.cuda.empty_cache() + + case _: + raise DGPUComputeError('Unsupported compute method') + + async def rpc_handler(req: SkynetRPCRequest, ctx: Context): + result = {} + resp = SkynetRPCResponse() + + match req.method: + case 'dgpu_time': + result = {'ok': time_ms()} + + case _: + logging.debug(f'dgpu got one request: {req.method}') + try: + resp.bin = await gpu_compute_one( + req.method, MessageToDict(req.params), + binext=req.bin if req.bin else None + ) + logging.debug(f'dgpu processed one request') + + except DGPUComputeError as e: + result = {'error': str(e)} + + resp.result.update(result) + return resp + + rpc_server = SessionServer( + dgpu_address, + rpc_handler, + cert_name=cert_name, + key_name=key_name + ) + skynet_rpc = SessionClient( + rpc_address, + unique_id, + cert_name=cert_name, + key_name=key_name + ) + skynet_rpc.connect() - async with ( - open_skynet_rpc( - unique_id, - rpc_address=rpc_address, - security=security, - cert_name=cert_name, - key_name=key_name - ) as rpc_call, - trio.open_nursery() as n - ): + async with rpc_server.open() as rpc_server: + res = await skynet_rpc.rpc( + 'dgpu_online', { + 'dgpu_addr': rpc_server.addr, + 'cert': cert_name + }) - tls_config = None - if security: - # load tls certs - if not key_name: - key_name = cert_name - - certs_dir = Path(DEFAULT_CERTS_DIR).resolve() - - skynet_cert_path = certs_dir / 'brain.cert' - tls_cert_path = certs_dir / f'{cert_name}.cert' - tls_key_path = certs_dir / f'{key_name}.key' - - cert_name = tls_cert_path.stem - - skynet_cert_data = skynet_cert_path.read_text() - skynet_cert = load_certificate(FILETYPE_PEM, skynet_cert_data) - - tls_cert_data = tls_cert_path.read_text() - - tls_key_data = tls_key_path.read_text() - tls_key = load_privatekey(FILETYPE_PEM, tls_key_data) - - logging.info(f'skynet cert: {skynet_cert_path}') - logging.info(f'dgpu cert: {tls_cert_path}') - logging.info(f'dgpu key: {tls_key_path}') - - dgpu_address = 'tls+' + dgpu_address - tls_config = TLSConfig( - TLSConfig.MODE_CLIENT, - own_key_string=tls_key_data, - own_cert_string=tls_cert_data, - ca_string=skynet_cert_data) - - logging.info(f'connecting to {dgpu_address}') - - dgpu_bus = ReconnectingBus(dgpu_address, tls_config) - dgpu_bus.connect() - - last_msg = time.time() - async def connection_refresher(refresh_time: int = 120): - nonlocal last_msg - while True: - now = time.time() - last_msg_time_delta = now - last_msg - logging.info(f'time since last msg: {last_msg_time_delta}') - if last_msg_time_delta > refresh_time: - dgpu_bus.reconnect() - logging.info('reconnected!') - last_msg = now - - await trio.sleep(refresh_time) - - n.start_soon(connection_refresher) - - res = await rpc_call('dgpu_online') assert 'ok' in res.result try: - while True: - msg = await dgpu_bus.arecv() - - img = None - if b'BINEXT' in msg: - header, msg, img_raw = msg.split(b'%$%$') - logging.info(f'got img attachment of size {len(img_raw)}') - logging.info(img_raw[:10]) - raw_img = zlib.decompress(img_raw) - logging.info(raw_img[:10]) - img = Image.open(io.BytesIO(raw_img)) - w, h = img.size - logging.info(f'user sent img of size {img.size}') - - if w > 512 or h > 512: - img.thumbnail((512, 512)) - logging.info(f'resized it to {img.size}') - - - req = DGPUBusMessage() - req.ParseFromString(msg) - last_msg = time.time() - - if req.method == 'heartbeat': - rep = DGPUBusMessage( - rid=req.rid, - nid=unique_id, - method=req.method - ) - rep.params.update({'time': int(time.time() * 1000)}) - - if security: - rep.auth.cert = cert_name - rep.auth.sig = sign_protobuf_msg(rep, tls_key) - - await dgpu_bus.asend(rep.SerializeToString()) - logging.info('heartbeat reply') - continue - - if req.nid != unique_id: - logging.info( - f'witnessed msg {req.rid}, node involved: {req.nid}') - continue - - if security: - verify_protobuf_msg(req, skynet_cert) - - - ack_resp = DGPUBusMessage( - rid=req.rid, - nid=req.nid - ) - ack_resp.params.update({'ack': {}}) - - if security: - ack_resp.auth.cert = cert_name - ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key) - - # send ack - await dgpu_bus.asend(ack_resp.SerializeToString()) - - logging.info(f'sent ack, processing {req.rid}...') - - try: - img_req = DiffusionParameters(**req.params) - - if not img_req.seed: - img_req.seed = random.randint(0, 2 ** 64) - - img = await gpu_compute_one(img_req, image=img) - img_resp = DGPUBusMessage( - rid=req.rid, - nid=req.nid, - method='binary-reply' - ) - img_resp.params.update({ - 'len': len(img), - 'meta': img_req.to_dict() - }) - - except DGPUComputeError as e: - traceback.print_exception(type(e), e, e.__traceback__) - img_resp = DGPUBusMessage( - rid=req.rid, - nid=req.nid - ) - img_resp.params.update({'error': str(e)}) - - - if security: - img_resp.auth.cert = cert_name - img_resp.auth.sig = sign_protobuf_msg(img_resp, tls_key) - - # send final image - logging.info('sending img back...') - raw_msg = img_resp.SerializeToString() - await dgpu_bus.asend(raw_msg) - logging.info(f'sent {len(raw_msg)} bytes.') - if img_resp.method == 'binary-reply': - await dgpu_bus.asend(zlib.compress(img)) - logging.info(f'sent {len(img)} bytes.') + await trio.sleep_forever() except KeyboardInterrupt: logging.info('interrupt caught, stopping...') - n.cancel_scope.cancel() - dgpu_bus.close() finally: - res = await rpc_call('dgpu_offline') + res = await skynet_rpc.rpc('dgpu_offline') assert 'ok' in res.result diff --git a/skynet/frontend/__init__.py b/skynet/frontend/__init__.py index f8193a2..04d6b90 100644 --- a/skynet/frontend/__init__.py +++ b/skynet/frontend/__init__.py @@ -4,7 +4,7 @@ import json from typing import Union, Optional from pathlib import Path -from contextlib import asynccontextmanager as acm +from contextlib import contextmanager as cm import pynng @@ -17,6 +17,7 @@ from OpenSSL.crypto import ( from google.protobuf.struct_pb2 import Struct +from ..network import SessionClient from ..constants import * from ..protobuf.auth import * @@ -39,75 +40,23 @@ class ConfigSizeDivisionByEight(BaseException): ... -@acm -async def open_skynet_rpc( +@cm +def open_skynet_rpc( unique_id: str, rpc_address: str = DEFAULT_RPC_ADDR, - security: bool = False, cert_name: Optional[str] = None, key_name: Optional[str] = None ): - tls_config = None - - if security: - # load tls certs - if not key_name: - key_name = cert_name - - certs_dir = Path(DEFAULT_CERTS_DIR).resolve() - - skynet_cert_data = (certs_dir / 'brain.cert').read_text() - skynet_cert = load_certificate(FILETYPE_PEM, skynet_cert_data) - - tls_cert_path = certs_dir / f'{cert_name}.cert' - tls_cert_data = tls_cert_path.read_text() - tls_cert = load_certificate(FILETYPE_PEM, tls_cert_data) - cert_name = tls_cert_path.stem - - tls_key_data = (certs_dir / f'{key_name}.key').read_text() - tls_key = load_privatekey(FILETYPE_PEM, tls_key_data) - - rpc_address = 'tls+' + rpc_address - tls_config = TLSConfig( - TLSConfig.MODE_CLIENT, - own_key_string=tls_key_data, - own_cert_string=tls_cert_data, - ca_string=skynet_cert_data) - - with pynng.Req0(recv_max_size=0) as sock: - if security: - sock.tls_config = tls_config - - sock.dial(rpc_address) - - async def _rpc_call( - method: str, - params: dict = {}, - uid: Optional[str] = None - ): - req = SkynetRPCRequest() - req.uid = uid if uid else unique_id - req.method = method - req.params.update(params) - - if security: - req.auth.cert = cert_name - req.auth.sig = sign_protobuf_msg(req, tls_key) - - ctx = sock.new_context() - await ctx.asend(req.SerializeToString()) - - resp = SkynetRPCResponse() - resp.ParseFromString(await ctx.arecv()) - ctx.close() - - if security: - verify_protobuf_msg(resp, skynet_cert) - - return resp - - yield _rpc_call - + sesh = SessionClient( + rpc_address, + unique_id, + cert_name=cert_name, + key_name=key_name + ) + logging.debug(f'opening skynet rpc...') + sesh.connect() + yield sesh + sesh.disconnect() def validate_user_config_request(req: str): params = req.split(' ') diff --git a/skynet/frontend/telegram.py b/skynet/frontend/telegram.py index 3287b3a..65a6fcb 100644 --- a/skynet/frontend/telegram.py +++ b/skynet/frontend/telegram.py @@ -6,8 +6,6 @@ import logging from datetime import datetime -import pynng - from PIL import Image from trio_asyncio import aio_as_trio @@ -16,6 +14,7 @@ from telebot.types import ( ) from telebot.async_telebot import AsyncTeleBot +from ..db import open_database_connection from ..constants import * from . import * @@ -56,228 +55,274 @@ def prepare_metainfo_caption(tguser, meta: dict) -> str: async def run_skynet_telegram( + name: str, tg_token: str, - key_name: str = 'telegram-frontend', - cert_name: str = 'whitelist/telegram-frontend', - rpc_address: str = DEFAULT_RPC_ADDR + key_name: str = 'telegram-frontend.key', + cert_name: str = 'whitelist/telegram-frontend.cert', + rpc_address: str = DEFAULT_RPC_ADDR, + db_host: str = 'localhost:5432', + db_user: str = 'skynet', + db_pass: str = 'password' ): logging.basicConfig(level=logging.INFO) bot = AsyncTeleBot(tg_token) + logging.info(f'tg_token: {tg_token}') - async with open_skynet_rpc( - 'skynet-telegram-0', - rpc_address=rpc_address, - security=True, - cert_name=cert_name, - key_name=key_name - ) as rpc_call: + async with open_database_connection( + db_user, db_pass, db_host + ) as db_call: + with open_skynet_rpc( + f'skynet-telegram-{name}', + rpc_address=rpc_address, + cert_name=cert_name, + key_name=key_name + ) as session: - async def _rpc_call( - uid: int, - method: str, - params: dict = {} - ): - return await rpc_call( - method, params, uid=f'{PREFIX}+{uid}') + @bot.message_handler(commands=['help']) + async def send_help(message): + splt_msg = message.text.split(' ') - @bot.message_handler(commands=['help']) - async def send_help(message): - splt_msg = message.text.split(' ') - - if len(splt_msg) == 1: - await bot.reply_to(message, HELP_TEXT) - - else: - param = splt_msg[1] - if param in HELP_TOPICS: - await bot.reply_to(message, HELP_TOPICS[param]) + if len(splt_msg) == 1: + await bot.reply_to(message, HELP_TEXT) else: - await bot.reply_to(message, HELP_UNKWNOWN_PARAM) + param = splt_msg[1] + if param in HELP_TOPICS: + await bot.reply_to(message, HELP_TOPICS[param]) - @bot.message_handler(commands=['cool']) - async def send_cool_words(message): - await bot.reply_to(message, '\n'.join(COOL_WORDS)) + else: + await bot.reply_to(message, HELP_UNKWNOWN_PARAM) - @bot.message_handler(commands=['txt2img']) - async def send_txt2img(message): - chat = message.chat + @bot.message_handler(commands=['cool']) + async def send_cool_words(message): + await bot.reply_to(message, '\n'.join(COOL_WORDS)) - prompt = ' '.join(message.text.split(' ')[1:]) + @bot.message_handler(commands=['txt2img']) + async def send_txt2img(message): + chat = message.chat + reply_id = None + if chat.type == 'group' and chat.id == GROUP_ID: + reply_id = message.message_id - if len(prompt) == 0: - await bot.reply_to(message, 'Empty text prompt ignored.') - return + user_id = f'tg+{message.from_user.id}' - logging.info(f'mid: {message.id}') - resp = await _rpc_call( - message.from_user.id, - 'txt2img', - {'prompt': prompt} - ) - logging.info(f'resp to {message.id} arrived') + prompt = ' '.join(message.text.split(' ')[1:]) - resp_txt = '' - result = MessageToDict(resp.result) - if 'error' in resp.result: - resp_txt = resp.result['message'] + if len(prompt) == 0: + await bot.reply_to(message, 'Empty text prompt ignored.') + return - else: - logging.info(result['id']) - img_raw = zlib.decompress(bytes.fromhex(result['img'])) - logging.info(f'got image of size: {len(img_raw)}') - img = Image.open(io.BytesIO(img_raw)) + logging.info(f'mid: {message.id}') + user = await db_call('get_or_create_user', user_id) + user_config = {**(await db_call('get_user_config', user))} + del user_config['id'] - await bot.send_photo( - GROUP_ID, - caption=prepare_metainfo_caption(message.from_user, result['meta']['meta']), - photo=img, - reply_markup=build_redo_menu() + resp = await session.rpc( + 'dgpu_call', { + 'method': 'diffuse', + 'params': { + 'prompt': prompt, + **user_config + } + }, + timeout=60 ) - return + logging.info(f'resp to {message.id} arrived') - await bot.reply_to(message, resp_txt) + resp_txt = '' + result = MessageToDict(resp.result) + if 'error' in resp.result: + resp_txt = resp.result['message'] + await bot.reply_to(message, resp_txt) - @bot.message_handler(func=lambda message: True, content_types=['photo']) - async def send_img2img(message): - chat = message.chat + else: + logging.info(result['id']) + img_raw = resp.bin + logging.info(f'got image of size: {len(img_raw)}') + img = Image.open(io.BytesIO(img_raw)) - if not message.caption.startswith('/img2img'): - return + await bot.send_photo( + GROUP_ID, + caption=prepare_metainfo_caption(message.from_user, result['meta']['meta']), + photo=img, + reply_to_message_id=reply_id, + reply_markup=build_redo_menu() + ) + return - prompt = ' '.join(message.caption.split(' ')[1:]) - if len(prompt) == 0: - await bot.reply_to(message, 'Empty text prompt ignored.') - return + @bot.message_handler(func=lambda message: True, content_types=['photo']) + async def send_img2img(message): + chat = message.chat + reply_id = None + if chat.type == 'group' and chat.id == GROUP_ID: + reply_id = message.message_id - file_id = message.photo[-1].file_id - file_path = (await bot.get_file(file_id)).file_path - file_raw = await bot.download_file(file_path) - img = zlib.compress(file_raw) + user_id = f'tg+{message.from_user.id}' - logging.info(f'mid: {message.id}') - resp = await _rpc_call( - message.from_user.id, - 'img2img', - {'prompt': prompt, 'img': img.hex()} - ) - logging.info(f'resp to {message.id} arrived') + if not message.caption.startswith('/img2img'): + await bot.reply_to( + message, + 'For image to image you need to add /img2img to the beggining of your caption' + ) + return - resp_txt = '' - result = MessageToDict(resp.result) - if 'error' in resp.result: - resp_txt = resp.result['message'] + prompt = ' '.join(message.caption.split(' ')[1:]) - else: - logging.info(result['id']) - img_raw = zlib.decompress(bytes.fromhex(result['img'])) - logging.info(f'got image of size: {len(img_raw)}') - img = Image.open(io.BytesIO(img_raw)) + if len(prompt) == 0: + await bot.reply_to(message, 'Empty text prompt ignored.') + return - await bot.send_media_group( - GROUP_ID, - media=[ - InputMediaPhoto(file_id), - InputMediaPhoto( - img, - caption=prepare_metainfo_caption(message.from_user, result['meta']['meta']) - ) - ] + file_id = message.photo[-1].file_id + file_path = (await bot.get_file(file_id)).file_path + file_raw = await bot.download_file(file_path) + + logging.info(f'mid: {message.id}') + + user = await db_call('get_or_create_user', user_id) + user_config = {**(await db_call('get_user_config', user))} + del user_config['id'] + + resp = await session.rpc( + 'dgpu_call', { + 'method': 'diffuse', + 'params': { + 'prompt': prompt, + **user_config + } + }, + binext=file_raw, + timeout=60 ) - return + logging.info(f'resp to {message.id} arrived') - await bot.reply_to(message, resp_txt) + resp_txt = '' + result = MessageToDict(resp.result) + if 'error' in resp.result: + resp_txt = resp.result['message'] + await bot.reply_to(message, resp_txt) - @bot.message_handler(commands=['img2img']) - async def redo_txt2img(message): - await bot.reply_to( - message, - 'seems you tried to do an img2img command without sending image' - ) + else: + logging.info(result['id']) + img_raw = resp.bin + logging.info(f'got image of size: {len(img_raw)}') + img = Image.open(io.BytesIO(img_raw)) - async def _redo(message): - resp = await _rpc_call(message.from_user.id, 'redo') + await bot.send_media_group( + GROUP_ID, + media=[ + InputMediaPhoto(file_id), + InputMediaPhoto( + img, + caption=prepare_metainfo_caption(message.from_user, result['meta']['meta']) + ) + ], + reply_to_message_id=reply_id + ) + return - resp_txt = '' - result = MessageToDict(resp.result) - if 'error' in resp.result: - resp_txt = resp.result['message'] - else: - logging.info(result['id']) - img_raw = zlib.decompress(bytes.fromhex(result['img'])) - logging.info(f'got image of size: {len(img_raw)}') - img = Image.open(io.BytesIO(img_raw)) - - await bot.send_photo( - GROUP_ID, - caption=prepare_metainfo_caption(message.from_user, result['meta']['meta']), - photo=img, - reply_markup=build_redo_menu() + @bot.message_handler(commands=['img2img']) + async def img2img_missing_image(message): + await bot.reply_to( + message, + 'seems you tried to do an img2img command without sending image' ) - return - await bot.reply_to(message, resp_txt) + @bot.message_handler(commands=['redo']) + async def redo(message): + chat = message.chat + reply_id = None + if chat.type == 'group' and chat.id == GROUP_ID: + reply_id = message.message_id - @bot.message_handler(commands=['redo']) - async def redo_txt2img(message): - await _redo(message) + user_config = {**(await db_call('get_user_config', user))} + del user_config['id'] + prompt = await db_call('get_last_prompt_of', user) - @bot.message_handler(commands=['config']) - async def set_config(message): - rpc_params = {} - try: - attr, val, reply_txt = validate_user_config_request( - message.text) + resp = await session.rpc( + 'dgpu_call', { + 'method': 'diffuse', + 'params': { + 'prompt': prompt, + **user_config + } + }, + timeout=60 + ) + logging.info(f'resp to {message.id} arrived') - resp = await _rpc_call( - message.from_user.id, - 'config', {'attr': attr, 'val': val}) + resp_txt = '' + result = MessageToDict(resp.result) + if 'error' in resp.result: + resp_txt = resp.result['message'] + await bot.reply_to(message, resp_txt) - except BaseException as e: - reply_txt = str(e) + else: + logging.info(result['id']) + img_raw = resp.bin + logging.info(f'got image of size: {len(img_raw)}') + img = Image.open(io.BytesIO(img_raw)) - finally: - await bot.reply_to(message, reply_txt) + await bot.send_photo( + GROUP_ID, + caption=prepare_metainfo_caption(message.from_user, result['meta']['meta']), + photo=img, + reply_to_message_id=reply_id + ) + return - @bot.message_handler(commands=['stats']) - async def user_stats(message): - resp = await _rpc_call( - message.from_user.id, - 'stats', - {} - ) - stats = resp.result + @bot.message_handler(commands=['config']) + async def set_config(message): + rpc_params = {} + try: + attr, val, reply_txt = validate_user_config_request( + message.text) - stats_str = f'generated: {stats["generated"]}\n' - stats_str += f'joined: {stats["joined"]}\n' - stats_str += f'role: {stats["role"]}\n' + logging.info(f'user config update: {attr} to {val}') + await db_call('update_user_config', + user, req.params['attr'], req.params['val']) + logging.info('done') - await bot.reply_to( - message, stats_str) + except BaseException as e: + reply_txt = str(e) - @bot.message_handler(commands=['donate']) - async def donation_info(message): - await bot.reply_to( - message, DONATION_INFO) + finally: + await bot.reply_to(message, reply_txt) - @bot.message_handler(commands=['say']) - async def say(message): - chat = message.chat - user = message.from_user + @bot.message_handler(commands=['stats']) + async def user_stats(message): - if (chat.type == 'group') or (user.id != 383385940): - return + generated, joined, role = await db_call('get_user_stats', user) - await bot.send_message(GROUP_ID, message.text[4:]) + stats_str = f'generated: {generated}\n' + stats_str += f'joined: {joined}\n' + stats_str += f'role: {role}\n' + + await bot.reply_to( + message, stats_str) + + @bot.message_handler(commands=['donate']) + async def donation_info(message): + await bot.reply_to( + message, DONATION_INFO) + + @bot.message_handler(commands=['say']) + async def say(message): + chat = message.chat + user = message.from_user + + if (chat.type == 'group') or (user.id != 383385940): + return + + await bot.send_message(GROUP_ID, message.text[4:]) - @bot.message_handler(func=lambda message: True) - async def echo_message(message): - if message.text[0] == '/': - await bot.reply_to(message, UNKNOWN_CMD_TEXT) + @bot.message_handler(func=lambda message: True) + async def echo_message(message): + if message.text[0] == '/': + await bot.reply_to(message, UNKNOWN_CMD_TEXT) @bot.callback_query_handler(func=lambda call: True) async def callback_query(call): @@ -289,4 +334,4 @@ async def run_skynet_telegram( await _redo(call) - await aio_as_trio(bot.infinity_polling()) + await aio_as_trio(bot.infinity_polling)() diff --git a/skynet/network.py b/skynet/network.py new file mode 100644 index 0000000..95fb60f --- /dev/null +++ b/skynet/network.py @@ -0,0 +1,341 @@ +#!/usr/bin/python + +import zlib +import socket + +from typing import Callable, Awaitable, Optional +from pathlib import Path +from contextlib import asynccontextmanager as acm +from cryptography import x509 +from cryptography.hazmat.primitives import serialization + +import trio +import pynng + +from pynng import TLSConfig, Context + +from .protobuf import * +from .constants import * + + +def get_random_port(): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(('', 0)) + return s.getsockname()[1] + + +def load_certs( + certs_dir: str, + cert_name: str, + key_name: str +): + certs_dir = Path(certs_dir).resolve() + tls_key_data = (certs_dir / key_name).read_bytes() + tls_key = serialization.load_pem_private_key( + tls_key_data, + password=None + ) + + tls_cert_data = (certs_dir / cert_name).read_bytes() + tls_cert = x509.load_pem_x509_certificate( + tls_cert_data + ) + + tls_whitelist = {} + for cert_path in (*(certs_dir / 'whitelist').glob('*.cert'), certs_dir / 'brain.cert'): + tls_whitelist[cert_path.stem] = x509.load_pem_x509_certificate( + cert_path.read_bytes() + ) + + return ( + SessionTLSConfig( + TLSConfig.MODE_SERVER, + own_key_string=tls_key_data, + own_cert_string=tls_cert_data + ), + + tls_whitelist + ) + + +def load_certs_client( + certs_dir: str, + cert_name: str, + key_name: str, + ca_name: Optional[str] = None +): + certs_dir = Path(certs_dir).resolve() + if not ca_name: + ca_name = 'brain.cert' + + ca_cert_data = (certs_dir / ca_name).read_bytes() + + tls_key_data = (certs_dir / key_name).read_bytes() + + + tls_cert_data = (certs_dir / cert_name).read_bytes() + + + tls_whitelist = {} + for cert_path in (*(certs_dir / 'whitelist').glob('*.cert'), certs_dir / 'brain.cert'): + tls_whitelist[cert_path.stem] = x509.load_pem_x509_certificate( + cert_path.read_bytes() + ) + + return ( + SessionTLSConfig( + TLSConfig.MODE_CLIENT, + own_key_string=tls_key_data, + own_cert_string=tls_cert_data, + ca_string=ca_cert_data + ), + + tls_whitelist + ) + + +class SessionError(BaseException): + ... + + +class SessionTLSConfig(TLSConfig): + + def __init__( + self, + mode, + server_name=None, + ca_string=None, + own_key_string=None, + own_cert_string=None, + auth_mode=None, + ca_files=None, + cert_key_file=None, + passwd=None + ): + super().__init__( + mode, + server_name=server_name, + ca_string=ca_string, + own_key_string=own_key_string, + own_cert_string=own_cert_string, + auth_mode=auth_mode, + ca_files=ca_files, + cert_key_file=cert_key_file, + passwd=passwd + ) + + if ca_string: + self.ca_cert = x509.load_pem_x509_certificate(ca_string) + + self.cert = x509.load_pem_x509_certificate(own_cert_string) + self.key = serialization.load_pem_private_key( + own_key_string, + password=passwd + ) + + +class SessionServer: + + def __init__( + self, + addr: str, + msg_handler: Callable[ + [SkynetRPCRequest, Context], Awaitable[SkynetRPCResponse] + ], + cert_name: Optional[str] = None, + key_name: Optional[str] = None, + cert_dir: str = DEFAULT_CERTS_DIR, + recv_max_size = 0 + ): + self.addr = addr + self.msg_handler = msg_handler + + self.cert_name = cert_name + self.tls_config = None + self.tls_whitelist = None + if cert_name and key_name: + self.cert_name = cert_name + self.tls_config, self.tls_whitelist = load_certs( + cert_dir, cert_name, key_name) + + self.addr = 'tls+' + self.addr + + self.recv_max_size = recv_max_size + + async def _handle_msg(self, req: SkynetRPCRequest, ctx: Context): + resp = await self.msg_handler(req, ctx) + + if self.tls_config: + resp.auth.cert = 'skynet' + resp.auth.sig = sign_protobuf_msg( + resp, self.tls_config.key) + + raw_msg = zlib.compress(resp.SerializeToString()) + + await ctx.asend(raw_msg) + + ctx.close() + + async def _listener (self, sock): + async with trio.open_nursery() as n: + while True: + ctx = sock.new_context() + + raw_msg = await ctx.arecv() + raw_size = len(raw_msg) + logging.debug(f'rpc server new msg {raw_size} bytes') + + try: + msg = zlib.decompress(raw_msg) + msg_size = len(msg) + + except zlib.error: + logging.warning(f'Zlib decompress error, dropping msg of size {len(raw_msg)}') + continue + + logging.debug(f'msg after decompress {msg_size} bytes, +{msg_size - raw_size} bytes') + + req = SkynetRPCRequest() + try: + req.ParseFromString(msg) + + except google.protobuf.message.DecodeError: + logging.warning(f'Dropping malfomed msg of size {len(msg)}') + continue + + logging.debug(f'msg method: {req.method}') + + if self.tls_config: + if req.auth.cert not in self.tls_whitelist: + logging.warning( + f'{req.auth.cert} not in tls whitelist') + continue + + try: + verify_protobuf_msg(req, self.tls_whitelist[req.auth.cert]) + + except ValueError: + logging.warning( + f'{req.cert} sent an unauthenticated msg') + continue + + n.start_soon(self._handle_msg, req, ctx) + + @acm + async def open(self): + with pynng.Rep0( + recv_max_size=self.recv_max_size + ) as sock: + + if self.tls_config: + sock.tls_config = self.tls_config + + sock.listen(self.addr) + + logging.debug(f'server socket listening at {self.addr}') + + async with trio.open_nursery() as n: + n.start_soon(self._listener, sock) + + try: + yield self + + finally: + n.cancel_scope.cancel() + + logging.debug('server socket is off.') + + +class SessionClient: + + def __init__( + self, + connect_addr: str, + uid: str, + cert_name: Optional[str] = None, + key_name: Optional[str] = None, + ca_name: Optional[str] = None, + cert_dir: str = DEFAULT_CERTS_DIR, + recv_max_size = 0 + ): + self.uid = uid + self.connect_addr = connect_addr + + self.cert_name = None + self.tls_config = None + self.tls_whitelist = None + self.tls_cert = None + self.tls_key = None + if cert_name and key_name: + self.cert_name = Path(cert_name).stem + self.tls_config, self.tls_whitelist = load_certs_client( + cert_dir, cert_name, key_name, ca_name=ca_name) + + if not self.connect_addr.startswith('tls'): + self.connect_addr = 'tls+' + self.connect_addr + + self.recv_max_size = recv_max_size + + self._connected = False + self._sock = None + + def connect(self): + self._sock = pynng.Req0( + recv_max_size=0, + name=self.uid + ) + + if self.tls_config: + self._sock.tls_config = self.tls_config + + logging.debug(f'client is dialing {self.connect_addr}...') + self._sock.dial(self.connect_addr, block=True) + self._connected = True + logging.debug(f'client is connected to {self.connect_addr}') + + def disconnect(self): + self._sock.close() + self._connected = False + logging.debug(f'client disconnected.') + + async def rpc( + self, + method: str, + params: dict = {}, + binext: Optional[bytes] = None, + timeout: float = 2. + ): + if not self._connected: + raise SessionError('tried to use rpc without connecting') + + req = SkynetRPCRequest() + req.uid = self.uid + req.method = method + req.params.update(params) + if binext: + logging.debug('added binary extension') + req.bin = binext + + if self.tls_config: + req.auth.cert = self.cert_name + req.auth.sig = sign_protobuf_msg(req, self.tls_config.key) + + with trio.fail_after(timeout): + ctx = self._sock.new_context() + raw_req = zlib.compress(req.SerializeToString()) + logging.debug(f'rpc client sending new msg {method} of size {len(raw_req)}') + await ctx.asend(raw_req) + logging.debug('sent, awaiting response...') + raw_resp = await ctx.arecv() + logging.debug(f'rpc client got response of size {len(raw_resp)}') + raw_resp = zlib.decompress(raw_resp) + + resp = SkynetRPCResponse() + resp.ParseFromString(raw_resp) + ctx.close() + + if self.tls_config: + verify_protobuf_msg(resp, self.tls_config.ca_cert) + + return resp diff --git a/skynet/protobuf/__init__.py b/skynet/protobuf/__init__.py index b985940..acafec8 100644 --- a/skynet/protobuf/__init__.py +++ b/skynet/protobuf/__init__.py @@ -1,29 +1,4 @@ #!/usr/bin/python -from typing import Optional -from dataclasses import dataclass, asdict - -from google.protobuf.json_format import MessageToDict - from .auth import * from .skynet_pb2 import * - - -class Struct: - - def to_dict(self): - return asdict(self) - - -@dataclass -class DiffusionParameters(Struct): - algo: str - prompt: str - step: int - width: int - height: int - guidance: float - strength: float - seed: Optional[int] - image: bool # if true indicates a bytestream is next msg - upscaler: Optional[str] diff --git a/skynet/protobuf/auth.py b/skynet/protobuf/auth.py index e2904cb..876683d 100644 --- a/skynet/protobuf/auth.py +++ b/skynet/protobuf/auth.py @@ -7,7 +7,8 @@ from hashlib import sha256 from collections import OrderedDict from google.protobuf.json_format import MessageToDict -from OpenSSL.crypto import PKey, X509, verify, sign +from cryptography.hazmat.primitives import serialization, hashes +from cryptography.hazmat.primitives.asymmetric import padding from .skynet_pb2 import * @@ -46,20 +47,23 @@ def serialize_msg_deterministic(msg): if field_descriptor.message_type.name == 'Struct': hash_dict(MessageToDict(getattr(msg, field_name))) - deterministic_msg = shasum.hexdigest() + deterministic_msg = shasum.digest() return deterministic_msg -def sign_protobuf_msg(msg, key: PKey): - return sign( - key, serialize_msg_deterministic(msg), 'sha256').hex() +def sign_protobuf_msg(msg, key): + return key.sign( + serialize_msg_deterministic(msg), + padding.PKCS1v15(), + hashes.SHA256() + ).hex() -def verify_protobuf_msg(msg, cert: X509): - return verify( - cert, +def verify_protobuf_msg(msg, cert): + return cert.public_key().verify( bytes.fromhex(msg.auth.sig), serialize_msg_deterministic(msg), - 'sha256' + padding.PKCS1v15(), + hashes.SHA256() ) diff --git a/skynet/protobuf/skynet.proto b/skynet/protobuf/skynet.proto index 6e66274..0bdccad 100644 --- a/skynet/protobuf/skynet.proto +++ b/skynet/protobuf/skynet.proto @@ -13,18 +13,12 @@ message SkynetRPCRequest { string uid = 1; string method = 2; google.protobuf.Struct params = 3; - optional Auth auth = 4; + optional bytes bin = 4; + optional Auth auth = 5; } message SkynetRPCResponse { google.protobuf.Struct result = 1; - optional Auth auth = 2; -} - -message DGPUBusMessage { - string rid = 1; - string nid = 2; - string method = 3; - google.protobuf.Struct params = 4; - optional Auth auth = 5; + optional bytes bin = 2; + optional Auth auth = 3; } diff --git a/skynet/protobuf/skynet_pb2.py b/skynet/protobuf/skynet_pb2.py index dd7db33..84b0527 100644 --- a/skynet/protobuf/skynet_pb2.py +++ b/skynet/protobuf/skynet_pb2.py @@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default() from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cskynet.proto\x12\x06skynet\x1a\x1cgoogle/protobuf/struct.proto\"!\n\x04\x41uth\x12\x0c\n\x04\x63\x65rt\x18\x01 \x01(\t\x12\x0b\n\x03sig\x18\x02 \x01(\t\"\x82\x01\n\x10SkynetRPCRequest\x12\x0b\n\x03uid\x18\x01 \x01(\t\x12\x0e\n\x06method\x18\x02 \x01(\t\x12\'\n\x06params\x18\x03 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x1f\n\x04\x61uth\x18\x04 \x01(\x0b\x32\x0c.skynet.AuthH\x00\x88\x01\x01\x42\x07\n\x05_auth\"f\n\x11SkynetRPCResponse\x12\'\n\x06result\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x1f\n\x04\x61uth\x18\x02 \x01(\x0b\x32\x0c.skynet.AuthH\x00\x88\x01\x01\x42\x07\n\x05_auth\"\x8d\x01\n\x0e\x44GPUBusMessage\x12\x0b\n\x03rid\x18\x01 \x01(\t\x12\x0b\n\x03nid\x18\x02 \x01(\t\x12\x0e\n\x06method\x18\x03 \x01(\t\x12\'\n\x06params\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x1f\n\x04\x61uth\x18\x05 \x01(\x0b\x32\x0c.skynet.AuthH\x00\x88\x01\x01\x42\x07\n\x05_authb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cskynet.proto\x12\x06skynet\x1a\x1cgoogle/protobuf/struct.proto\"!\n\x04\x41uth\x12\x0c\n\x04\x63\x65rt\x18\x01 \x01(\t\x12\x0b\n\x03sig\x18\x02 \x01(\t\"\x9c\x01\n\x10SkynetRPCRequest\x12\x0b\n\x03uid\x18\x01 \x01(\t\x12\x0e\n\x06method\x18\x02 \x01(\t\x12\'\n\x06params\x18\x03 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x10\n\x03\x62in\x18\x04 \x01(\x0cH\x00\x88\x01\x01\x12\x1f\n\x04\x61uth\x18\x05 \x01(\x0b\x32\x0c.skynet.AuthH\x01\x88\x01\x01\x42\x06\n\x04_binB\x07\n\x05_auth\"\x80\x01\n\x11SkynetRPCResponse\x12\'\n\x06result\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x10\n\x03\x62in\x18\x02 \x01(\x0cH\x00\x88\x01\x01\x12\x1f\n\x04\x61uth\x18\x03 \x01(\x0b\x32\x0c.skynet.AuthH\x01\x88\x01\x01\x42\x06\n\x04_binB\x07\n\x05_authb\x06proto3') _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'skynet_pb2', globals()) @@ -24,9 +24,7 @@ if _descriptor._USE_C_DESCRIPTORS == False: _AUTH._serialized_start=54 _AUTH._serialized_end=87 _SKYNETRPCREQUEST._serialized_start=90 - _SKYNETRPCREQUEST._serialized_end=220 - _SKYNETRPCRESPONSE._serialized_start=222 - _SKYNETRPCRESPONSE._serialized_end=324 - _DGPUBUSMESSAGE._serialized_start=327 - _DGPUBUSMESSAGE._serialized_end=468 + _SKYNETRPCREQUEST._serialized_end=246 + _SKYNETRPCRESPONSE._serialized_start=249 + _SKYNETRPCRESPONSE._serialized_end=377 # @@protoc_insertion_point(module_scope) diff --git a/skynet/utils.py b/skynet/utils.py index ba1ce2d..f84c0ef 100644 --- a/skynet/utils.py +++ b/skynet/utils.py @@ -1,5 +1,6 @@ #!/usr/bin/python +import time import random from typing import Optional @@ -21,6 +22,10 @@ from huggingface_hub import login from .constants import ALGOS +def time_ms(): + return int(time.time() * 1000) + + def convert_from_cv2_to_image(img: np.ndarray) -> Image: # return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) return Image.fromarray(img) diff --git a/tests/conftest.py b/tests/conftest.py index 64a369f..ac2f4be 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,89 +3,30 @@ import os import json import time -import random -import string import logging -from functools import partial from pathlib import Path +from functools import partial -import trio import pytest -import psycopg2 -import trio_asyncio from docker.types import Mount, DeviceRequest -from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT -from skynet.constants import * +from skynet.db import open_new_database from skynet.brain import run_skynet +from skynet.network import get_random_port +from skynet.constants import * @pytest.fixture(scope='session') def postgres_db(dockerctl): - rpassword = ''.join( - random.choice(string.ascii_lowercase) - for i in range(12)) - password = ''.join( - random.choice(string.ascii_lowercase) - for i in range(12)) - - with dockerctl.run( - 'postgres', - name='skynet-test-postgres', - ports={'5432/tcp': None}, - environment={ - 'POSTGRES_PASSWORD': rpassword - } - ) as containers: - container = containers[0] - # ip = container.attrs['NetworkSettings']['IPAddress'] - port = container.ports['5432/tcp'][0]['HostPort'] - host = f'localhost:{port}' - - for log in container.logs(stream=True): - log = log.decode().rstrip() - logging.info(log) - if ('database system is ready to accept connections' in log or - 'database system is shut down' in log): - break - - # why print the system is ready to accept connections when its not - # postgres? wtf - time.sleep(1) - logging.info('creating skynet db...') - - conn = psycopg2.connect( - user='postgres', - password=rpassword, - host='localhost', - port=port - ) - logging.info('connected...') - conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) - with conn.cursor() as cursor: - cursor.execute( - f'CREATE USER {DB_USER} WITH PASSWORD \'{password}\'') - cursor.execute( - f'CREATE DATABASE {DB_NAME}') - cursor.execute( - f'GRANT ALL PRIVILEGES ON DATABASE {DB_NAME} TO {DB_USER}') - - conn.close() - - logging.info('done.') - yield container, password, host + with open_new_database() as db_params: + yield db_params @pytest.fixture -async def skynet_running(postgres_db): - db_container, db_pass, db_host = postgres_db - - async with run_skynet( - db_pass=db_pass, - db_host=db_host - ): +async def skynet_running(): + async with run_skynet(): yield @@ -99,11 +40,13 @@ def dgpu_workers(request, dockerctl, skynet_running): cmds = [] for i in range(num_containers): + dgpu_addr = f'tcp://127.0.0.1:{get_random_port()}' cmd = f''' pip install -e . && \ skynet run dgpu \ --algos=\'{json.dumps(initial_algos)}\' \ - --uid=dgpu-{i} + --uid=dgpu-{i} \ + --dgpu={dgpu_addr} ''' cmds.append(['bash', '-c', cmd]) @@ -120,7 +63,7 @@ def dgpu_workers(request, dockerctl, skynet_running): network='host', mounts=mounts, device_requests=devices, - num=num_containers + num=num_containers, ) as containers: yield containers diff --git a/tests/test_dgpu.py b/tests/test_dgpu.py index 4ce93bf..c187af0 100644 --- a/tests/test_dgpu.py +++ b/tests/test_dgpu.py @@ -12,29 +12,26 @@ from functools import partial import trio import pytest -import trio_asyncio from PIL import Image from google.protobuf.json_format import MessageToDict from skynet.brain import SkynetDGPUComputeError -from skynet.constants import * +from skynet.network import get_random_port, SessionServer +from skynet.protobuf import SkynetRPCResponse from skynet.frontend import open_skynet_rpc +from skynet.constants import * -async def wait_for_dgpus(rpc, amount: int, timeout: float = 30.0): +async def wait_for_dgpus(session, amount: int, timeout: float = 30.0): gpu_ready = False - start_time = time.time() - current_time = time.time() - while not gpu_ready and (current_time - start_time) < timeout: - res = await rpc('dgpu_workers') - if res.result['ok'] >= amount: - break + with trio.fail_after(timeout): + while not gpu_ready: + res = await session.rpc('dgpu_workers') + if res.result['ok'] >= amount: + break - await trio.sleep(1) - current_time = time.time() - - assert (current_time - start_time) < timeout + await trio.sleep(1) _images = set() @@ -48,34 +45,33 @@ async def check_request_img( ): global _images - async with open_skynet_rpc( + with open_skynet_rpc( uid, - security=True, - cert_name='whitelist/testing', - key_name='testing' - ) as rpc_call: - res = await rpc_call( - 'txt2img', { - 'prompt': 'red old tractor in a sunny wheat field', - 'step': 28, - 'width': width, 'height': height, - 'guidance': 7.5, - 'seed': None, - 'algo': list(ALGOS.keys())[i], - 'upscaler': upscaler - }) + cert_name='whitelist/testing.cert', + key_name='testing.key' + ) as session: + res = await session.rpc( + 'dgpu_call', { + 'method': 'diffuse', + 'params': { + 'prompt': 'red old tractor in a sunny wheat field', + 'step': 28, + 'width': width, 'height': height, + 'guidance': 7.5, + 'seed': None, + 'algo': list(ALGOS.keys())[i], + 'upscaler': upscaler + } + }, + timeout=60 + ) if 'error' in res.result: raise SkynetDGPUComputeError(MessageToDict(res.result)) - if upscaler == 'x4': - width *= 4 - height *= 4 - - img_raw = zlib.decompress(bytes.fromhex(res.result['img'])) + img_raw = res.bin img_sha = sha256(img_raw).hexdigest() - img = Image.frombytes( - 'RGB', (width, height), img_raw) + img = Image.open(io.BytesIO(img_raw)) if expect_unique and img_sha in _images: raise ValueError('Duplicated image sha: {img_sha}') @@ -96,13 +92,12 @@ async def test_dgpu_worker_compute_error(dgpu_workers): then generate a smaller image to show gpu worker recovery ''' - async with open_skynet_rpc( + with open_skynet_rpc( 'test-ctx', - security=True, - cert_name='whitelist/testing', - key_name='testing' - ) as test_rpc: - await wait_for_dgpus(test_rpc, 1) + cert_name='whitelist/testing.cert', + key_name='testing.key' + ) as session: + await wait_for_dgpus(session, 1) with pytest.raises(SkynetDGPUComputeError) as e: await check_request_img(0, width=4096, height=4096) @@ -112,20 +107,35 @@ async def test_dgpu_worker_compute_error(dgpu_workers): await check_request_img(0) +@pytest.mark.parametrize( + 'dgpu_workers', [(1, ['midj'])], indirect=True) +async def test_dgpu_worker(dgpu_workers): + '''Generate one image in a single dgpu worker + ''' + + with open_skynet_rpc( + 'test-ctx', + cert_name='whitelist/testing.cert', + key_name='testing.key' + ) as session: + await wait_for_dgpus(session, 1) + + await check_request_img(0) + + @pytest.mark.parametrize( 'dgpu_workers', [(1, ['midj', 'stable'])], indirect=True) -async def test_dgpu_workers(dgpu_workers): +async def test_dgpu_worker_two_models(dgpu_workers): '''Generate two images in a single dgpu worker using two different models. ''' - async with open_skynet_rpc( + with open_skynet_rpc( 'test-ctx', - security=True, - cert_name='whitelist/testing', - key_name='testing' - ) as test_rpc: - await wait_for_dgpus(test_rpc, 1) + cert_name='whitelist/testing.cert', + key_name='testing.key' + ) as session: + await wait_for_dgpus(session, 1) await check_request_img(0) await check_request_img(1) @@ -138,14 +148,12 @@ async def test_dgpu_worker_upscale(dgpu_workers): two different models. ''' - async with open_skynet_rpc( + with open_skynet_rpc( 'test-ctx', - security=True, - cert_name='whitelist/testing', - key_name='testing' - ) as test_rpc: - await wait_for_dgpus(test_rpc, 1) - logging.error('UPSCALE') + cert_name='whitelist/testing.cert', + key_name='testing.key' + ) as session: + await wait_for_dgpus(session, 1) img = await check_request_img(0, upscaler='x4') @@ -157,13 +165,12 @@ async def test_dgpu_worker_upscale(dgpu_workers): async def test_dgpu_workers_two(dgpu_workers): '''Generate two images in two separate dgpu workers ''' - async with open_skynet_rpc( + with open_skynet_rpc( 'test-ctx', - security=True, - cert_name='whitelist/testing', - key_name='testing' - ) as test_rpc: - await wait_for_dgpus(test_rpc, 2) + cert_name='whitelist/testing.cert', + key_name='testing.key' + ) as session: + await wait_for_dgpus(session, 2, timeout=60) async with trio.open_nursery() as n: n.start_soon(check_request_img, 0) @@ -175,13 +182,12 @@ async def test_dgpu_workers_two(dgpu_workers): async def test_dgpu_worker_algo_swap(dgpu_workers): '''Generate an image using a non default model ''' - async with open_skynet_rpc( + with open_skynet_rpc( 'test-ctx', - security=True, - cert_name='whitelist/testing', - key_name='testing' - ) as test_rpc: - await wait_for_dgpus(test_rpc, 1) + cert_name='whitelist/testing.cert', + key_name='testing.key' + ) as session: + await wait_for_dgpus(session, 1) await check_request_img(5) @@ -191,33 +197,32 @@ async def test_dgpu_rotation_next_worker(dgpu_workers): '''Connect three dgpu workers, disconnect and check next_worker rotation happens correctly ''' - async with open_skynet_rpc( + with open_skynet_rpc( 'test-ctx', - security=True, - cert_name='whitelist/testing', - key_name='testing' - ) as test_rpc: - await wait_for_dgpus(test_rpc, 3) + cert_name='whitelist/testing.cert', + key_name='testing.key' + ) as session: + await wait_for_dgpus(session, 3) - res = await test_rpc('dgpu_next') + res = await session.rpc('dgpu_next') assert 'ok' in res.result assert res.result['ok'] == 0 await check_request_img(0) - res = await test_rpc('dgpu_next') + res = await session.rpc('dgpu_next') assert 'ok' in res.result assert res.result['ok'] == 1 await check_request_img(0) - res = await test_rpc('dgpu_next') + res = await session.rpc('dgpu_next') assert 'ok' in res.result assert res.result['ok'] == 2 await check_request_img(0) - res = await test_rpc('dgpu_next') + res = await session.rpc('dgpu_next') assert 'ok' in res.result assert res.result['ok'] == 0 @@ -228,13 +233,12 @@ async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers): '''Connect three dgpu workers, disconnect the first one and check next_worker rotation happens correctly ''' - async with open_skynet_rpc( + with open_skynet_rpc( 'test-ctx', - security=True, - cert_name='whitelist/testing', - key_name='testing' - ) as test_rpc: - await wait_for_dgpus(test_rpc, 3) + cert_name='whitelist/testing.cert', + key_name='testing.key' + ) as session: + await wait_for_dgpus(session, 3) await trio.sleep(3) @@ -245,7 +249,7 @@ async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers): dgpu_workers[0].wait() - res = await test_rpc('dgpu_workers') + res = await session.rpc('dgpu_workers') assert 'ok' in res.result assert res.result['ok'] == 2 @@ -258,26 +262,43 @@ async def test_dgpu_no_ack_node_disconnect(skynet_running): '''Mock a node that connects, gets a request but fails to acknowledge it, then check skynet correctly drops the node ''' - async with open_skynet_rpc( - 'test-ctx', - security=True, - cert_name='whitelist/testing', - key_name='testing' - ) as rpc_call: - res = await rpc_call('dgpu_online') - assert 'ok' in res.result + async def mock_rpc(req, ctx): + resp = SkynetRPCResponse() + resp.result.update({'error': 'can\'t do it mate'}) + return resp - await wait_for_dgpus(rpc_call, 1) + dgpu_addr = f'tcp://127.0.0.1:{get_random_port()}' + mock_server = SessionServer( + dgpu_addr, + mock_rpc, + cert_name='whitelist/testing.cert', + key_name='testing.key' + ) - with pytest.raises(SkynetDGPUComputeError) as e: - await check_request_img(0) + async with mock_server.open(): + with open_skynet_rpc( + 'test-ctx', + cert_name='whitelist/testing.cert', + key_name='testing.key' + ) as session: - assert 'dgpu failed to acknowledge request' in str(e) + res = await session.rpc('dgpu_online', { + 'dgpu_addr': dgpu_addr, + 'cert': 'whitelist/testing.cert' + }) + assert 'ok' in res.result - res = await rpc_call('dgpu_workers') - assert 'ok' in res.result - assert res.result['ok'] == 0 + await wait_for_dgpus(session, 1) + + with pytest.raises(SkynetDGPUComputeError) as e: + await check_request_img(0) + + assert 'can\'t do it mate' in str(e.value) + + res = await session.rpc('dgpu_workers') + assert 'ok' in res.result + assert res.result['ok'] == 0 @pytest.mark.parametrize( @@ -286,13 +307,12 @@ async def test_dgpu_timeout_while_processing(dgpu_workers): '''Stop node while processing request to cause timeout and then check skynet correctly drops the node. ''' - async with open_skynet_rpc( + with open_skynet_rpc( 'test-ctx', - security=True, - cert_name='whitelist/testing', - key_name='testing' - ) as test_rpc: - await wait_for_dgpus(test_rpc, 1) + cert_name='whitelist/testing.cert', + key_name='testing.key' + ) as session: + await wait_for_dgpus(session, 1) async def check_request_img_raises(): with pytest.raises(SkynetDGPUComputeError) as e: @@ -308,72 +328,62 @@ async def test_dgpu_timeout_while_processing(dgpu_workers): assert ec == 0 -@pytest.mark.parametrize( - 'dgpu_workers', [(1, ['midj'])], indirect=True) -async def test_dgpu_heartbeat(dgpu_workers): - ''' - ''' - async with open_skynet_rpc( - 'test-ctx', - security=True, - cert_name='whitelist/testing', - key_name='testing' - ) as test_rpc: - await wait_for_dgpus(test_rpc, 1) - await trio.sleep(120) - - @pytest.mark.parametrize( 'dgpu_workers', [(1, ['midj'])], indirect=True) async def test_dgpu_img2img(dgpu_workers): - async with open_skynet_rpc( - '1', - security=True, - cert_name='whitelist/testing', - key_name='testing' - ) as rpc_call: - await wait_for_dgpus(rpc_call, 1) + with open_skynet_rpc( + 'test-ctx', + cert_name='whitelist/testing.cert', + key_name='testing.key' + ) as session: + await wait_for_dgpus(session, 1) + await trio.sleep(2) - res = await rpc_call( - 'txt2img', { - 'prompt': 'red old tractor in a sunny wheat field', - 'step': 28, - 'width': 512, 'height': 512, - 'guidance': 7.5, - 'seed': None, - 'algo': list(ALGOS.keys())[0], - 'upscaler': None - }) + res = await session.rpc( + 'dgpu_call', { + 'method': 'diffuse', + 'params': { + 'prompt': 'red old tractor in a sunny wheat field', + 'step': 28, + 'width': 512, 'height': 512, + 'guidance': 7.5, + 'seed': None, + 'algo': list(ALGOS.keys())[0], + 'upscaler': None + } + }, + timeout=60 + ) if 'error' in res.result: raise SkynetDGPUComputeError(MessageToDict(res.result)) - img_raw = res.result['img'] - img = zlib.decompress(bytes.fromhex(img_raw)) - logging.info(img[:10]) - img = Image.open(io.BytesIO(img)) - + img_raw = res.bin + img = Image.open(io.BytesIO(img_raw)) img.save('txt2img.png') - res = await rpc_call( - 'img2img', { - 'prompt': 'red sports car in a sunny wheat field', - 'step': 28, - 'img': img_raw, - 'guidance': 12, - 'seed': None, - 'algo': list(ALGOS.keys())[0], - 'upscaler': 'x4' - }) + res = await session.rpc( + 'dgpu_call', { + 'method': 'diffuse', + 'params': { + 'prompt': 'red ferrari in a sunny wheat field', + 'step': 28, + 'guidance': 8, + 'strength': 0.7, + 'seed': None, + 'algo': list(ALGOS.keys())[0], + 'upscaler': 'x4' + } + }, + binext=img_raw, + timeout=60 + ) if 'error' in res.result: raise SkynetDGPUComputeError(MessageToDict(res.result)) - img_raw = res.result['img'] - img = zlib.decompress(bytes.fromhex(img_raw)) - logging.info(img[:10]) - img = Image.open(io.BytesIO(img)) - + img_raw = res.bin + img = Image.open(io.BytesIO(img_raw)) img.save('img2img.png') diff --git a/tests/test_skynet.py b/tests/test_skynet.py index 5572a70..ad1c488 100644 --- a/tests/test_skynet.py +++ b/tests/test_skynet.py @@ -9,6 +9,7 @@ import trio_asyncio from skynet.brain import run_skynet from skynet.structs import * +from skynet.network import SessionServer from skynet.frontend import open_skynet_rpc @@ -17,54 +18,66 @@ async def test_skynet(skynet_running): async def test_skynet_attempt_insecure(skynet_running): - with pytest.raises(pynng.exceptions.NNGException) as e: - async with open_skynet_rpc('bad-actor'): - ... - - assert str(e.value) == 'Connection shutdown' + with pytest.raises(trio.TooSlowError) as e: + with open_skynet_rpc('bad-actor') as session: + with trio.fail_after(5): + await session.rpc('skynet_shutdown') async def test_skynet_dgpu_connection_simple(skynet_running): - async with open_skynet_rpc( + + async def rpc_handler(req, ctx): + ... + + fake_dgpu_addr = 'tcp://127.0.0.1:41001' + rpc_server = SessionServer( + fake_dgpu_addr, + rpc_handler, + cert_name='whitelist/testing.cert', + key_name='testing.key' + ) + + with open_skynet_rpc( 'dgpu-0', - security=True, - cert_name='whitelist/testing', - key_name='testing' - ) as rpc_call: + cert_name='whitelist/testing.cert', + key_name='testing.key' + ) as session: # check 0 nodes are connected - res = await rpc_call('dgpu_workers') + res = await session.rpc('dgpu_workers') assert 'ok' in res.result assert res.result['ok'] == 0 # check next worker is None - res = await rpc_call('dgpu_next') + res = await session.rpc('dgpu_next') assert 'ok' in res.result assert res.result['ok'] == None - # connect 1 dgpu - res = await rpc_call('dgpu_online') - assert 'ok' in res.result + async with rpc_server.open() as rpc_server: + # connect 1 dgpu + res = await session.rpc( + 'dgpu_online', {'dgpu_addr': fake_dgpu_addr}) + assert 'ok' in res.result - # check 1 node is connected - res = await rpc_call('dgpu_workers') - assert 'ok' in res.result - assert res.result['ok'] == 1 + # check 1 node is connected + res = await session.rpc('dgpu_workers') + assert 'ok' in res.result + assert res.result['ok'] == 1 - # check next worker is 0 - res = await rpc_call('dgpu_next') - assert 'ok' in res.result - assert res.result['ok'] == 0 + # check next worker is 0 + res = await session.rpc('dgpu_next') + assert 'ok' in res.result + assert res.result['ok'] == 0 - # disconnect 1 dgpu - res = await rpc_call('dgpu_offline') - assert 'ok' in res.result + # disconnect 1 dgpu + res = await session.rpc('dgpu_offline') + assert 'ok' in res.result # check 0 nodes are connected - res = await rpc_call('dgpu_workers') + res = await session.rpc('dgpu_workers') assert 'ok' in res.result assert res.result['ok'] == 0 # check next worker is None - res = await rpc_call('dgpu_next') + res = await session.rpc('dgpu_next') assert 'ok' in res.result assert res.result['ok'] == None diff --git a/tests/test_telegram.py b/tests/test_telegram.py new file mode 100644 index 0000000..d94a6bf --- /dev/null +++ b/tests/test_telegram.py @@ -0,0 +1,28 @@ +#!/usr/bin/python + +import trio + +from functools import partial + +from skynet.db import open_new_database +from skynet.brain import run_skynet +from skynet.config import load_skynet_ini +from skynet.frontend.telegram import run_skynet_telegram + + +if __name__ == '__main__': + '''You will need a telegram bot token configured on skynet.ini for this + ''' + with open_new_database() as db_params: + db_container, db_pass, db_host = db_params + config = load_skynet_ini() + + async def main(): + await run_skynet_telegram( + 'telegram-test', + config['skynet.telegram-test']['token'], + db_host=db_host, + db_pass=db_pass + ) + + trio.run(main)