skynet/skynet/brain.py

469 lines
14 KiB
Python

#!/usr/bin/python
import time
import json
import uuid
import zlib
import logging
import traceback
from uuid import UUID
from pathlib import Path
from functools import partial
from contextlib import asynccontextmanager as acm
from collections import OrderedDict
import trio
import pynng
import trio_asyncio
from pynng import TLSConfig
from OpenSSL.crypto import (
load_privatekey,
load_certificate,
FILETYPE_PEM
)
from .db import *
from .constants import *
from .protobuf import *
class SkynetDGPUOffline(BaseException):
...
class SkynetDGPUOverloaded(BaseException):
...
class SkynetDGPUComputeError(BaseException):
...
class SkynetShutdownRequested(BaseException):
...
@acm
async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
nodes = OrderedDict()
wip_reqs = {}
fin_reqs = {}
heartbeats = {}
next_worker: Optional[int] = None
security = len(tls_whitelist) > 0
def connect_node(uid):
nonlocal next_worker
nodes[uid] = {
'task': None
}
logging.info(f'dgpu online: {uid}')
if not next_worker:
next_worker = 0
def disconnect_node(uid):
nonlocal next_worker
if uid not in nodes:
return
i = list(nodes.keys()).index(uid)
del nodes[uid]
if i < next_worker:
next_worker -= 1
if len(nodes) == 0:
logging.info('nw: None')
next_worker = None
logging.warning(f'dgpu offline: {uid}')
def is_worker_busy(nid: str):
return nodes[nid]['task'] != None
def are_all_workers_busy():
for nid in nodes.keys():
if not is_worker_busy(nid):
return False
return True
def get_next_worker():
nonlocal next_worker
logging.info('get next_worker called')
logging.info(f'pre next_worker: {next_worker}')
if next_worker == None:
raise SkynetDGPUOffline('No workers connected, try again later')
if are_all_workers_busy():
raise SkynetDGPUOverloaded('All workers are busy at the moment')
nid = list(nodes.keys())[next_worker]
while is_worker_busy(nid):
next_worker += 1
if next_worker >= len(nodes):
next_worker = 0
nid = list(nodes.keys())[next_worker]
next_worker += 1
if next_worker >= len(nodes):
next_worker = 0
logging.info(f'post next_worker: {next_worker}')
return nid
async def dgpu_heartbeat_service():
nonlocal heartbeats
while True:
await trio.sleep(60)
rid = uuid.uuid4().hex
beat_msg = DGPUBusMessage(
rid=rid,
nid='',
method='heartbeat'
)
heartbeats.clear()
heartbeats[rid] = int(time.time() * 1000)
await dgpu_bus.asend(beat_msg.SerializeToString())
logging.info('sent heartbeat')
async def dgpu_bus_streamer():
nonlocal wip_reqs, fin_reqs, heartbeats
while True:
raw_msg = await dgpu_bus.arecv()
logging.info(f'streamer got {len(raw_msg)} bytes.')
msg = DGPUBusMessage()
msg.ParseFromString(raw_msg)
if security:
verify_protobuf_msg(msg, tls_whitelist[msg.auth.cert])
rid = msg.rid
if msg.method == 'heartbeat':
sent_time = heartbeats[rid]
delta = msg.params['time'] - sent_time
logging.info(f'got heartbeat reply from {msg.nid}, ping: {delta}')
continue
if rid not in wip_reqs:
continue
if msg.method == 'binary-reply':
logging.info('bin reply, recv extra data')
raw_img = await dgpu_bus.arecv()
msg = (msg, raw_img)
fin_reqs[rid] = msg
event = wip_reqs[rid]
event.set()
del wip_reqs[rid]
async def dgpu_stream_one_img(req: Text2ImageParameters):
nonlocal wip_reqs, fin_reqs, next_worker
nid = get_next_worker()
idx = list(nodes.keys()).index(nid)
logging.info(f'dgpu_stream_one_img {idx}/{len(nodes)} {nid}')
rid = uuid.uuid4().hex
ack_event = trio.Event()
img_event = trio.Event()
wip_reqs[rid] = ack_event
nodes[nid]['task'] = rid
dgpu_req = DGPUBusMessage(
rid=rid,
nid=nid,
method='diffuse')
dgpu_req.params.update(req.to_dict())
if security:
dgpu_req.auth.cert = 'skynet'
dgpu_req.auth.sig = sign_protobuf_msg(dgpu_req, tls_key)
await dgpu_bus.asend(dgpu_req.SerializeToString())
with trio.move_on_after(4):
await ack_event.wait()
logging.info(f'ack event: {ack_event.is_set()}')
if not ack_event.is_set():
disconnect_node(nid)
raise SkynetDGPUOffline('dgpu failed to acknowledge request')
ack_msg = fin_reqs[rid]
if 'ack' not in ack_msg.params:
disconnect_node(nid)
raise SkynetDGPUOffline('dgpu failed to acknowledge request')
wip_reqs[rid] = img_event
with trio.move_on_after(30):
await img_event.wait()
logging.info(f'img event: {ack_event.is_set()}')
if not img_event.is_set():
disconnect_node(nid)
raise SkynetDGPUComputeError('30 seconds timeout while processing request')
nodes[nid]['task'] = None
resp = fin_reqs[rid]
del fin_reqs[rid]
if isinstance(resp, tuple):
meta, img = resp
return rid, img, meta.params
raise SkynetDGPUComputeError(MessageToDict(resp.params))
async def handle_user_request(rpc_ctx, req):
try:
async with db_pool.acquire() as conn:
user = await get_or_create_user(conn, req.uid)
result = {}
match req.method:
case 'txt2img':
logging.info('txt2img')
user_config = {**(await get_user_config(conn, user))}
del user_config['id']
user_config.update(MessageToDict(req.params))
req = Text2ImageParameters(**user_config)
rid, img, meta = await dgpu_stream_one_img(req)
logging.info(f'done streaming {rid}')
result = {
'id': rid,
'img': zlib.compress(img).hex(),
'meta': meta
}
await update_user_stats(conn, user, last_prompt=user_config['prompt'])
logging.info('updated user stats.')
case 'redo':
logging.info('redo')
user_config = {**(await get_user_config(conn, user))}
del user_config['id']
prompt = await get_last_prompt_of(conn, user)
if prompt:
req = Text2ImageParameters(
prompt=prompt,
**user_config
)
rid, img, meta = await dgpu_stream_one_img(req)
result = {
'id': rid,
'img': zlib.compress(img).hex(),
'meta': meta
}
await update_user_stats(conn, user)
logging.info('updated user stats.')
else:
result = {
'error': 'skynet_no_last_prompt',
'message': 'No prompt to redo, do txt2img first'
}
case 'config':
logging.info('config')
if req.params['attr'] in CONFIG_ATTRS:
logging.info(f'update: {req.params}')
await update_user_config(
conn, user, req.params['attr'], req.params['val'])
logging.info('done')
case 'stats':
logging.info('stats')
generated, joined, role = await get_user_stats(conn, user)
result = {
'generated': generated,
'joined': joined.strftime(DATE_FORMAT),
'role': role
}
case _:
logging.warn('unknown method')
except SkynetDGPUOffline as e:
result = {
'error': 'skynet_dgpu_offline',
'message': str(e)
}
except SkynetDGPUOverloaded as e:
result = {
'error': 'skynet_dgpu_overloaded',
'message': str(e),
'nodes': len(nodes)
}
except SkynetDGPUComputeError as e:
result = {
'error': 'skynet_dgpu_compute_error',
'message': str(e)
}
except BaseException as e:
traceback.print_exception(type(e), e, e.__traceback__)
result = {
'error': 'skynet_internal_error',
'message': str(e)
}
resp = SkynetRPCResponse()
resp.result.update(result)
if security:
resp.auth.cert = 'skynet'
resp.auth.sig = sign_protobuf_msg(resp, tls_key)
logging.info('sending response')
await rpc_ctx.asend(resp.SerializeToString())
rpc_ctx.close()
logging.info('done')
async def request_service(n):
nonlocal next_worker
while True:
ctx = sock.new_context()
req = SkynetRPCRequest()
req.ParseFromString(await ctx.arecv())
if security:
if req.auth.cert not in tls_whitelist:
logging.warning(
f'{req.cert} not in tls whitelist and security=True')
continue
try:
verify_protobuf_msg(req, tls_whitelist[req.auth.cert])
except ValueError:
logging.warning(
f'{req.cert} sent an unauthenticated msg with security=True')
continue
result = {}
match req.method:
case 'skynet_shutdown':
raise SkynetShutdownRequested
case 'dgpu_online':
connect_node(req.uid)
case 'dgpu_offline':
disconnect_node(req.uid)
case 'dgpu_workers':
result = len(nodes)
case 'dgpu_next':
result = next_worker
case 'heartbeat':
logging.info('beat')
result = {'time': time.time()}
case _:
n.start_soon(
handle_user_request, ctx, req)
continue
resp = SkynetRPCResponse()
resp.result.update({'ok': result})
if security:
resp.auth.cert = 'skynet'
resp.auth.sig = sign_protobuf_msg(resp, tls_key)
await ctx.asend(resp.SerializeToString())
ctx.close()
async with trio.open_nursery() as n:
n.start_soon(dgpu_bus_streamer)
n.start_soon(dgpu_heartbeat_service)
n.start_soon(request_service, n)
logging.info('starting rpc service')
yield
logging.info('stopping rpc service')
n.cancel_scope.cancel()
@acm
async def run_skynet(
db_user: str = DB_USER,
db_pass: str = DB_PASS,
db_host: str = DB_HOST,
rpc_address: str = DEFAULT_RPC_ADDR,
dgpu_address: str = DEFAULT_DGPU_ADDR,
security: bool = True
):
logging.basicConfig(level=logging.INFO)
logging.info('skynet is starting')
tls_config = None
if security:
# load tls certs
certs_dir = Path(DEFAULT_CERTS_DIR).resolve()
tls_key_data = (certs_dir / DEFAULT_CERT_SKYNET_PRIV).read_text()
tls_key = load_privatekey(FILETYPE_PEM, tls_key_data)
tls_cert_data = (certs_dir / DEFAULT_CERT_SKYNET_PUB).read_text()
tls_cert = load_certificate(FILETYPE_PEM, tls_cert_data)
tls_whitelist = {}
for cert_path in (certs_dir / 'whitelist').glob('*.cert'):
tls_whitelist[cert_path.stem] = load_certificate(
FILETYPE_PEM, cert_path.read_text())
cert_start = tls_cert_data.index('\n') + 1
logging.info(f'tls_cert: {tls_cert_data[cert_start:cert_start+64]}...')
logging.info(f'tls_whitelist len: {len(tls_whitelist)}')
rpc_address = 'tls+' + rpc_address
dgpu_address = 'tls+' + dgpu_address
tls_config = TLSConfig(
TLSConfig.MODE_SERVER,
own_key_string=tls_key_data,
own_cert_string=tls_cert_data)
with (
pynng.Rep0(recv_max_size=0) as rpc_sock,
pynng.Bus0(recv_max_size=0) as dgpu_bus
):
async with open_database_connection(
db_user, db_pass, db_host) as db_pool:
logging.info('connected to db.')
if security:
rpc_sock.tls_config = tls_config
dgpu_bus.tls_config = tls_config
rpc_sock.listen(rpc_address)
dgpu_bus.listen(dgpu_address)
try:
async with open_rpc_service(
rpc_sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
yield
except SkynetShutdownRequested:
...
logging.info('disconnected from db.')