address comments, add compute logic to support transformers

add-txt2txt-models
Konstantine Tsafatinos 2023-06-10 13:54:52 -04:00
parent 82a7a3e076
commit 9f5a70ee11
2 changed files with 86 additions and 46 deletions

View File

@ -7,12 +7,13 @@ from hashlib import sha256
import json import json
import logging import logging
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from transformers import AutoModelForCausalLM
import torch import torch
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
from skynet.dgpu.errors import DGPUComputeError from skynet.dgpu.errors import DGPUComputeError
from skynet.utils import convert_from_bytes_and_crop, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for_image from skynet.utils import convert_from_bytes_and_crop, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for_diffuse
def prepare_params_for_diffuse( def prepare_params_for_diffuse(
@ -42,6 +43,20 @@ def prepare_params_for_diffuse(
) )
def prepare_params_for_transform(
params: dict,
):
return (
params['prompt'],
int(params['num_return_sequences']),
int(params['no_repeat_ngram_size']),
float(params['top_p']),
float(params['temperature']),
int(params['max_length']),
_params
)
class SkynetMM: class SkynetMM:
def __init__(self, config: dict): def __init__(self, config: dict):
@ -53,7 +68,8 @@ class SkynetMM:
self._models = {} self._models = {}
for model in self.initial_models: for model in self.initial_models:
self.load_model(model, False, force=True) # TODO: look into config for adding model type
self.load_model(model, False, config['model_type'], force=True)
def log_debug_info(self): def log_debug_info(self):
logging.info('memory summary:') logging.info('memory summary:')
@ -71,11 +87,14 @@ class SkynetMM:
self, self,
model_name: str, model_name: str,
image: bool, image: bool,
model_type: str,
force=False force=False
): ):
logging.info(f'loading model {model_name}...') logging.info(f'loading model {model_name}...')
match model_type:
case 'diffuse':
if force or len(self._models.keys()) == 0: if force or len(self._models.keys()) == 0:
pipe = pipeline_for_image(model_name, image=image) pipe = pipeline_for_diffuse(model_name, image=image)
self._models[model_name] = { self._models[model_name] = {
'pipe': pipe, 'pipe': pipe,
'generated': 0, 'generated': 0,
@ -92,12 +111,13 @@ class SkynetMM:
del self._models[least_used] del self._models[least_used]
logging.info(f'swapping model {least_used} for {model_name}...') logging.info(
f'swapping model {least_used} for {model_name}...')
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
pipe = pipeline_for_image(model_name, image=image) pipe = pipeline_for_diffuse(model_name, image=image)
self._models[model_name] = { self._models[model_name] = {
'pipe': pipe, 'pipe': pipe,
@ -107,8 +127,12 @@ class SkynetMM:
logging.info(f'loaded model {model_name}') logging.info(f'loaded model {model_name}')
return pipe return pipe
case 'transform':
...
def get_model(self, model_name: str, image: bool) -> DiffusionPipeline: def get_model(self, model_name: str, image: bool, model_type: str) -> DiffusionPipeline | AutoModelForCausalLM:
match model_type:
case 'diffuse':
if model_name not in MODELS: if model_name not in MODELS:
raise DGPUComputeError(f'Unknown model {model_name}') raise DGPUComputeError(f'Unknown model {model_name}')
@ -119,6 +143,8 @@ class SkynetMM:
pipe = self._models[model_name]['pipe'] pipe = self._models[model_name]['pipe']
return pipe return pipe
case 'transform':
...
def compute_one( def compute_one(
self, self,
@ -134,7 +160,7 @@ class SkynetMM:
arguments = prepare_params_for_diffuse(params, binary) arguments = prepare_params_for_diffuse(params, binary)
prompt, guidance, step, seed, upscaler, extra_params = arguments prompt, guidance, step, seed, upscaler, extra_params = arguments
model = self.get_model( model = self.get_model(
params['model'], 'image' in extra_params) params['model'], 'image' in extra_params, method)
image = model( image = model(
prompt, prompt,
@ -156,9 +182,23 @@ class SkynetMM:
return img_sha, img_raw return img_sha, img_raw
case 'transformer': case 'transform':
# TODO: Understand dpgu code and figure out what to put here arguments = prepare_params_for_transform(params)
pass prompt, num_return_sequences, no_repeat_ngram_size, top_p, \
temperature, max_length, extra_params = arguments
model = self.get_model(params['model'], False, method)
response = model(
prompt,
num_return_sequences,
no_repeat_ngram_size,
top_p,
temperature,
max_length,
)
return response
case _: case _:
raise DGPUComputeError('Unsupported compute method') raise DGPUComputeError('Unsupported compute method')

View File

@ -59,7 +59,7 @@ def convert_from_bytes_and_crop(raw: bytes, max_w: int, max_h: int) -> Image:
return image.convert('RGB') return image.convert('RGB')
def pipeline_for_image(model: str, mem_fraction: float = 1.0, image=False) -> DiffusionPipeline: def pipeline_for_diffuse(model: str, mem_fraction: float = 1.0, image=False) -> DiffusionPipeline:
assert torch.cuda.is_available() assert torch.cuda.is_available()
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.set_per_process_memory_fraction(mem_fraction) torch.cuda.set_per_process_memory_fraction(mem_fraction)
@ -98,7 +98,7 @@ def pipeline_for_image(model: str, mem_fraction: float = 1.0, image=False) -> Di
return pipe.to('cuda') return pipe.to('cuda')
def pipeline_for_text(model: str, mem_fraction: float = 1.0, image=False) -> DiffusionPipeline: def pipeline_for_transform(model: str, mem_fraction: float = 1.0, image=False) -> DiffusionPipeline:
assert torch.cuda.is_available() assert torch.cuda.is_available()
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.set_per_process_memory_fraction(mem_fraction) torch.cuda.set_per_process_memory_fraction(mem_fraction)
@ -150,7 +150,7 @@ def txt2img(
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
login(token=hf_token) login(token=hf_token)
pipe = pipeline_for_image(model) pipe = pipeline_for_diffuse(model)
seed = seed if seed else random.randint(0, 2 ** 64) seed = seed if seed else random.randint(0, 2 ** 64)
prompt = prompt prompt = prompt
@ -183,7 +183,7 @@ def img2img(
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
login(token=hf_token) login(token=hf_token)
pipe = pipeline_for_image(model, image=True) pipe = pipeline_for_diffuse(model, image=True)
with open(img_path, 'rb') as img_file: with open(img_path, 'rb') as img_file:
input_img = convert_from_bytes_and_crop(img_file.read(), 512, 512) input_img = convert_from_bytes_and_crop(img_file.read(), 512, 512)
@ -220,7 +220,7 @@ def txt2txt(
login(token=hf_token) login(token=hf_token)
tokenizer = AutoTokenizer.from_pretrained(model) tokenizer = AutoTokenizer.from_pretrained(model)
pipe = pipeline_for_text(model) pipe = pipeline_for_transform(model)
prompt = prompt prompt = prompt
# TODO: learn more about return tensors and model params # TODO: learn more about return tensors and model params
@ -286,6 +286,6 @@ def download_all_models(hf_token: str):
login(token=hf_token) login(token=hf_token)
for model in MODELS: for model in MODELS:
print(f'DOWNLOADING {model.upper()}') print(f'DOWNLOADING {model.upper()}')
pipeline_for_image(model) pipeline_for_diffuse(model)
print(f'DOWNLOADING IMAGE {model.upper()}') print(f'DOWNLOADING IMAGE {model.upper()}')
pipeline_for_image(model, image=True) pipeline_for_diffuse(model, image=True)