mirror of https://github.com/skygpu/skynet.git
Rework dgpu client to be single task
Add a lot of dgpu real image gen tests Modified docker files and environment to allow for quick test relaunch without image rebuild Rename package from skynet_bot to skynet Drop tractor usage cause cuda is oriented to just a single proc managing gpu resources Add ackwnoledge phase to image request for quick gpu disconnected type scenarios Add click entry point for dgpu Add posibility to reuse postgres_db fixture on same session by checking if schema init has been already donepull/2/head
parent
d2e676627a
commit
f6326ad05c
|
@ -1,3 +1,9 @@
|
||||||
|
.git
|
||||||
hf_home
|
hf_home
|
||||||
inputs
|
|
||||||
outputs
|
outputs
|
||||||
|
.python-version
|
||||||
|
.pytest-cache
|
||||||
|
**/__pycache__
|
||||||
|
*.egg-info
|
||||||
|
**/*.key
|
||||||
|
**/*.cert
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
.python-version
|
.python-version
|
||||||
hf_home
|
hf_home
|
||||||
outputs
|
outputs
|
||||||
|
secrets
|
||||||
**/__pycache__
|
**/__pycache__
|
||||||
*.egg-info
|
*.egg-info
|
||||||
|
**/*.key
|
||||||
|
**/*.cert
|
||||||
|
|
|
@ -4,10 +4,16 @@ env DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
workdir /skynet
|
workdir /skynet
|
||||||
|
|
||||||
copy requirements.* ./
|
copy requirements.test.txt requirements.test.txt
|
||||||
|
copy requirements.txt requirements.txt
|
||||||
|
copy pytest.ini ./
|
||||||
|
copy setup.py ./
|
||||||
|
copy skynet ./skynet
|
||||||
|
|
||||||
run pip install \
|
run pip install \
|
||||||
|
-e . \
|
||||||
-r requirements.txt \
|
-r requirements.txt \
|
||||||
-r requirements.test.txt
|
-r requirements.test.txt
|
||||||
|
|
||||||
workdir /scripts
|
copy scripts ./
|
||||||
|
copy tests ./
|
||||||
|
|
|
@ -5,19 +5,25 @@ env DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
workdir /skynet
|
workdir /skynet
|
||||||
|
|
||||||
copy requirements.* ./
|
copy requirements.cuda* ./
|
||||||
|
|
||||||
run pip install -U pip ninja
|
run pip install -U pip ninja
|
||||||
run pip install -r requirements.cuda.0.txt
|
run pip install -r requirements.cuda.0.txt
|
||||||
run pip install -v -r requirements.cuda.1.txt
|
run pip install -v -r requirements.cuda.1.txt
|
||||||
|
|
||||||
run pip install \
|
copy requirements.test.txt requirements.test.txt
|
||||||
|
copy requirements.txt requirements.txt
|
||||||
|
copy pytest.ini pytest.ini
|
||||||
|
copy setup.py setup.py
|
||||||
|
copy skynet skynet
|
||||||
|
|
||||||
|
run pip install -e . \
|
||||||
-r requirements.txt \
|
-r requirements.txt \
|
||||||
-r requirements.test.txt
|
-r requirements.test.txt
|
||||||
|
|
||||||
|
env PYTORCH_CUDA_ALLOC_CONF max_split_size_mb:128
|
||||||
env NVIDIA_VISIBLE_DEVICES=all
|
env NVIDIA_VISIBLE_DEVICES=all
|
||||||
env HF_HOME /hf_home
|
env HF_HOME /hf_home
|
||||||
|
|
||||||
env PYTORCH_CUDA_ALLOC_CONF max_split_size_mb:128
|
copy scripts scripts
|
||||||
|
copy tests tests
|
||||||
workdir /scripts
|
|
|
@ -1,6 +1,6 @@
|
||||||
docker build \
|
docker build \
|
||||||
-t skynet:runtime-cuda \
|
-t skynet:runtime-cuda \
|
||||||
-f Dockerfile.runtime-cuda .
|
-f Dockerfile.runtime+cuda .
|
||||||
|
|
||||||
docker build \
|
docker build \
|
||||||
-t skynet:runtime \
|
-t skynet:runtime \
|
||||||
|
|
|
@ -1,2 +1,4 @@
|
||||||
[pytest]
|
[pytest]
|
||||||
|
log_cli = True
|
||||||
|
log_level = info
|
||||||
trio_mode = true
|
trio_mode = true
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
pdbpp
|
|
||||||
scipy
|
scipy
|
||||||
triton
|
triton
|
||||||
accelerate
|
accelerate
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
|
pdbpp
|
||||||
pytest
|
pytest
|
||||||
psycopg2
|
psycopg2
|
||||||
pytest-trio
|
pytest-trio
|
||||||
|
|
||||||
git+https://github.com/tgoodlet/pytest-dockerctl.git@master#egg=pytest-dockerctl
|
git+https://github.com/guilledk/pytest-dockerctl.git@host_network#egg=pytest-dockerctl
|
||||||
|
|
|
@ -5,5 +5,3 @@ aiohttp
|
||||||
msgspec
|
msgspec
|
||||||
pyOpenSSL
|
pyOpenSSL
|
||||||
trio_asyncio
|
trio_asyncio
|
||||||
|
|
||||||
git+https://github.com/goodboy/tractor.git@piker_pin#egg=tractor
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ import sys
|
||||||
|
|
||||||
from OpenSSL import crypto, SSL
|
from OpenSSL import crypto, SSL
|
||||||
|
|
||||||
from skynet_bot.constants import DEFAULT_CERTS_DIR
|
from skynet.constants import DEFAULT_CERTS_DIR
|
||||||
|
|
||||||
|
|
||||||
def input_or_skip(txt, default):
|
def input_or_skip(txt, default):
|
||||||
|
|
9
setup.py
9
setup.py
|
@ -1,11 +1,16 @@
|
||||||
from setuptools import setup, find_packages
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='skynet-bot',
|
name='skynet',
|
||||||
version='0.1.0a6',
|
version='0.1.0a6',
|
||||||
description='Decentralized compute platform',
|
description='Decentralized compute platform',
|
||||||
author='Guillermo Rodriguez',
|
author='Guillermo Rodriguez',
|
||||||
author_email='guillermo@telos.net',
|
author_email='guillermo@telos.net',
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
install_requires=[]
|
entry_points={
|
||||||
|
'console_scripts': [
|
||||||
|
'skynet = skynet.cli:skynet',
|
||||||
|
]
|
||||||
|
},
|
||||||
|
install_requires=['click']
|
||||||
)
|
)
|
||||||
|
|
|
@ -8,6 +8,7 @@ import logging
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from contextlib import asynccontextmanager as acm
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
|
@ -17,7 +18,7 @@ import trio_asyncio
|
||||||
from pynng import TLSConfig
|
from pynng import TLSConfig
|
||||||
|
|
||||||
from .db import *
|
from .db import *
|
||||||
from .types import *
|
from .structs import *
|
||||||
from .constants import *
|
from .constants import *
|
||||||
|
|
||||||
|
|
||||||
|
@ -27,18 +28,47 @@ class SkynetDGPUOffline(BaseException):
|
||||||
class SkynetDGPUOverloaded(BaseException):
|
class SkynetDGPUOverloaded(BaseException):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
class SkynetDGPUComputeError(BaseException):
|
||||||
|
...
|
||||||
|
|
||||||
async def rpc_service(sock, dgpu_bus, db_pool):
|
class SkynetShutdownRequested(BaseException):
|
||||||
|
...
|
||||||
|
|
||||||
|
@acm
|
||||||
|
async def open_rpc_service(sock, dgpu_bus, db_pool):
|
||||||
nodes = OrderedDict()
|
nodes = OrderedDict()
|
||||||
wip_reqs = {}
|
wip_reqs = {}
|
||||||
fin_reqs = {}
|
fin_reqs = {}
|
||||||
|
next_worker: Optional[int] = None
|
||||||
|
|
||||||
def is_worker_busy(nid: int):
|
def connect_node(uid):
|
||||||
for task in nodes[nid]['tasks']:
|
nonlocal next_worker
|
||||||
if task != None:
|
nodes[uid] = {
|
||||||
return False
|
'task': None
|
||||||
|
}
|
||||||
|
logging.info(f'dgpu online: {uid}')
|
||||||
|
|
||||||
return True
|
if not next_worker:
|
||||||
|
next_worker = 0
|
||||||
|
|
||||||
|
def disconnect_node(uid):
|
||||||
|
nonlocal next_worker
|
||||||
|
if uid not in nodes:
|
||||||
|
return
|
||||||
|
i = list(nodes.keys()).index(uid)
|
||||||
|
del nodes[uid]
|
||||||
|
|
||||||
|
if i < next_worker:
|
||||||
|
next_worker -= 1
|
||||||
|
|
||||||
|
if len(nodes) == 0:
|
||||||
|
logging.info('nw: None')
|
||||||
|
next_worker = None
|
||||||
|
|
||||||
|
logging.warning(f'dgpu offline: {uid}')
|
||||||
|
|
||||||
|
def is_worker_busy(nid: str):
|
||||||
|
return nodes[nid]['task'] != None
|
||||||
|
|
||||||
def are_all_workers_busy():
|
def are_all_workers_busy():
|
||||||
for nid in nodes.keys():
|
for nid in nodes.keys():
|
||||||
|
@ -47,30 +77,55 @@ async def rpc_service(sock, dgpu_bus, db_pool):
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
next_worker: Optional[int] = None
|
|
||||||
def get_next_worker():
|
def get_next_worker():
|
||||||
nonlocal next_worker
|
nonlocal next_worker
|
||||||
|
logging.info('get next_worker called')
|
||||||
|
logging.info(f'pre next_worker: {next_worker}')
|
||||||
|
|
||||||
if not next_worker:
|
if next_worker == None:
|
||||||
raise SkynetDGPUOffline
|
raise SkynetDGPUOffline
|
||||||
|
|
||||||
if are_all_workers_busy():
|
if are_all_workers_busy():
|
||||||
raise SkynetDGPUOverloaded
|
raise SkynetDGPUOverloaded
|
||||||
|
|
||||||
while is_worker_busy(next_worker):
|
|
||||||
|
nid = list(nodes.keys())[next_worker]
|
||||||
|
while is_worker_busy(nid):
|
||||||
next_worker += 1
|
next_worker += 1
|
||||||
|
|
||||||
if next_worker >= len(nodes):
|
if next_worker >= len(nodes):
|
||||||
next_worker = 0
|
next_worker = 0
|
||||||
|
|
||||||
return next_worker
|
nid = list(nodes.keys())[next_worker]
|
||||||
|
|
||||||
|
next_worker += 1
|
||||||
|
if next_worker >= len(nodes):
|
||||||
|
next_worker = 0
|
||||||
|
|
||||||
|
logging.info(f'post next_worker: {next_worker}')
|
||||||
|
|
||||||
|
return nid
|
||||||
|
|
||||||
async def dgpu_image_streamer():
|
async def dgpu_image_streamer():
|
||||||
nonlocal wip_reqs, fin_reqs
|
nonlocal wip_reqs, fin_reqs
|
||||||
while True:
|
while True:
|
||||||
msg = await dgpu_bus.arecv_msg()
|
msg = await dgpu_bus.arecv_msg()
|
||||||
rid = UUID(bytes=msg.bytes[:16]).hex
|
rid = UUID(bytes=msg.bytes[:16]).hex
|
||||||
img = msg.bytes[16:].hex()
|
raw_msg = msg.bytes[16:]
|
||||||
|
logging.info(f'streamer got back {rid} of size {len(raw_msg)}')
|
||||||
|
match raw_msg[:5]:
|
||||||
|
case b'error':
|
||||||
|
img = raw_msg.decode()
|
||||||
|
|
||||||
|
case b'ack':
|
||||||
|
img = raw_msg
|
||||||
|
|
||||||
|
case _:
|
||||||
|
img = base64.b64encode(raw_msg).hex()
|
||||||
|
|
||||||
|
if rid not in wip_reqs:
|
||||||
|
continue
|
||||||
|
|
||||||
fin_reqs[rid] = img
|
fin_reqs[rid] = img
|
||||||
event = wip_reqs[rid]
|
event = wip_reqs[rid]
|
||||||
event.set()
|
event.set()
|
||||||
|
@ -79,13 +134,14 @@ async def rpc_service(sock, dgpu_bus, db_pool):
|
||||||
async def dgpu_stream_one_img(req: ImageGenRequest):
|
async def dgpu_stream_one_img(req: ImageGenRequest):
|
||||||
nonlocal wip_reqs, fin_reqs, next_worker
|
nonlocal wip_reqs, fin_reqs, next_worker
|
||||||
nid = get_next_worker()
|
nid = get_next_worker()
|
||||||
logging.info(f'dgpu_stream_one_img {next_worker} {nid}')
|
idx = list(nodes.keys()).index(nid)
|
||||||
|
logging.info(f'dgpu_stream_one_img {idx}/{len(nodes)} {nid}')
|
||||||
rid = uuid.uuid4().hex
|
rid = uuid.uuid4().hex
|
||||||
event = trio.Event()
|
ack_event = trio.Event()
|
||||||
wip_reqs[rid] = event
|
img_event = trio.Event()
|
||||||
|
wip_reqs[rid] = ack_event
|
||||||
|
|
||||||
tid = nodes[nid]['tasks'].index(None)
|
nodes[nid]['task'] = rid
|
||||||
nodes[nid]['tasks'][tid] = rid
|
|
||||||
|
|
||||||
dgpu_req = DGPUBusRequest(
|
dgpu_req = DGPUBusRequest(
|
||||||
rid=rid,
|
rid=rid,
|
||||||
|
@ -98,14 +154,37 @@ async def rpc_service(sock, dgpu_bus, db_pool):
|
||||||
await dgpu_bus.asend(
|
await dgpu_bus.asend(
|
||||||
json.dumps(dgpu_req.to_dict()).encode())
|
json.dumps(dgpu_req.to_dict()).encode())
|
||||||
|
|
||||||
await event.wait()
|
with trio.move_on_after(4):
|
||||||
|
await ack_event.wait()
|
||||||
|
|
||||||
nodes[nid]['tasks'][tid] = None
|
logging.info(f'ack event: {ack_event.is_set()}')
|
||||||
|
|
||||||
|
if not ack_event.is_set():
|
||||||
|
disconnect_node(nid)
|
||||||
|
raise SkynetDGPUOffline('dgpu failed to acknowledge request')
|
||||||
|
|
||||||
|
ack = fin_reqs[rid]
|
||||||
|
if ack != b'ack':
|
||||||
|
disconnect_node(nid)
|
||||||
|
raise SkynetDGPUOffline('dgpu failed to acknowledge request')
|
||||||
|
|
||||||
|
wip_reqs[rid] = img_event
|
||||||
|
with trio.move_on_after(30):
|
||||||
|
await img_event.wait()
|
||||||
|
|
||||||
|
if not img_event.is_set():
|
||||||
|
disconnect_node(nid)
|
||||||
|
raise SkynetDGPUComputeError('30 seconds timeout while processing request')
|
||||||
|
|
||||||
|
nodes[nid]['task'] = None
|
||||||
|
|
||||||
img = fin_reqs[rid]
|
img = fin_reqs[rid]
|
||||||
del fin_reqs[rid]
|
del fin_reqs[rid]
|
||||||
|
|
||||||
logging.info(f'done streaming {img}')
|
logging.info(f'done streaming {len(img)} bytes')
|
||||||
|
|
||||||
|
if 'error' in img:
|
||||||
|
raise SkynetDGPUComputeError(img)
|
||||||
|
|
||||||
return rid, img
|
return rid, img
|
||||||
|
|
||||||
|
@ -122,6 +201,10 @@ async def rpc_service(sock, dgpu_bus, db_pool):
|
||||||
user_config = {**(await get_user_config(conn, user))}
|
user_config = {**(await get_user_config(conn, user))}
|
||||||
del user_config['id']
|
del user_config['id']
|
||||||
prompt = req.params['prompt']
|
prompt = req.params['prompt']
|
||||||
|
user_config= {
|
||||||
|
key : req.params.get(key, val)
|
||||||
|
for key, val in user_config.items()
|
||||||
|
}
|
||||||
req = ImageGenRequest(
|
req = ImageGenRequest(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
**user_config
|
**user_config
|
||||||
|
@ -165,9 +248,10 @@ async def rpc_service(sock, dgpu_bus, db_pool):
|
||||||
case _:
|
case _:
|
||||||
logging.warn('unknown method')
|
logging.warn('unknown method')
|
||||||
|
|
||||||
except SkynetDGPUOffline:
|
except SkynetDGPUOffline as e:
|
||||||
result = {
|
result = {
|
||||||
'error': 'skynet_dgpu_offline'
|
'error': 'skynet_dgpu_offline',
|
||||||
|
'message': str(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
except SkynetDGPUOverloaded:
|
except SkynetDGPUOverloaded:
|
||||||
|
@ -176,22 +260,22 @@ async def rpc_service(sock, dgpu_bus, db_pool):
|
||||||
'nodes': len(nodes)
|
'nodes': len(nodes)
|
||||||
}
|
}
|
||||||
|
|
||||||
except BaseException as e:
|
except SkynetDGPUComputeError as e:
|
||||||
logging.error(e)
|
|
||||||
result = {
|
result = {
|
||||||
'error': 'skynet_internal_error'
|
'error': 'skynet_dgpu_compute_error',
|
||||||
|
'message': str(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
await rpc_ctx.asend(
|
await rpc_ctx.asend(
|
||||||
json.dumps(
|
json.dumps(
|
||||||
SkynetRPCResponse(result=result).to_dict()).encode())
|
SkynetRPCResponse(result=result).to_dict()).encode())
|
||||||
|
|
||||||
|
async def request_service(n):
|
||||||
async with trio.open_nursery() as n:
|
nonlocal next_worker
|
||||||
n.start_soon(dgpu_image_streamer)
|
|
||||||
while True:
|
while True:
|
||||||
ctx = sock.new_context()
|
ctx = sock.new_context()
|
||||||
msg = await ctx.arecv_msg()
|
msg = await ctx.arecv_msg()
|
||||||
|
|
||||||
content = msg.bytes.decode()
|
content = msg.bytes.decode()
|
||||||
req = SkynetRPCRequest(**json.loads(content))
|
req = SkynetRPCRequest(**json.loads(content))
|
||||||
|
|
||||||
|
@ -199,27 +283,14 @@ async def rpc_service(sock, dgpu_bus, db_pool):
|
||||||
|
|
||||||
result = {}
|
result = {}
|
||||||
|
|
||||||
if req.method == 'dgpu_online':
|
if req.method == 'skynet_shutdown':
|
||||||
nodes[req.uid] = {
|
raise SkynetShutdownRequested
|
||||||
'tasks': [None for _ in range(req.params['max_tasks'])],
|
|
||||||
'max_tasks': req.params['max_tasks']
|
|
||||||
}
|
|
||||||
logging.info(f'dgpu online: {req.uid}')
|
|
||||||
|
|
||||||
if not next_worker:
|
elif req.method == 'dgpu_online':
|
||||||
next_worker = 0
|
connect_node(req.uid)
|
||||||
|
|
||||||
elif req.method == 'dgpu_offline':
|
elif req.method == 'dgpu_offline':
|
||||||
i = list(nodes.keys()).index(req.uid)
|
disconnect_node(req.uid)
|
||||||
del nodes[req.uid]
|
|
||||||
|
|
||||||
if i < next_worker:
|
|
||||||
next_worker -= 1
|
|
||||||
|
|
||||||
if len(nodes) == 0:
|
|
||||||
next_worker = None
|
|
||||||
|
|
||||||
logging.info(f'dgpu offline: {req.uid}')
|
|
||||||
|
|
||||||
elif req.method == 'dgpu_workers':
|
elif req.method == 'dgpu_workers':
|
||||||
result = len(nodes)
|
result = len(nodes)
|
||||||
|
@ -238,13 +309,22 @@ async def rpc_service(sock, dgpu_bus, db_pool):
|
||||||
result={'ok': result}).to_dict()).encode())
|
result={'ok': result}).to_dict()).encode())
|
||||||
|
|
||||||
|
|
||||||
|
async with trio.open_nursery() as n:
|
||||||
|
n.start_soon(dgpu_image_streamer)
|
||||||
|
n.start_soon(request_service, n)
|
||||||
|
logging.info('starting rpc service')
|
||||||
|
yield
|
||||||
|
logging.info('stopping rpc service')
|
||||||
|
n.cancel_scope.cancel()
|
||||||
|
|
||||||
|
|
||||||
|
@acm
|
||||||
async def run_skynet(
|
async def run_skynet(
|
||||||
db_user: str = DB_USER,
|
db_user: str = DB_USER,
|
||||||
db_pass: str = DB_PASS,
|
db_pass: str = DB_PASS,
|
||||||
db_host: str = DB_HOST,
|
db_host: str = DB_HOST,
|
||||||
rpc_address: str = DEFAULT_RPC_ADDR,
|
rpc_address: str = DEFAULT_RPC_ADDR,
|
||||||
dgpu_address: str = DEFAULT_DGPU_ADDR,
|
dgpu_address: str = DEFAULT_DGPU_ADDR,
|
||||||
task_status = trio.TASK_STATUS_IGNORED,
|
|
||||||
security: bool = True
|
security: bool = True
|
||||||
):
|
):
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
@ -260,8 +340,8 @@ async def run_skynet(
|
||||||
(cert_path).read_text()
|
(cert_path).read_text()
|
||||||
for cert_path in (certs_dir / 'whitelist').glob('*.cert')]
|
for cert_path in (certs_dir / 'whitelist').glob('*.cert')]
|
||||||
|
|
||||||
logging.info(f'tls_key: {tls_key}')
|
cert_start = tls_cert.index('\n') + 1
|
||||||
logging.info(f'tls_cert: {tls_cert}')
|
logging.info(f'tls_cert: {tls_cert[cert_start:cert_start+64]}...')
|
||||||
logging.info(f'tls_whitelist len: {len(tls_whitelist)}')
|
logging.info(f'tls_whitelist len: {len(tls_whitelist)}')
|
||||||
|
|
||||||
rpc_address = 'tls+' + rpc_address
|
rpc_address = 'tls+' + rpc_address
|
||||||
|
@ -271,16 +351,14 @@ async def run_skynet(
|
||||||
own_key_string=tls_key,
|
own_key_string=tls_key,
|
||||||
own_cert_string=tls_cert)
|
own_cert_string=tls_cert)
|
||||||
|
|
||||||
async with (
|
with (
|
||||||
trio.open_nursery() as n,
|
pynng.Rep0() as rpc_sock,
|
||||||
open_database_connection(
|
pynng.Bus0() as dgpu_bus
|
||||||
db_user, db_pass, db_host) as db_pool
|
|
||||||
):
|
):
|
||||||
logging.info('connected to db.')
|
async with open_database_connection(
|
||||||
with (
|
db_user, db_pass, db_host) as db_pool:
|
||||||
pynng.Rep0() as rpc_sock,
|
|
||||||
pynng.Bus0() as dgpu_bus
|
logging.info('connected to db.')
|
||||||
):
|
|
||||||
if security:
|
if security:
|
||||||
rpc_sock.tls_config = tls_config
|
rpc_sock.tls_config = tls_config
|
||||||
dgpu_bus.tls_config = tls_config
|
dgpu_bus.tls_config = tls_config
|
||||||
|
@ -288,13 +366,11 @@ async def run_skynet(
|
||||||
rpc_sock.listen(rpc_address)
|
rpc_sock.listen(rpc_address)
|
||||||
dgpu_bus.listen(dgpu_address)
|
dgpu_bus.listen(dgpu_address)
|
||||||
|
|
||||||
n.start_soon(
|
|
||||||
rpc_service, rpc_sock, dgpu_bus, db_pool)
|
|
||||||
task_status.started()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await trio.sleep_forever()
|
async with open_rpc_service(rpc_sock, dgpu_bus, db_pool):
|
||||||
|
yield
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except SkynetShutdownRequested:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
logging.info('disconnected from db.')
|
|
@ -0,0 +1,68 @@
|
||||||
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import trio
|
||||||
|
import click
|
||||||
|
|
||||||
|
from .dgpu import open_dgpu_node
|
||||||
|
from .utils import txt2img
|
||||||
|
from .constants import ALGOS
|
||||||
|
|
||||||
|
|
||||||
|
@click.group()
|
||||||
|
def skynet(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@skynet.command()
|
||||||
|
@click.option('--model', '-m', default=ALGOS['midj'])
|
||||||
|
@click.option(
|
||||||
|
'--prompt', '-p', default='a red tractor in a wheat field')
|
||||||
|
@click.option('--output', '-o', default='output.png')
|
||||||
|
@click.option('--width', '-w', default=512)
|
||||||
|
@click.option('--height', '-h', default=512)
|
||||||
|
@click.option('--guidance', '-g', default=10.0)
|
||||||
|
@click.option('--steps', '-s', default=26)
|
||||||
|
@click.option('--seed', '-S', default=None)
|
||||||
|
def txt2img(*args
|
||||||
|
# model: str,
|
||||||
|
# prompt: str,
|
||||||
|
# output: str
|
||||||
|
# width: int, height: int,
|
||||||
|
# guidance: float,
|
||||||
|
# steps: int,
|
||||||
|
# seed: Optional[int]
|
||||||
|
):
|
||||||
|
assert 'HF_TOKEN' in os.environ
|
||||||
|
txt2img(
|
||||||
|
os.environ['HF_TOKEN'], *args)
|
||||||
|
|
||||||
|
@skynet.group()
|
||||||
|
def run(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@run.command()
|
||||||
|
@click.option('--loglevel', '-l', default='warning', help='Logging level')
|
||||||
|
@click.option(
|
||||||
|
'--key', '-k', default='dgpu')
|
||||||
|
@click.option(
|
||||||
|
'--cert', '-c', default='whitelist/dgpu')
|
||||||
|
@click.option(
|
||||||
|
'--algos', '-a', default=None)
|
||||||
|
def dgpu(
|
||||||
|
loglevel: str,
|
||||||
|
key: str,
|
||||||
|
cert: str,
|
||||||
|
algos: Optional[str]
|
||||||
|
):
|
||||||
|
trio.run(
|
||||||
|
partial(
|
||||||
|
open_dgpu_node,
|
||||||
|
cert,
|
||||||
|
key_name=key,
|
||||||
|
initial_algos=json.loads(algos)
|
||||||
|
))
|
|
@ -1,6 +1,6 @@
|
||||||
#!/usr/bin/python
|
#!/usr/bin/python
|
||||||
|
|
||||||
API_TOKEN = '5880619053:AAFge2UfObw1kCn9Kb7AAyqaIHs_HgM0Fx0'
|
DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda'
|
||||||
|
|
||||||
DB_HOST = 'ancap.tech:34508'
|
DB_HOST = 'ancap.tech:34508'
|
||||||
DB_USER = 'skynet'
|
DB_USER = 'skynet'
|
||||||
|
@ -8,8 +8,8 @@ DB_PASS = 'password'
|
||||||
DB_NAME = 'skynet'
|
DB_NAME = 'skynet'
|
||||||
|
|
||||||
ALGOS = {
|
ALGOS = {
|
||||||
'stable': 'runwayml/stable-diffusion-v1-5',
|
|
||||||
'midj': 'prompthero/openjourney',
|
'midj': 'prompthero/openjourney',
|
||||||
|
'stable': 'runwayml/stable-diffusion-v1-5',
|
||||||
'hdanime': 'Linaqruf/anything-v3.0',
|
'hdanime': 'Linaqruf/anything-v3.0',
|
||||||
'waifu': 'hakurei/waifu-diffusion',
|
'waifu': 'hakurei/waifu-diffusion',
|
||||||
'ghibli': 'nitrosocke/Ghibli-Diffusion',
|
'ghibli': 'nitrosocke/Ghibli-Diffusion',
|
||||||
|
@ -122,7 +122,7 @@ DEFAULT_CERT_DGPU = 'dgpu.key'
|
||||||
DEFAULT_RPC_ADDR = 'tcp://127.0.0.1:41000'
|
DEFAULT_RPC_ADDR = 'tcp://127.0.0.1:41000'
|
||||||
|
|
||||||
DEFAULT_DGPU_ADDR = 'tcp://127.0.0.1:41069'
|
DEFAULT_DGPU_ADDR = 'tcp://127.0.0.1:41069'
|
||||||
DEFAULT_DGPU_MAX_TASKS = 3
|
DEFAULT_DGPU_MAX_TASKS = 2
|
||||||
DEFAULT_INITAL_ALGOS = ['midj', 'stable', 'ink']
|
DEFAULT_INITAL_ALGOS = ['midj', 'stable', 'ink']
|
||||||
|
|
||||||
DATE_FORMAT = '%B the %dth %Y, %H:%M:%S'
|
DATE_FORMAT = '%B the %dth %Y, %H:%M:%S'
|
|
@ -7,6 +7,9 @@ from contextlib import asynccontextmanager as acm
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
import triopg
|
import triopg
|
||||||
|
import trio_asyncio
|
||||||
|
|
||||||
|
from asyncpg.exceptions import UndefinedColumnError
|
||||||
|
|
||||||
from .constants import *
|
from .constants import *
|
||||||
|
|
||||||
|
@ -72,13 +75,22 @@ async def open_database_connection(
|
||||||
db_host: str = DB_HOST,
|
db_host: str = DB_HOST,
|
||||||
db_name: str = DB_NAME
|
db_name: str = DB_NAME
|
||||||
):
|
):
|
||||||
async with triopg.create_pool(
|
async with trio_asyncio.open_loop() as loop:
|
||||||
dsn=f'postgres://{db_user}:{db_pass}@{db_host}/{db_name}'
|
async with triopg.create_pool(
|
||||||
) as pool_conn:
|
dsn=f'postgres://{db_user}:{db_pass}@{db_host}/{db_name}'
|
||||||
async with pool_conn.acquire() as conn:
|
) as pool_conn:
|
||||||
await conn.execute(DB_INIT_SQL)
|
async with pool_conn.acquire() as conn:
|
||||||
|
res = await conn.execute(f'''
|
||||||
|
select distinct table_schema
|
||||||
|
from information_schema.tables
|
||||||
|
where table_schema = \'{db_name}\'
|
||||||
|
''')
|
||||||
|
if '1' in res:
|
||||||
|
logging.info('schema already in db, skipping init')
|
||||||
|
else:
|
||||||
|
await conn.execute(DB_INIT_SQL)
|
||||||
|
|
||||||
yield pool_conn
|
yield pool_conn
|
||||||
|
|
||||||
|
|
||||||
async def get_user(conn, uid: str):
|
async def get_user(conn, uid: str):
|
||||||
|
@ -135,6 +147,7 @@ async def new_user(conn, uid: str):
|
||||||
tg_id, generated, joined, last_prompt, role)
|
tg_id, generated, joined, last_prompt, role)
|
||||||
|
|
||||||
VALUES($1, $2, $3, $4, $5)
|
VALUES($1, $2, $3, $4, $5)
|
||||||
|
ON CONFLICT DO NOTHING
|
||||||
''')
|
''')
|
||||||
await stmt.fetch(
|
await stmt.fetch(
|
||||||
tg_id, 0, date, None, DEFAULT_ROLE
|
tg_id, 0, date, None, DEFAULT_ROLE
|
||||||
|
@ -147,6 +160,7 @@ async def new_user(conn, uid: str):
|
||||||
id, algo, step, width, height, seed, guidance, upscaler)
|
id, algo, step, width, height, seed, guidance, upscaler)
|
||||||
|
|
||||||
VALUES($1, $2, $3, $4, $5, $6, $7, $8)
|
VALUES($1, $2, $3, $4, $5, $6, $7, $8)
|
||||||
|
ON CONFLICT DO NOTHING
|
||||||
''')
|
''')
|
||||||
user = await stmt.fetch(
|
user = await stmt.fetch(
|
||||||
new_uid,
|
new_uid,
|
|
@ -0,0 +1,197 @@
|
||||||
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import io
|
||||||
|
import trio
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
import random
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
from pathlib import Path
|
||||||
|
from contextlib import AsyncExitStack
|
||||||
|
|
||||||
|
import pynng
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from pynng import TLSConfig
|
||||||
|
from diffusers import (
|
||||||
|
StableDiffusionPipeline,
|
||||||
|
EulerAncestralDiscreteScheduler
|
||||||
|
)
|
||||||
|
|
||||||
|
from .structs import *
|
||||||
|
from .constants import *
|
||||||
|
from .frontend import open_skynet_rpc
|
||||||
|
|
||||||
|
|
||||||
|
def pipeline_for(algo: str, mem_fraction: float = 1.0):
|
||||||
|
assert torch.cuda.is_available()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
params = {
|
||||||
|
'torch_dtype': torch.float16,
|
||||||
|
'safety_checker': None
|
||||||
|
}
|
||||||
|
|
||||||
|
if algo == 'stable':
|
||||||
|
params['revision'] = 'fp16'
|
||||||
|
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
|
ALGOS[algo], **params)
|
||||||
|
|
||||||
|
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
||||||
|
pipe.scheduler.config)
|
||||||
|
|
||||||
|
pipe.enable_vae_slicing()
|
||||||
|
|
||||||
|
return pipe.to("cuda")
|
||||||
|
|
||||||
|
|
||||||
|
class DGPUComputeError(BaseException):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
async def open_dgpu_node(
|
||||||
|
cert_name: str,
|
||||||
|
key_name: Optional[str],
|
||||||
|
rpc_address: str = DEFAULT_RPC_ADDR,
|
||||||
|
dgpu_address: str = DEFAULT_DGPU_ADDR,
|
||||||
|
initial_algos: Optional[List[str]] = None,
|
||||||
|
security: bool = True
|
||||||
|
):
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logging.info(f'starting dgpu node!')
|
||||||
|
|
||||||
|
name = uuid.uuid4()
|
||||||
|
|
||||||
|
logging.info(f'loading models...')
|
||||||
|
|
||||||
|
initial_algos = (
|
||||||
|
initial_algos
|
||||||
|
if initial_algos else DEFAULT_INITAL_ALGOS
|
||||||
|
)
|
||||||
|
models = {}
|
||||||
|
for algo in initial_algos:
|
||||||
|
models[algo] = {
|
||||||
|
'pipe': pipeline_for(algo),
|
||||||
|
'generated': 0
|
||||||
|
}
|
||||||
|
logging.info(f'loaded {algo}.')
|
||||||
|
|
||||||
|
logging.info('memory summary:\n')
|
||||||
|
logging.info(torch.cuda.memory_summary())
|
||||||
|
|
||||||
|
async def gpu_compute_one(ireq: ImageGenRequest):
|
||||||
|
if ireq.algo not in models:
|
||||||
|
least_used = list(models.keys())[0]
|
||||||
|
for model in models:
|
||||||
|
if models[least_used]['generated'] > models[model]['generated']:
|
||||||
|
least_used = model
|
||||||
|
|
||||||
|
del models[least_used]
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
models[ireq.algo] = {
|
||||||
|
'pipe': pipeline_for(ireq.algo),
|
||||||
|
'generated': 0
|
||||||
|
}
|
||||||
|
|
||||||
|
seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64)
|
||||||
|
try:
|
||||||
|
image = models[ireq.algo]['pipe'](
|
||||||
|
ireq.prompt,
|
||||||
|
width=ireq.width,
|
||||||
|
height=ireq.height,
|
||||||
|
guidance_scale=ireq.guidance,
|
||||||
|
num_inference_steps=ireq.step,
|
||||||
|
generator=torch.Generator("cuda").manual_seed(seed)
|
||||||
|
).images[0]
|
||||||
|
return image.tobytes()
|
||||||
|
|
||||||
|
except BaseException as e:
|
||||||
|
logging.error(e)
|
||||||
|
raise DGPUComputeError(str(e))
|
||||||
|
|
||||||
|
finally:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
async with open_skynet_rpc(
|
||||||
|
security=security,
|
||||||
|
cert_name=cert_name,
|
||||||
|
key_name=key_name
|
||||||
|
) as rpc_call:
|
||||||
|
|
||||||
|
tls_config = None
|
||||||
|
if security:
|
||||||
|
# load tls certs
|
||||||
|
if not key_name:
|
||||||
|
key_name = certs_name
|
||||||
|
certs_dir = Path(DEFAULT_CERTS_DIR).resolve()
|
||||||
|
|
||||||
|
skynet_cert_path = certs_dir / 'brain.cert'
|
||||||
|
tls_cert_path = certs_dir / f'{cert_name}.cert'
|
||||||
|
tls_key_path = certs_dir / f'{key_name}.key'
|
||||||
|
|
||||||
|
skynet_cert = skynet_cert_path.read_text()
|
||||||
|
tls_cert = tls_cert_path.read_text()
|
||||||
|
tls_key = tls_key_path.read_text()
|
||||||
|
|
||||||
|
logging.info(f'skynet cert: {skynet_cert_path}')
|
||||||
|
logging.info(f'dgpu cert: {tls_cert_path}')
|
||||||
|
logging.info(f'dgpu key: {tls_key_path}')
|
||||||
|
|
||||||
|
dgpu_address = 'tls+' + dgpu_address
|
||||||
|
tls_config = TLSConfig(
|
||||||
|
TLSConfig.MODE_CLIENT,
|
||||||
|
own_key_string=tls_key,
|
||||||
|
own_cert_string=tls_cert,
|
||||||
|
ca_string=skynet_cert)
|
||||||
|
|
||||||
|
logging.info(f'connecting to {dgpu_address}')
|
||||||
|
with pynng.Bus0() as dgpu_sock:
|
||||||
|
dgpu_sock.tls_config = tls_config
|
||||||
|
dgpu_sock.dial(dgpu_address)
|
||||||
|
|
||||||
|
res = await rpc_call(name.hex, 'dgpu_online')
|
||||||
|
logging.info(res)
|
||||||
|
assert 'ok' in res.result
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
msg = await dgpu_sock.arecv()
|
||||||
|
req = DGPUBusRequest(
|
||||||
|
**json.loads(msg.decode()))
|
||||||
|
|
||||||
|
if req.nid != name.hex:
|
||||||
|
logging.info('witnessed request {req.rid}, for {req.nid}')
|
||||||
|
continue
|
||||||
|
|
||||||
|
# send ack
|
||||||
|
await dgpu_sock.asend(
|
||||||
|
bytes.fromhex(req.rid) + b'ack')
|
||||||
|
|
||||||
|
logging.info(f'sent ack, processing {req.rid}...')
|
||||||
|
|
||||||
|
try:
|
||||||
|
img = await gpu_compute_one(
|
||||||
|
ImageGenRequest(**req.params))
|
||||||
|
|
||||||
|
except DGPUComputeError as e:
|
||||||
|
img = b'error' + str(e).encode()
|
||||||
|
|
||||||
|
await dgpu_sock.asend(
|
||||||
|
bytes.fromhex(req.rid) + img)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logging.info('interrupt caught, stopping...')
|
||||||
|
|
||||||
|
finally:
|
||||||
|
res = await rpc_call(name.hex, 'dgpu_offline')
|
||||||
|
logging.info(res)
|
||||||
|
assert 'ok' in res.result
|
|
@ -10,7 +10,7 @@ import pynng
|
||||||
|
|
||||||
from pynng import TLSConfig
|
from pynng import TLSConfig
|
||||||
|
|
||||||
from ..types import SkynetRPCRequest, SkynetRPCResponse
|
from ..structs import SkynetRPCRequest, SkynetRPCResponse
|
||||||
from ..constants import *
|
from ..constants import *
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from diffusers import StableDiffusionPipeline
|
||||||
|
from huggingface_hub import login
|
||||||
|
|
||||||
|
|
||||||
|
def txt2img(
|
||||||
|
hf_token: str,
|
||||||
|
model_name: str,
|
||||||
|
prompt: str,
|
||||||
|
output: str,
|
||||||
|
width: int, height: int,
|
||||||
|
guidance: float,
|
||||||
|
steps: int,
|
||||||
|
seed: Optional[int]
|
||||||
|
):
|
||||||
|
assert torch.cuda.is_available()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.set_per_process_memory_fraction(0.333)
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
login(token=hf_token)
|
||||||
|
|
||||||
|
params = {
|
||||||
|
'torch_dtype': torch.float16,
|
||||||
|
'safety_checker': None
|
||||||
|
}
|
||||||
|
if model_name == 'runwayml/stable-diffusion-v1-5':
|
||||||
|
params['revision'] = 'fp16'
|
||||||
|
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
|
model_name, **params)
|
||||||
|
|
||||||
|
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
||||||
|
pipe.scheduler.config)
|
||||||
|
|
||||||
|
pipe = pipe.to("cuda")
|
||||||
|
|
||||||
|
seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64)
|
||||||
|
prompt = prompt
|
||||||
|
image = pipe(
|
||||||
|
prompt,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
guidance_scale=guidance, num_inference_steps=steps,
|
||||||
|
generator=torch.Generator("cuda").manual_seed(seed)
|
||||||
|
).images[0]
|
||||||
|
|
||||||
|
image.save(output)
|
|
@ -1,124 +0,0 @@
|
||||||
#!/usr/bin/python
|
|
||||||
|
|
||||||
import trio
|
|
||||||
import json
|
|
||||||
import uuid
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import pynng
|
|
||||||
import tractor
|
|
||||||
|
|
||||||
from . import gpu
|
|
||||||
from .gpu import open_gpu_worker
|
|
||||||
from .types import *
|
|
||||||
from .constants import *
|
|
||||||
from .frontend import rpc_call
|
|
||||||
|
|
||||||
|
|
||||||
async def open_dgpu_node(
|
|
||||||
cert_name: str,
|
|
||||||
key_name: Optional[str],
|
|
||||||
rpc_address: str = DEFAULT_RPC_ADDR,
|
|
||||||
dgpu_address: str = DEFAULT_DGPU_ADDR,
|
|
||||||
dgpu_max_tasks: int = DEFAULT_DGPU_MAX_TASKS,
|
|
||||||
initial_algos: str = DEFAULT_INITAL_ALGOS,
|
|
||||||
security: bool = True
|
|
||||||
):
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
|
|
||||||
name = uuid.uuid4()
|
|
||||||
workers = initial_algos.copy()
|
|
||||||
tasks = [None for _ in range(dgpu_max_tasks)]
|
|
||||||
|
|
||||||
portal_map: dict[int, tractor.Portal]
|
|
||||||
contexts: dict[int, tractor.Context]
|
|
||||||
|
|
||||||
def get_next_worker(need_algo: str):
|
|
||||||
nonlocal workers, tasks
|
|
||||||
for task, algo in zip(workers, tasks):
|
|
||||||
if need_algo == algo and not task:
|
|
||||||
return workers.index(need_algo)
|
|
||||||
|
|
||||||
return tasks.index(None)
|
|
||||||
|
|
||||||
async def gpu_streamer(
|
|
||||||
ctx: tractor.Context,
|
|
||||||
nid: int
|
|
||||||
):
|
|
||||||
nonlocal tasks
|
|
||||||
async with ctx.open_stream() as stream:
|
|
||||||
async for img in stream:
|
|
||||||
tasks[nid]['res'] = img
|
|
||||||
tasks[nid]['event'].set()
|
|
||||||
|
|
||||||
async def gpu_compute_one(ireq: ImageGenRequest):
|
|
||||||
wid = get_next_worker(ireq.algo)
|
|
||||||
event = trio.Event()
|
|
||||||
|
|
||||||
workers[wid] = ireq.algo
|
|
||||||
tasks[wid] = {
|
|
||||||
'res': None, 'event': event}
|
|
||||||
|
|
||||||
await contexts[i].send(ireq)
|
|
||||||
|
|
||||||
await event.wait()
|
|
||||||
|
|
||||||
img = tasks[wid]['res']
|
|
||||||
tasks[wid] = None
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
async with open_skynet_rpc(
|
|
||||||
security=security,
|
|
||||||
cert_name=cert_name,
|
|
||||||
key_name=key_name
|
|
||||||
) as rpc_call:
|
|
||||||
with pynng.Bus0(dial=dgpu_address) as dgpu_sock:
|
|
||||||
async def _process_dgpu_req(req: DGPUBusRequest):
|
|
||||||
img = await gpu_compute_one(
|
|
||||||
ImageGenRequest(**req.params))
|
|
||||||
await dgpu_sock.asend(
|
|
||||||
bytes.fromhex(req.rid) + img)
|
|
||||||
|
|
||||||
res = await rpc_call(
|
|
||||||
name.hex, 'dgpu_online', {'max_tasks': dgpu_max_tasks})
|
|
||||||
logging.info(res)
|
|
||||||
assert 'ok' in res.result
|
|
||||||
|
|
||||||
async with (
|
|
||||||
tractor.open_actor_cluster(
|
|
||||||
modules=['skynet_bot.gpu'],
|
|
||||||
count=dgpu_max_tasks,
|
|
||||||
names=[i for i in range(dgpu_max_tasks)]
|
|
||||||
) as portal_map,
|
|
||||||
trio.open_nursery() as n
|
|
||||||
):
|
|
||||||
logging.info(f'starting {dgpu_max_tasks} gpu workers')
|
|
||||||
async with tractor.gather_contexts((
|
|
||||||
portal.open_context(
|
|
||||||
open_gpu_worker, algo, 1.0 / dgpu_max_tasks)
|
|
||||||
for portal in portal_map.values()
|
|
||||||
)) as contexts:
|
|
||||||
contexts = {i: ctx for i, ctx in enumerate(contexts)}
|
|
||||||
for i, ctx in contexts.items():
|
|
||||||
n.start_soon(
|
|
||||||
gpu_streamer, ctx, i)
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
msg = await dgpu_sock.arecv()
|
|
||||||
req = DGPUBusRequest(
|
|
||||||
**json.loads(msg.decode()))
|
|
||||||
|
|
||||||
if req.nid != name.hex:
|
|
||||||
continue
|
|
||||||
|
|
||||||
logging.info(f'dgpu: {name}, req: {req}')
|
|
||||||
n.start_soon(
|
|
||||||
_process_dgpu_req, req)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
...
|
|
||||||
|
|
||||||
res = await rpc_call(name.hex, 'dgpu_offline')
|
|
||||||
logging.info(res)
|
|
||||||
assert 'ok' in res.result
|
|
|
@ -1,77 +0,0 @@
|
||||||
#!/usr/bin/python
|
|
||||||
|
|
||||||
import io
|
|
||||||
import random
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import tractor
|
|
||||||
|
|
||||||
from diffusers import (
|
|
||||||
StableDiffusionPipeline,
|
|
||||||
EulerAncestralDiscreteScheduler
|
|
||||||
)
|
|
||||||
|
|
||||||
from .types import ImageGenRequest
|
|
||||||
from .constants import ALGOS
|
|
||||||
|
|
||||||
|
|
||||||
def pipeline_for(algo: str, mem_fraction: float):
|
|
||||||
assert torch.cuda.is_available()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
|
||||||
|
|
||||||
params = {
|
|
||||||
'torch_dtype': torch.float16,
|
|
||||||
'safety_checker': None
|
|
||||||
}
|
|
||||||
|
|
||||||
if algo == 'stable':
|
|
||||||
params['revision'] = 'fp16'
|
|
||||||
|
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(
|
|
||||||
ALGOS[algo], **params)
|
|
||||||
|
|
||||||
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
|
||||||
pipe.scheduler.config)
|
|
||||||
|
|
||||||
return pipe.to("cuda")
|
|
||||||
|
|
||||||
@tractor.context
|
|
||||||
async def open_gpu_worker(
|
|
||||||
ctx: tractor.Context,
|
|
||||||
start_algo: str,
|
|
||||||
mem_fraction: float
|
|
||||||
):
|
|
||||||
log = tractor.log.get_logger(name='gpu', _root_name='skynet')
|
|
||||||
log.info(f'starting gpu worker with algo {start_algo}...')
|
|
||||||
current_algo = start_algo
|
|
||||||
with torch.no_grad():
|
|
||||||
pipe = pipeline_for(current_algo, mem_fraction)
|
|
||||||
log.info('pipeline loaded')
|
|
||||||
await ctx.started()
|
|
||||||
async with ctx.open_stream() as bus:
|
|
||||||
async for ireq in bus:
|
|
||||||
if ireq.algo != current_algo:
|
|
||||||
current_algo = ireq.algo
|
|
||||||
pipe = pipeline_for(current_algo, mem_fraction)
|
|
||||||
|
|
||||||
seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64)
|
|
||||||
image = pipe(
|
|
||||||
ireq.prompt,
|
|
||||||
width=ireq.width,
|
|
||||||
height=ireq.height,
|
|
||||||
guidance_scale=ireq.guidance,
|
|
||||||
num_inference_steps=ireq.step,
|
|
||||||
generator=torch.Generator("cuda").manual_seed(seed)
|
|
||||||
).images[0]
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# convert PIL.Image to BytesIO
|
|
||||||
img_bytes = io.BytesIO()
|
|
||||||
image.save(img_bytes, format='PNG')
|
|
||||||
await bus.send(img_bytes.getvalue())
|
|
||||||
|
|
|
@ -1,2 +0,0 @@
|
||||||
from OpenSSL.crypto import load_publickey, FILETYPE_PEM, verify, X509
|
|
||||||
|
|
9
test.sh
9
test.sh
|
@ -1,9 +0,0 @@
|
||||||
docker run \
|
|
||||||
-it \
|
|
||||||
--rm \
|
|
||||||
--gpus=all \
|
|
||||||
--mount type=bind,source="$(pwd)",target=/skynet \
|
|
||||||
skynet:runtime-cuda \
|
|
||||||
bash -c \
|
|
||||||
"cd /skynet && pip install -e . && \
|
|
||||||
pytest $1 --log-cli-level=info"
|
|
|
@ -1,21 +1,25 @@
|
||||||
#!/usr/bin/python
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
import pytest
|
import pytest
|
||||||
import psycopg2
|
import psycopg2
|
||||||
import trio_asyncio
|
import trio_asyncio
|
||||||
|
|
||||||
|
from docker.types import Mount, DeviceRequest
|
||||||
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
|
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
|
||||||
|
|
||||||
from skynet_bot.constants import *
|
from skynet.constants import *
|
||||||
from skynet_bot.brain import run_skynet
|
from skynet.brain import run_skynet
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope='session')
|
@pytest.fixture(scope='session')
|
||||||
|
@ -29,6 +33,7 @@ def postgres_db(dockerctl):
|
||||||
|
|
||||||
with dockerctl.run(
|
with dockerctl.run(
|
||||||
'postgres',
|
'postgres',
|
||||||
|
name='skynet-test-postgres',
|
||||||
ports={'5432/tcp': None},
|
ports={'5432/tcp': None},
|
||||||
environment={
|
environment={
|
||||||
'POSTGRES_PASSWORD': rpassword
|
'POSTGRES_PASSWORD': rpassword
|
||||||
|
@ -67,6 +72,8 @@ def postgres_db(dockerctl):
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
f'GRANT ALL PRIVILEGES ON DATABASE {DB_NAME} TO {DB_USER}')
|
f'GRANT ALL PRIVILEGES ON DATABASE {DB_NAME} TO {DB_USER}')
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
|
||||||
logging.info('done.')
|
logging.info('done.')
|
||||||
yield container, password, host
|
yield container, password, host
|
||||||
|
|
||||||
|
@ -74,16 +81,44 @@ def postgres_db(dockerctl):
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def skynet_running(postgres_db):
|
async def skynet_running(postgres_db):
|
||||||
db_container, db_pass, db_host = postgres_db
|
db_container, db_pass, db_host = postgres_db
|
||||||
async with (
|
|
||||||
trio_asyncio.open_loop(),
|
async with run_skynet(
|
||||||
trio.open_nursery() as n
|
db_pass=db_pass,
|
||||||
|
db_host=db_host
|
||||||
):
|
):
|
||||||
await n.start(
|
|
||||||
partial(run_skynet,
|
|
||||||
db_pass=db_pass,
|
|
||||||
db_host=db_host))
|
|
||||||
|
|
||||||
yield
|
yield
|
||||||
n.cancel_scope.cancel()
|
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dgpu_workers(request, dockerctl, skynet_running):
|
||||||
|
devices = [DeviceRequest(capabilities=[['gpu']])]
|
||||||
|
mounts = [Mount(
|
||||||
|
'/skynet', str(Path().resolve()), type='bind')]
|
||||||
|
|
||||||
|
num_containers, initial_algos = request.param
|
||||||
|
|
||||||
|
cmd = f'''
|
||||||
|
pip install -e . && \
|
||||||
|
skynet run dgpu --algos=\'{json.dumps(initial_algos)}\'
|
||||||
|
'''
|
||||||
|
|
||||||
|
logging.info(f'launching: \n{cmd}')
|
||||||
|
|
||||||
|
with dockerctl.run(
|
||||||
|
DOCKER_RUNTIME_CUDA,
|
||||||
|
name='skynet-test-runtime-cuda',
|
||||||
|
command=['bash', '-c', cmd],
|
||||||
|
environment={
|
||||||
|
'HF_TOKEN': os.environ['HF_TOKEN'],
|
||||||
|
'HF_HOME': '/skynet/hf_home'
|
||||||
|
},
|
||||||
|
network='host',
|
||||||
|
mounts=mounts,
|
||||||
|
device_requests=devices,
|
||||||
|
num=num_containers
|
||||||
|
) as containers:
|
||||||
|
yield containers
|
||||||
|
|
||||||
|
#for i, container in enumerate(containers):
|
||||||
|
# logging.info(f'container {i} logs:')
|
||||||
|
# logging.info(container.logs().decode())
|
||||||
|
|
|
@ -1,57 +1,248 @@
|
||||||
#!/usr/bin/python
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
import io
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
|
import base64
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from hashlib import sha256
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
import pynng
|
import pytest
|
||||||
import tractor
|
import tractor
|
||||||
import trio_asyncio
|
import trio_asyncio
|
||||||
|
|
||||||
from skynet_bot.gpu import open_gpu_worker
|
from PIL import Image
|
||||||
from skynet_bot.dgpu import open_dgpu_node
|
|
||||||
from skynet_bot.types import *
|
from skynet.brain import SkynetDGPUComputeError
|
||||||
from skynet_bot.brain import run_skynet
|
from skynet.constants import *
|
||||||
from skynet_bot.constants import *
|
from skynet.frontend import open_skynet_rpc
|
||||||
from skynet_bot.frontend import open_skynet_rpc, rpc_call
|
|
||||||
|
|
||||||
|
|
||||||
def test_dgpu_simple():
|
async def wait_for_dgpus(rpc, amount: int, timeout: float = 30.0):
|
||||||
async def main():
|
gpu_ready = False
|
||||||
|
start_time = time.time()
|
||||||
|
current_time = time.time()
|
||||||
|
while not gpu_ready and (current_time - start_time) < timeout:
|
||||||
|
res = await rpc('dgpu-test', 'dgpu_workers')
|
||||||
|
logging.info(res)
|
||||||
|
if res.result['ok'] >= amount:
|
||||||
|
break
|
||||||
|
|
||||||
|
await trio.sleep(1)
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
assert (current_time - start_time) < timeout
|
||||||
|
|
||||||
|
|
||||||
|
_images = set()
|
||||||
|
async def check_request_img(
|
||||||
|
i: int,
|
||||||
|
width: int = 512,
|
||||||
|
height: int = 512,
|
||||||
|
expect_unique=True
|
||||||
|
):
|
||||||
|
global _images
|
||||||
|
|
||||||
|
async with open_skynet_rpc(
|
||||||
|
security=True,
|
||||||
|
cert_name='whitelist/testing',
|
||||||
|
key_name='testing'
|
||||||
|
) as rpc_call:
|
||||||
|
res = await rpc_call(
|
||||||
|
'tg+580213293', 'txt2img', {
|
||||||
|
'prompt': 'red old tractor in a sunny wheat field',
|
||||||
|
'step': 28,
|
||||||
|
'width': width, 'height': height,
|
||||||
|
'guidance': 7.5,
|
||||||
|
'seed': None,
|
||||||
|
'algo': list(ALGOS.keys())[i],
|
||||||
|
'upscaler': None
|
||||||
|
})
|
||||||
|
|
||||||
|
if 'error' in res.result:
|
||||||
|
raise SkynetDGPUComputeError(json.dumps(res.result))
|
||||||
|
|
||||||
|
img_raw = base64.b64decode(bytes.fromhex(res.result['img']))
|
||||||
|
img_sha = sha256(img_raw).hexdigest()
|
||||||
|
img = Image.frombytes(
|
||||||
|
'RGB', (width, height), img_raw)
|
||||||
|
|
||||||
|
if expect_unique and img_sha in _images:
|
||||||
|
raise ValueError('Duplicated image sha: {img_sha}')
|
||||||
|
|
||||||
|
_images.add(img_sha)
|
||||||
|
|
||||||
|
logging.info(f'img sha256: {img_sha} size: {len(img_raw)}')
|
||||||
|
|
||||||
|
assert len(img_raw) > 100000
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'dgpu_workers', [(1, ['midj'])], indirect=True)
|
||||||
|
async def test_dgpu_worker_compute_error(dgpu_workers):
|
||||||
|
'''Attempt to generate a huge image and check we get the right error,
|
||||||
|
then generate a smaller image to show gpu worker recovery
|
||||||
|
'''
|
||||||
|
|
||||||
|
async with open_skynet_rpc(
|
||||||
|
security=True,
|
||||||
|
cert_name='whitelist/testing',
|
||||||
|
key_name='testing'
|
||||||
|
) as test_rpc:
|
||||||
|
await wait_for_dgpus(test_rpc, 1)
|
||||||
|
|
||||||
|
with pytest.raises(SkynetDGPUComputeError) as e:
|
||||||
|
await check_request_img(0, width=4096, height=4096)
|
||||||
|
|
||||||
|
logging.info(e)
|
||||||
|
|
||||||
|
await check_request_img(0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'dgpu_workers', [(1, ['midj', 'stable'])], indirect=True)
|
||||||
|
async def test_dgpu_workers(dgpu_workers):
|
||||||
|
'''Generate two images in a single dgpu worker using
|
||||||
|
two different models.
|
||||||
|
'''
|
||||||
|
|
||||||
|
async with open_skynet_rpc(
|
||||||
|
security=True,
|
||||||
|
cert_name='whitelist/testing',
|
||||||
|
key_name='testing'
|
||||||
|
) as test_rpc:
|
||||||
|
await wait_for_dgpus(test_rpc, 1)
|
||||||
|
|
||||||
|
await check_request_img(0)
|
||||||
|
await check_request_img(1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'dgpu_workers', [(2, ['midj'])], indirect=True)
|
||||||
|
async def test_dgpu_workers_two(dgpu_workers):
|
||||||
|
'''Generate two images in two separate dgpu workers
|
||||||
|
'''
|
||||||
|
async with open_skynet_rpc(
|
||||||
|
security=True,
|
||||||
|
cert_name='whitelist/testing',
|
||||||
|
key_name='testing'
|
||||||
|
) as test_rpc:
|
||||||
|
await wait_for_dgpus(test_rpc, 2)
|
||||||
|
|
||||||
async with trio.open_nursery() as n:
|
async with trio.open_nursery() as n:
|
||||||
await n.start(
|
n.start_soon(check_request_img, 0)
|
||||||
run_skynet,
|
n.start_soon(check_request_img, 0)
|
||||||
'skynet', '3GbZd6UojbD8V7zWpeFn', 'ancap.tech:34508')
|
|
||||||
|
|
||||||
await trio.sleep(2)
|
|
||||||
|
|
||||||
for i in range(3):
|
|
||||||
n.start_soon(open_dgpu_node)
|
|
||||||
|
|
||||||
await trio.sleep(1)
|
|
||||||
start = time.time()
|
|
||||||
async def request_img():
|
|
||||||
with pynng.Req0(dial=DEFAULT_RPC_ADDR) as rpc_sock:
|
|
||||||
res = await rpc_call(
|
|
||||||
rpc_sock, 'tg+1', 'txt2img', {
|
|
||||||
'prompt': 'test',
|
|
||||||
'step': 28,
|
|
||||||
'width': 512, 'height': 512,
|
|
||||||
'guidance': 7.5,
|
|
||||||
'seed': None,
|
|
||||||
'algo': 'stable',
|
|
||||||
'upscaler': None
|
|
||||||
})
|
|
||||||
|
|
||||||
logging.info(res)
|
|
||||||
|
|
||||||
async with trio.open_nursery() as inner_n:
|
|
||||||
for i in range(3):
|
|
||||||
inner_n.start_soon(request_img)
|
|
||||||
|
|
||||||
logging.info(f'time elapsed: {time.time() - start}')
|
|
||||||
n.cancel_scope.cancel()
|
|
||||||
|
|
||||||
|
|
||||||
trio_asyncio.run(main)
|
@pytest.mark.parametrize(
|
||||||
|
'dgpu_workers', [(1, ['midj'])], indirect=True)
|
||||||
|
async def test_dgpu_worker_algo_swap(dgpu_workers):
|
||||||
|
'''Generate an image using a non default model
|
||||||
|
'''
|
||||||
|
async with open_skynet_rpc(
|
||||||
|
security=True,
|
||||||
|
cert_name='whitelist/testing',
|
||||||
|
key_name='testing'
|
||||||
|
) as test_rpc:
|
||||||
|
await wait_for_dgpus(test_rpc, 1)
|
||||||
|
await check_request_img(5)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'dgpu_workers', [(3, ['midj'])], indirect=True)
|
||||||
|
async def test_dgpu_rotation_next_worker(dgpu_workers):
|
||||||
|
'''Connect three dgpu workers, disconnect and check next_worker
|
||||||
|
rotation happens correctly
|
||||||
|
'''
|
||||||
|
async with open_skynet_rpc(
|
||||||
|
security=True,
|
||||||
|
cert_name='whitelist/testing',
|
||||||
|
key_name='testing'
|
||||||
|
) as test_rpc:
|
||||||
|
await wait_for_dgpus(test_rpc, 3)
|
||||||
|
|
||||||
|
res = await test_rpc('testing-rpc', 'dgpu_next')
|
||||||
|
logging.info(res)
|
||||||
|
assert 'ok' in res.result
|
||||||
|
assert res.result['ok'] == 0
|
||||||
|
|
||||||
|
await check_request_img(0)
|
||||||
|
|
||||||
|
res = await test_rpc('testing-rpc', 'dgpu_next')
|
||||||
|
logging.info(res)
|
||||||
|
assert 'ok' in res.result
|
||||||
|
assert res.result['ok'] == 1
|
||||||
|
|
||||||
|
await check_request_img(0)
|
||||||
|
|
||||||
|
res = await test_rpc('testing-rpc', 'dgpu_next')
|
||||||
|
logging.info(res)
|
||||||
|
assert 'ok' in res.result
|
||||||
|
assert res.result['ok'] == 2
|
||||||
|
|
||||||
|
await check_request_img(0)
|
||||||
|
|
||||||
|
res = await test_rpc('testing-rpc', 'dgpu_next')
|
||||||
|
logging.info(res)
|
||||||
|
assert 'ok' in res.result
|
||||||
|
assert res.result['ok'] == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'dgpu_workers', [(3, ['midj'])], indirect=True)
|
||||||
|
async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers):
|
||||||
|
'''Connect three dgpu workers, disconnect the first one and check
|
||||||
|
next_worker rotation happens correctly
|
||||||
|
'''
|
||||||
|
async with open_skynet_rpc(
|
||||||
|
security=True,
|
||||||
|
cert_name='whitelist/testing',
|
||||||
|
key_name='testing'
|
||||||
|
) as test_rpc:
|
||||||
|
await wait_for_dgpus(test_rpc, 3)
|
||||||
|
|
||||||
|
await trio.sleep(3)
|
||||||
|
|
||||||
|
# stop worker who's turn is next
|
||||||
|
for _ in range(2):
|
||||||
|
ec, out = dgpu_workers[0].exec_run(['pkill', '-INT', '-f', 'skynet'])
|
||||||
|
assert ec == 0
|
||||||
|
|
||||||
|
dgpu_workers[0].wait()
|
||||||
|
|
||||||
|
res = await test_rpc('testing-rpc', 'dgpu_workers')
|
||||||
|
logging.info(res)
|
||||||
|
assert 'ok' in res.result
|
||||||
|
assert res.result['ok'] == 2
|
||||||
|
|
||||||
|
async with trio.open_nursery() as n:
|
||||||
|
n.start_soon(check_request_img, 0)
|
||||||
|
n.start_soon(check_request_img, 0)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_dgpu_no_ack_node_disconnect(skynet_running):
|
||||||
|
async with open_skynet_rpc(
|
||||||
|
security=True,
|
||||||
|
cert_name='whitelist/testing',
|
||||||
|
key_name='testing'
|
||||||
|
) as rpc_call:
|
||||||
|
|
||||||
|
res = await rpc_call('dgpu-0', 'dgpu_online')
|
||||||
|
logging.info(res)
|
||||||
|
assert 'ok' in res.result
|
||||||
|
|
||||||
|
await wait_for_dgpus(rpc_call, 1)
|
||||||
|
|
||||||
|
with pytest.raises(SkynetDGPUComputeError) as e:
|
||||||
|
await check_request_img(0)
|
||||||
|
|
||||||
|
assert 'dgpu failed to acknowledge request' in str(e)
|
||||||
|
|
||||||
|
res = await rpc_call('testing-rpc', 'dgpu_workers')
|
||||||
|
logging.info(res)
|
||||||
|
assert 'ok' in res.result
|
||||||
|
assert res.result['ok'] == 0
|
||||||
|
|
||||||
|
|
|
@ -1,107 +0,0 @@
|
||||||
import trio
|
|
||||||
import tractor
|
|
||||||
|
|
||||||
from skynet_bot.types import *
|
|
||||||
|
|
||||||
@tractor.context
|
|
||||||
async def open_fake_worker(
|
|
||||||
ctx: tractor.Context,
|
|
||||||
start_algo: str,
|
|
||||||
mem_fraction: float
|
|
||||||
):
|
|
||||||
log = tractor.log.get_logger(name='gpu', _root_name='skynet')
|
|
||||||
log.info(f'starting gpu worker with algo {start_algo}...')
|
|
||||||
current_algo = start_algo
|
|
||||||
log.info('pipeline loaded')
|
|
||||||
await ctx.started()
|
|
||||||
async with ctx.open_stream() as bus:
|
|
||||||
async for ireq in bus:
|
|
||||||
if ireq:
|
|
||||||
await bus.send('hello!')
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
def test_gpu_worker():
|
|
||||||
log = tractor.log.get_logger(name='root', _root_name='skynet')
|
|
||||||
async def main():
|
|
||||||
async with (
|
|
||||||
tractor.open_nursery(debug_mode=True) as an,
|
|
||||||
trio.open_nursery() as n
|
|
||||||
):
|
|
||||||
portal = await an.start_actor(
|
|
||||||
'gpu_worker',
|
|
||||||
enable_modules=[__name__],
|
|
||||||
debug_mode=True
|
|
||||||
)
|
|
||||||
|
|
||||||
log.info('portal opened')
|
|
||||||
async with (
|
|
||||||
portal.open_context(
|
|
||||||
open_fake_worker,
|
|
||||||
start_algo='midj',
|
|
||||||
mem_fraction=0.6
|
|
||||||
) as (ctx, _),
|
|
||||||
ctx.open_stream() as stream,
|
|
||||||
):
|
|
||||||
log.info('opened worker sending req...')
|
|
||||||
ireq = ImageGenRequest(
|
|
||||||
prompt='a red tractor on a wheat field',
|
|
||||||
step=28,
|
|
||||||
width=512, height=512,
|
|
||||||
guidance=10, seed=None,
|
|
||||||
algo='midj', upscaler=None)
|
|
||||||
|
|
||||||
await stream.send(ireq)
|
|
||||||
log.info('sent, await respnse')
|
|
||||||
async for msg in stream:
|
|
||||||
log.info(f'got {msg}')
|
|
||||||
break
|
|
||||||
|
|
||||||
assert msg == 'hello!'
|
|
||||||
await stream.send(None)
|
|
||||||
log.info('done.')
|
|
||||||
|
|
||||||
await portal.cancel_actor()
|
|
||||||
|
|
||||||
trio.run(main)
|
|
||||||
|
|
||||||
|
|
||||||
def test_gpu_two_workers():
|
|
||||||
async def main():
|
|
||||||
outputs = []
|
|
||||||
async with (
|
|
||||||
tractor.open_actor_cluster(
|
|
||||||
modules=[__name__],
|
|
||||||
count=2,
|
|
||||||
names=[0, 1]) as portal_map,
|
|
||||||
tractor.trionics.gather_contexts((
|
|
||||||
portal.open_context(
|
|
||||||
open_fake_worker,
|
|
||||||
start_algo='midj',
|
|
||||||
mem_fraction=0.333)
|
|
||||||
for portal in portal_map.values()
|
|
||||||
)) as contexts,
|
|
||||||
trio.open_nursery() as n
|
|
||||||
):
|
|
||||||
ireq = ImageGenRequest(
|
|
||||||
prompt='a red tractor on a wheat field',
|
|
||||||
step=28,
|
|
||||||
width=512, height=512,
|
|
||||||
guidance=10, seed=None,
|
|
||||||
algo='midj', upscaler=None)
|
|
||||||
|
|
||||||
async def get_img(i):
|
|
||||||
ctx = contexts[i]
|
|
||||||
async with ctx.open_stream() as stream:
|
|
||||||
await stream.send(ireq)
|
|
||||||
async for img in stream:
|
|
||||||
outputs[i] = img
|
|
||||||
await portal_map[i].cancel_actor()
|
|
||||||
|
|
||||||
n.start_soon(get_img, 0)
|
|
||||||
n.start_soon(get_img, 1)
|
|
||||||
|
|
||||||
|
|
||||||
assert len(outputs) == 2
|
|
||||||
|
|
||||||
trio.run(main)
|
|
|
@ -7,9 +7,9 @@ import pynng
|
||||||
import pytest
|
import pytest
|
||||||
import trio_asyncio
|
import trio_asyncio
|
||||||
|
|
||||||
from skynet_bot.types import *
|
from skynet.brain import run_skynet
|
||||||
from skynet_bot.brain import run_skynet
|
from skynet.structs import *
|
||||||
from skynet_bot.frontend import open_skynet_rpc
|
from skynet.frontend import open_skynet_rpc
|
||||||
|
|
||||||
|
|
||||||
async def test_skynet_attempt_insecure(skynet_running):
|
async def test_skynet_attempt_insecure(skynet_running):
|
||||||
|
@ -40,7 +40,7 @@ async def test_skynet_dgpu_connection_simple(skynet_running):
|
||||||
|
|
||||||
# connect 1 dgpu
|
# connect 1 dgpu
|
||||||
res = await rpc_call(
|
res = await rpc_call(
|
||||||
'dgpu-0', 'dgpu_online', {'max_tasks': 3})
|
'dgpu-0', 'dgpu_online')
|
||||||
logging.info(res)
|
logging.info(res)
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue