mirror of https://github.com/skygpu/skynet.git
Fix import bug and only enable unet compilation on high end cards
parent
7cd539a944
commit
5b6e18e1ef
|
@ -1,10 +1,11 @@
|
||||||
#!/usr/bin/python
|
#!/usr/bin/python
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
|
import logging
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -122,11 +123,6 @@ def pipeline_for(
|
||||||
|
|
||||||
pipe.enable_xformers_memory_efficient_attention()
|
pipe.enable_xformers_memory_efficient_attention()
|
||||||
|
|
||||||
if sys.version_info[1] < 11:
|
|
||||||
# torch.compile only supported on python < 3.11
|
|
||||||
pipe.unet = torch.compile(
|
|
||||||
pipe.unet, mode='reduce-overhead', fullgraph=True)
|
|
||||||
|
|
||||||
if over_mem:
|
if over_mem:
|
||||||
if not image:
|
if not image:
|
||||||
pipe.enable_vae_slicing()
|
pipe.enable_vae_slicing()
|
||||||
|
@ -135,6 +131,11 @@ def pipeline_for(
|
||||||
pipe.enable_model_cpu_offload()
|
pipe.enable_model_cpu_offload()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
if sys.version_info[1] < 11:
|
||||||
|
# torch.compile only supported on python < 3.11
|
||||||
|
pipe.unet = torch.compile(
|
||||||
|
pipe.unet, mode='reduce-overhead', fullgraph=True)
|
||||||
|
|
||||||
pipe = pipe.to('cuda')
|
pipe = pipe.to('cuda')
|
||||||
|
|
||||||
return pipe
|
return pipe
|
||||||
|
|
Loading…
Reference in New Issue