address comments, add compute logic to support transformers

add-txt2txt-models
Konstantine Tsafatinos 2023-06-10 13:54:52 -04:00
parent 82a7a3e076
commit 9f5a70ee11
2 changed files with 86 additions and 46 deletions

View File

@ -7,12 +7,13 @@ from hashlib import sha256
import json
import logging
from diffusers import DiffusionPipeline
from transformers import AutoModelForCausalLM
import torch
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
from skynet.dgpu.errors import DGPUComputeError
from skynet.utils import convert_from_bytes_and_crop, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for_image
from skynet.utils import convert_from_bytes_and_crop, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for_diffuse
def prepare_params_for_diffuse(
@ -42,6 +43,20 @@ def prepare_params_for_diffuse(
)
def prepare_params_for_transform(
params: dict,
):
return (
params['prompt'],
int(params['num_return_sequences']),
int(params['no_repeat_ngram_size']),
float(params['top_p']),
float(params['temperature']),
int(params['max_length']),
_params
)
class SkynetMM:
def __init__(self, config: dict):
@ -53,7 +68,8 @@ class SkynetMM:
self._models = {}
for model in self.initial_models:
self.load_model(model, False, force=True)
# TODO: look into config for adding model type
self.load_model(model, False, config['model_type'], force=True)
def log_debug_info(self):
logging.info('memory summary:')
@ -71,54 +87,64 @@ class SkynetMM:
self,
model_name: str,
image: bool,
model_type: str,
force=False
):
logging.info(f'loading model {model_name}...')
if force or len(self._models.keys()) == 0:
pipe = pipeline_for_image(model_name, image=image)
self._models[model_name] = {
'pipe': pipe,
'generated': 0,
'image': image
}
match model_type:
case 'diffuse':
if force or len(self._models.keys()) == 0:
pipe = pipeline_for_diffuse(model_name, image=image)
self._models[model_name] = {
'pipe': pipe,
'generated': 0,
'image': image
}
else:
least_used = list(self._models.keys())[0]
else:
least_used = list(self._models.keys())[0]
for model in self._models:
if self._models[
least_used]['generated'] > self._models[model]['generated']:
least_used = model
for model in self._models:
if self._models[
least_used]['generated'] > self._models[model]['generated']:
least_used = model
del self._models[least_used]
del self._models[least_used]
logging.info(f'swapping model {least_used} for {model_name}...')
logging.info(
f'swapping model {least_used} for {model_name}...')
gc.collect()
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
pipe = pipeline_for_image(model_name, image=image)
pipe = pipeline_for_diffuse(model_name, image=image)
self._models[model_name] = {
'pipe': pipe,
'generated': 0,
'image': image
}
self._models[model_name] = {
'pipe': pipe,
'generated': 0,
'image': image
}
logging.info(f'loaded model {model_name}')
return pipe
logging.info(f'loaded model {model_name}')
return pipe
case 'transform':
...
def get_model(self, model_name: str, image: bool) -> DiffusionPipeline:
if model_name not in MODELS:
raise DGPUComputeError(f'Unknown model {model_name}')
def get_model(self, model_name: str, image: bool, model_type: str) -> DiffusionPipeline | AutoModelForCausalLM:
match model_type:
case 'diffuse':
if model_name not in MODELS:
raise DGPUComputeError(f'Unknown model {model_name}')
if not self.is_model_loaded(model_name, image):
pipe = self.load_model(model_name, image=image)
if not self.is_model_loaded(model_name, image):
pipe = self.load_model(model_name, image=image)
else:
pipe = self._models[model_name]['pipe']
else:
pipe = self._models[model_name]['pipe']
return pipe
return pipe
case 'transform':
...
def compute_one(
self,
@ -134,7 +160,7 @@ class SkynetMM:
arguments = prepare_params_for_diffuse(params, binary)
prompt, guidance, step, seed, upscaler, extra_params = arguments
model = self.get_model(
params['model'], 'image' in extra_params)
params['model'], 'image' in extra_params, method)
image = model(
prompt,
@ -156,9 +182,23 @@ class SkynetMM:
return img_sha, img_raw
case 'transformer':
# TODO: Understand dpgu code and figure out what to put here
pass
case 'transform':
arguments = prepare_params_for_transform(params)
prompt, num_return_sequences, no_repeat_ngram_size, top_p, \
temperature, max_length, extra_params = arguments
model = self.get_model(params['model'], False, method)
response = model(
prompt,
num_return_sequences,
no_repeat_ngram_size,
top_p,
temperature,
max_length,
)
return response
case _:
raise DGPUComputeError('Unsupported compute method')

View File

@ -59,7 +59,7 @@ def convert_from_bytes_and_crop(raw: bytes, max_w: int, max_h: int) -> Image:
return image.convert('RGB')
def pipeline_for_image(model: str, mem_fraction: float = 1.0, image=False) -> DiffusionPipeline:
def pipeline_for_diffuse(model: str, mem_fraction: float = 1.0, image=False) -> DiffusionPipeline:
assert torch.cuda.is_available()
torch.cuda.empty_cache()
torch.cuda.set_per_process_memory_fraction(mem_fraction)
@ -98,7 +98,7 @@ def pipeline_for_image(model: str, mem_fraction: float = 1.0, image=False) -> Di
return pipe.to('cuda')
def pipeline_for_text(model: str, mem_fraction: float = 1.0, image=False) -> DiffusionPipeline:
def pipeline_for_transform(model: str, mem_fraction: float = 1.0, image=False) -> DiffusionPipeline:
assert torch.cuda.is_available()
torch.cuda.empty_cache()
torch.cuda.set_per_process_memory_fraction(mem_fraction)
@ -150,7 +150,7 @@ def txt2img(
torch.backends.cudnn.allow_tf32 = True
login(token=hf_token)
pipe = pipeline_for_image(model)
pipe = pipeline_for_diffuse(model)
seed = seed if seed else random.randint(0, 2 ** 64)
prompt = prompt
@ -183,7 +183,7 @@ def img2img(
torch.backends.cudnn.allow_tf32 = True
login(token=hf_token)
pipe = pipeline_for_image(model, image=True)
pipe = pipeline_for_diffuse(model, image=True)
with open(img_path, 'rb') as img_file:
input_img = convert_from_bytes_and_crop(img_file.read(), 512, 512)
@ -220,7 +220,7 @@ def txt2txt(
login(token=hf_token)
tokenizer = AutoTokenizer.from_pretrained(model)
pipe = pipeline_for_text(model)
pipe = pipeline_for_transform(model)
prompt = prompt
# TODO: learn more about return tensors and model params
@ -286,6 +286,6 @@ def download_all_models(hf_token: str):
login(token=hf_token)
for model in MODELS:
print(f'DOWNLOADING {model.upper()}')
pipeline_for_image(model)
pipeline_for_diffuse(model)
print(f'DOWNLOADING IMAGE {model.upper()}')
pipeline_for_image(model, image=True)
pipeline_for_diffuse(model, image=True)