mirror of https://github.com/skygpu/skynet.git
address comments, add compute logic to support transformers
parent
82a7a3e076
commit
9f5a70ee11
|
@ -7,12 +7,13 @@ from hashlib import sha256
|
|||
import json
|
||||
import logging
|
||||
from diffusers import DiffusionPipeline
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
import torch
|
||||
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
|
||||
from skynet.dgpu.errors import DGPUComputeError
|
||||
|
||||
from skynet.utils import convert_from_bytes_and_crop, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for_image
|
||||
from skynet.utils import convert_from_bytes_and_crop, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for_diffuse
|
||||
|
||||
|
||||
def prepare_params_for_diffuse(
|
||||
|
@ -42,6 +43,20 @@ def prepare_params_for_diffuse(
|
|||
)
|
||||
|
||||
|
||||
def prepare_params_for_transform(
|
||||
params: dict,
|
||||
):
|
||||
return (
|
||||
params['prompt'],
|
||||
int(params['num_return_sequences']),
|
||||
int(params['no_repeat_ngram_size']),
|
||||
float(params['top_p']),
|
||||
float(params['temperature']),
|
||||
int(params['max_length']),
|
||||
_params
|
||||
)
|
||||
|
||||
|
||||
class SkynetMM:
|
||||
|
||||
def __init__(self, config: dict):
|
||||
|
@ -53,7 +68,8 @@ class SkynetMM:
|
|||
|
||||
self._models = {}
|
||||
for model in self.initial_models:
|
||||
self.load_model(model, False, force=True)
|
||||
# TODO: look into config for adding model type
|
||||
self.load_model(model, False, config['model_type'], force=True)
|
||||
|
||||
def log_debug_info(self):
|
||||
logging.info('memory summary:')
|
||||
|
@ -71,11 +87,14 @@ class SkynetMM:
|
|||
self,
|
||||
model_name: str,
|
||||
image: bool,
|
||||
model_type: str,
|
||||
force=False
|
||||
):
|
||||
logging.info(f'loading model {model_name}...')
|
||||
match model_type:
|
||||
case 'diffuse':
|
||||
if force or len(self._models.keys()) == 0:
|
||||
pipe = pipeline_for_image(model_name, image=image)
|
||||
pipe = pipeline_for_diffuse(model_name, image=image)
|
||||
self._models[model_name] = {
|
||||
'pipe': pipe,
|
||||
'generated': 0,
|
||||
|
@ -92,12 +111,13 @@ class SkynetMM:
|
|||
|
||||
del self._models[least_used]
|
||||
|
||||
logging.info(f'swapping model {least_used} for {model_name}...')
|
||||
logging.info(
|
||||
f'swapping model {least_used} for {model_name}...')
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
pipe = pipeline_for_image(model_name, image=image)
|
||||
pipe = pipeline_for_diffuse(model_name, image=image)
|
||||
|
||||
self._models[model_name] = {
|
||||
'pipe': pipe,
|
||||
|
@ -107,8 +127,12 @@ class SkynetMM:
|
|||
|
||||
logging.info(f'loaded model {model_name}')
|
||||
return pipe
|
||||
case 'transform':
|
||||
...
|
||||
|
||||
def get_model(self, model_name: str, image: bool) -> DiffusionPipeline:
|
||||
def get_model(self, model_name: str, image: bool, model_type: str) -> DiffusionPipeline | AutoModelForCausalLM:
|
||||
match model_type:
|
||||
case 'diffuse':
|
||||
if model_name not in MODELS:
|
||||
raise DGPUComputeError(f'Unknown model {model_name}')
|
||||
|
||||
|
@ -119,6 +143,8 @@ class SkynetMM:
|
|||
pipe = self._models[model_name]['pipe']
|
||||
|
||||
return pipe
|
||||
case 'transform':
|
||||
...
|
||||
|
||||
def compute_one(
|
||||
self,
|
||||
|
@ -134,7 +160,7 @@ class SkynetMM:
|
|||
arguments = prepare_params_for_diffuse(params, binary)
|
||||
prompt, guidance, step, seed, upscaler, extra_params = arguments
|
||||
model = self.get_model(
|
||||
params['model'], 'image' in extra_params)
|
||||
params['model'], 'image' in extra_params, method)
|
||||
|
||||
image = model(
|
||||
prompt,
|
||||
|
@ -156,9 +182,23 @@ class SkynetMM:
|
|||
|
||||
return img_sha, img_raw
|
||||
|
||||
case 'transformer':
|
||||
# TODO: Understand dpgu code and figure out what to put here
|
||||
pass
|
||||
case 'transform':
|
||||
arguments = prepare_params_for_transform(params)
|
||||
prompt, num_return_sequences, no_repeat_ngram_size, top_p, \
|
||||
temperature, max_length, extra_params = arguments
|
||||
model = self.get_model(params['model'], False, method)
|
||||
|
||||
response = model(
|
||||
prompt,
|
||||
num_return_sequences,
|
||||
no_repeat_ngram_size,
|
||||
top_p,
|
||||
temperature,
|
||||
max_length,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
case _:
|
||||
raise DGPUComputeError('Unsupported compute method')
|
||||
|
||||
|
|
|
@ -59,7 +59,7 @@ def convert_from_bytes_and_crop(raw: bytes, max_w: int, max_h: int) -> Image:
|
|||
return image.convert('RGB')
|
||||
|
||||
|
||||
def pipeline_for_image(model: str, mem_fraction: float = 1.0, image=False) -> DiffusionPipeline:
|
||||
def pipeline_for_diffuse(model: str, mem_fraction: float = 1.0, image=False) -> DiffusionPipeline:
|
||||
assert torch.cuda.is_available()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
||||
|
@ -98,7 +98,7 @@ def pipeline_for_image(model: str, mem_fraction: float = 1.0, image=False) -> Di
|
|||
return pipe.to('cuda')
|
||||
|
||||
|
||||
def pipeline_for_text(model: str, mem_fraction: float = 1.0, image=False) -> DiffusionPipeline:
|
||||
def pipeline_for_transform(model: str, mem_fraction: float = 1.0, image=False) -> DiffusionPipeline:
|
||||
assert torch.cuda.is_available()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
||||
|
@ -150,7 +150,7 @@ def txt2img(
|
|||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
login(token=hf_token)
|
||||
pipe = pipeline_for_image(model)
|
||||
pipe = pipeline_for_diffuse(model)
|
||||
|
||||
seed = seed if seed else random.randint(0, 2 ** 64)
|
||||
prompt = prompt
|
||||
|
@ -183,7 +183,7 @@ def img2img(
|
|||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
login(token=hf_token)
|
||||
pipe = pipeline_for_image(model, image=True)
|
||||
pipe = pipeline_for_diffuse(model, image=True)
|
||||
|
||||
with open(img_path, 'rb') as img_file:
|
||||
input_img = convert_from_bytes_and_crop(img_file.read(), 512, 512)
|
||||
|
@ -220,7 +220,7 @@ def txt2txt(
|
|||
|
||||
login(token=hf_token)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
pipe = pipeline_for_text(model)
|
||||
pipe = pipeline_for_transform(model)
|
||||
|
||||
prompt = prompt
|
||||
# TODO: learn more about return tensors and model params
|
||||
|
@ -286,6 +286,6 @@ def download_all_models(hf_token: str):
|
|||
login(token=hf_token)
|
||||
for model in MODELS:
|
||||
print(f'DOWNLOADING {model.upper()}')
|
||||
pipeline_for_image(model)
|
||||
pipeline_for_diffuse(model)
|
||||
print(f'DOWNLOADING IMAGE {model.upper()}')
|
||||
pipeline_for_image(model, image=True)
|
||||
pipeline_for_diffuse(model, image=True)
|
||||
|
|
Loading…
Reference in New Issue