mirror of https://github.com/skygpu/skynet.git
Add authenticated messaging, also cmd line utils txt2img and upscale
parent
f6326ad05c
commit
6bc555f0d6
2
setup.py
2
setup.py
|
@ -10,6 +10,8 @@ setup(
|
|||
entry_points={
|
||||
'console_scripts': [
|
||||
'skynet = skynet.cli:skynet',
|
||||
'txt2img = skynet.cli:txt2img',
|
||||
'upscale = skynet.cli:upscale'
|
||||
]
|
||||
},
|
||||
install_requires=['click']
|
||||
|
|
110
skynet/brain.py
110
skynet/brain.py
|
@ -16,6 +16,11 @@ import pynng
|
|||
import trio_asyncio
|
||||
|
||||
from pynng import TLSConfig
|
||||
from OpenSSL.crypto import (
|
||||
load_privatekey,
|
||||
load_certificate,
|
||||
FILETYPE_PEM
|
||||
)
|
||||
|
||||
from .db import *
|
||||
from .structs import *
|
||||
|
@ -34,12 +39,14 @@ class SkynetDGPUComputeError(BaseException):
|
|||
class SkynetShutdownRequested(BaseException):
|
||||
...
|
||||
|
||||
|
||||
@acm
|
||||
async def open_rpc_service(sock, dgpu_bus, db_pool):
|
||||
async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||
nodes = OrderedDict()
|
||||
wip_reqs = {}
|
||||
fin_reqs = {}
|
||||
next_worker: Optional[int] = None
|
||||
security = len(tls_whitelist) > 0
|
||||
|
||||
def connect_node(uid):
|
||||
nonlocal next_worker
|
||||
|
@ -109,27 +116,20 @@ async def open_rpc_service(sock, dgpu_bus, db_pool):
|
|||
async def dgpu_image_streamer():
|
||||
nonlocal wip_reqs, fin_reqs
|
||||
while True:
|
||||
msg = await dgpu_bus.arecv_msg()
|
||||
rid = UUID(bytes=msg.bytes[:16]).hex
|
||||
raw_msg = msg.bytes[16:]
|
||||
logging.info(f'streamer got back {rid} of size {len(raw_msg)}')
|
||||
match raw_msg[:5]:
|
||||
case b'error':
|
||||
img = raw_msg.decode()
|
||||
msg = DGPUBusResponse(
|
||||
**json.loads(
|
||||
(await dgpu_bus.arecv()).decode()))
|
||||
|
||||
case b'ack':
|
||||
img = raw_msg
|
||||
if security:
|
||||
msg.verify(tls_whitelist[msg.cert])
|
||||
|
||||
case _:
|
||||
img = base64.b64encode(raw_msg).hex()
|
||||
|
||||
if rid not in wip_reqs:
|
||||
if msg.rid not in wip_reqs:
|
||||
continue
|
||||
|
||||
fin_reqs[rid] = img
|
||||
event = wip_reqs[rid]
|
||||
fin_reqs[msg.rid] = msg
|
||||
event = wip_reqs[msg.rid]
|
||||
event.set()
|
||||
del wip_reqs[rid]
|
||||
del wip_reqs[msg.rid]
|
||||
|
||||
async def dgpu_stream_one_img(req: ImageGenRequest):
|
||||
nonlocal wip_reqs, fin_reqs, next_worker
|
||||
|
@ -151,6 +151,9 @@ async def open_rpc_service(sock, dgpu_bus, db_pool):
|
|||
|
||||
logging.info(f'dgpu_bus req: {dgpu_req}')
|
||||
|
||||
if security:
|
||||
dgpu_req.sign(tls_key, 'skynet')
|
||||
|
||||
await dgpu_bus.asend(
|
||||
json.dumps(dgpu_req.to_dict()).encode())
|
||||
|
||||
|
@ -163,8 +166,8 @@ async def open_rpc_service(sock, dgpu_bus, db_pool):
|
|||
disconnect_node(nid)
|
||||
raise SkynetDGPUOffline('dgpu failed to acknowledge request')
|
||||
|
||||
ack = fin_reqs[rid]
|
||||
if ack != b'ack':
|
||||
ack_msg = fin_reqs[rid]
|
||||
if 'ack' not in ack_msg.params:
|
||||
disconnect_node(nid)
|
||||
raise SkynetDGPUOffline('dgpu failed to acknowledge request')
|
||||
|
||||
|
@ -178,15 +181,13 @@ async def open_rpc_service(sock, dgpu_bus, db_pool):
|
|||
|
||||
nodes[nid]['task'] = None
|
||||
|
||||
img = fin_reqs[rid]
|
||||
img_resp = fin_reqs[rid]
|
||||
del fin_reqs[rid]
|
||||
|
||||
logging.info(f'done streaming {len(img)} bytes')
|
||||
if 'error' in img_resp.params:
|
||||
raise SkynetDGPUComputeError(img_resp.params['error'])
|
||||
|
||||
if 'error' in img:
|
||||
raise SkynetDGPUComputeError(img)
|
||||
|
||||
return rid, img
|
||||
return rid, img_resp.params['img']
|
||||
|
||||
async def handle_user_request(rpc_ctx, req):
|
||||
try:
|
||||
|
@ -266,9 +267,13 @@ async def open_rpc_service(sock, dgpu_bus, db_pool):
|
|||
'message': str(e)
|
||||
}
|
||||
|
||||
resp = SkynetRPCResponse(result=result)
|
||||
|
||||
if security:
|
||||
resp.sign(tls_key, 'skynet')
|
||||
|
||||
await rpc_ctx.asend(
|
||||
json.dumps(
|
||||
SkynetRPCResponse(result=result).to_dict()).encode())
|
||||
json.dumps(resp.to_dict()).encode())
|
||||
|
||||
async def request_service(n):
|
||||
nonlocal next_worker
|
||||
|
@ -279,7 +284,19 @@ async def open_rpc_service(sock, dgpu_bus, db_pool):
|
|||
content = msg.bytes.decode()
|
||||
req = SkynetRPCRequest(**json.loads(content))
|
||||
|
||||
logging.info(req)
|
||||
if security:
|
||||
if req.cert not in tls_whitelist:
|
||||
logging.warning(
|
||||
f'{req.cert} not in tls whitelist and security=True')
|
||||
continue
|
||||
|
||||
try:
|
||||
req.verify(tls_whitelist[req.cert])
|
||||
|
||||
except ValueError:
|
||||
logging.warning(
|
||||
f'{req.cert} sent an unauthenticated msg with security=True')
|
||||
continue
|
||||
|
||||
result = {}
|
||||
|
||||
|
@ -303,10 +320,14 @@ async def open_rpc_service(sock, dgpu_bus, db_pool):
|
|||
handle_user_request, ctx, req)
|
||||
continue
|
||||
|
||||
resp = SkynetRPCResponse(
|
||||
result={'ok': result})
|
||||
|
||||
if security:
|
||||
resp.sign(tls_key, 'skynet')
|
||||
|
||||
await ctx.asend(
|
||||
json.dumps(
|
||||
SkynetRPCResponse(
|
||||
result={'ok': result}).to_dict()).encode())
|
||||
json.dumps(resp.to_dict()).encode())
|
||||
|
||||
|
||||
async with trio.open_nursery() as n:
|
||||
|
@ -334,22 +355,28 @@ async def run_skynet(
|
|||
if security:
|
||||
# load tls certs
|
||||
certs_dir = Path(DEFAULT_CERTS_DIR).resolve()
|
||||
tls_key = (certs_dir / DEFAULT_CERT_SKYNET_PRIV).read_text()
|
||||
tls_cert = (certs_dir / DEFAULT_CERT_SKYNET_PUB).read_text()
|
||||
tls_whitelist = [
|
||||
(cert_path).read_text()
|
||||
for cert_path in (certs_dir / 'whitelist').glob('*.cert')]
|
||||
|
||||
cert_start = tls_cert.index('\n') + 1
|
||||
logging.info(f'tls_cert: {tls_cert[cert_start:cert_start+64]}...')
|
||||
tls_key_data = (certs_dir / DEFAULT_CERT_SKYNET_PRIV).read_text()
|
||||
tls_key = load_privatekey(FILETYPE_PEM, tls_key_data)
|
||||
|
||||
tls_cert_data = (certs_dir / DEFAULT_CERT_SKYNET_PUB).read_text()
|
||||
tls_cert = load_certificate(FILETYPE_PEM, tls_cert_data)
|
||||
|
||||
tls_whitelist = {}
|
||||
for cert_path in (certs_dir / 'whitelist').glob('*.cert'):
|
||||
tls_whitelist[cert_path.stem] = load_certificate(
|
||||
FILETYPE_PEM, cert_path.read_text())
|
||||
|
||||
cert_start = tls_cert_data.index('\n') + 1
|
||||
logging.info(f'tls_cert: {tls_cert_data[cert_start:cert_start+64]}...')
|
||||
logging.info(f'tls_whitelist len: {len(tls_whitelist)}')
|
||||
|
||||
rpc_address = 'tls+' + rpc_address
|
||||
dgpu_address = 'tls+' + dgpu_address
|
||||
tls_config = TLSConfig(
|
||||
TLSConfig.MODE_SERVER,
|
||||
own_key_string=tls_key,
|
||||
own_cert_string=tls_cert)
|
||||
own_key_string=tls_key_data,
|
||||
own_cert_string=tls_cert_data)
|
||||
|
||||
with (
|
||||
pynng.Rep0() as rpc_sock,
|
||||
|
@ -367,7 +394,8 @@ async def run_skynet(
|
|||
dgpu_bus.listen(dgpu_address)
|
||||
|
||||
try:
|
||||
async with open_rpc_service(rpc_sock, dgpu_bus, db_pool):
|
||||
async with open_rpc_service(
|
||||
rpc_sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||
yield
|
||||
|
||||
except SkynetShutdownRequested:
|
||||
|
|
|
@ -9,44 +9,78 @@ from functools import partial
|
|||
import trio
|
||||
import click
|
||||
|
||||
from . import utils
|
||||
from .dgpu import open_dgpu_node
|
||||
from .utils import txt2img
|
||||
from .constants import ALGOS
|
||||
from .brain import run_skynet
|
||||
|
||||
from .frontend.telegram import run_skynet_telegram
|
||||
|
||||
|
||||
@click.group()
|
||||
def skynet(*args, **kwargs):
|
||||
pass
|
||||
|
||||
@skynet.command()
|
||||
@click.option('--model', '-m', default=ALGOS['midj'])
|
||||
|
||||
@click.command()
|
||||
@click.option('--model', '-m', default='midj')
|
||||
@click.option(
|
||||
'--prompt', '-p', default='a red tractor in a wheat field')
|
||||
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
|
||||
@click.option('--output', '-o', default='output.png')
|
||||
@click.option('--width', '-w', default=512)
|
||||
@click.option('--height', '-h', default=512)
|
||||
@click.option('--guidance', '-g', default=10.0)
|
||||
@click.option('--steps', '-s', default=26)
|
||||
@click.option('--seed', '-S', default=None)
|
||||
def txt2img(*args
|
||||
# model: str,
|
||||
# prompt: str,
|
||||
# output: str
|
||||
# width: int, height: int,
|
||||
# guidance: float,
|
||||
# steps: int,
|
||||
# seed: Optional[int]
|
||||
):
|
||||
def txt2img(*args, **kwargs):
|
||||
assert 'HF_TOKEN' in os.environ
|
||||
txt2img(
|
||||
os.environ['HF_TOKEN'], *args)
|
||||
utils.txt2img(os.environ['HF_TOKEN'], **kwargs)
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
|
||||
@click.option('--input', '-i', default='input.png')
|
||||
@click.option('--output', '-o', default='output.png')
|
||||
@click.option('--steps', '-s', default=26)
|
||||
def upscale(prompt, input, output, steps):
|
||||
assert 'HF_TOKEN' in os.environ
|
||||
utils.upscale(
|
||||
os.environ['HF_TOKEN'],
|
||||
prompt=prompt,
|
||||
img_path=input,
|
||||
output=output,
|
||||
steps=steps)
|
||||
|
||||
|
||||
@skynet.group()
|
||||
def run(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
@run.command()
|
||||
@click.option('--loglevel', '-l', default='warning', help='Logging level')
|
||||
@click.option(
|
||||
'--host', '-h', default='localhost:5432')
|
||||
@click.option(
|
||||
'--pass', '-p', default='password')
|
||||
def skynet(
|
||||
loglevel: str,
|
||||
host: str,
|
||||
passw: str
|
||||
):
|
||||
async def _run_skynet():
|
||||
async with run_skynet(
|
||||
db_host=host,
|
||||
db_pass=passw
|
||||
):
|
||||
await trio.sleep_forever()
|
||||
|
||||
trio_asyncio.run(_run_skynet)
|
||||
|
||||
|
||||
@run.command()
|
||||
@click.option('--loglevel', '-l', default='warning', help='Logging level')
|
||||
@click.option(
|
||||
'--uid', '-u', required=True)
|
||||
@click.option(
|
||||
'--key', '-k', default='dgpu')
|
||||
@click.option(
|
||||
|
@ -55,6 +89,7 @@ def run(*args, **kwargs):
|
|||
'--algos', '-a', default=None)
|
||||
def dgpu(
|
||||
loglevel: str,
|
||||
uid: str,
|
||||
key: str,
|
||||
cert: str,
|
||||
algos: Optional[str]
|
||||
|
@ -63,6 +98,28 @@ def dgpu(
|
|||
partial(
|
||||
open_dgpu_node,
|
||||
cert,
|
||||
uid,
|
||||
key_name=key,
|
||||
initial_algos=json.loads(algos)
|
||||
))
|
||||
|
||||
|
||||
@run.command()
|
||||
@click.option('--loglevel', '-l', default='warning', help='Logging level')
|
||||
@click.option(
|
||||
'--key', '-k', default='telegram-frontend')
|
||||
@click.option(
|
||||
'--cert', '-c', default='whitelist/telegram-frontend')
|
||||
def telegram(
|
||||
loglevel: str,
|
||||
key: str,
|
||||
cert: str
|
||||
):
|
||||
assert 'TG_TOKEN' in os.environ
|
||||
trio_asyncio.run(
|
||||
partial(
|
||||
run_skynet_telegram,
|
||||
os.environ['TG_TOKEN'],
|
||||
key_name=key,
|
||||
cert_name=cert
|
||||
))
|
||||
|
|
45
skynet/db.py
45
skynet/db.py
|
@ -58,13 +58,16 @@ ALTER TABLE skynet.user_config
|
|||
|
||||
|
||||
def try_decode_uid(uid: str):
|
||||
if isinstance(uid, int):
|
||||
return None, uid
|
||||
|
||||
try:
|
||||
proto, uid = uid.split('+')
|
||||
uid = int(uid)
|
||||
return proto, uid
|
||||
|
||||
except ValueError:
|
||||
logging.warning(f'got non numeric uid?: {uid}')
|
||||
logging.warning(f'got non chat proto uid?: {uid}')
|
||||
return None, None
|
||||
|
||||
|
||||
|
@ -132,28 +135,38 @@ async def new_user(conn, uid: str):
|
|||
|
||||
logging.info(f'new user! {uid}')
|
||||
|
||||
tg_id = None
|
||||
date = datetime.utcnow()
|
||||
|
||||
proto, pid = try_decode_uid(uid)
|
||||
|
||||
match proto:
|
||||
case 'tg':
|
||||
tg_id = pid
|
||||
|
||||
async with conn.transaction():
|
||||
stmt = await conn.prepare('''
|
||||
INSERT INTO skynet.user(
|
||||
tg_id, generated, joined, last_prompt, role)
|
||||
match proto:
|
||||
case 'tg':
|
||||
tg_id = pid
|
||||
stmt = await conn.prepare('''
|
||||
INSERT INTO skynet.user(
|
||||
tg_id, generated, joined, last_prompt, role)
|
||||
|
||||
VALUES($1, $2, $3, $4, $5)
|
||||
ON CONFLICT DO NOTHING
|
||||
''')
|
||||
await stmt.fetch(
|
||||
tg_id, 0, date, None, DEFAULT_ROLE
|
||||
)
|
||||
VALUES($1, $2, $3, $4, $5)
|
||||
ON CONFLICT DO NOTHING
|
||||
''')
|
||||
await stmt.fetch(
|
||||
tg_id, 0, date, None, DEFAULT_ROLE
|
||||
)
|
||||
new_uid = await get_user(conn, uid)
|
||||
|
||||
new_uid = await get_user(conn, uid)
|
||||
case None:
|
||||
stmt = await conn.prepare('''
|
||||
INSERT INTO skynet.user(
|
||||
id, generated, joined, last_prompt, role)
|
||||
|
||||
VALUES($1, $2, $3, $4, $5)
|
||||
ON CONFLICT DO NOTHING
|
||||
''')
|
||||
await stmt.fetch(
|
||||
pid, 0, date, None, DEFAULT_ROLE
|
||||
)
|
||||
new_uid = pid
|
||||
|
||||
stmt = await conn.prepare('''
|
||||
INSERT INTO skynet.user_config(
|
||||
|
|
|
@ -5,6 +5,7 @@ import io
|
|||
import trio
|
||||
import json
|
||||
import uuid
|
||||
import base64
|
||||
import random
|
||||
import logging
|
||||
|
||||
|
@ -16,10 +17,16 @@ import pynng
|
|||
import torch
|
||||
|
||||
from pynng import TLSConfig
|
||||
from OpenSSL.crypto import (
|
||||
load_privatekey,
|
||||
load_certificate,
|
||||
FILETYPE_PEM
|
||||
)
|
||||
from diffusers import (
|
||||
StableDiffusionPipeline,
|
||||
EulerAncestralDiscreteScheduler
|
||||
)
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
|
||||
from .structs import *
|
||||
from .constants import *
|
||||
|
@ -58,6 +65,7 @@ class DGPUComputeError(BaseException):
|
|||
|
||||
async def open_dgpu_node(
|
||||
cert_name: str,
|
||||
unique_id: str,
|
||||
key_name: Optional[str],
|
||||
rpc_address: str = DEFAULT_RPC_ADDR,
|
||||
dgpu_address: str = DEFAULT_DGPU_ADDR,
|
||||
|
@ -122,6 +130,7 @@ async def open_dgpu_node(
|
|||
|
||||
|
||||
async with open_skynet_rpc(
|
||||
unique_id,
|
||||
security=security,
|
||||
cert_name=cert_name,
|
||||
key_name=key_name
|
||||
|
@ -131,16 +140,23 @@ async def open_dgpu_node(
|
|||
if security:
|
||||
# load tls certs
|
||||
if not key_name:
|
||||
key_name = certs_name
|
||||
key_name = cert_name
|
||||
|
||||
certs_dir = Path(DEFAULT_CERTS_DIR).resolve()
|
||||
|
||||
skynet_cert_path = certs_dir / 'brain.cert'
|
||||
tls_cert_path = certs_dir / f'{cert_name}.cert'
|
||||
tls_key_path = certs_dir / f'{key_name}.key'
|
||||
|
||||
skynet_cert = skynet_cert_path.read_text()
|
||||
tls_cert = tls_cert_path.read_text()
|
||||
tls_key = tls_key_path.read_text()
|
||||
cert_name = tls_cert_path.stem
|
||||
|
||||
skynet_cert_data = skynet_cert_path.read_text()
|
||||
skynet_cert = load_certificate(FILETYPE_PEM, skynet_cert_data)
|
||||
|
||||
tls_cert_data = tls_cert_path.read_text()
|
||||
|
||||
tls_key_data = tls_key_path.read_text()
|
||||
tls_key = load_privatekey(FILETYPE_PEM, tls_key_data)
|
||||
|
||||
logging.info(f'skynet cert: {skynet_cert_path}')
|
||||
logging.info(f'dgpu cert: {tls_cert_path}')
|
||||
|
@ -149,17 +165,16 @@ async def open_dgpu_node(
|
|||
dgpu_address = 'tls+' + dgpu_address
|
||||
tls_config = TLSConfig(
|
||||
TLSConfig.MODE_CLIENT,
|
||||
own_key_string=tls_key,
|
||||
own_cert_string=tls_cert,
|
||||
ca_string=skynet_cert)
|
||||
own_key_string=tls_key_data,
|
||||
own_cert_string=tls_cert_data,
|
||||
ca_string=skynet_cert_data)
|
||||
|
||||
logging.info(f'connecting to {dgpu_address}')
|
||||
with pynng.Bus0() as dgpu_sock:
|
||||
dgpu_sock.tls_config = tls_config
|
||||
dgpu_sock.dial(dgpu_address)
|
||||
|
||||
res = await rpc_call(name.hex, 'dgpu_online')
|
||||
logging.info(res)
|
||||
res = await rpc_call('dgpu_online')
|
||||
assert 'ok' in res.result
|
||||
|
||||
try:
|
||||
|
@ -168,30 +183,55 @@ async def open_dgpu_node(
|
|||
req = DGPUBusRequest(
|
||||
**json.loads(msg.decode()))
|
||||
|
||||
if req.nid != name.hex:
|
||||
logging.info('witnessed request {req.rid}, for {req.nid}')
|
||||
if req.nid != unique_id:
|
||||
logging.info(
|
||||
f'witnessed msg {req.rid}, node involved: {req.nid}')
|
||||
continue
|
||||
|
||||
if security:
|
||||
req.verify(skynet_cert)
|
||||
|
||||
ack_resp = DGPUBusResponse(
|
||||
rid=req.rid,
|
||||
nid=req.nid,
|
||||
params={'ack': {}}
|
||||
)
|
||||
|
||||
if security:
|
||||
ack_resp.sign(tls_key, cert_name)
|
||||
|
||||
# send ack
|
||||
await dgpu_sock.asend(
|
||||
bytes.fromhex(req.rid) + b'ack')
|
||||
json.dumps(ack_resp.to_dict()).encode())
|
||||
|
||||
logging.info(f'sent ack, processing {req.rid}...')
|
||||
|
||||
try:
|
||||
img = await gpu_compute_one(
|
||||
ImageGenRequest(**req.params))
|
||||
img_resp = DGPUBusResponse(
|
||||
rid=req.rid,
|
||||
nid=req.nid,
|
||||
params={'img': base64.b64encode(img).hex()}
|
||||
)
|
||||
|
||||
except DGPUComputeError as e:
|
||||
img = b'error' + str(e).encode()
|
||||
img_resp = DGPUBusResponse(
|
||||
rid=req.rid,
|
||||
nid=req.nid,
|
||||
params={'error': str(e)}
|
||||
)
|
||||
|
||||
if security:
|
||||
img_resp.sign(tls_key, cert_name)
|
||||
|
||||
# send final image
|
||||
await dgpu_sock.asend(
|
||||
bytes.fromhex(req.rid) + img)
|
||||
json.dumps(img_resp.to_dict()).encode())
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logging.info('interrupt caught, stopping...')
|
||||
|
||||
finally:
|
||||
res = await rpc_call(name.hex, 'dgpu_offline')
|
||||
logging.info(res)
|
||||
res = await rpc_call('dgpu_offline')
|
||||
assert 'ok' in res.result
|
||||
|
|
|
@ -9,6 +9,11 @@ from contextlib import asynccontextmanager as acm
|
|||
import pynng
|
||||
|
||||
from pynng import TLSConfig
|
||||
from OpenSSL.crypto import (
|
||||
load_privatekey,
|
||||
load_certificate,
|
||||
FILETYPE_PEM
|
||||
)
|
||||
|
||||
from ..structs import SkynetRPCRequest, SkynetRPCResponse
|
||||
from ..constants import *
|
||||
|
@ -30,56 +35,73 @@ class ConfigSizeDivisionByEight(BaseException):
|
|||
...
|
||||
|
||||
|
||||
async def rpc_call(
|
||||
sock,
|
||||
uid: Union[int, str],
|
||||
method: str,
|
||||
params: dict = {}
|
||||
):
|
||||
req = SkynetRPCRequest(
|
||||
uid=uid,
|
||||
method=method,
|
||||
params=params
|
||||
)
|
||||
await sock.asend(
|
||||
json.dumps(
|
||||
req.to_dict()).encode())
|
||||
|
||||
return SkynetRPCResponse(
|
||||
**json.loads(
|
||||
(await sock.arecv_msg()).bytes.decode()))
|
||||
|
||||
|
||||
@acm
|
||||
async def open_skynet_rpc(
|
||||
unique_id: str,
|
||||
rpc_address: str = DEFAULT_RPC_ADDR,
|
||||
security: bool = False,
|
||||
cert_name: Optional[str] = None,
|
||||
key_name: Optional[str] = None
|
||||
):
|
||||
tls_config = None
|
||||
|
||||
if security:
|
||||
# load tls certs
|
||||
if not key_name:
|
||||
key_name = certs_name
|
||||
key_name = cert_name
|
||||
|
||||
certs_dir = Path(DEFAULT_CERTS_DIR).resolve()
|
||||
skynet_cert = (certs_dir / 'brain.cert').read_text()
|
||||
tls_cert = (certs_dir / f'{cert_name}.cert').read_text()
|
||||
tls_key = (certs_dir / f'{key_name}.key').read_text()
|
||||
|
||||
skynet_cert_data = (certs_dir / 'brain.cert').read_text()
|
||||
skynet_cert = load_certificate(FILETYPE_PEM, skynet_cert_data)
|
||||
|
||||
tls_cert_path = certs_dir / f'{cert_name}.cert'
|
||||
tls_cert_data = tls_cert_path.read_text()
|
||||
tls_cert = load_certificate(FILETYPE_PEM, tls_cert_data)
|
||||
cert_name = tls_cert_path.stem
|
||||
|
||||
tls_key_data = (certs_dir / f'{key_name}.key').read_text()
|
||||
tls_key = load_privatekey(FILETYPE_PEM, tls_key_data)
|
||||
|
||||
rpc_address = 'tls+' + rpc_address
|
||||
tls_config = TLSConfig(
|
||||
TLSConfig.MODE_CLIENT,
|
||||
own_key_string=tls_key,
|
||||
own_cert_string=tls_cert,
|
||||
ca_string=skynet_cert)
|
||||
own_key_string=tls_key_data,
|
||||
own_cert_string=tls_cert_data,
|
||||
ca_string=skynet_cert_data)
|
||||
|
||||
with pynng.Req0() as sock:
|
||||
if security:
|
||||
sock.tls_config = tls_config
|
||||
|
||||
sock.dial(rpc_address)
|
||||
async def _rpc_call(*args, **kwargs):
|
||||
return await rpc_call(sock, *args, **kwargs)
|
||||
|
||||
async def _rpc_call(
|
||||
method: str,
|
||||
params: dict = {},
|
||||
uid: Optional[Union[int, str]] = None
|
||||
):
|
||||
req = SkynetRPCRequest(
|
||||
uid=uid if uid else unique_id,
|
||||
method=method,
|
||||
params=params
|
||||
)
|
||||
|
||||
if security:
|
||||
req.sign(tls_key, cert_name)
|
||||
|
||||
await sock.asend(
|
||||
json.dumps(
|
||||
req.to_dict()).encode())
|
||||
|
||||
resp = SkynetRPCResponse(
|
||||
**json.loads(
|
||||
(await sock.arecv_msg()).bytes.decode()))
|
||||
|
||||
if security:
|
||||
resp.verify(skynet_cert)
|
||||
|
||||
return resp
|
||||
|
||||
yield _rpc_call
|
||||
|
||||
|
|
|
@ -18,14 +18,19 @@ PREFIX = 'tg'
|
|||
|
||||
|
||||
async def run_skynet_telegram(
|
||||
tg_token: str
|
||||
tg_token: str,
|
||||
key_name: str = 'telegram-frontend',
|
||||
cert_name: str = 'whitelist/telegram-frontend'
|
||||
):
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
bot = AsyncTeleBot(tg_token)
|
||||
|
||||
with open_skynet_rpc(
|
||||
security=True, cert_name='telegram-frontend'
|
||||
'skynet-telegram-0',
|
||||
security=True,
|
||||
cert_name=cert,
|
||||
key_name=key
|
||||
) as rpc_call:
|
||||
|
||||
async def _rpc_call(
|
||||
|
@ -33,7 +38,8 @@ async def run_skynet_telegram(
|
|||
method: str,
|
||||
params: dict
|
||||
):
|
||||
return await rpc_call(f'{PREFIX}+{uid}', method, params)
|
||||
return await rpc_call(
|
||||
method, params, uid=f'{PREFIX}+{uid}')
|
||||
|
||||
@bot.message_handler(commands=['help'])
|
||||
async def send_help(message):
|
||||
|
|
|
@ -18,6 +18,8 @@
|
|||
Built-in (extension) types.
|
||||
"""
|
||||
import sys
|
||||
import json
|
||||
|
||||
from typing import Optional, Union
|
||||
from pprint import pformat
|
||||
|
||||
|
@ -83,15 +85,51 @@ class Struct(
|
|||
setattr(self, fname, ftype(getattr(self, fname)))
|
||||
|
||||
# proto
|
||||
from OpenSSL.crypto import PKey, X509, verify, sign
|
||||
|
||||
class SkynetRPCRequest(Struct):
|
||||
|
||||
class AuthenticatedStruct(Struct):
|
||||
cert: Optional[str] = None
|
||||
sig: Optional[str] = None
|
||||
|
||||
def to_unsigned_dict(self) -> dict:
|
||||
self_dict = self.to_dict()
|
||||
|
||||
if 'sig' in self_dict:
|
||||
del self_dict['sig']
|
||||
|
||||
if 'cert' in self_dict:
|
||||
del self_dict['cert']
|
||||
|
||||
return self_dict
|
||||
|
||||
def unsigned_to_bytes(self) -> bytes:
|
||||
return json.dumps(
|
||||
self.to_unsigned_dict()).encode()
|
||||
|
||||
def sign(self, key: PKey, cert: str):
|
||||
self.cert = cert
|
||||
self.sig = sign(
|
||||
key, self.unsigned_to_bytes(), 'sha256').hex()
|
||||
|
||||
def verify(self, cert: X509):
|
||||
if not self.sig:
|
||||
raise ValueError('Tried to verify unsigned request')
|
||||
|
||||
return verify(
|
||||
cert, bytes.fromhex(self.sig), self.unsigned_to_bytes(), 'sha256')
|
||||
|
||||
|
||||
class SkynetRPCRequest(AuthenticatedStruct):
|
||||
uid: Union[str, int] # user unique id
|
||||
method: str # rpc method name
|
||||
params: dict # variable params
|
||||
|
||||
class SkynetRPCResponse(Struct):
|
||||
|
||||
class SkynetRPCResponse(AuthenticatedStruct):
|
||||
result: dict
|
||||
|
||||
|
||||
class ImageGenRequest(Struct):
|
||||
prompt: str
|
||||
step: int
|
||||
|
@ -102,8 +140,15 @@ class ImageGenRequest(Struct):
|
|||
algo: str
|
||||
upscaler: Optional[str]
|
||||
|
||||
class DGPUBusRequest(Struct):
|
||||
|
||||
class DGPUBusRequest(AuthenticatedStruct):
|
||||
rid: str # req id
|
||||
nid: str # node id
|
||||
task: str
|
||||
params: dict
|
||||
|
||||
|
||||
class DGPUBusResponse(AuthenticatedStruct):
|
||||
rid: str # req id
|
||||
nid: str # node id
|
||||
params: dict
|
||||
|
|
|
@ -7,44 +7,37 @@ from pathlib import Path
|
|||
|
||||
import torch
|
||||
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from PIL import Image
|
||||
from diffusers import (
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
EulerAncestralDiscreteScheduler
|
||||
)
|
||||
from huggingface_hub import login
|
||||
|
||||
from .dgpu import pipeline_for
|
||||
|
||||
|
||||
def txt2img(
|
||||
hf_token: str,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
output: str,
|
||||
width: int, height: int,
|
||||
guidance: float,
|
||||
steps: int,
|
||||
seed: Optional[int]
|
||||
model: str = 'midj',
|
||||
prompt: str = 'a red old tractor in a sunny wheat field',
|
||||
output: str = 'output.png',
|
||||
width: int = 512, height: int = 512,
|
||||
guidance: float = 10,
|
||||
steps: int = 28,
|
||||
seed: Optional[int] = None
|
||||
):
|
||||
assert torch.cuda.is_available()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.set_per_process_memory_fraction(0.333)
|
||||
torch.cuda.set_per_process_memory_fraction(1.0)
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
login(token=hf_token)
|
||||
pipe = pipeline_for(model)
|
||||
|
||||
params = {
|
||||
'torch_dtype': torch.float16,
|
||||
'safety_checker': None
|
||||
}
|
||||
if model_name == 'runwayml/stable-diffusion-v1-5':
|
||||
params['revision'] = 'fp16'
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
model_name, **params)
|
||||
|
||||
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
||||
pipe.scheduler.config)
|
||||
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64)
|
||||
seed = seed if seed else random.randint(0, 2 ** 64)
|
||||
prompt = prompt
|
||||
image = pipe(
|
||||
prompt,
|
||||
|
@ -55,3 +48,34 @@ def txt2img(
|
|||
).images[0]
|
||||
|
||||
image.save(output)
|
||||
|
||||
|
||||
def upscale(
|
||||
hf_token: str,
|
||||
prompt: str = 'a red old tractor in a sunny wheat field',
|
||||
img_path: str = 'input.png',
|
||||
output: str = 'output.png',
|
||||
steps: int = 28
|
||||
):
|
||||
assert torch.cuda.is_available()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.set_per_process_memory_fraction(1.0)
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
login(token=hf_token)
|
||||
params = {
|
||||
'torch_dtype': torch.float16,
|
||||
'safety_checker': None
|
||||
}
|
||||
|
||||
pipe = StableDiffusionUpscalePipeline.from_pretrained(
|
||||
'stabilityai/stable-diffusion-x4-upscaler', **params)
|
||||
|
||||
prompt = prompt
|
||||
image = pipe(
|
||||
prompt,
|
||||
image=Image.open(img_path)
|
||||
).images[0]
|
||||
|
||||
image.save(output)
|
||||
|
|
|
@ -97,17 +97,22 @@ def dgpu_workers(request, dockerctl, skynet_running):
|
|||
|
||||
num_containers, initial_algos = request.param
|
||||
|
||||
cmd = f'''
|
||||
pip install -e . && \
|
||||
skynet run dgpu --algos=\'{json.dumps(initial_algos)}\'
|
||||
'''
|
||||
cmds = []
|
||||
for i in range(num_containers):
|
||||
cmd = f'''
|
||||
pip install -e . && \
|
||||
skynet run dgpu \
|
||||
--algos=\'{json.dumps(initial_algos)}\' \
|
||||
--uid=dgpu-{i}
|
||||
'''
|
||||
cmds.append(['bash', '-c', cmd])
|
||||
|
||||
logging.info(f'launching: \n{cmd}')
|
||||
|
||||
with dockerctl.run(
|
||||
DOCKER_RUNTIME_CUDA,
|
||||
name='skynet-test-runtime-cuda',
|
||||
command=['bash', '-c', cmd],
|
||||
commands=cmds,
|
||||
environment={
|
||||
'HF_TOKEN': os.environ['HF_TOKEN'],
|
||||
'HF_HOME': '/skynet/hf_home'
|
||||
|
|
|
@ -26,8 +26,7 @@ async def wait_for_dgpus(rpc, amount: int, timeout: float = 30.0):
|
|||
start_time = time.time()
|
||||
current_time = time.time()
|
||||
while not gpu_ready and (current_time - start_time) < timeout:
|
||||
res = await rpc('dgpu-test', 'dgpu_workers')
|
||||
logging.info(res)
|
||||
res = await rpc('dgpu_workers')
|
||||
if res.result['ok'] >= amount:
|
||||
break
|
||||
|
||||
|
@ -40,6 +39,7 @@ async def wait_for_dgpus(rpc, amount: int, timeout: float = 30.0):
|
|||
_images = set()
|
||||
async def check_request_img(
|
||||
i: int,
|
||||
uid: int = 0,
|
||||
width: int = 512,
|
||||
height: int = 512,
|
||||
expect_unique=True
|
||||
|
@ -47,12 +47,13 @@ async def check_request_img(
|
|||
global _images
|
||||
|
||||
async with open_skynet_rpc(
|
||||
uid,
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
) as rpc_call:
|
||||
res = await rpc_call(
|
||||
'tg+580213293', 'txt2img', {
|
||||
'txt2img', {
|
||||
'prompt': 'red old tractor in a sunny wheat field',
|
||||
'step': 28,
|
||||
'width': width, 'height': height,
|
||||
|
@ -88,6 +89,7 @@ async def test_dgpu_worker_compute_error(dgpu_workers):
|
|||
'''
|
||||
|
||||
async with open_skynet_rpc(
|
||||
'test-ctx',
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
|
@ -110,6 +112,7 @@ async def test_dgpu_workers(dgpu_workers):
|
|||
'''
|
||||
|
||||
async with open_skynet_rpc(
|
||||
'test-ctx',
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
|
@ -126,6 +129,7 @@ async def test_dgpu_workers_two(dgpu_workers):
|
|||
'''Generate two images in two separate dgpu workers
|
||||
'''
|
||||
async with open_skynet_rpc(
|
||||
'test-ctx',
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
|
@ -143,6 +147,7 @@ async def test_dgpu_worker_algo_swap(dgpu_workers):
|
|||
'''Generate an image using a non default model
|
||||
'''
|
||||
async with open_skynet_rpc(
|
||||
'test-ctx',
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
|
@ -158,35 +163,32 @@ async def test_dgpu_rotation_next_worker(dgpu_workers):
|
|||
rotation happens correctly
|
||||
'''
|
||||
async with open_skynet_rpc(
|
||||
'test-ctx',
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
) as test_rpc:
|
||||
await wait_for_dgpus(test_rpc, 3)
|
||||
|
||||
res = await test_rpc('testing-rpc', 'dgpu_next')
|
||||
logging.info(res)
|
||||
res = await test_rpc('dgpu_next')
|
||||
assert 'ok' in res.result
|
||||
assert res.result['ok'] == 0
|
||||
|
||||
await check_request_img(0)
|
||||
|
||||
res = await test_rpc('testing-rpc', 'dgpu_next')
|
||||
logging.info(res)
|
||||
res = await test_rpc('dgpu_next')
|
||||
assert 'ok' in res.result
|
||||
assert res.result['ok'] == 1
|
||||
|
||||
await check_request_img(0)
|
||||
|
||||
res = await test_rpc('testing-rpc', 'dgpu_next')
|
||||
logging.info(res)
|
||||
res = await test_rpc('dgpu_next')
|
||||
assert 'ok' in res.result
|
||||
assert res.result['ok'] == 2
|
||||
|
||||
await check_request_img(0)
|
||||
|
||||
res = await test_rpc('testing-rpc', 'dgpu_next')
|
||||
logging.info(res)
|
||||
res = await test_rpc('dgpu_next')
|
||||
assert 'ok' in res.result
|
||||
assert res.result['ok'] == 0
|
||||
|
||||
|
@ -198,6 +200,7 @@ async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers):
|
|||
next_worker rotation happens correctly
|
||||
'''
|
||||
async with open_skynet_rpc(
|
||||
'test-ctx',
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
|
@ -213,8 +216,7 @@ async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers):
|
|||
|
||||
dgpu_workers[0].wait()
|
||||
|
||||
res = await test_rpc('testing-rpc', 'dgpu_workers')
|
||||
logging.info(res)
|
||||
res = await test_rpc('dgpu_workers')
|
||||
assert 'ok' in res.result
|
||||
assert res.result['ok'] == 2
|
||||
|
||||
|
@ -224,14 +226,17 @@ async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers):
|
|||
|
||||
|
||||
async def test_dgpu_no_ack_node_disconnect(skynet_running):
|
||||
'''Mock a node that connects, gets a request but fails to
|
||||
acknowledge it, then check skynet correctly drops the node
|
||||
'''
|
||||
async with open_skynet_rpc(
|
||||
'test-ctx',
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
) as rpc_call:
|
||||
|
||||
res = await rpc_call('dgpu-0', 'dgpu_online')
|
||||
logging.info(res)
|
||||
res = await rpc_call('dgpu_online')
|
||||
assert 'ok' in res.result
|
||||
|
||||
await wait_for_dgpus(rpc_call, 1)
|
||||
|
@ -241,8 +246,34 @@ async def test_dgpu_no_ack_node_disconnect(skynet_running):
|
|||
|
||||
assert 'dgpu failed to acknowledge request' in str(e)
|
||||
|
||||
res = await rpc_call('testing-rpc', 'dgpu_workers')
|
||||
logging.info(res)
|
||||
res = await rpc_call('dgpu_workers')
|
||||
assert 'ok' in res.result
|
||||
assert res.result['ok'] == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'dgpu_workers', [(1, ['midj'])], indirect=True)
|
||||
async def test_dgpu_timeout_while_processing(dgpu_workers):
|
||||
'''Stop node while processing request to cause timeout and
|
||||
then check skynet correctly drops the node.
|
||||
'''
|
||||
async with open_skynet_rpc(
|
||||
'test-ctx',
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
) as test_rpc:
|
||||
await wait_for_dgpus(test_rpc, 1)
|
||||
|
||||
async def check_request_img_raises():
|
||||
with pytest.raises(SkynetDGPUComputeError) as e:
|
||||
await check_request_img(0)
|
||||
|
||||
assert 'timeout while processing request' in str(e)
|
||||
|
||||
async with trio.open_nursery() as n:
|
||||
n.start_soon(check_request_img_raises)
|
||||
await trio.sleep(1)
|
||||
ec, out = dgpu_workers[0].exec_run(
|
||||
['pkill', '-TERM', '-f', 'skynet'])
|
||||
assert ec == 0
|
||||
|
|
|
@ -12,64 +12,59 @@ from skynet.structs import *
|
|||
from skynet.frontend import open_skynet_rpc
|
||||
|
||||
|
||||
async def test_skynet(skynet_running):
|
||||
...
|
||||
|
||||
|
||||
async def test_skynet_attempt_insecure(skynet_running):
|
||||
with pytest.raises(pynng.exceptions.NNGException) as e:
|
||||
async with open_skynet_rpc():
|
||||
...
|
||||
async with open_skynet_rpc('bad-actor'):
|
||||
...
|
||||
|
||||
assert str(e.value) == 'Connection shutdown'
|
||||
|
||||
|
||||
async def test_skynet_dgpu_connection_simple(skynet_running):
|
||||
async with open_skynet_rpc(
|
||||
'dgpu-0',
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
) as rpc_call:
|
||||
# check 0 nodes are connected
|
||||
res = await rpc_call('dgpu-0', 'dgpu_workers')
|
||||
logging.info(res)
|
||||
res = await rpc_call('dgpu_workers')
|
||||
assert 'ok' in res.result
|
||||
assert res.result['ok'] == 0
|
||||
|
||||
# check next worker is None
|
||||
res = await rpc_call('dgpu-0', 'dgpu_next')
|
||||
logging.info(res)
|
||||
res = await rpc_call('dgpu_next')
|
||||
assert 'ok' in res.result
|
||||
assert res.result['ok'] == None
|
||||
|
||||
# connect 1 dgpu
|
||||
res = await rpc_call(
|
||||
'dgpu-0', 'dgpu_online')
|
||||
logging.info(res)
|
||||
res = await rpc_call('dgpu_online')
|
||||
assert 'ok' in res.result
|
||||
|
||||
# check 1 node is connected
|
||||
res = await rpc_call('dgpu-0', 'dgpu_workers')
|
||||
logging.info(res)
|
||||
res = await rpc_call('dgpu_workers')
|
||||
assert 'ok' in res.result
|
||||
assert res.result['ok'] == 1
|
||||
|
||||
# check next worker is 0
|
||||
res = await rpc_call('dgpu-0', 'dgpu_next')
|
||||
logging.info(res)
|
||||
res = await rpc_call('dgpu_next')
|
||||
assert 'ok' in res.result
|
||||
assert res.result['ok'] == 0
|
||||
|
||||
# disconnect 1 dgpu
|
||||
res = await rpc_call(
|
||||
'dgpu-0', 'dgpu_offline')
|
||||
logging.info(res)
|
||||
res = await rpc_call('dgpu_offline')
|
||||
assert 'ok' in res.result
|
||||
|
||||
# check 0 nodes are connected
|
||||
res = await rpc_call('dgpu-0', 'dgpu_workers')
|
||||
logging.info(res)
|
||||
res = await rpc_call('dgpu_workers')
|
||||
assert 'ok' in res.result
|
||||
assert res.result['ok'] == 0
|
||||
|
||||
# check next worker is None
|
||||
res = await rpc_call('dgpu-0', 'dgpu_next')
|
||||
logging.info(res)
|
||||
res = await rpc_call('dgpu_next')
|
||||
assert 'ok' in res.result
|
||||
assert res.result['ok'] == None
|
||||
|
|
Loading…
Reference in New Issue