Add authenticated messaging, also cmd line utils txt2img and upscale

pull/2/head
Guillermo Rodriguez 2022-12-19 12:36:02 -03:00
parent f6326ad05c
commit 6bc555f0d6
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
12 changed files with 459 additions and 191 deletions

View File

@ -10,6 +10,8 @@ setup(
entry_points={
'console_scripts': [
'skynet = skynet.cli:skynet',
'txt2img = skynet.cli:txt2img',
'upscale = skynet.cli:upscale'
]
},
install_requires=['click']

View File

@ -16,6 +16,11 @@ import pynng
import trio_asyncio
from pynng import TLSConfig
from OpenSSL.crypto import (
load_privatekey,
load_certificate,
FILETYPE_PEM
)
from .db import *
from .structs import *
@ -34,12 +39,14 @@ class SkynetDGPUComputeError(BaseException):
class SkynetShutdownRequested(BaseException):
...
@acm
async def open_rpc_service(sock, dgpu_bus, db_pool):
async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
nodes = OrderedDict()
wip_reqs = {}
fin_reqs = {}
next_worker: Optional[int] = None
security = len(tls_whitelist) > 0
def connect_node(uid):
nonlocal next_worker
@ -109,27 +116,20 @@ async def open_rpc_service(sock, dgpu_bus, db_pool):
async def dgpu_image_streamer():
nonlocal wip_reqs, fin_reqs
while True:
msg = await dgpu_bus.arecv_msg()
rid = UUID(bytes=msg.bytes[:16]).hex
raw_msg = msg.bytes[16:]
logging.info(f'streamer got back {rid} of size {len(raw_msg)}')
match raw_msg[:5]:
case b'error':
img = raw_msg.decode()
msg = DGPUBusResponse(
**json.loads(
(await dgpu_bus.arecv()).decode()))
case b'ack':
img = raw_msg
if security:
msg.verify(tls_whitelist[msg.cert])
case _:
img = base64.b64encode(raw_msg).hex()
if rid not in wip_reqs:
if msg.rid not in wip_reqs:
continue
fin_reqs[rid] = img
event = wip_reqs[rid]
fin_reqs[msg.rid] = msg
event = wip_reqs[msg.rid]
event.set()
del wip_reqs[rid]
del wip_reqs[msg.rid]
async def dgpu_stream_one_img(req: ImageGenRequest):
nonlocal wip_reqs, fin_reqs, next_worker
@ -151,6 +151,9 @@ async def open_rpc_service(sock, dgpu_bus, db_pool):
logging.info(f'dgpu_bus req: {dgpu_req}')
if security:
dgpu_req.sign(tls_key, 'skynet')
await dgpu_bus.asend(
json.dumps(dgpu_req.to_dict()).encode())
@ -163,8 +166,8 @@ async def open_rpc_service(sock, dgpu_bus, db_pool):
disconnect_node(nid)
raise SkynetDGPUOffline('dgpu failed to acknowledge request')
ack = fin_reqs[rid]
if ack != b'ack':
ack_msg = fin_reqs[rid]
if 'ack' not in ack_msg.params:
disconnect_node(nid)
raise SkynetDGPUOffline('dgpu failed to acknowledge request')
@ -178,15 +181,13 @@ async def open_rpc_service(sock, dgpu_bus, db_pool):
nodes[nid]['task'] = None
img = fin_reqs[rid]
img_resp = fin_reqs[rid]
del fin_reqs[rid]
logging.info(f'done streaming {len(img)} bytes')
if 'error' in img_resp.params:
raise SkynetDGPUComputeError(img_resp.params['error'])
if 'error' in img:
raise SkynetDGPUComputeError(img)
return rid, img
return rid, img_resp.params['img']
async def handle_user_request(rpc_ctx, req):
try:
@ -266,9 +267,13 @@ async def open_rpc_service(sock, dgpu_bus, db_pool):
'message': str(e)
}
resp = SkynetRPCResponse(result=result)
if security:
resp.sign(tls_key, 'skynet')
await rpc_ctx.asend(
json.dumps(
SkynetRPCResponse(result=result).to_dict()).encode())
json.dumps(resp.to_dict()).encode())
async def request_service(n):
nonlocal next_worker
@ -279,7 +284,19 @@ async def open_rpc_service(sock, dgpu_bus, db_pool):
content = msg.bytes.decode()
req = SkynetRPCRequest(**json.loads(content))
logging.info(req)
if security:
if req.cert not in tls_whitelist:
logging.warning(
f'{req.cert} not in tls whitelist and security=True')
continue
try:
req.verify(tls_whitelist[req.cert])
except ValueError:
logging.warning(
f'{req.cert} sent an unauthenticated msg with security=True')
continue
result = {}
@ -303,10 +320,14 @@ async def open_rpc_service(sock, dgpu_bus, db_pool):
handle_user_request, ctx, req)
continue
resp = SkynetRPCResponse(
result={'ok': result})
if security:
resp.sign(tls_key, 'skynet')
await ctx.asend(
json.dumps(
SkynetRPCResponse(
result={'ok': result}).to_dict()).encode())
json.dumps(resp.to_dict()).encode())
async with trio.open_nursery() as n:
@ -334,22 +355,28 @@ async def run_skynet(
if security:
# load tls certs
certs_dir = Path(DEFAULT_CERTS_DIR).resolve()
tls_key = (certs_dir / DEFAULT_CERT_SKYNET_PRIV).read_text()
tls_cert = (certs_dir / DEFAULT_CERT_SKYNET_PUB).read_text()
tls_whitelist = [
(cert_path).read_text()
for cert_path in (certs_dir / 'whitelist').glob('*.cert')]
cert_start = tls_cert.index('\n') + 1
logging.info(f'tls_cert: {tls_cert[cert_start:cert_start+64]}...')
tls_key_data = (certs_dir / DEFAULT_CERT_SKYNET_PRIV).read_text()
tls_key = load_privatekey(FILETYPE_PEM, tls_key_data)
tls_cert_data = (certs_dir / DEFAULT_CERT_SKYNET_PUB).read_text()
tls_cert = load_certificate(FILETYPE_PEM, tls_cert_data)
tls_whitelist = {}
for cert_path in (certs_dir / 'whitelist').glob('*.cert'):
tls_whitelist[cert_path.stem] = load_certificate(
FILETYPE_PEM, cert_path.read_text())
cert_start = tls_cert_data.index('\n') + 1
logging.info(f'tls_cert: {tls_cert_data[cert_start:cert_start+64]}...')
logging.info(f'tls_whitelist len: {len(tls_whitelist)}')
rpc_address = 'tls+' + rpc_address
dgpu_address = 'tls+' + dgpu_address
tls_config = TLSConfig(
TLSConfig.MODE_SERVER,
own_key_string=tls_key,
own_cert_string=tls_cert)
own_key_string=tls_key_data,
own_cert_string=tls_cert_data)
with (
pynng.Rep0() as rpc_sock,
@ -367,7 +394,8 @@ async def run_skynet(
dgpu_bus.listen(dgpu_address)
try:
async with open_rpc_service(rpc_sock, dgpu_bus, db_pool):
async with open_rpc_service(
rpc_sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
yield
except SkynetShutdownRequested:

View File

@ -9,44 +9,78 @@ from functools import partial
import trio
import click
from . import utils
from .dgpu import open_dgpu_node
from .utils import txt2img
from .constants import ALGOS
from .brain import run_skynet
from .frontend.telegram import run_skynet_telegram
@click.group()
def skynet(*args, **kwargs):
pass
@skynet.command()
@click.option('--model', '-m', default=ALGOS['midj'])
@click.command()
@click.option('--model', '-m', default='midj')
@click.option(
'--prompt', '-p', default='a red tractor in a wheat field')
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
@click.option('--output', '-o', default='output.png')
@click.option('--width', '-w', default=512)
@click.option('--height', '-h', default=512)
@click.option('--guidance', '-g', default=10.0)
@click.option('--steps', '-s', default=26)
@click.option('--seed', '-S', default=None)
def txt2img(*args
# model: str,
# prompt: str,
# output: str
# width: int, height: int,
# guidance: float,
# steps: int,
# seed: Optional[int]
):
def txt2img(*args, **kwargs):
assert 'HF_TOKEN' in os.environ
txt2img(
os.environ['HF_TOKEN'], *args)
utils.txt2img(os.environ['HF_TOKEN'], **kwargs)
@click.command()
@click.option(
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
@click.option('--input', '-i', default='input.png')
@click.option('--output', '-o', default='output.png')
@click.option('--steps', '-s', default=26)
def upscale(prompt, input, output, steps):
assert 'HF_TOKEN' in os.environ
utils.upscale(
os.environ['HF_TOKEN'],
prompt=prompt,
img_path=input,
output=output,
steps=steps)
@skynet.group()
def run(*args, **kwargs):
pass
@run.command()
@click.option('--loglevel', '-l', default='warning', help='Logging level')
@click.option(
'--host', '-h', default='localhost:5432')
@click.option(
'--pass', '-p', default='password')
def skynet(
loglevel: str,
host: str,
passw: str
):
async def _run_skynet():
async with run_skynet(
db_host=host,
db_pass=passw
):
await trio.sleep_forever()
trio_asyncio.run(_run_skynet)
@run.command()
@click.option('--loglevel', '-l', default='warning', help='Logging level')
@click.option(
'--uid', '-u', required=True)
@click.option(
'--key', '-k', default='dgpu')
@click.option(
@ -55,6 +89,7 @@ def run(*args, **kwargs):
'--algos', '-a', default=None)
def dgpu(
loglevel: str,
uid: str,
key: str,
cert: str,
algos: Optional[str]
@ -63,6 +98,28 @@ def dgpu(
partial(
open_dgpu_node,
cert,
uid,
key_name=key,
initial_algos=json.loads(algos)
))
@run.command()
@click.option('--loglevel', '-l', default='warning', help='Logging level')
@click.option(
'--key', '-k', default='telegram-frontend')
@click.option(
'--cert', '-c', default='whitelist/telegram-frontend')
def telegram(
loglevel: str,
key: str,
cert: str
):
assert 'TG_TOKEN' in os.environ
trio_asyncio.run(
partial(
run_skynet_telegram,
os.environ['TG_TOKEN'],
key_name=key,
cert_name=cert
))

View File

@ -58,13 +58,16 @@ ALTER TABLE skynet.user_config
def try_decode_uid(uid: str):
if isinstance(uid, int):
return None, uid
try:
proto, uid = uid.split('+')
uid = int(uid)
return proto, uid
except ValueError:
logging.warning(f'got non numeric uid?: {uid}')
logging.warning(f'got non chat proto uid?: {uid}')
return None, None
@ -132,28 +135,38 @@ async def new_user(conn, uid: str):
logging.info(f'new user! {uid}')
tg_id = None
date = datetime.utcnow()
proto, pid = try_decode_uid(uid)
match proto:
case 'tg':
tg_id = pid
async with conn.transaction():
stmt = await conn.prepare('''
INSERT INTO skynet.user(
tg_id, generated, joined, last_prompt, role)
match proto:
case 'tg':
tg_id = pid
stmt = await conn.prepare('''
INSERT INTO skynet.user(
tg_id, generated, joined, last_prompt, role)
VALUES($1, $2, $3, $4, $5)
ON CONFLICT DO NOTHING
''')
await stmt.fetch(
tg_id, 0, date, None, DEFAULT_ROLE
)
VALUES($1, $2, $3, $4, $5)
ON CONFLICT DO NOTHING
''')
await stmt.fetch(
tg_id, 0, date, None, DEFAULT_ROLE
)
new_uid = await get_user(conn, uid)
new_uid = await get_user(conn, uid)
case None:
stmt = await conn.prepare('''
INSERT INTO skynet.user(
id, generated, joined, last_prompt, role)
VALUES($1, $2, $3, $4, $5)
ON CONFLICT DO NOTHING
''')
await stmt.fetch(
pid, 0, date, None, DEFAULT_ROLE
)
new_uid = pid
stmt = await conn.prepare('''
INSERT INTO skynet.user_config(

View File

@ -5,6 +5,7 @@ import io
import trio
import json
import uuid
import base64
import random
import logging
@ -16,10 +17,16 @@ import pynng
import torch
from pynng import TLSConfig
from OpenSSL.crypto import (
load_privatekey,
load_certificate,
FILETYPE_PEM
)
from diffusers import (
StableDiffusionPipeline,
EulerAncestralDiscreteScheduler
)
from diffusers.models import UNet2DConditionModel
from .structs import *
from .constants import *
@ -58,6 +65,7 @@ class DGPUComputeError(BaseException):
async def open_dgpu_node(
cert_name: str,
unique_id: str,
key_name: Optional[str],
rpc_address: str = DEFAULT_RPC_ADDR,
dgpu_address: str = DEFAULT_DGPU_ADDR,
@ -122,6 +130,7 @@ async def open_dgpu_node(
async with open_skynet_rpc(
unique_id,
security=security,
cert_name=cert_name,
key_name=key_name
@ -131,16 +140,23 @@ async def open_dgpu_node(
if security:
# load tls certs
if not key_name:
key_name = certs_name
key_name = cert_name
certs_dir = Path(DEFAULT_CERTS_DIR).resolve()
skynet_cert_path = certs_dir / 'brain.cert'
tls_cert_path = certs_dir / f'{cert_name}.cert'
tls_key_path = certs_dir / f'{key_name}.key'
skynet_cert = skynet_cert_path.read_text()
tls_cert = tls_cert_path.read_text()
tls_key = tls_key_path.read_text()
cert_name = tls_cert_path.stem
skynet_cert_data = skynet_cert_path.read_text()
skynet_cert = load_certificate(FILETYPE_PEM, skynet_cert_data)
tls_cert_data = tls_cert_path.read_text()
tls_key_data = tls_key_path.read_text()
tls_key = load_privatekey(FILETYPE_PEM, tls_key_data)
logging.info(f'skynet cert: {skynet_cert_path}')
logging.info(f'dgpu cert: {tls_cert_path}')
@ -149,17 +165,16 @@ async def open_dgpu_node(
dgpu_address = 'tls+' + dgpu_address
tls_config = TLSConfig(
TLSConfig.MODE_CLIENT,
own_key_string=tls_key,
own_cert_string=tls_cert,
ca_string=skynet_cert)
own_key_string=tls_key_data,
own_cert_string=tls_cert_data,
ca_string=skynet_cert_data)
logging.info(f'connecting to {dgpu_address}')
with pynng.Bus0() as dgpu_sock:
dgpu_sock.tls_config = tls_config
dgpu_sock.dial(dgpu_address)
res = await rpc_call(name.hex, 'dgpu_online')
logging.info(res)
res = await rpc_call('dgpu_online')
assert 'ok' in res.result
try:
@ -168,30 +183,55 @@ async def open_dgpu_node(
req = DGPUBusRequest(
**json.loads(msg.decode()))
if req.nid != name.hex:
logging.info('witnessed request {req.rid}, for {req.nid}')
if req.nid != unique_id:
logging.info(
f'witnessed msg {req.rid}, node involved: {req.nid}')
continue
if security:
req.verify(skynet_cert)
ack_resp = DGPUBusResponse(
rid=req.rid,
nid=req.nid,
params={'ack': {}}
)
if security:
ack_resp.sign(tls_key, cert_name)
# send ack
await dgpu_sock.asend(
bytes.fromhex(req.rid) + b'ack')
json.dumps(ack_resp.to_dict()).encode())
logging.info(f'sent ack, processing {req.rid}...')
try:
img = await gpu_compute_one(
ImageGenRequest(**req.params))
img_resp = DGPUBusResponse(
rid=req.rid,
nid=req.nid,
params={'img': base64.b64encode(img).hex()}
)
except DGPUComputeError as e:
img = b'error' + str(e).encode()
img_resp = DGPUBusResponse(
rid=req.rid,
nid=req.nid,
params={'error': str(e)}
)
if security:
img_resp.sign(tls_key, cert_name)
# send final image
await dgpu_sock.asend(
bytes.fromhex(req.rid) + img)
json.dumps(img_resp.to_dict()).encode())
except KeyboardInterrupt:
logging.info('interrupt caught, stopping...')
finally:
res = await rpc_call(name.hex, 'dgpu_offline')
logging.info(res)
res = await rpc_call('dgpu_offline')
assert 'ok' in res.result

View File

@ -9,6 +9,11 @@ from contextlib import asynccontextmanager as acm
import pynng
from pynng import TLSConfig
from OpenSSL.crypto import (
load_privatekey,
load_certificate,
FILETYPE_PEM
)
from ..structs import SkynetRPCRequest, SkynetRPCResponse
from ..constants import *
@ -30,56 +35,73 @@ class ConfigSizeDivisionByEight(BaseException):
...
async def rpc_call(
sock,
uid: Union[int, str],
method: str,
params: dict = {}
):
req = SkynetRPCRequest(
uid=uid,
method=method,
params=params
)
await sock.asend(
json.dumps(
req.to_dict()).encode())
return SkynetRPCResponse(
**json.loads(
(await sock.arecv_msg()).bytes.decode()))
@acm
async def open_skynet_rpc(
unique_id: str,
rpc_address: str = DEFAULT_RPC_ADDR,
security: bool = False,
cert_name: Optional[str] = None,
key_name: Optional[str] = None
):
tls_config = None
if security:
# load tls certs
if not key_name:
key_name = certs_name
key_name = cert_name
certs_dir = Path(DEFAULT_CERTS_DIR).resolve()
skynet_cert = (certs_dir / 'brain.cert').read_text()
tls_cert = (certs_dir / f'{cert_name}.cert').read_text()
tls_key = (certs_dir / f'{key_name}.key').read_text()
skynet_cert_data = (certs_dir / 'brain.cert').read_text()
skynet_cert = load_certificate(FILETYPE_PEM, skynet_cert_data)
tls_cert_path = certs_dir / f'{cert_name}.cert'
tls_cert_data = tls_cert_path.read_text()
tls_cert = load_certificate(FILETYPE_PEM, tls_cert_data)
cert_name = tls_cert_path.stem
tls_key_data = (certs_dir / f'{key_name}.key').read_text()
tls_key = load_privatekey(FILETYPE_PEM, tls_key_data)
rpc_address = 'tls+' + rpc_address
tls_config = TLSConfig(
TLSConfig.MODE_CLIENT,
own_key_string=tls_key,
own_cert_string=tls_cert,
ca_string=skynet_cert)
own_key_string=tls_key_data,
own_cert_string=tls_cert_data,
ca_string=skynet_cert_data)
with pynng.Req0() as sock:
if security:
sock.tls_config = tls_config
sock.dial(rpc_address)
async def _rpc_call(*args, **kwargs):
return await rpc_call(sock, *args, **kwargs)
async def _rpc_call(
method: str,
params: dict = {},
uid: Optional[Union[int, str]] = None
):
req = SkynetRPCRequest(
uid=uid if uid else unique_id,
method=method,
params=params
)
if security:
req.sign(tls_key, cert_name)
await sock.asend(
json.dumps(
req.to_dict()).encode())
resp = SkynetRPCResponse(
**json.loads(
(await sock.arecv_msg()).bytes.decode()))
if security:
resp.verify(skynet_cert)
return resp
yield _rpc_call

View File

@ -18,14 +18,19 @@ PREFIX = 'tg'
async def run_skynet_telegram(
tg_token: str
tg_token: str,
key_name: str = 'telegram-frontend',
cert_name: str = 'whitelist/telegram-frontend'
):
logging.basicConfig(level=logging.INFO)
bot = AsyncTeleBot(tg_token)
with open_skynet_rpc(
security=True, cert_name='telegram-frontend'
'skynet-telegram-0',
security=True,
cert_name=cert,
key_name=key
) as rpc_call:
async def _rpc_call(
@ -33,7 +38,8 @@ async def run_skynet_telegram(
method: str,
params: dict
):
return await rpc_call(f'{PREFIX}+{uid}', method, params)
return await rpc_call(
method, params, uid=f'{PREFIX}+{uid}')
@bot.message_handler(commands=['help'])
async def send_help(message):

View File

@ -18,6 +18,8 @@
Built-in (extension) types.
"""
import sys
import json
from typing import Optional, Union
from pprint import pformat
@ -83,15 +85,51 @@ class Struct(
setattr(self, fname, ftype(getattr(self, fname)))
# proto
from OpenSSL.crypto import PKey, X509, verify, sign
class SkynetRPCRequest(Struct):
class AuthenticatedStruct(Struct):
cert: Optional[str] = None
sig: Optional[str] = None
def to_unsigned_dict(self) -> dict:
self_dict = self.to_dict()
if 'sig' in self_dict:
del self_dict['sig']
if 'cert' in self_dict:
del self_dict['cert']
return self_dict
def unsigned_to_bytes(self) -> bytes:
return json.dumps(
self.to_unsigned_dict()).encode()
def sign(self, key: PKey, cert: str):
self.cert = cert
self.sig = sign(
key, self.unsigned_to_bytes(), 'sha256').hex()
def verify(self, cert: X509):
if not self.sig:
raise ValueError('Tried to verify unsigned request')
return verify(
cert, bytes.fromhex(self.sig), self.unsigned_to_bytes(), 'sha256')
class SkynetRPCRequest(AuthenticatedStruct):
uid: Union[str, int] # user unique id
method: str # rpc method name
params: dict # variable params
class SkynetRPCResponse(Struct):
class SkynetRPCResponse(AuthenticatedStruct):
result: dict
class ImageGenRequest(Struct):
prompt: str
step: int
@ -102,8 +140,15 @@ class ImageGenRequest(Struct):
algo: str
upscaler: Optional[str]
class DGPUBusRequest(Struct):
class DGPUBusRequest(AuthenticatedStruct):
rid: str # req id
nid: str # node id
task: str
params: dict
class DGPUBusResponse(AuthenticatedStruct):
rid: str # req id
nid: str # node id
params: dict

View File

@ -7,44 +7,37 @@ from pathlib import Path
import torch
from diffusers import StableDiffusionPipeline
from PIL import Image
from diffusers import (
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
EulerAncestralDiscreteScheduler
)
from huggingface_hub import login
from .dgpu import pipeline_for
def txt2img(
hf_token: str,
model_name: str,
prompt: str,
output: str,
width: int, height: int,
guidance: float,
steps: int,
seed: Optional[int]
model: str = 'midj',
prompt: str = 'a red old tractor in a sunny wheat field',
output: str = 'output.png',
width: int = 512, height: int = 512,
guidance: float = 10,
steps: int = 28,
seed: Optional[int] = None
):
assert torch.cuda.is_available()
torch.cuda.empty_cache()
torch.cuda.set_per_process_memory_fraction(0.333)
torch.cuda.set_per_process_memory_fraction(1.0)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
login(token=hf_token)
pipe = pipeline_for(model)
params = {
'torch_dtype': torch.float16,
'safety_checker': None
}
if model_name == 'runwayml/stable-diffusion-v1-5':
params['revision'] = 'fp16'
pipe = StableDiffusionPipeline.from_pretrained(
model_name, **params)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe.scheduler.config)
pipe = pipe.to("cuda")
seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64)
seed = seed if seed else random.randint(0, 2 ** 64)
prompt = prompt
image = pipe(
prompt,
@ -55,3 +48,34 @@ def txt2img(
).images[0]
image.save(output)
def upscale(
hf_token: str,
prompt: str = 'a red old tractor in a sunny wheat field',
img_path: str = 'input.png',
output: str = 'output.png',
steps: int = 28
):
assert torch.cuda.is_available()
torch.cuda.empty_cache()
torch.cuda.set_per_process_memory_fraction(1.0)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
login(token=hf_token)
params = {
'torch_dtype': torch.float16,
'safety_checker': None
}
pipe = StableDiffusionUpscalePipeline.from_pretrained(
'stabilityai/stable-diffusion-x4-upscaler', **params)
prompt = prompt
image = pipe(
prompt,
image=Image.open(img_path)
).images[0]
image.save(output)

View File

@ -97,17 +97,22 @@ def dgpu_workers(request, dockerctl, skynet_running):
num_containers, initial_algos = request.param
cmd = f'''
pip install -e . && \
skynet run dgpu --algos=\'{json.dumps(initial_algos)}\'
'''
cmds = []
for i in range(num_containers):
cmd = f'''
pip install -e . && \
skynet run dgpu \
--algos=\'{json.dumps(initial_algos)}\' \
--uid=dgpu-{i}
'''
cmds.append(['bash', '-c', cmd])
logging.info(f'launching: \n{cmd}')
with dockerctl.run(
DOCKER_RUNTIME_CUDA,
name='skynet-test-runtime-cuda',
command=['bash', '-c', cmd],
commands=cmds,
environment={
'HF_TOKEN': os.environ['HF_TOKEN'],
'HF_HOME': '/skynet/hf_home'

View File

@ -26,8 +26,7 @@ async def wait_for_dgpus(rpc, amount: int, timeout: float = 30.0):
start_time = time.time()
current_time = time.time()
while not gpu_ready and (current_time - start_time) < timeout:
res = await rpc('dgpu-test', 'dgpu_workers')
logging.info(res)
res = await rpc('dgpu_workers')
if res.result['ok'] >= amount:
break
@ -40,6 +39,7 @@ async def wait_for_dgpus(rpc, amount: int, timeout: float = 30.0):
_images = set()
async def check_request_img(
i: int,
uid: int = 0,
width: int = 512,
height: int = 512,
expect_unique=True
@ -47,12 +47,13 @@ async def check_request_img(
global _images
async with open_skynet_rpc(
uid,
security=True,
cert_name='whitelist/testing',
key_name='testing'
) as rpc_call:
res = await rpc_call(
'tg+580213293', 'txt2img', {
'txt2img', {
'prompt': 'red old tractor in a sunny wheat field',
'step': 28,
'width': width, 'height': height,
@ -88,6 +89,7 @@ async def test_dgpu_worker_compute_error(dgpu_workers):
'''
async with open_skynet_rpc(
'test-ctx',
security=True,
cert_name='whitelist/testing',
key_name='testing'
@ -110,6 +112,7 @@ async def test_dgpu_workers(dgpu_workers):
'''
async with open_skynet_rpc(
'test-ctx',
security=True,
cert_name='whitelist/testing',
key_name='testing'
@ -126,6 +129,7 @@ async def test_dgpu_workers_two(dgpu_workers):
'''Generate two images in two separate dgpu workers
'''
async with open_skynet_rpc(
'test-ctx',
security=True,
cert_name='whitelist/testing',
key_name='testing'
@ -143,6 +147,7 @@ async def test_dgpu_worker_algo_swap(dgpu_workers):
'''Generate an image using a non default model
'''
async with open_skynet_rpc(
'test-ctx',
security=True,
cert_name='whitelist/testing',
key_name='testing'
@ -158,35 +163,32 @@ async def test_dgpu_rotation_next_worker(dgpu_workers):
rotation happens correctly
'''
async with open_skynet_rpc(
'test-ctx',
security=True,
cert_name='whitelist/testing',
key_name='testing'
) as test_rpc:
await wait_for_dgpus(test_rpc, 3)
res = await test_rpc('testing-rpc', 'dgpu_next')
logging.info(res)
res = await test_rpc('dgpu_next')
assert 'ok' in res.result
assert res.result['ok'] == 0
await check_request_img(0)
res = await test_rpc('testing-rpc', 'dgpu_next')
logging.info(res)
res = await test_rpc('dgpu_next')
assert 'ok' in res.result
assert res.result['ok'] == 1
await check_request_img(0)
res = await test_rpc('testing-rpc', 'dgpu_next')
logging.info(res)
res = await test_rpc('dgpu_next')
assert 'ok' in res.result
assert res.result['ok'] == 2
await check_request_img(0)
res = await test_rpc('testing-rpc', 'dgpu_next')
logging.info(res)
res = await test_rpc('dgpu_next')
assert 'ok' in res.result
assert res.result['ok'] == 0
@ -198,6 +200,7 @@ async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers):
next_worker rotation happens correctly
'''
async with open_skynet_rpc(
'test-ctx',
security=True,
cert_name='whitelist/testing',
key_name='testing'
@ -213,8 +216,7 @@ async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers):
dgpu_workers[0].wait()
res = await test_rpc('testing-rpc', 'dgpu_workers')
logging.info(res)
res = await test_rpc('dgpu_workers')
assert 'ok' in res.result
assert res.result['ok'] == 2
@ -224,14 +226,17 @@ async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers):
async def test_dgpu_no_ack_node_disconnect(skynet_running):
'''Mock a node that connects, gets a request but fails to
acknowledge it, then check skynet correctly drops the node
'''
async with open_skynet_rpc(
'test-ctx',
security=True,
cert_name='whitelist/testing',
key_name='testing'
) as rpc_call:
res = await rpc_call('dgpu-0', 'dgpu_online')
logging.info(res)
res = await rpc_call('dgpu_online')
assert 'ok' in res.result
await wait_for_dgpus(rpc_call, 1)
@ -241,8 +246,34 @@ async def test_dgpu_no_ack_node_disconnect(skynet_running):
assert 'dgpu failed to acknowledge request' in str(e)
res = await rpc_call('testing-rpc', 'dgpu_workers')
logging.info(res)
res = await rpc_call('dgpu_workers')
assert 'ok' in res.result
assert res.result['ok'] == 0
@pytest.mark.parametrize(
'dgpu_workers', [(1, ['midj'])], indirect=True)
async def test_dgpu_timeout_while_processing(dgpu_workers):
'''Stop node while processing request to cause timeout and
then check skynet correctly drops the node.
'''
async with open_skynet_rpc(
'test-ctx',
security=True,
cert_name='whitelist/testing',
key_name='testing'
) as test_rpc:
await wait_for_dgpus(test_rpc, 1)
async def check_request_img_raises():
with pytest.raises(SkynetDGPUComputeError) as e:
await check_request_img(0)
assert 'timeout while processing request' in str(e)
async with trio.open_nursery() as n:
n.start_soon(check_request_img_raises)
await trio.sleep(1)
ec, out = dgpu_workers[0].exec_run(
['pkill', '-TERM', '-f', 'skynet'])
assert ec == 0

View File

@ -12,64 +12,59 @@ from skynet.structs import *
from skynet.frontend import open_skynet_rpc
async def test_skynet(skynet_running):
...
async def test_skynet_attempt_insecure(skynet_running):
with pytest.raises(pynng.exceptions.NNGException) as e:
async with open_skynet_rpc():
...
async with open_skynet_rpc('bad-actor'):
...
assert str(e.value) == 'Connection shutdown'
async def test_skynet_dgpu_connection_simple(skynet_running):
async with open_skynet_rpc(
'dgpu-0',
security=True,
cert_name='whitelist/testing',
key_name='testing'
) as rpc_call:
# check 0 nodes are connected
res = await rpc_call('dgpu-0', 'dgpu_workers')
logging.info(res)
res = await rpc_call('dgpu_workers')
assert 'ok' in res.result
assert res.result['ok'] == 0
# check next worker is None
res = await rpc_call('dgpu-0', 'dgpu_next')
logging.info(res)
res = await rpc_call('dgpu_next')
assert 'ok' in res.result
assert res.result['ok'] == None
# connect 1 dgpu
res = await rpc_call(
'dgpu-0', 'dgpu_online')
logging.info(res)
res = await rpc_call('dgpu_online')
assert 'ok' in res.result
# check 1 node is connected
res = await rpc_call('dgpu-0', 'dgpu_workers')
logging.info(res)
res = await rpc_call('dgpu_workers')
assert 'ok' in res.result
assert res.result['ok'] == 1
# check next worker is 0
res = await rpc_call('dgpu-0', 'dgpu_next')
logging.info(res)
res = await rpc_call('dgpu_next')
assert 'ok' in res.result
assert res.result['ok'] == 0
# disconnect 1 dgpu
res = await rpc_call(
'dgpu-0', 'dgpu_offline')
logging.info(res)
res = await rpc_call('dgpu_offline')
assert 'ok' in res.result
# check 0 nodes are connected
res = await rpc_call('dgpu-0', 'dgpu_workers')
logging.info(res)
res = await rpc_call('dgpu_workers')
assert 'ok' in res.result
assert res.result['ok'] == 0
# check next worker is None
res = await rpc_call('dgpu-0', 'dgpu_next')
logging.info(res)
res = await rpc_call('dgpu_next')
assert 'ok' in res.result
assert res.result['ok'] == None