Fix import bug and only enable unet compilation on high end cards

pull/26/head
Guillermo Rodriguez 2023-10-07 11:12:15 -03:00
parent 7cd539a944
commit 5b6e18e1ef
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
1 changed files with 7 additions and 6 deletions

View File

@ -1,10 +1,11 @@
#!/usr/bin/python
import io
import logging
import os
import sys
import time
import random
import logging
from typing import Optional
from pathlib import Path
@ -122,11 +123,6 @@ def pipeline_for(
pipe.enable_xformers_memory_efficient_attention()
if sys.version_info[1] < 11:
# torch.compile only supported on python < 3.11
pipe.unet = torch.compile(
pipe.unet, mode='reduce-overhead', fullgraph=True)
if over_mem:
if not image:
pipe.enable_vae_slicing()
@ -135,6 +131,11 @@ def pipeline_for(
pipe.enable_model_cpu_offload()
else:
if sys.version_info[1] < 11:
# torch.compile only supported on python < 3.11
pipe.unet = torch.compile(
pipe.unet, mode='reduce-overhead', fullgraph=True)
pipe = pipe.to('cuda')
return pipe