mirror of https://github.com/skygpu/skynet.git
Add authenticated messaging, also cmd line utils txt2img and upscale
parent
f6326ad05c
commit
6bc555f0d6
2
setup.py
2
setup.py
|
@ -10,6 +10,8 @@ setup(
|
||||||
entry_points={
|
entry_points={
|
||||||
'console_scripts': [
|
'console_scripts': [
|
||||||
'skynet = skynet.cli:skynet',
|
'skynet = skynet.cli:skynet',
|
||||||
|
'txt2img = skynet.cli:txt2img',
|
||||||
|
'upscale = skynet.cli:upscale'
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
install_requires=['click']
|
install_requires=['click']
|
||||||
|
|
110
skynet/brain.py
110
skynet/brain.py
|
@ -16,6 +16,11 @@ import pynng
|
||||||
import trio_asyncio
|
import trio_asyncio
|
||||||
|
|
||||||
from pynng import TLSConfig
|
from pynng import TLSConfig
|
||||||
|
from OpenSSL.crypto import (
|
||||||
|
load_privatekey,
|
||||||
|
load_certificate,
|
||||||
|
FILETYPE_PEM
|
||||||
|
)
|
||||||
|
|
||||||
from .db import *
|
from .db import *
|
||||||
from .structs import *
|
from .structs import *
|
||||||
|
@ -34,12 +39,14 @@ class SkynetDGPUComputeError(BaseException):
|
||||||
class SkynetShutdownRequested(BaseException):
|
class SkynetShutdownRequested(BaseException):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
@acm
|
@acm
|
||||||
async def open_rpc_service(sock, dgpu_bus, db_pool):
|
async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||||
nodes = OrderedDict()
|
nodes = OrderedDict()
|
||||||
wip_reqs = {}
|
wip_reqs = {}
|
||||||
fin_reqs = {}
|
fin_reqs = {}
|
||||||
next_worker: Optional[int] = None
|
next_worker: Optional[int] = None
|
||||||
|
security = len(tls_whitelist) > 0
|
||||||
|
|
||||||
def connect_node(uid):
|
def connect_node(uid):
|
||||||
nonlocal next_worker
|
nonlocal next_worker
|
||||||
|
@ -109,27 +116,20 @@ async def open_rpc_service(sock, dgpu_bus, db_pool):
|
||||||
async def dgpu_image_streamer():
|
async def dgpu_image_streamer():
|
||||||
nonlocal wip_reqs, fin_reqs
|
nonlocal wip_reqs, fin_reqs
|
||||||
while True:
|
while True:
|
||||||
msg = await dgpu_bus.arecv_msg()
|
msg = DGPUBusResponse(
|
||||||
rid = UUID(bytes=msg.bytes[:16]).hex
|
**json.loads(
|
||||||
raw_msg = msg.bytes[16:]
|
(await dgpu_bus.arecv()).decode()))
|
||||||
logging.info(f'streamer got back {rid} of size {len(raw_msg)}')
|
|
||||||
match raw_msg[:5]:
|
|
||||||
case b'error':
|
|
||||||
img = raw_msg.decode()
|
|
||||||
|
|
||||||
case b'ack':
|
if security:
|
||||||
img = raw_msg
|
msg.verify(tls_whitelist[msg.cert])
|
||||||
|
|
||||||
case _:
|
if msg.rid not in wip_reqs:
|
||||||
img = base64.b64encode(raw_msg).hex()
|
|
||||||
|
|
||||||
if rid not in wip_reqs:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
fin_reqs[rid] = img
|
fin_reqs[msg.rid] = msg
|
||||||
event = wip_reqs[rid]
|
event = wip_reqs[msg.rid]
|
||||||
event.set()
|
event.set()
|
||||||
del wip_reqs[rid]
|
del wip_reqs[msg.rid]
|
||||||
|
|
||||||
async def dgpu_stream_one_img(req: ImageGenRequest):
|
async def dgpu_stream_one_img(req: ImageGenRequest):
|
||||||
nonlocal wip_reqs, fin_reqs, next_worker
|
nonlocal wip_reqs, fin_reqs, next_worker
|
||||||
|
@ -151,6 +151,9 @@ async def open_rpc_service(sock, dgpu_bus, db_pool):
|
||||||
|
|
||||||
logging.info(f'dgpu_bus req: {dgpu_req}')
|
logging.info(f'dgpu_bus req: {dgpu_req}')
|
||||||
|
|
||||||
|
if security:
|
||||||
|
dgpu_req.sign(tls_key, 'skynet')
|
||||||
|
|
||||||
await dgpu_bus.asend(
|
await dgpu_bus.asend(
|
||||||
json.dumps(dgpu_req.to_dict()).encode())
|
json.dumps(dgpu_req.to_dict()).encode())
|
||||||
|
|
||||||
|
@ -163,8 +166,8 @@ async def open_rpc_service(sock, dgpu_bus, db_pool):
|
||||||
disconnect_node(nid)
|
disconnect_node(nid)
|
||||||
raise SkynetDGPUOffline('dgpu failed to acknowledge request')
|
raise SkynetDGPUOffline('dgpu failed to acknowledge request')
|
||||||
|
|
||||||
ack = fin_reqs[rid]
|
ack_msg = fin_reqs[rid]
|
||||||
if ack != b'ack':
|
if 'ack' not in ack_msg.params:
|
||||||
disconnect_node(nid)
|
disconnect_node(nid)
|
||||||
raise SkynetDGPUOffline('dgpu failed to acknowledge request')
|
raise SkynetDGPUOffline('dgpu failed to acknowledge request')
|
||||||
|
|
||||||
|
@ -178,15 +181,13 @@ async def open_rpc_service(sock, dgpu_bus, db_pool):
|
||||||
|
|
||||||
nodes[nid]['task'] = None
|
nodes[nid]['task'] = None
|
||||||
|
|
||||||
img = fin_reqs[rid]
|
img_resp = fin_reqs[rid]
|
||||||
del fin_reqs[rid]
|
del fin_reqs[rid]
|
||||||
|
|
||||||
logging.info(f'done streaming {len(img)} bytes')
|
if 'error' in img_resp.params:
|
||||||
|
raise SkynetDGPUComputeError(img_resp.params['error'])
|
||||||
|
|
||||||
if 'error' in img:
|
return rid, img_resp.params['img']
|
||||||
raise SkynetDGPUComputeError(img)
|
|
||||||
|
|
||||||
return rid, img
|
|
||||||
|
|
||||||
async def handle_user_request(rpc_ctx, req):
|
async def handle_user_request(rpc_ctx, req):
|
||||||
try:
|
try:
|
||||||
|
@ -266,9 +267,13 @@ async def open_rpc_service(sock, dgpu_bus, db_pool):
|
||||||
'message': str(e)
|
'message': str(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
resp = SkynetRPCResponse(result=result)
|
||||||
|
|
||||||
|
if security:
|
||||||
|
resp.sign(tls_key, 'skynet')
|
||||||
|
|
||||||
await rpc_ctx.asend(
|
await rpc_ctx.asend(
|
||||||
json.dumps(
|
json.dumps(resp.to_dict()).encode())
|
||||||
SkynetRPCResponse(result=result).to_dict()).encode())
|
|
||||||
|
|
||||||
async def request_service(n):
|
async def request_service(n):
|
||||||
nonlocal next_worker
|
nonlocal next_worker
|
||||||
|
@ -279,7 +284,19 @@ async def open_rpc_service(sock, dgpu_bus, db_pool):
|
||||||
content = msg.bytes.decode()
|
content = msg.bytes.decode()
|
||||||
req = SkynetRPCRequest(**json.loads(content))
|
req = SkynetRPCRequest(**json.loads(content))
|
||||||
|
|
||||||
logging.info(req)
|
if security:
|
||||||
|
if req.cert not in tls_whitelist:
|
||||||
|
logging.warning(
|
||||||
|
f'{req.cert} not in tls whitelist and security=True')
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
req.verify(tls_whitelist[req.cert])
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
logging.warning(
|
||||||
|
f'{req.cert} sent an unauthenticated msg with security=True')
|
||||||
|
continue
|
||||||
|
|
||||||
result = {}
|
result = {}
|
||||||
|
|
||||||
|
@ -303,10 +320,14 @@ async def open_rpc_service(sock, dgpu_bus, db_pool):
|
||||||
handle_user_request, ctx, req)
|
handle_user_request, ctx, req)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
resp = SkynetRPCResponse(
|
||||||
|
result={'ok': result})
|
||||||
|
|
||||||
|
if security:
|
||||||
|
resp.sign(tls_key, 'skynet')
|
||||||
|
|
||||||
await ctx.asend(
|
await ctx.asend(
|
||||||
json.dumps(
|
json.dumps(resp.to_dict()).encode())
|
||||||
SkynetRPCResponse(
|
|
||||||
result={'ok': result}).to_dict()).encode())
|
|
||||||
|
|
||||||
|
|
||||||
async with trio.open_nursery() as n:
|
async with trio.open_nursery() as n:
|
||||||
|
@ -334,22 +355,28 @@ async def run_skynet(
|
||||||
if security:
|
if security:
|
||||||
# load tls certs
|
# load tls certs
|
||||||
certs_dir = Path(DEFAULT_CERTS_DIR).resolve()
|
certs_dir = Path(DEFAULT_CERTS_DIR).resolve()
|
||||||
tls_key = (certs_dir / DEFAULT_CERT_SKYNET_PRIV).read_text()
|
|
||||||
tls_cert = (certs_dir / DEFAULT_CERT_SKYNET_PUB).read_text()
|
|
||||||
tls_whitelist = [
|
|
||||||
(cert_path).read_text()
|
|
||||||
for cert_path in (certs_dir / 'whitelist').glob('*.cert')]
|
|
||||||
|
|
||||||
cert_start = tls_cert.index('\n') + 1
|
tls_key_data = (certs_dir / DEFAULT_CERT_SKYNET_PRIV).read_text()
|
||||||
logging.info(f'tls_cert: {tls_cert[cert_start:cert_start+64]}...')
|
tls_key = load_privatekey(FILETYPE_PEM, tls_key_data)
|
||||||
|
|
||||||
|
tls_cert_data = (certs_dir / DEFAULT_CERT_SKYNET_PUB).read_text()
|
||||||
|
tls_cert = load_certificate(FILETYPE_PEM, tls_cert_data)
|
||||||
|
|
||||||
|
tls_whitelist = {}
|
||||||
|
for cert_path in (certs_dir / 'whitelist').glob('*.cert'):
|
||||||
|
tls_whitelist[cert_path.stem] = load_certificate(
|
||||||
|
FILETYPE_PEM, cert_path.read_text())
|
||||||
|
|
||||||
|
cert_start = tls_cert_data.index('\n') + 1
|
||||||
|
logging.info(f'tls_cert: {tls_cert_data[cert_start:cert_start+64]}...')
|
||||||
logging.info(f'tls_whitelist len: {len(tls_whitelist)}')
|
logging.info(f'tls_whitelist len: {len(tls_whitelist)}')
|
||||||
|
|
||||||
rpc_address = 'tls+' + rpc_address
|
rpc_address = 'tls+' + rpc_address
|
||||||
dgpu_address = 'tls+' + dgpu_address
|
dgpu_address = 'tls+' + dgpu_address
|
||||||
tls_config = TLSConfig(
|
tls_config = TLSConfig(
|
||||||
TLSConfig.MODE_SERVER,
|
TLSConfig.MODE_SERVER,
|
||||||
own_key_string=tls_key,
|
own_key_string=tls_key_data,
|
||||||
own_cert_string=tls_cert)
|
own_cert_string=tls_cert_data)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
pynng.Rep0() as rpc_sock,
|
pynng.Rep0() as rpc_sock,
|
||||||
|
@ -367,7 +394,8 @@ async def run_skynet(
|
||||||
dgpu_bus.listen(dgpu_address)
|
dgpu_bus.listen(dgpu_address)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with open_rpc_service(rpc_sock, dgpu_bus, db_pool):
|
async with open_rpc_service(
|
||||||
|
rpc_sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
except SkynetShutdownRequested:
|
except SkynetShutdownRequested:
|
||||||
|
|
|
@ -9,44 +9,78 @@ from functools import partial
|
||||||
import trio
|
import trio
|
||||||
import click
|
import click
|
||||||
|
|
||||||
|
from . import utils
|
||||||
from .dgpu import open_dgpu_node
|
from .dgpu import open_dgpu_node
|
||||||
from .utils import txt2img
|
from .brain import run_skynet
|
||||||
from .constants import ALGOS
|
|
||||||
|
from .frontend.telegram import run_skynet_telegram
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
def skynet(*args, **kwargs):
|
def skynet(*args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@skynet.command()
|
|
||||||
@click.option('--model', '-m', default=ALGOS['midj'])
|
@click.command()
|
||||||
|
@click.option('--model', '-m', default='midj')
|
||||||
@click.option(
|
@click.option(
|
||||||
'--prompt', '-p', default='a red tractor in a wheat field')
|
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
|
||||||
@click.option('--output', '-o', default='output.png')
|
@click.option('--output', '-o', default='output.png')
|
||||||
@click.option('--width', '-w', default=512)
|
@click.option('--width', '-w', default=512)
|
||||||
@click.option('--height', '-h', default=512)
|
@click.option('--height', '-h', default=512)
|
||||||
@click.option('--guidance', '-g', default=10.0)
|
@click.option('--guidance', '-g', default=10.0)
|
||||||
@click.option('--steps', '-s', default=26)
|
@click.option('--steps', '-s', default=26)
|
||||||
@click.option('--seed', '-S', default=None)
|
@click.option('--seed', '-S', default=None)
|
||||||
def txt2img(*args
|
def txt2img(*args, **kwargs):
|
||||||
# model: str,
|
|
||||||
# prompt: str,
|
|
||||||
# output: str
|
|
||||||
# width: int, height: int,
|
|
||||||
# guidance: float,
|
|
||||||
# steps: int,
|
|
||||||
# seed: Optional[int]
|
|
||||||
):
|
|
||||||
assert 'HF_TOKEN' in os.environ
|
assert 'HF_TOKEN' in os.environ
|
||||||
txt2img(
|
utils.txt2img(os.environ['HF_TOKEN'], **kwargs)
|
||||||
os.environ['HF_TOKEN'], *args)
|
|
||||||
|
@click.command()
|
||||||
|
@click.option(
|
||||||
|
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
|
||||||
|
@click.option('--input', '-i', default='input.png')
|
||||||
|
@click.option('--output', '-o', default='output.png')
|
||||||
|
@click.option('--steps', '-s', default=26)
|
||||||
|
def upscale(prompt, input, output, steps):
|
||||||
|
assert 'HF_TOKEN' in os.environ
|
||||||
|
utils.upscale(
|
||||||
|
os.environ['HF_TOKEN'],
|
||||||
|
prompt=prompt,
|
||||||
|
img_path=input,
|
||||||
|
output=output,
|
||||||
|
steps=steps)
|
||||||
|
|
||||||
|
|
||||||
@skynet.group()
|
@skynet.group()
|
||||||
def run(*args, **kwargs):
|
def run(*args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@run.command()
|
@run.command()
|
||||||
@click.option('--loglevel', '-l', default='warning', help='Logging level')
|
@click.option('--loglevel', '-l', default='warning', help='Logging level')
|
||||||
|
@click.option(
|
||||||
|
'--host', '-h', default='localhost:5432')
|
||||||
|
@click.option(
|
||||||
|
'--pass', '-p', default='password')
|
||||||
|
def skynet(
|
||||||
|
loglevel: str,
|
||||||
|
host: str,
|
||||||
|
passw: str
|
||||||
|
):
|
||||||
|
async def _run_skynet():
|
||||||
|
async with run_skynet(
|
||||||
|
db_host=host,
|
||||||
|
db_pass=passw
|
||||||
|
):
|
||||||
|
await trio.sleep_forever()
|
||||||
|
|
||||||
|
trio_asyncio.run(_run_skynet)
|
||||||
|
|
||||||
|
|
||||||
|
@run.command()
|
||||||
|
@click.option('--loglevel', '-l', default='warning', help='Logging level')
|
||||||
|
@click.option(
|
||||||
|
'--uid', '-u', required=True)
|
||||||
@click.option(
|
@click.option(
|
||||||
'--key', '-k', default='dgpu')
|
'--key', '-k', default='dgpu')
|
||||||
@click.option(
|
@click.option(
|
||||||
|
@ -55,6 +89,7 @@ def run(*args, **kwargs):
|
||||||
'--algos', '-a', default=None)
|
'--algos', '-a', default=None)
|
||||||
def dgpu(
|
def dgpu(
|
||||||
loglevel: str,
|
loglevel: str,
|
||||||
|
uid: str,
|
||||||
key: str,
|
key: str,
|
||||||
cert: str,
|
cert: str,
|
||||||
algos: Optional[str]
|
algos: Optional[str]
|
||||||
|
@ -63,6 +98,28 @@ def dgpu(
|
||||||
partial(
|
partial(
|
||||||
open_dgpu_node,
|
open_dgpu_node,
|
||||||
cert,
|
cert,
|
||||||
|
uid,
|
||||||
key_name=key,
|
key_name=key,
|
||||||
initial_algos=json.loads(algos)
|
initial_algos=json.loads(algos)
|
||||||
))
|
))
|
||||||
|
|
||||||
|
|
||||||
|
@run.command()
|
||||||
|
@click.option('--loglevel', '-l', default='warning', help='Logging level')
|
||||||
|
@click.option(
|
||||||
|
'--key', '-k', default='telegram-frontend')
|
||||||
|
@click.option(
|
||||||
|
'--cert', '-c', default='whitelist/telegram-frontend')
|
||||||
|
def telegram(
|
||||||
|
loglevel: str,
|
||||||
|
key: str,
|
||||||
|
cert: str
|
||||||
|
):
|
||||||
|
assert 'TG_TOKEN' in os.environ
|
||||||
|
trio_asyncio.run(
|
||||||
|
partial(
|
||||||
|
run_skynet_telegram,
|
||||||
|
os.environ['TG_TOKEN'],
|
||||||
|
key_name=key,
|
||||||
|
cert_name=cert
|
||||||
|
))
|
||||||
|
|
45
skynet/db.py
45
skynet/db.py
|
@ -58,13 +58,16 @@ ALTER TABLE skynet.user_config
|
||||||
|
|
||||||
|
|
||||||
def try_decode_uid(uid: str):
|
def try_decode_uid(uid: str):
|
||||||
|
if isinstance(uid, int):
|
||||||
|
return None, uid
|
||||||
|
|
||||||
try:
|
try:
|
||||||
proto, uid = uid.split('+')
|
proto, uid = uid.split('+')
|
||||||
uid = int(uid)
|
uid = int(uid)
|
||||||
return proto, uid
|
return proto, uid
|
||||||
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logging.warning(f'got non numeric uid?: {uid}')
|
logging.warning(f'got non chat proto uid?: {uid}')
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
@ -132,28 +135,38 @@ async def new_user(conn, uid: str):
|
||||||
|
|
||||||
logging.info(f'new user! {uid}')
|
logging.info(f'new user! {uid}')
|
||||||
|
|
||||||
tg_id = None
|
|
||||||
date = datetime.utcnow()
|
date = datetime.utcnow()
|
||||||
|
|
||||||
proto, pid = try_decode_uid(uid)
|
proto, pid = try_decode_uid(uid)
|
||||||
|
|
||||||
match proto:
|
|
||||||
case 'tg':
|
|
||||||
tg_id = pid
|
|
||||||
|
|
||||||
async with conn.transaction():
|
async with conn.transaction():
|
||||||
stmt = await conn.prepare('''
|
match proto:
|
||||||
INSERT INTO skynet.user(
|
case 'tg':
|
||||||
tg_id, generated, joined, last_prompt, role)
|
tg_id = pid
|
||||||
|
stmt = await conn.prepare('''
|
||||||
|
INSERT INTO skynet.user(
|
||||||
|
tg_id, generated, joined, last_prompt, role)
|
||||||
|
|
||||||
VALUES($1, $2, $3, $4, $5)
|
VALUES($1, $2, $3, $4, $5)
|
||||||
ON CONFLICT DO NOTHING
|
ON CONFLICT DO NOTHING
|
||||||
''')
|
''')
|
||||||
await stmt.fetch(
|
await stmt.fetch(
|
||||||
tg_id, 0, date, None, DEFAULT_ROLE
|
tg_id, 0, date, None, DEFAULT_ROLE
|
||||||
)
|
)
|
||||||
|
new_uid = await get_user(conn, uid)
|
||||||
|
|
||||||
new_uid = await get_user(conn, uid)
|
case None:
|
||||||
|
stmt = await conn.prepare('''
|
||||||
|
INSERT INTO skynet.user(
|
||||||
|
id, generated, joined, last_prompt, role)
|
||||||
|
|
||||||
|
VALUES($1, $2, $3, $4, $5)
|
||||||
|
ON CONFLICT DO NOTHING
|
||||||
|
''')
|
||||||
|
await stmt.fetch(
|
||||||
|
pid, 0, date, None, DEFAULT_ROLE
|
||||||
|
)
|
||||||
|
new_uid = pid
|
||||||
|
|
||||||
stmt = await conn.prepare('''
|
stmt = await conn.prepare('''
|
||||||
INSERT INTO skynet.user_config(
|
INSERT INTO skynet.user_config(
|
||||||
|
|
|
@ -5,6 +5,7 @@ import io
|
||||||
import trio
|
import trio
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
|
import base64
|
||||||
import random
|
import random
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -16,10 +17,16 @@ import pynng
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pynng import TLSConfig
|
from pynng import TLSConfig
|
||||||
|
from OpenSSL.crypto import (
|
||||||
|
load_privatekey,
|
||||||
|
load_certificate,
|
||||||
|
FILETYPE_PEM
|
||||||
|
)
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
EulerAncestralDiscreteScheduler
|
EulerAncestralDiscreteScheduler
|
||||||
)
|
)
|
||||||
|
from diffusers.models import UNet2DConditionModel
|
||||||
|
|
||||||
from .structs import *
|
from .structs import *
|
||||||
from .constants import *
|
from .constants import *
|
||||||
|
@ -58,6 +65,7 @@ class DGPUComputeError(BaseException):
|
||||||
|
|
||||||
async def open_dgpu_node(
|
async def open_dgpu_node(
|
||||||
cert_name: str,
|
cert_name: str,
|
||||||
|
unique_id: str,
|
||||||
key_name: Optional[str],
|
key_name: Optional[str],
|
||||||
rpc_address: str = DEFAULT_RPC_ADDR,
|
rpc_address: str = DEFAULT_RPC_ADDR,
|
||||||
dgpu_address: str = DEFAULT_DGPU_ADDR,
|
dgpu_address: str = DEFAULT_DGPU_ADDR,
|
||||||
|
@ -122,6 +130,7 @@ async def open_dgpu_node(
|
||||||
|
|
||||||
|
|
||||||
async with open_skynet_rpc(
|
async with open_skynet_rpc(
|
||||||
|
unique_id,
|
||||||
security=security,
|
security=security,
|
||||||
cert_name=cert_name,
|
cert_name=cert_name,
|
||||||
key_name=key_name
|
key_name=key_name
|
||||||
|
@ -131,16 +140,23 @@ async def open_dgpu_node(
|
||||||
if security:
|
if security:
|
||||||
# load tls certs
|
# load tls certs
|
||||||
if not key_name:
|
if not key_name:
|
||||||
key_name = certs_name
|
key_name = cert_name
|
||||||
|
|
||||||
certs_dir = Path(DEFAULT_CERTS_DIR).resolve()
|
certs_dir = Path(DEFAULT_CERTS_DIR).resolve()
|
||||||
|
|
||||||
skynet_cert_path = certs_dir / 'brain.cert'
|
skynet_cert_path = certs_dir / 'brain.cert'
|
||||||
tls_cert_path = certs_dir / f'{cert_name}.cert'
|
tls_cert_path = certs_dir / f'{cert_name}.cert'
|
||||||
tls_key_path = certs_dir / f'{key_name}.key'
|
tls_key_path = certs_dir / f'{key_name}.key'
|
||||||
|
|
||||||
skynet_cert = skynet_cert_path.read_text()
|
cert_name = tls_cert_path.stem
|
||||||
tls_cert = tls_cert_path.read_text()
|
|
||||||
tls_key = tls_key_path.read_text()
|
skynet_cert_data = skynet_cert_path.read_text()
|
||||||
|
skynet_cert = load_certificate(FILETYPE_PEM, skynet_cert_data)
|
||||||
|
|
||||||
|
tls_cert_data = tls_cert_path.read_text()
|
||||||
|
|
||||||
|
tls_key_data = tls_key_path.read_text()
|
||||||
|
tls_key = load_privatekey(FILETYPE_PEM, tls_key_data)
|
||||||
|
|
||||||
logging.info(f'skynet cert: {skynet_cert_path}')
|
logging.info(f'skynet cert: {skynet_cert_path}')
|
||||||
logging.info(f'dgpu cert: {tls_cert_path}')
|
logging.info(f'dgpu cert: {tls_cert_path}')
|
||||||
|
@ -149,17 +165,16 @@ async def open_dgpu_node(
|
||||||
dgpu_address = 'tls+' + dgpu_address
|
dgpu_address = 'tls+' + dgpu_address
|
||||||
tls_config = TLSConfig(
|
tls_config = TLSConfig(
|
||||||
TLSConfig.MODE_CLIENT,
|
TLSConfig.MODE_CLIENT,
|
||||||
own_key_string=tls_key,
|
own_key_string=tls_key_data,
|
||||||
own_cert_string=tls_cert,
|
own_cert_string=tls_cert_data,
|
||||||
ca_string=skynet_cert)
|
ca_string=skynet_cert_data)
|
||||||
|
|
||||||
logging.info(f'connecting to {dgpu_address}')
|
logging.info(f'connecting to {dgpu_address}')
|
||||||
with pynng.Bus0() as dgpu_sock:
|
with pynng.Bus0() as dgpu_sock:
|
||||||
dgpu_sock.tls_config = tls_config
|
dgpu_sock.tls_config = tls_config
|
||||||
dgpu_sock.dial(dgpu_address)
|
dgpu_sock.dial(dgpu_address)
|
||||||
|
|
||||||
res = await rpc_call(name.hex, 'dgpu_online')
|
res = await rpc_call('dgpu_online')
|
||||||
logging.info(res)
|
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -168,30 +183,55 @@ async def open_dgpu_node(
|
||||||
req = DGPUBusRequest(
|
req = DGPUBusRequest(
|
||||||
**json.loads(msg.decode()))
|
**json.loads(msg.decode()))
|
||||||
|
|
||||||
if req.nid != name.hex:
|
if req.nid != unique_id:
|
||||||
logging.info('witnessed request {req.rid}, for {req.nid}')
|
logging.info(
|
||||||
|
f'witnessed msg {req.rid}, node involved: {req.nid}')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if security:
|
||||||
|
req.verify(skynet_cert)
|
||||||
|
|
||||||
|
ack_resp = DGPUBusResponse(
|
||||||
|
rid=req.rid,
|
||||||
|
nid=req.nid,
|
||||||
|
params={'ack': {}}
|
||||||
|
)
|
||||||
|
|
||||||
|
if security:
|
||||||
|
ack_resp.sign(tls_key, cert_name)
|
||||||
|
|
||||||
# send ack
|
# send ack
|
||||||
await dgpu_sock.asend(
|
await dgpu_sock.asend(
|
||||||
bytes.fromhex(req.rid) + b'ack')
|
json.dumps(ack_resp.to_dict()).encode())
|
||||||
|
|
||||||
logging.info(f'sent ack, processing {req.rid}...')
|
logging.info(f'sent ack, processing {req.rid}...')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
img = await gpu_compute_one(
|
img = await gpu_compute_one(
|
||||||
ImageGenRequest(**req.params))
|
ImageGenRequest(**req.params))
|
||||||
|
img_resp = DGPUBusResponse(
|
||||||
|
rid=req.rid,
|
||||||
|
nid=req.nid,
|
||||||
|
params={'img': base64.b64encode(img).hex()}
|
||||||
|
)
|
||||||
|
|
||||||
except DGPUComputeError as e:
|
except DGPUComputeError as e:
|
||||||
img = b'error' + str(e).encode()
|
img_resp = DGPUBusResponse(
|
||||||
|
rid=req.rid,
|
||||||
|
nid=req.nid,
|
||||||
|
params={'error': str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
|
if security:
|
||||||
|
img_resp.sign(tls_key, cert_name)
|
||||||
|
|
||||||
|
# send final image
|
||||||
await dgpu_sock.asend(
|
await dgpu_sock.asend(
|
||||||
bytes.fromhex(req.rid) + img)
|
json.dumps(img_resp.to_dict()).encode())
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logging.info('interrupt caught, stopping...')
|
logging.info('interrupt caught, stopping...')
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
res = await rpc_call(name.hex, 'dgpu_offline')
|
res = await rpc_call('dgpu_offline')
|
||||||
logging.info(res)
|
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result
|
||||||
|
|
|
@ -9,6 +9,11 @@ from contextlib import asynccontextmanager as acm
|
||||||
import pynng
|
import pynng
|
||||||
|
|
||||||
from pynng import TLSConfig
|
from pynng import TLSConfig
|
||||||
|
from OpenSSL.crypto import (
|
||||||
|
load_privatekey,
|
||||||
|
load_certificate,
|
||||||
|
FILETYPE_PEM
|
||||||
|
)
|
||||||
|
|
||||||
from ..structs import SkynetRPCRequest, SkynetRPCResponse
|
from ..structs import SkynetRPCRequest, SkynetRPCResponse
|
||||||
from ..constants import *
|
from ..constants import *
|
||||||
|
@ -30,56 +35,73 @@ class ConfigSizeDivisionByEight(BaseException):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
async def rpc_call(
|
|
||||||
sock,
|
|
||||||
uid: Union[int, str],
|
|
||||||
method: str,
|
|
||||||
params: dict = {}
|
|
||||||
):
|
|
||||||
req = SkynetRPCRequest(
|
|
||||||
uid=uid,
|
|
||||||
method=method,
|
|
||||||
params=params
|
|
||||||
)
|
|
||||||
await sock.asend(
|
|
||||||
json.dumps(
|
|
||||||
req.to_dict()).encode())
|
|
||||||
|
|
||||||
return SkynetRPCResponse(
|
|
||||||
**json.loads(
|
|
||||||
(await sock.arecv_msg()).bytes.decode()))
|
|
||||||
|
|
||||||
|
|
||||||
@acm
|
@acm
|
||||||
async def open_skynet_rpc(
|
async def open_skynet_rpc(
|
||||||
|
unique_id: str,
|
||||||
rpc_address: str = DEFAULT_RPC_ADDR,
|
rpc_address: str = DEFAULT_RPC_ADDR,
|
||||||
security: bool = False,
|
security: bool = False,
|
||||||
cert_name: Optional[str] = None,
|
cert_name: Optional[str] = None,
|
||||||
key_name: Optional[str] = None
|
key_name: Optional[str] = None
|
||||||
):
|
):
|
||||||
tls_config = None
|
tls_config = None
|
||||||
|
|
||||||
if security:
|
if security:
|
||||||
# load tls certs
|
# load tls certs
|
||||||
if not key_name:
|
if not key_name:
|
||||||
key_name = certs_name
|
key_name = cert_name
|
||||||
|
|
||||||
certs_dir = Path(DEFAULT_CERTS_DIR).resolve()
|
certs_dir = Path(DEFAULT_CERTS_DIR).resolve()
|
||||||
skynet_cert = (certs_dir / 'brain.cert').read_text()
|
|
||||||
tls_cert = (certs_dir / f'{cert_name}.cert').read_text()
|
skynet_cert_data = (certs_dir / 'brain.cert').read_text()
|
||||||
tls_key = (certs_dir / f'{key_name}.key').read_text()
|
skynet_cert = load_certificate(FILETYPE_PEM, skynet_cert_data)
|
||||||
|
|
||||||
|
tls_cert_path = certs_dir / f'{cert_name}.cert'
|
||||||
|
tls_cert_data = tls_cert_path.read_text()
|
||||||
|
tls_cert = load_certificate(FILETYPE_PEM, tls_cert_data)
|
||||||
|
cert_name = tls_cert_path.stem
|
||||||
|
|
||||||
|
tls_key_data = (certs_dir / f'{key_name}.key').read_text()
|
||||||
|
tls_key = load_privatekey(FILETYPE_PEM, tls_key_data)
|
||||||
|
|
||||||
rpc_address = 'tls+' + rpc_address
|
rpc_address = 'tls+' + rpc_address
|
||||||
tls_config = TLSConfig(
|
tls_config = TLSConfig(
|
||||||
TLSConfig.MODE_CLIENT,
|
TLSConfig.MODE_CLIENT,
|
||||||
own_key_string=tls_key,
|
own_key_string=tls_key_data,
|
||||||
own_cert_string=tls_cert,
|
own_cert_string=tls_cert_data,
|
||||||
ca_string=skynet_cert)
|
ca_string=skynet_cert_data)
|
||||||
|
|
||||||
with pynng.Req0() as sock:
|
with pynng.Req0() as sock:
|
||||||
if security:
|
if security:
|
||||||
sock.tls_config = tls_config
|
sock.tls_config = tls_config
|
||||||
|
|
||||||
sock.dial(rpc_address)
|
sock.dial(rpc_address)
|
||||||
async def _rpc_call(*args, **kwargs):
|
|
||||||
return await rpc_call(sock, *args, **kwargs)
|
async def _rpc_call(
|
||||||
|
method: str,
|
||||||
|
params: dict = {},
|
||||||
|
uid: Optional[Union[int, str]] = None
|
||||||
|
):
|
||||||
|
req = SkynetRPCRequest(
|
||||||
|
uid=uid if uid else unique_id,
|
||||||
|
method=method,
|
||||||
|
params=params
|
||||||
|
)
|
||||||
|
|
||||||
|
if security:
|
||||||
|
req.sign(tls_key, cert_name)
|
||||||
|
|
||||||
|
await sock.asend(
|
||||||
|
json.dumps(
|
||||||
|
req.to_dict()).encode())
|
||||||
|
|
||||||
|
resp = SkynetRPCResponse(
|
||||||
|
**json.loads(
|
||||||
|
(await sock.arecv_msg()).bytes.decode()))
|
||||||
|
|
||||||
|
if security:
|
||||||
|
resp.verify(skynet_cert)
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
yield _rpc_call
|
yield _rpc_call
|
||||||
|
|
||||||
|
|
|
@ -18,14 +18,19 @@ PREFIX = 'tg'
|
||||||
|
|
||||||
|
|
||||||
async def run_skynet_telegram(
|
async def run_skynet_telegram(
|
||||||
tg_token: str
|
tg_token: str,
|
||||||
|
key_name: str = 'telegram-frontend',
|
||||||
|
cert_name: str = 'whitelist/telegram-frontend'
|
||||||
):
|
):
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
bot = AsyncTeleBot(tg_token)
|
bot = AsyncTeleBot(tg_token)
|
||||||
|
|
||||||
with open_skynet_rpc(
|
with open_skynet_rpc(
|
||||||
security=True, cert_name='telegram-frontend'
|
'skynet-telegram-0',
|
||||||
|
security=True,
|
||||||
|
cert_name=cert,
|
||||||
|
key_name=key
|
||||||
) as rpc_call:
|
) as rpc_call:
|
||||||
|
|
||||||
async def _rpc_call(
|
async def _rpc_call(
|
||||||
|
@ -33,7 +38,8 @@ async def run_skynet_telegram(
|
||||||
method: str,
|
method: str,
|
||||||
params: dict
|
params: dict
|
||||||
):
|
):
|
||||||
return await rpc_call(f'{PREFIX}+{uid}', method, params)
|
return await rpc_call(
|
||||||
|
method, params, uid=f'{PREFIX}+{uid}')
|
||||||
|
|
||||||
@bot.message_handler(commands=['help'])
|
@bot.message_handler(commands=['help'])
|
||||||
async def send_help(message):
|
async def send_help(message):
|
||||||
|
|
|
@ -18,6 +18,8 @@
|
||||||
Built-in (extension) types.
|
Built-in (extension) types.
|
||||||
"""
|
"""
|
||||||
import sys
|
import sys
|
||||||
|
import json
|
||||||
|
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
|
|
||||||
|
@ -83,15 +85,51 @@ class Struct(
|
||||||
setattr(self, fname, ftype(getattr(self, fname)))
|
setattr(self, fname, ftype(getattr(self, fname)))
|
||||||
|
|
||||||
# proto
|
# proto
|
||||||
|
from OpenSSL.crypto import PKey, X509, verify, sign
|
||||||
|
|
||||||
class SkynetRPCRequest(Struct):
|
|
||||||
|
class AuthenticatedStruct(Struct):
|
||||||
|
cert: Optional[str] = None
|
||||||
|
sig: Optional[str] = None
|
||||||
|
|
||||||
|
def to_unsigned_dict(self) -> dict:
|
||||||
|
self_dict = self.to_dict()
|
||||||
|
|
||||||
|
if 'sig' in self_dict:
|
||||||
|
del self_dict['sig']
|
||||||
|
|
||||||
|
if 'cert' in self_dict:
|
||||||
|
del self_dict['cert']
|
||||||
|
|
||||||
|
return self_dict
|
||||||
|
|
||||||
|
def unsigned_to_bytes(self) -> bytes:
|
||||||
|
return json.dumps(
|
||||||
|
self.to_unsigned_dict()).encode()
|
||||||
|
|
||||||
|
def sign(self, key: PKey, cert: str):
|
||||||
|
self.cert = cert
|
||||||
|
self.sig = sign(
|
||||||
|
key, self.unsigned_to_bytes(), 'sha256').hex()
|
||||||
|
|
||||||
|
def verify(self, cert: X509):
|
||||||
|
if not self.sig:
|
||||||
|
raise ValueError('Tried to verify unsigned request')
|
||||||
|
|
||||||
|
return verify(
|
||||||
|
cert, bytes.fromhex(self.sig), self.unsigned_to_bytes(), 'sha256')
|
||||||
|
|
||||||
|
|
||||||
|
class SkynetRPCRequest(AuthenticatedStruct):
|
||||||
uid: Union[str, int] # user unique id
|
uid: Union[str, int] # user unique id
|
||||||
method: str # rpc method name
|
method: str # rpc method name
|
||||||
params: dict # variable params
|
params: dict # variable params
|
||||||
|
|
||||||
class SkynetRPCResponse(Struct):
|
|
||||||
|
class SkynetRPCResponse(AuthenticatedStruct):
|
||||||
result: dict
|
result: dict
|
||||||
|
|
||||||
|
|
||||||
class ImageGenRequest(Struct):
|
class ImageGenRequest(Struct):
|
||||||
prompt: str
|
prompt: str
|
||||||
step: int
|
step: int
|
||||||
|
@ -102,8 +140,15 @@ class ImageGenRequest(Struct):
|
||||||
algo: str
|
algo: str
|
||||||
upscaler: Optional[str]
|
upscaler: Optional[str]
|
||||||
|
|
||||||
class DGPUBusRequest(Struct):
|
|
||||||
|
class DGPUBusRequest(AuthenticatedStruct):
|
||||||
rid: str # req id
|
rid: str # req id
|
||||||
nid: str # node id
|
nid: str # node id
|
||||||
task: str
|
task: str
|
||||||
params: dict
|
params: dict
|
||||||
|
|
||||||
|
|
||||||
|
class DGPUBusResponse(AuthenticatedStruct):
|
||||||
|
rid: str # req id
|
||||||
|
nid: str # node id
|
||||||
|
params: dict
|
||||||
|
|
|
@ -7,44 +7,37 @@ from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import StableDiffusionPipeline
|
from PIL import Image
|
||||||
|
from diffusers import (
|
||||||
|
StableDiffusionPipeline,
|
||||||
|
StableDiffusionUpscalePipeline,
|
||||||
|
EulerAncestralDiscreteScheduler
|
||||||
|
)
|
||||||
from huggingface_hub import login
|
from huggingface_hub import login
|
||||||
|
|
||||||
|
from .dgpu import pipeline_for
|
||||||
|
|
||||||
|
|
||||||
def txt2img(
|
def txt2img(
|
||||||
hf_token: str,
|
hf_token: str,
|
||||||
model_name: str,
|
model: str = 'midj',
|
||||||
prompt: str,
|
prompt: str = 'a red old tractor in a sunny wheat field',
|
||||||
output: str,
|
output: str = 'output.png',
|
||||||
width: int, height: int,
|
width: int = 512, height: int = 512,
|
||||||
guidance: float,
|
guidance: float = 10,
|
||||||
steps: int,
|
steps: int = 28,
|
||||||
seed: Optional[int]
|
seed: Optional[int] = None
|
||||||
):
|
):
|
||||||
assert torch.cuda.is_available()
|
assert torch.cuda.is_available()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.set_per_process_memory_fraction(0.333)
|
torch.cuda.set_per_process_memory_fraction(1.0)
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
login(token=hf_token)
|
login(token=hf_token)
|
||||||
|
pipe = pipeline_for(model)
|
||||||
|
|
||||||
params = {
|
seed = seed if seed else random.randint(0, 2 ** 64)
|
||||||
'torch_dtype': torch.float16,
|
|
||||||
'safety_checker': None
|
|
||||||
}
|
|
||||||
if model_name == 'runwayml/stable-diffusion-v1-5':
|
|
||||||
params['revision'] = 'fp16'
|
|
||||||
|
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(
|
|
||||||
model_name, **params)
|
|
||||||
|
|
||||||
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
|
||||||
pipe.scheduler.config)
|
|
||||||
|
|
||||||
pipe = pipe.to("cuda")
|
|
||||||
|
|
||||||
seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64)
|
|
||||||
prompt = prompt
|
prompt = prompt
|
||||||
image = pipe(
|
image = pipe(
|
||||||
prompt,
|
prompt,
|
||||||
|
@ -55,3 +48,34 @@ def txt2img(
|
||||||
).images[0]
|
).images[0]
|
||||||
|
|
||||||
image.save(output)
|
image.save(output)
|
||||||
|
|
||||||
|
|
||||||
|
def upscale(
|
||||||
|
hf_token: str,
|
||||||
|
prompt: str = 'a red old tractor in a sunny wheat field',
|
||||||
|
img_path: str = 'input.png',
|
||||||
|
output: str = 'output.png',
|
||||||
|
steps: int = 28
|
||||||
|
):
|
||||||
|
assert torch.cuda.is_available()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.set_per_process_memory_fraction(1.0)
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
login(token=hf_token)
|
||||||
|
params = {
|
||||||
|
'torch_dtype': torch.float16,
|
||||||
|
'safety_checker': None
|
||||||
|
}
|
||||||
|
|
||||||
|
pipe = StableDiffusionUpscalePipeline.from_pretrained(
|
||||||
|
'stabilityai/stable-diffusion-x4-upscaler', **params)
|
||||||
|
|
||||||
|
prompt = prompt
|
||||||
|
image = pipe(
|
||||||
|
prompt,
|
||||||
|
image=Image.open(img_path)
|
||||||
|
).images[0]
|
||||||
|
|
||||||
|
image.save(output)
|
||||||
|
|
|
@ -97,17 +97,22 @@ def dgpu_workers(request, dockerctl, skynet_running):
|
||||||
|
|
||||||
num_containers, initial_algos = request.param
|
num_containers, initial_algos = request.param
|
||||||
|
|
||||||
cmd = f'''
|
cmds = []
|
||||||
pip install -e . && \
|
for i in range(num_containers):
|
||||||
skynet run dgpu --algos=\'{json.dumps(initial_algos)}\'
|
cmd = f'''
|
||||||
'''
|
pip install -e . && \
|
||||||
|
skynet run dgpu \
|
||||||
|
--algos=\'{json.dumps(initial_algos)}\' \
|
||||||
|
--uid=dgpu-{i}
|
||||||
|
'''
|
||||||
|
cmds.append(['bash', '-c', cmd])
|
||||||
|
|
||||||
logging.info(f'launching: \n{cmd}')
|
logging.info(f'launching: \n{cmd}')
|
||||||
|
|
||||||
with dockerctl.run(
|
with dockerctl.run(
|
||||||
DOCKER_RUNTIME_CUDA,
|
DOCKER_RUNTIME_CUDA,
|
||||||
name='skynet-test-runtime-cuda',
|
name='skynet-test-runtime-cuda',
|
||||||
command=['bash', '-c', cmd],
|
commands=cmds,
|
||||||
environment={
|
environment={
|
||||||
'HF_TOKEN': os.environ['HF_TOKEN'],
|
'HF_TOKEN': os.environ['HF_TOKEN'],
|
||||||
'HF_HOME': '/skynet/hf_home'
|
'HF_HOME': '/skynet/hf_home'
|
||||||
|
|
|
@ -26,8 +26,7 @@ async def wait_for_dgpus(rpc, amount: int, timeout: float = 30.0):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
while not gpu_ready and (current_time - start_time) < timeout:
|
while not gpu_ready and (current_time - start_time) < timeout:
|
||||||
res = await rpc('dgpu-test', 'dgpu_workers')
|
res = await rpc('dgpu_workers')
|
||||||
logging.info(res)
|
|
||||||
if res.result['ok'] >= amount:
|
if res.result['ok'] >= amount:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -40,6 +39,7 @@ async def wait_for_dgpus(rpc, amount: int, timeout: float = 30.0):
|
||||||
_images = set()
|
_images = set()
|
||||||
async def check_request_img(
|
async def check_request_img(
|
||||||
i: int,
|
i: int,
|
||||||
|
uid: int = 0,
|
||||||
width: int = 512,
|
width: int = 512,
|
||||||
height: int = 512,
|
height: int = 512,
|
||||||
expect_unique=True
|
expect_unique=True
|
||||||
|
@ -47,12 +47,13 @@ async def check_request_img(
|
||||||
global _images
|
global _images
|
||||||
|
|
||||||
async with open_skynet_rpc(
|
async with open_skynet_rpc(
|
||||||
|
uid,
|
||||||
security=True,
|
security=True,
|
||||||
cert_name='whitelist/testing',
|
cert_name='whitelist/testing',
|
||||||
key_name='testing'
|
key_name='testing'
|
||||||
) as rpc_call:
|
) as rpc_call:
|
||||||
res = await rpc_call(
|
res = await rpc_call(
|
||||||
'tg+580213293', 'txt2img', {
|
'txt2img', {
|
||||||
'prompt': 'red old tractor in a sunny wheat field',
|
'prompt': 'red old tractor in a sunny wheat field',
|
||||||
'step': 28,
|
'step': 28,
|
||||||
'width': width, 'height': height,
|
'width': width, 'height': height,
|
||||||
|
@ -88,6 +89,7 @@ async def test_dgpu_worker_compute_error(dgpu_workers):
|
||||||
'''
|
'''
|
||||||
|
|
||||||
async with open_skynet_rpc(
|
async with open_skynet_rpc(
|
||||||
|
'test-ctx',
|
||||||
security=True,
|
security=True,
|
||||||
cert_name='whitelist/testing',
|
cert_name='whitelist/testing',
|
||||||
key_name='testing'
|
key_name='testing'
|
||||||
|
@ -110,6 +112,7 @@ async def test_dgpu_workers(dgpu_workers):
|
||||||
'''
|
'''
|
||||||
|
|
||||||
async with open_skynet_rpc(
|
async with open_skynet_rpc(
|
||||||
|
'test-ctx',
|
||||||
security=True,
|
security=True,
|
||||||
cert_name='whitelist/testing',
|
cert_name='whitelist/testing',
|
||||||
key_name='testing'
|
key_name='testing'
|
||||||
|
@ -126,6 +129,7 @@ async def test_dgpu_workers_two(dgpu_workers):
|
||||||
'''Generate two images in two separate dgpu workers
|
'''Generate two images in two separate dgpu workers
|
||||||
'''
|
'''
|
||||||
async with open_skynet_rpc(
|
async with open_skynet_rpc(
|
||||||
|
'test-ctx',
|
||||||
security=True,
|
security=True,
|
||||||
cert_name='whitelist/testing',
|
cert_name='whitelist/testing',
|
||||||
key_name='testing'
|
key_name='testing'
|
||||||
|
@ -143,6 +147,7 @@ async def test_dgpu_worker_algo_swap(dgpu_workers):
|
||||||
'''Generate an image using a non default model
|
'''Generate an image using a non default model
|
||||||
'''
|
'''
|
||||||
async with open_skynet_rpc(
|
async with open_skynet_rpc(
|
||||||
|
'test-ctx',
|
||||||
security=True,
|
security=True,
|
||||||
cert_name='whitelist/testing',
|
cert_name='whitelist/testing',
|
||||||
key_name='testing'
|
key_name='testing'
|
||||||
|
@ -158,35 +163,32 @@ async def test_dgpu_rotation_next_worker(dgpu_workers):
|
||||||
rotation happens correctly
|
rotation happens correctly
|
||||||
'''
|
'''
|
||||||
async with open_skynet_rpc(
|
async with open_skynet_rpc(
|
||||||
|
'test-ctx',
|
||||||
security=True,
|
security=True,
|
||||||
cert_name='whitelist/testing',
|
cert_name='whitelist/testing',
|
||||||
key_name='testing'
|
key_name='testing'
|
||||||
) as test_rpc:
|
) as test_rpc:
|
||||||
await wait_for_dgpus(test_rpc, 3)
|
await wait_for_dgpus(test_rpc, 3)
|
||||||
|
|
||||||
res = await test_rpc('testing-rpc', 'dgpu_next')
|
res = await test_rpc('dgpu_next')
|
||||||
logging.info(res)
|
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result
|
||||||
assert res.result['ok'] == 0
|
assert res.result['ok'] == 0
|
||||||
|
|
||||||
await check_request_img(0)
|
await check_request_img(0)
|
||||||
|
|
||||||
res = await test_rpc('testing-rpc', 'dgpu_next')
|
res = await test_rpc('dgpu_next')
|
||||||
logging.info(res)
|
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result
|
||||||
assert res.result['ok'] == 1
|
assert res.result['ok'] == 1
|
||||||
|
|
||||||
await check_request_img(0)
|
await check_request_img(0)
|
||||||
|
|
||||||
res = await test_rpc('testing-rpc', 'dgpu_next')
|
res = await test_rpc('dgpu_next')
|
||||||
logging.info(res)
|
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result
|
||||||
assert res.result['ok'] == 2
|
assert res.result['ok'] == 2
|
||||||
|
|
||||||
await check_request_img(0)
|
await check_request_img(0)
|
||||||
|
|
||||||
res = await test_rpc('testing-rpc', 'dgpu_next')
|
res = await test_rpc('dgpu_next')
|
||||||
logging.info(res)
|
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result
|
||||||
assert res.result['ok'] == 0
|
assert res.result['ok'] == 0
|
||||||
|
|
||||||
|
@ -198,6 +200,7 @@ async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers):
|
||||||
next_worker rotation happens correctly
|
next_worker rotation happens correctly
|
||||||
'''
|
'''
|
||||||
async with open_skynet_rpc(
|
async with open_skynet_rpc(
|
||||||
|
'test-ctx',
|
||||||
security=True,
|
security=True,
|
||||||
cert_name='whitelist/testing',
|
cert_name='whitelist/testing',
|
||||||
key_name='testing'
|
key_name='testing'
|
||||||
|
@ -213,8 +216,7 @@ async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers):
|
||||||
|
|
||||||
dgpu_workers[0].wait()
|
dgpu_workers[0].wait()
|
||||||
|
|
||||||
res = await test_rpc('testing-rpc', 'dgpu_workers')
|
res = await test_rpc('dgpu_workers')
|
||||||
logging.info(res)
|
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result
|
||||||
assert res.result['ok'] == 2
|
assert res.result['ok'] == 2
|
||||||
|
|
||||||
|
@ -224,14 +226,17 @@ async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers):
|
||||||
|
|
||||||
|
|
||||||
async def test_dgpu_no_ack_node_disconnect(skynet_running):
|
async def test_dgpu_no_ack_node_disconnect(skynet_running):
|
||||||
|
'''Mock a node that connects, gets a request but fails to
|
||||||
|
acknowledge it, then check skynet correctly drops the node
|
||||||
|
'''
|
||||||
async with open_skynet_rpc(
|
async with open_skynet_rpc(
|
||||||
|
'test-ctx',
|
||||||
security=True,
|
security=True,
|
||||||
cert_name='whitelist/testing',
|
cert_name='whitelist/testing',
|
||||||
key_name='testing'
|
key_name='testing'
|
||||||
) as rpc_call:
|
) as rpc_call:
|
||||||
|
|
||||||
res = await rpc_call('dgpu-0', 'dgpu_online')
|
res = await rpc_call('dgpu_online')
|
||||||
logging.info(res)
|
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result
|
||||||
|
|
||||||
await wait_for_dgpus(rpc_call, 1)
|
await wait_for_dgpus(rpc_call, 1)
|
||||||
|
@ -241,8 +246,34 @@ async def test_dgpu_no_ack_node_disconnect(skynet_running):
|
||||||
|
|
||||||
assert 'dgpu failed to acknowledge request' in str(e)
|
assert 'dgpu failed to acknowledge request' in str(e)
|
||||||
|
|
||||||
res = await rpc_call('testing-rpc', 'dgpu_workers')
|
res = await rpc_call('dgpu_workers')
|
||||||
logging.info(res)
|
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result
|
||||||
assert res.result['ok'] == 0
|
assert res.result['ok'] == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'dgpu_workers', [(1, ['midj'])], indirect=True)
|
||||||
|
async def test_dgpu_timeout_while_processing(dgpu_workers):
|
||||||
|
'''Stop node while processing request to cause timeout and
|
||||||
|
then check skynet correctly drops the node.
|
||||||
|
'''
|
||||||
|
async with open_skynet_rpc(
|
||||||
|
'test-ctx',
|
||||||
|
security=True,
|
||||||
|
cert_name='whitelist/testing',
|
||||||
|
key_name='testing'
|
||||||
|
) as test_rpc:
|
||||||
|
await wait_for_dgpus(test_rpc, 1)
|
||||||
|
|
||||||
|
async def check_request_img_raises():
|
||||||
|
with pytest.raises(SkynetDGPUComputeError) as e:
|
||||||
|
await check_request_img(0)
|
||||||
|
|
||||||
|
assert 'timeout while processing request' in str(e)
|
||||||
|
|
||||||
|
async with trio.open_nursery() as n:
|
||||||
|
n.start_soon(check_request_img_raises)
|
||||||
|
await trio.sleep(1)
|
||||||
|
ec, out = dgpu_workers[0].exec_run(
|
||||||
|
['pkill', '-TERM', '-f', 'skynet'])
|
||||||
|
assert ec == 0
|
||||||
|
|
|
@ -12,64 +12,59 @@ from skynet.structs import *
|
||||||
from skynet.frontend import open_skynet_rpc
|
from skynet.frontend import open_skynet_rpc
|
||||||
|
|
||||||
|
|
||||||
|
async def test_skynet(skynet_running):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
async def test_skynet_attempt_insecure(skynet_running):
|
async def test_skynet_attempt_insecure(skynet_running):
|
||||||
with pytest.raises(pynng.exceptions.NNGException) as e:
|
with pytest.raises(pynng.exceptions.NNGException) as e:
|
||||||
async with open_skynet_rpc():
|
async with open_skynet_rpc('bad-actor'):
|
||||||
...
|
...
|
||||||
|
|
||||||
assert str(e.value) == 'Connection shutdown'
|
assert str(e.value) == 'Connection shutdown'
|
||||||
|
|
||||||
|
|
||||||
async def test_skynet_dgpu_connection_simple(skynet_running):
|
async def test_skynet_dgpu_connection_simple(skynet_running):
|
||||||
async with open_skynet_rpc(
|
async with open_skynet_rpc(
|
||||||
|
'dgpu-0',
|
||||||
security=True,
|
security=True,
|
||||||
cert_name='whitelist/testing',
|
cert_name='whitelist/testing',
|
||||||
key_name='testing'
|
key_name='testing'
|
||||||
) as rpc_call:
|
) as rpc_call:
|
||||||
# check 0 nodes are connected
|
# check 0 nodes are connected
|
||||||
res = await rpc_call('dgpu-0', 'dgpu_workers')
|
res = await rpc_call('dgpu_workers')
|
||||||
logging.info(res)
|
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result
|
||||||
assert res.result['ok'] == 0
|
assert res.result['ok'] == 0
|
||||||
|
|
||||||
# check next worker is None
|
# check next worker is None
|
||||||
res = await rpc_call('dgpu-0', 'dgpu_next')
|
res = await rpc_call('dgpu_next')
|
||||||
logging.info(res)
|
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result
|
||||||
assert res.result['ok'] == None
|
assert res.result['ok'] == None
|
||||||
|
|
||||||
# connect 1 dgpu
|
# connect 1 dgpu
|
||||||
res = await rpc_call(
|
res = await rpc_call('dgpu_online')
|
||||||
'dgpu-0', 'dgpu_online')
|
|
||||||
logging.info(res)
|
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result
|
||||||
|
|
||||||
# check 1 node is connected
|
# check 1 node is connected
|
||||||
res = await rpc_call('dgpu-0', 'dgpu_workers')
|
res = await rpc_call('dgpu_workers')
|
||||||
logging.info(res)
|
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result
|
||||||
assert res.result['ok'] == 1
|
assert res.result['ok'] == 1
|
||||||
|
|
||||||
# check next worker is 0
|
# check next worker is 0
|
||||||
res = await rpc_call('dgpu-0', 'dgpu_next')
|
res = await rpc_call('dgpu_next')
|
||||||
logging.info(res)
|
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result
|
||||||
assert res.result['ok'] == 0
|
assert res.result['ok'] == 0
|
||||||
|
|
||||||
# disconnect 1 dgpu
|
# disconnect 1 dgpu
|
||||||
res = await rpc_call(
|
res = await rpc_call('dgpu_offline')
|
||||||
'dgpu-0', 'dgpu_offline')
|
|
||||||
logging.info(res)
|
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result
|
||||||
|
|
||||||
# check 0 nodes are connected
|
# check 0 nodes are connected
|
||||||
res = await rpc_call('dgpu-0', 'dgpu_workers')
|
res = await rpc_call('dgpu_workers')
|
||||||
logging.info(res)
|
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result
|
||||||
assert res.result['ok'] == 0
|
assert res.result['ok'] == 0
|
||||||
|
|
||||||
# check next worker is None
|
# check next worker is None
|
||||||
res = await rpc_call('dgpu-0', 'dgpu_next')
|
res = await rpc_call('dgpu_next')
|
||||||
logging.info(res)
|
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result
|
||||||
assert res.result['ok'] == None
|
assert res.result['ok'] == None
|
||||||
|
|
Loading…
Reference in New Issue