mirror of https://github.com/skygpu/skynet.git
address comments, add compute logic to support transformers
parent
82a7a3e076
commit
9f5a70ee11
|
@ -7,12 +7,13 @@ from hashlib import sha256
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
|
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
|
||||||
from skynet.dgpu.errors import DGPUComputeError
|
from skynet.dgpu.errors import DGPUComputeError
|
||||||
|
|
||||||
from skynet.utils import convert_from_bytes_and_crop, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for_image
|
from skynet.utils import convert_from_bytes_and_crop, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for_diffuse
|
||||||
|
|
||||||
|
|
||||||
def prepare_params_for_diffuse(
|
def prepare_params_for_diffuse(
|
||||||
|
@ -42,6 +43,20 @@ def prepare_params_for_diffuse(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_params_for_transform(
|
||||||
|
params: dict,
|
||||||
|
):
|
||||||
|
return (
|
||||||
|
params['prompt'],
|
||||||
|
int(params['num_return_sequences']),
|
||||||
|
int(params['no_repeat_ngram_size']),
|
||||||
|
float(params['top_p']),
|
||||||
|
float(params['temperature']),
|
||||||
|
int(params['max_length']),
|
||||||
|
_params
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SkynetMM:
|
class SkynetMM:
|
||||||
|
|
||||||
def __init__(self, config: dict):
|
def __init__(self, config: dict):
|
||||||
|
@ -53,7 +68,8 @@ class SkynetMM:
|
||||||
|
|
||||||
self._models = {}
|
self._models = {}
|
||||||
for model in self.initial_models:
|
for model in self.initial_models:
|
||||||
self.load_model(model, False, force=True)
|
# TODO: look into config for adding model type
|
||||||
|
self.load_model(model, False, config['model_type'], force=True)
|
||||||
|
|
||||||
def log_debug_info(self):
|
def log_debug_info(self):
|
||||||
logging.info('memory summary:')
|
logging.info('memory summary:')
|
||||||
|
@ -71,11 +87,14 @@ class SkynetMM:
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
image: bool,
|
image: bool,
|
||||||
|
model_type: str,
|
||||||
force=False
|
force=False
|
||||||
):
|
):
|
||||||
logging.info(f'loading model {model_name}...')
|
logging.info(f'loading model {model_name}...')
|
||||||
|
match model_type:
|
||||||
|
case 'diffuse':
|
||||||
if force or len(self._models.keys()) == 0:
|
if force or len(self._models.keys()) == 0:
|
||||||
pipe = pipeline_for_image(model_name, image=image)
|
pipe = pipeline_for_diffuse(model_name, image=image)
|
||||||
self._models[model_name] = {
|
self._models[model_name] = {
|
||||||
'pipe': pipe,
|
'pipe': pipe,
|
||||||
'generated': 0,
|
'generated': 0,
|
||||||
|
@ -92,12 +111,13 @@ class SkynetMM:
|
||||||
|
|
||||||
del self._models[least_used]
|
del self._models[least_used]
|
||||||
|
|
||||||
logging.info(f'swapping model {least_used} for {model_name}...')
|
logging.info(
|
||||||
|
f'swapping model {least_used} for {model_name}...')
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
pipe = pipeline_for_image(model_name, image=image)
|
pipe = pipeline_for_diffuse(model_name, image=image)
|
||||||
|
|
||||||
self._models[model_name] = {
|
self._models[model_name] = {
|
||||||
'pipe': pipe,
|
'pipe': pipe,
|
||||||
|
@ -107,8 +127,12 @@ class SkynetMM:
|
||||||
|
|
||||||
logging.info(f'loaded model {model_name}')
|
logging.info(f'loaded model {model_name}')
|
||||||
return pipe
|
return pipe
|
||||||
|
case 'transform':
|
||||||
|
...
|
||||||
|
|
||||||
def get_model(self, model_name: str, image: bool) -> DiffusionPipeline:
|
def get_model(self, model_name: str, image: bool, model_type: str) -> DiffusionPipeline | AutoModelForCausalLM:
|
||||||
|
match model_type:
|
||||||
|
case 'diffuse':
|
||||||
if model_name not in MODELS:
|
if model_name not in MODELS:
|
||||||
raise DGPUComputeError(f'Unknown model {model_name}')
|
raise DGPUComputeError(f'Unknown model {model_name}')
|
||||||
|
|
||||||
|
@ -119,6 +143,8 @@ class SkynetMM:
|
||||||
pipe = self._models[model_name]['pipe']
|
pipe = self._models[model_name]['pipe']
|
||||||
|
|
||||||
return pipe
|
return pipe
|
||||||
|
case 'transform':
|
||||||
|
...
|
||||||
|
|
||||||
def compute_one(
|
def compute_one(
|
||||||
self,
|
self,
|
||||||
|
@ -134,7 +160,7 @@ class SkynetMM:
|
||||||
arguments = prepare_params_for_diffuse(params, binary)
|
arguments = prepare_params_for_diffuse(params, binary)
|
||||||
prompt, guidance, step, seed, upscaler, extra_params = arguments
|
prompt, guidance, step, seed, upscaler, extra_params = arguments
|
||||||
model = self.get_model(
|
model = self.get_model(
|
||||||
params['model'], 'image' in extra_params)
|
params['model'], 'image' in extra_params, method)
|
||||||
|
|
||||||
image = model(
|
image = model(
|
||||||
prompt,
|
prompt,
|
||||||
|
@ -156,9 +182,23 @@ class SkynetMM:
|
||||||
|
|
||||||
return img_sha, img_raw
|
return img_sha, img_raw
|
||||||
|
|
||||||
case 'transformer':
|
case 'transform':
|
||||||
# TODO: Understand dpgu code and figure out what to put here
|
arguments = prepare_params_for_transform(params)
|
||||||
pass
|
prompt, num_return_sequences, no_repeat_ngram_size, top_p, \
|
||||||
|
temperature, max_length, extra_params = arguments
|
||||||
|
model = self.get_model(params['model'], False, method)
|
||||||
|
|
||||||
|
response = model(
|
||||||
|
prompt,
|
||||||
|
num_return_sequences,
|
||||||
|
no_repeat_ngram_size,
|
||||||
|
top_p,
|
||||||
|
temperature,
|
||||||
|
max_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
raise DGPUComputeError('Unsupported compute method')
|
raise DGPUComputeError('Unsupported compute method')
|
||||||
|
|
||||||
|
|
|
@ -59,7 +59,7 @@ def convert_from_bytes_and_crop(raw: bytes, max_w: int, max_h: int) -> Image:
|
||||||
return image.convert('RGB')
|
return image.convert('RGB')
|
||||||
|
|
||||||
|
|
||||||
def pipeline_for_image(model: str, mem_fraction: float = 1.0, image=False) -> DiffusionPipeline:
|
def pipeline_for_diffuse(model: str, mem_fraction: float = 1.0, image=False) -> DiffusionPipeline:
|
||||||
assert torch.cuda.is_available()
|
assert torch.cuda.is_available()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
||||||
|
@ -98,7 +98,7 @@ def pipeline_for_image(model: str, mem_fraction: float = 1.0, image=False) -> Di
|
||||||
return pipe.to('cuda')
|
return pipe.to('cuda')
|
||||||
|
|
||||||
|
|
||||||
def pipeline_for_text(model: str, mem_fraction: float = 1.0, image=False) -> DiffusionPipeline:
|
def pipeline_for_transform(model: str, mem_fraction: float = 1.0, image=False) -> DiffusionPipeline:
|
||||||
assert torch.cuda.is_available()
|
assert torch.cuda.is_available()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
||||||
|
@ -150,7 +150,7 @@ def txt2img(
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
login(token=hf_token)
|
login(token=hf_token)
|
||||||
pipe = pipeline_for_image(model)
|
pipe = pipeline_for_diffuse(model)
|
||||||
|
|
||||||
seed = seed if seed else random.randint(0, 2 ** 64)
|
seed = seed if seed else random.randint(0, 2 ** 64)
|
||||||
prompt = prompt
|
prompt = prompt
|
||||||
|
@ -183,7 +183,7 @@ def img2img(
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
login(token=hf_token)
|
login(token=hf_token)
|
||||||
pipe = pipeline_for_image(model, image=True)
|
pipe = pipeline_for_diffuse(model, image=True)
|
||||||
|
|
||||||
with open(img_path, 'rb') as img_file:
|
with open(img_path, 'rb') as img_file:
|
||||||
input_img = convert_from_bytes_and_crop(img_file.read(), 512, 512)
|
input_img = convert_from_bytes_and_crop(img_file.read(), 512, 512)
|
||||||
|
@ -220,7 +220,7 @@ def txt2txt(
|
||||||
|
|
||||||
login(token=hf_token)
|
login(token=hf_token)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||||
pipe = pipeline_for_text(model)
|
pipe = pipeline_for_transform(model)
|
||||||
|
|
||||||
prompt = prompt
|
prompt = prompt
|
||||||
# TODO: learn more about return tensors and model params
|
# TODO: learn more about return tensors and model params
|
||||||
|
@ -286,6 +286,6 @@ def download_all_models(hf_token: str):
|
||||||
login(token=hf_token)
|
login(token=hf_token)
|
||||||
for model in MODELS:
|
for model in MODELS:
|
||||||
print(f'DOWNLOADING {model.upper()}')
|
print(f'DOWNLOADING {model.upper()}')
|
||||||
pipeline_for_image(model)
|
pipeline_for_diffuse(model)
|
||||||
print(f'DOWNLOADING IMAGE {model.upper()}')
|
print(f'DOWNLOADING IMAGE {model.upper()}')
|
||||||
pipeline_for_image(model, image=True)
|
pipeline_for_diffuse(model, image=True)
|
||||||
|
|
Loading…
Reference in New Issue