Fix import bug and only enable unet compilation on high end cards

pull/26/head
Guillermo Rodriguez 2023-10-07 11:12:15 -03:00
parent 7cd539a944
commit 5b6e18e1ef
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
1 changed files with 7 additions and 6 deletions

View File

@ -1,10 +1,11 @@
#!/usr/bin/python #!/usr/bin/python
import io import io
import logging
import os import os
import sys
import time import time
import random import random
import logging
from typing import Optional from typing import Optional
from pathlib import Path from pathlib import Path
@ -122,11 +123,6 @@ def pipeline_for(
pipe.enable_xformers_memory_efficient_attention() pipe.enable_xformers_memory_efficient_attention()
if sys.version_info[1] < 11:
# torch.compile only supported on python < 3.11
pipe.unet = torch.compile(
pipe.unet, mode='reduce-overhead', fullgraph=True)
if over_mem: if over_mem:
if not image: if not image:
pipe.enable_vae_slicing() pipe.enable_vae_slicing()
@ -135,6 +131,11 @@ def pipeline_for(
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload()
else: else:
if sys.version_info[1] < 11:
# torch.compile only supported on python < 3.11
pipe.unet = torch.compile(
pipe.unet, mode='reduce-overhead', fullgraph=True)
pipe = pipe.to('cuda') pipe = pipe.to('cuda')
return pipe return pipe