mirror of https://github.com/skygpu/skynet.git
469 lines
14 KiB
Python
469 lines
14 KiB
Python
#!/usr/bin/python
|
|
|
|
import time
|
|
import json
|
|
import uuid
|
|
import zlib
|
|
import logging
|
|
import traceback
|
|
|
|
from uuid import UUID
|
|
from pathlib import Path
|
|
from functools import partial
|
|
from contextlib import asynccontextmanager as acm
|
|
from collections import OrderedDict
|
|
|
|
import trio
|
|
import pynng
|
|
import trio_asyncio
|
|
|
|
from pynng import TLSConfig
|
|
from OpenSSL.crypto import (
|
|
load_privatekey,
|
|
load_certificate,
|
|
FILETYPE_PEM
|
|
)
|
|
|
|
from .db import *
|
|
from .constants import *
|
|
|
|
from .protobuf import *
|
|
|
|
|
|
class SkynetDGPUOffline(BaseException):
|
|
...
|
|
|
|
class SkynetDGPUOverloaded(BaseException):
|
|
...
|
|
|
|
class SkynetDGPUComputeError(BaseException):
|
|
...
|
|
|
|
class SkynetShutdownRequested(BaseException):
|
|
...
|
|
|
|
|
|
@acm
|
|
async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
|
nodes = OrderedDict()
|
|
wip_reqs = {}
|
|
fin_reqs = {}
|
|
heartbeats = {}
|
|
next_worker: Optional[int] = None
|
|
security = len(tls_whitelist) > 0
|
|
|
|
def connect_node(uid):
|
|
nonlocal next_worker
|
|
nodes[uid] = {
|
|
'task': None
|
|
}
|
|
logging.info(f'dgpu online: {uid}')
|
|
|
|
if not next_worker:
|
|
next_worker = 0
|
|
|
|
def disconnect_node(uid):
|
|
nonlocal next_worker
|
|
if uid not in nodes:
|
|
return
|
|
i = list(nodes.keys()).index(uid)
|
|
del nodes[uid]
|
|
|
|
if i < next_worker:
|
|
next_worker -= 1
|
|
|
|
if len(nodes) == 0:
|
|
logging.info('nw: None')
|
|
next_worker = None
|
|
|
|
logging.warning(f'dgpu offline: {uid}')
|
|
|
|
def is_worker_busy(nid: str):
|
|
return nodes[nid]['task'] != None
|
|
|
|
def are_all_workers_busy():
|
|
for nid in nodes.keys():
|
|
if not is_worker_busy(nid):
|
|
return False
|
|
|
|
return True
|
|
|
|
def get_next_worker():
|
|
nonlocal next_worker
|
|
logging.info('get next_worker called')
|
|
logging.info(f'pre next_worker: {next_worker}')
|
|
|
|
if next_worker == None:
|
|
raise SkynetDGPUOffline('No workers connected, try again later')
|
|
|
|
if are_all_workers_busy():
|
|
raise SkynetDGPUOverloaded('All workers are busy at the moment')
|
|
|
|
|
|
nid = list(nodes.keys())[next_worker]
|
|
while is_worker_busy(nid):
|
|
next_worker += 1
|
|
|
|
if next_worker >= len(nodes):
|
|
next_worker = 0
|
|
|
|
nid = list(nodes.keys())[next_worker]
|
|
|
|
next_worker += 1
|
|
if next_worker >= len(nodes):
|
|
next_worker = 0
|
|
|
|
logging.info(f'post next_worker: {next_worker}')
|
|
|
|
return nid
|
|
|
|
async def dgpu_heartbeat_service():
|
|
nonlocal heartbeats
|
|
while True:
|
|
await trio.sleep(60)
|
|
rid = uuid.uuid4().hex
|
|
beat_msg = DGPUBusMessage(
|
|
rid=rid,
|
|
nid='',
|
|
method='heartbeat'
|
|
)
|
|
heartbeats.clear()
|
|
heartbeats[rid] = int(time.time() * 1000)
|
|
await dgpu_bus.asend(beat_msg.SerializeToString())
|
|
logging.info('sent heartbeat')
|
|
|
|
async def dgpu_bus_streamer():
|
|
nonlocal wip_reqs, fin_reqs, heartbeats
|
|
while True:
|
|
raw_msg = await dgpu_bus.arecv()
|
|
logging.info(f'streamer got {len(raw_msg)} bytes.')
|
|
msg = DGPUBusMessage()
|
|
msg.ParseFromString(raw_msg)
|
|
|
|
if security:
|
|
verify_protobuf_msg(msg, tls_whitelist[msg.auth.cert])
|
|
|
|
rid = msg.rid
|
|
|
|
if msg.method == 'heartbeat':
|
|
sent_time = heartbeats[rid]
|
|
delta = msg.params['time'] - sent_time
|
|
logging.info(f'got heartbeat reply from {msg.nid}, ping: {delta}')
|
|
continue
|
|
|
|
if rid not in wip_reqs:
|
|
continue
|
|
|
|
if msg.method == 'binary-reply':
|
|
logging.info('bin reply, recv extra data')
|
|
raw_img = await dgpu_bus.arecv()
|
|
msg = (msg, raw_img)
|
|
|
|
fin_reqs[rid] = msg
|
|
event = wip_reqs[rid]
|
|
event.set()
|
|
del wip_reqs[rid]
|
|
|
|
async def dgpu_stream_one_img(req: Text2ImageParameters):
|
|
nonlocal wip_reqs, fin_reqs, next_worker
|
|
nid = get_next_worker()
|
|
idx = list(nodes.keys()).index(nid)
|
|
logging.info(f'dgpu_stream_one_img {idx}/{len(nodes)} {nid}')
|
|
rid = uuid.uuid4().hex
|
|
ack_event = trio.Event()
|
|
img_event = trio.Event()
|
|
wip_reqs[rid] = ack_event
|
|
|
|
nodes[nid]['task'] = rid
|
|
|
|
dgpu_req = DGPUBusMessage(
|
|
rid=rid,
|
|
nid=nid,
|
|
method='diffuse')
|
|
dgpu_req.params.update(req.to_dict())
|
|
|
|
if security:
|
|
dgpu_req.auth.cert = 'skynet'
|
|
dgpu_req.auth.sig = sign_protobuf_msg(dgpu_req, tls_key)
|
|
|
|
await dgpu_bus.asend(dgpu_req.SerializeToString())
|
|
|
|
with trio.move_on_after(4):
|
|
await ack_event.wait()
|
|
|
|
logging.info(f'ack event: {ack_event.is_set()}')
|
|
|
|
if not ack_event.is_set():
|
|
disconnect_node(nid)
|
|
raise SkynetDGPUOffline('dgpu failed to acknowledge request')
|
|
|
|
ack_msg = fin_reqs[rid]
|
|
if 'ack' not in ack_msg.params:
|
|
disconnect_node(nid)
|
|
raise SkynetDGPUOffline('dgpu failed to acknowledge request')
|
|
|
|
wip_reqs[rid] = img_event
|
|
with trio.move_on_after(30):
|
|
await img_event.wait()
|
|
|
|
logging.info(f'img event: {ack_event.is_set()}')
|
|
|
|
if not img_event.is_set():
|
|
disconnect_node(nid)
|
|
raise SkynetDGPUComputeError('30 seconds timeout while processing request')
|
|
|
|
nodes[nid]['task'] = None
|
|
|
|
resp = fin_reqs[rid]
|
|
del fin_reqs[rid]
|
|
if isinstance(resp, tuple):
|
|
meta, img = resp
|
|
return rid, img, meta.params
|
|
|
|
raise SkynetDGPUComputeError(MessageToDict(resp.params))
|
|
|
|
|
|
async def handle_user_request(rpc_ctx, req):
|
|
try:
|
|
async with db_pool.acquire() as conn:
|
|
user = await get_or_create_user(conn, req.uid)
|
|
|
|
result = {}
|
|
|
|
match req.method:
|
|
case 'txt2img':
|
|
logging.info('txt2img')
|
|
user_config = {**(await get_user_config(conn, user))}
|
|
del user_config['id']
|
|
user_config.update(MessageToDict(req.params))
|
|
|
|
req = Text2ImageParameters(**user_config)
|
|
rid, img, meta = await dgpu_stream_one_img(req)
|
|
logging.info(f'done streaming {rid}')
|
|
result = {
|
|
'id': rid,
|
|
'img': zlib.compress(img).hex(),
|
|
'meta': meta
|
|
}
|
|
|
|
await update_user_stats(conn, user, last_prompt=user_config['prompt'])
|
|
logging.info('updated user stats.')
|
|
|
|
case 'redo':
|
|
logging.info('redo')
|
|
user_config = {**(await get_user_config(conn, user))}
|
|
del user_config['id']
|
|
prompt = await get_last_prompt_of(conn, user)
|
|
|
|
if prompt:
|
|
req = Text2ImageParameters(
|
|
prompt=prompt,
|
|
**user_config
|
|
)
|
|
rid, img, meta = await dgpu_stream_one_img(req)
|
|
result = {
|
|
'id': rid,
|
|
'img': zlib.compress(img).hex(),
|
|
'meta': meta
|
|
}
|
|
await update_user_stats(conn, user)
|
|
logging.info('updated user stats.')
|
|
|
|
else:
|
|
result = {
|
|
'error': 'skynet_no_last_prompt',
|
|
'message': 'No prompt to redo, do txt2img first'
|
|
}
|
|
|
|
case 'config':
|
|
logging.info('config')
|
|
if req.params['attr'] in CONFIG_ATTRS:
|
|
logging.info(f'update: {req.params}')
|
|
await update_user_config(
|
|
conn, user, req.params['attr'], req.params['val'])
|
|
logging.info('done')
|
|
|
|
case 'stats':
|
|
logging.info('stats')
|
|
generated, joined, role = await get_user_stats(conn, user)
|
|
|
|
result = {
|
|
'generated': generated,
|
|
'joined': joined.strftime(DATE_FORMAT),
|
|
'role': role
|
|
}
|
|
|
|
case _:
|
|
logging.warn('unknown method')
|
|
|
|
except SkynetDGPUOffline as e:
|
|
result = {
|
|
'error': 'skynet_dgpu_offline',
|
|
'message': str(e)
|
|
}
|
|
|
|
except SkynetDGPUOverloaded as e:
|
|
result = {
|
|
'error': 'skynet_dgpu_overloaded',
|
|
'message': str(e),
|
|
'nodes': len(nodes)
|
|
}
|
|
|
|
except SkynetDGPUComputeError as e:
|
|
result = {
|
|
'error': 'skynet_dgpu_compute_error',
|
|
'message': str(e)
|
|
}
|
|
except BaseException as e:
|
|
traceback.print_exception(type(e), e, e.__traceback__)
|
|
result = {
|
|
'error': 'skynet_internal_error',
|
|
'message': str(e)
|
|
}
|
|
|
|
resp = SkynetRPCResponse()
|
|
resp.result.update(result)
|
|
|
|
if security:
|
|
resp.auth.cert = 'skynet'
|
|
resp.auth.sig = sign_protobuf_msg(resp, tls_key)
|
|
|
|
logging.info('sending response')
|
|
await rpc_ctx.asend(resp.SerializeToString())
|
|
rpc_ctx.close()
|
|
logging.info('done')
|
|
|
|
async def request_service(n):
|
|
nonlocal next_worker
|
|
while True:
|
|
ctx = sock.new_context()
|
|
req = SkynetRPCRequest()
|
|
req.ParseFromString(await ctx.arecv())
|
|
|
|
if security:
|
|
if req.auth.cert not in tls_whitelist:
|
|
logging.warning(
|
|
f'{req.cert} not in tls whitelist and security=True')
|
|
continue
|
|
|
|
try:
|
|
verify_protobuf_msg(req, tls_whitelist[req.auth.cert])
|
|
|
|
except ValueError:
|
|
logging.warning(
|
|
f'{req.cert} sent an unauthenticated msg with security=True')
|
|
continue
|
|
|
|
result = {}
|
|
|
|
match req.method:
|
|
case 'skynet_shutdown':
|
|
raise SkynetShutdownRequested
|
|
|
|
case 'dgpu_online':
|
|
connect_node(req.uid)
|
|
|
|
case 'dgpu_offline':
|
|
disconnect_node(req.uid)
|
|
|
|
case 'dgpu_workers':
|
|
result = len(nodes)
|
|
|
|
case 'dgpu_next':
|
|
result = next_worker
|
|
|
|
case 'heartbeat':
|
|
logging.info('beat')
|
|
result = {'time': time.time()}
|
|
|
|
case _:
|
|
n.start_soon(
|
|
handle_user_request, ctx, req)
|
|
continue
|
|
|
|
resp = SkynetRPCResponse()
|
|
resp.result.update({'ok': result})
|
|
|
|
if security:
|
|
resp.auth.cert = 'skynet'
|
|
resp.auth.sig = sign_protobuf_msg(resp, tls_key)
|
|
|
|
await ctx.asend(resp.SerializeToString())
|
|
|
|
ctx.close()
|
|
|
|
|
|
async with trio.open_nursery() as n:
|
|
n.start_soon(dgpu_bus_streamer)
|
|
n.start_soon(dgpu_heartbeat_service)
|
|
n.start_soon(request_service, n)
|
|
logging.info('starting rpc service')
|
|
yield
|
|
logging.info('stopping rpc service')
|
|
n.cancel_scope.cancel()
|
|
|
|
|
|
@acm
|
|
async def run_skynet(
|
|
db_user: str = DB_USER,
|
|
db_pass: str = DB_PASS,
|
|
db_host: str = DB_HOST,
|
|
rpc_address: str = DEFAULT_RPC_ADDR,
|
|
dgpu_address: str = DEFAULT_DGPU_ADDR,
|
|
security: bool = True
|
|
):
|
|
logging.basicConfig(level=logging.INFO)
|
|
logging.info('skynet is starting')
|
|
|
|
tls_config = None
|
|
if security:
|
|
# load tls certs
|
|
certs_dir = Path(DEFAULT_CERTS_DIR).resolve()
|
|
|
|
tls_key_data = (certs_dir / DEFAULT_CERT_SKYNET_PRIV).read_text()
|
|
tls_key = load_privatekey(FILETYPE_PEM, tls_key_data)
|
|
|
|
tls_cert_data = (certs_dir / DEFAULT_CERT_SKYNET_PUB).read_text()
|
|
tls_cert = load_certificate(FILETYPE_PEM, tls_cert_data)
|
|
|
|
tls_whitelist = {}
|
|
for cert_path in (certs_dir / 'whitelist').glob('*.cert'):
|
|
tls_whitelist[cert_path.stem] = load_certificate(
|
|
FILETYPE_PEM, cert_path.read_text())
|
|
|
|
cert_start = tls_cert_data.index('\n') + 1
|
|
logging.info(f'tls_cert: {tls_cert_data[cert_start:cert_start+64]}...')
|
|
logging.info(f'tls_whitelist len: {len(tls_whitelist)}')
|
|
|
|
rpc_address = 'tls+' + rpc_address
|
|
dgpu_address = 'tls+' + dgpu_address
|
|
tls_config = TLSConfig(
|
|
TLSConfig.MODE_SERVER,
|
|
own_key_string=tls_key_data,
|
|
own_cert_string=tls_cert_data)
|
|
|
|
with (
|
|
pynng.Rep0(recv_max_size=0) as rpc_sock,
|
|
pynng.Bus0(recv_max_size=0) as dgpu_bus
|
|
):
|
|
async with open_database_connection(
|
|
db_user, db_pass, db_host) as db_pool:
|
|
|
|
logging.info('connected to db.')
|
|
if security:
|
|
rpc_sock.tls_config = tls_config
|
|
dgpu_bus.tls_config = tls_config
|
|
|
|
rpc_sock.listen(rpc_address)
|
|
dgpu_bus.listen(dgpu_address)
|
|
|
|
try:
|
|
async with open_rpc_service(
|
|
rpc_sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
|
yield
|
|
|
|
except SkynetShutdownRequested:
|
|
...
|
|
|
|
logging.info('disconnected from db.')
|