2022-12-10 21:18:03 +00:00
|
|
|
#!/usr/bin/python
|
|
|
|
|
2022-12-17 14:39:42 +00:00
|
|
|
import io
|
2022-12-10 21:18:03 +00:00
|
|
|
import time
|
|
|
|
import json
|
2023-01-06 17:36:50 +00:00
|
|
|
import zlib
|
2022-12-10 21:18:03 +00:00
|
|
|
import logging
|
|
|
|
|
2022-12-24 13:39:40 +00:00
|
|
|
from typing import Optional
|
2022-12-17 14:39:42 +00:00
|
|
|
from hashlib import sha256
|
|
|
|
from functools import partial
|
|
|
|
|
2022-12-10 21:18:03 +00:00
|
|
|
import trio
|
2022-12-17 14:39:42 +00:00
|
|
|
import pytest
|
2022-12-10 21:18:03 +00:00
|
|
|
|
2022-12-17 14:39:42 +00:00
|
|
|
from PIL import Image
|
2023-01-06 17:36:50 +00:00
|
|
|
from google.protobuf.json_format import MessageToDict
|
2022-12-17 14:39:42 +00:00
|
|
|
|
|
|
|
from skynet.brain import SkynetDGPUComputeError
|
2023-01-22 15:12:33 +00:00
|
|
|
from skynet.network import get_random_port, SessionServer
|
|
|
|
from skynet.protobuf import SkynetRPCResponse
|
2022-12-17 14:39:42 +00:00
|
|
|
from skynet.frontend import open_skynet_rpc
|
2023-01-22 15:12:33 +00:00
|
|
|
from skynet.constants import *
|
2022-12-17 14:39:42 +00:00
|
|
|
|
|
|
|
|
2023-01-22 15:12:33 +00:00
|
|
|
async def wait_for_dgpus(session, amount: int, timeout: float = 30.0):
|
2022-12-17 14:39:42 +00:00
|
|
|
gpu_ready = False
|
2023-01-22 15:12:33 +00:00
|
|
|
with trio.fail_after(timeout):
|
|
|
|
while not gpu_ready:
|
|
|
|
res = await session.rpc('dgpu_workers')
|
|
|
|
if res.result['ok'] >= amount:
|
|
|
|
break
|
2022-12-17 14:39:42 +00:00
|
|
|
|
2023-01-22 15:12:33 +00:00
|
|
|
await trio.sleep(1)
|
2022-12-17 14:39:42 +00:00
|
|
|
|
|
|
|
|
|
|
|
_images = set()
|
|
|
|
async def check_request_img(
|
|
|
|
i: int,
|
2023-01-06 17:36:50 +00:00
|
|
|
uid: str = '1',
|
2022-12-17 14:39:42 +00:00
|
|
|
width: int = 512,
|
|
|
|
height: int = 512,
|
2022-12-24 13:39:40 +00:00
|
|
|
expect_unique = True,
|
|
|
|
upscaler: Optional[str] = None
|
2022-12-17 14:39:42 +00:00
|
|
|
):
|
|
|
|
global _images
|
|
|
|
|
2023-01-22 15:12:33 +00:00
|
|
|
with open_skynet_rpc(
|
2022-12-19 15:36:02 +00:00
|
|
|
uid,
|
2023-01-22 15:12:33 +00:00
|
|
|
cert_name='whitelist/testing.cert',
|
|
|
|
key_name='testing.key'
|
|
|
|
) as session:
|
|
|
|
res = await session.rpc(
|
|
|
|
'dgpu_call', {
|
|
|
|
'method': 'diffuse',
|
|
|
|
'params': {
|
|
|
|
'prompt': 'red old tractor in a sunny wheat field',
|
|
|
|
'step': 28,
|
|
|
|
'width': width, 'height': height,
|
|
|
|
'guidance': 7.5,
|
|
|
|
'seed': None,
|
|
|
|
'algo': list(ALGOS.keys())[i],
|
|
|
|
'upscaler': upscaler
|
|
|
|
}
|
|
|
|
},
|
|
|
|
timeout=60
|
|
|
|
)
|
2022-12-17 14:39:42 +00:00
|
|
|
|
|
|
|
if 'error' in res.result:
|
2023-01-06 17:36:50 +00:00
|
|
|
raise SkynetDGPUComputeError(MessageToDict(res.result))
|
2022-12-17 14:39:42 +00:00
|
|
|
|
2023-01-22 15:12:33 +00:00
|
|
|
img_raw = res.bin
|
2022-12-17 14:39:42 +00:00
|
|
|
img_sha = sha256(img_raw).hexdigest()
|
2023-01-22 15:12:33 +00:00
|
|
|
img = Image.open(io.BytesIO(img_raw))
|
2022-12-17 14:39:42 +00:00
|
|
|
|
|
|
|
if expect_unique and img_sha in _images:
|
|
|
|
raise ValueError('Duplicated image sha: {img_sha}')
|
|
|
|
|
|
|
|
_images.add(img_sha)
|
|
|
|
|
|
|
|
logging.info(f'img sha256: {img_sha} size: {len(img_raw)}')
|
|
|
|
|
|
|
|
assert len(img_raw) > 100000
|
|
|
|
|
2022-12-24 13:39:40 +00:00
|
|
|
return img
|
|
|
|
|
2022-12-17 14:39:42 +00:00
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
'dgpu_workers', [(1, ['midj'])], indirect=True)
|
|
|
|
async def test_dgpu_worker_compute_error(dgpu_workers):
|
|
|
|
'''Attempt to generate a huge image and check we get the right error,
|
|
|
|
then generate a smaller image to show gpu worker recovery
|
|
|
|
'''
|
|
|
|
|
2023-01-22 15:12:33 +00:00
|
|
|
with open_skynet_rpc(
|
2022-12-19 15:36:02 +00:00
|
|
|
'test-ctx',
|
2023-01-22 15:12:33 +00:00
|
|
|
cert_name='whitelist/testing.cert',
|
|
|
|
key_name='testing.key'
|
|
|
|
) as session:
|
|
|
|
await wait_for_dgpus(session, 1)
|
2022-12-17 14:39:42 +00:00
|
|
|
|
|
|
|
with pytest.raises(SkynetDGPUComputeError) as e:
|
|
|
|
await check_request_img(0, width=4096, height=4096)
|
|
|
|
|
|
|
|
logging.info(e)
|
|
|
|
|
|
|
|
await check_request_img(0)
|
|
|
|
|
|
|
|
|
2023-01-22 15:12:33 +00:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
'dgpu_workers', [(1, ['midj'])], indirect=True)
|
|
|
|
async def test_dgpu_worker(dgpu_workers):
|
|
|
|
'''Generate one image in a single dgpu worker
|
|
|
|
'''
|
|
|
|
|
|
|
|
with open_skynet_rpc(
|
|
|
|
'test-ctx',
|
|
|
|
cert_name='whitelist/testing.cert',
|
|
|
|
key_name='testing.key'
|
|
|
|
) as session:
|
|
|
|
await wait_for_dgpus(session, 1)
|
|
|
|
|
|
|
|
await check_request_img(0)
|
|
|
|
|
|
|
|
|
2022-12-17 14:39:42 +00:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
'dgpu_workers', [(1, ['midj', 'stable'])], indirect=True)
|
2023-01-22 15:12:33 +00:00
|
|
|
async def test_dgpu_worker_two_models(dgpu_workers):
|
2022-12-17 14:39:42 +00:00
|
|
|
'''Generate two images in a single dgpu worker using
|
|
|
|
two different models.
|
|
|
|
'''
|
|
|
|
|
2023-01-22 15:12:33 +00:00
|
|
|
with open_skynet_rpc(
|
2022-12-19 15:36:02 +00:00
|
|
|
'test-ctx',
|
2023-01-22 15:12:33 +00:00
|
|
|
cert_name='whitelist/testing.cert',
|
|
|
|
key_name='testing.key'
|
|
|
|
) as session:
|
|
|
|
await wait_for_dgpus(session, 1)
|
2022-12-17 14:39:42 +00:00
|
|
|
|
|
|
|
await check_request_img(0)
|
|
|
|
await check_request_img(1)
|
|
|
|
|
|
|
|
|
2022-12-24 13:39:40 +00:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
'dgpu_workers', [(1, ['midj'])], indirect=True)
|
|
|
|
async def test_dgpu_worker_upscale(dgpu_workers):
|
|
|
|
'''Generate two images in a single dgpu worker using
|
|
|
|
two different models.
|
|
|
|
'''
|
|
|
|
|
2023-01-22 15:12:33 +00:00
|
|
|
with open_skynet_rpc(
|
2022-12-24 13:39:40 +00:00
|
|
|
'test-ctx',
|
2023-01-22 15:12:33 +00:00
|
|
|
cert_name='whitelist/testing.cert',
|
|
|
|
key_name='testing.key'
|
|
|
|
) as session:
|
|
|
|
await wait_for_dgpus(session, 1)
|
2022-12-24 13:39:40 +00:00
|
|
|
|
|
|
|
img = await check_request_img(0, upscaler='x4')
|
|
|
|
|
|
|
|
assert img.size == (2048, 2048)
|
|
|
|
|
|
|
|
|
2022-12-17 14:39:42 +00:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
'dgpu_workers', [(2, ['midj'])], indirect=True)
|
|
|
|
async def test_dgpu_workers_two(dgpu_workers):
|
|
|
|
'''Generate two images in two separate dgpu workers
|
|
|
|
'''
|
2023-01-22 15:12:33 +00:00
|
|
|
with open_skynet_rpc(
|
2022-12-19 15:36:02 +00:00
|
|
|
'test-ctx',
|
2023-01-22 15:12:33 +00:00
|
|
|
cert_name='whitelist/testing.cert',
|
|
|
|
key_name='testing.key'
|
|
|
|
) as session:
|
|
|
|
await wait_for_dgpus(session, 2, timeout=60)
|
2022-12-17 14:39:42 +00:00
|
|
|
|
|
|
|
async with trio.open_nursery() as n:
|
|
|
|
n.start_soon(check_request_img, 0)
|
|
|
|
n.start_soon(check_request_img, 0)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
'dgpu_workers', [(1, ['midj'])], indirect=True)
|
|
|
|
async def test_dgpu_worker_algo_swap(dgpu_workers):
|
|
|
|
'''Generate an image using a non default model
|
|
|
|
'''
|
2023-01-22 15:12:33 +00:00
|
|
|
with open_skynet_rpc(
|
2022-12-19 15:36:02 +00:00
|
|
|
'test-ctx',
|
2023-01-22 15:12:33 +00:00
|
|
|
cert_name='whitelist/testing.cert',
|
|
|
|
key_name='testing.key'
|
|
|
|
) as session:
|
|
|
|
await wait_for_dgpus(session, 1)
|
2022-12-17 14:39:42 +00:00
|
|
|
await check_request_img(5)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
'dgpu_workers', [(3, ['midj'])], indirect=True)
|
|
|
|
async def test_dgpu_rotation_next_worker(dgpu_workers):
|
|
|
|
'''Connect three dgpu workers, disconnect and check next_worker
|
|
|
|
rotation happens correctly
|
|
|
|
'''
|
2023-01-22 15:12:33 +00:00
|
|
|
with open_skynet_rpc(
|
2022-12-19 15:36:02 +00:00
|
|
|
'test-ctx',
|
2023-01-22 15:12:33 +00:00
|
|
|
cert_name='whitelist/testing.cert',
|
|
|
|
key_name='testing.key'
|
|
|
|
) as session:
|
|
|
|
await wait_for_dgpus(session, 3)
|
2022-12-17 14:39:42 +00:00
|
|
|
|
2023-01-22 15:12:33 +00:00
|
|
|
res = await session.rpc('dgpu_next')
|
2022-12-17 14:39:42 +00:00
|
|
|
assert 'ok' in res.result
|
|
|
|
assert res.result['ok'] == 0
|
|
|
|
|
|
|
|
await check_request_img(0)
|
|
|
|
|
2023-01-22 15:12:33 +00:00
|
|
|
res = await session.rpc('dgpu_next')
|
2022-12-17 14:39:42 +00:00
|
|
|
assert 'ok' in res.result
|
|
|
|
assert res.result['ok'] == 1
|
|
|
|
|
|
|
|
await check_request_img(0)
|
|
|
|
|
2023-01-22 15:12:33 +00:00
|
|
|
res = await session.rpc('dgpu_next')
|
2022-12-17 14:39:42 +00:00
|
|
|
assert 'ok' in res.result
|
|
|
|
assert res.result['ok'] == 2
|
|
|
|
|
|
|
|
await check_request_img(0)
|
|
|
|
|
2023-01-22 15:12:33 +00:00
|
|
|
res = await session.rpc('dgpu_next')
|
2022-12-17 14:39:42 +00:00
|
|
|
assert 'ok' in res.result
|
|
|
|
assert res.result['ok'] == 0
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
'dgpu_workers', [(3, ['midj'])], indirect=True)
|
|
|
|
async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers):
|
|
|
|
'''Connect three dgpu workers, disconnect the first one and check
|
|
|
|
next_worker rotation happens correctly
|
|
|
|
'''
|
2023-01-22 15:12:33 +00:00
|
|
|
with open_skynet_rpc(
|
2022-12-19 15:36:02 +00:00
|
|
|
'test-ctx',
|
2023-01-22 15:12:33 +00:00
|
|
|
cert_name='whitelist/testing.cert',
|
|
|
|
key_name='testing.key'
|
|
|
|
) as session:
|
|
|
|
await wait_for_dgpus(session, 3)
|
2022-12-17 14:39:42 +00:00
|
|
|
|
|
|
|
await trio.sleep(3)
|
|
|
|
|
|
|
|
# stop worker who's turn is next
|
|
|
|
for _ in range(2):
|
|
|
|
ec, out = dgpu_workers[0].exec_run(['pkill', '-INT', '-f', 'skynet'])
|
|
|
|
assert ec == 0
|
|
|
|
|
|
|
|
dgpu_workers[0].wait()
|
2022-12-10 21:18:03 +00:00
|
|
|
|
2023-01-22 15:12:33 +00:00
|
|
|
res = await session.rpc('dgpu_workers')
|
2022-12-17 14:39:42 +00:00
|
|
|
assert 'ok' in res.result
|
|
|
|
assert res.result['ok'] == 2
|
2022-12-10 21:18:03 +00:00
|
|
|
|
|
|
|
async with trio.open_nursery() as n:
|
2022-12-17 14:39:42 +00:00
|
|
|
n.start_soon(check_request_img, 0)
|
|
|
|
n.start_soon(check_request_img, 0)
|
2022-12-10 21:18:03 +00:00
|
|
|
|
|
|
|
|
2022-12-17 14:39:42 +00:00
|
|
|
async def test_dgpu_no_ack_node_disconnect(skynet_running):
|
2022-12-19 15:36:02 +00:00
|
|
|
'''Mock a node that connects, gets a request but fails to
|
|
|
|
acknowledge it, then check skynet correctly drops the node
|
|
|
|
'''
|
2022-12-10 21:18:03 +00:00
|
|
|
|
2023-01-22 15:12:33 +00:00
|
|
|
async def mock_rpc(req, ctx):
|
|
|
|
resp = SkynetRPCResponse()
|
|
|
|
resp.result.update({'error': 'can\'t do it mate'})
|
|
|
|
return resp
|
|
|
|
|
|
|
|
dgpu_addr = f'tcp://127.0.0.1:{get_random_port()}'
|
|
|
|
mock_server = SessionServer(
|
|
|
|
dgpu_addr,
|
|
|
|
mock_rpc,
|
|
|
|
cert_name='whitelist/testing.cert',
|
|
|
|
key_name='testing.key'
|
|
|
|
)
|
|
|
|
|
|
|
|
async with mock_server.open():
|
|
|
|
with open_skynet_rpc(
|
|
|
|
'test-ctx',
|
|
|
|
cert_name='whitelist/testing.cert',
|
|
|
|
key_name='testing.key'
|
|
|
|
) as session:
|
|
|
|
|
|
|
|
res = await session.rpc('dgpu_online', {
|
|
|
|
'dgpu_addr': dgpu_addr,
|
|
|
|
'cert': 'whitelist/testing.cert'
|
|
|
|
})
|
|
|
|
assert 'ok' in res.result
|
2022-12-10 21:18:03 +00:00
|
|
|
|
2023-01-22 15:12:33 +00:00
|
|
|
await wait_for_dgpus(session, 1)
|
2022-12-10 21:18:03 +00:00
|
|
|
|
2023-01-22 15:12:33 +00:00
|
|
|
with pytest.raises(SkynetDGPUComputeError) as e:
|
|
|
|
await check_request_img(0)
|
2022-12-10 21:18:03 +00:00
|
|
|
|
2023-01-22 15:12:33 +00:00
|
|
|
assert 'can\'t do it mate' in str(e.value)
|
2022-12-10 21:18:03 +00:00
|
|
|
|
2023-01-22 15:12:33 +00:00
|
|
|
res = await session.rpc('dgpu_workers')
|
|
|
|
assert 'ok' in res.result
|
|
|
|
assert res.result['ok'] == 0
|
2022-12-10 21:18:03 +00:00
|
|
|
|
2022-12-19 15:36:02 +00:00
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
'dgpu_workers', [(1, ['midj'])], indirect=True)
|
|
|
|
async def test_dgpu_timeout_while_processing(dgpu_workers):
|
|
|
|
'''Stop node while processing request to cause timeout and
|
|
|
|
then check skynet correctly drops the node.
|
|
|
|
'''
|
2023-01-22 15:12:33 +00:00
|
|
|
with open_skynet_rpc(
|
2022-12-19 15:36:02 +00:00
|
|
|
'test-ctx',
|
2023-01-22 15:12:33 +00:00
|
|
|
cert_name='whitelist/testing.cert',
|
|
|
|
key_name='testing.key'
|
|
|
|
) as session:
|
|
|
|
await wait_for_dgpus(session, 1)
|
2022-12-19 15:36:02 +00:00
|
|
|
|
|
|
|
async def check_request_img_raises():
|
|
|
|
with pytest.raises(SkynetDGPUComputeError) as e:
|
|
|
|
await check_request_img(0)
|
|
|
|
|
|
|
|
assert 'timeout while processing request' in str(e)
|
|
|
|
|
|
|
|
async with trio.open_nursery() as n:
|
|
|
|
n.start_soon(check_request_img_raises)
|
|
|
|
await trio.sleep(1)
|
|
|
|
ec, out = dgpu_workers[0].exec_run(
|
|
|
|
['pkill', '-TERM', '-f', 'skynet'])
|
|
|
|
assert ec == 0
|
2023-01-07 09:59:50 +00:00
|
|
|
|
|
|
|
|
2023-01-16 02:42:45 +00:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
'dgpu_workers', [(1, ['midj'])], indirect=True)
|
|
|
|
async def test_dgpu_img2img(dgpu_workers):
|
|
|
|
|
2023-01-22 15:12:33 +00:00
|
|
|
with open_skynet_rpc(
|
|
|
|
'test-ctx',
|
|
|
|
cert_name='whitelist/testing.cert',
|
|
|
|
key_name='testing.key'
|
|
|
|
) as session:
|
|
|
|
await wait_for_dgpus(session, 1)
|
|
|
|
|
|
|
|
await trio.sleep(2)
|
|
|
|
|
|
|
|
res = await session.rpc(
|
|
|
|
'dgpu_call', {
|
|
|
|
'method': 'diffuse',
|
|
|
|
'params': {
|
|
|
|
'prompt': 'red old tractor in a sunny wheat field',
|
|
|
|
'step': 28,
|
|
|
|
'width': 512, 'height': 512,
|
|
|
|
'guidance': 7.5,
|
|
|
|
'seed': None,
|
|
|
|
'algo': list(ALGOS.keys())[0],
|
|
|
|
'upscaler': None
|
|
|
|
}
|
|
|
|
},
|
|
|
|
timeout=60
|
|
|
|
)
|
2023-01-16 02:42:45 +00:00
|
|
|
|
|
|
|
if 'error' in res.result:
|
|
|
|
raise SkynetDGPUComputeError(MessageToDict(res.result))
|
|
|
|
|
2023-01-22 15:12:33 +00:00
|
|
|
img_raw = res.bin
|
|
|
|
img = Image.open(io.BytesIO(img_raw))
|
2023-01-16 02:42:45 +00:00
|
|
|
img.save('txt2img.png')
|
|
|
|
|
2023-01-22 15:12:33 +00:00
|
|
|
res = await session.rpc(
|
|
|
|
'dgpu_call', {
|
|
|
|
'method': 'diffuse',
|
|
|
|
'params': {
|
|
|
|
'prompt': 'red ferrari in a sunny wheat field',
|
|
|
|
'step': 28,
|
|
|
|
'guidance': 8,
|
|
|
|
'strength': 0.7,
|
|
|
|
'seed': None,
|
|
|
|
'algo': list(ALGOS.keys())[0],
|
|
|
|
'upscaler': 'x4'
|
|
|
|
}
|
|
|
|
},
|
|
|
|
binext=img_raw,
|
|
|
|
timeout=60
|
|
|
|
)
|
2023-01-16 02:42:45 +00:00
|
|
|
|
|
|
|
if 'error' in res.result:
|
|
|
|
raise SkynetDGPUComputeError(MessageToDict(res.result))
|
|
|
|
|
2023-01-22 15:12:33 +00:00
|
|
|
img_raw = res.bin
|
|
|
|
img = Image.open(io.BytesIO(img_raw))
|
2023-01-16 02:42:45 +00:00
|
|
|
img.save('img2img.png')
|