From f6326ad05ccf14162f5e1e63aa928a56a522d1bc Mon Sep 17 00:00:00 2001 From: Guillermo Rodriguez Date: Sat, 17 Dec 2022 11:39:42 -0300 Subject: [PATCH] Rework dgpu client to be single task Add a lot of dgpu real image gen tests Modified docker files and environment to allow for quick test relaunch without image rebuild Rename package from skynet_bot to skynet Drop tractor usage cause cuda is oriented to just a single proc managing gpu resources Add ackwnoledge phase to image request for quick gpu disconnected type scenarios Add click entry point for dgpu Add posibility to reuse postgres_db fixture on same session by checking if schema init has been already done --- .dockerignore | 8 +- .gitignore | 3 + Dockerfile.runtime | 10 +- ...le.runtime-cuda => Dockerfile.runtime+cuda | 16 +- build_docker.sh | 2 +- pytest.ini | 2 + requirements.cuda.0.txt | 1 - requirements.test.txt | 3 +- requirements.txt | 2 - scripts/generate_cert.py | 2 +- setup.py | 9 +- {skynet_bot => skynet}/__init__.py | 0 {skynet_bot => skynet}/brain.py | 204 +++++++++---- skynet/cli.py | 68 +++++ {skynet_bot => skynet}/constants.py | 6 +- {skynet_bot => skynet}/db.py | 26 +- skynet/dgpu.py | 197 +++++++++++++ {skynet_bot => skynet}/frontend/__init__.py | 2 +- {skynet_bot => skynet}/frontend/telegram.py | 0 skynet_bot/types.py => skynet/structs.py | 0 skynet/utils.py | 57 ++++ skynet_bot/dgpu.py | 124 -------- skynet_bot/gpu.py | 77 ----- skynet_bot/utils.py | 2 - test.sh | 9 - tests/conftest.py | 57 +++- tests/test_dgpu.py | 275 +++++++++++++++--- tests/test_gpu_workers.py | 107 ------- tests/test_skynet.py | 8 +- 29 files changed, 811 insertions(+), 466 deletions(-) rename Dockerfile.runtime-cuda => Dockerfile.runtime+cuda (61%) rename {skynet_bot => skynet}/__init__.py (100%) rename {skynet_bot => skynet}/brain.py (60%) create mode 100644 skynet/cli.py rename {skynet_bot => skynet}/constants.py (97%) rename {skynet_bot => skynet}/db.py (84%) create mode 100644 skynet/dgpu.py rename {skynet_bot => skynet}/frontend/__init__.py (98%) rename {skynet_bot => skynet}/frontend/telegram.py (100%) rename skynet_bot/types.py => skynet/structs.py (100%) create mode 100644 skynet/utils.py delete mode 100644 skynet_bot/dgpu.py delete mode 100644 skynet_bot/gpu.py delete mode 100644 skynet_bot/utils.py delete mode 100755 test.sh delete mode 100644 tests/test_gpu_workers.py diff --git a/.dockerignore b/.dockerignore index d611665..a9214c5 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,3 +1,9 @@ +.git hf_home -inputs outputs +.python-version +.pytest-cache +**/__pycache__ +*.egg-info +**/*.key +**/*.cert diff --git a/.gitignore b/.gitignore index c60d49a..e264fa9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,8 @@ .python-version hf_home outputs +secrets **/__pycache__ *.egg-info +**/*.key +**/*.cert diff --git a/Dockerfile.runtime b/Dockerfile.runtime index 84b38b6..7f09a6e 100644 --- a/Dockerfile.runtime +++ b/Dockerfile.runtime @@ -4,10 +4,16 @@ env DEBIAN_FRONTEND=noninteractive workdir /skynet -copy requirements.* ./ +copy requirements.test.txt requirements.test.txt +copy requirements.txt requirements.txt +copy pytest.ini ./ +copy setup.py ./ +copy skynet ./skynet run pip install \ + -e . \ -r requirements.txt \ -r requirements.test.txt -workdir /scripts +copy scripts ./ +copy tests ./ diff --git a/Dockerfile.runtime-cuda b/Dockerfile.runtime+cuda similarity index 61% rename from Dockerfile.runtime-cuda rename to Dockerfile.runtime+cuda index af7c77f..48520cd 100644 --- a/Dockerfile.runtime-cuda +++ b/Dockerfile.runtime+cuda @@ -5,19 +5,25 @@ env DEBIAN_FRONTEND=noninteractive workdir /skynet -copy requirements.* ./ +copy requirements.cuda* ./ run pip install -U pip ninja run pip install -r requirements.cuda.0.txt run pip install -v -r requirements.cuda.1.txt -run pip install \ +copy requirements.test.txt requirements.test.txt +copy requirements.txt requirements.txt +copy pytest.ini pytest.ini +copy setup.py setup.py +copy skynet skynet + +run pip install -e . \ -r requirements.txt \ -r requirements.test.txt +env PYTORCH_CUDA_ALLOC_CONF max_split_size_mb:128 env NVIDIA_VISIBLE_DEVICES=all env HF_HOME /hf_home -env PYTORCH_CUDA_ALLOC_CONF max_split_size_mb:128 - -workdir /scripts +copy scripts scripts +copy tests tests diff --git a/build_docker.sh b/build_docker.sh index 72e37f1..5d67269 100755 --- a/build_docker.sh +++ b/build_docker.sh @@ -1,6 +1,6 @@ docker build \ -t skynet:runtime-cuda \ - -f Dockerfile.runtime-cuda . + -f Dockerfile.runtime+cuda . docker build \ -t skynet:runtime \ diff --git a/pytest.ini b/pytest.ini index 5f4a13a..7f91c13 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,2 +1,4 @@ [pytest] +log_cli = True +log_level = info trio_mode = true diff --git a/requirements.cuda.0.txt b/requirements.cuda.0.txt index 3c1ce2a..e31de88 100644 --- a/requirements.cuda.0.txt +++ b/requirements.cuda.0.txt @@ -1,4 +1,3 @@ -pdbpp scipy triton accelerate diff --git a/requirements.test.txt b/requirements.test.txt index 48af1de..9dd1ac9 100644 --- a/requirements.test.txt +++ b/requirements.test.txt @@ -1,5 +1,6 @@ +pdbpp pytest psycopg2 pytest-trio -git+https://github.com/tgoodlet/pytest-dockerctl.git@master#egg=pytest-dockerctl +git+https://github.com/guilledk/pytest-dockerctl.git@host_network#egg=pytest-dockerctl diff --git a/requirements.txt b/requirements.txt index 8220873..b1034c9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,5 +5,3 @@ aiohttp msgspec pyOpenSSL trio_asyncio - -git+https://github.com/goodboy/tractor.git@piker_pin#egg=tractor diff --git a/scripts/generate_cert.py b/scripts/generate_cert.py index 621e4b1..2b1634a 100644 --- a/scripts/generate_cert.py +++ b/scripts/generate_cert.py @@ -8,7 +8,7 @@ import sys from OpenSSL import crypto, SSL -from skynet_bot.constants import DEFAULT_CERTS_DIR +from skynet.constants import DEFAULT_CERTS_DIR def input_or_skip(txt, default): diff --git a/setup.py b/setup.py index f48893c..007883f 100644 --- a/setup.py +++ b/setup.py @@ -1,11 +1,16 @@ from setuptools import setup, find_packages setup( - name='skynet-bot', + name='skynet', version='0.1.0a6', description='Decentralized compute platform', author='Guillermo Rodriguez', author_email='guillermo@telos.net', packages=find_packages(), - install_requires=[] + entry_points={ + 'console_scripts': [ + 'skynet = skynet.cli:skynet', + ] + }, + install_requires=['click'] ) diff --git a/skynet_bot/__init__.py b/skynet/__init__.py similarity index 100% rename from skynet_bot/__init__.py rename to skynet/__init__.py diff --git a/skynet_bot/brain.py b/skynet/brain.py similarity index 60% rename from skynet_bot/brain.py rename to skynet/brain.py index 79b62d2..99732ce 100644 --- a/skynet_bot/brain.py +++ b/skynet/brain.py @@ -8,6 +8,7 @@ import logging from uuid import UUID from pathlib import Path from functools import partial +from contextlib import asynccontextmanager as acm from collections import OrderedDict import trio @@ -17,7 +18,7 @@ import trio_asyncio from pynng import TLSConfig from .db import * -from .types import * +from .structs import * from .constants import * @@ -27,18 +28,47 @@ class SkynetDGPUOffline(BaseException): class SkynetDGPUOverloaded(BaseException): ... +class SkynetDGPUComputeError(BaseException): + ... -async def rpc_service(sock, dgpu_bus, db_pool): +class SkynetShutdownRequested(BaseException): + ... + +@acm +async def open_rpc_service(sock, dgpu_bus, db_pool): nodes = OrderedDict() wip_reqs = {} fin_reqs = {} + next_worker: Optional[int] = None - def is_worker_busy(nid: int): - for task in nodes[nid]['tasks']: - if task != None: - return False + def connect_node(uid): + nonlocal next_worker + nodes[uid] = { + 'task': None + } + logging.info(f'dgpu online: {uid}') - return True + if not next_worker: + next_worker = 0 + + def disconnect_node(uid): + nonlocal next_worker + if uid not in nodes: + return + i = list(nodes.keys()).index(uid) + del nodes[uid] + + if i < next_worker: + next_worker -= 1 + + if len(nodes) == 0: + logging.info('nw: None') + next_worker = None + + logging.warning(f'dgpu offline: {uid}') + + def is_worker_busy(nid: str): + return nodes[nid]['task'] != None def are_all_workers_busy(): for nid in nodes.keys(): @@ -47,30 +77,55 @@ async def rpc_service(sock, dgpu_bus, db_pool): return True - next_worker: Optional[int] = None def get_next_worker(): nonlocal next_worker + logging.info('get next_worker called') + logging.info(f'pre next_worker: {next_worker}') - if not next_worker: + if next_worker == None: raise SkynetDGPUOffline if are_all_workers_busy(): raise SkynetDGPUOverloaded - while is_worker_busy(next_worker): + + nid = list(nodes.keys())[next_worker] + while is_worker_busy(nid): next_worker += 1 if next_worker >= len(nodes): next_worker = 0 - return next_worker + nid = list(nodes.keys())[next_worker] + + next_worker += 1 + if next_worker >= len(nodes): + next_worker = 0 + + logging.info(f'post next_worker: {next_worker}') + + return nid async def dgpu_image_streamer(): nonlocal wip_reqs, fin_reqs while True: msg = await dgpu_bus.arecv_msg() rid = UUID(bytes=msg.bytes[:16]).hex - img = msg.bytes[16:].hex() + raw_msg = msg.bytes[16:] + logging.info(f'streamer got back {rid} of size {len(raw_msg)}') + match raw_msg[:5]: + case b'error': + img = raw_msg.decode() + + case b'ack': + img = raw_msg + + case _: + img = base64.b64encode(raw_msg).hex() + + if rid not in wip_reqs: + continue + fin_reqs[rid] = img event = wip_reqs[rid] event.set() @@ -79,13 +134,14 @@ async def rpc_service(sock, dgpu_bus, db_pool): async def dgpu_stream_one_img(req: ImageGenRequest): nonlocal wip_reqs, fin_reqs, next_worker nid = get_next_worker() - logging.info(f'dgpu_stream_one_img {next_worker} {nid}') + idx = list(nodes.keys()).index(nid) + logging.info(f'dgpu_stream_one_img {idx}/{len(nodes)} {nid}') rid = uuid.uuid4().hex - event = trio.Event() - wip_reqs[rid] = event + ack_event = trio.Event() + img_event = trio.Event() + wip_reqs[rid] = ack_event - tid = nodes[nid]['tasks'].index(None) - nodes[nid]['tasks'][tid] = rid + nodes[nid]['task'] = rid dgpu_req = DGPUBusRequest( rid=rid, @@ -98,14 +154,37 @@ async def rpc_service(sock, dgpu_bus, db_pool): await dgpu_bus.asend( json.dumps(dgpu_req.to_dict()).encode()) - await event.wait() + with trio.move_on_after(4): + await ack_event.wait() - nodes[nid]['tasks'][tid] = None + logging.info(f'ack event: {ack_event.is_set()}') + + if not ack_event.is_set(): + disconnect_node(nid) + raise SkynetDGPUOffline('dgpu failed to acknowledge request') + + ack = fin_reqs[rid] + if ack != b'ack': + disconnect_node(nid) + raise SkynetDGPUOffline('dgpu failed to acknowledge request') + + wip_reqs[rid] = img_event + with trio.move_on_after(30): + await img_event.wait() + + if not img_event.is_set(): + disconnect_node(nid) + raise SkynetDGPUComputeError('30 seconds timeout while processing request') + + nodes[nid]['task'] = None img = fin_reqs[rid] del fin_reqs[rid] - logging.info(f'done streaming {img}') + logging.info(f'done streaming {len(img)} bytes') + + if 'error' in img: + raise SkynetDGPUComputeError(img) return rid, img @@ -122,6 +201,10 @@ async def rpc_service(sock, dgpu_bus, db_pool): user_config = {**(await get_user_config(conn, user))} del user_config['id'] prompt = req.params['prompt'] + user_config= { + key : req.params.get(key, val) + for key, val in user_config.items() + } req = ImageGenRequest( prompt=prompt, **user_config @@ -165,9 +248,10 @@ async def rpc_service(sock, dgpu_bus, db_pool): case _: logging.warn('unknown method') - except SkynetDGPUOffline: + except SkynetDGPUOffline as e: result = { - 'error': 'skynet_dgpu_offline' + 'error': 'skynet_dgpu_offline', + 'message': str(e) } except SkynetDGPUOverloaded: @@ -176,22 +260,22 @@ async def rpc_service(sock, dgpu_bus, db_pool): 'nodes': len(nodes) } - except BaseException as e: - logging.error(e) + except SkynetDGPUComputeError as e: result = { - 'error': 'skynet_internal_error' + 'error': 'skynet_dgpu_compute_error', + 'message': str(e) } await rpc_ctx.asend( json.dumps( SkynetRPCResponse(result=result).to_dict()).encode()) - - async with trio.open_nursery() as n: - n.start_soon(dgpu_image_streamer) + async def request_service(n): + nonlocal next_worker while True: ctx = sock.new_context() msg = await ctx.arecv_msg() + content = msg.bytes.decode() req = SkynetRPCRequest(**json.loads(content)) @@ -199,27 +283,14 @@ async def rpc_service(sock, dgpu_bus, db_pool): result = {} - if req.method == 'dgpu_online': - nodes[req.uid] = { - 'tasks': [None for _ in range(req.params['max_tasks'])], - 'max_tasks': req.params['max_tasks'] - } - logging.info(f'dgpu online: {req.uid}') + if req.method == 'skynet_shutdown': + raise SkynetShutdownRequested - if not next_worker: - next_worker = 0 + elif req.method == 'dgpu_online': + connect_node(req.uid) elif req.method == 'dgpu_offline': - i = list(nodes.keys()).index(req.uid) - del nodes[req.uid] - - if i < next_worker: - next_worker -= 1 - - if len(nodes) == 0: - next_worker = None - - logging.info(f'dgpu offline: {req.uid}') + disconnect_node(req.uid) elif req.method == 'dgpu_workers': result = len(nodes) @@ -238,13 +309,22 @@ async def rpc_service(sock, dgpu_bus, db_pool): result={'ok': result}).to_dict()).encode()) + async with trio.open_nursery() as n: + n.start_soon(dgpu_image_streamer) + n.start_soon(request_service, n) + logging.info('starting rpc service') + yield + logging.info('stopping rpc service') + n.cancel_scope.cancel() + + +@acm async def run_skynet( db_user: str = DB_USER, db_pass: str = DB_PASS, db_host: str = DB_HOST, rpc_address: str = DEFAULT_RPC_ADDR, dgpu_address: str = DEFAULT_DGPU_ADDR, - task_status = trio.TASK_STATUS_IGNORED, security: bool = True ): logging.basicConfig(level=logging.INFO) @@ -260,8 +340,8 @@ async def run_skynet( (cert_path).read_text() for cert_path in (certs_dir / 'whitelist').glob('*.cert')] - logging.info(f'tls_key: {tls_key}') - logging.info(f'tls_cert: {tls_cert}') + cert_start = tls_cert.index('\n') + 1 + logging.info(f'tls_cert: {tls_cert[cert_start:cert_start+64]}...') logging.info(f'tls_whitelist len: {len(tls_whitelist)}') rpc_address = 'tls+' + rpc_address @@ -271,16 +351,14 @@ async def run_skynet( own_key_string=tls_key, own_cert_string=tls_cert) - async with ( - trio.open_nursery() as n, - open_database_connection( - db_user, db_pass, db_host) as db_pool + with ( + pynng.Rep0() as rpc_sock, + pynng.Bus0() as dgpu_bus ): - logging.info('connected to db.') - with ( - pynng.Rep0() as rpc_sock, - pynng.Bus0() as dgpu_bus - ): + async with open_database_connection( + db_user, db_pass, db_host) as db_pool: + + logging.info('connected to db.') if security: rpc_sock.tls_config = tls_config dgpu_bus.tls_config = tls_config @@ -288,13 +366,11 @@ async def run_skynet( rpc_sock.listen(rpc_address) dgpu_bus.listen(dgpu_address) - n.start_soon( - rpc_service, rpc_sock, dgpu_bus, db_pool) - task_status.started() - try: - await trio.sleep_forever() + async with open_rpc_service(rpc_sock, dgpu_bus, db_pool): + yield - except KeyboardInterrupt: + except SkynetShutdownRequested: ... + logging.info('disconnected from db.') diff --git a/skynet/cli.py b/skynet/cli.py new file mode 100644 index 0000000..d0e3fa8 --- /dev/null +++ b/skynet/cli.py @@ -0,0 +1,68 @@ +#!/usr/bin/python + +import os +import json + +from typing import Optional +from functools import partial + +import trio +import click + +from .dgpu import open_dgpu_node +from .utils import txt2img +from .constants import ALGOS + + +@click.group() +def skynet(*args, **kwargs): + pass + +@skynet.command() +@click.option('--model', '-m', default=ALGOS['midj']) +@click.option( + '--prompt', '-p', default='a red tractor in a wheat field') +@click.option('--output', '-o', default='output.png') +@click.option('--width', '-w', default=512) +@click.option('--height', '-h', default=512) +@click.option('--guidance', '-g', default=10.0) +@click.option('--steps', '-s', default=26) +@click.option('--seed', '-S', default=None) +def txt2img(*args +# model: str, +# prompt: str, +# output: str +# width: int, height: int, +# guidance: float, +# steps: int, +# seed: Optional[int] +): + assert 'HF_TOKEN' in os.environ + txt2img( + os.environ['HF_TOKEN'], *args) + +@skynet.group() +def run(*args, **kwargs): + pass + +@run.command() +@click.option('--loglevel', '-l', default='warning', help='Logging level') +@click.option( + '--key', '-k', default='dgpu') +@click.option( + '--cert', '-c', default='whitelist/dgpu') +@click.option( + '--algos', '-a', default=None) +def dgpu( + loglevel: str, + key: str, + cert: str, + algos: Optional[str] +): + trio.run( + partial( + open_dgpu_node, + cert, + key_name=key, + initial_algos=json.loads(algos) + )) diff --git a/skynet_bot/constants.py b/skynet/constants.py similarity index 97% rename from skynet_bot/constants.py rename to skynet/constants.py index 4fe4439..5e7d767 100644 --- a/skynet_bot/constants.py +++ b/skynet/constants.py @@ -1,6 +1,6 @@ #!/usr/bin/python -API_TOKEN = '5880619053:AAFge2UfObw1kCn9Kb7AAyqaIHs_HgM0Fx0' +DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda' DB_HOST = 'ancap.tech:34508' DB_USER = 'skynet' @@ -8,8 +8,8 @@ DB_PASS = 'password' DB_NAME = 'skynet' ALGOS = { - 'stable': 'runwayml/stable-diffusion-v1-5', 'midj': 'prompthero/openjourney', + 'stable': 'runwayml/stable-diffusion-v1-5', 'hdanime': 'Linaqruf/anything-v3.0', 'waifu': 'hakurei/waifu-diffusion', 'ghibli': 'nitrosocke/Ghibli-Diffusion', @@ -122,7 +122,7 @@ DEFAULT_CERT_DGPU = 'dgpu.key' DEFAULT_RPC_ADDR = 'tcp://127.0.0.1:41000' DEFAULT_DGPU_ADDR = 'tcp://127.0.0.1:41069' -DEFAULT_DGPU_MAX_TASKS = 3 +DEFAULT_DGPU_MAX_TASKS = 2 DEFAULT_INITAL_ALGOS = ['midj', 'stable', 'ink'] DATE_FORMAT = '%B the %dth %Y, %H:%M:%S' diff --git a/skynet_bot/db.py b/skynet/db.py similarity index 84% rename from skynet_bot/db.py rename to skynet/db.py index 9998e77..1b12e2c 100644 --- a/skynet_bot/db.py +++ b/skynet/db.py @@ -7,6 +7,9 @@ from contextlib import asynccontextmanager as acm import trio import triopg +import trio_asyncio + +from asyncpg.exceptions import UndefinedColumnError from .constants import * @@ -72,13 +75,22 @@ async def open_database_connection( db_host: str = DB_HOST, db_name: str = DB_NAME ): - async with triopg.create_pool( - dsn=f'postgres://{db_user}:{db_pass}@{db_host}/{db_name}' - ) as pool_conn: - async with pool_conn.acquire() as conn: - await conn.execute(DB_INIT_SQL) + async with trio_asyncio.open_loop() as loop: + async with triopg.create_pool( + dsn=f'postgres://{db_user}:{db_pass}@{db_host}/{db_name}' + ) as pool_conn: + async with pool_conn.acquire() as conn: + res = await conn.execute(f''' + select distinct table_schema + from information_schema.tables + where table_schema = \'{db_name}\' + ''') + if '1' in res: + logging.info('schema already in db, skipping init') + else: + await conn.execute(DB_INIT_SQL) - yield pool_conn + yield pool_conn async def get_user(conn, uid: str): @@ -135,6 +147,7 @@ async def new_user(conn, uid: str): tg_id, generated, joined, last_prompt, role) VALUES($1, $2, $3, $4, $5) + ON CONFLICT DO NOTHING ''') await stmt.fetch( tg_id, 0, date, None, DEFAULT_ROLE @@ -147,6 +160,7 @@ async def new_user(conn, uid: str): id, algo, step, width, height, seed, guidance, upscaler) VALUES($1, $2, $3, $4, $5, $6, $7, $8) + ON CONFLICT DO NOTHING ''') user = await stmt.fetch( new_uid, diff --git a/skynet/dgpu.py b/skynet/dgpu.py new file mode 100644 index 0000000..4efe1b9 --- /dev/null +++ b/skynet/dgpu.py @@ -0,0 +1,197 @@ +#!/usr/bin/python + +import gc +import io +import trio +import json +import uuid +import random +import logging + +from typing import List, Optional +from pathlib import Path +from contextlib import AsyncExitStack + +import pynng +import torch + +from pynng import TLSConfig +from diffusers import ( + StableDiffusionPipeline, + EulerAncestralDiscreteScheduler +) + +from .structs import * +from .constants import * +from .frontend import open_skynet_rpc + + +def pipeline_for(algo: str, mem_fraction: float = 1.0): + assert torch.cuda.is_available() + torch.cuda.empty_cache() + torch.cuda.set_per_process_memory_fraction(mem_fraction) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + params = { + 'torch_dtype': torch.float16, + 'safety_checker': None + } + + if algo == 'stable': + params['revision'] = 'fp16' + + pipe = StableDiffusionPipeline.from_pretrained( + ALGOS[algo], **params) + + pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( + pipe.scheduler.config) + + pipe.enable_vae_slicing() + + return pipe.to("cuda") + + +class DGPUComputeError(BaseException): + ... + + +async def open_dgpu_node( + cert_name: str, + key_name: Optional[str], + rpc_address: str = DEFAULT_RPC_ADDR, + dgpu_address: str = DEFAULT_DGPU_ADDR, + initial_algos: Optional[List[str]] = None, + security: bool = True +): + logging.basicConfig(level=logging.INFO) + logging.info(f'starting dgpu node!') + + name = uuid.uuid4() + + logging.info(f'loading models...') + + initial_algos = ( + initial_algos + if initial_algos else DEFAULT_INITAL_ALGOS + ) + models = {} + for algo in initial_algos: + models[algo] = { + 'pipe': pipeline_for(algo), + 'generated': 0 + } + logging.info(f'loaded {algo}.') + + logging.info('memory summary:\n') + logging.info(torch.cuda.memory_summary()) + + async def gpu_compute_one(ireq: ImageGenRequest): + if ireq.algo not in models: + least_used = list(models.keys())[0] + for model in models: + if models[least_used]['generated'] > models[model]['generated']: + least_used = model + + del models[least_used] + gc.collect() + + models[ireq.algo] = { + 'pipe': pipeline_for(ireq.algo), + 'generated': 0 + } + + seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64) + try: + image = models[ireq.algo]['pipe']( + ireq.prompt, + width=ireq.width, + height=ireq.height, + guidance_scale=ireq.guidance, + num_inference_steps=ireq.step, + generator=torch.Generator("cuda").manual_seed(seed) + ).images[0] + return image.tobytes() + + except BaseException as e: + logging.error(e) + raise DGPUComputeError(str(e)) + + finally: + torch.cuda.empty_cache() + + + async with open_skynet_rpc( + security=security, + cert_name=cert_name, + key_name=key_name + ) as rpc_call: + + tls_config = None + if security: + # load tls certs + if not key_name: + key_name = certs_name + certs_dir = Path(DEFAULT_CERTS_DIR).resolve() + + skynet_cert_path = certs_dir / 'brain.cert' + tls_cert_path = certs_dir / f'{cert_name}.cert' + tls_key_path = certs_dir / f'{key_name}.key' + + skynet_cert = skynet_cert_path.read_text() + tls_cert = tls_cert_path.read_text() + tls_key = tls_key_path.read_text() + + logging.info(f'skynet cert: {skynet_cert_path}') + logging.info(f'dgpu cert: {tls_cert_path}') + logging.info(f'dgpu key: {tls_key_path}') + + dgpu_address = 'tls+' + dgpu_address + tls_config = TLSConfig( + TLSConfig.MODE_CLIENT, + own_key_string=tls_key, + own_cert_string=tls_cert, + ca_string=skynet_cert) + + logging.info(f'connecting to {dgpu_address}') + with pynng.Bus0() as dgpu_sock: + dgpu_sock.tls_config = tls_config + dgpu_sock.dial(dgpu_address) + + res = await rpc_call(name.hex, 'dgpu_online') + logging.info(res) + assert 'ok' in res.result + + try: + while True: + msg = await dgpu_sock.arecv() + req = DGPUBusRequest( + **json.loads(msg.decode())) + + if req.nid != name.hex: + logging.info('witnessed request {req.rid}, for {req.nid}') + continue + + # send ack + await dgpu_sock.asend( + bytes.fromhex(req.rid) + b'ack') + + logging.info(f'sent ack, processing {req.rid}...') + + try: + img = await gpu_compute_one( + ImageGenRequest(**req.params)) + + except DGPUComputeError as e: + img = b'error' + str(e).encode() + + await dgpu_sock.asend( + bytes.fromhex(req.rid) + img) + + except KeyboardInterrupt: + logging.info('interrupt caught, stopping...') + + finally: + res = await rpc_call(name.hex, 'dgpu_offline') + logging.info(res) + assert 'ok' in res.result diff --git a/skynet_bot/frontend/__init__.py b/skynet/frontend/__init__.py similarity index 98% rename from skynet_bot/frontend/__init__.py rename to skynet/frontend/__init__.py index 62ac0af..0532bcd 100644 --- a/skynet_bot/frontend/__init__.py +++ b/skynet/frontend/__init__.py @@ -10,7 +10,7 @@ import pynng from pynng import TLSConfig -from ..types import SkynetRPCRequest, SkynetRPCResponse +from ..structs import SkynetRPCRequest, SkynetRPCResponse from ..constants import * diff --git a/skynet_bot/frontend/telegram.py b/skynet/frontend/telegram.py similarity index 100% rename from skynet_bot/frontend/telegram.py rename to skynet/frontend/telegram.py diff --git a/skynet_bot/types.py b/skynet/structs.py similarity index 100% rename from skynet_bot/types.py rename to skynet/structs.py diff --git a/skynet/utils.py b/skynet/utils.py new file mode 100644 index 0000000..06b8863 --- /dev/null +++ b/skynet/utils.py @@ -0,0 +1,57 @@ +#!/usr/bin/python + +import random + +from typing import Optional +from pathlib import Path + +import torch + +from diffusers import StableDiffusionPipeline +from huggingface_hub import login + + +def txt2img( + hf_token: str, + model_name: str, + prompt: str, + output: str, + width: int, height: int, + guidance: float, + steps: int, + seed: Optional[int] +): + assert torch.cuda.is_available() + torch.cuda.empty_cache() + torch.cuda.set_per_process_memory_fraction(0.333) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + login(token=hf_token) + + params = { + 'torch_dtype': torch.float16, + 'safety_checker': None + } + if model_name == 'runwayml/stable-diffusion-v1-5': + params['revision'] = 'fp16' + + pipe = StableDiffusionPipeline.from_pretrained( + model_name, **params) + + pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( + pipe.scheduler.config) + + pipe = pipe.to("cuda") + + seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64) + prompt = prompt + image = pipe( + prompt, + width=width, + height=height, + guidance_scale=guidance, num_inference_steps=steps, + generator=torch.Generator("cuda").manual_seed(seed) + ).images[0] + + image.save(output) diff --git a/skynet_bot/dgpu.py b/skynet_bot/dgpu.py deleted file mode 100644 index 9f1fde3..0000000 --- a/skynet_bot/dgpu.py +++ /dev/null @@ -1,124 +0,0 @@ -#!/usr/bin/python - -import trio -import json -import uuid -import logging - -import pynng -import tractor - -from . import gpu -from .gpu import open_gpu_worker -from .types import * -from .constants import * -from .frontend import rpc_call - - -async def open_dgpu_node( - cert_name: str, - key_name: Optional[str], - rpc_address: str = DEFAULT_RPC_ADDR, - dgpu_address: str = DEFAULT_DGPU_ADDR, - dgpu_max_tasks: int = DEFAULT_DGPU_MAX_TASKS, - initial_algos: str = DEFAULT_INITAL_ALGOS, - security: bool = True -): - logging.basicConfig(level=logging.INFO) - - name = uuid.uuid4() - workers = initial_algos.copy() - tasks = [None for _ in range(dgpu_max_tasks)] - - portal_map: dict[int, tractor.Portal] - contexts: dict[int, tractor.Context] - - def get_next_worker(need_algo: str): - nonlocal workers, tasks - for task, algo in zip(workers, tasks): - if need_algo == algo and not task: - return workers.index(need_algo) - - return tasks.index(None) - - async def gpu_streamer( - ctx: tractor.Context, - nid: int - ): - nonlocal tasks - async with ctx.open_stream() as stream: - async for img in stream: - tasks[nid]['res'] = img - tasks[nid]['event'].set() - - async def gpu_compute_one(ireq: ImageGenRequest): - wid = get_next_worker(ireq.algo) - event = trio.Event() - - workers[wid] = ireq.algo - tasks[wid] = { - 'res': None, 'event': event} - - await contexts[i].send(ireq) - - await event.wait() - - img = tasks[wid]['res'] - tasks[wid] = None - return img - - - async with open_skynet_rpc( - security=security, - cert_name=cert_name, - key_name=key_name - ) as rpc_call: - with pynng.Bus0(dial=dgpu_address) as dgpu_sock: - async def _process_dgpu_req(req: DGPUBusRequest): - img = await gpu_compute_one( - ImageGenRequest(**req.params)) - await dgpu_sock.asend( - bytes.fromhex(req.rid) + img) - - res = await rpc_call( - name.hex, 'dgpu_online', {'max_tasks': dgpu_max_tasks}) - logging.info(res) - assert 'ok' in res.result - - async with ( - tractor.open_actor_cluster( - modules=['skynet_bot.gpu'], - count=dgpu_max_tasks, - names=[i for i in range(dgpu_max_tasks)] - ) as portal_map, - trio.open_nursery() as n - ): - logging.info(f'starting {dgpu_max_tasks} gpu workers') - async with tractor.gather_contexts(( - portal.open_context( - open_gpu_worker, algo, 1.0 / dgpu_max_tasks) - for portal in portal_map.values() - )) as contexts: - contexts = {i: ctx for i, ctx in enumerate(contexts)} - for i, ctx in contexts.items(): - n.start_soon( - gpu_streamer, ctx, i) - try: - while True: - msg = await dgpu_sock.arecv() - req = DGPUBusRequest( - **json.loads(msg.decode())) - - if req.nid != name.hex: - continue - - logging.info(f'dgpu: {name}, req: {req}') - n.start_soon( - _process_dgpu_req, req) - - except KeyboardInterrupt: - ... - - res = await rpc_call(name.hex, 'dgpu_offline') - logging.info(res) - assert 'ok' in res.result diff --git a/skynet_bot/gpu.py b/skynet_bot/gpu.py deleted file mode 100644 index 2756e0b..0000000 --- a/skynet_bot/gpu.py +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/python - -import io -import random -import logging - -import torch -import tractor - -from diffusers import ( - StableDiffusionPipeline, - EulerAncestralDiscreteScheduler -) - -from .types import ImageGenRequest -from .constants import ALGOS - - -def pipeline_for(algo: str, mem_fraction: float): - assert torch.cuda.is_available() - torch.cuda.empty_cache() - torch.cuda.set_per_process_memory_fraction(mem_fraction) - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - - params = { - 'torch_dtype': torch.float16, - 'safety_checker': None - } - - if algo == 'stable': - params['revision'] = 'fp16' - - pipe = StableDiffusionPipeline.from_pretrained( - ALGOS[algo], **params) - - pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( - pipe.scheduler.config) - - return pipe.to("cuda") - -@tractor.context -async def open_gpu_worker( - ctx: tractor.Context, - start_algo: str, - mem_fraction: float -): - log = tractor.log.get_logger(name='gpu', _root_name='skynet') - log.info(f'starting gpu worker with algo {start_algo}...') - current_algo = start_algo - with torch.no_grad(): - pipe = pipeline_for(current_algo, mem_fraction) - log.info('pipeline loaded') - await ctx.started() - async with ctx.open_stream() as bus: - async for ireq in bus: - if ireq.algo != current_algo: - current_algo = ireq.algo - pipe = pipeline_for(current_algo, mem_fraction) - - seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64) - image = pipe( - ireq.prompt, - width=ireq.width, - height=ireq.height, - guidance_scale=ireq.guidance, - num_inference_steps=ireq.step, - generator=torch.Generator("cuda").manual_seed(seed) - ).images[0] - - torch.cuda.empty_cache() - - # convert PIL.Image to BytesIO - img_bytes = io.BytesIO() - image.save(img_bytes, format='PNG') - await bus.send(img_bytes.getvalue()) - diff --git a/skynet_bot/utils.py b/skynet_bot/utils.py deleted file mode 100644 index 8a60885..0000000 --- a/skynet_bot/utils.py +++ /dev/null @@ -1,2 +0,0 @@ -from OpenSSL.crypto import load_publickey, FILETYPE_PEM, verify, X509 - diff --git a/test.sh b/test.sh deleted file mode 100755 index 7cc11fd..0000000 --- a/test.sh +++ /dev/null @@ -1,9 +0,0 @@ -docker run \ - -it \ - --rm \ - --gpus=all \ - --mount type=bind,source="$(pwd)",target=/skynet \ - skynet:runtime-cuda \ - bash -c \ - "cd /skynet && pip install -e . && \ - pytest $1 --log-cli-level=info" diff --git a/tests/conftest.py b/tests/conftest.py index 4157062..ff58407 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,21 +1,25 @@ #!/usr/bin/python +import os +import json import time import random import string import logging from functools import partial +from pathlib import Path import trio import pytest import psycopg2 import trio_asyncio +from docker.types import Mount, DeviceRequest from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT -from skynet_bot.constants import * -from skynet_bot.brain import run_skynet +from skynet.constants import * +from skynet.brain import run_skynet @pytest.fixture(scope='session') @@ -29,6 +33,7 @@ def postgres_db(dockerctl): with dockerctl.run( 'postgres', + name='skynet-test-postgres', ports={'5432/tcp': None}, environment={ 'POSTGRES_PASSWORD': rpassword @@ -67,6 +72,8 @@ def postgres_db(dockerctl): cursor.execute( f'GRANT ALL PRIVILEGES ON DATABASE {DB_NAME} TO {DB_USER}') + conn.close() + logging.info('done.') yield container, password, host @@ -74,16 +81,44 @@ def postgres_db(dockerctl): @pytest.fixture async def skynet_running(postgres_db): db_container, db_pass, db_host = postgres_db - async with ( - trio_asyncio.open_loop(), - trio.open_nursery() as n + + async with run_skynet( + db_pass=db_pass, + db_host=db_host ): - await n.start( - partial(run_skynet, - db_pass=db_pass, - db_host=db_host)) - yield - n.cancel_scope.cancel() +@pytest.fixture +def dgpu_workers(request, dockerctl, skynet_running): + devices = [DeviceRequest(capabilities=[['gpu']])] + mounts = [Mount( + '/skynet', str(Path().resolve()), type='bind')] + + num_containers, initial_algos = request.param + + cmd = f''' + pip install -e . && \ + skynet run dgpu --algos=\'{json.dumps(initial_algos)}\' + ''' + + logging.info(f'launching: \n{cmd}') + + with dockerctl.run( + DOCKER_RUNTIME_CUDA, + name='skynet-test-runtime-cuda', + command=['bash', '-c', cmd], + environment={ + 'HF_TOKEN': os.environ['HF_TOKEN'], + 'HF_HOME': '/skynet/hf_home' + }, + network='host', + mounts=mounts, + device_requests=devices, + num=num_containers + ) as containers: + yield containers + + #for i, container in enumerate(containers): + # logging.info(f'container {i} logs:') + # logging.info(container.logs().decode()) diff --git a/tests/test_dgpu.py b/tests/test_dgpu.py index cb9cc80..51f1423 100644 --- a/tests/test_dgpu.py +++ b/tests/test_dgpu.py @@ -1,57 +1,248 @@ #!/usr/bin/python +import io import time import json +import base64 import logging +from hashlib import sha256 +from functools import partial + import trio -import pynng +import pytest import tractor import trio_asyncio -from skynet_bot.gpu import open_gpu_worker -from skynet_bot.dgpu import open_dgpu_node -from skynet_bot.types import * -from skynet_bot.brain import run_skynet -from skynet_bot.constants import * -from skynet_bot.frontend import open_skynet_rpc, rpc_call +from PIL import Image + +from skynet.brain import SkynetDGPUComputeError +from skynet.constants import * +from skynet.frontend import open_skynet_rpc -def test_dgpu_simple(): - async def main(): +async def wait_for_dgpus(rpc, amount: int, timeout: float = 30.0): + gpu_ready = False + start_time = time.time() + current_time = time.time() + while not gpu_ready and (current_time - start_time) < timeout: + res = await rpc('dgpu-test', 'dgpu_workers') + logging.info(res) + if res.result['ok'] >= amount: + break + + await trio.sleep(1) + current_time = time.time() + + assert (current_time - start_time) < timeout + + +_images = set() +async def check_request_img( + i: int, + width: int = 512, + height: int = 512, + expect_unique=True +): + global _images + + async with open_skynet_rpc( + security=True, + cert_name='whitelist/testing', + key_name='testing' + ) as rpc_call: + res = await rpc_call( + 'tg+580213293', 'txt2img', { + 'prompt': 'red old tractor in a sunny wheat field', + 'step': 28, + 'width': width, 'height': height, + 'guidance': 7.5, + 'seed': None, + 'algo': list(ALGOS.keys())[i], + 'upscaler': None + }) + + if 'error' in res.result: + raise SkynetDGPUComputeError(json.dumps(res.result)) + + img_raw = base64.b64decode(bytes.fromhex(res.result['img'])) + img_sha = sha256(img_raw).hexdigest() + img = Image.frombytes( + 'RGB', (width, height), img_raw) + + if expect_unique and img_sha in _images: + raise ValueError('Duplicated image sha: {img_sha}') + + _images.add(img_sha) + + logging.info(f'img sha256: {img_sha} size: {len(img_raw)}') + + assert len(img_raw) > 100000 + + +@pytest.mark.parametrize( + 'dgpu_workers', [(1, ['midj'])], indirect=True) +async def test_dgpu_worker_compute_error(dgpu_workers): + '''Attempt to generate a huge image and check we get the right error, + then generate a smaller image to show gpu worker recovery + ''' + + async with open_skynet_rpc( + security=True, + cert_name='whitelist/testing', + key_name='testing' + ) as test_rpc: + await wait_for_dgpus(test_rpc, 1) + + with pytest.raises(SkynetDGPUComputeError) as e: + await check_request_img(0, width=4096, height=4096) + + logging.info(e) + + await check_request_img(0) + + +@pytest.mark.parametrize( + 'dgpu_workers', [(1, ['midj', 'stable'])], indirect=True) +async def test_dgpu_workers(dgpu_workers): + '''Generate two images in a single dgpu worker using + two different models. + ''' + + async with open_skynet_rpc( + security=True, + cert_name='whitelist/testing', + key_name='testing' + ) as test_rpc: + await wait_for_dgpus(test_rpc, 1) + + await check_request_img(0) + await check_request_img(1) + + +@pytest.mark.parametrize( + 'dgpu_workers', [(2, ['midj'])], indirect=True) +async def test_dgpu_workers_two(dgpu_workers): + '''Generate two images in two separate dgpu workers + ''' + async with open_skynet_rpc( + security=True, + cert_name='whitelist/testing', + key_name='testing' + ) as test_rpc: + await wait_for_dgpus(test_rpc, 2) + async with trio.open_nursery() as n: - await n.start( - run_skynet, - 'skynet', '3GbZd6UojbD8V7zWpeFn', 'ancap.tech:34508') - - await trio.sleep(2) - - for i in range(3): - n.start_soon(open_dgpu_node) - - await trio.sleep(1) - start = time.time() - async def request_img(): - with pynng.Req0(dial=DEFAULT_RPC_ADDR) as rpc_sock: - res = await rpc_call( - rpc_sock, 'tg+1', 'txt2img', { - 'prompt': 'test', - 'step': 28, - 'width': 512, 'height': 512, - 'guidance': 7.5, - 'seed': None, - 'algo': 'stable', - 'upscaler': None - }) - - logging.info(res) - - async with trio.open_nursery() as inner_n: - for i in range(3): - inner_n.start_soon(request_img) - - logging.info(f'time elapsed: {time.time() - start}') - n.cancel_scope.cancel() + n.start_soon(check_request_img, 0) + n.start_soon(check_request_img, 0) - trio_asyncio.run(main) +@pytest.mark.parametrize( + 'dgpu_workers', [(1, ['midj'])], indirect=True) +async def test_dgpu_worker_algo_swap(dgpu_workers): + '''Generate an image using a non default model + ''' + async with open_skynet_rpc( + security=True, + cert_name='whitelist/testing', + key_name='testing' + ) as test_rpc: + await wait_for_dgpus(test_rpc, 1) + await check_request_img(5) + + +@pytest.mark.parametrize( + 'dgpu_workers', [(3, ['midj'])], indirect=True) +async def test_dgpu_rotation_next_worker(dgpu_workers): + '''Connect three dgpu workers, disconnect and check next_worker + rotation happens correctly + ''' + async with open_skynet_rpc( + security=True, + cert_name='whitelist/testing', + key_name='testing' + ) as test_rpc: + await wait_for_dgpus(test_rpc, 3) + + res = await test_rpc('testing-rpc', 'dgpu_next') + logging.info(res) + assert 'ok' in res.result + assert res.result['ok'] == 0 + + await check_request_img(0) + + res = await test_rpc('testing-rpc', 'dgpu_next') + logging.info(res) + assert 'ok' in res.result + assert res.result['ok'] == 1 + + await check_request_img(0) + + res = await test_rpc('testing-rpc', 'dgpu_next') + logging.info(res) + assert 'ok' in res.result + assert res.result['ok'] == 2 + + await check_request_img(0) + + res = await test_rpc('testing-rpc', 'dgpu_next') + logging.info(res) + assert 'ok' in res.result + assert res.result['ok'] == 0 + + +@pytest.mark.parametrize( + 'dgpu_workers', [(3, ['midj'])], indirect=True) +async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers): + '''Connect three dgpu workers, disconnect the first one and check + next_worker rotation happens correctly + ''' + async with open_skynet_rpc( + security=True, + cert_name='whitelist/testing', + key_name='testing' + ) as test_rpc: + await wait_for_dgpus(test_rpc, 3) + + await trio.sleep(3) + + # stop worker who's turn is next + for _ in range(2): + ec, out = dgpu_workers[0].exec_run(['pkill', '-INT', '-f', 'skynet']) + assert ec == 0 + + dgpu_workers[0].wait() + + res = await test_rpc('testing-rpc', 'dgpu_workers') + logging.info(res) + assert 'ok' in res.result + assert res.result['ok'] == 2 + + async with trio.open_nursery() as n: + n.start_soon(check_request_img, 0) + n.start_soon(check_request_img, 0) + + +async def test_dgpu_no_ack_node_disconnect(skynet_running): + async with open_skynet_rpc( + security=True, + cert_name='whitelist/testing', + key_name='testing' + ) as rpc_call: + + res = await rpc_call('dgpu-0', 'dgpu_online') + logging.info(res) + assert 'ok' in res.result + + await wait_for_dgpus(rpc_call, 1) + + with pytest.raises(SkynetDGPUComputeError) as e: + await check_request_img(0) + + assert 'dgpu failed to acknowledge request' in str(e) + + res = await rpc_call('testing-rpc', 'dgpu_workers') + logging.info(res) + assert 'ok' in res.result + assert res.result['ok'] == 0 + diff --git a/tests/test_gpu_workers.py b/tests/test_gpu_workers.py deleted file mode 100644 index fca2920..0000000 --- a/tests/test_gpu_workers.py +++ /dev/null @@ -1,107 +0,0 @@ -import trio -import tractor - -from skynet_bot.types import * - -@tractor.context -async def open_fake_worker( - ctx: tractor.Context, - start_algo: str, - mem_fraction: float -): - log = tractor.log.get_logger(name='gpu', _root_name='skynet') - log.info(f'starting gpu worker with algo {start_algo}...') - current_algo = start_algo - log.info('pipeline loaded') - await ctx.started() - async with ctx.open_stream() as bus: - async for ireq in bus: - if ireq: - await bus.send('hello!') - else: - break - -def test_gpu_worker(): - log = tractor.log.get_logger(name='root', _root_name='skynet') - async def main(): - async with ( - tractor.open_nursery(debug_mode=True) as an, - trio.open_nursery() as n - ): - portal = await an.start_actor( - 'gpu_worker', - enable_modules=[__name__], - debug_mode=True - ) - - log.info('portal opened') - async with ( - portal.open_context( - open_fake_worker, - start_algo='midj', - mem_fraction=0.6 - ) as (ctx, _), - ctx.open_stream() as stream, - ): - log.info('opened worker sending req...') - ireq = ImageGenRequest( - prompt='a red tractor on a wheat field', - step=28, - width=512, height=512, - guidance=10, seed=None, - algo='midj', upscaler=None) - - await stream.send(ireq) - log.info('sent, await respnse') - async for msg in stream: - log.info(f'got {msg}') - break - - assert msg == 'hello!' - await stream.send(None) - log.info('done.') - - await portal.cancel_actor() - - trio.run(main) - - -def test_gpu_two_workers(): - async def main(): - outputs = [] - async with ( - tractor.open_actor_cluster( - modules=[__name__], - count=2, - names=[0, 1]) as portal_map, - tractor.trionics.gather_contexts(( - portal.open_context( - open_fake_worker, - start_algo='midj', - mem_fraction=0.333) - for portal in portal_map.values() - )) as contexts, - trio.open_nursery() as n - ): - ireq = ImageGenRequest( - prompt='a red tractor on a wheat field', - step=28, - width=512, height=512, - guidance=10, seed=None, - algo='midj', upscaler=None) - - async def get_img(i): - ctx = contexts[i] - async with ctx.open_stream() as stream: - await stream.send(ireq) - async for img in stream: - outputs[i] = img - await portal_map[i].cancel_actor() - - n.start_soon(get_img, 0) - n.start_soon(get_img, 1) - - - assert len(outputs) == 2 - - trio.run(main) diff --git a/tests/test_skynet.py b/tests/test_skynet.py index 5c9367a..c6f5a89 100644 --- a/tests/test_skynet.py +++ b/tests/test_skynet.py @@ -7,9 +7,9 @@ import pynng import pytest import trio_asyncio -from skynet_bot.types import * -from skynet_bot.brain import run_skynet -from skynet_bot.frontend import open_skynet_rpc +from skynet.brain import run_skynet +from skynet.structs import * +from skynet.frontend import open_skynet_rpc async def test_skynet_attempt_insecure(skynet_running): @@ -40,7 +40,7 @@ async def test_skynet_dgpu_connection_simple(skynet_running): # connect 1 dgpu res = await rpc_call( - 'dgpu-0', 'dgpu_online', {'max_tasks': 3}) + 'dgpu-0', 'dgpu_online') logging.info(res) assert 'ok' in res.result