Fix minor issues on compute daemon found thanks to tests, vastly improve pipeline_for function and support old diffuse method

pull/44/head
Guillermo Rodriguez 2025-01-10 12:33:23 -03:00
parent 22e40b766f
commit 7108543709
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
3 changed files with 52 additions and 13 deletions

View File

@ -44,14 +44,14 @@ def prepare_params_for_diffuse(
_params['image'] = image _params['image'] = image
_params['strength'] = float(params['strength']) _params['strength'] = float(params['strength'])
case 'txt2img': case 'txt2img' | 'diffuse':
... ...
case _: case _:
raise DGPUComputeError(f'Unknown input_type {input_type}') raise DGPUComputeError(f'Unknown mode {mode}')
_params['width'] = int(params['width']) # _params['width'] = int(params['width'])
_params['height'] = int(params['height']) # _params['height'] = int(params['height'])
return ( return (
params['prompt'], params['prompt'],
@ -72,7 +72,10 @@ class SkynetMM:
if 'hf_home' in config: if 'hf_home' in config:
self.cache_dir = config['hf_home'] self.cache_dir = config['hf_home']
self.load_model(DEFAULT_INITAL_MODEL, 'txt2img') self._model_name = ''
self._model_mode = ''
# self.load_model(DEFAULT_INITAL_MODEL, 'txt2img')
def log_debug_info(self): def log_debug_info(self):
logging.info('memory summary:') logging.info('memory summary:')
@ -90,10 +93,13 @@ class SkynetMM:
name: str, name: str,
mode: str mode: str
): ):
logging.info(f'loading model {model_name}...') logging.info(f'loading model {name}...')
self._model_mode = mode self._model_mode = mode
self._model_name = name self._model_name = name
if getattr(self, '_model', None):
del self._model
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -131,7 +137,7 @@ class SkynetMM:
output_hash = None output_hash = None
try: try:
match method: match method:
case 'txt2img' | 'img2img' | 'inpaint': case 'diffuse' | 'txt2img' | 'img2img' | 'inpaint':
arguments = prepare_params_for_diffuse( arguments = prepare_params_for_diffuse(
params, method, inputs) params, method, inputs)
prompt, guidance, step, seed, upscaler, extra_params = arguments prompt, guidance, step, seed, upscaler, extra_params = arguments

View File

@ -18,6 +18,8 @@ from PIL import Image
from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.rrdbnet_arch import RRDBNet
from diffusers import ( from diffusers import (
DiffusionPipeline, DiffusionPipeline,
AutoPipelineForText2Image,
AutoPipelineForImage2Image,
AutoPipelineForInpainting, AutoPipelineForInpainting,
EulerAncestralDiscreteScheduler EulerAncestralDiscreteScheduler
) )
@ -106,11 +108,15 @@ def pipeline_for(
torch.cuda.set_per_process_memory_fraction(mem_fraction) torch.cuda.set_per_process_memory_fraction(mem_fraction)
if 'inpaint' in mode: match mode:
pipe_class = AutoPipelineForInpainting case 'inpaint':
pipe_class = AutoPipelineForInpainting
else: case 'img2img':
pipe_class = DiffusionPipeline pipe_class = AutoPipelineForImage2Image
case 'txt2img' | 'diffuse':
pipe_class = AutoPipelineForText2Image
pipe = pipe_class.from_pretrained( pipe = pipe_class.from_pretrained(
model, **params) model, **params)
@ -121,7 +127,7 @@ def pipeline_for(
pipe.enable_xformers_memory_efficient_attention() pipe.enable_xformers_memory_efficient_attention()
if over_mem: if over_mem:
if 'img2img' not in mode: if mode == 'txt2img':
pipe.enable_vae_slicing() pipe.enable_vae_slicing()
pipe.enable_vae_tiling() pipe.enable_vae_tiling()

View File

@ -4,6 +4,33 @@ from skynet.dgpu.compute import SkynetMM
from skynet.constants import * from skynet.constants import *
from skynet.config import * from skynet.config import *
async def test_diffuse(dgpu):
conn, mm, daemon = dgpu
await conn.cancel_work(0, 'testing')
daemon._snap['requests'][0] = {}
req = {
'id': 0,
'nonce': 0,
'body': json.dumps({
"method": "diffuse",
"params": {
"prompt": "Kronos God Realistic 4k",
"model": list(MODELS.keys())[-1],
"step": 21,
"width": 1024,
"height": 1024,
"seed": 168402949,
"guidance": "7.5"
}
}),
'binary_data': '',
}
await daemon.maybe_serve_one(req)
async def test_txt2img(dgpu): async def test_txt2img(dgpu):
conn, mm, daemon = dgpu conn, mm, daemon = dgpu
await conn.cancel_work(0, 'testing') await conn.cancel_work(0, 'testing')
@ -41,7 +68,7 @@ async def test_img2img(dgpu):
'body': json.dumps({ 'body': json.dumps({
"method": "img2img", "method": "img2img",
"params": { "params": {
"prompt": "Kronos God Realistic 4k", "prompt": "a hindu cat god feline god on a house roof",
"model": list(MODELS.keys())[-2], "model": list(MODELS.keys())[-2],
"step": 21, "step": 21,
"width": 1024, "width": 1024,