mirror of https://github.com/skygpu/skynet.git
Started making roboust testing fixtures to init fresh db and skynet
Add simple dgpu worker connection test Make db connection handler manage schema and table init logic Keep tweaking dgpu main handler attemtping to fix subactor hangs Change frontend open rpc logic to return a wrapped rpc_call fn referencing the new socket Decupled user config request validation from telegram module Fix next_worker logic, now takes in account multiple tasks per dgpu Add dgpu_workers and dgpu_next calls Fixed readme, moved db init code into db modulepull/2/head
parent
c3852314a7
commit
9afb192251
49
README.md
49
README.md
|
@ -1,47 +1,2 @@
|
||||||
create db in postgres:
|
# skynet
|
||||||
|
### decentralized compute platform
|
||||||
```sql
|
|
||||||
CREATE USER skynet WITH PASSWORD 'password';
|
|
||||||
CREATE DATABASE skynet_art_bot;
|
|
||||||
GRANT ALL PRIVILEGES ON DATABASE skynet_art_bot TO skynet;
|
|
||||||
|
|
||||||
CREATE SCHEMA IF NOT EXISTS skynet;
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS skynet.user(
|
|
||||||
id SERIAL PRIMARY KEY NOT NULL,
|
|
||||||
tg_id INT,
|
|
||||||
wp_id VARCHAR(128),
|
|
||||||
mx_id VARCHAR(128),
|
|
||||||
ig_id VARCHAR(128),
|
|
||||||
generated INT NOT NULL,
|
|
||||||
joined DATE NOT NULL,
|
|
||||||
last_prompt TEXT,
|
|
||||||
role VARCHAR(128) NOT NULL
|
|
||||||
);
|
|
||||||
ALTER TABLE skynet.user
|
|
||||||
ADD CONSTRAINT tg_unique
|
|
||||||
UNIQUE (tg_id);
|
|
||||||
ALTER TABLE skynet.user
|
|
||||||
ADD CONSTRAINT wp_unique
|
|
||||||
UNIQUE (wp_id);
|
|
||||||
ALTER TABLE skynet.user
|
|
||||||
ADD CONSTRAINT mx_unique
|
|
||||||
UNIQUE (mx_id);
|
|
||||||
ALTER TABLE skynet.user
|
|
||||||
ADD CONSTRAINT ig_unique
|
|
||||||
UNIQUE (ig_id);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS skynet.user_config(
|
|
||||||
id SERIAL NOT NULL,
|
|
||||||
algo VARCHAR(128) NOT NULL,
|
|
||||||
step INT NOT NULL,
|
|
||||||
width INT NOT NULL,
|
|
||||||
height INT NOT NULL,
|
|
||||||
seed INT,
|
|
||||||
guidance INT NOT NULL,
|
|
||||||
upscaler VARCHAR(128)
|
|
||||||
);
|
|
||||||
ALTER TABLE skynet.user_config
|
|
||||||
ADD FOREIGN KEY(id)
|
|
||||||
REFERENCES skynet.user(id);
|
|
||||||
```
|
|
||||||
|
|
|
@ -1,2 +1,5 @@
|
||||||
pytest
|
pytest
|
||||||
|
psycopg2
|
||||||
pytest-trio
|
pytest-trio
|
||||||
|
|
||||||
|
git+https://github.com/tgoodlet/pytest-dockerctl.git@master#egg=pytest-dockerctl
|
||||||
|
|
|
@ -30,30 +30,37 @@ async def rpc_service(sock, dgpu_bus, db_pool):
|
||||||
wip_reqs = {}
|
wip_reqs = {}
|
||||||
fin_reqs = {}
|
fin_reqs = {}
|
||||||
|
|
||||||
def are_all_workers_busy():
|
def is_worker_busy(nid: int):
|
||||||
for nid, info in nodes.items():
|
for task in nodes[nid]['tasks']:
|
||||||
if info['task'] == None:
|
if task != None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
next_worker = 0
|
def are_all_workers_busy():
|
||||||
|
for nid in nodes.keys():
|
||||||
|
if not is_worker_busy(nid):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
next_worker: Optional[int] = None
|
||||||
def get_next_worker():
|
def get_next_worker():
|
||||||
nonlocal next_worker
|
nonlocal next_worker
|
||||||
|
|
||||||
if len(nodes) == 0:
|
if not next_worker:
|
||||||
raise SkynetDGPUOffline
|
raise SkynetDGPUOffline
|
||||||
|
|
||||||
if are_all_workers_busy():
|
if are_all_workers_busy():
|
||||||
raise SkynetDGPUOverloaded
|
raise SkynetDGPUOverloaded
|
||||||
|
|
||||||
next_worker += 1
|
while is_worker_busy(next_worker):
|
||||||
|
next_worker += 1
|
||||||
|
|
||||||
if next_worker >= len(nodes):
|
if next_worker >= len(nodes):
|
||||||
next_worker = 0
|
next_worker = 0
|
||||||
|
|
||||||
nid = list(nodes.keys())[next_worker]
|
return next_worker
|
||||||
return nid
|
|
||||||
|
|
||||||
async def dgpu_image_streamer():
|
async def dgpu_image_streamer():
|
||||||
nonlocal wip_reqs, fin_reqs
|
nonlocal wip_reqs, fin_reqs
|
||||||
|
@ -74,7 +81,8 @@ async def rpc_service(sock, dgpu_bus, db_pool):
|
||||||
event = trio.Event()
|
event = trio.Event()
|
||||||
wip_reqs[rid] = event
|
wip_reqs[rid] = event
|
||||||
|
|
||||||
nodes[nid]['task'] = rid
|
tid = nodes[nid]['tasks'].index(None)
|
||||||
|
nodes[nid]['tasks'][tid] = rid
|
||||||
|
|
||||||
dgpu_req = DGPUBusRequest(
|
dgpu_req = DGPUBusRequest(
|
||||||
rid=rid,
|
rid=rid,
|
||||||
|
@ -89,7 +97,7 @@ async def rpc_service(sock, dgpu_bus, db_pool):
|
||||||
|
|
||||||
await event.wait()
|
await event.wait()
|
||||||
|
|
||||||
nodes[nid]['task'] = None
|
nodes[nid]['tasks'][tid] = None
|
||||||
|
|
||||||
img = fin_reqs[rid]
|
img = fin_reqs[rid]
|
||||||
del fin_reqs[rid]
|
del fin_reqs[rid]
|
||||||
|
@ -167,10 +175,9 @@ async def rpc_service(sock, dgpu_bus, db_pool):
|
||||||
|
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logging.error(e)
|
logging.error(e)
|
||||||
raise e
|
result = {
|
||||||
# result = {
|
'error': 'skynet_internal_error'
|
||||||
# 'error': 'skynet_internal_error'
|
}
|
||||||
# }
|
|
||||||
|
|
||||||
await rpc_ctx.asend(
|
await rpc_ctx.asend(
|
||||||
json.dumps(
|
json.dumps(
|
||||||
|
@ -187,21 +194,36 @@ async def rpc_service(sock, dgpu_bus, db_pool):
|
||||||
|
|
||||||
logging.info(req)
|
logging.info(req)
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
|
||||||
if req.method == 'dgpu_online':
|
if req.method == 'dgpu_online':
|
||||||
nodes[req.uid] = {
|
nodes[req.uid] = {
|
||||||
'task': None
|
'tasks': [None for _ in range(req.params['max_tasks'])],
|
||||||
|
'max_tasks': req.params['max_tasks']
|
||||||
}
|
}
|
||||||
logging.info(f'dgpu online: {req.uid}')
|
logging.info(f'dgpu online: {req.uid}')
|
||||||
|
|
||||||
|
if not next_worker:
|
||||||
|
next_worker = 0
|
||||||
|
|
||||||
elif req.method == 'dgpu_offline':
|
elif req.method == 'dgpu_offline':
|
||||||
i = nodes.values().index(req.uid)
|
i = list(nodes.keys()).index(req.uid)
|
||||||
del nodes[req.uid]
|
del nodes[req.uid]
|
||||||
|
|
||||||
if i < next_worker:
|
if i < next_worker:
|
||||||
next_worker -= 1
|
next_worker -= 1
|
||||||
|
|
||||||
|
if len(nodes) == 0:
|
||||||
|
next_worker = None
|
||||||
|
|
||||||
logging.info(f'dgpu offline: {req.uid}')
|
logging.info(f'dgpu offline: {req.uid}')
|
||||||
|
|
||||||
|
elif req.method == 'dgpu_workers':
|
||||||
|
result = len(nodes)
|
||||||
|
|
||||||
|
elif req.method == 'dgpu_next':
|
||||||
|
result = next_worker
|
||||||
|
|
||||||
else:
|
else:
|
||||||
n.start_soon(
|
n.start_soon(
|
||||||
handle_user_request, ctx, req)
|
handle_user_request, ctx, req)
|
||||||
|
@ -210,12 +232,12 @@ async def rpc_service(sock, dgpu_bus, db_pool):
|
||||||
await ctx.asend(
|
await ctx.asend(
|
||||||
json.dumps(
|
json.dumps(
|
||||||
SkynetRPCResponse(
|
SkynetRPCResponse(
|
||||||
result={'ok': {}}).to_dict()).encode())
|
result={'ok': result}).to_dict()).encode())
|
||||||
|
|
||||||
|
|
||||||
async def run_skynet(
|
async def run_skynet(
|
||||||
db_user: str,
|
db_user: str = DB_USER,
|
||||||
db_pass: str,
|
db_pass: str = DB_PASS,
|
||||||
db_host: str = DB_HOST,
|
db_host: str = DB_HOST,
|
||||||
rpc_address: str = DEFAULT_RPC_ADDR,
|
rpc_address: str = DEFAULT_RPC_ADDR,
|
||||||
dgpu_address: str = DEFAULT_DGPU_ADDR,
|
dgpu_address: str = DEFAULT_DGPU_ADDR,
|
||||||
|
|
|
@ -3,6 +3,9 @@
|
||||||
API_TOKEN = '5880619053:AAFge2UfObw1kCn9Kb7AAyqaIHs_HgM0Fx0'
|
API_TOKEN = '5880619053:AAFge2UfObw1kCn9Kb7AAyqaIHs_HgM0Fx0'
|
||||||
|
|
||||||
DB_HOST = 'ancap.tech:34508'
|
DB_HOST = 'ancap.tech:34508'
|
||||||
|
DB_USER = 'skynet'
|
||||||
|
DB_PASS = 'password'
|
||||||
|
DB_NAME = 'skynet'
|
||||||
|
|
||||||
ALGOS = {
|
ALGOS = {
|
||||||
'stable': 'runwayml/stable-diffusion-v1-5',
|
'stable': 'runwayml/stable-diffusion-v1-5',
|
||||||
|
|
|
@ -11,6 +11,49 @@ import triopg
|
||||||
from .constants import *
|
from .constants import *
|
||||||
|
|
||||||
|
|
||||||
|
DB_INIT_SQL = '''
|
||||||
|
CREATE SCHEMA IF NOT EXISTS skynet;
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS skynet.user(
|
||||||
|
id SERIAL PRIMARY KEY NOT NULL,
|
||||||
|
tg_id INT,
|
||||||
|
wp_id VARCHAR(128),
|
||||||
|
mx_id VARCHAR(128),
|
||||||
|
ig_id VARCHAR(128),
|
||||||
|
generated INT NOT NULL,
|
||||||
|
joined DATE NOT NULL,
|
||||||
|
last_prompt TEXT,
|
||||||
|
role VARCHAR(128) NOT NULL
|
||||||
|
);
|
||||||
|
ALTER TABLE skynet.user
|
||||||
|
ADD CONSTRAINT tg_unique
|
||||||
|
UNIQUE (tg_id);
|
||||||
|
ALTER TABLE skynet.user
|
||||||
|
ADD CONSTRAINT wp_unique
|
||||||
|
UNIQUE (wp_id);
|
||||||
|
ALTER TABLE skynet.user
|
||||||
|
ADD CONSTRAINT mx_unique
|
||||||
|
UNIQUE (mx_id);
|
||||||
|
ALTER TABLE skynet.user
|
||||||
|
ADD CONSTRAINT ig_unique
|
||||||
|
UNIQUE (ig_id);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS skynet.user_config(
|
||||||
|
id SERIAL NOT NULL,
|
||||||
|
algo VARCHAR(128) NOT NULL,
|
||||||
|
step INT NOT NULL,
|
||||||
|
width INT NOT NULL,
|
||||||
|
height INT NOT NULL,
|
||||||
|
seed INT,
|
||||||
|
guidance INT NOT NULL,
|
||||||
|
upscaler VARCHAR(128)
|
||||||
|
);
|
||||||
|
ALTER TABLE skynet.user_config
|
||||||
|
ADD FOREIGN KEY(id)
|
||||||
|
REFERENCES skynet.user(id);
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
def try_decode_uid(uid: str):
|
def try_decode_uid(uid: str):
|
||||||
try:
|
try:
|
||||||
proto, uid = uid.split('+')
|
proto, uid = uid.split('+')
|
||||||
|
@ -24,14 +67,18 @@ def try_decode_uid(uid: str):
|
||||||
|
|
||||||
@acm
|
@acm
|
||||||
async def open_database_connection(
|
async def open_database_connection(
|
||||||
db_user: str,
|
db_user: str = DB_USER,
|
||||||
db_pass: str,
|
db_pass: str = DB_PASS,
|
||||||
db_host: str = DB_HOST,
|
db_host: str = DB_HOST,
|
||||||
|
db_name: str = DB_NAME
|
||||||
):
|
):
|
||||||
async with triopg.create_pool(
|
async with triopg.create_pool(
|
||||||
dsn=f'postgres://{db_user}:{db_pass}@{db_host}/skynet_art_bot'
|
dsn=f'postgres://{db_user}:{db_pass}@{db_host}/{db_name}'
|
||||||
) as conn:
|
) as pool_conn:
|
||||||
yield conn
|
async with pool_conn.acquire() as conn:
|
||||||
|
await conn.execute(DB_INIT_SQL)
|
||||||
|
|
||||||
|
yield pool_conn
|
||||||
|
|
||||||
|
|
||||||
async def get_user(conn, uid: str):
|
async def get_user(conn, uid: str):
|
||||||
|
|
|
@ -65,58 +65,53 @@ async def open_dgpu_node(
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
with (
|
async with open_skynet_rpc() as rpc_call:
|
||||||
pynng.Req0(dial=rpc_address) as rpc_sock,
|
with pynng.Bus0(dial=dgpu_address) as dgpu_sock:
|
||||||
pynng.Bus0(dial=dgpu_address) as dgpu_sock
|
async def _process_dgpu_req(req: DGPUBusRequest):
|
||||||
):
|
img = await gpu_compute_one(
|
||||||
async def _rpc_call(*args, **kwargs):
|
ImageGenRequest(**req.params))
|
||||||
return await rpc_call(rpc_sock, *args, **kwargs)
|
await dgpu_sock.asend(
|
||||||
|
bytes.fromhex(req.rid) + img)
|
||||||
|
|
||||||
async def _process_dgpu_req(req: DGPUBusRequest):
|
res = await rpc_call(
|
||||||
img = await gpu_compute_one(
|
name.hex, 'dgpu_online', {'max_tasks': dgpu_max_tasks})
|
||||||
ImageGenRequest(**req.params))
|
logging.info(res)
|
||||||
await dgpu_sock.asend(
|
assert 'ok' in res.result
|
||||||
bytes.fromhex(req.rid) + img)
|
|
||||||
|
|
||||||
res = await _rpc_call(
|
async with (
|
||||||
name.hex, 'dgpu_online', {'max_tasks': dgpu_max_tasks})
|
tractor.open_actor_cluster(
|
||||||
logging.info(res)
|
modules=['skynet_bot.gpu'],
|
||||||
assert 'ok' in res.result
|
count=dgpu_max_tasks,
|
||||||
|
names=[i for i in range(dgpu_max_tasks)]
|
||||||
async with (
|
) as portal_map,
|
||||||
tractor.open_actor_cluster(
|
trio.open_nursery() as n
|
||||||
modules=['skynet_bot.gpu'],
|
):
|
||||||
count=dgpu_max_tasks,
|
logging.info(f'starting {dgpu_max_tasks} gpu workers')
|
||||||
names=[i for i in range(dgpu_max_tasks)]
|
async with tractor.gather_contexts((
|
||||||
) as portal_map,
|
portal.open_context(
|
||||||
trio.open_nursery() as n
|
open_gpu_worker, algo, 1.0 / dgpu_max_tasks)
|
||||||
):
|
for portal in portal_map.values()
|
||||||
logging.info(f'starting {dgpu_max_tasks} gpu workers')
|
)) as contexts:
|
||||||
async with tractor.gather_contexts((
|
contexts = {i: ctx for i, ctx in enumerate(contexts)}
|
||||||
portal.open_context(
|
for i, ctx in contexts.items():
|
||||||
open_gpu_worker, algo, 1.0 / dgpu_max_tasks)
|
|
||||||
for portal in portal_map.values()
|
|
||||||
)) as contexts:
|
|
||||||
contexts = {i: ctx for i, ctx in enumerate(contexts)}
|
|
||||||
for i, ctx in contexts.items():
|
|
||||||
n.start_soon(
|
|
||||||
gpu_streamer, ctx, i)
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
msg = await dgpu_sock.arecv()
|
|
||||||
req = DGPUBusRequest(
|
|
||||||
**json.loads(msg.decode()))
|
|
||||||
|
|
||||||
if req.nid != name.hex:
|
|
||||||
continue
|
|
||||||
|
|
||||||
logging.info(f'dgpu: {name}, req: {req}')
|
|
||||||
n.start_soon(
|
n.start_soon(
|
||||||
_process_dgpu_req, req)
|
gpu_streamer, ctx, i)
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
msg = await dgpu_sock.arecv()
|
||||||
|
req = DGPUBusRequest(
|
||||||
|
**json.loads(msg.decode()))
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
if req.nid != name.hex:
|
||||||
...
|
continue
|
||||||
|
|
||||||
res = await _rpc_call(name.hex, 'dgpu_offline')
|
logging.info(f'dgpu: {name}, req: {req}')
|
||||||
logging.info(res)
|
n.start_soon(
|
||||||
assert 'ok' in res.result
|
_process_dgpu_req, req)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
...
|
||||||
|
|
||||||
|
res = await rpc_call(name.hex, 'dgpu_offline')
|
||||||
|
logging.info(res)
|
||||||
|
assert 'ok' in res.result
|
||||||
|
|
|
@ -3,14 +3,17 @@
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from contextlib import contextmanager as cm
|
from contextlib import asynccontextmanager as acm
|
||||||
|
|
||||||
import pynng
|
import pynng
|
||||||
|
|
||||||
from ..types import SkynetRPCRequest, SkynetRPCResponse
|
from ..types import SkynetRPCRequest, SkynetRPCResponse
|
||||||
from ..constants import DEFAULT_RPC_ADDR
|
from ..constants import *
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigRequestFormatError(BaseException):
|
||||||
|
...
|
||||||
|
|
||||||
class ConfigUnknownAttribute(BaseException):
|
class ConfigUnknownAttribute(BaseException):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -44,7 +47,71 @@ async def rpc_call(
|
||||||
(await sock.arecv_msg()).bytes.decode()))
|
(await sock.arecv_msg()).bytes.decode()))
|
||||||
|
|
||||||
|
|
||||||
@cm
|
@acm
|
||||||
def open_skynet_rpc(rpc_address: str = DEFAULT_RPC_ADDR):
|
async def open_skynet_rpc(rpc_address: str = DEFAULT_RPC_ADDR):
|
||||||
with pynng.Req0(dial=rpc_address) as rpc_sock:
|
with pynng.Req0(dial=rpc_address) as sock:
|
||||||
yield rpc_sock
|
async def _rpc_call(*args, **kwargs):
|
||||||
|
return await rpc_call(sock, *args, **kwargs)
|
||||||
|
|
||||||
|
yield _rpc_call
|
||||||
|
|
||||||
|
|
||||||
|
def validate_user_config_request(req: str):
|
||||||
|
params = req.split(' ')
|
||||||
|
|
||||||
|
if len(params) < 3:
|
||||||
|
raise ConfigRequestFormatError('config request format incorrect')
|
||||||
|
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
attr = params[1]
|
||||||
|
|
||||||
|
if attr == 'algo':
|
||||||
|
val = params[2]
|
||||||
|
if val not in ALGOS:
|
||||||
|
raise ConfigUnknownAlgorithm(f'no algo named {val}')
|
||||||
|
|
||||||
|
elif attr == 'step':
|
||||||
|
val = int(params[2])
|
||||||
|
val = max(min(val, MAX_STEP), MIN_STEP)
|
||||||
|
|
||||||
|
elif attr == 'width':
|
||||||
|
val = max(min(int(params[2]), MAX_WIDTH), 16)
|
||||||
|
if val % 8 != 0:
|
||||||
|
raise ConfigSizeDivisionByEight(
|
||||||
|
'size must be divisible by 8!')
|
||||||
|
|
||||||
|
elif attr == 'height':
|
||||||
|
val = max(min(int(params[2]), MAX_HEIGHT), 16)
|
||||||
|
if val % 8 != 0:
|
||||||
|
raise ConfigSizeDivisionByEight(
|
||||||
|
'size must be divisible by 8!')
|
||||||
|
|
||||||
|
elif attr == 'seed':
|
||||||
|
val = params[2]
|
||||||
|
if val == 'auto':
|
||||||
|
val = None
|
||||||
|
else:
|
||||||
|
val = int(params[2])
|
||||||
|
|
||||||
|
elif attr == 'guidance':
|
||||||
|
val = float(params[2])
|
||||||
|
val = max(min(val, MAX_GUIDANCE), 0)
|
||||||
|
|
||||||
|
elif attr == 'upscaler':
|
||||||
|
val = params[2]
|
||||||
|
if val == 'off':
|
||||||
|
val = None
|
||||||
|
elif val != 'x4':
|
||||||
|
raise ConfigUnknownUpscaler(
|
||||||
|
f'\"{val}\" is not a valid upscaler')
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ConfigUnknownAttribute(
|
||||||
|
f'\"{attr}\" not a configurable parameter')
|
||||||
|
|
||||||
|
return attr, val, f'config updated! {attr} to {val}'
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f'\"{val}\" is not a number silly')
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import pynng
|
import pynng
|
||||||
|
|
||||||
|
@ -17,20 +18,21 @@ from . import *
|
||||||
PREFIX = 'tg'
|
PREFIX = 'tg'
|
||||||
|
|
||||||
|
|
||||||
async def run_skynet_telegram(tg_token: str):
|
async def run_skynet_telegram(
|
||||||
|
tg_token: str
|
||||||
|
):
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
bot = AsyncTeleBot(tg_token)
|
bot = AsyncTeleBot(tg_token)
|
||||||
|
|
||||||
with open_skynet_rpc() as rpc_sock:
|
with open_skynet_rpc() as rpc_call:
|
||||||
|
|
||||||
async def _rpc_call(
|
async def _rpc_call(
|
||||||
uid: int,
|
uid: int,
|
||||||
method: str,
|
method: str,
|
||||||
params: dict
|
params: dict
|
||||||
):
|
):
|
||||||
return await rpc_call(
|
return await rpc_call(f'{PREFIX}+{uid}', method, params)
|
||||||
rpc_sock, f'{PREFIX}+{uid}', method, params)
|
|
||||||
|
|
||||||
@bot.message_handler(commands=['help'])
|
@bot.message_handler(commands=['help'])
|
||||||
async def send_help(message):
|
async def send_help(message):
|
||||||
|
@ -58,79 +60,19 @@ async def run_skynet_telegram(tg_token: str):
|
||||||
|
|
||||||
@bot.message_handler(commands=['config'])
|
@bot.message_handler(commands=['config'])
|
||||||
async def set_config(message):
|
async def set_config(message):
|
||||||
params = message.text.split(' ')
|
|
||||||
|
|
||||||
rpc_params = {}
|
rpc_params = {}
|
||||||
|
try:
|
||||||
|
attr, val, reply_txt = validate_user_config_request(
|
||||||
|
message.text)
|
||||||
|
|
||||||
if len(params) < 3:
|
resp = await _rpc_call(
|
||||||
bot.reply_to(message, 'wrong msg format')
|
message.from_user.id,
|
||||||
|
'config', {'attr': attr, 'val': val})
|
||||||
|
|
||||||
else:
|
except BaseException as e:
|
||||||
|
reply_text = e.message
|
||||||
try:
|
|
||||||
attr = params[1]
|
|
||||||
|
|
||||||
if attr == 'algo':
|
|
||||||
val = params[2]
|
|
||||||
if val not in ALGOS:
|
|
||||||
raise ConfigUnknownAlgorithm
|
|
||||||
|
|
||||||
elif attr == 'step':
|
|
||||||
val = int(params[2])
|
|
||||||
val = max(min(val, MAX_STEP), MIN_STEP)
|
|
||||||
|
|
||||||
elif attr == 'width':
|
|
||||||
val = max(min(int(params[2]), MAX_WIDTH), 16)
|
|
||||||
if val % 8 != 0:
|
|
||||||
raise ConfigSizeDivisionByEight
|
|
||||||
|
|
||||||
elif attr == 'height':
|
|
||||||
val = max(min(int(params[2]), MAX_HEIGHT), 16)
|
|
||||||
if val % 8 != 0:
|
|
||||||
raise ConfigSizeDivisionByEight
|
|
||||||
|
|
||||||
elif attr == 'seed':
|
|
||||||
val = params[2]
|
|
||||||
if val == 'auto':
|
|
||||||
val = None
|
|
||||||
else:
|
|
||||||
val = int(params[2])
|
|
||||||
|
|
||||||
elif attr == 'guidance':
|
|
||||||
val = float(params[2])
|
|
||||||
val = max(min(val, MAX_GUIDANCE), 0)
|
|
||||||
|
|
||||||
elif attr == 'upscaler':
|
|
||||||
val = params[2]
|
|
||||||
if val == 'off':
|
|
||||||
val = None
|
|
||||||
elif val != 'x4':
|
|
||||||
raise ConfigUnknownUpscaler
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ConfigUnknownAttribute
|
|
||||||
|
|
||||||
resp = await _rpc_call(
|
|
||||||
message.from_user.id,
|
|
||||||
'config', {'attr': attr, 'val': val})
|
|
||||||
|
|
||||||
reply_txt = f'config updated! {attr} to {val}'
|
|
||||||
|
|
||||||
except ConfigUnknownAlgorithm:
|
|
||||||
reply_txt = f'no algo named {val}'
|
|
||||||
|
|
||||||
except ConfigUnknownAttribute:
|
|
||||||
reply_txt = f'\"{attr}\" not a configurable parameter'
|
|
||||||
|
|
||||||
except ConfigUnknownUpscaler:
|
|
||||||
reply_txt = f'\"{val}\" is not a valid upscaler'
|
|
||||||
|
|
||||||
except ConfigSizeDivisionByEight:
|
|
||||||
reply_txt = 'size must be divisible by 8!'
|
|
||||||
|
|
||||||
except ValueError:
|
|
||||||
reply_txt = f'\"{val}\" is not a number silly'
|
|
||||||
|
|
||||||
|
finally:
|
||||||
await bot.reply_to(message, reply_txt)
|
await bot.reply_to(message, reply_txt)
|
||||||
|
|
||||||
@bot.message_handler(commands=['stats'])
|
@bot.message_handler(commands=['stats'])
|
||||||
|
|
|
@ -0,0 +1,90 @@
|
||||||
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import trio
|
||||||
|
import pytest
|
||||||
|
import psycopg2
|
||||||
|
import trio_asyncio
|
||||||
|
|
||||||
|
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
|
||||||
|
|
||||||
|
from skynet_bot.constants import *
|
||||||
|
from skynet_bot.brain import run_skynet
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope='session')
|
||||||
|
def postgres_db(dockerctl):
|
||||||
|
rpassword = ''.join(
|
||||||
|
random.choice(string.ascii_lowercase)
|
||||||
|
for i in range(12))
|
||||||
|
password = ''.join(
|
||||||
|
random.choice(string.ascii_lowercase)
|
||||||
|
for i in range(12))
|
||||||
|
|
||||||
|
with dockerctl.run(
|
||||||
|
'postgres',
|
||||||
|
command='postgres',
|
||||||
|
ports={'5432/tcp': None},
|
||||||
|
environment={
|
||||||
|
'POSTGRES_PASSWORD': rpassword
|
||||||
|
}
|
||||||
|
) as containers:
|
||||||
|
container = containers[0]
|
||||||
|
# ip = container.attrs['NetworkSettings']['IPAddress']
|
||||||
|
port = container.ports['5432/tcp'][0]['HostPort']
|
||||||
|
host = f'localhost:{port}'
|
||||||
|
|
||||||
|
for log in container.logs(stream=True):
|
||||||
|
log = log.decode().rstrip()
|
||||||
|
logging.info(log)
|
||||||
|
if ('database system is ready to accept connections' in log or
|
||||||
|
'database system is shut down' in log):
|
||||||
|
break
|
||||||
|
|
||||||
|
# why print the system is ready to accept connections when its not
|
||||||
|
# postgres? wtf
|
||||||
|
time.sleep(1)
|
||||||
|
logging.info('creating skynet db...')
|
||||||
|
|
||||||
|
conn = psycopg2.connect(
|
||||||
|
user='postgres',
|
||||||
|
password=rpassword,
|
||||||
|
host='localhost',
|
||||||
|
port=port
|
||||||
|
)
|
||||||
|
logging.info('connected...')
|
||||||
|
conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
|
||||||
|
with conn.cursor() as cursor:
|
||||||
|
cursor.execute(
|
||||||
|
f'CREATE USER {DB_USER} WITH PASSWORD \'{password}\'')
|
||||||
|
cursor.execute(
|
||||||
|
f'CREATE DATABASE {DB_NAME}')
|
||||||
|
cursor.execute(
|
||||||
|
f'GRANT ALL PRIVILEGES ON DATABASE {DB_NAME} TO {DB_USER}')
|
||||||
|
|
||||||
|
logging.info('done.')
|
||||||
|
yield container, password, host
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def skynet_running(postgres_db):
|
||||||
|
db_container, db_pass, db_host = postgres_db
|
||||||
|
async with (
|
||||||
|
trio_asyncio.open_loop(),
|
||||||
|
trio.open_nursery() as n
|
||||||
|
):
|
||||||
|
await n.start(
|
||||||
|
partial(run_skynet,
|
||||||
|
db_pass=db_pass,
|
||||||
|
db_host=db_host))
|
||||||
|
|
||||||
|
yield
|
||||||
|
n.cancel_scope.cancel()
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import trio
|
||||||
|
import trio_asyncio
|
||||||
|
|
||||||
|
from skynet_bot.types import *
|
||||||
|
from skynet_bot.brain import run_skynet
|
||||||
|
from skynet_bot.frontend import open_skynet_rpc
|
||||||
|
|
||||||
|
|
||||||
|
async def test_skynet_dgpu_connection_simple(skynet_running):
|
||||||
|
async with open_skynet_rpc() as rpc_call:
|
||||||
|
# check 0 nodes are connected
|
||||||
|
res = await rpc_call('dgpu-0', 'dgpu_workers')
|
||||||
|
logging.info(res)
|
||||||
|
assert 'ok' in res.result
|
||||||
|
assert res.result['ok'] == 0
|
||||||
|
|
||||||
|
# check next worker is None
|
||||||
|
res = await rpc_call('dgpu-0', 'dgpu_next')
|
||||||
|
logging.info(res)
|
||||||
|
assert 'ok' in res.result
|
||||||
|
assert res.result['ok'] == None
|
||||||
|
|
||||||
|
# connect 1 dgpu
|
||||||
|
res = await rpc_call(
|
||||||
|
'dgpu-0', 'dgpu_online', {'max_tasks': 3})
|
||||||
|
logging.info(res)
|
||||||
|
assert 'ok' in res.result
|
||||||
|
|
||||||
|
# check 1 node is connected
|
||||||
|
res = await rpc_call('dgpu-0', 'dgpu_workers')
|
||||||
|
logging.info(res)
|
||||||
|
assert 'ok' in res.result
|
||||||
|
assert res.result['ok'] == 1
|
||||||
|
|
||||||
|
# check next worker is 0
|
||||||
|
res = await rpc_call('dgpu-0', 'dgpu_next')
|
||||||
|
logging.info(res)
|
||||||
|
assert 'ok' in res.result
|
||||||
|
assert res.result['ok'] == 0
|
||||||
|
|
||||||
|
# disconnect 1 dgpu
|
||||||
|
res = await rpc_call(
|
||||||
|
'dgpu-0', 'dgpu_offline')
|
||||||
|
logging.info(res)
|
||||||
|
assert 'ok' in res.result
|
||||||
|
|
||||||
|
# check 0 nodes are connected
|
||||||
|
res = await rpc_call('dgpu-0', 'dgpu_workers')
|
||||||
|
logging.info(res)
|
||||||
|
assert 'ok' in res.result
|
||||||
|
assert res.result['ok'] == 0
|
||||||
|
|
||||||
|
# check next worker is None
|
||||||
|
res = await rpc_call('dgpu-0', 'dgpu_next')
|
||||||
|
logging.info(res)
|
||||||
|
assert 'ok' in res.result
|
||||||
|
assert res.result['ok'] == None
|
|
@ -1,22 +0,0 @@
|
||||||
#!/usr/bin/python
|
|
||||||
|
|
||||||
import trio
|
|
||||||
import trio_asyncio
|
|
||||||
|
|
||||||
from skynet_bot.brain import run_skynet
|
|
||||||
from skynet_bot.frontend import open_skynet_rpc
|
|
||||||
from skynet_bot.frontend.telegram import run_skynet_telegram
|
|
||||||
|
|
||||||
|
|
||||||
def test_run_tg_bot():
|
|
||||||
async def main():
|
|
||||||
async with trio.open_nursery() as n:
|
|
||||||
await n.start(
|
|
||||||
run_skynet,
|
|
||||||
'skynet', '3GbZd6UojbD8V7zWpeFn', 'ancap.tech:34508')
|
|
||||||
n.start_soon(
|
|
||||||
run_skynet_telegram,
|
|
||||||
'5853245787:AAFEmv3EjJ_qJ8d_vmOpi6o6HFHUf8a0uCQ')
|
|
||||||
|
|
||||||
|
|
||||||
trio_asyncio.run(main)
|
|
Loading…
Reference in New Issue