mirror of https://github.com/skygpu/skynet.git
247 lines
6.7 KiB
Python
247 lines
6.7 KiB
Python
#!/usr/bin/python
|
|
|
|
import json
|
|
import uuid
|
|
import base64
|
|
import logging
|
|
|
|
from uuid import UUID
|
|
from functools import partial
|
|
from collections import OrderedDict
|
|
|
|
import trio
|
|
import pynng
|
|
import trio_asyncio
|
|
|
|
from .db import *
|
|
from .types import *
|
|
from .constants import *
|
|
|
|
|
|
class SkynetDGPUOffline(BaseException):
|
|
...
|
|
|
|
class SkynetDGPUOverloaded(BaseException):
|
|
...
|
|
|
|
|
|
async def rpc_service(sock, dgpu_bus, db_pool):
|
|
nodes = OrderedDict()
|
|
wip_reqs = {}
|
|
fin_reqs = {}
|
|
|
|
def are_all_workers_busy():
|
|
for nid, info in nodes.items():
|
|
if info['task'] == None:
|
|
return False
|
|
|
|
return True
|
|
|
|
next_worker = 0
|
|
def get_next_worker():
|
|
nonlocal next_worker
|
|
|
|
if len(nodes) == 0:
|
|
raise SkynetDGPUOffline
|
|
|
|
if are_all_workers_busy():
|
|
raise SkynetDGPUOverloaded
|
|
|
|
next_worker += 1
|
|
|
|
if next_worker >= len(nodes):
|
|
next_worker = 0
|
|
|
|
nid = list(nodes.keys())[next_worker]
|
|
return nid
|
|
|
|
async def dgpu_image_streamer():
|
|
nonlocal wip_reqs, fin_reqs
|
|
while True:
|
|
msg = await dgpu_bus.arecv_msg()
|
|
rid = UUID(bytes=msg.bytes[:16]).hex
|
|
img = msg.bytes[16:].hex()
|
|
fin_reqs[rid] = img
|
|
event = wip_reqs[rid]
|
|
event.set()
|
|
del wip_reqs[rid]
|
|
|
|
async def dgpu_stream_one_img(req: ImageGenRequest):
|
|
nonlocal wip_reqs, fin_reqs, next_worker
|
|
nid = get_next_worker()
|
|
logging.info(f'dgpu_stream_one_img {next_worker} {nid}')
|
|
rid = uuid.uuid4().hex
|
|
event = trio.Event()
|
|
wip_reqs[rid] = event
|
|
|
|
nodes[nid]['task'] = rid
|
|
|
|
dgpu_req = DGPUBusRequest(
|
|
rid=rid,
|
|
nid=nid,
|
|
task='diffuse',
|
|
params=req.to_dict())
|
|
|
|
logging.info(f'dgpu_bus req: {dgpu_req}')
|
|
|
|
await dgpu_bus.asend(
|
|
json.dumps(dgpu_req.to_dict()).encode())
|
|
|
|
await event.wait()
|
|
|
|
nodes[nid]['task'] = None
|
|
|
|
img = fin_reqs[rid]
|
|
del fin_reqs[rid]
|
|
|
|
logging.info(f'done streaming {img}')
|
|
|
|
return rid, img
|
|
|
|
async def handle_user_request(rpc_ctx, req):
|
|
try:
|
|
async with db_pool.acquire() as conn:
|
|
user = await get_or_create_user(conn, req.uid)
|
|
|
|
result = {}
|
|
|
|
match req.method:
|
|
case 'txt2img':
|
|
logging.info('txt2img')
|
|
user_config = {**(await get_user_config(conn, user))}
|
|
del user_config['id']
|
|
prompt = req.params['prompt']
|
|
req = ImageGenRequest(
|
|
prompt=prompt,
|
|
**user_config
|
|
)
|
|
rid, img = await dgpu_stream_one_img(req)
|
|
result = {
|
|
'id': rid,
|
|
'img': img
|
|
}
|
|
|
|
case 'redo':
|
|
logging.info('redo')
|
|
user_config = await get_user_config(conn, user)
|
|
prompt = await get_last_prompt_of(conn, user)
|
|
req = ImageGenRequest(
|
|
prompt=prompt,
|
|
**user_config
|
|
)
|
|
rid, img = await dgpu_stream_one_img(req)
|
|
result = {
|
|
'id': rid,
|
|
'img': img
|
|
}
|
|
|
|
case 'config':
|
|
logging.info('config')
|
|
if req.params['attr'] in CONFIG_ATTRS:
|
|
await update_user_config(
|
|
conn, user, req.params['attr'], req.params['val'])
|
|
|
|
case 'stats':
|
|
logging.info('stats')
|
|
generated, joined, role = await get_user_stats(conn, user)
|
|
|
|
result = {
|
|
'generated': generated,
|
|
'joined': joined.strftime(DATE_FORMAT),
|
|
'role': role
|
|
}
|
|
|
|
case _:
|
|
logging.warn('unknown method')
|
|
|
|
except SkynetDGPUOffline:
|
|
result = {
|
|
'error': 'skynet_dgpu_offline'
|
|
}
|
|
|
|
except SkynetDGPUOverloaded:
|
|
result = {
|
|
'error': 'skynet_dgpu_overloaded',
|
|
'nodes': len(nodes)
|
|
}
|
|
|
|
except BaseException as e:
|
|
logging.error(e)
|
|
raise e
|
|
# result = {
|
|
# 'error': 'skynet_internal_error'
|
|
# }
|
|
|
|
await rpc_ctx.asend(
|
|
json.dumps(
|
|
SkynetRPCResponse(result=result).to_dict()).encode())
|
|
|
|
|
|
async with trio.open_nursery() as n:
|
|
n.start_soon(dgpu_image_streamer)
|
|
while True:
|
|
ctx = sock.new_context()
|
|
msg = await ctx.arecv_msg()
|
|
content = msg.bytes.decode()
|
|
req = SkynetRPCRequest(**json.loads(content))
|
|
|
|
logging.info(req)
|
|
|
|
if req.method == 'dgpu_online':
|
|
nodes[req.uid] = {
|
|
'task': None
|
|
}
|
|
logging.info(f'dgpu online: {req.uid}')
|
|
|
|
|
|
elif req.method == 'dgpu_offline':
|
|
i = nodes.values().index(req.uid)
|
|
del nodes[req.uid]
|
|
|
|
if i < next_worker:
|
|
next_worker -= 1
|
|
logging.info(f'dgpu offline: {req.uid}')
|
|
|
|
else:
|
|
n.start_soon(
|
|
handle_user_request, ctx, req)
|
|
continue
|
|
|
|
await ctx.asend(
|
|
json.dumps(
|
|
SkynetRPCResponse(
|
|
result={'ok': {}}).to_dict()).encode())
|
|
|
|
|
|
async def run_skynet(
|
|
db_user: str,
|
|
db_pass: str,
|
|
db_host: str = DB_HOST,
|
|
rpc_address: str = DEFAULT_RPC_ADDR,
|
|
dgpu_address: str = DEFAULT_DGPU_ADDR,
|
|
task_status = trio.TASK_STATUS_IGNORED
|
|
):
|
|
logging.basicConfig(level=logging.INFO)
|
|
logging.info('skynet is starting')
|
|
|
|
async with (
|
|
trio.open_nursery() as n,
|
|
open_database_connection(
|
|
db_user, db_pass, db_host) as db_pool
|
|
):
|
|
logging.info('connected to db.')
|
|
with (
|
|
pynng.Rep0(listen=rpc_address) as rpc_sock,
|
|
pynng.Bus0(listen=dgpu_address) as dgpu_bus
|
|
):
|
|
n.start_soon(
|
|
rpc_service, rpc_sock, dgpu_bus, db_pool)
|
|
task_status.started()
|
|
|
|
try:
|
|
await trio.sleep_forever()
|
|
|
|
except KeyboardInterrupt:
|
|
...
|
|
|