From 6bc555f0d6d4b8ca91e6c384313f222fbdf6df7c Mon Sep 17 00:00:00 2001 From: Guillermo Rodriguez Date: Mon, 19 Dec 2022 12:36:02 -0300 Subject: [PATCH] Add authenticated messaging, also cmd line utils txt2img and upscale --- setup.py | 2 + skynet/brain.py | 110 ++++++++++++++++++++++-------------- skynet/cli.py | 89 +++++++++++++++++++++++------ skynet/db.py | 45 +++++++++------ skynet/dgpu.py | 72 +++++++++++++++++------ skynet/frontend/__init__.py | 80 ++++++++++++++++---------- skynet/frontend/telegram.py | 12 +++- skynet/structs.py | 51 ++++++++++++++++- skynet/utils.py | 74 ++++++++++++++++-------- tests/conftest.py | 15 +++-- tests/test_dgpu.py | 65 +++++++++++++++------ tests/test_skynet.py | 35 +++++------- 12 files changed, 459 insertions(+), 191 deletions(-) diff --git a/setup.py b/setup.py index 007883f..4781c43 100644 --- a/setup.py +++ b/setup.py @@ -10,6 +10,8 @@ setup( entry_points={ 'console_scripts': [ 'skynet = skynet.cli:skynet', + 'txt2img = skynet.cli:txt2img', + 'upscale = skynet.cli:upscale' ] }, install_requires=['click'] diff --git a/skynet/brain.py b/skynet/brain.py index 99732ce..7bd3cf7 100644 --- a/skynet/brain.py +++ b/skynet/brain.py @@ -16,6 +16,11 @@ import pynng import trio_asyncio from pynng import TLSConfig +from OpenSSL.crypto import ( + load_privatekey, + load_certificate, + FILETYPE_PEM +) from .db import * from .structs import * @@ -34,12 +39,14 @@ class SkynetDGPUComputeError(BaseException): class SkynetShutdownRequested(BaseException): ... + @acm -async def open_rpc_service(sock, dgpu_bus, db_pool): +async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): nodes = OrderedDict() wip_reqs = {} fin_reqs = {} next_worker: Optional[int] = None + security = len(tls_whitelist) > 0 def connect_node(uid): nonlocal next_worker @@ -109,27 +116,20 @@ async def open_rpc_service(sock, dgpu_bus, db_pool): async def dgpu_image_streamer(): nonlocal wip_reqs, fin_reqs while True: - msg = await dgpu_bus.arecv_msg() - rid = UUID(bytes=msg.bytes[:16]).hex - raw_msg = msg.bytes[16:] - logging.info(f'streamer got back {rid} of size {len(raw_msg)}') - match raw_msg[:5]: - case b'error': - img = raw_msg.decode() + msg = DGPUBusResponse( + **json.loads( + (await dgpu_bus.arecv()).decode())) - case b'ack': - img = raw_msg + if security: + msg.verify(tls_whitelist[msg.cert]) - case _: - img = base64.b64encode(raw_msg).hex() - - if rid not in wip_reqs: + if msg.rid not in wip_reqs: continue - fin_reqs[rid] = img - event = wip_reqs[rid] + fin_reqs[msg.rid] = msg + event = wip_reqs[msg.rid] event.set() - del wip_reqs[rid] + del wip_reqs[msg.rid] async def dgpu_stream_one_img(req: ImageGenRequest): nonlocal wip_reqs, fin_reqs, next_worker @@ -151,6 +151,9 @@ async def open_rpc_service(sock, dgpu_bus, db_pool): logging.info(f'dgpu_bus req: {dgpu_req}') + if security: + dgpu_req.sign(tls_key, 'skynet') + await dgpu_bus.asend( json.dumps(dgpu_req.to_dict()).encode()) @@ -163,8 +166,8 @@ async def open_rpc_service(sock, dgpu_bus, db_pool): disconnect_node(nid) raise SkynetDGPUOffline('dgpu failed to acknowledge request') - ack = fin_reqs[rid] - if ack != b'ack': + ack_msg = fin_reqs[rid] + if 'ack' not in ack_msg.params: disconnect_node(nid) raise SkynetDGPUOffline('dgpu failed to acknowledge request') @@ -178,15 +181,13 @@ async def open_rpc_service(sock, dgpu_bus, db_pool): nodes[nid]['task'] = None - img = fin_reqs[rid] + img_resp = fin_reqs[rid] del fin_reqs[rid] - logging.info(f'done streaming {len(img)} bytes') + if 'error' in img_resp.params: + raise SkynetDGPUComputeError(img_resp.params['error']) - if 'error' in img: - raise SkynetDGPUComputeError(img) - - return rid, img + return rid, img_resp.params['img'] async def handle_user_request(rpc_ctx, req): try: @@ -266,9 +267,13 @@ async def open_rpc_service(sock, dgpu_bus, db_pool): 'message': str(e) } + resp = SkynetRPCResponse(result=result) + + if security: + resp.sign(tls_key, 'skynet') + await rpc_ctx.asend( - json.dumps( - SkynetRPCResponse(result=result).to_dict()).encode()) + json.dumps(resp.to_dict()).encode()) async def request_service(n): nonlocal next_worker @@ -279,7 +284,19 @@ async def open_rpc_service(sock, dgpu_bus, db_pool): content = msg.bytes.decode() req = SkynetRPCRequest(**json.loads(content)) - logging.info(req) + if security: + if req.cert not in tls_whitelist: + logging.warning( + f'{req.cert} not in tls whitelist and security=True') + continue + + try: + req.verify(tls_whitelist[req.cert]) + + except ValueError: + logging.warning( + f'{req.cert} sent an unauthenticated msg with security=True') + continue result = {} @@ -303,10 +320,14 @@ async def open_rpc_service(sock, dgpu_bus, db_pool): handle_user_request, ctx, req) continue + resp = SkynetRPCResponse( + result={'ok': result}) + + if security: + resp.sign(tls_key, 'skynet') + await ctx.asend( - json.dumps( - SkynetRPCResponse( - result={'ok': result}).to_dict()).encode()) + json.dumps(resp.to_dict()).encode()) async with trio.open_nursery() as n: @@ -334,22 +355,28 @@ async def run_skynet( if security: # load tls certs certs_dir = Path(DEFAULT_CERTS_DIR).resolve() - tls_key = (certs_dir / DEFAULT_CERT_SKYNET_PRIV).read_text() - tls_cert = (certs_dir / DEFAULT_CERT_SKYNET_PUB).read_text() - tls_whitelist = [ - (cert_path).read_text() - for cert_path in (certs_dir / 'whitelist').glob('*.cert')] - cert_start = tls_cert.index('\n') + 1 - logging.info(f'tls_cert: {tls_cert[cert_start:cert_start+64]}...') + tls_key_data = (certs_dir / DEFAULT_CERT_SKYNET_PRIV).read_text() + tls_key = load_privatekey(FILETYPE_PEM, tls_key_data) + + tls_cert_data = (certs_dir / DEFAULT_CERT_SKYNET_PUB).read_text() + tls_cert = load_certificate(FILETYPE_PEM, tls_cert_data) + + tls_whitelist = {} + for cert_path in (certs_dir / 'whitelist').glob('*.cert'): + tls_whitelist[cert_path.stem] = load_certificate( + FILETYPE_PEM, cert_path.read_text()) + + cert_start = tls_cert_data.index('\n') + 1 + logging.info(f'tls_cert: {tls_cert_data[cert_start:cert_start+64]}...') logging.info(f'tls_whitelist len: {len(tls_whitelist)}') rpc_address = 'tls+' + rpc_address dgpu_address = 'tls+' + dgpu_address tls_config = TLSConfig( TLSConfig.MODE_SERVER, - own_key_string=tls_key, - own_cert_string=tls_cert) + own_key_string=tls_key_data, + own_cert_string=tls_cert_data) with ( pynng.Rep0() as rpc_sock, @@ -367,7 +394,8 @@ async def run_skynet( dgpu_bus.listen(dgpu_address) try: - async with open_rpc_service(rpc_sock, dgpu_bus, db_pool): + async with open_rpc_service( + rpc_sock, dgpu_bus, db_pool, tls_whitelist, tls_key): yield except SkynetShutdownRequested: diff --git a/skynet/cli.py b/skynet/cli.py index d0e3fa8..50360b1 100644 --- a/skynet/cli.py +++ b/skynet/cli.py @@ -9,44 +9,78 @@ from functools import partial import trio import click +from . import utils from .dgpu import open_dgpu_node -from .utils import txt2img -from .constants import ALGOS +from .brain import run_skynet + +from .frontend.telegram import run_skynet_telegram @click.group() def skynet(*args, **kwargs): pass -@skynet.command() -@click.option('--model', '-m', default=ALGOS['midj']) + +@click.command() +@click.option('--model', '-m', default='midj') @click.option( - '--prompt', '-p', default='a red tractor in a wheat field') + '--prompt', '-p', default='a red old tractor in a sunny wheat field') @click.option('--output', '-o', default='output.png') @click.option('--width', '-w', default=512) @click.option('--height', '-h', default=512) @click.option('--guidance', '-g', default=10.0) @click.option('--steps', '-s', default=26) @click.option('--seed', '-S', default=None) -def txt2img(*args -# model: str, -# prompt: str, -# output: str -# width: int, height: int, -# guidance: float, -# steps: int, -# seed: Optional[int] -): +def txt2img(*args, **kwargs): assert 'HF_TOKEN' in os.environ - txt2img( - os.environ['HF_TOKEN'], *args) + utils.txt2img(os.environ['HF_TOKEN'], **kwargs) + +@click.command() +@click.option( + '--prompt', '-p', default='a red old tractor in a sunny wheat field') +@click.option('--input', '-i', default='input.png') +@click.option('--output', '-o', default='output.png') +@click.option('--steps', '-s', default=26) +def upscale(prompt, input, output, steps): + assert 'HF_TOKEN' in os.environ + utils.upscale( + os.environ['HF_TOKEN'], + prompt=prompt, + img_path=input, + output=output, + steps=steps) + @skynet.group() def run(*args, **kwargs): pass + @run.command() @click.option('--loglevel', '-l', default='warning', help='Logging level') +@click.option( + '--host', '-h', default='localhost:5432') +@click.option( + '--pass', '-p', default='password') +def skynet( + loglevel: str, + host: str, + passw: str +): + async def _run_skynet(): + async with run_skynet( + db_host=host, + db_pass=passw + ): + await trio.sleep_forever() + + trio_asyncio.run(_run_skynet) + + +@run.command() +@click.option('--loglevel', '-l', default='warning', help='Logging level') +@click.option( + '--uid', '-u', required=True) @click.option( '--key', '-k', default='dgpu') @click.option( @@ -55,6 +89,7 @@ def run(*args, **kwargs): '--algos', '-a', default=None) def dgpu( loglevel: str, + uid: str, key: str, cert: str, algos: Optional[str] @@ -63,6 +98,28 @@ def dgpu( partial( open_dgpu_node, cert, + uid, key_name=key, initial_algos=json.loads(algos) )) + + +@run.command() +@click.option('--loglevel', '-l', default='warning', help='Logging level') +@click.option( + '--key', '-k', default='telegram-frontend') +@click.option( + '--cert', '-c', default='whitelist/telegram-frontend') +def telegram( + loglevel: str, + key: str, + cert: str +): + assert 'TG_TOKEN' in os.environ + trio_asyncio.run( + partial( + run_skynet_telegram, + os.environ['TG_TOKEN'], + key_name=key, + cert_name=cert + )) diff --git a/skynet/db.py b/skynet/db.py index 1b12e2c..7745e56 100644 --- a/skynet/db.py +++ b/skynet/db.py @@ -58,13 +58,16 @@ ALTER TABLE skynet.user_config def try_decode_uid(uid: str): + if isinstance(uid, int): + return None, uid + try: proto, uid = uid.split('+') uid = int(uid) return proto, uid except ValueError: - logging.warning(f'got non numeric uid?: {uid}') + logging.warning(f'got non chat proto uid?: {uid}') return None, None @@ -132,28 +135,38 @@ async def new_user(conn, uid: str): logging.info(f'new user! {uid}') - tg_id = None date = datetime.utcnow() proto, pid = try_decode_uid(uid) - match proto: - case 'tg': - tg_id = pid - async with conn.transaction(): - stmt = await conn.prepare(''' - INSERT INTO skynet.user( - tg_id, generated, joined, last_prompt, role) + match proto: + case 'tg': + tg_id = pid + stmt = await conn.prepare(''' + INSERT INTO skynet.user( + tg_id, generated, joined, last_prompt, role) - VALUES($1, $2, $3, $4, $5) - ON CONFLICT DO NOTHING - ''') - await stmt.fetch( - tg_id, 0, date, None, DEFAULT_ROLE - ) + VALUES($1, $2, $3, $4, $5) + ON CONFLICT DO NOTHING + ''') + await stmt.fetch( + tg_id, 0, date, None, DEFAULT_ROLE + ) + new_uid = await get_user(conn, uid) - new_uid = await get_user(conn, uid) + case None: + stmt = await conn.prepare(''' + INSERT INTO skynet.user( + id, generated, joined, last_prompt, role) + + VALUES($1, $2, $3, $4, $5) + ON CONFLICT DO NOTHING + ''') + await stmt.fetch( + pid, 0, date, None, DEFAULT_ROLE + ) + new_uid = pid stmt = await conn.prepare(''' INSERT INTO skynet.user_config( diff --git a/skynet/dgpu.py b/skynet/dgpu.py index 4efe1b9..a26336e 100644 --- a/skynet/dgpu.py +++ b/skynet/dgpu.py @@ -5,6 +5,7 @@ import io import trio import json import uuid +import base64 import random import logging @@ -16,10 +17,16 @@ import pynng import torch from pynng import TLSConfig +from OpenSSL.crypto import ( + load_privatekey, + load_certificate, + FILETYPE_PEM +) from diffusers import ( StableDiffusionPipeline, EulerAncestralDiscreteScheduler ) +from diffusers.models import UNet2DConditionModel from .structs import * from .constants import * @@ -58,6 +65,7 @@ class DGPUComputeError(BaseException): async def open_dgpu_node( cert_name: str, + unique_id: str, key_name: Optional[str], rpc_address: str = DEFAULT_RPC_ADDR, dgpu_address: str = DEFAULT_DGPU_ADDR, @@ -122,6 +130,7 @@ async def open_dgpu_node( async with open_skynet_rpc( + unique_id, security=security, cert_name=cert_name, key_name=key_name @@ -131,16 +140,23 @@ async def open_dgpu_node( if security: # load tls certs if not key_name: - key_name = certs_name + key_name = cert_name + certs_dir = Path(DEFAULT_CERTS_DIR).resolve() skynet_cert_path = certs_dir / 'brain.cert' tls_cert_path = certs_dir / f'{cert_name}.cert' tls_key_path = certs_dir / f'{key_name}.key' - skynet_cert = skynet_cert_path.read_text() - tls_cert = tls_cert_path.read_text() - tls_key = tls_key_path.read_text() + cert_name = tls_cert_path.stem + + skynet_cert_data = skynet_cert_path.read_text() + skynet_cert = load_certificate(FILETYPE_PEM, skynet_cert_data) + + tls_cert_data = tls_cert_path.read_text() + + tls_key_data = tls_key_path.read_text() + tls_key = load_privatekey(FILETYPE_PEM, tls_key_data) logging.info(f'skynet cert: {skynet_cert_path}') logging.info(f'dgpu cert: {tls_cert_path}') @@ -149,17 +165,16 @@ async def open_dgpu_node( dgpu_address = 'tls+' + dgpu_address tls_config = TLSConfig( TLSConfig.MODE_CLIENT, - own_key_string=tls_key, - own_cert_string=tls_cert, - ca_string=skynet_cert) + own_key_string=tls_key_data, + own_cert_string=tls_cert_data, + ca_string=skynet_cert_data) logging.info(f'connecting to {dgpu_address}') with pynng.Bus0() as dgpu_sock: dgpu_sock.tls_config = tls_config dgpu_sock.dial(dgpu_address) - res = await rpc_call(name.hex, 'dgpu_online') - logging.info(res) + res = await rpc_call('dgpu_online') assert 'ok' in res.result try: @@ -168,30 +183,55 @@ async def open_dgpu_node( req = DGPUBusRequest( **json.loads(msg.decode())) - if req.nid != name.hex: - logging.info('witnessed request {req.rid}, for {req.nid}') + if req.nid != unique_id: + logging.info( + f'witnessed msg {req.rid}, node involved: {req.nid}') continue + if security: + req.verify(skynet_cert) + + ack_resp = DGPUBusResponse( + rid=req.rid, + nid=req.nid, + params={'ack': {}} + ) + + if security: + ack_resp.sign(tls_key, cert_name) + # send ack await dgpu_sock.asend( - bytes.fromhex(req.rid) + b'ack') + json.dumps(ack_resp.to_dict()).encode()) logging.info(f'sent ack, processing {req.rid}...') try: img = await gpu_compute_one( ImageGenRequest(**req.params)) + img_resp = DGPUBusResponse( + rid=req.rid, + nid=req.nid, + params={'img': base64.b64encode(img).hex()} + ) except DGPUComputeError as e: - img = b'error' + str(e).encode() + img_resp = DGPUBusResponse( + rid=req.rid, + nid=req.nid, + params={'error': str(e)} + ) + if security: + img_resp.sign(tls_key, cert_name) + + # send final image await dgpu_sock.asend( - bytes.fromhex(req.rid) + img) + json.dumps(img_resp.to_dict()).encode()) except KeyboardInterrupt: logging.info('interrupt caught, stopping...') finally: - res = await rpc_call(name.hex, 'dgpu_offline') - logging.info(res) + res = await rpc_call('dgpu_offline') assert 'ok' in res.result diff --git a/skynet/frontend/__init__.py b/skynet/frontend/__init__.py index 0532bcd..4eaf918 100644 --- a/skynet/frontend/__init__.py +++ b/skynet/frontend/__init__.py @@ -9,6 +9,11 @@ from contextlib import asynccontextmanager as acm import pynng from pynng import TLSConfig +from OpenSSL.crypto import ( + load_privatekey, + load_certificate, + FILETYPE_PEM +) from ..structs import SkynetRPCRequest, SkynetRPCResponse from ..constants import * @@ -30,56 +35,73 @@ class ConfigSizeDivisionByEight(BaseException): ... -async def rpc_call( - sock, - uid: Union[int, str], - method: str, - params: dict = {} -): - req = SkynetRPCRequest( - uid=uid, - method=method, - params=params - ) - await sock.asend( - json.dumps( - req.to_dict()).encode()) - - return SkynetRPCResponse( - **json.loads( - (await sock.arecv_msg()).bytes.decode())) - - @acm async def open_skynet_rpc( + unique_id: str, rpc_address: str = DEFAULT_RPC_ADDR, security: bool = False, cert_name: Optional[str] = None, key_name: Optional[str] = None ): tls_config = None + if security: # load tls certs if not key_name: - key_name = certs_name + key_name = cert_name + certs_dir = Path(DEFAULT_CERTS_DIR).resolve() - skynet_cert = (certs_dir / 'brain.cert').read_text() - tls_cert = (certs_dir / f'{cert_name}.cert').read_text() - tls_key = (certs_dir / f'{key_name}.key').read_text() + + skynet_cert_data = (certs_dir / 'brain.cert').read_text() + skynet_cert = load_certificate(FILETYPE_PEM, skynet_cert_data) + + tls_cert_path = certs_dir / f'{cert_name}.cert' + tls_cert_data = tls_cert_path.read_text() + tls_cert = load_certificate(FILETYPE_PEM, tls_cert_data) + cert_name = tls_cert_path.stem + + tls_key_data = (certs_dir / f'{key_name}.key').read_text() + tls_key = load_privatekey(FILETYPE_PEM, tls_key_data) + rpc_address = 'tls+' + rpc_address tls_config = TLSConfig( TLSConfig.MODE_CLIENT, - own_key_string=tls_key, - own_cert_string=tls_cert, - ca_string=skynet_cert) + own_key_string=tls_key_data, + own_cert_string=tls_cert_data, + ca_string=skynet_cert_data) with pynng.Req0() as sock: if security: sock.tls_config = tls_config sock.dial(rpc_address) - async def _rpc_call(*args, **kwargs): - return await rpc_call(sock, *args, **kwargs) + + async def _rpc_call( + method: str, + params: dict = {}, + uid: Optional[Union[int, str]] = None + ): + req = SkynetRPCRequest( + uid=uid if uid else unique_id, + method=method, + params=params + ) + + if security: + req.sign(tls_key, cert_name) + + await sock.asend( + json.dumps( + req.to_dict()).encode()) + + resp = SkynetRPCResponse( + **json.loads( + (await sock.arecv_msg()).bytes.decode())) + + if security: + resp.verify(skynet_cert) + + return resp yield _rpc_call diff --git a/skynet/frontend/telegram.py b/skynet/frontend/telegram.py index 1e0b0ab..6f217b3 100644 --- a/skynet/frontend/telegram.py +++ b/skynet/frontend/telegram.py @@ -18,14 +18,19 @@ PREFIX = 'tg' async def run_skynet_telegram( - tg_token: str + tg_token: str, + key_name: str = 'telegram-frontend', + cert_name: str = 'whitelist/telegram-frontend' ): logging.basicConfig(level=logging.INFO) bot = AsyncTeleBot(tg_token) with open_skynet_rpc( - security=True, cert_name='telegram-frontend' + 'skynet-telegram-0', + security=True, + cert_name=cert, + key_name=key ) as rpc_call: async def _rpc_call( @@ -33,7 +38,8 @@ async def run_skynet_telegram( method: str, params: dict ): - return await rpc_call(f'{PREFIX}+{uid}', method, params) + return await rpc_call( + method, params, uid=f'{PREFIX}+{uid}') @bot.message_handler(commands=['help']) async def send_help(message): diff --git a/skynet/structs.py b/skynet/structs.py index 4332229..cc9f25f 100644 --- a/skynet/structs.py +++ b/skynet/structs.py @@ -18,6 +18,8 @@ Built-in (extension) types. """ import sys +import json + from typing import Optional, Union from pprint import pformat @@ -83,15 +85,51 @@ class Struct( setattr(self, fname, ftype(getattr(self, fname))) # proto +from OpenSSL.crypto import PKey, X509, verify, sign -class SkynetRPCRequest(Struct): + +class AuthenticatedStruct(Struct): + cert: Optional[str] = None + sig: Optional[str] = None + + def to_unsigned_dict(self) -> dict: + self_dict = self.to_dict() + + if 'sig' in self_dict: + del self_dict['sig'] + + if 'cert' in self_dict: + del self_dict['cert'] + + return self_dict + + def unsigned_to_bytes(self) -> bytes: + return json.dumps( + self.to_unsigned_dict()).encode() + + def sign(self, key: PKey, cert: str): + self.cert = cert + self.sig = sign( + key, self.unsigned_to_bytes(), 'sha256').hex() + + def verify(self, cert: X509): + if not self.sig: + raise ValueError('Tried to verify unsigned request') + + return verify( + cert, bytes.fromhex(self.sig), self.unsigned_to_bytes(), 'sha256') + + +class SkynetRPCRequest(AuthenticatedStruct): uid: Union[str, int] # user unique id method: str # rpc method name params: dict # variable params -class SkynetRPCResponse(Struct): + +class SkynetRPCResponse(AuthenticatedStruct): result: dict + class ImageGenRequest(Struct): prompt: str step: int @@ -102,8 +140,15 @@ class ImageGenRequest(Struct): algo: str upscaler: Optional[str] -class DGPUBusRequest(Struct): + +class DGPUBusRequest(AuthenticatedStruct): rid: str # req id nid: str # node id task: str params: dict + + +class DGPUBusResponse(AuthenticatedStruct): + rid: str # req id + nid: str # node id + params: dict diff --git a/skynet/utils.py b/skynet/utils.py index 06b8863..0534160 100644 --- a/skynet/utils.py +++ b/skynet/utils.py @@ -7,44 +7,37 @@ from pathlib import Path import torch -from diffusers import StableDiffusionPipeline +from PIL import Image +from diffusers import ( + StableDiffusionPipeline, + StableDiffusionUpscalePipeline, + EulerAncestralDiscreteScheduler +) from huggingface_hub import login +from .dgpu import pipeline_for + def txt2img( hf_token: str, - model_name: str, - prompt: str, - output: str, - width: int, height: int, - guidance: float, - steps: int, - seed: Optional[int] + model: str = 'midj', + prompt: str = 'a red old tractor in a sunny wheat field', + output: str = 'output.png', + width: int = 512, height: int = 512, + guidance: float = 10, + steps: int = 28, + seed: Optional[int] = None ): assert torch.cuda.is_available() torch.cuda.empty_cache() - torch.cuda.set_per_process_memory_fraction(0.333) + torch.cuda.set_per_process_memory_fraction(1.0) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True login(token=hf_token) + pipe = pipeline_for(model) - params = { - 'torch_dtype': torch.float16, - 'safety_checker': None - } - if model_name == 'runwayml/stable-diffusion-v1-5': - params['revision'] = 'fp16' - - pipe = StableDiffusionPipeline.from_pretrained( - model_name, **params) - - pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( - pipe.scheduler.config) - - pipe = pipe.to("cuda") - - seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64) + seed = seed if seed else random.randint(0, 2 ** 64) prompt = prompt image = pipe( prompt, @@ -55,3 +48,34 @@ def txt2img( ).images[0] image.save(output) + + +def upscale( + hf_token: str, + prompt: str = 'a red old tractor in a sunny wheat field', + img_path: str = 'input.png', + output: str = 'output.png', + steps: int = 28 +): + assert torch.cuda.is_available() + torch.cuda.empty_cache() + torch.cuda.set_per_process_memory_fraction(1.0) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + login(token=hf_token) + params = { + 'torch_dtype': torch.float16, + 'safety_checker': None + } + + pipe = StableDiffusionUpscalePipeline.from_pretrained( + 'stabilityai/stable-diffusion-x4-upscaler', **params) + + prompt = prompt + image = pipe( + prompt, + image=Image.open(img_path) + ).images[0] + + image.save(output) diff --git a/tests/conftest.py b/tests/conftest.py index ff58407..64a369f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -97,17 +97,22 @@ def dgpu_workers(request, dockerctl, skynet_running): num_containers, initial_algos = request.param - cmd = f''' - pip install -e . && \ - skynet run dgpu --algos=\'{json.dumps(initial_algos)}\' - ''' + cmds = [] + for i in range(num_containers): + cmd = f''' + pip install -e . && \ + skynet run dgpu \ + --algos=\'{json.dumps(initial_algos)}\' \ + --uid=dgpu-{i} + ''' + cmds.append(['bash', '-c', cmd]) logging.info(f'launching: \n{cmd}') with dockerctl.run( DOCKER_RUNTIME_CUDA, name='skynet-test-runtime-cuda', - command=['bash', '-c', cmd], + commands=cmds, environment={ 'HF_TOKEN': os.environ['HF_TOKEN'], 'HF_HOME': '/skynet/hf_home' diff --git a/tests/test_dgpu.py b/tests/test_dgpu.py index 51f1423..28426a9 100644 --- a/tests/test_dgpu.py +++ b/tests/test_dgpu.py @@ -26,8 +26,7 @@ async def wait_for_dgpus(rpc, amount: int, timeout: float = 30.0): start_time = time.time() current_time = time.time() while not gpu_ready and (current_time - start_time) < timeout: - res = await rpc('dgpu-test', 'dgpu_workers') - logging.info(res) + res = await rpc('dgpu_workers') if res.result['ok'] >= amount: break @@ -40,6 +39,7 @@ async def wait_for_dgpus(rpc, amount: int, timeout: float = 30.0): _images = set() async def check_request_img( i: int, + uid: int = 0, width: int = 512, height: int = 512, expect_unique=True @@ -47,12 +47,13 @@ async def check_request_img( global _images async with open_skynet_rpc( + uid, security=True, cert_name='whitelist/testing', key_name='testing' ) as rpc_call: res = await rpc_call( - 'tg+580213293', 'txt2img', { + 'txt2img', { 'prompt': 'red old tractor in a sunny wheat field', 'step': 28, 'width': width, 'height': height, @@ -88,6 +89,7 @@ async def test_dgpu_worker_compute_error(dgpu_workers): ''' async with open_skynet_rpc( + 'test-ctx', security=True, cert_name='whitelist/testing', key_name='testing' @@ -110,6 +112,7 @@ async def test_dgpu_workers(dgpu_workers): ''' async with open_skynet_rpc( + 'test-ctx', security=True, cert_name='whitelist/testing', key_name='testing' @@ -126,6 +129,7 @@ async def test_dgpu_workers_two(dgpu_workers): '''Generate two images in two separate dgpu workers ''' async with open_skynet_rpc( + 'test-ctx', security=True, cert_name='whitelist/testing', key_name='testing' @@ -143,6 +147,7 @@ async def test_dgpu_worker_algo_swap(dgpu_workers): '''Generate an image using a non default model ''' async with open_skynet_rpc( + 'test-ctx', security=True, cert_name='whitelist/testing', key_name='testing' @@ -158,35 +163,32 @@ async def test_dgpu_rotation_next_worker(dgpu_workers): rotation happens correctly ''' async with open_skynet_rpc( + 'test-ctx', security=True, cert_name='whitelist/testing', key_name='testing' ) as test_rpc: await wait_for_dgpus(test_rpc, 3) - res = await test_rpc('testing-rpc', 'dgpu_next') - logging.info(res) + res = await test_rpc('dgpu_next') assert 'ok' in res.result assert res.result['ok'] == 0 await check_request_img(0) - res = await test_rpc('testing-rpc', 'dgpu_next') - logging.info(res) + res = await test_rpc('dgpu_next') assert 'ok' in res.result assert res.result['ok'] == 1 await check_request_img(0) - res = await test_rpc('testing-rpc', 'dgpu_next') - logging.info(res) + res = await test_rpc('dgpu_next') assert 'ok' in res.result assert res.result['ok'] == 2 await check_request_img(0) - res = await test_rpc('testing-rpc', 'dgpu_next') - logging.info(res) + res = await test_rpc('dgpu_next') assert 'ok' in res.result assert res.result['ok'] == 0 @@ -198,6 +200,7 @@ async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers): next_worker rotation happens correctly ''' async with open_skynet_rpc( + 'test-ctx', security=True, cert_name='whitelist/testing', key_name='testing' @@ -213,8 +216,7 @@ async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers): dgpu_workers[0].wait() - res = await test_rpc('testing-rpc', 'dgpu_workers') - logging.info(res) + res = await test_rpc('dgpu_workers') assert 'ok' in res.result assert res.result['ok'] == 2 @@ -224,14 +226,17 @@ async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers): async def test_dgpu_no_ack_node_disconnect(skynet_running): + '''Mock a node that connects, gets a request but fails to + acknowledge it, then check skynet correctly drops the node + ''' async with open_skynet_rpc( + 'test-ctx', security=True, cert_name='whitelist/testing', key_name='testing' ) as rpc_call: - res = await rpc_call('dgpu-0', 'dgpu_online') - logging.info(res) + res = await rpc_call('dgpu_online') assert 'ok' in res.result await wait_for_dgpus(rpc_call, 1) @@ -241,8 +246,34 @@ async def test_dgpu_no_ack_node_disconnect(skynet_running): assert 'dgpu failed to acknowledge request' in str(e) - res = await rpc_call('testing-rpc', 'dgpu_workers') - logging.info(res) + res = await rpc_call('dgpu_workers') assert 'ok' in res.result assert res.result['ok'] == 0 + +@pytest.mark.parametrize( + 'dgpu_workers', [(1, ['midj'])], indirect=True) +async def test_dgpu_timeout_while_processing(dgpu_workers): + '''Stop node while processing request to cause timeout and + then check skynet correctly drops the node. + ''' + async with open_skynet_rpc( + 'test-ctx', + security=True, + cert_name='whitelist/testing', + key_name='testing' + ) as test_rpc: + await wait_for_dgpus(test_rpc, 1) + + async def check_request_img_raises(): + with pytest.raises(SkynetDGPUComputeError) as e: + await check_request_img(0) + + assert 'timeout while processing request' in str(e) + + async with trio.open_nursery() as n: + n.start_soon(check_request_img_raises) + await trio.sleep(1) + ec, out = dgpu_workers[0].exec_run( + ['pkill', '-TERM', '-f', 'skynet']) + assert ec == 0 diff --git a/tests/test_skynet.py b/tests/test_skynet.py index c6f5a89..5572a70 100644 --- a/tests/test_skynet.py +++ b/tests/test_skynet.py @@ -12,64 +12,59 @@ from skynet.structs import * from skynet.frontend import open_skynet_rpc +async def test_skynet(skynet_running): + ... + + async def test_skynet_attempt_insecure(skynet_running): with pytest.raises(pynng.exceptions.NNGException) as e: - async with open_skynet_rpc(): - ... + async with open_skynet_rpc('bad-actor'): + ... assert str(e.value) == 'Connection shutdown' async def test_skynet_dgpu_connection_simple(skynet_running): async with open_skynet_rpc( + 'dgpu-0', security=True, cert_name='whitelist/testing', key_name='testing' ) as rpc_call: # check 0 nodes are connected - res = await rpc_call('dgpu-0', 'dgpu_workers') - logging.info(res) + res = await rpc_call('dgpu_workers') assert 'ok' in res.result assert res.result['ok'] == 0 # check next worker is None - res = await rpc_call('dgpu-0', 'dgpu_next') - logging.info(res) + res = await rpc_call('dgpu_next') assert 'ok' in res.result assert res.result['ok'] == None # connect 1 dgpu - res = await rpc_call( - 'dgpu-0', 'dgpu_online') - logging.info(res) + res = await rpc_call('dgpu_online') assert 'ok' in res.result # check 1 node is connected - res = await rpc_call('dgpu-0', 'dgpu_workers') - logging.info(res) + res = await rpc_call('dgpu_workers') assert 'ok' in res.result assert res.result['ok'] == 1 # check next worker is 0 - res = await rpc_call('dgpu-0', 'dgpu_next') - logging.info(res) + res = await rpc_call('dgpu_next') assert 'ok' in res.result assert res.result['ok'] == 0 # disconnect 1 dgpu - res = await rpc_call( - 'dgpu-0', 'dgpu_offline') - logging.info(res) + res = await rpc_call('dgpu_offline') assert 'ok' in res.result # check 0 nodes are connected - res = await rpc_call('dgpu-0', 'dgpu_workers') - logging.info(res) + res = await rpc_call('dgpu_workers') assert 'ok' in res.result assert res.result['ok'] == 0 # check next worker is None - res = await rpc_call('dgpu-0', 'dgpu_next') - logging.info(res) + res = await rpc_call('dgpu_next') assert 'ok' in res.result assert res.result['ok'] == None