mirror of https://github.com/skygpu/skynet.git
				
				
				
			Add authenticated messaging, also cmd line utils txt2img and upscale
							parent
							
								
									f6326ad05c
								
							
						
					
					
						commit
						6bc555f0d6
					
				
							
								
								
									
										2
									
								
								setup.py
								
								
								
								
							
							
						
						
									
										2
									
								
								setup.py
								
								
								
								
							| 
						 | 
				
			
			@ -10,6 +10,8 @@ setup(
 | 
			
		|||
    entry_points={
 | 
			
		||||
        'console_scripts': [
 | 
			
		||||
            'skynet = skynet.cli:skynet',
 | 
			
		||||
            'txt2img = skynet.cli:txt2img',
 | 
			
		||||
            'upscale = skynet.cli:upscale'
 | 
			
		||||
        ]
 | 
			
		||||
    },
 | 
			
		||||
    install_requires=['click']
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										110
									
								
								skynet/brain.py
								
								
								
								
							
							
						
						
									
										110
									
								
								skynet/brain.py
								
								
								
								
							| 
						 | 
				
			
			@ -16,6 +16,11 @@ import pynng
 | 
			
		|||
import trio_asyncio
 | 
			
		||||
 | 
			
		||||
from pynng import TLSConfig
 | 
			
		||||
from OpenSSL.crypto import (
 | 
			
		||||
    load_privatekey,
 | 
			
		||||
    load_certificate,
 | 
			
		||||
    FILETYPE_PEM
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from .db import *
 | 
			
		||||
from .structs import *
 | 
			
		||||
| 
						 | 
				
			
			@ -34,12 +39,14 @@ class SkynetDGPUComputeError(BaseException):
 | 
			
		|||
class SkynetShutdownRequested(BaseException):
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@acm
 | 
			
		||||
async def open_rpc_service(sock, dgpu_bus, db_pool):
 | 
			
		||||
async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
 | 
			
		||||
    nodes = OrderedDict()
 | 
			
		||||
    wip_reqs = {}
 | 
			
		||||
    fin_reqs = {}
 | 
			
		||||
    next_worker: Optional[int] = None
 | 
			
		||||
    security = len(tls_whitelist) > 0
 | 
			
		||||
 | 
			
		||||
    def connect_node(uid):
 | 
			
		||||
        nonlocal next_worker
 | 
			
		||||
| 
						 | 
				
			
			@ -109,27 +116,20 @@ async def open_rpc_service(sock, dgpu_bus, db_pool):
 | 
			
		|||
    async def dgpu_image_streamer():
 | 
			
		||||
        nonlocal wip_reqs, fin_reqs
 | 
			
		||||
        while True:
 | 
			
		||||
            msg = await dgpu_bus.arecv_msg()
 | 
			
		||||
            rid = UUID(bytes=msg.bytes[:16]).hex
 | 
			
		||||
            raw_msg = msg.bytes[16:]
 | 
			
		||||
            logging.info(f'streamer got back {rid} of size {len(raw_msg)}')
 | 
			
		||||
            match raw_msg[:5]:
 | 
			
		||||
                case b'error':
 | 
			
		||||
                    img = raw_msg.decode()
 | 
			
		||||
            msg = DGPUBusResponse(
 | 
			
		||||
                **json.loads(
 | 
			
		||||
                    (await dgpu_bus.arecv()).decode()))
 | 
			
		||||
 | 
			
		||||
                case b'ack':
 | 
			
		||||
                    img = raw_msg
 | 
			
		||||
            if security:
 | 
			
		||||
                msg.verify(tls_whitelist[msg.cert])
 | 
			
		||||
 | 
			
		||||
                case _:
 | 
			
		||||
                    img = base64.b64encode(raw_msg).hex()
 | 
			
		||||
 | 
			
		||||
            if rid not in wip_reqs:
 | 
			
		||||
            if msg.rid not in wip_reqs:
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            fin_reqs[rid] = img
 | 
			
		||||
            event = wip_reqs[rid]
 | 
			
		||||
            fin_reqs[msg.rid] = msg
 | 
			
		||||
            event = wip_reqs[msg.rid]
 | 
			
		||||
            event.set()
 | 
			
		||||
            del wip_reqs[rid]
 | 
			
		||||
            del wip_reqs[msg.rid]
 | 
			
		||||
 | 
			
		||||
    async def dgpu_stream_one_img(req: ImageGenRequest):
 | 
			
		||||
        nonlocal wip_reqs, fin_reqs, next_worker
 | 
			
		||||
| 
						 | 
				
			
			@ -151,6 +151,9 @@ async def open_rpc_service(sock, dgpu_bus, db_pool):
 | 
			
		|||
 | 
			
		||||
        logging.info(f'dgpu_bus req: {dgpu_req}')
 | 
			
		||||
 | 
			
		||||
        if security:
 | 
			
		||||
            dgpu_req.sign(tls_key, 'skynet')
 | 
			
		||||
 | 
			
		||||
        await dgpu_bus.asend(
 | 
			
		||||
            json.dumps(dgpu_req.to_dict()).encode())
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -163,8 +166,8 @@ async def open_rpc_service(sock, dgpu_bus, db_pool):
 | 
			
		|||
            disconnect_node(nid)
 | 
			
		||||
            raise SkynetDGPUOffline('dgpu failed to acknowledge request')
 | 
			
		||||
 | 
			
		||||
        ack = fin_reqs[rid]
 | 
			
		||||
        if ack != b'ack':
 | 
			
		||||
        ack_msg = fin_reqs[rid]
 | 
			
		||||
        if 'ack' not in ack_msg.params:
 | 
			
		||||
            disconnect_node(nid)
 | 
			
		||||
            raise SkynetDGPUOffline('dgpu failed to acknowledge request')
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -178,15 +181,13 @@ async def open_rpc_service(sock, dgpu_bus, db_pool):
 | 
			
		|||
 | 
			
		||||
        nodes[nid]['task'] = None
 | 
			
		||||
 | 
			
		||||
        img = fin_reqs[rid]
 | 
			
		||||
        img_resp = fin_reqs[rid]
 | 
			
		||||
        del fin_reqs[rid]
 | 
			
		||||
 | 
			
		||||
        logging.info(f'done streaming {len(img)} bytes')
 | 
			
		||||
        if 'error' in img_resp.params:
 | 
			
		||||
            raise SkynetDGPUComputeError(img_resp.params['error'])
 | 
			
		||||
 | 
			
		||||
        if 'error' in img:
 | 
			
		||||
            raise SkynetDGPUComputeError(img)
 | 
			
		||||
 | 
			
		||||
        return rid, img
 | 
			
		||||
        return rid, img_resp.params['img']
 | 
			
		||||
 | 
			
		||||
    async def handle_user_request(rpc_ctx, req):
 | 
			
		||||
        try:
 | 
			
		||||
| 
						 | 
				
			
			@ -266,9 +267,13 @@ async def open_rpc_service(sock, dgpu_bus, db_pool):
 | 
			
		|||
                'message': str(e)
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        resp = SkynetRPCResponse(result=result)
 | 
			
		||||
 | 
			
		||||
        if security:
 | 
			
		||||
            resp.sign(tls_key, 'skynet')
 | 
			
		||||
 | 
			
		||||
        await rpc_ctx.asend(
 | 
			
		||||
            json.dumps(
 | 
			
		||||
                SkynetRPCResponse(result=result).to_dict()).encode())
 | 
			
		||||
            json.dumps(resp.to_dict()).encode())
 | 
			
		||||
 | 
			
		||||
    async def request_service(n):
 | 
			
		||||
        nonlocal next_worker
 | 
			
		||||
| 
						 | 
				
			
			@ -279,7 +284,19 @@ async def open_rpc_service(sock, dgpu_bus, db_pool):
 | 
			
		|||
            content = msg.bytes.decode()
 | 
			
		||||
            req = SkynetRPCRequest(**json.loads(content))
 | 
			
		||||
 | 
			
		||||
            logging.info(req)
 | 
			
		||||
            if security:
 | 
			
		||||
                if req.cert not in tls_whitelist:
 | 
			
		||||
                    logging.warning(
 | 
			
		||||
                        f'{req.cert} not in tls whitelist and security=True')
 | 
			
		||||
                    continue
 | 
			
		||||
 | 
			
		||||
                try:
 | 
			
		||||
                    req.verify(tls_whitelist[req.cert])
 | 
			
		||||
 | 
			
		||||
                except ValueError:
 | 
			
		||||
                    logging.warning(
 | 
			
		||||
                        f'{req.cert} sent an unauthenticated msg with security=True')
 | 
			
		||||
                    continue
 | 
			
		||||
 | 
			
		||||
            result = {}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -303,10 +320,14 @@ async def open_rpc_service(sock, dgpu_bus, db_pool):
 | 
			
		|||
                    handle_user_request, ctx, req)
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            resp = SkynetRPCResponse(
 | 
			
		||||
                result={'ok': result})
 | 
			
		||||
 | 
			
		||||
            if security:
 | 
			
		||||
                resp.sign(tls_key, 'skynet')
 | 
			
		||||
 | 
			
		||||
            await ctx.asend(
 | 
			
		||||
                json.dumps(
 | 
			
		||||
                    SkynetRPCResponse(
 | 
			
		||||
                        result={'ok': result}).to_dict()).encode())
 | 
			
		||||
                json.dumps(resp.to_dict()).encode())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    async with trio.open_nursery() as n:
 | 
			
		||||
| 
						 | 
				
			
			@ -334,22 +355,28 @@ async def run_skynet(
 | 
			
		|||
    if security:
 | 
			
		||||
        # load tls certs
 | 
			
		||||
        certs_dir = Path(DEFAULT_CERTS_DIR).resolve()
 | 
			
		||||
        tls_key = (certs_dir / DEFAULT_CERT_SKYNET_PRIV).read_text()
 | 
			
		||||
        tls_cert = (certs_dir / DEFAULT_CERT_SKYNET_PUB).read_text()
 | 
			
		||||
        tls_whitelist = [
 | 
			
		||||
            (cert_path).read_text()
 | 
			
		||||
            for cert_path in (certs_dir / 'whitelist').glob('*.cert')]
 | 
			
		||||
 | 
			
		||||
        cert_start = tls_cert.index('\n') + 1
 | 
			
		||||
        logging.info(f'tls_cert: {tls_cert[cert_start:cert_start+64]}...')
 | 
			
		||||
        tls_key_data = (certs_dir / DEFAULT_CERT_SKYNET_PRIV).read_text()
 | 
			
		||||
        tls_key = load_privatekey(FILETYPE_PEM, tls_key_data)
 | 
			
		||||
 | 
			
		||||
        tls_cert_data = (certs_dir / DEFAULT_CERT_SKYNET_PUB).read_text()
 | 
			
		||||
        tls_cert = load_certificate(FILETYPE_PEM, tls_cert_data)
 | 
			
		||||
 | 
			
		||||
        tls_whitelist = {}
 | 
			
		||||
        for cert_path in (certs_dir / 'whitelist').glob('*.cert'):
 | 
			
		||||
            tls_whitelist[cert_path.stem] = load_certificate(
 | 
			
		||||
                FILETYPE_PEM, cert_path.read_text())
 | 
			
		||||
 | 
			
		||||
        cert_start = tls_cert_data.index('\n') + 1
 | 
			
		||||
        logging.info(f'tls_cert: {tls_cert_data[cert_start:cert_start+64]}...')
 | 
			
		||||
        logging.info(f'tls_whitelist len: {len(tls_whitelist)}')
 | 
			
		||||
 | 
			
		||||
        rpc_address = 'tls+' + rpc_address
 | 
			
		||||
        dgpu_address = 'tls+' + dgpu_address
 | 
			
		||||
        tls_config = TLSConfig(
 | 
			
		||||
            TLSConfig.MODE_SERVER,
 | 
			
		||||
            own_key_string=tls_key,
 | 
			
		||||
            own_cert_string=tls_cert)
 | 
			
		||||
            own_key_string=tls_key_data,
 | 
			
		||||
            own_cert_string=tls_cert_data)
 | 
			
		||||
 | 
			
		||||
    with (
 | 
			
		||||
        pynng.Rep0() as rpc_sock,
 | 
			
		||||
| 
						 | 
				
			
			@ -367,7 +394,8 @@ async def run_skynet(
 | 
			
		|||
            dgpu_bus.listen(dgpu_address)
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                async with open_rpc_service(rpc_sock, dgpu_bus, db_pool):
 | 
			
		||||
                async with open_rpc_service(
 | 
			
		||||
                    rpc_sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
 | 
			
		||||
                    yield
 | 
			
		||||
 | 
			
		||||
            except SkynetShutdownRequested:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -9,44 +9,78 @@ from functools import partial
 | 
			
		|||
import trio
 | 
			
		||||
import click
 | 
			
		||||
 | 
			
		||||
from . import utils
 | 
			
		||||
from .dgpu import open_dgpu_node
 | 
			
		||||
from .utils import txt2img
 | 
			
		||||
from .constants import ALGOS
 | 
			
		||||
from .brain import run_skynet
 | 
			
		||||
 | 
			
		||||
from .frontend.telegram import run_skynet_telegram
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@click.group()
 | 
			
		||||
def skynet(*args, **kwargs):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
@skynet.command()
 | 
			
		||||
@click.option('--model', '-m', default=ALGOS['midj'])
 | 
			
		||||
 | 
			
		||||
@click.command()
 | 
			
		||||
@click.option('--model', '-m', default='midj')
 | 
			
		||||
@click.option(
 | 
			
		||||
    '--prompt', '-p', default='a red tractor in a wheat field')
 | 
			
		||||
    '--prompt', '-p', default='a red old tractor in a sunny wheat field')
 | 
			
		||||
@click.option('--output', '-o', default='output.png')
 | 
			
		||||
@click.option('--width', '-w', default=512)
 | 
			
		||||
@click.option('--height', '-h', default=512)
 | 
			
		||||
@click.option('--guidance', '-g', default=10.0)
 | 
			
		||||
@click.option('--steps', '-s', default=26)
 | 
			
		||||
@click.option('--seed', '-S', default=None)
 | 
			
		||||
def txt2img(*args
 | 
			
		||||
#     model: str,
 | 
			
		||||
#     prompt: str,
 | 
			
		||||
#     output: str
 | 
			
		||||
#     width: int, height: int,
 | 
			
		||||
#     guidance: float,
 | 
			
		||||
#     steps: int,
 | 
			
		||||
#     seed: Optional[int]
 | 
			
		||||
):
 | 
			
		||||
def txt2img(*args, **kwargs):
 | 
			
		||||
    assert 'HF_TOKEN' in os.environ
 | 
			
		||||
    txt2img(
 | 
			
		||||
        os.environ['HF_TOKEN'], *args)
 | 
			
		||||
    utils.txt2img(os.environ['HF_TOKEN'], **kwargs)
 | 
			
		||||
 | 
			
		||||
@click.command()
 | 
			
		||||
@click.option(
 | 
			
		||||
    '--prompt', '-p', default='a red old tractor in a sunny wheat field')
 | 
			
		||||
@click.option('--input', '-i', default='input.png')
 | 
			
		||||
@click.option('--output', '-o', default='output.png')
 | 
			
		||||
@click.option('--steps', '-s', default=26)
 | 
			
		||||
def upscale(prompt, input, output, steps):
 | 
			
		||||
    assert 'HF_TOKEN' in os.environ
 | 
			
		||||
    utils.upscale(
 | 
			
		||||
        os.environ['HF_TOKEN'],
 | 
			
		||||
        prompt=prompt,
 | 
			
		||||
        img_path=input,
 | 
			
		||||
        output=output,
 | 
			
		||||
        steps=steps)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@skynet.group()
 | 
			
		||||
def run(*args, **kwargs):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@run.command()
 | 
			
		||||
@click.option('--loglevel', '-l', default='warning', help='Logging level')
 | 
			
		||||
@click.option(
 | 
			
		||||
    '--host', '-h', default='localhost:5432')
 | 
			
		||||
@click.option(
 | 
			
		||||
    '--pass', '-p', default='password')
 | 
			
		||||
def skynet(
 | 
			
		||||
    loglevel: str,
 | 
			
		||||
    host: str,
 | 
			
		||||
    passw: str
 | 
			
		||||
):
 | 
			
		||||
    async def _run_skynet():
 | 
			
		||||
        async with run_skynet(
 | 
			
		||||
            db_host=host,
 | 
			
		||||
            db_pass=passw
 | 
			
		||||
        ):
 | 
			
		||||
            await trio.sleep_forever()
 | 
			
		||||
 | 
			
		||||
    trio_asyncio.run(_run_skynet)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@run.command()
 | 
			
		||||
@click.option('--loglevel', '-l', default='warning', help='Logging level')
 | 
			
		||||
@click.option(
 | 
			
		||||
    '--uid', '-u', required=True)
 | 
			
		||||
@click.option(
 | 
			
		||||
    '--key', '-k', default='dgpu')
 | 
			
		||||
@click.option(
 | 
			
		||||
| 
						 | 
				
			
			@ -55,6 +89,7 @@ def run(*args, **kwargs):
 | 
			
		|||
    '--algos', '-a', default=None)
 | 
			
		||||
def dgpu(
 | 
			
		||||
    loglevel: str,
 | 
			
		||||
    uid: str,
 | 
			
		||||
    key: str,
 | 
			
		||||
    cert: str,
 | 
			
		||||
    algos: Optional[str]
 | 
			
		||||
| 
						 | 
				
			
			@ -63,6 +98,28 @@ def dgpu(
 | 
			
		|||
        partial(
 | 
			
		||||
            open_dgpu_node,
 | 
			
		||||
            cert,
 | 
			
		||||
            uid,
 | 
			
		||||
            key_name=key,
 | 
			
		||||
            initial_algos=json.loads(algos)
 | 
			
		||||
    ))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@run.command()
 | 
			
		||||
@click.option('--loglevel', '-l', default='warning', help='Logging level')
 | 
			
		||||
@click.option(
 | 
			
		||||
    '--key', '-k', default='telegram-frontend')
 | 
			
		||||
@click.option(
 | 
			
		||||
    '--cert', '-c', default='whitelist/telegram-frontend')
 | 
			
		||||
def telegram(
 | 
			
		||||
    loglevel: str,
 | 
			
		||||
    key: str,
 | 
			
		||||
    cert: str
 | 
			
		||||
):
 | 
			
		||||
    assert 'TG_TOKEN' in os.environ
 | 
			
		||||
    trio_asyncio.run(
 | 
			
		||||
        partial(
 | 
			
		||||
            run_skynet_telegram,
 | 
			
		||||
            os.environ['TG_TOKEN'],
 | 
			
		||||
            key_name=key,
 | 
			
		||||
            cert_name=cert
 | 
			
		||||
    ))
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										45
									
								
								skynet/db.py
								
								
								
								
							
							
						
						
									
										45
									
								
								skynet/db.py
								
								
								
								
							| 
						 | 
				
			
			@ -58,13 +58,16 @@ ALTER TABLE skynet.user_config
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def try_decode_uid(uid: str):
 | 
			
		||||
    if isinstance(uid, int):
 | 
			
		||||
        return None, uid
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        proto, uid = uid.split('+')
 | 
			
		||||
        uid = int(uid)
 | 
			
		||||
        return proto, uid
 | 
			
		||||
 | 
			
		||||
    except ValueError:
 | 
			
		||||
        logging.warning(f'got non numeric uid?: {uid}')
 | 
			
		||||
        logging.warning(f'got non chat proto uid?: {uid}')
 | 
			
		||||
        return None, None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -132,28 +135,38 @@ async def new_user(conn, uid: str):
 | 
			
		|||
 | 
			
		||||
    logging.info(f'new user! {uid}')
 | 
			
		||||
 | 
			
		||||
    tg_id = None
 | 
			
		||||
    date = datetime.utcnow()
 | 
			
		||||
 | 
			
		||||
    proto, pid = try_decode_uid(uid)
 | 
			
		||||
 | 
			
		||||
    match proto:
 | 
			
		||||
        case 'tg':
 | 
			
		||||
            tg_id = pid
 | 
			
		||||
 | 
			
		||||
    async with conn.transaction():
 | 
			
		||||
        stmt = await conn.prepare('''
 | 
			
		||||
            INSERT INTO skynet.user(
 | 
			
		||||
                tg_id, generated, joined, last_prompt, role)
 | 
			
		||||
        match proto:
 | 
			
		||||
            case 'tg':
 | 
			
		||||
                tg_id = pid
 | 
			
		||||
                stmt = await conn.prepare('''
 | 
			
		||||
                    INSERT INTO skynet.user(
 | 
			
		||||
                        tg_id, generated, joined, last_prompt, role)
 | 
			
		||||
 | 
			
		||||
            VALUES($1, $2, $3, $4, $5)
 | 
			
		||||
            ON CONFLICT DO NOTHING
 | 
			
		||||
        ''')
 | 
			
		||||
        await stmt.fetch(
 | 
			
		||||
            tg_id, 0, date, None, DEFAULT_ROLE
 | 
			
		||||
        )
 | 
			
		||||
                    VALUES($1, $2, $3, $4, $5)
 | 
			
		||||
                    ON CONFLICT DO NOTHING
 | 
			
		||||
                ''')
 | 
			
		||||
                await stmt.fetch(
 | 
			
		||||
                    tg_id, 0, date, None, DEFAULT_ROLE
 | 
			
		||||
                )
 | 
			
		||||
                new_uid = await get_user(conn, uid)
 | 
			
		||||
 | 
			
		||||
        new_uid = await get_user(conn, uid)
 | 
			
		||||
            case None:
 | 
			
		||||
                stmt = await conn.prepare('''
 | 
			
		||||
                    INSERT INTO skynet.user(
 | 
			
		||||
                        id, generated, joined, last_prompt, role)
 | 
			
		||||
 | 
			
		||||
                    VALUES($1, $2, $3, $4, $5)
 | 
			
		||||
                    ON CONFLICT DO NOTHING
 | 
			
		||||
                ''')
 | 
			
		||||
                await stmt.fetch(
 | 
			
		||||
                    pid, 0, date, None, DEFAULT_ROLE
 | 
			
		||||
                )
 | 
			
		||||
                new_uid = pid
 | 
			
		||||
 | 
			
		||||
        stmt = await conn.prepare('''
 | 
			
		||||
            INSERT INTO skynet.user_config(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -5,6 +5,7 @@ import io
 | 
			
		|||
import trio
 | 
			
		||||
import json
 | 
			
		||||
import uuid
 | 
			
		||||
import base64
 | 
			
		||||
import random
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -16,10 +17,16 @@ import pynng
 | 
			
		|||
import torch
 | 
			
		||||
 | 
			
		||||
from pynng import TLSConfig
 | 
			
		||||
from OpenSSL.crypto import (
 | 
			
		||||
    load_privatekey,
 | 
			
		||||
    load_certificate,
 | 
			
		||||
    FILETYPE_PEM
 | 
			
		||||
)
 | 
			
		||||
from diffusers import (
 | 
			
		||||
    StableDiffusionPipeline,
 | 
			
		||||
    EulerAncestralDiscreteScheduler
 | 
			
		||||
)
 | 
			
		||||
from diffusers.models import UNet2DConditionModel
 | 
			
		||||
 | 
			
		||||
from .structs import *
 | 
			
		||||
from .constants import *
 | 
			
		||||
| 
						 | 
				
			
			@ -58,6 +65,7 @@ class DGPUComputeError(BaseException):
 | 
			
		|||
 | 
			
		||||
async def open_dgpu_node(
 | 
			
		||||
    cert_name: str,
 | 
			
		||||
    unique_id: str,
 | 
			
		||||
    key_name: Optional[str],
 | 
			
		||||
    rpc_address: str = DEFAULT_RPC_ADDR,
 | 
			
		||||
    dgpu_address: str = DEFAULT_DGPU_ADDR,
 | 
			
		||||
| 
						 | 
				
			
			@ -122,6 +130,7 @@ async def open_dgpu_node(
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
    async with open_skynet_rpc(
 | 
			
		||||
        unique_id,
 | 
			
		||||
        security=security,
 | 
			
		||||
        cert_name=cert_name,
 | 
			
		||||
        key_name=key_name
 | 
			
		||||
| 
						 | 
				
			
			@ -131,16 +140,23 @@ async def open_dgpu_node(
 | 
			
		|||
        if security:
 | 
			
		||||
            # load tls certs
 | 
			
		||||
            if not key_name:
 | 
			
		||||
                key_name = certs_name
 | 
			
		||||
                key_name = cert_name
 | 
			
		||||
 | 
			
		||||
            certs_dir = Path(DEFAULT_CERTS_DIR).resolve()
 | 
			
		||||
 | 
			
		||||
            skynet_cert_path = certs_dir / 'brain.cert'
 | 
			
		||||
            tls_cert_path = certs_dir / f'{cert_name}.cert'
 | 
			
		||||
            tls_key_path = certs_dir / f'{key_name}.key'
 | 
			
		||||
 | 
			
		||||
            skynet_cert = skynet_cert_path.read_text()
 | 
			
		||||
            tls_cert = tls_cert_path.read_text()
 | 
			
		||||
            tls_key = tls_key_path.read_text()
 | 
			
		||||
            cert_name = tls_cert_path.stem
 | 
			
		||||
 | 
			
		||||
            skynet_cert_data = skynet_cert_path.read_text()
 | 
			
		||||
            skynet_cert = load_certificate(FILETYPE_PEM, skynet_cert_data)
 | 
			
		||||
 | 
			
		||||
            tls_cert_data = tls_cert_path.read_text()
 | 
			
		||||
 | 
			
		||||
            tls_key_data = tls_key_path.read_text()
 | 
			
		||||
            tls_key = load_privatekey(FILETYPE_PEM, tls_key_data)
 | 
			
		||||
 | 
			
		||||
            logging.info(f'skynet cert: {skynet_cert_path}')
 | 
			
		||||
            logging.info(f'dgpu cert:  {tls_cert_path}')
 | 
			
		||||
| 
						 | 
				
			
			@ -149,17 +165,16 @@ async def open_dgpu_node(
 | 
			
		|||
            dgpu_address = 'tls+' + dgpu_address
 | 
			
		||||
            tls_config = TLSConfig(
 | 
			
		||||
                TLSConfig.MODE_CLIENT,
 | 
			
		||||
                own_key_string=tls_key,
 | 
			
		||||
                own_cert_string=tls_cert,
 | 
			
		||||
                ca_string=skynet_cert)
 | 
			
		||||
                own_key_string=tls_key_data,
 | 
			
		||||
                own_cert_string=tls_cert_data,
 | 
			
		||||
                ca_string=skynet_cert_data)
 | 
			
		||||
 | 
			
		||||
        logging.info(f'connecting to {dgpu_address}')
 | 
			
		||||
        with pynng.Bus0() as dgpu_sock:
 | 
			
		||||
            dgpu_sock.tls_config = tls_config
 | 
			
		||||
            dgpu_sock.dial(dgpu_address)
 | 
			
		||||
 | 
			
		||||
            res = await rpc_call(name.hex, 'dgpu_online')
 | 
			
		||||
            logging.info(res)
 | 
			
		||||
            res = await rpc_call('dgpu_online')
 | 
			
		||||
            assert 'ok' in res.result
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
| 
						 | 
				
			
			@ -168,30 +183,55 @@ async def open_dgpu_node(
 | 
			
		|||
                    req = DGPUBusRequest(
 | 
			
		||||
                        **json.loads(msg.decode()))
 | 
			
		||||
 | 
			
		||||
                    if req.nid != name.hex:
 | 
			
		||||
                        logging.info('witnessed request {req.rid}, for {req.nid}')
 | 
			
		||||
                    if req.nid != unique_id:
 | 
			
		||||
                        logging.info(
 | 
			
		||||
                            f'witnessed msg {req.rid}, node involved: {req.nid}')
 | 
			
		||||
                        continue
 | 
			
		||||
 | 
			
		||||
                    if security:
 | 
			
		||||
                        req.verify(skynet_cert)
 | 
			
		||||
 | 
			
		||||
                    ack_resp = DGPUBusResponse(
 | 
			
		||||
                        rid=req.rid,
 | 
			
		||||
                        nid=req.nid,
 | 
			
		||||
                        params={'ack': {}}
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                    if security:
 | 
			
		||||
                        ack_resp.sign(tls_key, cert_name)
 | 
			
		||||
 | 
			
		||||
                    # send ack
 | 
			
		||||
                    await dgpu_sock.asend(
 | 
			
		||||
                        bytes.fromhex(req.rid) + b'ack')
 | 
			
		||||
                        json.dumps(ack_resp.to_dict()).encode())
 | 
			
		||||
 | 
			
		||||
                    logging.info(f'sent ack, processing {req.rid}...')
 | 
			
		||||
 | 
			
		||||
                    try:
 | 
			
		||||
                        img = await gpu_compute_one(
 | 
			
		||||
                            ImageGenRequest(**req.params))
 | 
			
		||||
                        img_resp = DGPUBusResponse(
 | 
			
		||||
                            rid=req.rid,
 | 
			
		||||
                            nid=req.nid,
 | 
			
		||||
                            params={'img': base64.b64encode(img).hex()}
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
                    except DGPUComputeError as e:
 | 
			
		||||
                        img = b'error' + str(e).encode()
 | 
			
		||||
                        img_resp = DGPUBusResponse(
 | 
			
		||||
                            rid=req.rid,
 | 
			
		||||
                            nid=req.nid,
 | 
			
		||||
                            params={'error': str(e)}
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
                    if security:
 | 
			
		||||
                        img_resp.sign(tls_key, cert_name)
 | 
			
		||||
 | 
			
		||||
                    # send final image
 | 
			
		||||
                    await dgpu_sock.asend(
 | 
			
		||||
                        bytes.fromhex(req.rid) + img)
 | 
			
		||||
                        json.dumps(img_resp.to_dict()).encode())
 | 
			
		||||
 | 
			
		||||
            except KeyboardInterrupt:
 | 
			
		||||
                logging.info('interrupt caught, stopping...')
 | 
			
		||||
 | 
			
		||||
            finally:
 | 
			
		||||
                res = await rpc_call(name.hex, 'dgpu_offline')
 | 
			
		||||
                logging.info(res)
 | 
			
		||||
                res = await rpc_call('dgpu_offline')
 | 
			
		||||
                assert 'ok' in res.result
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -9,6 +9,11 @@ from contextlib import asynccontextmanager as acm
 | 
			
		|||
import pynng
 | 
			
		||||
 | 
			
		||||
from pynng import TLSConfig
 | 
			
		||||
from OpenSSL.crypto import (
 | 
			
		||||
    load_privatekey,
 | 
			
		||||
    load_certificate,
 | 
			
		||||
    FILETYPE_PEM
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from ..structs import SkynetRPCRequest, SkynetRPCResponse
 | 
			
		||||
from ..constants import *
 | 
			
		||||
| 
						 | 
				
			
			@ -30,56 +35,73 @@ class ConfigSizeDivisionByEight(BaseException):
 | 
			
		|||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def rpc_call(
 | 
			
		||||
    sock,
 | 
			
		||||
    uid: Union[int, str],
 | 
			
		||||
    method: str,
 | 
			
		||||
    params: dict = {}
 | 
			
		||||
):
 | 
			
		||||
    req = SkynetRPCRequest(
 | 
			
		||||
        uid=uid,
 | 
			
		||||
        method=method,
 | 
			
		||||
        params=params
 | 
			
		||||
    )
 | 
			
		||||
    await sock.asend(
 | 
			
		||||
        json.dumps(
 | 
			
		||||
            req.to_dict()).encode())
 | 
			
		||||
 | 
			
		||||
    return SkynetRPCResponse(
 | 
			
		||||
        **json.loads(
 | 
			
		||||
            (await sock.arecv_msg()).bytes.decode()))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@acm
 | 
			
		||||
async def open_skynet_rpc(
 | 
			
		||||
    unique_id: str,
 | 
			
		||||
    rpc_address: str = DEFAULT_RPC_ADDR,
 | 
			
		||||
    security: bool = False,
 | 
			
		||||
    cert_name: Optional[str] = None,
 | 
			
		||||
    key_name: Optional[str] = None
 | 
			
		||||
):
 | 
			
		||||
    tls_config = None
 | 
			
		||||
 | 
			
		||||
    if security:
 | 
			
		||||
        # load tls certs
 | 
			
		||||
        if not key_name:
 | 
			
		||||
            key_name = certs_name
 | 
			
		||||
            key_name = cert_name
 | 
			
		||||
 | 
			
		||||
        certs_dir = Path(DEFAULT_CERTS_DIR).resolve()
 | 
			
		||||
        skynet_cert = (certs_dir / 'brain.cert').read_text()
 | 
			
		||||
        tls_cert = (certs_dir / f'{cert_name}.cert').read_text()
 | 
			
		||||
        tls_key = (certs_dir / f'{key_name}.key').read_text()
 | 
			
		||||
 | 
			
		||||
        skynet_cert_data = (certs_dir / 'brain.cert').read_text()
 | 
			
		||||
        skynet_cert = load_certificate(FILETYPE_PEM, skynet_cert_data)
 | 
			
		||||
 | 
			
		||||
        tls_cert_path = certs_dir / f'{cert_name}.cert'
 | 
			
		||||
        tls_cert_data = tls_cert_path.read_text()
 | 
			
		||||
        tls_cert = load_certificate(FILETYPE_PEM, tls_cert_data)
 | 
			
		||||
        cert_name = tls_cert_path.stem
 | 
			
		||||
 | 
			
		||||
        tls_key_data = (certs_dir / f'{key_name}.key').read_text()
 | 
			
		||||
        tls_key = load_privatekey(FILETYPE_PEM, tls_key_data)
 | 
			
		||||
 | 
			
		||||
        rpc_address = 'tls+' + rpc_address
 | 
			
		||||
        tls_config = TLSConfig(
 | 
			
		||||
            TLSConfig.MODE_CLIENT,
 | 
			
		||||
            own_key_string=tls_key,
 | 
			
		||||
            own_cert_string=tls_cert,
 | 
			
		||||
            ca_string=skynet_cert)
 | 
			
		||||
            own_key_string=tls_key_data,
 | 
			
		||||
            own_cert_string=tls_cert_data,
 | 
			
		||||
            ca_string=skynet_cert_data)
 | 
			
		||||
 | 
			
		||||
    with pynng.Req0() as sock:
 | 
			
		||||
        if security:
 | 
			
		||||
            sock.tls_config = tls_config
 | 
			
		||||
 | 
			
		||||
        sock.dial(rpc_address)
 | 
			
		||||
        async def _rpc_call(*args, **kwargs):
 | 
			
		||||
            return await rpc_call(sock, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
        async def _rpc_call(
 | 
			
		||||
            method: str,
 | 
			
		||||
            params: dict = {},
 | 
			
		||||
            uid: Optional[Union[int, str]] = None
 | 
			
		||||
        ):
 | 
			
		||||
            req = SkynetRPCRequest(
 | 
			
		||||
                uid=uid if uid else unique_id,
 | 
			
		||||
                method=method,
 | 
			
		||||
                params=params
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            if security:
 | 
			
		||||
                req.sign(tls_key, cert_name)
 | 
			
		||||
 | 
			
		||||
            await sock.asend(
 | 
			
		||||
                json.dumps(
 | 
			
		||||
                    req.to_dict()).encode())
 | 
			
		||||
 | 
			
		||||
            resp = SkynetRPCResponse(
 | 
			
		||||
                **json.loads(
 | 
			
		||||
                    (await sock.arecv_msg()).bytes.decode()))
 | 
			
		||||
 | 
			
		||||
            if security:
 | 
			
		||||
                resp.verify(skynet_cert)
 | 
			
		||||
 | 
			
		||||
            return resp
 | 
			
		||||
 | 
			
		||||
        yield _rpc_call
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,14 +18,19 @@ PREFIX = 'tg'
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
async def run_skynet_telegram(
 | 
			
		||||
    tg_token: str
 | 
			
		||||
    tg_token: str,
 | 
			
		||||
    key_name: str = 'telegram-frontend',
 | 
			
		||||
    cert_name: str = 'whitelist/telegram-frontend'
 | 
			
		||||
):
 | 
			
		||||
 | 
			
		||||
    logging.basicConfig(level=logging.INFO)
 | 
			
		||||
    bot = AsyncTeleBot(tg_token)
 | 
			
		||||
 | 
			
		||||
    with open_skynet_rpc(
 | 
			
		||||
        security=True, cert_name='telegram-frontend'
 | 
			
		||||
        'skynet-telegram-0',
 | 
			
		||||
        security=True,
 | 
			
		||||
        cert_name=cert,
 | 
			
		||||
        key_name=key
 | 
			
		||||
    ) as rpc_call:
 | 
			
		||||
 | 
			
		||||
        async def _rpc_call(
 | 
			
		||||
| 
						 | 
				
			
			@ -33,7 +38,8 @@ async def run_skynet_telegram(
 | 
			
		|||
            method: str,
 | 
			
		||||
            params: dict
 | 
			
		||||
        ):
 | 
			
		||||
            return await rpc_call(f'{PREFIX}+{uid}', method, params)
 | 
			
		||||
            return await rpc_call(
 | 
			
		||||
                method, params, uid=f'{PREFIX}+{uid}')
 | 
			
		||||
 | 
			
		||||
        @bot.message_handler(commands=['help'])
 | 
			
		||||
        async def send_help(message):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,6 +18,8 @@
 | 
			
		|||
Built-in (extension) types.
 | 
			
		||||
"""
 | 
			
		||||
import sys
 | 
			
		||||
import json
 | 
			
		||||
 | 
			
		||||
from typing import Optional, Union
 | 
			
		||||
from pprint import pformat
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -83,15 +85,51 @@ class Struct(
 | 
			
		|||
            setattr(self, fname, ftype(getattr(self, fname)))
 | 
			
		||||
 | 
			
		||||
# proto
 | 
			
		||||
from OpenSSL.crypto import PKey, X509, verify, sign
 | 
			
		||||
 | 
			
		||||
class SkynetRPCRequest(Struct):
 | 
			
		||||
 | 
			
		||||
class AuthenticatedStruct(Struct):
 | 
			
		||||
    cert: Optional[str] = None
 | 
			
		||||
    sig: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
    def to_unsigned_dict(self) -> dict:
 | 
			
		||||
        self_dict = self.to_dict()
 | 
			
		||||
 | 
			
		||||
        if 'sig' in self_dict:
 | 
			
		||||
            del self_dict['sig']
 | 
			
		||||
 | 
			
		||||
        if 'cert' in self_dict:
 | 
			
		||||
            del self_dict['cert']
 | 
			
		||||
 | 
			
		||||
        return self_dict
 | 
			
		||||
 | 
			
		||||
    def unsigned_to_bytes(self) -> bytes:
 | 
			
		||||
        return json.dumps(
 | 
			
		||||
            self.to_unsigned_dict()).encode()
 | 
			
		||||
 | 
			
		||||
    def sign(self, key: PKey, cert: str):
 | 
			
		||||
        self.cert = cert
 | 
			
		||||
        self.sig = sign(
 | 
			
		||||
            key, self.unsigned_to_bytes(), 'sha256').hex()
 | 
			
		||||
 | 
			
		||||
    def verify(self, cert: X509):
 | 
			
		||||
        if not self.sig:
 | 
			
		||||
            raise ValueError('Tried to verify unsigned request')
 | 
			
		||||
 | 
			
		||||
        return verify(
 | 
			
		||||
            cert, bytes.fromhex(self.sig), self.unsigned_to_bytes(), 'sha256')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SkynetRPCRequest(AuthenticatedStruct):
 | 
			
		||||
    uid: Union[str, int]  # user unique id
 | 
			
		||||
    method: str  # rpc method name
 | 
			
		||||
    params: dict  # variable params
 | 
			
		||||
 | 
			
		||||
class SkynetRPCResponse(Struct):
 | 
			
		||||
 | 
			
		||||
class SkynetRPCResponse(AuthenticatedStruct):
 | 
			
		||||
    result: dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ImageGenRequest(Struct):
 | 
			
		||||
    prompt: str
 | 
			
		||||
    step: int
 | 
			
		||||
| 
						 | 
				
			
			@ -102,8 +140,15 @@ class ImageGenRequest(Struct):
 | 
			
		|||
    algo: str
 | 
			
		||||
    upscaler: Optional[str]
 | 
			
		||||
 | 
			
		||||
class DGPUBusRequest(Struct):
 | 
			
		||||
 | 
			
		||||
class DGPUBusRequest(AuthenticatedStruct):
 | 
			
		||||
    rid: str  # req id
 | 
			
		||||
    nid: str  # node id
 | 
			
		||||
    task: str
 | 
			
		||||
    params: dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DGPUBusResponse(AuthenticatedStruct):
 | 
			
		||||
    rid: str  # req id
 | 
			
		||||
    nid: str  # node id
 | 
			
		||||
    params: dict
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -7,44 +7,37 @@ from pathlib import Path
 | 
			
		|||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from diffusers import StableDiffusionPipeline
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from diffusers import (
 | 
			
		||||
    StableDiffusionPipeline,
 | 
			
		||||
    StableDiffusionUpscalePipeline,
 | 
			
		||||
    EulerAncestralDiscreteScheduler
 | 
			
		||||
)
 | 
			
		||||
from huggingface_hub import login
 | 
			
		||||
 | 
			
		||||
from .dgpu import pipeline_for
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def txt2img(
 | 
			
		||||
    hf_token: str,
 | 
			
		||||
    model_name: str,
 | 
			
		||||
    prompt: str,
 | 
			
		||||
    output: str,
 | 
			
		||||
    width: int, height: int,
 | 
			
		||||
    guidance: float,
 | 
			
		||||
    steps: int,
 | 
			
		||||
    seed: Optional[int]
 | 
			
		||||
    model: str = 'midj',
 | 
			
		||||
    prompt: str = 'a red old tractor in a sunny wheat field',
 | 
			
		||||
    output: str = 'output.png',
 | 
			
		||||
    width: int = 512, height: int = 512,
 | 
			
		||||
    guidance: float = 10,
 | 
			
		||||
    steps: int = 28,
 | 
			
		||||
    seed: Optional[int] = None
 | 
			
		||||
):
 | 
			
		||||
    assert torch.cuda.is_available()
 | 
			
		||||
    torch.cuda.empty_cache()
 | 
			
		||||
    torch.cuda.set_per_process_memory_fraction(0.333)
 | 
			
		||||
    torch.cuda.set_per_process_memory_fraction(1.0)
 | 
			
		||||
    torch.backends.cuda.matmul.allow_tf32 = True
 | 
			
		||||
    torch.backends.cudnn.allow_tf32 = True
 | 
			
		||||
 | 
			
		||||
    login(token=hf_token)
 | 
			
		||||
    pipe = pipeline_for(model)
 | 
			
		||||
 | 
			
		||||
    params = {
 | 
			
		||||
        'torch_dtype': torch.float16,
 | 
			
		||||
        'safety_checker': None
 | 
			
		||||
    }
 | 
			
		||||
    if model_name == 'runwayml/stable-diffusion-v1-5':
 | 
			
		||||
        params['revision'] = 'fp16'
 | 
			
		||||
 | 
			
		||||
    pipe = StableDiffusionPipeline.from_pretrained(
 | 
			
		||||
        model_name, **params)
 | 
			
		||||
 | 
			
		||||
    pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
 | 
			
		||||
        pipe.scheduler.config)
 | 
			
		||||
 | 
			
		||||
    pipe = pipe.to("cuda")
 | 
			
		||||
 | 
			
		||||
    seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64)
 | 
			
		||||
    seed = seed if seed else random.randint(0, 2 ** 64)
 | 
			
		||||
    prompt = prompt
 | 
			
		||||
    image = pipe(
 | 
			
		||||
        prompt,
 | 
			
		||||
| 
						 | 
				
			
			@ -55,3 +48,34 @@ def txt2img(
 | 
			
		|||
    ).images[0]
 | 
			
		||||
 | 
			
		||||
    image.save(output)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def upscale(
 | 
			
		||||
    hf_token: str,
 | 
			
		||||
    prompt: str = 'a red old tractor in a sunny wheat field',
 | 
			
		||||
    img_path: str = 'input.png',
 | 
			
		||||
    output: str = 'output.png',
 | 
			
		||||
    steps: int = 28
 | 
			
		||||
):
 | 
			
		||||
    assert torch.cuda.is_available()
 | 
			
		||||
    torch.cuda.empty_cache()
 | 
			
		||||
    torch.cuda.set_per_process_memory_fraction(1.0)
 | 
			
		||||
    torch.backends.cuda.matmul.allow_tf32 = True
 | 
			
		||||
    torch.backends.cudnn.allow_tf32 = True
 | 
			
		||||
 | 
			
		||||
    login(token=hf_token)
 | 
			
		||||
    params = {
 | 
			
		||||
        'torch_dtype': torch.float16,
 | 
			
		||||
        'safety_checker': None
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pipe = StableDiffusionUpscalePipeline.from_pretrained(
 | 
			
		||||
        'stabilityai/stable-diffusion-x4-upscaler', **params)
 | 
			
		||||
 | 
			
		||||
    prompt = prompt
 | 
			
		||||
    image = pipe(
 | 
			
		||||
        prompt,
 | 
			
		||||
        image=Image.open(img_path)
 | 
			
		||||
    ).images[0]
 | 
			
		||||
 | 
			
		||||
    image.save(output)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -97,17 +97,22 @@ def dgpu_workers(request, dockerctl, skynet_running):
 | 
			
		|||
 | 
			
		||||
    num_containers, initial_algos = request.param
 | 
			
		||||
 | 
			
		||||
    cmd = f'''
 | 
			
		||||
    pip install -e . && \
 | 
			
		||||
    skynet run dgpu --algos=\'{json.dumps(initial_algos)}\'
 | 
			
		||||
    '''
 | 
			
		||||
    cmds = []
 | 
			
		||||
    for i in range(num_containers):
 | 
			
		||||
        cmd = f'''
 | 
			
		||||
        pip install -e . && \
 | 
			
		||||
        skynet run dgpu \
 | 
			
		||||
        --algos=\'{json.dumps(initial_algos)}\' \
 | 
			
		||||
        --uid=dgpu-{i}
 | 
			
		||||
        '''
 | 
			
		||||
        cmds.append(['bash', '-c', cmd])
 | 
			
		||||
 | 
			
		||||
    logging.info(f'launching: \n{cmd}')
 | 
			
		||||
 | 
			
		||||
    with dockerctl.run(
 | 
			
		||||
        DOCKER_RUNTIME_CUDA,
 | 
			
		||||
        name='skynet-test-runtime-cuda',
 | 
			
		||||
        command=['bash', '-c', cmd],
 | 
			
		||||
        commands=cmds,
 | 
			
		||||
        environment={
 | 
			
		||||
            'HF_TOKEN': os.environ['HF_TOKEN'],
 | 
			
		||||
            'HF_HOME': '/skynet/hf_home'
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -26,8 +26,7 @@ async def wait_for_dgpus(rpc, amount: int, timeout: float = 30.0):
 | 
			
		|||
    start_time = time.time()
 | 
			
		||||
    current_time = time.time()
 | 
			
		||||
    while not gpu_ready and (current_time - start_time) < timeout:
 | 
			
		||||
        res = await rpc('dgpu-test', 'dgpu_workers')
 | 
			
		||||
        logging.info(res)
 | 
			
		||||
        res = await rpc('dgpu_workers')
 | 
			
		||||
        if res.result['ok'] >= amount:
 | 
			
		||||
            break
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -40,6 +39,7 @@ async def wait_for_dgpus(rpc, amount: int, timeout: float = 30.0):
 | 
			
		|||
_images = set()
 | 
			
		||||
async def check_request_img(
 | 
			
		||||
    i: int,
 | 
			
		||||
    uid: int = 0,
 | 
			
		||||
    width: int = 512,
 | 
			
		||||
    height: int = 512,
 | 
			
		||||
    expect_unique=True
 | 
			
		||||
| 
						 | 
				
			
			@ -47,12 +47,13 @@ async def check_request_img(
 | 
			
		|||
    global _images
 | 
			
		||||
 | 
			
		||||
    async with open_skynet_rpc(
 | 
			
		||||
        uid,
 | 
			
		||||
        security=True,
 | 
			
		||||
        cert_name='whitelist/testing',
 | 
			
		||||
        key_name='testing'
 | 
			
		||||
    ) as rpc_call:
 | 
			
		||||
        res = await rpc_call(
 | 
			
		||||
            'tg+580213293', 'txt2img', {
 | 
			
		||||
            'txt2img', {
 | 
			
		||||
                'prompt': 'red old tractor in a sunny wheat field',
 | 
			
		||||
                'step': 28,
 | 
			
		||||
                'width': width, 'height': height,
 | 
			
		||||
| 
						 | 
				
			
			@ -88,6 +89,7 @@ async def test_dgpu_worker_compute_error(dgpu_workers):
 | 
			
		|||
    '''
 | 
			
		||||
 | 
			
		||||
    async with open_skynet_rpc(
 | 
			
		||||
        'test-ctx',
 | 
			
		||||
        security=True,
 | 
			
		||||
        cert_name='whitelist/testing',
 | 
			
		||||
        key_name='testing'
 | 
			
		||||
| 
						 | 
				
			
			@ -110,6 +112,7 @@ async def test_dgpu_workers(dgpu_workers):
 | 
			
		|||
    '''
 | 
			
		||||
 | 
			
		||||
    async with open_skynet_rpc(
 | 
			
		||||
        'test-ctx',
 | 
			
		||||
        security=True,
 | 
			
		||||
        cert_name='whitelist/testing',
 | 
			
		||||
        key_name='testing'
 | 
			
		||||
| 
						 | 
				
			
			@ -126,6 +129,7 @@ async def test_dgpu_workers_two(dgpu_workers):
 | 
			
		|||
    '''Generate two images in two separate dgpu workers
 | 
			
		||||
    '''
 | 
			
		||||
    async with open_skynet_rpc(
 | 
			
		||||
        'test-ctx',
 | 
			
		||||
        security=True,
 | 
			
		||||
        cert_name='whitelist/testing',
 | 
			
		||||
        key_name='testing'
 | 
			
		||||
| 
						 | 
				
			
			@ -143,6 +147,7 @@ async def test_dgpu_worker_algo_swap(dgpu_workers):
 | 
			
		|||
    '''Generate an image using a non default model
 | 
			
		||||
    '''
 | 
			
		||||
    async with open_skynet_rpc(
 | 
			
		||||
        'test-ctx',
 | 
			
		||||
        security=True,
 | 
			
		||||
        cert_name='whitelist/testing',
 | 
			
		||||
        key_name='testing'
 | 
			
		||||
| 
						 | 
				
			
			@ -158,35 +163,32 @@ async def test_dgpu_rotation_next_worker(dgpu_workers):
 | 
			
		|||
    rotation happens correctly
 | 
			
		||||
    '''
 | 
			
		||||
    async with open_skynet_rpc(
 | 
			
		||||
        'test-ctx',
 | 
			
		||||
        security=True,
 | 
			
		||||
        cert_name='whitelist/testing',
 | 
			
		||||
        key_name='testing'
 | 
			
		||||
    ) as test_rpc:
 | 
			
		||||
        await wait_for_dgpus(test_rpc, 3)
 | 
			
		||||
 | 
			
		||||
        res = await test_rpc('testing-rpc', 'dgpu_next')
 | 
			
		||||
        logging.info(res)
 | 
			
		||||
        res = await test_rpc('dgpu_next')
 | 
			
		||||
        assert 'ok' in res.result
 | 
			
		||||
        assert res.result['ok'] == 0
 | 
			
		||||
 | 
			
		||||
        await check_request_img(0)
 | 
			
		||||
 | 
			
		||||
        res = await test_rpc('testing-rpc', 'dgpu_next')
 | 
			
		||||
        logging.info(res)
 | 
			
		||||
        res = await test_rpc('dgpu_next')
 | 
			
		||||
        assert 'ok' in res.result
 | 
			
		||||
        assert res.result['ok'] == 1
 | 
			
		||||
 | 
			
		||||
        await check_request_img(0)
 | 
			
		||||
 | 
			
		||||
        res = await test_rpc('testing-rpc', 'dgpu_next')
 | 
			
		||||
        logging.info(res)
 | 
			
		||||
        res = await test_rpc('dgpu_next')
 | 
			
		||||
        assert 'ok' in res.result
 | 
			
		||||
        assert res.result['ok'] == 2
 | 
			
		||||
 | 
			
		||||
        await check_request_img(0)
 | 
			
		||||
 | 
			
		||||
        res = await test_rpc('testing-rpc', 'dgpu_next')
 | 
			
		||||
        logging.info(res)
 | 
			
		||||
        res = await test_rpc('dgpu_next')
 | 
			
		||||
        assert 'ok' in res.result
 | 
			
		||||
        assert res.result['ok'] == 0
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -198,6 +200,7 @@ async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers):
 | 
			
		|||
    next_worker rotation happens correctly
 | 
			
		||||
    '''
 | 
			
		||||
    async with open_skynet_rpc(
 | 
			
		||||
        'test-ctx',
 | 
			
		||||
        security=True,
 | 
			
		||||
        cert_name='whitelist/testing',
 | 
			
		||||
        key_name='testing'
 | 
			
		||||
| 
						 | 
				
			
			@ -213,8 +216,7 @@ async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers):
 | 
			
		|||
 | 
			
		||||
        dgpu_workers[0].wait()
 | 
			
		||||
 | 
			
		||||
        res = await test_rpc('testing-rpc', 'dgpu_workers')
 | 
			
		||||
        logging.info(res)
 | 
			
		||||
        res = await test_rpc('dgpu_workers')
 | 
			
		||||
        assert 'ok' in res.result
 | 
			
		||||
        assert res.result['ok'] == 2
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -224,14 +226,17 @@ async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
async def test_dgpu_no_ack_node_disconnect(skynet_running):
 | 
			
		||||
    '''Mock a node that connects, gets a request but fails to
 | 
			
		||||
    acknowledge it, then check skynet correctly drops the node
 | 
			
		||||
    '''
 | 
			
		||||
    async with open_skynet_rpc(
 | 
			
		||||
        'test-ctx',
 | 
			
		||||
        security=True,
 | 
			
		||||
        cert_name='whitelist/testing',
 | 
			
		||||
        key_name='testing'
 | 
			
		||||
    ) as rpc_call:
 | 
			
		||||
 | 
			
		||||
        res = await rpc_call('dgpu-0', 'dgpu_online')
 | 
			
		||||
        logging.info(res)
 | 
			
		||||
        res = await rpc_call('dgpu_online')
 | 
			
		||||
        assert 'ok' in res.result
 | 
			
		||||
 | 
			
		||||
        await wait_for_dgpus(rpc_call, 1)
 | 
			
		||||
| 
						 | 
				
			
			@ -241,8 +246,34 @@ async def test_dgpu_no_ack_node_disconnect(skynet_running):
 | 
			
		|||
 | 
			
		||||
        assert 'dgpu failed to acknowledge request' in str(e)
 | 
			
		||||
 | 
			
		||||
        res = await rpc_call('testing-rpc', 'dgpu_workers')
 | 
			
		||||
        logging.info(res)
 | 
			
		||||
        res = await rpc_call('dgpu_workers')
 | 
			
		||||
        assert 'ok' in res.result
 | 
			
		||||
        assert res.result['ok'] == 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize(
 | 
			
		||||
    'dgpu_workers', [(1, ['midj'])], indirect=True)
 | 
			
		||||
async def test_dgpu_timeout_while_processing(dgpu_workers):
 | 
			
		||||
    '''Stop node while processing request to cause timeout and
 | 
			
		||||
    then check skynet correctly drops the node.
 | 
			
		||||
    '''
 | 
			
		||||
    async with open_skynet_rpc(
 | 
			
		||||
        'test-ctx',
 | 
			
		||||
        security=True,
 | 
			
		||||
        cert_name='whitelist/testing',
 | 
			
		||||
        key_name='testing'
 | 
			
		||||
    ) as test_rpc:
 | 
			
		||||
        await wait_for_dgpus(test_rpc, 1)
 | 
			
		||||
 | 
			
		||||
        async def check_request_img_raises():
 | 
			
		||||
            with pytest.raises(SkynetDGPUComputeError) as e:
 | 
			
		||||
                await check_request_img(0)
 | 
			
		||||
 | 
			
		||||
            assert 'timeout while processing request' in str(e)
 | 
			
		||||
 | 
			
		||||
        async with trio.open_nursery() as n:
 | 
			
		||||
            n.start_soon(check_request_img_raises)
 | 
			
		||||
            await trio.sleep(1)
 | 
			
		||||
            ec, out = dgpu_workers[0].exec_run(
 | 
			
		||||
                ['pkill', '-TERM', '-f', 'skynet'])
 | 
			
		||||
            assert ec == 0
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -12,64 +12,59 @@ from skynet.structs import *
 | 
			
		|||
from skynet.frontend import open_skynet_rpc
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def test_skynet(skynet_running):
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def test_skynet_attempt_insecure(skynet_running):
 | 
			
		||||
    with pytest.raises(pynng.exceptions.NNGException) as e:
 | 
			
		||||
        async with open_skynet_rpc():
 | 
			
		||||
            ...
 | 
			
		||||
        async with open_skynet_rpc('bad-actor'):
 | 
			
		||||
           ...
 | 
			
		||||
 | 
			
		||||
    assert str(e.value) == 'Connection shutdown'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def test_skynet_dgpu_connection_simple(skynet_running):
 | 
			
		||||
    async with open_skynet_rpc(
 | 
			
		||||
        'dgpu-0',
 | 
			
		||||
        security=True,
 | 
			
		||||
        cert_name='whitelist/testing',
 | 
			
		||||
        key_name='testing'
 | 
			
		||||
    ) as rpc_call:
 | 
			
		||||
        # check 0 nodes are connected
 | 
			
		||||
        res = await rpc_call('dgpu-0', 'dgpu_workers')
 | 
			
		||||
        logging.info(res)
 | 
			
		||||
        res = await rpc_call('dgpu_workers')
 | 
			
		||||
        assert 'ok' in res.result
 | 
			
		||||
        assert res.result['ok'] == 0
 | 
			
		||||
 | 
			
		||||
        # check next worker is None
 | 
			
		||||
        res = await rpc_call('dgpu-0', 'dgpu_next')
 | 
			
		||||
        logging.info(res)
 | 
			
		||||
        res = await rpc_call('dgpu_next')
 | 
			
		||||
        assert 'ok' in res.result
 | 
			
		||||
        assert res.result['ok'] == None
 | 
			
		||||
 | 
			
		||||
        # connect 1 dgpu
 | 
			
		||||
        res = await rpc_call(
 | 
			
		||||
            'dgpu-0', 'dgpu_online')
 | 
			
		||||
        logging.info(res)
 | 
			
		||||
        res = await rpc_call('dgpu_online')
 | 
			
		||||
        assert 'ok' in res.result
 | 
			
		||||
 | 
			
		||||
        # check 1 node is connected
 | 
			
		||||
        res = await rpc_call('dgpu-0', 'dgpu_workers')
 | 
			
		||||
        logging.info(res)
 | 
			
		||||
        res = await rpc_call('dgpu_workers')
 | 
			
		||||
        assert 'ok' in res.result
 | 
			
		||||
        assert res.result['ok'] == 1
 | 
			
		||||
 | 
			
		||||
        # check next worker is 0
 | 
			
		||||
        res = await rpc_call('dgpu-0', 'dgpu_next')
 | 
			
		||||
        logging.info(res)
 | 
			
		||||
        res = await rpc_call('dgpu_next')
 | 
			
		||||
        assert 'ok' in res.result
 | 
			
		||||
        assert res.result['ok'] == 0
 | 
			
		||||
 | 
			
		||||
        # disconnect 1 dgpu
 | 
			
		||||
        res = await rpc_call(
 | 
			
		||||
            'dgpu-0', 'dgpu_offline')
 | 
			
		||||
        logging.info(res)
 | 
			
		||||
        res = await rpc_call('dgpu_offline')
 | 
			
		||||
        assert 'ok' in res.result
 | 
			
		||||
 | 
			
		||||
        # check 0 nodes are connected
 | 
			
		||||
        res = await rpc_call('dgpu-0', 'dgpu_workers')
 | 
			
		||||
        logging.info(res)
 | 
			
		||||
        res = await rpc_call('dgpu_workers')
 | 
			
		||||
        assert 'ok' in res.result
 | 
			
		||||
        assert res.result['ok'] == 0
 | 
			
		||||
 | 
			
		||||
        # check next worker is None
 | 
			
		||||
        res = await rpc_call('dgpu-0', 'dgpu_next')
 | 
			
		||||
        logging.info(res)
 | 
			
		||||
        res = await rpc_call('dgpu_next')
 | 
			
		||||
        assert 'ok' in res.result
 | 
			
		||||
        assert res.result['ok'] == None
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue