Add dgpu fixture

txt2txt
Guillermo Rodriguez 2025-01-09 21:25:00 -03:00
parent 8d35e5ed9a
commit 22e40b766f
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
6 changed files with 51 additions and 44 deletions

View File

@ -34,6 +34,7 @@ def prepare_params_for_diffuse(
inputs[1], params['width'], params['height']) inputs[1], params['width'], params['height'])
_params['image'] = image _params['image'] = image
_params['mask_image'] = mask
_params['strength'] = float(params['strength']) _params['strength'] = float(params['strength'])
case 'img2img': case 'img2img':

View File

@ -146,6 +146,7 @@ class SkynetDGPUDaemon:
inputs = [ inputs = [
await self.conn.get_input_data(_input) await self.conn.get_input_data(_input)
for _input in req['binary_data'].split(',') for _input in req['binary_data'].split(',')
if _input
] ]
hash_str = ( hash_str = (

View File

@ -268,7 +268,7 @@ class SkynetGPUConnector:
return file_cid return file_cid
async def get_input_data(self, ipfs_hash: str) -> Image: async def get_input_data(self, ipfs_hash: str) -> Image:
ipfs_link = f'https://{self.ipfs_domain}/ipfs/{ipfs_hash}' link = f'https://{self.ipfs_domain}/ipfs/{ipfs_hash}'
res = await get_ipfs_file(link, timeout=1) res = await get_ipfs_file(link, timeout=1)
logging.info(f'got response from {link}') logging.info(f'got response from {link}')

View File

@ -65,8 +65,8 @@ def convert_from_bytes_and_crop(raw: bytes, max_w: int, max_h: int) -> Image:
def pipeline_for( def pipeline_for(
model: str, model: str,
mode: str,
mem_fraction: float = 1.0, mem_fraction: float = 1.0,
mode: str = [],
cache_dir: str | None = None cache_dir: str | None = None
) -> DiffusionPipeline: ) -> DiffusionPipeline:
@ -112,7 +112,7 @@ def pipeline_for(
else: else:
pipe_class = DiffusionPipeline pipe_class = DiffusionPipeline
pipe = AutoPipelineForInpainting.from_pretrained( pipe = pipe_class.from_pretrained(
model, **params) model, **params)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(

View File

@ -2,7 +2,7 @@
import pytest import pytest
from skynet.db import open_new_database from skynet.config import *
from skynet.ipfs import AsyncIPFSHTTP from skynet.ipfs import AsyncIPFSHTTP
from skynet.ipfs.docker import open_ipfs_node from skynet.ipfs.docker import open_ipfs_node
from skynet.nodeos import open_nodeos from skynet.nodeos import open_nodeos
@ -15,6 +15,7 @@ def ipfs_client():
@pytest.fixture(scope='session') @pytest.fixture(scope='session')
def postgres_db(): def postgres_db():
from skynet.db import open_new_database
with open_new_database() as db_params: with open_new_database() as db_params:
yield db_params yield db_params
@ -22,3 +23,20 @@ def postgres_db():
def cleos(): def cleos():
with open_nodeos() as cli: with open_nodeos() as cli:
yield cli yield cli
@pytest.fixture(scope='session')
def dgpu():
from skynet.dgpu.network import SkynetGPUConnector
from skynet.dgpu.compute import SkynetMM
from skynet.dgpu.daemon import SkynetDGPUDaemon
config = load_skynet_toml(file_path='skynet.toml')
hf_token = load_key(config, 'skynet.dgpu.hf_token')
hf_home = load_key(config, 'skynet.dgpu.hf_home')
set_hf_vars(hf_token, hf_home)
config = config['skynet']['dgpu']
conn = SkynetGPUConnector(config)
mm = SkynetMM(config)
daemon = SkynetDGPUDaemon(mm, conn, config)
yield conn, mm, daemon

View File

@ -1,9 +1,17 @@
import json
from skynet.dgpu.compute import SkynetMM
from skynet.constants import *
from skynet.config import * from skynet.config import *
async def test_txt2img(): async def test_txt2img(dgpu):
conn, mm, daemon = dgpu
await conn.cancel_work(0, 'testing')
daemon._snap['requests'][0] = {}
req = { req = {
'id': 0, 'id': 0,
'nonce': 0,
'body': json.dumps({ 'body': json.dumps({
"method": "txt2img", "method": "txt2img",
"params": { "params": {
@ -16,25 +24,20 @@ async def test_txt2img():
"guidance": "7.5" "guidance": "7.5"
} }
}), }),
'inputs': [], 'binary_data': '',
} }
config = load_skynet_toml(file_path=config_path) await daemon.maybe_serve_one(req)
hf_token = load_key(config, 'skynet.dgpu.hf_token')
hf_home = load_key(config, 'skynet.dgpu.hf_home')
set_hf_vars(hf_token, hf_home)
assert 'skynet' in config
assert 'dgpu' in config['skynet']
mm = SkynetMM(config['skynet']['dgpu'])
mm.maybe_serve_one(req)
async def test_img2img(): async def test_img2img(dgpu):
conn, mm, daemon = dgpu
await conn.cancel_work(0, 'testing')
daemon._snap['requests'][0] = {}
req = { req = {
'id': 0, 'id': 0,
'nonce': 0,
'body': json.dumps({ 'body': json.dumps({
"method": "img2img", "method": "img2img",
"params": { "params": {
@ -48,24 +51,19 @@ async def test_img2img():
"strength": "0.5" "strength": "0.5"
} }
}), }),
'inputs': ['QmZcGdXXVQfpco1G3tr2CGFBtv8xVsCwcwuq9gnJBWDymi'], 'binary_data': 'QmZcGdXXVQfpco1G3tr2CGFBtv8xVsCwcwuq9gnJBWDymi',
} }
config = load_skynet_toml(file_path=config_path) await daemon.maybe_serve_one(req)
hf_token = load_key(config, 'skynet.dgpu.hf_token')
hf_home = load_key(config, 'skynet.dgpu.hf_home')
set_hf_vars(hf_token, hf_home)
assert 'skynet' in config async def test_inpaint(dgpu):
assert 'dgpu' in config['skynet'] conn, mm, daemon = dgpu
await conn.cancel_work(0, 'testing')
mm = SkynetMM(config['skynet']['dgpu']) daemon._snap['requests'][0] = {}
mm.maybe_serve_one(req)
async def test_inpaint():
req = { req = {
'id': 0, 'id': 0,
'nonce': 0,
'body': json.dumps({ 'body': json.dumps({
"method": "inpaint", "method": "inpaint",
"params": { "params": {
@ -79,20 +77,9 @@ async def test_inpaint():
"strength": "0.5" "strength": "0.5"
} }
}), }),
'inputs': [ 'binary_data':
'QmZcGdXXVQfpco1G3tr2CGFBtv8xVsCwcwuq9gnJBWDymi', 'QmZcGdXXVQfpco1G3tr2CGFBtv8xVsCwcwuq9gnJBWDymi,' +
'Qmccx1aXNmq5mZDS3YviUhgGHXWhQeHvca3AgA7MDjj2hR' 'Qmccx1aXNmq5mZDS3YviUhgGHXWhQeHvca3AgA7MDjj2hR'
],
} }
config = load_skynet_toml(file_path=config_path) await daemon.maybe_serve_one(req)
hf_token = load_key(config, 'skynet.dgpu.hf_token')
hf_home = load_key(config, 'skynet.dgpu.hf_home')
set_hf_vars(hf_token, hf_home)
assert 'skynet' in config
assert 'dgpu' in config['skynet']
mm = SkynetMM(config['skynet']['dgpu'])
mm.maybe_serve_one(req)