mirror of https://github.com/skygpu/skynet.git
Fix minor issues on compute daemon found thanks to tests, vastly improve pipeline_for function and support old diffuse method
parent
22e40b766f
commit
7108543709
|
@ -44,14 +44,14 @@ def prepare_params_for_diffuse(
|
||||||
_params['image'] = image
|
_params['image'] = image
|
||||||
_params['strength'] = float(params['strength'])
|
_params['strength'] = float(params['strength'])
|
||||||
|
|
||||||
case 'txt2img':
|
case 'txt2img' | 'diffuse':
|
||||||
...
|
...
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
raise DGPUComputeError(f'Unknown input_type {input_type}')
|
raise DGPUComputeError(f'Unknown mode {mode}')
|
||||||
|
|
||||||
_params['width'] = int(params['width'])
|
# _params['width'] = int(params['width'])
|
||||||
_params['height'] = int(params['height'])
|
# _params['height'] = int(params['height'])
|
||||||
|
|
||||||
return (
|
return (
|
||||||
params['prompt'],
|
params['prompt'],
|
||||||
|
@ -72,7 +72,10 @@ class SkynetMM:
|
||||||
if 'hf_home' in config:
|
if 'hf_home' in config:
|
||||||
self.cache_dir = config['hf_home']
|
self.cache_dir = config['hf_home']
|
||||||
|
|
||||||
self.load_model(DEFAULT_INITAL_MODEL, 'txt2img')
|
self._model_name = ''
|
||||||
|
self._model_mode = ''
|
||||||
|
|
||||||
|
# self.load_model(DEFAULT_INITAL_MODEL, 'txt2img')
|
||||||
|
|
||||||
def log_debug_info(self):
|
def log_debug_info(self):
|
||||||
logging.info('memory summary:')
|
logging.info('memory summary:')
|
||||||
|
@ -90,10 +93,13 @@ class SkynetMM:
|
||||||
name: str,
|
name: str,
|
||||||
mode: str
|
mode: str
|
||||||
):
|
):
|
||||||
logging.info(f'loading model {model_name}...')
|
logging.info(f'loading model {name}...')
|
||||||
self._model_mode = mode
|
self._model_mode = mode
|
||||||
self._model_name = name
|
self._model_name = name
|
||||||
|
|
||||||
|
if getattr(self, '_model', None):
|
||||||
|
del self._model
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@ -131,7 +137,7 @@ class SkynetMM:
|
||||||
output_hash = None
|
output_hash = None
|
||||||
try:
|
try:
|
||||||
match method:
|
match method:
|
||||||
case 'txt2img' | 'img2img' | 'inpaint':
|
case 'diffuse' | 'txt2img' | 'img2img' | 'inpaint':
|
||||||
arguments = prepare_params_for_diffuse(
|
arguments = prepare_params_for_diffuse(
|
||||||
params, method, inputs)
|
params, method, inputs)
|
||||||
prompt, guidance, step, seed, upscaler, extra_params = arguments
|
prompt, guidance, step, seed, upscaler, extra_params = arguments
|
||||||
|
|
|
@ -18,6 +18,8 @@ from PIL import Image
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
DiffusionPipeline,
|
DiffusionPipeline,
|
||||||
|
AutoPipelineForText2Image,
|
||||||
|
AutoPipelineForImage2Image,
|
||||||
AutoPipelineForInpainting,
|
AutoPipelineForInpainting,
|
||||||
EulerAncestralDiscreteScheduler
|
EulerAncestralDiscreteScheduler
|
||||||
)
|
)
|
||||||
|
@ -106,11 +108,15 @@ def pipeline_for(
|
||||||
|
|
||||||
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
||||||
|
|
||||||
if 'inpaint' in mode:
|
match mode:
|
||||||
|
case 'inpaint':
|
||||||
pipe_class = AutoPipelineForInpainting
|
pipe_class = AutoPipelineForInpainting
|
||||||
|
|
||||||
else:
|
case 'img2img':
|
||||||
pipe_class = DiffusionPipeline
|
pipe_class = AutoPipelineForImage2Image
|
||||||
|
|
||||||
|
case 'txt2img' | 'diffuse':
|
||||||
|
pipe_class = AutoPipelineForText2Image
|
||||||
|
|
||||||
pipe = pipe_class.from_pretrained(
|
pipe = pipe_class.from_pretrained(
|
||||||
model, **params)
|
model, **params)
|
||||||
|
@ -121,7 +127,7 @@ def pipeline_for(
|
||||||
pipe.enable_xformers_memory_efficient_attention()
|
pipe.enable_xformers_memory_efficient_attention()
|
||||||
|
|
||||||
if over_mem:
|
if over_mem:
|
||||||
if 'img2img' not in mode:
|
if mode == 'txt2img':
|
||||||
pipe.enable_vae_slicing()
|
pipe.enable_vae_slicing()
|
||||||
pipe.enable_vae_tiling()
|
pipe.enable_vae_tiling()
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,33 @@ from skynet.dgpu.compute import SkynetMM
|
||||||
from skynet.constants import *
|
from skynet.constants import *
|
||||||
from skynet.config import *
|
from skynet.config import *
|
||||||
|
|
||||||
|
|
||||||
|
async def test_diffuse(dgpu):
|
||||||
|
conn, mm, daemon = dgpu
|
||||||
|
await conn.cancel_work(0, 'testing')
|
||||||
|
|
||||||
|
daemon._snap['requests'][0] = {}
|
||||||
|
req = {
|
||||||
|
'id': 0,
|
||||||
|
'nonce': 0,
|
||||||
|
'body': json.dumps({
|
||||||
|
"method": "diffuse",
|
||||||
|
"params": {
|
||||||
|
"prompt": "Kronos God Realistic 4k",
|
||||||
|
"model": list(MODELS.keys())[-1],
|
||||||
|
"step": 21,
|
||||||
|
"width": 1024,
|
||||||
|
"height": 1024,
|
||||||
|
"seed": 168402949,
|
||||||
|
"guidance": "7.5"
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
'binary_data': '',
|
||||||
|
}
|
||||||
|
|
||||||
|
await daemon.maybe_serve_one(req)
|
||||||
|
|
||||||
|
|
||||||
async def test_txt2img(dgpu):
|
async def test_txt2img(dgpu):
|
||||||
conn, mm, daemon = dgpu
|
conn, mm, daemon = dgpu
|
||||||
await conn.cancel_work(0, 'testing')
|
await conn.cancel_work(0, 'testing')
|
||||||
|
@ -41,7 +68,7 @@ async def test_img2img(dgpu):
|
||||||
'body': json.dumps({
|
'body': json.dumps({
|
||||||
"method": "img2img",
|
"method": "img2img",
|
||||||
"params": {
|
"params": {
|
||||||
"prompt": "Kronos God Realistic 4k",
|
"prompt": "a hindu cat god feline god on a house roof",
|
||||||
"model": list(MODELS.keys())[-2],
|
"model": list(MODELS.keys())[-2],
|
||||||
"step": 21,
|
"step": 21,
|
||||||
"width": 1024,
|
"width": 1024,
|
||||||
|
|
Loading…
Reference in New Issue