mirror of https://github.com/skygpu/skynet.git
				
				
				
			Rework dgpu client to be single task
Add a lot of dgpu real image gen tests Modified docker files and environment to allow for quick test relaunch without image rebuild Rename package from skynet_bot to skynet Drop tractor usage cause cuda is oriented to just a single proc managing gpu resources Add ackwnoledge phase to image request for quick gpu disconnected type scenarios Add click entry point for dgpu Add posibility to reuse postgres_db fixture on same session by checking if schema init has been already donepull/2/head
							parent
							
								
									d2e676627a
								
							
						
					
					
						commit
						f6326ad05c
					
				| 
						 | 
					@ -1,3 +1,9 @@
 | 
				
			||||||
 | 
					.git
 | 
				
			||||||
hf_home
 | 
					hf_home
 | 
				
			||||||
inputs
 | 
					 | 
				
			||||||
outputs
 | 
					outputs
 | 
				
			||||||
 | 
					.python-version
 | 
				
			||||||
 | 
					.pytest-cache
 | 
				
			||||||
 | 
					**/__pycache__
 | 
				
			||||||
 | 
					*.egg-info
 | 
				
			||||||
 | 
					**/*.key
 | 
				
			||||||
 | 
					**/*.cert
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,5 +1,8 @@
 | 
				
			||||||
.python-version
 | 
					.python-version
 | 
				
			||||||
hf_home
 | 
					hf_home
 | 
				
			||||||
outputs
 | 
					outputs
 | 
				
			||||||
 | 
					secrets
 | 
				
			||||||
**/__pycache__
 | 
					**/__pycache__
 | 
				
			||||||
*.egg-info
 | 
					*.egg-info
 | 
				
			||||||
 | 
					**/*.key
 | 
				
			||||||
 | 
					**/*.cert
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -4,10 +4,16 @@ env DEBIAN_FRONTEND=noninteractive
 | 
				
			||||||
 | 
					
 | 
				
			||||||
workdir /skynet
 | 
					workdir /skynet
 | 
				
			||||||
 | 
					
 | 
				
			||||||
copy requirements.* ./
 | 
					copy requirements.test.txt requirements.test.txt
 | 
				
			||||||
 | 
					copy requirements.txt requirements.txt
 | 
				
			||||||
 | 
					copy pytest.ini ./
 | 
				
			||||||
 | 
					copy setup.py ./
 | 
				
			||||||
 | 
					copy skynet ./skynet
 | 
				
			||||||
 | 
					
 | 
				
			||||||
run pip install \
 | 
					run pip install \
 | 
				
			||||||
 | 
					    -e . \
 | 
				
			||||||
    -r requirements.txt \
 | 
					    -r requirements.txt \
 | 
				
			||||||
    -r requirements.test.txt
 | 
					    -r requirements.test.txt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
workdir /scripts
 | 
					copy scripts ./
 | 
				
			||||||
 | 
					copy tests ./
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -5,19 +5,25 @@ env DEBIAN_FRONTEND=noninteractive
 | 
				
			||||||
 | 
					
 | 
				
			||||||
workdir /skynet
 | 
					workdir /skynet
 | 
				
			||||||
 | 
					
 | 
				
			||||||
copy requirements.* ./
 | 
					copy requirements.cuda* ./
 | 
				
			||||||
 | 
					
 | 
				
			||||||
run pip install -U pip ninja
 | 
					run pip install -U pip ninja
 | 
				
			||||||
run pip install -r requirements.cuda.0.txt
 | 
					run pip install -r requirements.cuda.0.txt
 | 
				
			||||||
run pip install -v -r requirements.cuda.1.txt
 | 
					run pip install -v -r requirements.cuda.1.txt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
run pip install \
 | 
					copy requirements.test.txt requirements.test.txt
 | 
				
			||||||
 | 
					copy requirements.txt requirements.txt
 | 
				
			||||||
 | 
					copy pytest.ini pytest.ini
 | 
				
			||||||
 | 
					copy setup.py setup.py
 | 
				
			||||||
 | 
					copy skynet skynet
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					run pip install -e . \
 | 
				
			||||||
    -r requirements.txt \
 | 
					    -r requirements.txt \
 | 
				
			||||||
    -r requirements.test.txt
 | 
					    -r requirements.test.txt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					env PYTORCH_CUDA_ALLOC_CONF max_split_size_mb:128
 | 
				
			||||||
env NVIDIA_VISIBLE_DEVICES=all
 | 
					env NVIDIA_VISIBLE_DEVICES=all
 | 
				
			||||||
env HF_HOME /hf_home
 | 
					env HF_HOME /hf_home
 | 
				
			||||||
 | 
					
 | 
				
			||||||
env PYTORCH_CUDA_ALLOC_CONF max_split_size_mb:128
 | 
					copy scripts scripts
 | 
				
			||||||
 | 
					copy tests tests
 | 
				
			||||||
workdir /scripts
 | 
					 | 
				
			||||||
| 
						 | 
					@ -1,6 +1,6 @@
 | 
				
			||||||
docker build \
 | 
					docker build \
 | 
				
			||||||
    -t skynet:runtime-cuda \
 | 
					    -t skynet:runtime-cuda \
 | 
				
			||||||
    -f Dockerfile.runtime-cuda .
 | 
					    -f Dockerfile.runtime+cuda .
 | 
				
			||||||
 | 
					
 | 
				
			||||||
docker build \
 | 
					docker build \
 | 
				
			||||||
    -t skynet:runtime \
 | 
					    -t skynet:runtime \
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,2 +1,4 @@
 | 
				
			||||||
[pytest]
 | 
					[pytest]
 | 
				
			||||||
 | 
					log_cli = True
 | 
				
			||||||
 | 
					log_level = info
 | 
				
			||||||
trio_mode = true
 | 
					trio_mode = true
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,3 @@
 | 
				
			||||||
pdbpp
 | 
					 | 
				
			||||||
scipy
 | 
					scipy
 | 
				
			||||||
triton
 | 
					triton
 | 
				
			||||||
accelerate
 | 
					accelerate
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,5 +1,6 @@
 | 
				
			||||||
 | 
					pdbpp
 | 
				
			||||||
pytest
 | 
					pytest
 | 
				
			||||||
psycopg2
 | 
					psycopg2
 | 
				
			||||||
pytest-trio
 | 
					pytest-trio
 | 
				
			||||||
 | 
					
 | 
				
			||||||
git+https://github.com/tgoodlet/pytest-dockerctl.git@master#egg=pytest-dockerctl
 | 
					git+https://github.com/guilledk/pytest-dockerctl.git@host_network#egg=pytest-dockerctl
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -5,5 +5,3 @@ aiohttp
 | 
				
			||||||
msgspec
 | 
					msgspec
 | 
				
			||||||
pyOpenSSL
 | 
					pyOpenSSL
 | 
				
			||||||
trio_asyncio
 | 
					trio_asyncio
 | 
				
			||||||
 | 
					 | 
				
			||||||
git+https://github.com/goodboy/tractor.git@piker_pin#egg=tractor
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -8,7 +8,7 @@ import sys
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from OpenSSL import crypto, SSL
 | 
					from OpenSSL import crypto, SSL
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from skynet_bot.constants import DEFAULT_CERTS_DIR
 | 
					from skynet.constants import DEFAULT_CERTS_DIR
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def input_or_skip(txt, default):
 | 
					def input_or_skip(txt, default):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										9
									
								
								setup.py
								
								
								
								
							
							
						
						
									
										9
									
								
								setup.py
								
								
								
								
							| 
						 | 
					@ -1,11 +1,16 @@
 | 
				
			||||||
from setuptools import setup, find_packages
 | 
					from setuptools import setup, find_packages
 | 
				
			||||||
 | 
					
 | 
				
			||||||
setup(
 | 
					setup(
 | 
				
			||||||
    name='skynet-bot',
 | 
					    name='skynet',
 | 
				
			||||||
    version='0.1.0a6',
 | 
					    version='0.1.0a6',
 | 
				
			||||||
    description='Decentralized compute platform',
 | 
					    description='Decentralized compute platform',
 | 
				
			||||||
    author='Guillermo Rodriguez',
 | 
					    author='Guillermo Rodriguez',
 | 
				
			||||||
    author_email='guillermo@telos.net',
 | 
					    author_email='guillermo@telos.net',
 | 
				
			||||||
    packages=find_packages(),
 | 
					    packages=find_packages(),
 | 
				
			||||||
    install_requires=[]
 | 
					    entry_points={
 | 
				
			||||||
 | 
					        'console_scripts': [
 | 
				
			||||||
 | 
					            'skynet = skynet.cli:skynet',
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    install_requires=['click']
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -8,6 +8,7 @@ import logging
 | 
				
			||||||
from uuid import UUID
 | 
					from uuid import UUID
 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
from functools import partial
 | 
					from functools import partial
 | 
				
			||||||
 | 
					from contextlib import asynccontextmanager as acm
 | 
				
			||||||
from collections import OrderedDict
 | 
					from collections import OrderedDict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import trio
 | 
					import trio
 | 
				
			||||||
| 
						 | 
					@ -17,7 +18,7 @@ import trio_asyncio
 | 
				
			||||||
from pynng import TLSConfig
 | 
					from pynng import TLSConfig
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .db import *
 | 
					from .db import *
 | 
				
			||||||
from .types import *
 | 
					from .structs import *
 | 
				
			||||||
from .constants import *
 | 
					from .constants import *
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -27,18 +28,47 @@ class SkynetDGPUOffline(BaseException):
 | 
				
			||||||
class SkynetDGPUOverloaded(BaseException):
 | 
					class SkynetDGPUOverloaded(BaseException):
 | 
				
			||||||
    ...
 | 
					    ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class SkynetDGPUComputeError(BaseException):
 | 
				
			||||||
 | 
					    ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def rpc_service(sock, dgpu_bus, db_pool):
 | 
					class SkynetShutdownRequested(BaseException):
 | 
				
			||||||
 | 
					    ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@acm
 | 
				
			||||||
 | 
					async def open_rpc_service(sock, dgpu_bus, db_pool):
 | 
				
			||||||
    nodes = OrderedDict()
 | 
					    nodes = OrderedDict()
 | 
				
			||||||
    wip_reqs = {}
 | 
					    wip_reqs = {}
 | 
				
			||||||
    fin_reqs = {}
 | 
					    fin_reqs = {}
 | 
				
			||||||
 | 
					    next_worker: Optional[int] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def is_worker_busy(nid: int):
 | 
					    def connect_node(uid):
 | 
				
			||||||
        for task in nodes[nid]['tasks']:
 | 
					        nonlocal next_worker
 | 
				
			||||||
            if task != None:
 | 
					        nodes[uid] = {
 | 
				
			||||||
                return False
 | 
					            'task': None
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        logging.info(f'dgpu online: {uid}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return True
 | 
					        if not next_worker:
 | 
				
			||||||
 | 
					            next_worker = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def disconnect_node(uid):
 | 
				
			||||||
 | 
					        nonlocal next_worker
 | 
				
			||||||
 | 
					        if uid not in nodes:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					        i = list(nodes.keys()).index(uid)
 | 
				
			||||||
 | 
					        del nodes[uid]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if i < next_worker:
 | 
				
			||||||
 | 
					            next_worker -= 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if len(nodes) == 0:
 | 
				
			||||||
 | 
					            logging.info('nw: None')
 | 
				
			||||||
 | 
					            next_worker = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        logging.warning(f'dgpu offline: {uid}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def is_worker_busy(nid: str):
 | 
				
			||||||
 | 
					        return nodes[nid]['task'] != None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def are_all_workers_busy():
 | 
					    def are_all_workers_busy():
 | 
				
			||||||
        for nid in nodes.keys():
 | 
					        for nid in nodes.keys():
 | 
				
			||||||
| 
						 | 
					@ -47,30 +77,55 @@ async def rpc_service(sock, dgpu_bus, db_pool):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return True
 | 
					        return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    next_worker: Optional[int] = None
 | 
					 | 
				
			||||||
    def get_next_worker():
 | 
					    def get_next_worker():
 | 
				
			||||||
        nonlocal next_worker
 | 
					        nonlocal next_worker
 | 
				
			||||||
 | 
					        logging.info('get next_worker called')
 | 
				
			||||||
 | 
					        logging.info(f'pre next_worker: {next_worker}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if not next_worker:
 | 
					        if next_worker == None:
 | 
				
			||||||
            raise SkynetDGPUOffline
 | 
					            raise SkynetDGPUOffline
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if are_all_workers_busy():
 | 
					        if are_all_workers_busy():
 | 
				
			||||||
            raise SkynetDGPUOverloaded
 | 
					            raise SkynetDGPUOverloaded
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        while is_worker_busy(next_worker):
 | 
					
 | 
				
			||||||
 | 
					        nid = list(nodes.keys())[next_worker]
 | 
				
			||||||
 | 
					        while is_worker_busy(nid):
 | 
				
			||||||
            next_worker += 1
 | 
					            next_worker += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if next_worker >= len(nodes):
 | 
					            if next_worker >= len(nodes):
 | 
				
			||||||
                next_worker = 0
 | 
					                next_worker = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return next_worker
 | 
					            nid = list(nodes.keys())[next_worker]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        next_worker += 1
 | 
				
			||||||
 | 
					        if next_worker >= len(nodes):
 | 
				
			||||||
 | 
					            next_worker = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        logging.info(f'post next_worker: {next_worker}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return nid
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def dgpu_image_streamer():
 | 
					    async def dgpu_image_streamer():
 | 
				
			||||||
        nonlocal wip_reqs, fin_reqs
 | 
					        nonlocal wip_reqs, fin_reqs
 | 
				
			||||||
        while True:
 | 
					        while True:
 | 
				
			||||||
            msg = await dgpu_bus.arecv_msg()
 | 
					            msg = await dgpu_bus.arecv_msg()
 | 
				
			||||||
            rid = UUID(bytes=msg.bytes[:16]).hex
 | 
					            rid = UUID(bytes=msg.bytes[:16]).hex
 | 
				
			||||||
            img = msg.bytes[16:].hex()
 | 
					            raw_msg = msg.bytes[16:]
 | 
				
			||||||
 | 
					            logging.info(f'streamer got back {rid} of size {len(raw_msg)}')
 | 
				
			||||||
 | 
					            match raw_msg[:5]:
 | 
				
			||||||
 | 
					                case b'error':
 | 
				
			||||||
 | 
					                    img = raw_msg.decode()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                case b'ack':
 | 
				
			||||||
 | 
					                    img = raw_msg
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                case _:
 | 
				
			||||||
 | 
					                    img = base64.b64encode(raw_msg).hex()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if rid not in wip_reqs:
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            fin_reqs[rid] = img
 | 
					            fin_reqs[rid] = img
 | 
				
			||||||
            event = wip_reqs[rid]
 | 
					            event = wip_reqs[rid]
 | 
				
			||||||
            event.set()
 | 
					            event.set()
 | 
				
			||||||
| 
						 | 
					@ -79,13 +134,14 @@ async def rpc_service(sock, dgpu_bus, db_pool):
 | 
				
			||||||
    async def dgpu_stream_one_img(req: ImageGenRequest):
 | 
					    async def dgpu_stream_one_img(req: ImageGenRequest):
 | 
				
			||||||
        nonlocal wip_reqs, fin_reqs, next_worker
 | 
					        nonlocal wip_reqs, fin_reqs, next_worker
 | 
				
			||||||
        nid = get_next_worker()
 | 
					        nid = get_next_worker()
 | 
				
			||||||
        logging.info(f'dgpu_stream_one_img {next_worker} {nid}')
 | 
					        idx = list(nodes.keys()).index(nid)
 | 
				
			||||||
 | 
					        logging.info(f'dgpu_stream_one_img {idx}/{len(nodes)} {nid}')
 | 
				
			||||||
        rid = uuid.uuid4().hex
 | 
					        rid = uuid.uuid4().hex
 | 
				
			||||||
        event = trio.Event()
 | 
					        ack_event = trio.Event()
 | 
				
			||||||
        wip_reqs[rid] = event
 | 
					        img_event = trio.Event()
 | 
				
			||||||
 | 
					        wip_reqs[rid] = ack_event
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        tid = nodes[nid]['tasks'].index(None)
 | 
					        nodes[nid]['task'] = rid
 | 
				
			||||||
        nodes[nid]['tasks'][tid] = rid
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        dgpu_req = DGPUBusRequest(
 | 
					        dgpu_req = DGPUBusRequest(
 | 
				
			||||||
            rid=rid,
 | 
					            rid=rid,
 | 
				
			||||||
| 
						 | 
					@ -98,14 +154,37 @@ async def rpc_service(sock, dgpu_bus, db_pool):
 | 
				
			||||||
        await dgpu_bus.asend(
 | 
					        await dgpu_bus.asend(
 | 
				
			||||||
            json.dumps(dgpu_req.to_dict()).encode())
 | 
					            json.dumps(dgpu_req.to_dict()).encode())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        await event.wait()
 | 
					        with trio.move_on_after(4):
 | 
				
			||||||
 | 
					            await ack_event.wait()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        nodes[nid]['tasks'][tid] = None
 | 
					        logging.info(f'ack event: {ack_event.is_set()}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not ack_event.is_set():
 | 
				
			||||||
 | 
					            disconnect_node(nid)
 | 
				
			||||||
 | 
					            raise SkynetDGPUOffline('dgpu failed to acknowledge request')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ack = fin_reqs[rid]
 | 
				
			||||||
 | 
					        if ack != b'ack':
 | 
				
			||||||
 | 
					            disconnect_node(nid)
 | 
				
			||||||
 | 
					            raise SkynetDGPUOffline('dgpu failed to acknowledge request')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        wip_reqs[rid] = img_event
 | 
				
			||||||
 | 
					        with trio.move_on_after(30):
 | 
				
			||||||
 | 
					            await img_event.wait()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not img_event.is_set():
 | 
				
			||||||
 | 
					            disconnect_node(nid)
 | 
				
			||||||
 | 
					            raise SkynetDGPUComputeError('30 seconds timeout while processing request')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        nodes[nid]['task'] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        img = fin_reqs[rid]
 | 
					        img = fin_reqs[rid]
 | 
				
			||||||
        del fin_reqs[rid]
 | 
					        del fin_reqs[rid]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        logging.info(f'done streaming {img}')
 | 
					        logging.info(f'done streaming {len(img)} bytes')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if 'error' in img:
 | 
				
			||||||
 | 
					            raise SkynetDGPUComputeError(img)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return rid, img
 | 
					        return rid, img
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -122,6 +201,10 @@ async def rpc_service(sock, dgpu_bus, db_pool):
 | 
				
			||||||
                        user_config = {**(await get_user_config(conn, user))}
 | 
					                        user_config = {**(await get_user_config(conn, user))}
 | 
				
			||||||
                        del user_config['id']
 | 
					                        del user_config['id']
 | 
				
			||||||
                        prompt = req.params['prompt']
 | 
					                        prompt = req.params['prompt']
 | 
				
			||||||
 | 
					                        user_config= {
 | 
				
			||||||
 | 
					                            key : req.params.get(key, val) 
 | 
				
			||||||
 | 
					                            for key, val in user_config.items()
 | 
				
			||||||
 | 
					                        }
 | 
				
			||||||
                        req = ImageGenRequest(
 | 
					                        req = ImageGenRequest(
 | 
				
			||||||
                            prompt=prompt,
 | 
					                            prompt=prompt,
 | 
				
			||||||
                            **user_config
 | 
					                            **user_config
 | 
				
			||||||
| 
						 | 
					@ -165,9 +248,10 @@ async def rpc_service(sock, dgpu_bus, db_pool):
 | 
				
			||||||
                    case _:
 | 
					                    case _:
 | 
				
			||||||
                        logging.warn('unknown method')
 | 
					                        logging.warn('unknown method')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        except SkynetDGPUOffline:
 | 
					        except SkynetDGPUOffline as e:
 | 
				
			||||||
            result = {
 | 
					            result = {
 | 
				
			||||||
                'error': 'skynet_dgpu_offline'
 | 
					                'error': 'skynet_dgpu_offline',
 | 
				
			||||||
 | 
					                'message': str(e)
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        except SkynetDGPUOverloaded:
 | 
					        except SkynetDGPUOverloaded:
 | 
				
			||||||
| 
						 | 
					@ -176,22 +260,22 @@ async def rpc_service(sock, dgpu_bus, db_pool):
 | 
				
			||||||
                'nodes': len(nodes)
 | 
					                'nodes': len(nodes)
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        except BaseException as e:
 | 
					        except SkynetDGPUComputeError as e:
 | 
				
			||||||
            logging.error(e)
 | 
					 | 
				
			||||||
            result = {
 | 
					            result = {
 | 
				
			||||||
                'error': 'skynet_internal_error'
 | 
					                'error': 'skynet_dgpu_compute_error',
 | 
				
			||||||
 | 
					                'message': str(e)
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        await rpc_ctx.asend(
 | 
					        await rpc_ctx.asend(
 | 
				
			||||||
            json.dumps(
 | 
					            json.dumps(
 | 
				
			||||||
                SkynetRPCResponse(result=result).to_dict()).encode())
 | 
					                SkynetRPCResponse(result=result).to_dict()).encode())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def request_service(n):
 | 
				
			||||||
    async with trio.open_nursery() as n:
 | 
					        nonlocal next_worker
 | 
				
			||||||
        n.start_soon(dgpu_image_streamer)
 | 
					 | 
				
			||||||
        while True:
 | 
					        while True:
 | 
				
			||||||
            ctx = sock.new_context()
 | 
					            ctx = sock.new_context()
 | 
				
			||||||
            msg = await ctx.arecv_msg()
 | 
					            msg = await ctx.arecv_msg()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            content = msg.bytes.decode()
 | 
					            content = msg.bytes.decode()
 | 
				
			||||||
            req = SkynetRPCRequest(**json.loads(content))
 | 
					            req = SkynetRPCRequest(**json.loads(content))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -199,27 +283,14 @@ async def rpc_service(sock, dgpu_bus, db_pool):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            result = {}
 | 
					            result = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if req.method == 'dgpu_online':
 | 
					            if req.method == 'skynet_shutdown':
 | 
				
			||||||
                nodes[req.uid] = {
 | 
					                raise SkynetShutdownRequested
 | 
				
			||||||
                    'tasks': [None for _ in range(req.params['max_tasks'])],
 | 
					 | 
				
			||||||
                    'max_tasks': req.params['max_tasks']
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
                logging.info(f'dgpu online: {req.uid}')
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if not next_worker:
 | 
					            elif req.method == 'dgpu_online':
 | 
				
			||||||
                    next_worker = 0
 | 
					                connect_node(req.uid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            elif req.method == 'dgpu_offline':
 | 
					            elif req.method == 'dgpu_offline':
 | 
				
			||||||
                i = list(nodes.keys()).index(req.uid)
 | 
					                disconnect_node(req.uid)
 | 
				
			||||||
                del nodes[req.uid]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if i < next_worker:
 | 
					 | 
				
			||||||
                    next_worker -= 1
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if len(nodes) == 0:
 | 
					 | 
				
			||||||
                    next_worker = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                logging.info(f'dgpu offline: {req.uid}')
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            elif req.method == 'dgpu_workers':
 | 
					            elif req.method == 'dgpu_workers':
 | 
				
			||||||
                result = len(nodes)
 | 
					                result = len(nodes)
 | 
				
			||||||
| 
						 | 
					@ -238,13 +309,22 @@ async def rpc_service(sock, dgpu_bus, db_pool):
 | 
				
			||||||
                        result={'ok': result}).to_dict()).encode())
 | 
					                        result={'ok': result}).to_dict()).encode())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async with trio.open_nursery() as n:
 | 
				
			||||||
 | 
					        n.start_soon(dgpu_image_streamer)
 | 
				
			||||||
 | 
					        n.start_soon(request_service, n)
 | 
				
			||||||
 | 
					        logging.info('starting rpc service')
 | 
				
			||||||
 | 
					        yield
 | 
				
			||||||
 | 
					        logging.info('stopping rpc service')
 | 
				
			||||||
 | 
					        n.cancel_scope.cancel()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@acm
 | 
				
			||||||
async def run_skynet(
 | 
					async def run_skynet(
 | 
				
			||||||
    db_user: str = DB_USER,
 | 
					    db_user: str = DB_USER,
 | 
				
			||||||
    db_pass: str = DB_PASS,
 | 
					    db_pass: str = DB_PASS,
 | 
				
			||||||
    db_host: str = DB_HOST,
 | 
					    db_host: str = DB_HOST,
 | 
				
			||||||
    rpc_address: str = DEFAULT_RPC_ADDR,
 | 
					    rpc_address: str = DEFAULT_RPC_ADDR,
 | 
				
			||||||
    dgpu_address: str = DEFAULT_DGPU_ADDR,
 | 
					    dgpu_address: str = DEFAULT_DGPU_ADDR,
 | 
				
			||||||
    task_status = trio.TASK_STATUS_IGNORED,
 | 
					 | 
				
			||||||
    security: bool = True
 | 
					    security: bool = True
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    logging.basicConfig(level=logging.INFO)
 | 
					    logging.basicConfig(level=logging.INFO)
 | 
				
			||||||
| 
						 | 
					@ -260,8 +340,8 @@ async def run_skynet(
 | 
				
			||||||
            (cert_path).read_text()
 | 
					            (cert_path).read_text()
 | 
				
			||||||
            for cert_path in (certs_dir / 'whitelist').glob('*.cert')]
 | 
					            for cert_path in (certs_dir / 'whitelist').glob('*.cert')]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        logging.info(f'tls_key: {tls_key}')
 | 
					        cert_start = tls_cert.index('\n') + 1
 | 
				
			||||||
        logging.info(f'tls_cert: {tls_cert}')
 | 
					        logging.info(f'tls_cert: {tls_cert[cert_start:cert_start+64]}...')
 | 
				
			||||||
        logging.info(f'tls_whitelist len: {len(tls_whitelist)}')
 | 
					        logging.info(f'tls_whitelist len: {len(tls_whitelist)}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        rpc_address = 'tls+' + rpc_address
 | 
					        rpc_address = 'tls+' + rpc_address
 | 
				
			||||||
| 
						 | 
					@ -271,16 +351,14 @@ async def run_skynet(
 | 
				
			||||||
            own_key_string=tls_key,
 | 
					            own_key_string=tls_key,
 | 
				
			||||||
            own_cert_string=tls_cert)
 | 
					            own_cert_string=tls_cert)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async with (
 | 
					    with (
 | 
				
			||||||
        trio.open_nursery() as n,
 | 
					        pynng.Rep0() as rpc_sock,
 | 
				
			||||||
        open_database_connection(
 | 
					        pynng.Bus0() as dgpu_bus
 | 
				
			||||||
            db_user, db_pass, db_host) as db_pool
 | 
					 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        logging.info('connected to db.')
 | 
					        async with open_database_connection(
 | 
				
			||||||
        with (
 | 
					            db_user, db_pass, db_host) as db_pool:
 | 
				
			||||||
            pynng.Rep0() as rpc_sock,
 | 
					
 | 
				
			||||||
            pynng.Bus0() as dgpu_bus
 | 
					            logging.info('connected to db.')
 | 
				
			||||||
        ):
 | 
					 | 
				
			||||||
            if security:
 | 
					            if security:
 | 
				
			||||||
                rpc_sock.tls_config = tls_config
 | 
					                rpc_sock.tls_config = tls_config
 | 
				
			||||||
                dgpu_bus.tls_config = tls_config
 | 
					                dgpu_bus.tls_config = tls_config
 | 
				
			||||||
| 
						 | 
					@ -288,13 +366,11 @@ async def run_skynet(
 | 
				
			||||||
            rpc_sock.listen(rpc_address)
 | 
					            rpc_sock.listen(rpc_address)
 | 
				
			||||||
            dgpu_bus.listen(dgpu_address)
 | 
					            dgpu_bus.listen(dgpu_address)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            n.start_soon(
 | 
					 | 
				
			||||||
                rpc_service, rpc_sock, dgpu_bus, db_pool)
 | 
					 | 
				
			||||||
            task_status.started()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                await trio.sleep_forever()
 | 
					                async with open_rpc_service(rpc_sock, dgpu_bus, db_pool):
 | 
				
			||||||
 | 
					                    yield
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            except KeyboardInterrupt:
 | 
					            except SkynetShutdownRequested:
 | 
				
			||||||
                ...
 | 
					                ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        logging.info('disconnected from db.')
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,68 @@
 | 
				
			||||||
 | 
					#!/usr/bin/python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					from functools import partial
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import trio
 | 
				
			||||||
 | 
					import click
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .dgpu import open_dgpu_node
 | 
				
			||||||
 | 
					from .utils import txt2img
 | 
				
			||||||
 | 
					from .constants import ALGOS
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@click.group()
 | 
				
			||||||
 | 
					def skynet(*args, **kwargs):
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@skynet.command()
 | 
				
			||||||
 | 
					@click.option('--model', '-m', default=ALGOS['midj'])
 | 
				
			||||||
 | 
					@click.option(
 | 
				
			||||||
 | 
					    '--prompt', '-p', default='a red tractor in a wheat field')
 | 
				
			||||||
 | 
					@click.option('--output', '-o', default='output.png')
 | 
				
			||||||
 | 
					@click.option('--width', '-w', default=512)
 | 
				
			||||||
 | 
					@click.option('--height', '-h', default=512)
 | 
				
			||||||
 | 
					@click.option('--guidance', '-g', default=10.0)
 | 
				
			||||||
 | 
					@click.option('--steps', '-s', default=26)
 | 
				
			||||||
 | 
					@click.option('--seed', '-S', default=None)
 | 
				
			||||||
 | 
					def txt2img(*args
 | 
				
			||||||
 | 
					#     model: str,
 | 
				
			||||||
 | 
					#     prompt: str,
 | 
				
			||||||
 | 
					#     output: str
 | 
				
			||||||
 | 
					#     width: int, height: int,
 | 
				
			||||||
 | 
					#     guidance: float,
 | 
				
			||||||
 | 
					#     steps: int,
 | 
				
			||||||
 | 
					#     seed: Optional[int]
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    assert 'HF_TOKEN' in os.environ
 | 
				
			||||||
 | 
					    txt2img(
 | 
				
			||||||
 | 
					        os.environ['HF_TOKEN'], *args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@skynet.group()
 | 
				
			||||||
 | 
					def run(*args, **kwargs):
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@run.command()
 | 
				
			||||||
 | 
					@click.option('--loglevel', '-l', default='warning', help='Logging level')
 | 
				
			||||||
 | 
					@click.option(
 | 
				
			||||||
 | 
					    '--key', '-k', default='dgpu')
 | 
				
			||||||
 | 
					@click.option(
 | 
				
			||||||
 | 
					    '--cert', '-c', default='whitelist/dgpu')
 | 
				
			||||||
 | 
					@click.option(
 | 
				
			||||||
 | 
					    '--algos', '-a', default=None)
 | 
				
			||||||
 | 
					def dgpu(
 | 
				
			||||||
 | 
					    loglevel: str,
 | 
				
			||||||
 | 
					    key: str,
 | 
				
			||||||
 | 
					    cert: str,
 | 
				
			||||||
 | 
					    algos: Optional[str]
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    trio.run(
 | 
				
			||||||
 | 
					        partial(
 | 
				
			||||||
 | 
					            open_dgpu_node,
 | 
				
			||||||
 | 
					            cert,
 | 
				
			||||||
 | 
					            key_name=key,
 | 
				
			||||||
 | 
					            initial_algos=json.loads(algos)
 | 
				
			||||||
 | 
					    ))
 | 
				
			||||||
| 
						 | 
					@ -1,6 +1,6 @@
 | 
				
			||||||
#!/usr/bin/python
 | 
					#!/usr/bin/python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
API_TOKEN = '5880619053:AAFge2UfObw1kCn9Kb7AAyqaIHs_HgM0Fx0'
 | 
					DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
DB_HOST = 'ancap.tech:34508'
 | 
					DB_HOST = 'ancap.tech:34508'
 | 
				
			||||||
DB_USER = 'skynet'
 | 
					DB_USER = 'skynet'
 | 
				
			||||||
| 
						 | 
					@ -8,8 +8,8 @@ DB_PASS = 'password'
 | 
				
			||||||
DB_NAME = 'skynet'
 | 
					DB_NAME = 'skynet'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
ALGOS = {
 | 
					ALGOS = {
 | 
				
			||||||
    'stable': 'runwayml/stable-diffusion-v1-5',
 | 
					 | 
				
			||||||
    'midj': 'prompthero/openjourney',
 | 
					    'midj': 'prompthero/openjourney',
 | 
				
			||||||
 | 
					    'stable': 'runwayml/stable-diffusion-v1-5',
 | 
				
			||||||
    'hdanime': 'Linaqruf/anything-v3.0',
 | 
					    'hdanime': 'Linaqruf/anything-v3.0',
 | 
				
			||||||
    'waifu': 'hakurei/waifu-diffusion',
 | 
					    'waifu': 'hakurei/waifu-diffusion',
 | 
				
			||||||
    'ghibli': 'nitrosocke/Ghibli-Diffusion',
 | 
					    'ghibli': 'nitrosocke/Ghibli-Diffusion',
 | 
				
			||||||
| 
						 | 
					@ -122,7 +122,7 @@ DEFAULT_CERT_DGPU = 'dgpu.key'
 | 
				
			||||||
DEFAULT_RPC_ADDR = 'tcp://127.0.0.1:41000'
 | 
					DEFAULT_RPC_ADDR = 'tcp://127.0.0.1:41000'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
DEFAULT_DGPU_ADDR = 'tcp://127.0.0.1:41069'
 | 
					DEFAULT_DGPU_ADDR = 'tcp://127.0.0.1:41069'
 | 
				
			||||||
DEFAULT_DGPU_MAX_TASKS = 3
 | 
					DEFAULT_DGPU_MAX_TASKS = 2
 | 
				
			||||||
DEFAULT_INITAL_ALGOS = ['midj', 'stable', 'ink']
 | 
					DEFAULT_INITAL_ALGOS = ['midj', 'stable', 'ink']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
DATE_FORMAT = '%B the %dth %Y, %H:%M:%S'
 | 
					DATE_FORMAT = '%B the %dth %Y, %H:%M:%S'
 | 
				
			||||||
| 
						 | 
					@ -7,6 +7,9 @@ from contextlib import asynccontextmanager as acm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import trio
 | 
					import trio
 | 
				
			||||||
import triopg
 | 
					import triopg
 | 
				
			||||||
 | 
					import trio_asyncio
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from asyncpg.exceptions import UndefinedColumnError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .constants import *
 | 
					from .constants import *
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -72,13 +75,22 @@ async def open_database_connection(
 | 
				
			||||||
    db_host: str = DB_HOST,
 | 
					    db_host: str = DB_HOST,
 | 
				
			||||||
    db_name: str = DB_NAME
 | 
					    db_name: str = DB_NAME
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    async with triopg.create_pool(
 | 
					    async with trio_asyncio.open_loop() as loop:
 | 
				
			||||||
        dsn=f'postgres://{db_user}:{db_pass}@{db_host}/{db_name}'
 | 
					        async with triopg.create_pool(
 | 
				
			||||||
    ) as pool_conn:
 | 
					            dsn=f'postgres://{db_user}:{db_pass}@{db_host}/{db_name}'
 | 
				
			||||||
        async with pool_conn.acquire() as conn:
 | 
					        ) as pool_conn:
 | 
				
			||||||
            await conn.execute(DB_INIT_SQL)
 | 
					            async with pool_conn.acquire() as conn:
 | 
				
			||||||
 | 
					                res = await conn.execute(f'''
 | 
				
			||||||
 | 
					                    select distinct table_schema
 | 
				
			||||||
 | 
					                    from information_schema.tables
 | 
				
			||||||
 | 
					                    where table_schema = \'{db_name}\'
 | 
				
			||||||
 | 
					                ''')
 | 
				
			||||||
 | 
					                if '1' in res:
 | 
				
			||||||
 | 
					                    logging.info('schema already in db, skipping init')
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    await conn.execute(DB_INIT_SQL)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        yield pool_conn
 | 
					            yield pool_conn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def get_user(conn, uid: str):
 | 
					async def get_user(conn, uid: str):
 | 
				
			||||||
| 
						 | 
					@ -135,6 +147,7 @@ async def new_user(conn, uid: str):
 | 
				
			||||||
                tg_id, generated, joined, last_prompt, role)
 | 
					                tg_id, generated, joined, last_prompt, role)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            VALUES($1, $2, $3, $4, $5)
 | 
					            VALUES($1, $2, $3, $4, $5)
 | 
				
			||||||
 | 
					            ON CONFLICT DO NOTHING
 | 
				
			||||||
        ''')
 | 
					        ''')
 | 
				
			||||||
        await stmt.fetch(
 | 
					        await stmt.fetch(
 | 
				
			||||||
            tg_id, 0, date, None, DEFAULT_ROLE
 | 
					            tg_id, 0, date, None, DEFAULT_ROLE
 | 
				
			||||||
| 
						 | 
					@ -147,6 +160,7 @@ async def new_user(conn, uid: str):
 | 
				
			||||||
                id, algo, step, width, height, seed, guidance, upscaler)
 | 
					                id, algo, step, width, height, seed, guidance, upscaler)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            VALUES($1, $2, $3, $4, $5, $6, $7, $8)
 | 
					            VALUES($1, $2, $3, $4, $5, $6, $7, $8)
 | 
				
			||||||
 | 
					            ON CONFLICT DO NOTHING
 | 
				
			||||||
        ''')
 | 
					        ''')
 | 
				
			||||||
        user = await stmt.fetch(
 | 
					        user = await stmt.fetch(
 | 
				
			||||||
            new_uid,
 | 
					            new_uid,
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,197 @@
 | 
				
			||||||
 | 
					#!/usr/bin/python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import gc
 | 
				
			||||||
 | 
					import io
 | 
				
			||||||
 | 
					import trio
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
 | 
					import uuid
 | 
				
			||||||
 | 
					import random
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from typing import List, Optional
 | 
				
			||||||
 | 
					from pathlib import Path
 | 
				
			||||||
 | 
					from contextlib import AsyncExitStack
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import pynng
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from pynng import TLSConfig
 | 
				
			||||||
 | 
					from diffusers import (
 | 
				
			||||||
 | 
					    StableDiffusionPipeline,
 | 
				
			||||||
 | 
					    EulerAncestralDiscreteScheduler
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .structs import *
 | 
				
			||||||
 | 
					from .constants import *
 | 
				
			||||||
 | 
					from .frontend import open_skynet_rpc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def pipeline_for(algo: str, mem_fraction: float = 1.0):
 | 
				
			||||||
 | 
					    assert torch.cuda.is_available()
 | 
				
			||||||
 | 
					    torch.cuda.empty_cache()
 | 
				
			||||||
 | 
					    torch.cuda.set_per_process_memory_fraction(mem_fraction)
 | 
				
			||||||
 | 
					    torch.backends.cuda.matmul.allow_tf32 = True
 | 
				
			||||||
 | 
					    torch.backends.cudnn.allow_tf32 = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    params = {
 | 
				
			||||||
 | 
					        'torch_dtype': torch.float16,
 | 
				
			||||||
 | 
					        'safety_checker': None
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if algo == 'stable':
 | 
				
			||||||
 | 
					        params['revision'] = 'fp16'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pipe = StableDiffusionPipeline.from_pretrained(
 | 
				
			||||||
 | 
					        ALGOS[algo], **params)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
 | 
				
			||||||
 | 
					        pipe.scheduler.config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pipe.enable_vae_slicing()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return pipe.to("cuda")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DGPUComputeError(BaseException):
 | 
				
			||||||
 | 
					    ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def open_dgpu_node(
 | 
				
			||||||
 | 
					    cert_name: str,
 | 
				
			||||||
 | 
					    key_name: Optional[str],
 | 
				
			||||||
 | 
					    rpc_address: str = DEFAULT_RPC_ADDR,
 | 
				
			||||||
 | 
					    dgpu_address: str = DEFAULT_DGPU_ADDR,
 | 
				
			||||||
 | 
					    initial_algos: Optional[List[str]] = None,
 | 
				
			||||||
 | 
					    security: bool = True
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    logging.basicConfig(level=logging.INFO)
 | 
				
			||||||
 | 
					    logging.info(f'starting dgpu node!')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    name = uuid.uuid4()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    logging.info(f'loading models...')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    initial_algos = (
 | 
				
			||||||
 | 
					        initial_algos
 | 
				
			||||||
 | 
					        if initial_algos else DEFAULT_INITAL_ALGOS
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    models = {}
 | 
				
			||||||
 | 
					    for algo in initial_algos:
 | 
				
			||||||
 | 
					        models[algo] = {
 | 
				
			||||||
 | 
					            'pipe': pipeline_for(algo),
 | 
				
			||||||
 | 
					            'generated': 0
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        logging.info(f'loaded {algo}.')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    logging.info('memory summary:\n')
 | 
				
			||||||
 | 
					    logging.info(torch.cuda.memory_summary())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def gpu_compute_one(ireq: ImageGenRequest):
 | 
				
			||||||
 | 
					        if ireq.algo not in models:
 | 
				
			||||||
 | 
					            least_used = list(models.keys())[0]
 | 
				
			||||||
 | 
					            for model in models:
 | 
				
			||||||
 | 
					                if models[least_used]['generated'] > models[model]['generated']:
 | 
				
			||||||
 | 
					                    least_used = model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            del models[least_used]
 | 
				
			||||||
 | 
					            gc.collect()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            models[ireq.algo] = {
 | 
				
			||||||
 | 
					                'pipe': pipeline_for(ireq.algo),
 | 
				
			||||||
 | 
					                'generated': 0
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64)
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            image = models[ireq.algo]['pipe'](
 | 
				
			||||||
 | 
					                ireq.prompt,
 | 
				
			||||||
 | 
					                width=ireq.width,
 | 
				
			||||||
 | 
					                height=ireq.height,
 | 
				
			||||||
 | 
					                guidance_scale=ireq.guidance,
 | 
				
			||||||
 | 
					                num_inference_steps=ireq.step,
 | 
				
			||||||
 | 
					                generator=torch.Generator("cuda").manual_seed(seed)
 | 
				
			||||||
 | 
					            ).images[0]
 | 
				
			||||||
 | 
					            return image.tobytes()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        except BaseException as e:
 | 
				
			||||||
 | 
					            logging.error(e)
 | 
				
			||||||
 | 
					            raise DGPUComputeError(str(e))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        finally:
 | 
				
			||||||
 | 
					            torch.cuda.empty_cache()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async with open_skynet_rpc(
 | 
				
			||||||
 | 
					        security=security,
 | 
				
			||||||
 | 
					        cert_name=cert_name,
 | 
				
			||||||
 | 
					        key_name=key_name
 | 
				
			||||||
 | 
					    ) as rpc_call:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        tls_config = None
 | 
				
			||||||
 | 
					        if security:
 | 
				
			||||||
 | 
					            # load tls certs
 | 
				
			||||||
 | 
					            if not key_name:
 | 
				
			||||||
 | 
					                key_name = certs_name
 | 
				
			||||||
 | 
					            certs_dir = Path(DEFAULT_CERTS_DIR).resolve()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            skynet_cert_path = certs_dir / 'brain.cert'
 | 
				
			||||||
 | 
					            tls_cert_path = certs_dir / f'{cert_name}.cert'
 | 
				
			||||||
 | 
					            tls_key_path = certs_dir / f'{key_name}.key'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            skynet_cert = skynet_cert_path.read_text()
 | 
				
			||||||
 | 
					            tls_cert = tls_cert_path.read_text()
 | 
				
			||||||
 | 
					            tls_key = tls_key_path.read_text()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            logging.info(f'skynet cert: {skynet_cert_path}')
 | 
				
			||||||
 | 
					            logging.info(f'dgpu cert:  {tls_cert_path}')
 | 
				
			||||||
 | 
					            logging.info(f'dgpu key: {tls_key_path}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            dgpu_address = 'tls+' + dgpu_address
 | 
				
			||||||
 | 
					            tls_config = TLSConfig(
 | 
				
			||||||
 | 
					                TLSConfig.MODE_CLIENT,
 | 
				
			||||||
 | 
					                own_key_string=tls_key,
 | 
				
			||||||
 | 
					                own_cert_string=tls_cert,
 | 
				
			||||||
 | 
					                ca_string=skynet_cert)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        logging.info(f'connecting to {dgpu_address}')
 | 
				
			||||||
 | 
					        with pynng.Bus0() as dgpu_sock:
 | 
				
			||||||
 | 
					            dgpu_sock.tls_config = tls_config
 | 
				
			||||||
 | 
					            dgpu_sock.dial(dgpu_address)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            res = await rpc_call(name.hex, 'dgpu_online')
 | 
				
			||||||
 | 
					            logging.info(res)
 | 
				
			||||||
 | 
					            assert 'ok' in res.result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                while True:
 | 
				
			||||||
 | 
					                    msg = await dgpu_sock.arecv()
 | 
				
			||||||
 | 
					                    req = DGPUBusRequest(
 | 
				
			||||||
 | 
					                        **json.loads(msg.decode()))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    if req.nid != name.hex:
 | 
				
			||||||
 | 
					                        logging.info('witnessed request {req.rid}, for {req.nid}')
 | 
				
			||||||
 | 
					                        continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    # send ack
 | 
				
			||||||
 | 
					                    await dgpu_sock.asend(
 | 
				
			||||||
 | 
					                        bytes.fromhex(req.rid) + b'ack')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    logging.info(f'sent ack, processing {req.rid}...')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    try:
 | 
				
			||||||
 | 
					                        img = await gpu_compute_one(
 | 
				
			||||||
 | 
					                            ImageGenRequest(**req.params))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    except DGPUComputeError as e:
 | 
				
			||||||
 | 
					                        img = b'error' + str(e).encode()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    await dgpu_sock.asend(
 | 
				
			||||||
 | 
					                        bytes.fromhex(req.rid) + img)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            except KeyboardInterrupt:
 | 
				
			||||||
 | 
					                logging.info('interrupt caught, stopping...')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            finally:
 | 
				
			||||||
 | 
					                res = await rpc_call(name.hex, 'dgpu_offline')
 | 
				
			||||||
 | 
					                logging.info(res)
 | 
				
			||||||
 | 
					                assert 'ok' in res.result
 | 
				
			||||||
| 
						 | 
					@ -10,7 +10,7 @@ import pynng
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from pynng import TLSConfig
 | 
					from pynng import TLSConfig
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..types import SkynetRPCRequest, SkynetRPCResponse
 | 
					from ..structs import SkynetRPCRequest, SkynetRPCResponse
 | 
				
			||||||
from ..constants import *
 | 
					from ..constants import *
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,57 @@
 | 
				
			||||||
 | 
					#!/usr/bin/python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import random
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					from pathlib import Path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from diffusers import StableDiffusionPipeline
 | 
				
			||||||
 | 
					from huggingface_hub import login
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def txt2img(
 | 
				
			||||||
 | 
					    hf_token: str,
 | 
				
			||||||
 | 
					    model_name: str,
 | 
				
			||||||
 | 
					    prompt: str,
 | 
				
			||||||
 | 
					    output: str,
 | 
				
			||||||
 | 
					    width: int, height: int,
 | 
				
			||||||
 | 
					    guidance: float,
 | 
				
			||||||
 | 
					    steps: int,
 | 
				
			||||||
 | 
					    seed: Optional[int]
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    assert torch.cuda.is_available()
 | 
				
			||||||
 | 
					    torch.cuda.empty_cache()
 | 
				
			||||||
 | 
					    torch.cuda.set_per_process_memory_fraction(0.333)
 | 
				
			||||||
 | 
					    torch.backends.cuda.matmul.allow_tf32 = True
 | 
				
			||||||
 | 
					    torch.backends.cudnn.allow_tf32 = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    login(token=hf_token)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    params = {
 | 
				
			||||||
 | 
					        'torch_dtype': torch.float16,
 | 
				
			||||||
 | 
					        'safety_checker': None
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    if model_name == 'runwayml/stable-diffusion-v1-5':
 | 
				
			||||||
 | 
					        params['revision'] = 'fp16'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pipe = StableDiffusionPipeline.from_pretrained(
 | 
				
			||||||
 | 
					        model_name, **params)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
 | 
				
			||||||
 | 
					        pipe.scheduler.config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pipe = pipe.to("cuda")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64)
 | 
				
			||||||
 | 
					    prompt = prompt
 | 
				
			||||||
 | 
					    image = pipe(
 | 
				
			||||||
 | 
					        prompt,
 | 
				
			||||||
 | 
					        width=width,
 | 
				
			||||||
 | 
					        height=height,
 | 
				
			||||||
 | 
					        guidance_scale=guidance, num_inference_steps=steps,
 | 
				
			||||||
 | 
					        generator=torch.Generator("cuda").manual_seed(seed)
 | 
				
			||||||
 | 
					    ).images[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    image.save(output)
 | 
				
			||||||
| 
						 | 
					@ -1,124 +0,0 @@
 | 
				
			||||||
#!/usr/bin/python
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import trio
 | 
					 | 
				
			||||||
import json
 | 
					 | 
				
			||||||
import uuid
 | 
					 | 
				
			||||||
import logging
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import pynng
 | 
					 | 
				
			||||||
import tractor
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from . import gpu
 | 
					 | 
				
			||||||
from .gpu import open_gpu_worker
 | 
					 | 
				
			||||||
from .types import *
 | 
					 | 
				
			||||||
from .constants import *
 | 
					 | 
				
			||||||
from .frontend import rpc_call
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
async def open_dgpu_node(
 | 
					 | 
				
			||||||
    cert_name: str,
 | 
					 | 
				
			||||||
    key_name: Optional[str],
 | 
					 | 
				
			||||||
    rpc_address: str = DEFAULT_RPC_ADDR,
 | 
					 | 
				
			||||||
    dgpu_address: str = DEFAULT_DGPU_ADDR,
 | 
					 | 
				
			||||||
    dgpu_max_tasks: int = DEFAULT_DGPU_MAX_TASKS,
 | 
					 | 
				
			||||||
    initial_algos: str = DEFAULT_INITAL_ALGOS,
 | 
					 | 
				
			||||||
    security: bool = True
 | 
					 | 
				
			||||||
):
 | 
					 | 
				
			||||||
    logging.basicConfig(level=logging.INFO)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    name = uuid.uuid4()
 | 
					 | 
				
			||||||
    workers = initial_algos.copy()
 | 
					 | 
				
			||||||
    tasks = [None for _ in range(dgpu_max_tasks)]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    portal_map: dict[int, tractor.Portal]
 | 
					 | 
				
			||||||
    contexts: dict[int, tractor.Context]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_next_worker(need_algo: str):
 | 
					 | 
				
			||||||
        nonlocal workers, tasks
 | 
					 | 
				
			||||||
        for task, algo in zip(workers, tasks):
 | 
					 | 
				
			||||||
            if need_algo == algo and not task:
 | 
					 | 
				
			||||||
                return workers.index(need_algo)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return tasks.index(None)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def gpu_streamer(
 | 
					 | 
				
			||||||
        ctx: tractor.Context,
 | 
					 | 
				
			||||||
        nid: int
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        nonlocal tasks
 | 
					 | 
				
			||||||
        async with ctx.open_stream() as stream:
 | 
					 | 
				
			||||||
            async for img in stream:
 | 
					 | 
				
			||||||
                tasks[nid]['res'] = img
 | 
					 | 
				
			||||||
                tasks[nid]['event'].set()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def gpu_compute_one(ireq: ImageGenRequest):
 | 
					 | 
				
			||||||
        wid = get_next_worker(ireq.algo)
 | 
					 | 
				
			||||||
        event = trio.Event()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        workers[wid] = ireq.algo
 | 
					 | 
				
			||||||
        tasks[wid] = {
 | 
					 | 
				
			||||||
            'res': None, 'event': event}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        await contexts[i].send(ireq)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        await event.wait()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        img = tasks[wid]['res']
 | 
					 | 
				
			||||||
        tasks[wid] = None
 | 
					 | 
				
			||||||
        return img
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async with open_skynet_rpc(
 | 
					 | 
				
			||||||
        security=security,
 | 
					 | 
				
			||||||
        cert_name=cert_name,
 | 
					 | 
				
			||||||
        key_name=key_name
 | 
					 | 
				
			||||||
    ) as rpc_call:
 | 
					 | 
				
			||||||
        with pynng.Bus0(dial=dgpu_address) as dgpu_sock:
 | 
					 | 
				
			||||||
            async def _process_dgpu_req(req: DGPUBusRequest):
 | 
					 | 
				
			||||||
                img = await gpu_compute_one(
 | 
					 | 
				
			||||||
                    ImageGenRequest(**req.params))
 | 
					 | 
				
			||||||
                await dgpu_sock.asend(
 | 
					 | 
				
			||||||
                    bytes.fromhex(req.rid) + img)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            res = await rpc_call(
 | 
					 | 
				
			||||||
                name.hex, 'dgpu_online', {'max_tasks': dgpu_max_tasks})
 | 
					 | 
				
			||||||
            logging.info(res)
 | 
					 | 
				
			||||||
            assert 'ok' in res.result
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            async with (
 | 
					 | 
				
			||||||
                tractor.open_actor_cluster(
 | 
					 | 
				
			||||||
                    modules=['skynet_bot.gpu'],
 | 
					 | 
				
			||||||
                    count=dgpu_max_tasks,
 | 
					 | 
				
			||||||
                    names=[i for i in range(dgpu_max_tasks)]
 | 
					 | 
				
			||||||
                    ) as portal_map,
 | 
					 | 
				
			||||||
                trio.open_nursery() as n
 | 
					 | 
				
			||||||
            ):
 | 
					 | 
				
			||||||
                logging.info(f'starting {dgpu_max_tasks} gpu workers')
 | 
					 | 
				
			||||||
                async with tractor.gather_contexts((
 | 
					 | 
				
			||||||
                    portal.open_context(
 | 
					 | 
				
			||||||
                        open_gpu_worker, algo, 1.0 / dgpu_max_tasks)
 | 
					 | 
				
			||||||
                    for portal in portal_map.values()
 | 
					 | 
				
			||||||
                )) as contexts:
 | 
					 | 
				
			||||||
                    contexts = {i: ctx for i, ctx in enumerate(contexts)}
 | 
					 | 
				
			||||||
                    for i, ctx in contexts.items():
 | 
					 | 
				
			||||||
                        n.start_soon(
 | 
					 | 
				
			||||||
                            gpu_streamer, ctx, i)
 | 
					 | 
				
			||||||
                    try:
 | 
					 | 
				
			||||||
                        while True:
 | 
					 | 
				
			||||||
                            msg = await dgpu_sock.arecv()
 | 
					 | 
				
			||||||
                            req = DGPUBusRequest(
 | 
					 | 
				
			||||||
                                **json.loads(msg.decode()))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                            if req.nid != name.hex:
 | 
					 | 
				
			||||||
                                continue
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                            logging.info(f'dgpu: {name}, req: {req}')
 | 
					 | 
				
			||||||
                            n.start_soon(
 | 
					 | 
				
			||||||
                                _process_dgpu_req, req)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    except KeyboardInterrupt:
 | 
					 | 
				
			||||||
                        ...
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            res = await rpc_call(name.hex, 'dgpu_offline')
 | 
					 | 
				
			||||||
            logging.info(res)
 | 
					 | 
				
			||||||
            assert 'ok' in res.result
 | 
					 | 
				
			||||||
| 
						 | 
					@ -1,77 +0,0 @@
 | 
				
			||||||
#!/usr/bin/python
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import io
 | 
					 | 
				
			||||||
import random
 | 
					 | 
				
			||||||
import logging
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
import tractor
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from diffusers import (
 | 
					 | 
				
			||||||
    StableDiffusionPipeline,
 | 
					 | 
				
			||||||
    EulerAncestralDiscreteScheduler
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from .types import ImageGenRequest
 | 
					 | 
				
			||||||
from .constants import ALGOS
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def pipeline_for(algo: str, mem_fraction: float):
 | 
					 | 
				
			||||||
    assert torch.cuda.is_available()
 | 
					 | 
				
			||||||
    torch.cuda.empty_cache()
 | 
					 | 
				
			||||||
    torch.cuda.set_per_process_memory_fraction(mem_fraction)
 | 
					 | 
				
			||||||
    torch.backends.cuda.matmul.allow_tf32 = True
 | 
					 | 
				
			||||||
    torch.backends.cudnn.allow_tf32 = True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    params = {
 | 
					 | 
				
			||||||
        'torch_dtype': torch.float16,
 | 
					 | 
				
			||||||
        'safety_checker': None
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if algo == 'stable':
 | 
					 | 
				
			||||||
        params['revision'] = 'fp16'
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    pipe = StableDiffusionPipeline.from_pretrained(
 | 
					 | 
				
			||||||
        ALGOS[algo], **params)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
 | 
					 | 
				
			||||||
        pipe.scheduler.config)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return pipe.to("cuda")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@tractor.context
 | 
					 | 
				
			||||||
async def open_gpu_worker(
 | 
					 | 
				
			||||||
    ctx: tractor.Context,
 | 
					 | 
				
			||||||
    start_algo: str,
 | 
					 | 
				
			||||||
    mem_fraction: float
 | 
					 | 
				
			||||||
):
 | 
					 | 
				
			||||||
    log = tractor.log.get_logger(name='gpu', _root_name='skynet')
 | 
					 | 
				
			||||||
    log.info(f'starting gpu worker with algo {start_algo}...')
 | 
					 | 
				
			||||||
    current_algo = start_algo
 | 
					 | 
				
			||||||
    with torch.no_grad():
 | 
					 | 
				
			||||||
        pipe = pipeline_for(current_algo, mem_fraction)
 | 
					 | 
				
			||||||
        log.info('pipeline loaded')
 | 
					 | 
				
			||||||
        await ctx.started()
 | 
					 | 
				
			||||||
        async with ctx.open_stream() as bus:
 | 
					 | 
				
			||||||
            async for ireq in bus:
 | 
					 | 
				
			||||||
                if ireq.algo != current_algo:
 | 
					 | 
				
			||||||
                    current_algo = ireq.algo
 | 
					 | 
				
			||||||
                    pipe = pipeline_for(current_algo, mem_fraction)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64)
 | 
					 | 
				
			||||||
                image = pipe(
 | 
					 | 
				
			||||||
                    ireq.prompt,
 | 
					 | 
				
			||||||
                    width=ireq.width,
 | 
					 | 
				
			||||||
                    height=ireq.height,
 | 
					 | 
				
			||||||
                    guidance_scale=ireq.guidance,
 | 
					 | 
				
			||||||
                    num_inference_steps=ireq.step,
 | 
					 | 
				
			||||||
                    generator=torch.Generator("cuda").manual_seed(seed)
 | 
					 | 
				
			||||||
                ).images[0]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                torch.cuda.empty_cache()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                # convert PIL.Image to BytesIO
 | 
					 | 
				
			||||||
                img_bytes = io.BytesIO()
 | 
					 | 
				
			||||||
                image.save(img_bytes, format='PNG')
 | 
					 | 
				
			||||||
                await bus.send(img_bytes.getvalue())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
| 
						 | 
					@ -1,2 +0,0 @@
 | 
				
			||||||
from OpenSSL.crypto import load_publickey, FILETYPE_PEM, verify, X509
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
							
								
								
									
										9
									
								
								test.sh
								
								
								
								
							
							
						
						
									
										9
									
								
								test.sh
								
								
								
								
							| 
						 | 
					@ -1,9 +0,0 @@
 | 
				
			||||||
docker run \
 | 
					 | 
				
			||||||
    -it \
 | 
					 | 
				
			||||||
    --rm \
 | 
					 | 
				
			||||||
    --gpus=all \
 | 
					 | 
				
			||||||
    --mount type=bind,source="$(pwd)",target=/skynet \
 | 
					 | 
				
			||||||
    skynet:runtime-cuda \
 | 
					 | 
				
			||||||
    bash -c \
 | 
					 | 
				
			||||||
        "cd /skynet && pip install -e . && \
 | 
					 | 
				
			||||||
        pytest $1 --log-cli-level=info"
 | 
					 | 
				
			||||||
| 
						 | 
					@ -1,21 +1,25 @@
 | 
				
			||||||
#!/usr/bin/python
 | 
					#!/usr/bin/python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
import random
 | 
					import random
 | 
				
			||||||
import string
 | 
					import string
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from functools import partial
 | 
					from functools import partial
 | 
				
			||||||
 | 
					from pathlib import Path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import trio
 | 
					import trio
 | 
				
			||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
import psycopg2
 | 
					import psycopg2
 | 
				
			||||||
import trio_asyncio
 | 
					import trio_asyncio
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from docker.types import Mount, DeviceRequest
 | 
				
			||||||
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
 | 
					from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from skynet_bot.constants import *
 | 
					from skynet.constants import *
 | 
				
			||||||
from skynet_bot.brain import run_skynet
 | 
					from skynet.brain import run_skynet
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.fixture(scope='session')
 | 
					@pytest.fixture(scope='session')
 | 
				
			||||||
| 
						 | 
					@ -29,6 +33,7 @@ def postgres_db(dockerctl):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    with dockerctl.run(
 | 
					    with dockerctl.run(
 | 
				
			||||||
        'postgres',
 | 
					        'postgres',
 | 
				
			||||||
 | 
					        name='skynet-test-postgres',
 | 
				
			||||||
        ports={'5432/tcp': None},
 | 
					        ports={'5432/tcp': None},
 | 
				
			||||||
        environment={
 | 
					        environment={
 | 
				
			||||||
            'POSTGRES_PASSWORD': rpassword
 | 
					            'POSTGRES_PASSWORD': rpassword
 | 
				
			||||||
| 
						 | 
					@ -67,6 +72,8 @@ def postgres_db(dockerctl):
 | 
				
			||||||
            cursor.execute(
 | 
					            cursor.execute(
 | 
				
			||||||
                f'GRANT ALL PRIVILEGES ON DATABASE {DB_NAME} TO {DB_USER}')
 | 
					                f'GRANT ALL PRIVILEGES ON DATABASE {DB_NAME} TO {DB_USER}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        conn.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        logging.info('done.')
 | 
					        logging.info('done.')
 | 
				
			||||||
        yield container, password, host
 | 
					        yield container, password, host
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -74,16 +81,44 @@ def postgres_db(dockerctl):
 | 
				
			||||||
@pytest.fixture
 | 
					@pytest.fixture
 | 
				
			||||||
async def skynet_running(postgres_db):
 | 
					async def skynet_running(postgres_db):
 | 
				
			||||||
    db_container, db_pass, db_host = postgres_db
 | 
					    db_container, db_pass, db_host = postgres_db
 | 
				
			||||||
    async with (
 | 
					
 | 
				
			||||||
        trio_asyncio.open_loop(),
 | 
					    async with run_skynet(
 | 
				
			||||||
        trio.open_nursery() as n
 | 
					        db_pass=db_pass,
 | 
				
			||||||
 | 
					        db_host=db_host
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        await n.start(
 | 
					 | 
				
			||||||
            partial(run_skynet,
 | 
					 | 
				
			||||||
                db_pass=db_pass,
 | 
					 | 
				
			||||||
                db_host=db_host))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        yield
 | 
					        yield
 | 
				
			||||||
        n.cancel_scope.cancel()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture
 | 
				
			||||||
 | 
					def dgpu_workers(request, dockerctl, skynet_running):
 | 
				
			||||||
 | 
					    devices = [DeviceRequest(capabilities=[['gpu']])]
 | 
				
			||||||
 | 
					    mounts = [Mount(
 | 
				
			||||||
 | 
					        '/skynet', str(Path().resolve()), type='bind')]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    num_containers, initial_algos = request.param
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    cmd = f'''
 | 
				
			||||||
 | 
					    pip install -e . && \
 | 
				
			||||||
 | 
					    skynet run dgpu --algos=\'{json.dumps(initial_algos)}\'
 | 
				
			||||||
 | 
					    '''
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    logging.info(f'launching: \n{cmd}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with dockerctl.run(
 | 
				
			||||||
 | 
					        DOCKER_RUNTIME_CUDA,
 | 
				
			||||||
 | 
					        name='skynet-test-runtime-cuda',
 | 
				
			||||||
 | 
					        command=['bash', '-c', cmd],
 | 
				
			||||||
 | 
					        environment={
 | 
				
			||||||
 | 
					            'HF_TOKEN': os.environ['HF_TOKEN'],
 | 
				
			||||||
 | 
					            'HF_HOME': '/skynet/hf_home'
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        network='host',
 | 
				
			||||||
 | 
					        mounts=mounts,
 | 
				
			||||||
 | 
					        device_requests=devices,
 | 
				
			||||||
 | 
					        num=num_containers
 | 
				
			||||||
 | 
					    ) as containers:
 | 
				
			||||||
 | 
					        yield containers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        #for i, container in enumerate(containers):
 | 
				
			||||||
 | 
					        #    logging.info(f'container {i} logs:')
 | 
				
			||||||
 | 
					        #    logging.info(container.logs().decode())
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,57 +1,248 @@
 | 
				
			||||||
#!/usr/bin/python
 | 
					#!/usr/bin/python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import io
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
 | 
					import base64
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from hashlib import sha256
 | 
				
			||||||
 | 
					from functools import partial
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import trio
 | 
					import trio
 | 
				
			||||||
import pynng
 | 
					import pytest
 | 
				
			||||||
import tractor
 | 
					import tractor
 | 
				
			||||||
import trio_asyncio
 | 
					import trio_asyncio
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from skynet_bot.gpu import open_gpu_worker
 | 
					from PIL import Image
 | 
				
			||||||
from skynet_bot.dgpu import open_dgpu_node
 | 
					
 | 
				
			||||||
from skynet_bot.types import *
 | 
					from skynet.brain import SkynetDGPUComputeError
 | 
				
			||||||
from skynet_bot.brain import run_skynet
 | 
					from skynet.constants import *
 | 
				
			||||||
from skynet_bot.constants import *
 | 
					from skynet.frontend import open_skynet_rpc
 | 
				
			||||||
from skynet_bot.frontend import open_skynet_rpc, rpc_call
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_dgpu_simple():
 | 
					async def wait_for_dgpus(rpc, amount: int, timeout: float = 30.0):
 | 
				
			||||||
    async def main():
 | 
					    gpu_ready = False
 | 
				
			||||||
 | 
					    start_time = time.time()
 | 
				
			||||||
 | 
					    current_time = time.time()
 | 
				
			||||||
 | 
					    while not gpu_ready and (current_time - start_time) < timeout:
 | 
				
			||||||
 | 
					        res = await rpc('dgpu-test', 'dgpu_workers')
 | 
				
			||||||
 | 
					        logging.info(res)
 | 
				
			||||||
 | 
					        if res.result['ok'] >= amount:
 | 
				
			||||||
 | 
					            break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        await trio.sleep(1)
 | 
				
			||||||
 | 
					        current_time = time.time()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert (current_time - start_time) < timeout
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_images = set()
 | 
				
			||||||
 | 
					async def check_request_img(
 | 
				
			||||||
 | 
					    i: int,
 | 
				
			||||||
 | 
					    width: int = 512,
 | 
				
			||||||
 | 
					    height: int = 512,
 | 
				
			||||||
 | 
					    expect_unique=True
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    global _images
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async with open_skynet_rpc(
 | 
				
			||||||
 | 
					        security=True,
 | 
				
			||||||
 | 
					        cert_name='whitelist/testing',
 | 
				
			||||||
 | 
					        key_name='testing'
 | 
				
			||||||
 | 
					    ) as rpc_call:
 | 
				
			||||||
 | 
					        res = await rpc_call(
 | 
				
			||||||
 | 
					            'tg+580213293', 'txt2img', {
 | 
				
			||||||
 | 
					                'prompt': 'red old tractor in a sunny wheat field',
 | 
				
			||||||
 | 
					                'step': 28,
 | 
				
			||||||
 | 
					                'width': width, 'height': height,
 | 
				
			||||||
 | 
					                'guidance': 7.5,
 | 
				
			||||||
 | 
					                'seed': None,
 | 
				
			||||||
 | 
					                'algo': list(ALGOS.keys())[i],
 | 
				
			||||||
 | 
					                'upscaler': None
 | 
				
			||||||
 | 
					            })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if 'error' in res.result:
 | 
				
			||||||
 | 
					            raise SkynetDGPUComputeError(json.dumps(res.result))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        img_raw = base64.b64decode(bytes.fromhex(res.result['img']))
 | 
				
			||||||
 | 
					        img_sha = sha256(img_raw).hexdigest()
 | 
				
			||||||
 | 
					        img = Image.frombytes(
 | 
				
			||||||
 | 
					            'RGB', (width, height), img_raw)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if expect_unique and img_sha in _images:
 | 
				
			||||||
 | 
					            raise ValueError('Duplicated image sha: {img_sha}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        _images.add(img_sha)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        logging.info(f'img sha256: {img_sha} size: {len(img_raw)}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert len(img_raw) > 100000
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize(
 | 
				
			||||||
 | 
					    'dgpu_workers', [(1, ['midj'])], indirect=True)
 | 
				
			||||||
 | 
					async def test_dgpu_worker_compute_error(dgpu_workers):
 | 
				
			||||||
 | 
					    '''Attempt to generate a huge image and check we get the right error,
 | 
				
			||||||
 | 
					    then generate a smaller image to show gpu worker recovery
 | 
				
			||||||
 | 
					    '''
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async with open_skynet_rpc(
 | 
				
			||||||
 | 
					        security=True,
 | 
				
			||||||
 | 
					        cert_name='whitelist/testing',
 | 
				
			||||||
 | 
					        key_name='testing'
 | 
				
			||||||
 | 
					    ) as test_rpc:
 | 
				
			||||||
 | 
					        await wait_for_dgpus(test_rpc, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with pytest.raises(SkynetDGPUComputeError) as e:
 | 
				
			||||||
 | 
					            await check_request_img(0, width=4096, height=4096)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        logging.info(e)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        await check_request_img(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize(
 | 
				
			||||||
 | 
					    'dgpu_workers', [(1, ['midj', 'stable'])], indirect=True)
 | 
				
			||||||
 | 
					async def test_dgpu_workers(dgpu_workers):
 | 
				
			||||||
 | 
					    '''Generate two images in a single dgpu worker using
 | 
				
			||||||
 | 
					    two different models.
 | 
				
			||||||
 | 
					    '''
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async with open_skynet_rpc(
 | 
				
			||||||
 | 
					        security=True,
 | 
				
			||||||
 | 
					        cert_name='whitelist/testing',
 | 
				
			||||||
 | 
					        key_name='testing'
 | 
				
			||||||
 | 
					    ) as test_rpc:
 | 
				
			||||||
 | 
					        await wait_for_dgpus(test_rpc, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        await check_request_img(0)
 | 
				
			||||||
 | 
					        await check_request_img(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize(
 | 
				
			||||||
 | 
					    'dgpu_workers', [(2, ['midj'])], indirect=True)
 | 
				
			||||||
 | 
					async def test_dgpu_workers_two(dgpu_workers):
 | 
				
			||||||
 | 
					    '''Generate two images in two separate dgpu workers
 | 
				
			||||||
 | 
					    '''
 | 
				
			||||||
 | 
					    async with open_skynet_rpc(
 | 
				
			||||||
 | 
					        security=True,
 | 
				
			||||||
 | 
					        cert_name='whitelist/testing',
 | 
				
			||||||
 | 
					        key_name='testing'
 | 
				
			||||||
 | 
					    ) as test_rpc:
 | 
				
			||||||
 | 
					        await wait_for_dgpus(test_rpc, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        async with trio.open_nursery() as n:
 | 
					        async with trio.open_nursery() as n:
 | 
				
			||||||
            await n.start(
 | 
					            n.start_soon(check_request_img, 0)
 | 
				
			||||||
                run_skynet,
 | 
					            n.start_soon(check_request_img, 0)
 | 
				
			||||||
                'skynet', '3GbZd6UojbD8V7zWpeFn', 'ancap.tech:34508')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            await trio.sleep(2)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            for i in range(3):
 | 
					 | 
				
			||||||
                n.start_soon(open_dgpu_node)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            await trio.sleep(1)
 | 
					 | 
				
			||||||
            start = time.time()
 | 
					 | 
				
			||||||
            async def request_img():
 | 
					 | 
				
			||||||
                with pynng.Req0(dial=DEFAULT_RPC_ADDR) as rpc_sock:
 | 
					 | 
				
			||||||
                    res = await rpc_call(
 | 
					 | 
				
			||||||
                        rpc_sock, 'tg+1', 'txt2img', {
 | 
					 | 
				
			||||||
                            'prompt': 'test',
 | 
					 | 
				
			||||||
                            'step': 28,
 | 
					 | 
				
			||||||
                            'width': 512, 'height': 512,
 | 
					 | 
				
			||||||
                            'guidance': 7.5,
 | 
					 | 
				
			||||||
                            'seed': None,
 | 
					 | 
				
			||||||
                            'algo': 'stable',
 | 
					 | 
				
			||||||
                            'upscaler': None
 | 
					 | 
				
			||||||
                        })
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    logging.info(res)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            async with trio.open_nursery() as inner_n:
 | 
					 | 
				
			||||||
                for i in range(3):
 | 
					 | 
				
			||||||
                    inner_n.start_soon(request_img)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            logging.info(f'time elapsed: {time.time() - start}')
 | 
					 | 
				
			||||||
            n.cancel_scope.cancel()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    trio_asyncio.run(main)
 | 
					@pytest.mark.parametrize(
 | 
				
			||||||
 | 
					    'dgpu_workers', [(1, ['midj'])], indirect=True)
 | 
				
			||||||
 | 
					async def test_dgpu_worker_algo_swap(dgpu_workers):
 | 
				
			||||||
 | 
					    '''Generate an image using a non default model
 | 
				
			||||||
 | 
					    '''
 | 
				
			||||||
 | 
					    async with open_skynet_rpc(
 | 
				
			||||||
 | 
					        security=True,
 | 
				
			||||||
 | 
					        cert_name='whitelist/testing',
 | 
				
			||||||
 | 
					        key_name='testing'
 | 
				
			||||||
 | 
					    ) as test_rpc:
 | 
				
			||||||
 | 
					        await wait_for_dgpus(test_rpc, 1)
 | 
				
			||||||
 | 
					        await check_request_img(5)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize(
 | 
				
			||||||
 | 
					    'dgpu_workers', [(3, ['midj'])], indirect=True)
 | 
				
			||||||
 | 
					async def test_dgpu_rotation_next_worker(dgpu_workers):
 | 
				
			||||||
 | 
					    '''Connect three dgpu workers, disconnect and check next_worker
 | 
				
			||||||
 | 
					    rotation happens correctly
 | 
				
			||||||
 | 
					    '''
 | 
				
			||||||
 | 
					    async with open_skynet_rpc(
 | 
				
			||||||
 | 
					        security=True,
 | 
				
			||||||
 | 
					        cert_name='whitelist/testing',
 | 
				
			||||||
 | 
					        key_name='testing'
 | 
				
			||||||
 | 
					    ) as test_rpc:
 | 
				
			||||||
 | 
					        await wait_for_dgpus(test_rpc, 3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        res = await test_rpc('testing-rpc', 'dgpu_next')
 | 
				
			||||||
 | 
					        logging.info(res)
 | 
				
			||||||
 | 
					        assert 'ok' in res.result
 | 
				
			||||||
 | 
					        assert res.result['ok'] == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        await check_request_img(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        res = await test_rpc('testing-rpc', 'dgpu_next')
 | 
				
			||||||
 | 
					        logging.info(res)
 | 
				
			||||||
 | 
					        assert 'ok' in res.result
 | 
				
			||||||
 | 
					        assert res.result['ok'] == 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        await check_request_img(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        res = await test_rpc('testing-rpc', 'dgpu_next')
 | 
				
			||||||
 | 
					        logging.info(res)
 | 
				
			||||||
 | 
					        assert 'ok' in res.result
 | 
				
			||||||
 | 
					        assert res.result['ok'] == 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        await check_request_img(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        res = await test_rpc('testing-rpc', 'dgpu_next')
 | 
				
			||||||
 | 
					        logging.info(res)
 | 
				
			||||||
 | 
					        assert 'ok' in res.result
 | 
				
			||||||
 | 
					        assert res.result['ok'] == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize(
 | 
				
			||||||
 | 
					    'dgpu_workers', [(3, ['midj'])], indirect=True)
 | 
				
			||||||
 | 
					async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers):
 | 
				
			||||||
 | 
					    '''Connect three dgpu workers, disconnect the first one and check
 | 
				
			||||||
 | 
					    next_worker rotation happens correctly
 | 
				
			||||||
 | 
					    '''
 | 
				
			||||||
 | 
					    async with open_skynet_rpc(
 | 
				
			||||||
 | 
					        security=True,
 | 
				
			||||||
 | 
					        cert_name='whitelist/testing',
 | 
				
			||||||
 | 
					        key_name='testing'
 | 
				
			||||||
 | 
					    ) as test_rpc:
 | 
				
			||||||
 | 
					        await wait_for_dgpus(test_rpc, 3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        await trio.sleep(3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # stop worker who's turn is next
 | 
				
			||||||
 | 
					        for _ in range(2):
 | 
				
			||||||
 | 
					            ec, out = dgpu_workers[0].exec_run(['pkill', '-INT', '-f', 'skynet'])
 | 
				
			||||||
 | 
					            assert ec == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        dgpu_workers[0].wait()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        res = await test_rpc('testing-rpc', 'dgpu_workers')
 | 
				
			||||||
 | 
					        logging.info(res)
 | 
				
			||||||
 | 
					        assert 'ok' in res.result
 | 
				
			||||||
 | 
					        assert res.result['ok'] == 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        async with trio.open_nursery() as n:
 | 
				
			||||||
 | 
					            n.start_soon(check_request_img, 0)
 | 
				
			||||||
 | 
					            n.start_soon(check_request_img, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def test_dgpu_no_ack_node_disconnect(skynet_running):
 | 
				
			||||||
 | 
					    async with open_skynet_rpc(
 | 
				
			||||||
 | 
					        security=True,
 | 
				
			||||||
 | 
					        cert_name='whitelist/testing',
 | 
				
			||||||
 | 
					        key_name='testing'
 | 
				
			||||||
 | 
					    ) as rpc_call:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        res = await rpc_call('dgpu-0', 'dgpu_online')
 | 
				
			||||||
 | 
					        logging.info(res)
 | 
				
			||||||
 | 
					        assert 'ok' in res.result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        await wait_for_dgpus(rpc_call, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with pytest.raises(SkynetDGPUComputeError) as e:
 | 
				
			||||||
 | 
					            await check_request_img(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert 'dgpu failed to acknowledge request' in str(e)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        res = await rpc_call('testing-rpc', 'dgpu_workers')
 | 
				
			||||||
 | 
					        logging.info(res)
 | 
				
			||||||
 | 
					        assert 'ok' in res.result
 | 
				
			||||||
 | 
					        assert res.result['ok'] == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,107 +0,0 @@
 | 
				
			||||||
import trio
 | 
					 | 
				
			||||||
import tractor
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from skynet_bot.types import *
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@tractor.context
 | 
					 | 
				
			||||||
async def open_fake_worker(
 | 
					 | 
				
			||||||
    ctx: tractor.Context,
 | 
					 | 
				
			||||||
    start_algo: str,
 | 
					 | 
				
			||||||
    mem_fraction: float
 | 
					 | 
				
			||||||
):
 | 
					 | 
				
			||||||
    log = tractor.log.get_logger(name='gpu', _root_name='skynet')
 | 
					 | 
				
			||||||
    log.info(f'starting gpu worker with algo {start_algo}...')
 | 
					 | 
				
			||||||
    current_algo = start_algo
 | 
					 | 
				
			||||||
    log.info('pipeline loaded')
 | 
					 | 
				
			||||||
    await ctx.started()
 | 
					 | 
				
			||||||
    async with ctx.open_stream() as bus:
 | 
					 | 
				
			||||||
        async for ireq in bus:
 | 
					 | 
				
			||||||
            if ireq:
 | 
					 | 
				
			||||||
                await bus.send('hello!')
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                break
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def test_gpu_worker():
 | 
					 | 
				
			||||||
    log = tractor.log.get_logger(name='root', _root_name='skynet')
 | 
					 | 
				
			||||||
    async def main():
 | 
					 | 
				
			||||||
        async with (
 | 
					 | 
				
			||||||
            tractor.open_nursery(debug_mode=True) as an,
 | 
					 | 
				
			||||||
            trio.open_nursery() as n
 | 
					 | 
				
			||||||
        ):
 | 
					 | 
				
			||||||
            portal = await an.start_actor(
 | 
					 | 
				
			||||||
                'gpu_worker',
 | 
					 | 
				
			||||||
                enable_modules=[__name__],
 | 
					 | 
				
			||||||
                debug_mode=True
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            log.info('portal opened')
 | 
					 | 
				
			||||||
            async with (
 | 
					 | 
				
			||||||
                portal.open_context(
 | 
					 | 
				
			||||||
                    open_fake_worker,
 | 
					 | 
				
			||||||
                    start_algo='midj',
 | 
					 | 
				
			||||||
                    mem_fraction=0.6
 | 
					 | 
				
			||||||
                ) as (ctx, _),
 | 
					 | 
				
			||||||
                ctx.open_stream() as stream,
 | 
					 | 
				
			||||||
            ):
 | 
					 | 
				
			||||||
                log.info('opened worker sending req...')
 | 
					 | 
				
			||||||
                ireq = ImageGenRequest(
 | 
					 | 
				
			||||||
                    prompt='a red tractor on a wheat field',
 | 
					 | 
				
			||||||
                    step=28,
 | 
					 | 
				
			||||||
                    width=512, height=512,
 | 
					 | 
				
			||||||
                    guidance=10, seed=None,
 | 
					 | 
				
			||||||
                    algo='midj', upscaler=None)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                await stream.send(ireq)
 | 
					 | 
				
			||||||
                log.info('sent, await respnse')
 | 
					 | 
				
			||||||
                async for msg in stream:
 | 
					 | 
				
			||||||
                    log.info(f'got {msg}')
 | 
					 | 
				
			||||||
                    break
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                assert msg == 'hello!'
 | 
					 | 
				
			||||||
                await stream.send(None)
 | 
					 | 
				
			||||||
                log.info('done.')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            await portal.cancel_actor()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    trio.run(main)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def test_gpu_two_workers():
 | 
					 | 
				
			||||||
    async def main():
 | 
					 | 
				
			||||||
        outputs = []
 | 
					 | 
				
			||||||
        async with (
 | 
					 | 
				
			||||||
            tractor.open_actor_cluster(
 | 
					 | 
				
			||||||
                modules=[__name__],
 | 
					 | 
				
			||||||
                count=2,
 | 
					 | 
				
			||||||
                names=[0, 1]) as portal_map,
 | 
					 | 
				
			||||||
            tractor.trionics.gather_contexts((
 | 
					 | 
				
			||||||
                portal.open_context(
 | 
					 | 
				
			||||||
                    open_fake_worker,
 | 
					 | 
				
			||||||
                    start_algo='midj',
 | 
					 | 
				
			||||||
                    mem_fraction=0.333)
 | 
					 | 
				
			||||||
                for portal in portal_map.values()
 | 
					 | 
				
			||||||
            )) as contexts,
 | 
					 | 
				
			||||||
            trio.open_nursery() as n
 | 
					 | 
				
			||||||
        ):
 | 
					 | 
				
			||||||
            ireq = ImageGenRequest(
 | 
					 | 
				
			||||||
                prompt='a red tractor on a wheat field',
 | 
					 | 
				
			||||||
                step=28,
 | 
					 | 
				
			||||||
                width=512, height=512,
 | 
					 | 
				
			||||||
                guidance=10, seed=None,
 | 
					 | 
				
			||||||
                algo='midj', upscaler=None)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            async def get_img(i):
 | 
					 | 
				
			||||||
                ctx = contexts[i]
 | 
					 | 
				
			||||||
                async with ctx.open_stream() as stream:
 | 
					 | 
				
			||||||
                    await stream.send(ireq)
 | 
					 | 
				
			||||||
                    async for img in stream:
 | 
					 | 
				
			||||||
                        outputs[i] = img
 | 
					 | 
				
			||||||
                        await portal_map[i].cancel_actor()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            n.start_soon(get_img, 0)
 | 
					 | 
				
			||||||
            n.start_soon(get_img, 1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        assert len(outputs) == 2
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    trio.run(main)
 | 
					 | 
				
			||||||
| 
						 | 
					@ -7,9 +7,9 @@ import pynng
 | 
				
			||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
import trio_asyncio
 | 
					import trio_asyncio
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from skynet_bot.types import *
 | 
					from skynet.brain import run_skynet
 | 
				
			||||||
from skynet_bot.brain import run_skynet
 | 
					from skynet.structs import *
 | 
				
			||||||
from skynet_bot.frontend import open_skynet_rpc
 | 
					from skynet.frontend import open_skynet_rpc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def test_skynet_attempt_insecure(skynet_running):
 | 
					async def test_skynet_attempt_insecure(skynet_running):
 | 
				
			||||||
| 
						 | 
					@ -40,7 +40,7 @@ async def test_skynet_dgpu_connection_simple(skynet_running):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # connect 1 dgpu
 | 
					        # connect 1 dgpu
 | 
				
			||||||
        res = await rpc_call(
 | 
					        res = await rpc_call(
 | 
				
			||||||
            'dgpu-0', 'dgpu_online', {'max_tasks': 3})
 | 
					            'dgpu-0', 'dgpu_online')
 | 
				
			||||||
        logging.info(res)
 | 
					        logging.info(res)
 | 
				
			||||||
        assert 'ok' in res.result
 | 
					        assert 'ok' in res.result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue