diff --git a/skynet/utils.py b/skynet/utils.py index 5e88ed8..7b5621e 100755 --- a/skynet/utils.py +++ b/skynet/utils.py @@ -87,18 +87,18 @@ def pipeline_for( if over_mem: logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..') - shortname = model_info['short'] + # shortname = model_info['short'] params = { 'safety_checker': None, 'torch_dtype': torch.float16, 'cache_dir': cache_dir, - 'variant': 'fp16' + # 'variant': 'fp16' } - match shortname: - case 'stable': - params['revision'] = 'fp16' + # match shortname: + # case 'stable': + # params['revision'] = 'fp16' torch.cuda.set_per_process_memory_fraction(mem_fraction)