Rework dgpu client to be single task

Add a lot of dgpu real image gen tests
Modified docker files and environment to allow for quick test relaunch without image rebuild
Rename package from skynet_bot to skynet
Drop tractor usage cause cuda is oriented to just a single proc managing gpu resources
Add ackwnoledge phase to image request for quick gpu disconnected type scenarios
Add click entry point for dgpu
Add posibility to reuse postgres_db fixture on same session by checking if schema init has been already done
pull/2/head
Guillermo Rodriguez 2022-12-17 11:39:42 -03:00
parent d2e676627a
commit f6326ad05c
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
29 changed files with 811 additions and 466 deletions

View File

@ -1,3 +1,9 @@
.git
hf_home hf_home
inputs
outputs outputs
.python-version
.pytest-cache
**/__pycache__
*.egg-info
**/*.key
**/*.cert

3
.gitignore vendored
View File

@ -1,5 +1,8 @@
.python-version .python-version
hf_home hf_home
outputs outputs
secrets
**/__pycache__ **/__pycache__
*.egg-info *.egg-info
**/*.key
**/*.cert

View File

@ -4,10 +4,16 @@ env DEBIAN_FRONTEND=noninteractive
workdir /skynet workdir /skynet
copy requirements.* ./ copy requirements.test.txt requirements.test.txt
copy requirements.txt requirements.txt
copy pytest.ini ./
copy setup.py ./
copy skynet ./skynet
run pip install \ run pip install \
-e . \
-r requirements.txt \ -r requirements.txt \
-r requirements.test.txt -r requirements.test.txt
workdir /scripts copy scripts ./
copy tests ./

View File

@ -5,19 +5,25 @@ env DEBIAN_FRONTEND=noninteractive
workdir /skynet workdir /skynet
copy requirements.* ./ copy requirements.cuda* ./
run pip install -U pip ninja run pip install -U pip ninja
run pip install -r requirements.cuda.0.txt run pip install -r requirements.cuda.0.txt
run pip install -v -r requirements.cuda.1.txt run pip install -v -r requirements.cuda.1.txt
run pip install \ copy requirements.test.txt requirements.test.txt
copy requirements.txt requirements.txt
copy pytest.ini pytest.ini
copy setup.py setup.py
copy skynet skynet
run pip install -e . \
-r requirements.txt \ -r requirements.txt \
-r requirements.test.txt -r requirements.test.txt
env PYTORCH_CUDA_ALLOC_CONF max_split_size_mb:128
env NVIDIA_VISIBLE_DEVICES=all env NVIDIA_VISIBLE_DEVICES=all
env HF_HOME /hf_home env HF_HOME /hf_home
env PYTORCH_CUDA_ALLOC_CONF max_split_size_mb:128 copy scripts scripts
copy tests tests
workdir /scripts

View File

@ -1,6 +1,6 @@
docker build \ docker build \
-t skynet:runtime-cuda \ -t skynet:runtime-cuda \
-f Dockerfile.runtime-cuda . -f Dockerfile.runtime+cuda .
docker build \ docker build \
-t skynet:runtime \ -t skynet:runtime \

View File

@ -1,2 +1,4 @@
[pytest] [pytest]
log_cli = True
log_level = info
trio_mode = true trio_mode = true

View File

@ -1,4 +1,3 @@
pdbpp
scipy scipy
triton triton
accelerate accelerate

View File

@ -1,5 +1,6 @@
pdbpp
pytest pytest
psycopg2 psycopg2
pytest-trio pytest-trio
git+https://github.com/tgoodlet/pytest-dockerctl.git@master#egg=pytest-dockerctl git+https://github.com/guilledk/pytest-dockerctl.git@host_network#egg=pytest-dockerctl

View File

@ -5,5 +5,3 @@ aiohttp
msgspec msgspec
pyOpenSSL pyOpenSSL
trio_asyncio trio_asyncio
git+https://github.com/goodboy/tractor.git@piker_pin#egg=tractor

View File

@ -8,7 +8,7 @@ import sys
from OpenSSL import crypto, SSL from OpenSSL import crypto, SSL
from skynet_bot.constants import DEFAULT_CERTS_DIR from skynet.constants import DEFAULT_CERTS_DIR
def input_or_skip(txt, default): def input_or_skip(txt, default):

View File

@ -1,11 +1,16 @@
from setuptools import setup, find_packages from setuptools import setup, find_packages
setup( setup(
name='skynet-bot', name='skynet',
version='0.1.0a6', version='0.1.0a6',
description='Decentralized compute platform', description='Decentralized compute platform',
author='Guillermo Rodriguez', author='Guillermo Rodriguez',
author_email='guillermo@telos.net', author_email='guillermo@telos.net',
packages=find_packages(), packages=find_packages(),
install_requires=[] entry_points={
'console_scripts': [
'skynet = skynet.cli:skynet',
]
},
install_requires=['click']
) )

View File

@ -8,6 +8,7 @@ import logging
from uuid import UUID from uuid import UUID
from pathlib import Path from pathlib import Path
from functools import partial from functools import partial
from contextlib import asynccontextmanager as acm
from collections import OrderedDict from collections import OrderedDict
import trio import trio
@ -17,7 +18,7 @@ import trio_asyncio
from pynng import TLSConfig from pynng import TLSConfig
from .db import * from .db import *
from .types import * from .structs import *
from .constants import * from .constants import *
@ -27,18 +28,47 @@ class SkynetDGPUOffline(BaseException):
class SkynetDGPUOverloaded(BaseException): class SkynetDGPUOverloaded(BaseException):
... ...
class SkynetDGPUComputeError(BaseException):
...
async def rpc_service(sock, dgpu_bus, db_pool): class SkynetShutdownRequested(BaseException):
...
@acm
async def open_rpc_service(sock, dgpu_bus, db_pool):
nodes = OrderedDict() nodes = OrderedDict()
wip_reqs = {} wip_reqs = {}
fin_reqs = {} fin_reqs = {}
next_worker: Optional[int] = None
def is_worker_busy(nid: int): def connect_node(uid):
for task in nodes[nid]['tasks']: nonlocal next_worker
if task != None: nodes[uid] = {
return False 'task': None
}
logging.info(f'dgpu online: {uid}')
return True if not next_worker:
next_worker = 0
def disconnect_node(uid):
nonlocal next_worker
if uid not in nodes:
return
i = list(nodes.keys()).index(uid)
del nodes[uid]
if i < next_worker:
next_worker -= 1
if len(nodes) == 0:
logging.info('nw: None')
next_worker = None
logging.warning(f'dgpu offline: {uid}')
def is_worker_busy(nid: str):
return nodes[nid]['task'] != None
def are_all_workers_busy(): def are_all_workers_busy():
for nid in nodes.keys(): for nid in nodes.keys():
@ -47,30 +77,55 @@ async def rpc_service(sock, dgpu_bus, db_pool):
return True return True
next_worker: Optional[int] = None
def get_next_worker(): def get_next_worker():
nonlocal next_worker nonlocal next_worker
logging.info('get next_worker called')
logging.info(f'pre next_worker: {next_worker}')
if not next_worker: if next_worker == None:
raise SkynetDGPUOffline raise SkynetDGPUOffline
if are_all_workers_busy(): if are_all_workers_busy():
raise SkynetDGPUOverloaded raise SkynetDGPUOverloaded
while is_worker_busy(next_worker):
nid = list(nodes.keys())[next_worker]
while is_worker_busy(nid):
next_worker += 1 next_worker += 1
if next_worker >= len(nodes): if next_worker >= len(nodes):
next_worker = 0 next_worker = 0
return next_worker nid = list(nodes.keys())[next_worker]
next_worker += 1
if next_worker >= len(nodes):
next_worker = 0
logging.info(f'post next_worker: {next_worker}')
return nid
async def dgpu_image_streamer(): async def dgpu_image_streamer():
nonlocal wip_reqs, fin_reqs nonlocal wip_reqs, fin_reqs
while True: while True:
msg = await dgpu_bus.arecv_msg() msg = await dgpu_bus.arecv_msg()
rid = UUID(bytes=msg.bytes[:16]).hex rid = UUID(bytes=msg.bytes[:16]).hex
img = msg.bytes[16:].hex() raw_msg = msg.bytes[16:]
logging.info(f'streamer got back {rid} of size {len(raw_msg)}')
match raw_msg[:5]:
case b'error':
img = raw_msg.decode()
case b'ack':
img = raw_msg
case _:
img = base64.b64encode(raw_msg).hex()
if rid not in wip_reqs:
continue
fin_reqs[rid] = img fin_reqs[rid] = img
event = wip_reqs[rid] event = wip_reqs[rid]
event.set() event.set()
@ -79,13 +134,14 @@ async def rpc_service(sock, dgpu_bus, db_pool):
async def dgpu_stream_one_img(req: ImageGenRequest): async def dgpu_stream_one_img(req: ImageGenRequest):
nonlocal wip_reqs, fin_reqs, next_worker nonlocal wip_reqs, fin_reqs, next_worker
nid = get_next_worker() nid = get_next_worker()
logging.info(f'dgpu_stream_one_img {next_worker} {nid}') idx = list(nodes.keys()).index(nid)
logging.info(f'dgpu_stream_one_img {idx}/{len(nodes)} {nid}')
rid = uuid.uuid4().hex rid = uuid.uuid4().hex
event = trio.Event() ack_event = trio.Event()
wip_reqs[rid] = event img_event = trio.Event()
wip_reqs[rid] = ack_event
tid = nodes[nid]['tasks'].index(None) nodes[nid]['task'] = rid
nodes[nid]['tasks'][tid] = rid
dgpu_req = DGPUBusRequest( dgpu_req = DGPUBusRequest(
rid=rid, rid=rid,
@ -98,14 +154,37 @@ async def rpc_service(sock, dgpu_bus, db_pool):
await dgpu_bus.asend( await dgpu_bus.asend(
json.dumps(dgpu_req.to_dict()).encode()) json.dumps(dgpu_req.to_dict()).encode())
await event.wait() with trio.move_on_after(4):
await ack_event.wait()
nodes[nid]['tasks'][tid] = None logging.info(f'ack event: {ack_event.is_set()}')
if not ack_event.is_set():
disconnect_node(nid)
raise SkynetDGPUOffline('dgpu failed to acknowledge request')
ack = fin_reqs[rid]
if ack != b'ack':
disconnect_node(nid)
raise SkynetDGPUOffline('dgpu failed to acknowledge request')
wip_reqs[rid] = img_event
with trio.move_on_after(30):
await img_event.wait()
if not img_event.is_set():
disconnect_node(nid)
raise SkynetDGPUComputeError('30 seconds timeout while processing request')
nodes[nid]['task'] = None
img = fin_reqs[rid] img = fin_reqs[rid]
del fin_reqs[rid] del fin_reqs[rid]
logging.info(f'done streaming {img}') logging.info(f'done streaming {len(img)} bytes')
if 'error' in img:
raise SkynetDGPUComputeError(img)
return rid, img return rid, img
@ -122,6 +201,10 @@ async def rpc_service(sock, dgpu_bus, db_pool):
user_config = {**(await get_user_config(conn, user))} user_config = {**(await get_user_config(conn, user))}
del user_config['id'] del user_config['id']
prompt = req.params['prompt'] prompt = req.params['prompt']
user_config= {
key : req.params.get(key, val)
for key, val in user_config.items()
}
req = ImageGenRequest( req = ImageGenRequest(
prompt=prompt, prompt=prompt,
**user_config **user_config
@ -165,9 +248,10 @@ async def rpc_service(sock, dgpu_bus, db_pool):
case _: case _:
logging.warn('unknown method') logging.warn('unknown method')
except SkynetDGPUOffline: except SkynetDGPUOffline as e:
result = { result = {
'error': 'skynet_dgpu_offline' 'error': 'skynet_dgpu_offline',
'message': str(e)
} }
except SkynetDGPUOverloaded: except SkynetDGPUOverloaded:
@ -176,22 +260,22 @@ async def rpc_service(sock, dgpu_bus, db_pool):
'nodes': len(nodes) 'nodes': len(nodes)
} }
except BaseException as e: except SkynetDGPUComputeError as e:
logging.error(e)
result = { result = {
'error': 'skynet_internal_error' 'error': 'skynet_dgpu_compute_error',
'message': str(e)
} }
await rpc_ctx.asend( await rpc_ctx.asend(
json.dumps( json.dumps(
SkynetRPCResponse(result=result).to_dict()).encode()) SkynetRPCResponse(result=result).to_dict()).encode())
async def request_service(n):
async with trio.open_nursery() as n: nonlocal next_worker
n.start_soon(dgpu_image_streamer)
while True: while True:
ctx = sock.new_context() ctx = sock.new_context()
msg = await ctx.arecv_msg() msg = await ctx.arecv_msg()
content = msg.bytes.decode() content = msg.bytes.decode()
req = SkynetRPCRequest(**json.loads(content)) req = SkynetRPCRequest(**json.loads(content))
@ -199,27 +283,14 @@ async def rpc_service(sock, dgpu_bus, db_pool):
result = {} result = {}
if req.method == 'dgpu_online': if req.method == 'skynet_shutdown':
nodes[req.uid] = { raise SkynetShutdownRequested
'tasks': [None for _ in range(req.params['max_tasks'])],
'max_tasks': req.params['max_tasks']
}
logging.info(f'dgpu online: {req.uid}')
if not next_worker: elif req.method == 'dgpu_online':
next_worker = 0 connect_node(req.uid)
elif req.method == 'dgpu_offline': elif req.method == 'dgpu_offline':
i = list(nodes.keys()).index(req.uid) disconnect_node(req.uid)
del nodes[req.uid]
if i < next_worker:
next_worker -= 1
if len(nodes) == 0:
next_worker = None
logging.info(f'dgpu offline: {req.uid}')
elif req.method == 'dgpu_workers': elif req.method == 'dgpu_workers':
result = len(nodes) result = len(nodes)
@ -238,13 +309,22 @@ async def rpc_service(sock, dgpu_bus, db_pool):
result={'ok': result}).to_dict()).encode()) result={'ok': result}).to_dict()).encode())
async with trio.open_nursery() as n:
n.start_soon(dgpu_image_streamer)
n.start_soon(request_service, n)
logging.info('starting rpc service')
yield
logging.info('stopping rpc service')
n.cancel_scope.cancel()
@acm
async def run_skynet( async def run_skynet(
db_user: str = DB_USER, db_user: str = DB_USER,
db_pass: str = DB_PASS, db_pass: str = DB_PASS,
db_host: str = DB_HOST, db_host: str = DB_HOST,
rpc_address: str = DEFAULT_RPC_ADDR, rpc_address: str = DEFAULT_RPC_ADDR,
dgpu_address: str = DEFAULT_DGPU_ADDR, dgpu_address: str = DEFAULT_DGPU_ADDR,
task_status = trio.TASK_STATUS_IGNORED,
security: bool = True security: bool = True
): ):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -260,8 +340,8 @@ async def run_skynet(
(cert_path).read_text() (cert_path).read_text()
for cert_path in (certs_dir / 'whitelist').glob('*.cert')] for cert_path in (certs_dir / 'whitelist').glob('*.cert')]
logging.info(f'tls_key: {tls_key}') cert_start = tls_cert.index('\n') + 1
logging.info(f'tls_cert: {tls_cert}') logging.info(f'tls_cert: {tls_cert[cert_start:cert_start+64]}...')
logging.info(f'tls_whitelist len: {len(tls_whitelist)}') logging.info(f'tls_whitelist len: {len(tls_whitelist)}')
rpc_address = 'tls+' + rpc_address rpc_address = 'tls+' + rpc_address
@ -271,16 +351,14 @@ async def run_skynet(
own_key_string=tls_key, own_key_string=tls_key,
own_cert_string=tls_cert) own_cert_string=tls_cert)
async with ( with (
trio.open_nursery() as n, pynng.Rep0() as rpc_sock,
open_database_connection( pynng.Bus0() as dgpu_bus
db_user, db_pass, db_host) as db_pool
): ):
logging.info('connected to db.') async with open_database_connection(
with ( db_user, db_pass, db_host) as db_pool:
pynng.Rep0() as rpc_sock,
pynng.Bus0() as dgpu_bus logging.info('connected to db.')
):
if security: if security:
rpc_sock.tls_config = tls_config rpc_sock.tls_config = tls_config
dgpu_bus.tls_config = tls_config dgpu_bus.tls_config = tls_config
@ -288,13 +366,11 @@ async def run_skynet(
rpc_sock.listen(rpc_address) rpc_sock.listen(rpc_address)
dgpu_bus.listen(dgpu_address) dgpu_bus.listen(dgpu_address)
n.start_soon(
rpc_service, rpc_sock, dgpu_bus, db_pool)
task_status.started()
try: try:
await trio.sleep_forever() async with open_rpc_service(rpc_sock, dgpu_bus, db_pool):
yield
except KeyboardInterrupt: except SkynetShutdownRequested:
... ...
logging.info('disconnected from db.')

68
skynet/cli.py 100644
View File

@ -0,0 +1,68 @@
#!/usr/bin/python
import os
import json
from typing import Optional
from functools import partial
import trio
import click
from .dgpu import open_dgpu_node
from .utils import txt2img
from .constants import ALGOS
@click.group()
def skynet(*args, **kwargs):
pass
@skynet.command()
@click.option('--model', '-m', default=ALGOS['midj'])
@click.option(
'--prompt', '-p', default='a red tractor in a wheat field')
@click.option('--output', '-o', default='output.png')
@click.option('--width', '-w', default=512)
@click.option('--height', '-h', default=512)
@click.option('--guidance', '-g', default=10.0)
@click.option('--steps', '-s', default=26)
@click.option('--seed', '-S', default=None)
def txt2img(*args
# model: str,
# prompt: str,
# output: str
# width: int, height: int,
# guidance: float,
# steps: int,
# seed: Optional[int]
):
assert 'HF_TOKEN' in os.environ
txt2img(
os.environ['HF_TOKEN'], *args)
@skynet.group()
def run(*args, **kwargs):
pass
@run.command()
@click.option('--loglevel', '-l', default='warning', help='Logging level')
@click.option(
'--key', '-k', default='dgpu')
@click.option(
'--cert', '-c', default='whitelist/dgpu')
@click.option(
'--algos', '-a', default=None)
def dgpu(
loglevel: str,
key: str,
cert: str,
algos: Optional[str]
):
trio.run(
partial(
open_dgpu_node,
cert,
key_name=key,
initial_algos=json.loads(algos)
))

View File

@ -1,6 +1,6 @@
#!/usr/bin/python #!/usr/bin/python
API_TOKEN = '5880619053:AAFge2UfObw1kCn9Kb7AAyqaIHs_HgM0Fx0' DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda'
DB_HOST = 'ancap.tech:34508' DB_HOST = 'ancap.tech:34508'
DB_USER = 'skynet' DB_USER = 'skynet'
@ -8,8 +8,8 @@ DB_PASS = 'password'
DB_NAME = 'skynet' DB_NAME = 'skynet'
ALGOS = { ALGOS = {
'stable': 'runwayml/stable-diffusion-v1-5',
'midj': 'prompthero/openjourney', 'midj': 'prompthero/openjourney',
'stable': 'runwayml/stable-diffusion-v1-5',
'hdanime': 'Linaqruf/anything-v3.0', 'hdanime': 'Linaqruf/anything-v3.0',
'waifu': 'hakurei/waifu-diffusion', 'waifu': 'hakurei/waifu-diffusion',
'ghibli': 'nitrosocke/Ghibli-Diffusion', 'ghibli': 'nitrosocke/Ghibli-Diffusion',
@ -122,7 +122,7 @@ DEFAULT_CERT_DGPU = 'dgpu.key'
DEFAULT_RPC_ADDR = 'tcp://127.0.0.1:41000' DEFAULT_RPC_ADDR = 'tcp://127.0.0.1:41000'
DEFAULT_DGPU_ADDR = 'tcp://127.0.0.1:41069' DEFAULT_DGPU_ADDR = 'tcp://127.0.0.1:41069'
DEFAULT_DGPU_MAX_TASKS = 3 DEFAULT_DGPU_MAX_TASKS = 2
DEFAULT_INITAL_ALGOS = ['midj', 'stable', 'ink'] DEFAULT_INITAL_ALGOS = ['midj', 'stable', 'ink']
DATE_FORMAT = '%B the %dth %Y, %H:%M:%S' DATE_FORMAT = '%B the %dth %Y, %H:%M:%S'

View File

@ -7,6 +7,9 @@ from contextlib import asynccontextmanager as acm
import trio import trio
import triopg import triopg
import trio_asyncio
from asyncpg.exceptions import UndefinedColumnError
from .constants import * from .constants import *
@ -72,13 +75,22 @@ async def open_database_connection(
db_host: str = DB_HOST, db_host: str = DB_HOST,
db_name: str = DB_NAME db_name: str = DB_NAME
): ):
async with triopg.create_pool( async with trio_asyncio.open_loop() as loop:
dsn=f'postgres://{db_user}:{db_pass}@{db_host}/{db_name}' async with triopg.create_pool(
) as pool_conn: dsn=f'postgres://{db_user}:{db_pass}@{db_host}/{db_name}'
async with pool_conn.acquire() as conn: ) as pool_conn:
await conn.execute(DB_INIT_SQL) async with pool_conn.acquire() as conn:
res = await conn.execute(f'''
select distinct table_schema
from information_schema.tables
where table_schema = \'{db_name}\'
''')
if '1' in res:
logging.info('schema already in db, skipping init')
else:
await conn.execute(DB_INIT_SQL)
yield pool_conn yield pool_conn
async def get_user(conn, uid: str): async def get_user(conn, uid: str):
@ -135,6 +147,7 @@ async def new_user(conn, uid: str):
tg_id, generated, joined, last_prompt, role) tg_id, generated, joined, last_prompt, role)
VALUES($1, $2, $3, $4, $5) VALUES($1, $2, $3, $4, $5)
ON CONFLICT DO NOTHING
''') ''')
await stmt.fetch( await stmt.fetch(
tg_id, 0, date, None, DEFAULT_ROLE tg_id, 0, date, None, DEFAULT_ROLE
@ -147,6 +160,7 @@ async def new_user(conn, uid: str):
id, algo, step, width, height, seed, guidance, upscaler) id, algo, step, width, height, seed, guidance, upscaler)
VALUES($1, $2, $3, $4, $5, $6, $7, $8) VALUES($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT DO NOTHING
''') ''')
user = await stmt.fetch( user = await stmt.fetch(
new_uid, new_uid,

197
skynet/dgpu.py 100644
View File

@ -0,0 +1,197 @@
#!/usr/bin/python
import gc
import io
import trio
import json
import uuid
import random
import logging
from typing import List, Optional
from pathlib import Path
from contextlib import AsyncExitStack
import pynng
import torch
from pynng import TLSConfig
from diffusers import (
StableDiffusionPipeline,
EulerAncestralDiscreteScheduler
)
from .structs import *
from .constants import *
from .frontend import open_skynet_rpc
def pipeline_for(algo: str, mem_fraction: float = 1.0):
assert torch.cuda.is_available()
torch.cuda.empty_cache()
torch.cuda.set_per_process_memory_fraction(mem_fraction)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
params = {
'torch_dtype': torch.float16,
'safety_checker': None
}
if algo == 'stable':
params['revision'] = 'fp16'
pipe = StableDiffusionPipeline.from_pretrained(
ALGOS[algo], **params)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe.scheduler.config)
pipe.enable_vae_slicing()
return pipe.to("cuda")
class DGPUComputeError(BaseException):
...
async def open_dgpu_node(
cert_name: str,
key_name: Optional[str],
rpc_address: str = DEFAULT_RPC_ADDR,
dgpu_address: str = DEFAULT_DGPU_ADDR,
initial_algos: Optional[List[str]] = None,
security: bool = True
):
logging.basicConfig(level=logging.INFO)
logging.info(f'starting dgpu node!')
name = uuid.uuid4()
logging.info(f'loading models...')
initial_algos = (
initial_algos
if initial_algos else DEFAULT_INITAL_ALGOS
)
models = {}
for algo in initial_algos:
models[algo] = {
'pipe': pipeline_for(algo),
'generated': 0
}
logging.info(f'loaded {algo}.')
logging.info('memory summary:\n')
logging.info(torch.cuda.memory_summary())
async def gpu_compute_one(ireq: ImageGenRequest):
if ireq.algo not in models:
least_used = list(models.keys())[0]
for model in models:
if models[least_used]['generated'] > models[model]['generated']:
least_used = model
del models[least_used]
gc.collect()
models[ireq.algo] = {
'pipe': pipeline_for(ireq.algo),
'generated': 0
}
seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64)
try:
image = models[ireq.algo]['pipe'](
ireq.prompt,
width=ireq.width,
height=ireq.height,
guidance_scale=ireq.guidance,
num_inference_steps=ireq.step,
generator=torch.Generator("cuda").manual_seed(seed)
).images[0]
return image.tobytes()
except BaseException as e:
logging.error(e)
raise DGPUComputeError(str(e))
finally:
torch.cuda.empty_cache()
async with open_skynet_rpc(
security=security,
cert_name=cert_name,
key_name=key_name
) as rpc_call:
tls_config = None
if security:
# load tls certs
if not key_name:
key_name = certs_name
certs_dir = Path(DEFAULT_CERTS_DIR).resolve()
skynet_cert_path = certs_dir / 'brain.cert'
tls_cert_path = certs_dir / f'{cert_name}.cert'
tls_key_path = certs_dir / f'{key_name}.key'
skynet_cert = skynet_cert_path.read_text()
tls_cert = tls_cert_path.read_text()
tls_key = tls_key_path.read_text()
logging.info(f'skynet cert: {skynet_cert_path}')
logging.info(f'dgpu cert: {tls_cert_path}')
logging.info(f'dgpu key: {tls_key_path}')
dgpu_address = 'tls+' + dgpu_address
tls_config = TLSConfig(
TLSConfig.MODE_CLIENT,
own_key_string=tls_key,
own_cert_string=tls_cert,
ca_string=skynet_cert)
logging.info(f'connecting to {dgpu_address}')
with pynng.Bus0() as dgpu_sock:
dgpu_sock.tls_config = tls_config
dgpu_sock.dial(dgpu_address)
res = await rpc_call(name.hex, 'dgpu_online')
logging.info(res)
assert 'ok' in res.result
try:
while True:
msg = await dgpu_sock.arecv()
req = DGPUBusRequest(
**json.loads(msg.decode()))
if req.nid != name.hex:
logging.info('witnessed request {req.rid}, for {req.nid}')
continue
# send ack
await dgpu_sock.asend(
bytes.fromhex(req.rid) + b'ack')
logging.info(f'sent ack, processing {req.rid}...')
try:
img = await gpu_compute_one(
ImageGenRequest(**req.params))
except DGPUComputeError as e:
img = b'error' + str(e).encode()
await dgpu_sock.asend(
bytes.fromhex(req.rid) + img)
except KeyboardInterrupt:
logging.info('interrupt caught, stopping...')
finally:
res = await rpc_call(name.hex, 'dgpu_offline')
logging.info(res)
assert 'ok' in res.result

View File

@ -10,7 +10,7 @@ import pynng
from pynng import TLSConfig from pynng import TLSConfig
from ..types import SkynetRPCRequest, SkynetRPCResponse from ..structs import SkynetRPCRequest, SkynetRPCResponse
from ..constants import * from ..constants import *

57
skynet/utils.py 100644
View File

@ -0,0 +1,57 @@
#!/usr/bin/python
import random
from typing import Optional
from pathlib import Path
import torch
from diffusers import StableDiffusionPipeline
from huggingface_hub import login
def txt2img(
hf_token: str,
model_name: str,
prompt: str,
output: str,
width: int, height: int,
guidance: float,
steps: int,
seed: Optional[int]
):
assert torch.cuda.is_available()
torch.cuda.empty_cache()
torch.cuda.set_per_process_memory_fraction(0.333)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
login(token=hf_token)
params = {
'torch_dtype': torch.float16,
'safety_checker': None
}
if model_name == 'runwayml/stable-diffusion-v1-5':
params['revision'] = 'fp16'
pipe = StableDiffusionPipeline.from_pretrained(
model_name, **params)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe.scheduler.config)
pipe = pipe.to("cuda")
seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64)
prompt = prompt
image = pipe(
prompt,
width=width,
height=height,
guidance_scale=guidance, num_inference_steps=steps,
generator=torch.Generator("cuda").manual_seed(seed)
).images[0]
image.save(output)

View File

@ -1,124 +0,0 @@
#!/usr/bin/python
import trio
import json
import uuid
import logging
import pynng
import tractor
from . import gpu
from .gpu import open_gpu_worker
from .types import *
from .constants import *
from .frontend import rpc_call
async def open_dgpu_node(
cert_name: str,
key_name: Optional[str],
rpc_address: str = DEFAULT_RPC_ADDR,
dgpu_address: str = DEFAULT_DGPU_ADDR,
dgpu_max_tasks: int = DEFAULT_DGPU_MAX_TASKS,
initial_algos: str = DEFAULT_INITAL_ALGOS,
security: bool = True
):
logging.basicConfig(level=logging.INFO)
name = uuid.uuid4()
workers = initial_algos.copy()
tasks = [None for _ in range(dgpu_max_tasks)]
portal_map: dict[int, tractor.Portal]
contexts: dict[int, tractor.Context]
def get_next_worker(need_algo: str):
nonlocal workers, tasks
for task, algo in zip(workers, tasks):
if need_algo == algo and not task:
return workers.index(need_algo)
return tasks.index(None)
async def gpu_streamer(
ctx: tractor.Context,
nid: int
):
nonlocal tasks
async with ctx.open_stream() as stream:
async for img in stream:
tasks[nid]['res'] = img
tasks[nid]['event'].set()
async def gpu_compute_one(ireq: ImageGenRequest):
wid = get_next_worker(ireq.algo)
event = trio.Event()
workers[wid] = ireq.algo
tasks[wid] = {
'res': None, 'event': event}
await contexts[i].send(ireq)
await event.wait()
img = tasks[wid]['res']
tasks[wid] = None
return img
async with open_skynet_rpc(
security=security,
cert_name=cert_name,
key_name=key_name
) as rpc_call:
with pynng.Bus0(dial=dgpu_address) as dgpu_sock:
async def _process_dgpu_req(req: DGPUBusRequest):
img = await gpu_compute_one(
ImageGenRequest(**req.params))
await dgpu_sock.asend(
bytes.fromhex(req.rid) + img)
res = await rpc_call(
name.hex, 'dgpu_online', {'max_tasks': dgpu_max_tasks})
logging.info(res)
assert 'ok' in res.result
async with (
tractor.open_actor_cluster(
modules=['skynet_bot.gpu'],
count=dgpu_max_tasks,
names=[i for i in range(dgpu_max_tasks)]
) as portal_map,
trio.open_nursery() as n
):
logging.info(f'starting {dgpu_max_tasks} gpu workers')
async with tractor.gather_contexts((
portal.open_context(
open_gpu_worker, algo, 1.0 / dgpu_max_tasks)
for portal in portal_map.values()
)) as contexts:
contexts = {i: ctx for i, ctx in enumerate(contexts)}
for i, ctx in contexts.items():
n.start_soon(
gpu_streamer, ctx, i)
try:
while True:
msg = await dgpu_sock.arecv()
req = DGPUBusRequest(
**json.loads(msg.decode()))
if req.nid != name.hex:
continue
logging.info(f'dgpu: {name}, req: {req}')
n.start_soon(
_process_dgpu_req, req)
except KeyboardInterrupt:
...
res = await rpc_call(name.hex, 'dgpu_offline')
logging.info(res)
assert 'ok' in res.result

View File

@ -1,77 +0,0 @@
#!/usr/bin/python
import io
import random
import logging
import torch
import tractor
from diffusers import (
StableDiffusionPipeline,
EulerAncestralDiscreteScheduler
)
from .types import ImageGenRequest
from .constants import ALGOS
def pipeline_for(algo: str, mem_fraction: float):
assert torch.cuda.is_available()
torch.cuda.empty_cache()
torch.cuda.set_per_process_memory_fraction(mem_fraction)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
params = {
'torch_dtype': torch.float16,
'safety_checker': None
}
if algo == 'stable':
params['revision'] = 'fp16'
pipe = StableDiffusionPipeline.from_pretrained(
ALGOS[algo], **params)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe.scheduler.config)
return pipe.to("cuda")
@tractor.context
async def open_gpu_worker(
ctx: tractor.Context,
start_algo: str,
mem_fraction: float
):
log = tractor.log.get_logger(name='gpu', _root_name='skynet')
log.info(f'starting gpu worker with algo {start_algo}...')
current_algo = start_algo
with torch.no_grad():
pipe = pipeline_for(current_algo, mem_fraction)
log.info('pipeline loaded')
await ctx.started()
async with ctx.open_stream() as bus:
async for ireq in bus:
if ireq.algo != current_algo:
current_algo = ireq.algo
pipe = pipeline_for(current_algo, mem_fraction)
seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64)
image = pipe(
ireq.prompt,
width=ireq.width,
height=ireq.height,
guidance_scale=ireq.guidance,
num_inference_steps=ireq.step,
generator=torch.Generator("cuda").manual_seed(seed)
).images[0]
torch.cuda.empty_cache()
# convert PIL.Image to BytesIO
img_bytes = io.BytesIO()
image.save(img_bytes, format='PNG')
await bus.send(img_bytes.getvalue())

View File

@ -1,2 +0,0 @@
from OpenSSL.crypto import load_publickey, FILETYPE_PEM, verify, X509

View File

@ -1,9 +0,0 @@
docker run \
-it \
--rm \
--gpus=all \
--mount type=bind,source="$(pwd)",target=/skynet \
skynet:runtime-cuda \
bash -c \
"cd /skynet && pip install -e . && \
pytest $1 --log-cli-level=info"

View File

@ -1,21 +1,25 @@
#!/usr/bin/python #!/usr/bin/python
import os
import json
import time import time
import random import random
import string import string
import logging import logging
from functools import partial from functools import partial
from pathlib import Path
import trio import trio
import pytest import pytest
import psycopg2 import psycopg2
import trio_asyncio import trio_asyncio
from docker.types import Mount, DeviceRequest
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
from skynet_bot.constants import * from skynet.constants import *
from skynet_bot.brain import run_skynet from skynet.brain import run_skynet
@pytest.fixture(scope='session') @pytest.fixture(scope='session')
@ -29,6 +33,7 @@ def postgres_db(dockerctl):
with dockerctl.run( with dockerctl.run(
'postgres', 'postgres',
name='skynet-test-postgres',
ports={'5432/tcp': None}, ports={'5432/tcp': None},
environment={ environment={
'POSTGRES_PASSWORD': rpassword 'POSTGRES_PASSWORD': rpassword
@ -67,6 +72,8 @@ def postgres_db(dockerctl):
cursor.execute( cursor.execute(
f'GRANT ALL PRIVILEGES ON DATABASE {DB_NAME} TO {DB_USER}') f'GRANT ALL PRIVILEGES ON DATABASE {DB_NAME} TO {DB_USER}')
conn.close()
logging.info('done.') logging.info('done.')
yield container, password, host yield container, password, host
@ -74,16 +81,44 @@ def postgres_db(dockerctl):
@pytest.fixture @pytest.fixture
async def skynet_running(postgres_db): async def skynet_running(postgres_db):
db_container, db_pass, db_host = postgres_db db_container, db_pass, db_host = postgres_db
async with (
trio_asyncio.open_loop(), async with run_skynet(
trio.open_nursery() as n db_pass=db_pass,
db_host=db_host
): ):
await n.start(
partial(run_skynet,
db_pass=db_pass,
db_host=db_host))
yield yield
n.cancel_scope.cancel()
@pytest.fixture
def dgpu_workers(request, dockerctl, skynet_running):
devices = [DeviceRequest(capabilities=[['gpu']])]
mounts = [Mount(
'/skynet', str(Path().resolve()), type='bind')]
num_containers, initial_algos = request.param
cmd = f'''
pip install -e . && \
skynet run dgpu --algos=\'{json.dumps(initial_algos)}\'
'''
logging.info(f'launching: \n{cmd}')
with dockerctl.run(
DOCKER_RUNTIME_CUDA,
name='skynet-test-runtime-cuda',
command=['bash', '-c', cmd],
environment={
'HF_TOKEN': os.environ['HF_TOKEN'],
'HF_HOME': '/skynet/hf_home'
},
network='host',
mounts=mounts,
device_requests=devices,
num=num_containers
) as containers:
yield containers
#for i, container in enumerate(containers):
# logging.info(f'container {i} logs:')
# logging.info(container.logs().decode())

View File

@ -1,57 +1,248 @@
#!/usr/bin/python #!/usr/bin/python
import io
import time import time
import json import json
import base64
import logging import logging
from hashlib import sha256
from functools import partial
import trio import trio
import pynng import pytest
import tractor import tractor
import trio_asyncio import trio_asyncio
from skynet_bot.gpu import open_gpu_worker from PIL import Image
from skynet_bot.dgpu import open_dgpu_node
from skynet_bot.types import * from skynet.brain import SkynetDGPUComputeError
from skynet_bot.brain import run_skynet from skynet.constants import *
from skynet_bot.constants import * from skynet.frontend import open_skynet_rpc
from skynet_bot.frontend import open_skynet_rpc, rpc_call
def test_dgpu_simple(): async def wait_for_dgpus(rpc, amount: int, timeout: float = 30.0):
async def main(): gpu_ready = False
start_time = time.time()
current_time = time.time()
while not gpu_ready and (current_time - start_time) < timeout:
res = await rpc('dgpu-test', 'dgpu_workers')
logging.info(res)
if res.result['ok'] >= amount:
break
await trio.sleep(1)
current_time = time.time()
assert (current_time - start_time) < timeout
_images = set()
async def check_request_img(
i: int,
width: int = 512,
height: int = 512,
expect_unique=True
):
global _images
async with open_skynet_rpc(
security=True,
cert_name='whitelist/testing',
key_name='testing'
) as rpc_call:
res = await rpc_call(
'tg+580213293', 'txt2img', {
'prompt': 'red old tractor in a sunny wheat field',
'step': 28,
'width': width, 'height': height,
'guidance': 7.5,
'seed': None,
'algo': list(ALGOS.keys())[i],
'upscaler': None
})
if 'error' in res.result:
raise SkynetDGPUComputeError(json.dumps(res.result))
img_raw = base64.b64decode(bytes.fromhex(res.result['img']))
img_sha = sha256(img_raw).hexdigest()
img = Image.frombytes(
'RGB', (width, height), img_raw)
if expect_unique and img_sha in _images:
raise ValueError('Duplicated image sha: {img_sha}')
_images.add(img_sha)
logging.info(f'img sha256: {img_sha} size: {len(img_raw)}')
assert len(img_raw) > 100000
@pytest.mark.parametrize(
'dgpu_workers', [(1, ['midj'])], indirect=True)
async def test_dgpu_worker_compute_error(dgpu_workers):
'''Attempt to generate a huge image and check we get the right error,
then generate a smaller image to show gpu worker recovery
'''
async with open_skynet_rpc(
security=True,
cert_name='whitelist/testing',
key_name='testing'
) as test_rpc:
await wait_for_dgpus(test_rpc, 1)
with pytest.raises(SkynetDGPUComputeError) as e:
await check_request_img(0, width=4096, height=4096)
logging.info(e)
await check_request_img(0)
@pytest.mark.parametrize(
'dgpu_workers', [(1, ['midj', 'stable'])], indirect=True)
async def test_dgpu_workers(dgpu_workers):
'''Generate two images in a single dgpu worker using
two different models.
'''
async with open_skynet_rpc(
security=True,
cert_name='whitelist/testing',
key_name='testing'
) as test_rpc:
await wait_for_dgpus(test_rpc, 1)
await check_request_img(0)
await check_request_img(1)
@pytest.mark.parametrize(
'dgpu_workers', [(2, ['midj'])], indirect=True)
async def test_dgpu_workers_two(dgpu_workers):
'''Generate two images in two separate dgpu workers
'''
async with open_skynet_rpc(
security=True,
cert_name='whitelist/testing',
key_name='testing'
) as test_rpc:
await wait_for_dgpus(test_rpc, 2)
async with trio.open_nursery() as n: async with trio.open_nursery() as n:
await n.start( n.start_soon(check_request_img, 0)
run_skynet, n.start_soon(check_request_img, 0)
'skynet', '3GbZd6UojbD8V7zWpeFn', 'ancap.tech:34508')
await trio.sleep(2)
for i in range(3):
n.start_soon(open_dgpu_node)
await trio.sleep(1)
start = time.time()
async def request_img():
with pynng.Req0(dial=DEFAULT_RPC_ADDR) as rpc_sock:
res = await rpc_call(
rpc_sock, 'tg+1', 'txt2img', {
'prompt': 'test',
'step': 28,
'width': 512, 'height': 512,
'guidance': 7.5,
'seed': None,
'algo': 'stable',
'upscaler': None
})
logging.info(res)
async with trio.open_nursery() as inner_n:
for i in range(3):
inner_n.start_soon(request_img)
logging.info(f'time elapsed: {time.time() - start}')
n.cancel_scope.cancel()
trio_asyncio.run(main) @pytest.mark.parametrize(
'dgpu_workers', [(1, ['midj'])], indirect=True)
async def test_dgpu_worker_algo_swap(dgpu_workers):
'''Generate an image using a non default model
'''
async with open_skynet_rpc(
security=True,
cert_name='whitelist/testing',
key_name='testing'
) as test_rpc:
await wait_for_dgpus(test_rpc, 1)
await check_request_img(5)
@pytest.mark.parametrize(
'dgpu_workers', [(3, ['midj'])], indirect=True)
async def test_dgpu_rotation_next_worker(dgpu_workers):
'''Connect three dgpu workers, disconnect and check next_worker
rotation happens correctly
'''
async with open_skynet_rpc(
security=True,
cert_name='whitelist/testing',
key_name='testing'
) as test_rpc:
await wait_for_dgpus(test_rpc, 3)
res = await test_rpc('testing-rpc', 'dgpu_next')
logging.info(res)
assert 'ok' in res.result
assert res.result['ok'] == 0
await check_request_img(0)
res = await test_rpc('testing-rpc', 'dgpu_next')
logging.info(res)
assert 'ok' in res.result
assert res.result['ok'] == 1
await check_request_img(0)
res = await test_rpc('testing-rpc', 'dgpu_next')
logging.info(res)
assert 'ok' in res.result
assert res.result['ok'] == 2
await check_request_img(0)
res = await test_rpc('testing-rpc', 'dgpu_next')
logging.info(res)
assert 'ok' in res.result
assert res.result['ok'] == 0
@pytest.mark.parametrize(
'dgpu_workers', [(3, ['midj'])], indirect=True)
async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers):
'''Connect three dgpu workers, disconnect the first one and check
next_worker rotation happens correctly
'''
async with open_skynet_rpc(
security=True,
cert_name='whitelist/testing',
key_name='testing'
) as test_rpc:
await wait_for_dgpus(test_rpc, 3)
await trio.sleep(3)
# stop worker who's turn is next
for _ in range(2):
ec, out = dgpu_workers[0].exec_run(['pkill', '-INT', '-f', 'skynet'])
assert ec == 0
dgpu_workers[0].wait()
res = await test_rpc('testing-rpc', 'dgpu_workers')
logging.info(res)
assert 'ok' in res.result
assert res.result['ok'] == 2
async with trio.open_nursery() as n:
n.start_soon(check_request_img, 0)
n.start_soon(check_request_img, 0)
async def test_dgpu_no_ack_node_disconnect(skynet_running):
async with open_skynet_rpc(
security=True,
cert_name='whitelist/testing',
key_name='testing'
) as rpc_call:
res = await rpc_call('dgpu-0', 'dgpu_online')
logging.info(res)
assert 'ok' in res.result
await wait_for_dgpus(rpc_call, 1)
with pytest.raises(SkynetDGPUComputeError) as e:
await check_request_img(0)
assert 'dgpu failed to acknowledge request' in str(e)
res = await rpc_call('testing-rpc', 'dgpu_workers')
logging.info(res)
assert 'ok' in res.result
assert res.result['ok'] == 0

View File

@ -1,107 +0,0 @@
import trio
import tractor
from skynet_bot.types import *
@tractor.context
async def open_fake_worker(
ctx: tractor.Context,
start_algo: str,
mem_fraction: float
):
log = tractor.log.get_logger(name='gpu', _root_name='skynet')
log.info(f'starting gpu worker with algo {start_algo}...')
current_algo = start_algo
log.info('pipeline loaded')
await ctx.started()
async with ctx.open_stream() as bus:
async for ireq in bus:
if ireq:
await bus.send('hello!')
else:
break
def test_gpu_worker():
log = tractor.log.get_logger(name='root', _root_name='skynet')
async def main():
async with (
tractor.open_nursery(debug_mode=True) as an,
trio.open_nursery() as n
):
portal = await an.start_actor(
'gpu_worker',
enable_modules=[__name__],
debug_mode=True
)
log.info('portal opened')
async with (
portal.open_context(
open_fake_worker,
start_algo='midj',
mem_fraction=0.6
) as (ctx, _),
ctx.open_stream() as stream,
):
log.info('opened worker sending req...')
ireq = ImageGenRequest(
prompt='a red tractor on a wheat field',
step=28,
width=512, height=512,
guidance=10, seed=None,
algo='midj', upscaler=None)
await stream.send(ireq)
log.info('sent, await respnse')
async for msg in stream:
log.info(f'got {msg}')
break
assert msg == 'hello!'
await stream.send(None)
log.info('done.')
await portal.cancel_actor()
trio.run(main)
def test_gpu_two_workers():
async def main():
outputs = []
async with (
tractor.open_actor_cluster(
modules=[__name__],
count=2,
names=[0, 1]) as portal_map,
tractor.trionics.gather_contexts((
portal.open_context(
open_fake_worker,
start_algo='midj',
mem_fraction=0.333)
for portal in portal_map.values()
)) as contexts,
trio.open_nursery() as n
):
ireq = ImageGenRequest(
prompt='a red tractor on a wheat field',
step=28,
width=512, height=512,
guidance=10, seed=None,
algo='midj', upscaler=None)
async def get_img(i):
ctx = contexts[i]
async with ctx.open_stream() as stream:
await stream.send(ireq)
async for img in stream:
outputs[i] = img
await portal_map[i].cancel_actor()
n.start_soon(get_img, 0)
n.start_soon(get_img, 1)
assert len(outputs) == 2
trio.run(main)

View File

@ -7,9 +7,9 @@ import pynng
import pytest import pytest
import trio_asyncio import trio_asyncio
from skynet_bot.types import * from skynet.brain import run_skynet
from skynet_bot.brain import run_skynet from skynet.structs import *
from skynet_bot.frontend import open_skynet_rpc from skynet.frontend import open_skynet_rpc
async def test_skynet_attempt_insecure(skynet_running): async def test_skynet_attempt_insecure(skynet_running):
@ -40,7 +40,7 @@ async def test_skynet_dgpu_connection_simple(skynet_running):
# connect 1 dgpu # connect 1 dgpu
res = await rpc_call( res = await rpc_call(
'dgpu-0', 'dgpu_online', {'max_tasks': 3}) 'dgpu-0', 'dgpu_online')
logging.info(res) logging.info(res)
assert 'ok' in res.result assert 'ok' in res.result