mirror of https://github.com/skygpu/skynet.git
Rework dgpu client to be single task
Add a lot of dgpu real image gen tests Modified docker files and environment to allow for quick test relaunch without image rebuild Rename package from skynet_bot to skynet Drop tractor usage cause cuda is oriented to just a single proc managing gpu resources Add ackwnoledge phase to image request for quick gpu disconnected type scenarios Add click entry point for dgpu Add posibility to reuse postgres_db fixture on same session by checking if schema init has been already donepull/2/head
parent
d2e676627a
commit
f6326ad05c
|
@ -1,3 +1,9 @@
|
|||
.git
|
||||
hf_home
|
||||
inputs
|
||||
outputs
|
||||
.python-version
|
||||
.pytest-cache
|
||||
**/__pycache__
|
||||
*.egg-info
|
||||
**/*.key
|
||||
**/*.cert
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
.python-version
|
||||
hf_home
|
||||
outputs
|
||||
secrets
|
||||
**/__pycache__
|
||||
*.egg-info
|
||||
**/*.key
|
||||
**/*.cert
|
||||
|
|
|
@ -4,10 +4,16 @@ env DEBIAN_FRONTEND=noninteractive
|
|||
|
||||
workdir /skynet
|
||||
|
||||
copy requirements.* ./
|
||||
copy requirements.test.txt requirements.test.txt
|
||||
copy requirements.txt requirements.txt
|
||||
copy pytest.ini ./
|
||||
copy setup.py ./
|
||||
copy skynet ./skynet
|
||||
|
||||
run pip install \
|
||||
-e . \
|
||||
-r requirements.txt \
|
||||
-r requirements.test.txt
|
||||
|
||||
workdir /scripts
|
||||
copy scripts ./
|
||||
copy tests ./
|
||||
|
|
|
@ -5,19 +5,25 @@ env DEBIAN_FRONTEND=noninteractive
|
|||
|
||||
workdir /skynet
|
||||
|
||||
copy requirements.* ./
|
||||
copy requirements.cuda* ./
|
||||
|
||||
run pip install -U pip ninja
|
||||
run pip install -r requirements.cuda.0.txt
|
||||
run pip install -v -r requirements.cuda.1.txt
|
||||
|
||||
run pip install \
|
||||
copy requirements.test.txt requirements.test.txt
|
||||
copy requirements.txt requirements.txt
|
||||
copy pytest.ini pytest.ini
|
||||
copy setup.py setup.py
|
||||
copy skynet skynet
|
||||
|
||||
run pip install -e . \
|
||||
-r requirements.txt \
|
||||
-r requirements.test.txt
|
||||
|
||||
env PYTORCH_CUDA_ALLOC_CONF max_split_size_mb:128
|
||||
env NVIDIA_VISIBLE_DEVICES=all
|
||||
env HF_HOME /hf_home
|
||||
|
||||
env PYTORCH_CUDA_ALLOC_CONF max_split_size_mb:128
|
||||
|
||||
workdir /scripts
|
||||
copy scripts scripts
|
||||
copy tests tests
|
|
@ -1,6 +1,6 @@
|
|||
docker build \
|
||||
-t skynet:runtime-cuda \
|
||||
-f Dockerfile.runtime-cuda .
|
||||
-f Dockerfile.runtime+cuda .
|
||||
|
||||
docker build \
|
||||
-t skynet:runtime \
|
||||
|
|
|
@ -1,2 +1,4 @@
|
|||
[pytest]
|
||||
log_cli = True
|
||||
log_level = info
|
||||
trio_mode = true
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
pdbpp
|
||||
scipy
|
||||
triton
|
||||
accelerate
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
pdbpp
|
||||
pytest
|
||||
psycopg2
|
||||
pytest-trio
|
||||
|
||||
git+https://github.com/tgoodlet/pytest-dockerctl.git@master#egg=pytest-dockerctl
|
||||
git+https://github.com/guilledk/pytest-dockerctl.git@host_network#egg=pytest-dockerctl
|
||||
|
|
|
@ -5,5 +5,3 @@ aiohttp
|
|||
msgspec
|
||||
pyOpenSSL
|
||||
trio_asyncio
|
||||
|
||||
git+https://github.com/goodboy/tractor.git@piker_pin#egg=tractor
|
||||
|
|
|
@ -8,7 +8,7 @@ import sys
|
|||
|
||||
from OpenSSL import crypto, SSL
|
||||
|
||||
from skynet_bot.constants import DEFAULT_CERTS_DIR
|
||||
from skynet.constants import DEFAULT_CERTS_DIR
|
||||
|
||||
|
||||
def input_or_skip(txt, default):
|
||||
|
|
9
setup.py
9
setup.py
|
@ -1,11 +1,16 @@
|
|||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name='skynet-bot',
|
||||
name='skynet',
|
||||
version='0.1.0a6',
|
||||
description='Decentralized compute platform',
|
||||
author='Guillermo Rodriguez',
|
||||
author_email='guillermo@telos.net',
|
||||
packages=find_packages(),
|
||||
install_requires=[]
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'skynet = skynet.cli:skynet',
|
||||
]
|
||||
},
|
||||
install_requires=['click']
|
||||
)
|
||||
|
|
|
@ -8,6 +8,7 @@ import logging
|
|||
from uuid import UUID
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
from contextlib import asynccontextmanager as acm
|
||||
from collections import OrderedDict
|
||||
|
||||
import trio
|
||||
|
@ -17,7 +18,7 @@ import trio_asyncio
|
|||
from pynng import TLSConfig
|
||||
|
||||
from .db import *
|
||||
from .types import *
|
||||
from .structs import *
|
||||
from .constants import *
|
||||
|
||||
|
||||
|
@ -27,18 +28,47 @@ class SkynetDGPUOffline(BaseException):
|
|||
class SkynetDGPUOverloaded(BaseException):
|
||||
...
|
||||
|
||||
class SkynetDGPUComputeError(BaseException):
|
||||
...
|
||||
|
||||
async def rpc_service(sock, dgpu_bus, db_pool):
|
||||
class SkynetShutdownRequested(BaseException):
|
||||
...
|
||||
|
||||
@acm
|
||||
async def open_rpc_service(sock, dgpu_bus, db_pool):
|
||||
nodes = OrderedDict()
|
||||
wip_reqs = {}
|
||||
fin_reqs = {}
|
||||
next_worker: Optional[int] = None
|
||||
|
||||
def is_worker_busy(nid: int):
|
||||
for task in nodes[nid]['tasks']:
|
||||
if task != None:
|
||||
return False
|
||||
def connect_node(uid):
|
||||
nonlocal next_worker
|
||||
nodes[uid] = {
|
||||
'task': None
|
||||
}
|
||||
logging.info(f'dgpu online: {uid}')
|
||||
|
||||
return True
|
||||
if not next_worker:
|
||||
next_worker = 0
|
||||
|
||||
def disconnect_node(uid):
|
||||
nonlocal next_worker
|
||||
if uid not in nodes:
|
||||
return
|
||||
i = list(nodes.keys()).index(uid)
|
||||
del nodes[uid]
|
||||
|
||||
if i < next_worker:
|
||||
next_worker -= 1
|
||||
|
||||
if len(nodes) == 0:
|
||||
logging.info('nw: None')
|
||||
next_worker = None
|
||||
|
||||
logging.warning(f'dgpu offline: {uid}')
|
||||
|
||||
def is_worker_busy(nid: str):
|
||||
return nodes[nid]['task'] != None
|
||||
|
||||
def are_all_workers_busy():
|
||||
for nid in nodes.keys():
|
||||
|
@ -47,30 +77,55 @@ async def rpc_service(sock, dgpu_bus, db_pool):
|
|||
|
||||
return True
|
||||
|
||||
next_worker: Optional[int] = None
|
||||
def get_next_worker():
|
||||
nonlocal next_worker
|
||||
logging.info('get next_worker called')
|
||||
logging.info(f'pre next_worker: {next_worker}')
|
||||
|
||||
if not next_worker:
|
||||
if next_worker == None:
|
||||
raise SkynetDGPUOffline
|
||||
|
||||
if are_all_workers_busy():
|
||||
raise SkynetDGPUOverloaded
|
||||
|
||||
while is_worker_busy(next_worker):
|
||||
|
||||
nid = list(nodes.keys())[next_worker]
|
||||
while is_worker_busy(nid):
|
||||
next_worker += 1
|
||||
|
||||
if next_worker >= len(nodes):
|
||||
next_worker = 0
|
||||
|
||||
return next_worker
|
||||
nid = list(nodes.keys())[next_worker]
|
||||
|
||||
next_worker += 1
|
||||
if next_worker >= len(nodes):
|
||||
next_worker = 0
|
||||
|
||||
logging.info(f'post next_worker: {next_worker}')
|
||||
|
||||
return nid
|
||||
|
||||
async def dgpu_image_streamer():
|
||||
nonlocal wip_reqs, fin_reqs
|
||||
while True:
|
||||
msg = await dgpu_bus.arecv_msg()
|
||||
rid = UUID(bytes=msg.bytes[:16]).hex
|
||||
img = msg.bytes[16:].hex()
|
||||
raw_msg = msg.bytes[16:]
|
||||
logging.info(f'streamer got back {rid} of size {len(raw_msg)}')
|
||||
match raw_msg[:5]:
|
||||
case b'error':
|
||||
img = raw_msg.decode()
|
||||
|
||||
case b'ack':
|
||||
img = raw_msg
|
||||
|
||||
case _:
|
||||
img = base64.b64encode(raw_msg).hex()
|
||||
|
||||
if rid not in wip_reqs:
|
||||
continue
|
||||
|
||||
fin_reqs[rid] = img
|
||||
event = wip_reqs[rid]
|
||||
event.set()
|
||||
|
@ -79,13 +134,14 @@ async def rpc_service(sock, dgpu_bus, db_pool):
|
|||
async def dgpu_stream_one_img(req: ImageGenRequest):
|
||||
nonlocal wip_reqs, fin_reqs, next_worker
|
||||
nid = get_next_worker()
|
||||
logging.info(f'dgpu_stream_one_img {next_worker} {nid}')
|
||||
idx = list(nodes.keys()).index(nid)
|
||||
logging.info(f'dgpu_stream_one_img {idx}/{len(nodes)} {nid}')
|
||||
rid = uuid.uuid4().hex
|
||||
event = trio.Event()
|
||||
wip_reqs[rid] = event
|
||||
ack_event = trio.Event()
|
||||
img_event = trio.Event()
|
||||
wip_reqs[rid] = ack_event
|
||||
|
||||
tid = nodes[nid]['tasks'].index(None)
|
||||
nodes[nid]['tasks'][tid] = rid
|
||||
nodes[nid]['task'] = rid
|
||||
|
||||
dgpu_req = DGPUBusRequest(
|
||||
rid=rid,
|
||||
|
@ -98,14 +154,37 @@ async def rpc_service(sock, dgpu_bus, db_pool):
|
|||
await dgpu_bus.asend(
|
||||
json.dumps(dgpu_req.to_dict()).encode())
|
||||
|
||||
await event.wait()
|
||||
with trio.move_on_after(4):
|
||||
await ack_event.wait()
|
||||
|
||||
nodes[nid]['tasks'][tid] = None
|
||||
logging.info(f'ack event: {ack_event.is_set()}')
|
||||
|
||||
if not ack_event.is_set():
|
||||
disconnect_node(nid)
|
||||
raise SkynetDGPUOffline('dgpu failed to acknowledge request')
|
||||
|
||||
ack = fin_reqs[rid]
|
||||
if ack != b'ack':
|
||||
disconnect_node(nid)
|
||||
raise SkynetDGPUOffline('dgpu failed to acknowledge request')
|
||||
|
||||
wip_reqs[rid] = img_event
|
||||
with trio.move_on_after(30):
|
||||
await img_event.wait()
|
||||
|
||||
if not img_event.is_set():
|
||||
disconnect_node(nid)
|
||||
raise SkynetDGPUComputeError('30 seconds timeout while processing request')
|
||||
|
||||
nodes[nid]['task'] = None
|
||||
|
||||
img = fin_reqs[rid]
|
||||
del fin_reqs[rid]
|
||||
|
||||
logging.info(f'done streaming {img}')
|
||||
logging.info(f'done streaming {len(img)} bytes')
|
||||
|
||||
if 'error' in img:
|
||||
raise SkynetDGPUComputeError(img)
|
||||
|
||||
return rid, img
|
||||
|
||||
|
@ -122,6 +201,10 @@ async def rpc_service(sock, dgpu_bus, db_pool):
|
|||
user_config = {**(await get_user_config(conn, user))}
|
||||
del user_config['id']
|
||||
prompt = req.params['prompt']
|
||||
user_config= {
|
||||
key : req.params.get(key, val)
|
||||
for key, val in user_config.items()
|
||||
}
|
||||
req = ImageGenRequest(
|
||||
prompt=prompt,
|
||||
**user_config
|
||||
|
@ -165,9 +248,10 @@ async def rpc_service(sock, dgpu_bus, db_pool):
|
|||
case _:
|
||||
logging.warn('unknown method')
|
||||
|
||||
except SkynetDGPUOffline:
|
||||
except SkynetDGPUOffline as e:
|
||||
result = {
|
||||
'error': 'skynet_dgpu_offline'
|
||||
'error': 'skynet_dgpu_offline',
|
||||
'message': str(e)
|
||||
}
|
||||
|
||||
except SkynetDGPUOverloaded:
|
||||
|
@ -176,22 +260,22 @@ async def rpc_service(sock, dgpu_bus, db_pool):
|
|||
'nodes': len(nodes)
|
||||
}
|
||||
|
||||
except BaseException as e:
|
||||
logging.error(e)
|
||||
except SkynetDGPUComputeError as e:
|
||||
result = {
|
||||
'error': 'skynet_internal_error'
|
||||
'error': 'skynet_dgpu_compute_error',
|
||||
'message': str(e)
|
||||
}
|
||||
|
||||
await rpc_ctx.asend(
|
||||
json.dumps(
|
||||
SkynetRPCResponse(result=result).to_dict()).encode())
|
||||
|
||||
|
||||
async with trio.open_nursery() as n:
|
||||
n.start_soon(dgpu_image_streamer)
|
||||
async def request_service(n):
|
||||
nonlocal next_worker
|
||||
while True:
|
||||
ctx = sock.new_context()
|
||||
msg = await ctx.arecv_msg()
|
||||
|
||||
content = msg.bytes.decode()
|
||||
req = SkynetRPCRequest(**json.loads(content))
|
||||
|
||||
|
@ -199,27 +283,14 @@ async def rpc_service(sock, dgpu_bus, db_pool):
|
|||
|
||||
result = {}
|
||||
|
||||
if req.method == 'dgpu_online':
|
||||
nodes[req.uid] = {
|
||||
'tasks': [None for _ in range(req.params['max_tasks'])],
|
||||
'max_tasks': req.params['max_tasks']
|
||||
}
|
||||
logging.info(f'dgpu online: {req.uid}')
|
||||
if req.method == 'skynet_shutdown':
|
||||
raise SkynetShutdownRequested
|
||||
|
||||
if not next_worker:
|
||||
next_worker = 0
|
||||
elif req.method == 'dgpu_online':
|
||||
connect_node(req.uid)
|
||||
|
||||
elif req.method == 'dgpu_offline':
|
||||
i = list(nodes.keys()).index(req.uid)
|
||||
del nodes[req.uid]
|
||||
|
||||
if i < next_worker:
|
||||
next_worker -= 1
|
||||
|
||||
if len(nodes) == 0:
|
||||
next_worker = None
|
||||
|
||||
logging.info(f'dgpu offline: {req.uid}')
|
||||
disconnect_node(req.uid)
|
||||
|
||||
elif req.method == 'dgpu_workers':
|
||||
result = len(nodes)
|
||||
|
@ -238,13 +309,22 @@ async def rpc_service(sock, dgpu_bus, db_pool):
|
|||
result={'ok': result}).to_dict()).encode())
|
||||
|
||||
|
||||
async with trio.open_nursery() as n:
|
||||
n.start_soon(dgpu_image_streamer)
|
||||
n.start_soon(request_service, n)
|
||||
logging.info('starting rpc service')
|
||||
yield
|
||||
logging.info('stopping rpc service')
|
||||
n.cancel_scope.cancel()
|
||||
|
||||
|
||||
@acm
|
||||
async def run_skynet(
|
||||
db_user: str = DB_USER,
|
||||
db_pass: str = DB_PASS,
|
||||
db_host: str = DB_HOST,
|
||||
rpc_address: str = DEFAULT_RPC_ADDR,
|
||||
dgpu_address: str = DEFAULT_DGPU_ADDR,
|
||||
task_status = trio.TASK_STATUS_IGNORED,
|
||||
security: bool = True
|
||||
):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
@ -260,8 +340,8 @@ async def run_skynet(
|
|||
(cert_path).read_text()
|
||||
for cert_path in (certs_dir / 'whitelist').glob('*.cert')]
|
||||
|
||||
logging.info(f'tls_key: {tls_key}')
|
||||
logging.info(f'tls_cert: {tls_cert}')
|
||||
cert_start = tls_cert.index('\n') + 1
|
||||
logging.info(f'tls_cert: {tls_cert[cert_start:cert_start+64]}...')
|
||||
logging.info(f'tls_whitelist len: {len(tls_whitelist)}')
|
||||
|
||||
rpc_address = 'tls+' + rpc_address
|
||||
|
@ -271,16 +351,14 @@ async def run_skynet(
|
|||
own_key_string=tls_key,
|
||||
own_cert_string=tls_cert)
|
||||
|
||||
async with (
|
||||
trio.open_nursery() as n,
|
||||
open_database_connection(
|
||||
db_user, db_pass, db_host) as db_pool
|
||||
with (
|
||||
pynng.Rep0() as rpc_sock,
|
||||
pynng.Bus0() as dgpu_bus
|
||||
):
|
||||
logging.info('connected to db.')
|
||||
with (
|
||||
pynng.Rep0() as rpc_sock,
|
||||
pynng.Bus0() as dgpu_bus
|
||||
):
|
||||
async with open_database_connection(
|
||||
db_user, db_pass, db_host) as db_pool:
|
||||
|
||||
logging.info('connected to db.')
|
||||
if security:
|
||||
rpc_sock.tls_config = tls_config
|
||||
dgpu_bus.tls_config = tls_config
|
||||
|
@ -288,13 +366,11 @@ async def run_skynet(
|
|||
rpc_sock.listen(rpc_address)
|
||||
dgpu_bus.listen(dgpu_address)
|
||||
|
||||
n.start_soon(
|
||||
rpc_service, rpc_sock, dgpu_bus, db_pool)
|
||||
task_status.started()
|
||||
|
||||
try:
|
||||
await trio.sleep_forever()
|
||||
async with open_rpc_service(rpc_sock, dgpu_bus, db_pool):
|
||||
yield
|
||||
|
||||
except KeyboardInterrupt:
|
||||
except SkynetShutdownRequested:
|
||||
...
|
||||
|
||||
logging.info('disconnected from db.')
|
|
@ -0,0 +1,68 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import os
|
||||
import json
|
||||
|
||||
from typing import Optional
|
||||
from functools import partial
|
||||
|
||||
import trio
|
||||
import click
|
||||
|
||||
from .dgpu import open_dgpu_node
|
||||
from .utils import txt2img
|
||||
from .constants import ALGOS
|
||||
|
||||
|
||||
@click.group()
|
||||
def skynet(*args, **kwargs):
|
||||
pass
|
||||
|
||||
@skynet.command()
|
||||
@click.option('--model', '-m', default=ALGOS['midj'])
|
||||
@click.option(
|
||||
'--prompt', '-p', default='a red tractor in a wheat field')
|
||||
@click.option('--output', '-o', default='output.png')
|
||||
@click.option('--width', '-w', default=512)
|
||||
@click.option('--height', '-h', default=512)
|
||||
@click.option('--guidance', '-g', default=10.0)
|
||||
@click.option('--steps', '-s', default=26)
|
||||
@click.option('--seed', '-S', default=None)
|
||||
def txt2img(*args
|
||||
# model: str,
|
||||
# prompt: str,
|
||||
# output: str
|
||||
# width: int, height: int,
|
||||
# guidance: float,
|
||||
# steps: int,
|
||||
# seed: Optional[int]
|
||||
):
|
||||
assert 'HF_TOKEN' in os.environ
|
||||
txt2img(
|
||||
os.environ['HF_TOKEN'], *args)
|
||||
|
||||
@skynet.group()
|
||||
def run(*args, **kwargs):
|
||||
pass
|
||||
|
||||
@run.command()
|
||||
@click.option('--loglevel', '-l', default='warning', help='Logging level')
|
||||
@click.option(
|
||||
'--key', '-k', default='dgpu')
|
||||
@click.option(
|
||||
'--cert', '-c', default='whitelist/dgpu')
|
||||
@click.option(
|
||||
'--algos', '-a', default=None)
|
||||
def dgpu(
|
||||
loglevel: str,
|
||||
key: str,
|
||||
cert: str,
|
||||
algos: Optional[str]
|
||||
):
|
||||
trio.run(
|
||||
partial(
|
||||
open_dgpu_node,
|
||||
cert,
|
||||
key_name=key,
|
||||
initial_algos=json.loads(algos)
|
||||
))
|
|
@ -1,6 +1,6 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
API_TOKEN = '5880619053:AAFge2UfObw1kCn9Kb7AAyqaIHs_HgM0Fx0'
|
||||
DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda'
|
||||
|
||||
DB_HOST = 'ancap.tech:34508'
|
||||
DB_USER = 'skynet'
|
||||
|
@ -8,8 +8,8 @@ DB_PASS = 'password'
|
|||
DB_NAME = 'skynet'
|
||||
|
||||
ALGOS = {
|
||||
'stable': 'runwayml/stable-diffusion-v1-5',
|
||||
'midj': 'prompthero/openjourney',
|
||||
'stable': 'runwayml/stable-diffusion-v1-5',
|
||||
'hdanime': 'Linaqruf/anything-v3.0',
|
||||
'waifu': 'hakurei/waifu-diffusion',
|
||||
'ghibli': 'nitrosocke/Ghibli-Diffusion',
|
||||
|
@ -122,7 +122,7 @@ DEFAULT_CERT_DGPU = 'dgpu.key'
|
|||
DEFAULT_RPC_ADDR = 'tcp://127.0.0.1:41000'
|
||||
|
||||
DEFAULT_DGPU_ADDR = 'tcp://127.0.0.1:41069'
|
||||
DEFAULT_DGPU_MAX_TASKS = 3
|
||||
DEFAULT_DGPU_MAX_TASKS = 2
|
||||
DEFAULT_INITAL_ALGOS = ['midj', 'stable', 'ink']
|
||||
|
||||
DATE_FORMAT = '%B the %dth %Y, %H:%M:%S'
|
|
@ -7,6 +7,9 @@ from contextlib import asynccontextmanager as acm
|
|||
|
||||
import trio
|
||||
import triopg
|
||||
import trio_asyncio
|
||||
|
||||
from asyncpg.exceptions import UndefinedColumnError
|
||||
|
||||
from .constants import *
|
||||
|
||||
|
@ -72,13 +75,22 @@ async def open_database_connection(
|
|||
db_host: str = DB_HOST,
|
||||
db_name: str = DB_NAME
|
||||
):
|
||||
async with triopg.create_pool(
|
||||
dsn=f'postgres://{db_user}:{db_pass}@{db_host}/{db_name}'
|
||||
) as pool_conn:
|
||||
async with pool_conn.acquire() as conn:
|
||||
await conn.execute(DB_INIT_SQL)
|
||||
async with trio_asyncio.open_loop() as loop:
|
||||
async with triopg.create_pool(
|
||||
dsn=f'postgres://{db_user}:{db_pass}@{db_host}/{db_name}'
|
||||
) as pool_conn:
|
||||
async with pool_conn.acquire() as conn:
|
||||
res = await conn.execute(f'''
|
||||
select distinct table_schema
|
||||
from information_schema.tables
|
||||
where table_schema = \'{db_name}\'
|
||||
''')
|
||||
if '1' in res:
|
||||
logging.info('schema already in db, skipping init')
|
||||
else:
|
||||
await conn.execute(DB_INIT_SQL)
|
||||
|
||||
yield pool_conn
|
||||
yield pool_conn
|
||||
|
||||
|
||||
async def get_user(conn, uid: str):
|
||||
|
@ -135,6 +147,7 @@ async def new_user(conn, uid: str):
|
|||
tg_id, generated, joined, last_prompt, role)
|
||||
|
||||
VALUES($1, $2, $3, $4, $5)
|
||||
ON CONFLICT DO NOTHING
|
||||
''')
|
||||
await stmt.fetch(
|
||||
tg_id, 0, date, None, DEFAULT_ROLE
|
||||
|
@ -147,6 +160,7 @@ async def new_user(conn, uid: str):
|
|||
id, algo, step, width, height, seed, guidance, upscaler)
|
||||
|
||||
VALUES($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
ON CONFLICT DO NOTHING
|
||||
''')
|
||||
user = await stmt.fetch(
|
||||
new_uid,
|
|
@ -0,0 +1,197 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import gc
|
||||
import io
|
||||
import trio
|
||||
import json
|
||||
import uuid
|
||||
import random
|
||||
import logging
|
||||
|
||||
from typing import List, Optional
|
||||
from pathlib import Path
|
||||
from contextlib import AsyncExitStack
|
||||
|
||||
import pynng
|
||||
import torch
|
||||
|
||||
from pynng import TLSConfig
|
||||
from diffusers import (
|
||||
StableDiffusionPipeline,
|
||||
EulerAncestralDiscreteScheduler
|
||||
)
|
||||
|
||||
from .structs import *
|
||||
from .constants import *
|
||||
from .frontend import open_skynet_rpc
|
||||
|
||||
|
||||
def pipeline_for(algo: str, mem_fraction: float = 1.0):
|
||||
assert torch.cuda.is_available()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
params = {
|
||||
'torch_dtype': torch.float16,
|
||||
'safety_checker': None
|
||||
}
|
||||
|
||||
if algo == 'stable':
|
||||
params['revision'] = 'fp16'
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
ALGOS[algo], **params)
|
||||
|
||||
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
||||
pipe.scheduler.config)
|
||||
|
||||
pipe.enable_vae_slicing()
|
||||
|
||||
return pipe.to("cuda")
|
||||
|
||||
|
||||
class DGPUComputeError(BaseException):
|
||||
...
|
||||
|
||||
|
||||
async def open_dgpu_node(
|
||||
cert_name: str,
|
||||
key_name: Optional[str],
|
||||
rpc_address: str = DEFAULT_RPC_ADDR,
|
||||
dgpu_address: str = DEFAULT_DGPU_ADDR,
|
||||
initial_algos: Optional[List[str]] = None,
|
||||
security: bool = True
|
||||
):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.info(f'starting dgpu node!')
|
||||
|
||||
name = uuid.uuid4()
|
||||
|
||||
logging.info(f'loading models...')
|
||||
|
||||
initial_algos = (
|
||||
initial_algos
|
||||
if initial_algos else DEFAULT_INITAL_ALGOS
|
||||
)
|
||||
models = {}
|
||||
for algo in initial_algos:
|
||||
models[algo] = {
|
||||
'pipe': pipeline_for(algo),
|
||||
'generated': 0
|
||||
}
|
||||
logging.info(f'loaded {algo}.')
|
||||
|
||||
logging.info('memory summary:\n')
|
||||
logging.info(torch.cuda.memory_summary())
|
||||
|
||||
async def gpu_compute_one(ireq: ImageGenRequest):
|
||||
if ireq.algo not in models:
|
||||
least_used = list(models.keys())[0]
|
||||
for model in models:
|
||||
if models[least_used]['generated'] > models[model]['generated']:
|
||||
least_used = model
|
||||
|
||||
del models[least_used]
|
||||
gc.collect()
|
||||
|
||||
models[ireq.algo] = {
|
||||
'pipe': pipeline_for(ireq.algo),
|
||||
'generated': 0
|
||||
}
|
||||
|
||||
seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64)
|
||||
try:
|
||||
image = models[ireq.algo]['pipe'](
|
||||
ireq.prompt,
|
||||
width=ireq.width,
|
||||
height=ireq.height,
|
||||
guidance_scale=ireq.guidance,
|
||||
num_inference_steps=ireq.step,
|
||||
generator=torch.Generator("cuda").manual_seed(seed)
|
||||
).images[0]
|
||||
return image.tobytes()
|
||||
|
||||
except BaseException as e:
|
||||
logging.error(e)
|
||||
raise DGPUComputeError(str(e))
|
||||
|
||||
finally:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
async with open_skynet_rpc(
|
||||
security=security,
|
||||
cert_name=cert_name,
|
||||
key_name=key_name
|
||||
) as rpc_call:
|
||||
|
||||
tls_config = None
|
||||
if security:
|
||||
# load tls certs
|
||||
if not key_name:
|
||||
key_name = certs_name
|
||||
certs_dir = Path(DEFAULT_CERTS_DIR).resolve()
|
||||
|
||||
skynet_cert_path = certs_dir / 'brain.cert'
|
||||
tls_cert_path = certs_dir / f'{cert_name}.cert'
|
||||
tls_key_path = certs_dir / f'{key_name}.key'
|
||||
|
||||
skynet_cert = skynet_cert_path.read_text()
|
||||
tls_cert = tls_cert_path.read_text()
|
||||
tls_key = tls_key_path.read_text()
|
||||
|
||||
logging.info(f'skynet cert: {skynet_cert_path}')
|
||||
logging.info(f'dgpu cert: {tls_cert_path}')
|
||||
logging.info(f'dgpu key: {tls_key_path}')
|
||||
|
||||
dgpu_address = 'tls+' + dgpu_address
|
||||
tls_config = TLSConfig(
|
||||
TLSConfig.MODE_CLIENT,
|
||||
own_key_string=tls_key,
|
||||
own_cert_string=tls_cert,
|
||||
ca_string=skynet_cert)
|
||||
|
||||
logging.info(f'connecting to {dgpu_address}')
|
||||
with pynng.Bus0() as dgpu_sock:
|
||||
dgpu_sock.tls_config = tls_config
|
||||
dgpu_sock.dial(dgpu_address)
|
||||
|
||||
res = await rpc_call(name.hex, 'dgpu_online')
|
||||
logging.info(res)
|
||||
assert 'ok' in res.result
|
||||
|
||||
try:
|
||||
while True:
|
||||
msg = await dgpu_sock.arecv()
|
||||
req = DGPUBusRequest(
|
||||
**json.loads(msg.decode()))
|
||||
|
||||
if req.nid != name.hex:
|
||||
logging.info('witnessed request {req.rid}, for {req.nid}')
|
||||
continue
|
||||
|
||||
# send ack
|
||||
await dgpu_sock.asend(
|
||||
bytes.fromhex(req.rid) + b'ack')
|
||||
|
||||
logging.info(f'sent ack, processing {req.rid}...')
|
||||
|
||||
try:
|
||||
img = await gpu_compute_one(
|
||||
ImageGenRequest(**req.params))
|
||||
|
||||
except DGPUComputeError as e:
|
||||
img = b'error' + str(e).encode()
|
||||
|
||||
await dgpu_sock.asend(
|
||||
bytes.fromhex(req.rid) + img)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logging.info('interrupt caught, stopping...')
|
||||
|
||||
finally:
|
||||
res = await rpc_call(name.hex, 'dgpu_offline')
|
||||
logging.info(res)
|
||||
assert 'ok' in res.result
|
|
@ -10,7 +10,7 @@ import pynng
|
|||
|
||||
from pynng import TLSConfig
|
||||
|
||||
from ..types import SkynetRPCRequest, SkynetRPCResponse
|
||||
from ..structs import SkynetRPCRequest, SkynetRPCResponse
|
||||
from ..constants import *
|
||||
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import random
|
||||
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from huggingface_hub import login
|
||||
|
||||
|
||||
def txt2img(
|
||||
hf_token: str,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
output: str,
|
||||
width: int, height: int,
|
||||
guidance: float,
|
||||
steps: int,
|
||||
seed: Optional[int]
|
||||
):
|
||||
assert torch.cuda.is_available()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.set_per_process_memory_fraction(0.333)
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
login(token=hf_token)
|
||||
|
||||
params = {
|
||||
'torch_dtype': torch.float16,
|
||||
'safety_checker': None
|
||||
}
|
||||
if model_name == 'runwayml/stable-diffusion-v1-5':
|
||||
params['revision'] = 'fp16'
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
model_name, **params)
|
||||
|
||||
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
||||
pipe.scheduler.config)
|
||||
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64)
|
||||
prompt = prompt
|
||||
image = pipe(
|
||||
prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
guidance_scale=guidance, num_inference_steps=steps,
|
||||
generator=torch.Generator("cuda").manual_seed(seed)
|
||||
).images[0]
|
||||
|
||||
image.save(output)
|
|
@ -1,124 +0,0 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import trio
|
||||
import json
|
||||
import uuid
|
||||
import logging
|
||||
|
||||
import pynng
|
||||
import tractor
|
||||
|
||||
from . import gpu
|
||||
from .gpu import open_gpu_worker
|
||||
from .types import *
|
||||
from .constants import *
|
||||
from .frontend import rpc_call
|
||||
|
||||
|
||||
async def open_dgpu_node(
|
||||
cert_name: str,
|
||||
key_name: Optional[str],
|
||||
rpc_address: str = DEFAULT_RPC_ADDR,
|
||||
dgpu_address: str = DEFAULT_DGPU_ADDR,
|
||||
dgpu_max_tasks: int = DEFAULT_DGPU_MAX_TASKS,
|
||||
initial_algos: str = DEFAULT_INITAL_ALGOS,
|
||||
security: bool = True
|
||||
):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
name = uuid.uuid4()
|
||||
workers = initial_algos.copy()
|
||||
tasks = [None for _ in range(dgpu_max_tasks)]
|
||||
|
||||
portal_map: dict[int, tractor.Portal]
|
||||
contexts: dict[int, tractor.Context]
|
||||
|
||||
def get_next_worker(need_algo: str):
|
||||
nonlocal workers, tasks
|
||||
for task, algo in zip(workers, tasks):
|
||||
if need_algo == algo and not task:
|
||||
return workers.index(need_algo)
|
||||
|
||||
return tasks.index(None)
|
||||
|
||||
async def gpu_streamer(
|
||||
ctx: tractor.Context,
|
||||
nid: int
|
||||
):
|
||||
nonlocal tasks
|
||||
async with ctx.open_stream() as stream:
|
||||
async for img in stream:
|
||||
tasks[nid]['res'] = img
|
||||
tasks[nid]['event'].set()
|
||||
|
||||
async def gpu_compute_one(ireq: ImageGenRequest):
|
||||
wid = get_next_worker(ireq.algo)
|
||||
event = trio.Event()
|
||||
|
||||
workers[wid] = ireq.algo
|
||||
tasks[wid] = {
|
||||
'res': None, 'event': event}
|
||||
|
||||
await contexts[i].send(ireq)
|
||||
|
||||
await event.wait()
|
||||
|
||||
img = tasks[wid]['res']
|
||||
tasks[wid] = None
|
||||
return img
|
||||
|
||||
|
||||
async with open_skynet_rpc(
|
||||
security=security,
|
||||
cert_name=cert_name,
|
||||
key_name=key_name
|
||||
) as rpc_call:
|
||||
with pynng.Bus0(dial=dgpu_address) as dgpu_sock:
|
||||
async def _process_dgpu_req(req: DGPUBusRequest):
|
||||
img = await gpu_compute_one(
|
||||
ImageGenRequest(**req.params))
|
||||
await dgpu_sock.asend(
|
||||
bytes.fromhex(req.rid) + img)
|
||||
|
||||
res = await rpc_call(
|
||||
name.hex, 'dgpu_online', {'max_tasks': dgpu_max_tasks})
|
||||
logging.info(res)
|
||||
assert 'ok' in res.result
|
||||
|
||||
async with (
|
||||
tractor.open_actor_cluster(
|
||||
modules=['skynet_bot.gpu'],
|
||||
count=dgpu_max_tasks,
|
||||
names=[i for i in range(dgpu_max_tasks)]
|
||||
) as portal_map,
|
||||
trio.open_nursery() as n
|
||||
):
|
||||
logging.info(f'starting {dgpu_max_tasks} gpu workers')
|
||||
async with tractor.gather_contexts((
|
||||
portal.open_context(
|
||||
open_gpu_worker, algo, 1.0 / dgpu_max_tasks)
|
||||
for portal in portal_map.values()
|
||||
)) as contexts:
|
||||
contexts = {i: ctx for i, ctx in enumerate(contexts)}
|
||||
for i, ctx in contexts.items():
|
||||
n.start_soon(
|
||||
gpu_streamer, ctx, i)
|
||||
try:
|
||||
while True:
|
||||
msg = await dgpu_sock.arecv()
|
||||
req = DGPUBusRequest(
|
||||
**json.loads(msg.decode()))
|
||||
|
||||
if req.nid != name.hex:
|
||||
continue
|
||||
|
||||
logging.info(f'dgpu: {name}, req: {req}')
|
||||
n.start_soon(
|
||||
_process_dgpu_req, req)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
...
|
||||
|
||||
res = await rpc_call(name.hex, 'dgpu_offline')
|
||||
logging.info(res)
|
||||
assert 'ok' in res.result
|
|
@ -1,77 +0,0 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import io
|
||||
import random
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import tractor
|
||||
|
||||
from diffusers import (
|
||||
StableDiffusionPipeline,
|
||||
EulerAncestralDiscreteScheduler
|
||||
)
|
||||
|
||||
from .types import ImageGenRequest
|
||||
from .constants import ALGOS
|
||||
|
||||
|
||||
def pipeline_for(algo: str, mem_fraction: float):
|
||||
assert torch.cuda.is_available()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
params = {
|
||||
'torch_dtype': torch.float16,
|
||||
'safety_checker': None
|
||||
}
|
||||
|
||||
if algo == 'stable':
|
||||
params['revision'] = 'fp16'
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
ALGOS[algo], **params)
|
||||
|
||||
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
||||
pipe.scheduler.config)
|
||||
|
||||
return pipe.to("cuda")
|
||||
|
||||
@tractor.context
|
||||
async def open_gpu_worker(
|
||||
ctx: tractor.Context,
|
||||
start_algo: str,
|
||||
mem_fraction: float
|
||||
):
|
||||
log = tractor.log.get_logger(name='gpu', _root_name='skynet')
|
||||
log.info(f'starting gpu worker with algo {start_algo}...')
|
||||
current_algo = start_algo
|
||||
with torch.no_grad():
|
||||
pipe = pipeline_for(current_algo, mem_fraction)
|
||||
log.info('pipeline loaded')
|
||||
await ctx.started()
|
||||
async with ctx.open_stream() as bus:
|
||||
async for ireq in bus:
|
||||
if ireq.algo != current_algo:
|
||||
current_algo = ireq.algo
|
||||
pipe = pipeline_for(current_algo, mem_fraction)
|
||||
|
||||
seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64)
|
||||
image = pipe(
|
||||
ireq.prompt,
|
||||
width=ireq.width,
|
||||
height=ireq.height,
|
||||
guidance_scale=ireq.guidance,
|
||||
num_inference_steps=ireq.step,
|
||||
generator=torch.Generator("cuda").manual_seed(seed)
|
||||
).images[0]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# convert PIL.Image to BytesIO
|
||||
img_bytes = io.BytesIO()
|
||||
image.save(img_bytes, format='PNG')
|
||||
await bus.send(img_bytes.getvalue())
|
||||
|
|
@ -1,2 +0,0 @@
|
|||
from OpenSSL.crypto import load_publickey, FILETYPE_PEM, verify, X509
|
||||
|
9
test.sh
9
test.sh
|
@ -1,9 +0,0 @@
|
|||
docker run \
|
||||
-it \
|
||||
--rm \
|
||||
--gpus=all \
|
||||
--mount type=bind,source="$(pwd)",target=/skynet \
|
||||
skynet:runtime-cuda \
|
||||
bash -c \
|
||||
"cd /skynet && pip install -e . && \
|
||||
pytest $1 --log-cli-level=info"
|
|
@ -1,21 +1,25 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
import string
|
||||
import logging
|
||||
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import trio
|
||||
import pytest
|
||||
import psycopg2
|
||||
import trio_asyncio
|
||||
|
||||
from docker.types import Mount, DeviceRequest
|
||||
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
|
||||
|
||||
from skynet_bot.constants import *
|
||||
from skynet_bot.brain import run_skynet
|
||||
from skynet.constants import *
|
||||
from skynet.brain import run_skynet
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
|
@ -29,6 +33,7 @@ def postgres_db(dockerctl):
|
|||
|
||||
with dockerctl.run(
|
||||
'postgres',
|
||||
name='skynet-test-postgres',
|
||||
ports={'5432/tcp': None},
|
||||
environment={
|
||||
'POSTGRES_PASSWORD': rpassword
|
||||
|
@ -67,6 +72,8 @@ def postgres_db(dockerctl):
|
|||
cursor.execute(
|
||||
f'GRANT ALL PRIVILEGES ON DATABASE {DB_NAME} TO {DB_USER}')
|
||||
|
||||
conn.close()
|
||||
|
||||
logging.info('done.')
|
||||
yield container, password, host
|
||||
|
||||
|
@ -74,16 +81,44 @@ def postgres_db(dockerctl):
|
|||
@pytest.fixture
|
||||
async def skynet_running(postgres_db):
|
||||
db_container, db_pass, db_host = postgres_db
|
||||
async with (
|
||||
trio_asyncio.open_loop(),
|
||||
trio.open_nursery() as n
|
||||
|
||||
async with run_skynet(
|
||||
db_pass=db_pass,
|
||||
db_host=db_host
|
||||
):
|
||||
await n.start(
|
||||
partial(run_skynet,
|
||||
db_pass=db_pass,
|
||||
db_host=db_host))
|
||||
|
||||
yield
|
||||
n.cancel_scope.cancel()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dgpu_workers(request, dockerctl, skynet_running):
|
||||
devices = [DeviceRequest(capabilities=[['gpu']])]
|
||||
mounts = [Mount(
|
||||
'/skynet', str(Path().resolve()), type='bind')]
|
||||
|
||||
num_containers, initial_algos = request.param
|
||||
|
||||
cmd = f'''
|
||||
pip install -e . && \
|
||||
skynet run dgpu --algos=\'{json.dumps(initial_algos)}\'
|
||||
'''
|
||||
|
||||
logging.info(f'launching: \n{cmd}')
|
||||
|
||||
with dockerctl.run(
|
||||
DOCKER_RUNTIME_CUDA,
|
||||
name='skynet-test-runtime-cuda',
|
||||
command=['bash', '-c', cmd],
|
||||
environment={
|
||||
'HF_TOKEN': os.environ['HF_TOKEN'],
|
||||
'HF_HOME': '/skynet/hf_home'
|
||||
},
|
||||
network='host',
|
||||
mounts=mounts,
|
||||
device_requests=devices,
|
||||
num=num_containers
|
||||
) as containers:
|
||||
yield containers
|
||||
|
||||
#for i, container in enumerate(containers):
|
||||
# logging.info(f'container {i} logs:')
|
||||
# logging.info(container.logs().decode())
|
||||
|
|
|
@ -1,57 +1,248 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import io
|
||||
import time
|
||||
import json
|
||||
import base64
|
||||
import logging
|
||||
|
||||
from hashlib import sha256
|
||||
from functools import partial
|
||||
|
||||
import trio
|
||||
import pynng
|
||||
import pytest
|
||||
import tractor
|
||||
import trio_asyncio
|
||||
|
||||
from skynet_bot.gpu import open_gpu_worker
|
||||
from skynet_bot.dgpu import open_dgpu_node
|
||||
from skynet_bot.types import *
|
||||
from skynet_bot.brain import run_skynet
|
||||
from skynet_bot.constants import *
|
||||
from skynet_bot.frontend import open_skynet_rpc, rpc_call
|
||||
from PIL import Image
|
||||
|
||||
from skynet.brain import SkynetDGPUComputeError
|
||||
from skynet.constants import *
|
||||
from skynet.frontend import open_skynet_rpc
|
||||
|
||||
|
||||
def test_dgpu_simple():
|
||||
async def main():
|
||||
async def wait_for_dgpus(rpc, amount: int, timeout: float = 30.0):
|
||||
gpu_ready = False
|
||||
start_time = time.time()
|
||||
current_time = time.time()
|
||||
while not gpu_ready and (current_time - start_time) < timeout:
|
||||
res = await rpc('dgpu-test', 'dgpu_workers')
|
||||
logging.info(res)
|
||||
if res.result['ok'] >= amount:
|
||||
break
|
||||
|
||||
await trio.sleep(1)
|
||||
current_time = time.time()
|
||||
|
||||
assert (current_time - start_time) < timeout
|
||||
|
||||
|
||||
_images = set()
|
||||
async def check_request_img(
|
||||
i: int,
|
||||
width: int = 512,
|
||||
height: int = 512,
|
||||
expect_unique=True
|
||||
):
|
||||
global _images
|
||||
|
||||
async with open_skynet_rpc(
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
) as rpc_call:
|
||||
res = await rpc_call(
|
||||
'tg+580213293', 'txt2img', {
|
||||
'prompt': 'red old tractor in a sunny wheat field',
|
||||
'step': 28,
|
||||
'width': width, 'height': height,
|
||||
'guidance': 7.5,
|
||||
'seed': None,
|
||||
'algo': list(ALGOS.keys())[i],
|
||||
'upscaler': None
|
||||
})
|
||||
|
||||
if 'error' in res.result:
|
||||
raise SkynetDGPUComputeError(json.dumps(res.result))
|
||||
|
||||
img_raw = base64.b64decode(bytes.fromhex(res.result['img']))
|
||||
img_sha = sha256(img_raw).hexdigest()
|
||||
img = Image.frombytes(
|
||||
'RGB', (width, height), img_raw)
|
||||
|
||||
if expect_unique and img_sha in _images:
|
||||
raise ValueError('Duplicated image sha: {img_sha}')
|
||||
|
||||
_images.add(img_sha)
|
||||
|
||||
logging.info(f'img sha256: {img_sha} size: {len(img_raw)}')
|
||||
|
||||
assert len(img_raw) > 100000
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'dgpu_workers', [(1, ['midj'])], indirect=True)
|
||||
async def test_dgpu_worker_compute_error(dgpu_workers):
|
||||
'''Attempt to generate a huge image and check we get the right error,
|
||||
then generate a smaller image to show gpu worker recovery
|
||||
'''
|
||||
|
||||
async with open_skynet_rpc(
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
) as test_rpc:
|
||||
await wait_for_dgpus(test_rpc, 1)
|
||||
|
||||
with pytest.raises(SkynetDGPUComputeError) as e:
|
||||
await check_request_img(0, width=4096, height=4096)
|
||||
|
||||
logging.info(e)
|
||||
|
||||
await check_request_img(0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'dgpu_workers', [(1, ['midj', 'stable'])], indirect=True)
|
||||
async def test_dgpu_workers(dgpu_workers):
|
||||
'''Generate two images in a single dgpu worker using
|
||||
two different models.
|
||||
'''
|
||||
|
||||
async with open_skynet_rpc(
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
) as test_rpc:
|
||||
await wait_for_dgpus(test_rpc, 1)
|
||||
|
||||
await check_request_img(0)
|
||||
await check_request_img(1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'dgpu_workers', [(2, ['midj'])], indirect=True)
|
||||
async def test_dgpu_workers_two(dgpu_workers):
|
||||
'''Generate two images in two separate dgpu workers
|
||||
'''
|
||||
async with open_skynet_rpc(
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
) as test_rpc:
|
||||
await wait_for_dgpus(test_rpc, 2)
|
||||
|
||||
async with trio.open_nursery() as n:
|
||||
await n.start(
|
||||
run_skynet,
|
||||
'skynet', '3GbZd6UojbD8V7zWpeFn', 'ancap.tech:34508')
|
||||
|
||||
await trio.sleep(2)
|
||||
|
||||
for i in range(3):
|
||||
n.start_soon(open_dgpu_node)
|
||||
|
||||
await trio.sleep(1)
|
||||
start = time.time()
|
||||
async def request_img():
|
||||
with pynng.Req0(dial=DEFAULT_RPC_ADDR) as rpc_sock:
|
||||
res = await rpc_call(
|
||||
rpc_sock, 'tg+1', 'txt2img', {
|
||||
'prompt': 'test',
|
||||
'step': 28,
|
||||
'width': 512, 'height': 512,
|
||||
'guidance': 7.5,
|
||||
'seed': None,
|
||||
'algo': 'stable',
|
||||
'upscaler': None
|
||||
})
|
||||
|
||||
logging.info(res)
|
||||
|
||||
async with trio.open_nursery() as inner_n:
|
||||
for i in range(3):
|
||||
inner_n.start_soon(request_img)
|
||||
|
||||
logging.info(f'time elapsed: {time.time() - start}')
|
||||
n.cancel_scope.cancel()
|
||||
n.start_soon(check_request_img, 0)
|
||||
n.start_soon(check_request_img, 0)
|
||||
|
||||
|
||||
trio_asyncio.run(main)
|
||||
@pytest.mark.parametrize(
|
||||
'dgpu_workers', [(1, ['midj'])], indirect=True)
|
||||
async def test_dgpu_worker_algo_swap(dgpu_workers):
|
||||
'''Generate an image using a non default model
|
||||
'''
|
||||
async with open_skynet_rpc(
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
) as test_rpc:
|
||||
await wait_for_dgpus(test_rpc, 1)
|
||||
await check_request_img(5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'dgpu_workers', [(3, ['midj'])], indirect=True)
|
||||
async def test_dgpu_rotation_next_worker(dgpu_workers):
|
||||
'''Connect three dgpu workers, disconnect and check next_worker
|
||||
rotation happens correctly
|
||||
'''
|
||||
async with open_skynet_rpc(
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
) as test_rpc:
|
||||
await wait_for_dgpus(test_rpc, 3)
|
||||
|
||||
res = await test_rpc('testing-rpc', 'dgpu_next')
|
||||
logging.info(res)
|
||||
assert 'ok' in res.result
|
||||
assert res.result['ok'] == 0
|
||||
|
||||
await check_request_img(0)
|
||||
|
||||
res = await test_rpc('testing-rpc', 'dgpu_next')
|
||||
logging.info(res)
|
||||
assert 'ok' in res.result
|
||||
assert res.result['ok'] == 1
|
||||
|
||||
await check_request_img(0)
|
||||
|
||||
res = await test_rpc('testing-rpc', 'dgpu_next')
|
||||
logging.info(res)
|
||||
assert 'ok' in res.result
|
||||
assert res.result['ok'] == 2
|
||||
|
||||
await check_request_img(0)
|
||||
|
||||
res = await test_rpc('testing-rpc', 'dgpu_next')
|
||||
logging.info(res)
|
||||
assert 'ok' in res.result
|
||||
assert res.result['ok'] == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'dgpu_workers', [(3, ['midj'])], indirect=True)
|
||||
async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers):
|
||||
'''Connect three dgpu workers, disconnect the first one and check
|
||||
next_worker rotation happens correctly
|
||||
'''
|
||||
async with open_skynet_rpc(
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
) as test_rpc:
|
||||
await wait_for_dgpus(test_rpc, 3)
|
||||
|
||||
await trio.sleep(3)
|
||||
|
||||
# stop worker who's turn is next
|
||||
for _ in range(2):
|
||||
ec, out = dgpu_workers[0].exec_run(['pkill', '-INT', '-f', 'skynet'])
|
||||
assert ec == 0
|
||||
|
||||
dgpu_workers[0].wait()
|
||||
|
||||
res = await test_rpc('testing-rpc', 'dgpu_workers')
|
||||
logging.info(res)
|
||||
assert 'ok' in res.result
|
||||
assert res.result['ok'] == 2
|
||||
|
||||
async with trio.open_nursery() as n:
|
||||
n.start_soon(check_request_img, 0)
|
||||
n.start_soon(check_request_img, 0)
|
||||
|
||||
|
||||
async def test_dgpu_no_ack_node_disconnect(skynet_running):
|
||||
async with open_skynet_rpc(
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
) as rpc_call:
|
||||
|
||||
res = await rpc_call('dgpu-0', 'dgpu_online')
|
||||
logging.info(res)
|
||||
assert 'ok' in res.result
|
||||
|
||||
await wait_for_dgpus(rpc_call, 1)
|
||||
|
||||
with pytest.raises(SkynetDGPUComputeError) as e:
|
||||
await check_request_img(0)
|
||||
|
||||
assert 'dgpu failed to acknowledge request' in str(e)
|
||||
|
||||
res = await rpc_call('testing-rpc', 'dgpu_workers')
|
||||
logging.info(res)
|
||||
assert 'ok' in res.result
|
||||
assert res.result['ok'] == 0
|
||||
|
||||
|
|
|
@ -1,107 +0,0 @@
|
|||
import trio
|
||||
import tractor
|
||||
|
||||
from skynet_bot.types import *
|
||||
|
||||
@tractor.context
|
||||
async def open_fake_worker(
|
||||
ctx: tractor.Context,
|
||||
start_algo: str,
|
||||
mem_fraction: float
|
||||
):
|
||||
log = tractor.log.get_logger(name='gpu', _root_name='skynet')
|
||||
log.info(f'starting gpu worker with algo {start_algo}...')
|
||||
current_algo = start_algo
|
||||
log.info('pipeline loaded')
|
||||
await ctx.started()
|
||||
async with ctx.open_stream() as bus:
|
||||
async for ireq in bus:
|
||||
if ireq:
|
||||
await bus.send('hello!')
|
||||
else:
|
||||
break
|
||||
|
||||
def test_gpu_worker():
|
||||
log = tractor.log.get_logger(name='root', _root_name='skynet')
|
||||
async def main():
|
||||
async with (
|
||||
tractor.open_nursery(debug_mode=True) as an,
|
||||
trio.open_nursery() as n
|
||||
):
|
||||
portal = await an.start_actor(
|
||||
'gpu_worker',
|
||||
enable_modules=[__name__],
|
||||
debug_mode=True
|
||||
)
|
||||
|
||||
log.info('portal opened')
|
||||
async with (
|
||||
portal.open_context(
|
||||
open_fake_worker,
|
||||
start_algo='midj',
|
||||
mem_fraction=0.6
|
||||
) as (ctx, _),
|
||||
ctx.open_stream() as stream,
|
||||
):
|
||||
log.info('opened worker sending req...')
|
||||
ireq = ImageGenRequest(
|
||||
prompt='a red tractor on a wheat field',
|
||||
step=28,
|
||||
width=512, height=512,
|
||||
guidance=10, seed=None,
|
||||
algo='midj', upscaler=None)
|
||||
|
||||
await stream.send(ireq)
|
||||
log.info('sent, await respnse')
|
||||
async for msg in stream:
|
||||
log.info(f'got {msg}')
|
||||
break
|
||||
|
||||
assert msg == 'hello!'
|
||||
await stream.send(None)
|
||||
log.info('done.')
|
||||
|
||||
await portal.cancel_actor()
|
||||
|
||||
trio.run(main)
|
||||
|
||||
|
||||
def test_gpu_two_workers():
|
||||
async def main():
|
||||
outputs = []
|
||||
async with (
|
||||
tractor.open_actor_cluster(
|
||||
modules=[__name__],
|
||||
count=2,
|
||||
names=[0, 1]) as portal_map,
|
||||
tractor.trionics.gather_contexts((
|
||||
portal.open_context(
|
||||
open_fake_worker,
|
||||
start_algo='midj',
|
||||
mem_fraction=0.333)
|
||||
for portal in portal_map.values()
|
||||
)) as contexts,
|
||||
trio.open_nursery() as n
|
||||
):
|
||||
ireq = ImageGenRequest(
|
||||
prompt='a red tractor on a wheat field',
|
||||
step=28,
|
||||
width=512, height=512,
|
||||
guidance=10, seed=None,
|
||||
algo='midj', upscaler=None)
|
||||
|
||||
async def get_img(i):
|
||||
ctx = contexts[i]
|
||||
async with ctx.open_stream() as stream:
|
||||
await stream.send(ireq)
|
||||
async for img in stream:
|
||||
outputs[i] = img
|
||||
await portal_map[i].cancel_actor()
|
||||
|
||||
n.start_soon(get_img, 0)
|
||||
n.start_soon(get_img, 1)
|
||||
|
||||
|
||||
assert len(outputs) == 2
|
||||
|
||||
trio.run(main)
|
|
@ -7,9 +7,9 @@ import pynng
|
|||
import pytest
|
||||
import trio_asyncio
|
||||
|
||||
from skynet_bot.types import *
|
||||
from skynet_bot.brain import run_skynet
|
||||
from skynet_bot.frontend import open_skynet_rpc
|
||||
from skynet.brain import run_skynet
|
||||
from skynet.structs import *
|
||||
from skynet.frontend import open_skynet_rpc
|
||||
|
||||
|
||||
async def test_skynet_attempt_insecure(skynet_running):
|
||||
|
@ -40,7 +40,7 @@ async def test_skynet_dgpu_connection_simple(skynet_running):
|
|||
|
||||
# connect 1 dgpu
|
||||
res = await rpc_call(
|
||||
'dgpu-0', 'dgpu_online', {'max_tasks': 3})
|
||||
'dgpu-0', 'dgpu_online')
|
||||
logging.info(res)
|
||||
assert 'ok' in res.result
|
||||
|
||||
|
|
Loading…
Reference in New Issue