mirror of https://github.com/skygpu/skynet.git
remove variant fp16 for updated diffusers
parent
60be53ca6a
commit
c26871536a
|
@ -87,18 +87,18 @@ def pipeline_for(
|
||||||
if over_mem:
|
if over_mem:
|
||||||
logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..')
|
logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..')
|
||||||
|
|
||||||
shortname = model_info['short']
|
# shortname = model_info['short']
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'safety_checker': None,
|
'safety_checker': None,
|
||||||
'torch_dtype': torch.float16,
|
'torch_dtype': torch.float16,
|
||||||
'cache_dir': cache_dir,
|
'cache_dir': cache_dir,
|
||||||
'variant': 'fp16'
|
# 'variant': 'fp16'
|
||||||
}
|
}
|
||||||
|
|
||||||
match shortname:
|
# match shortname:
|
||||||
case 'stable':
|
# case 'stable':
|
||||||
params['revision'] = 'fp16'
|
# params['revision'] = 'fp16'
|
||||||
|
|
||||||
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue