mirror of https://github.com/skygpu/skynet.git
Fix import bug and only enable unet compilation on high end cards
parent
7cd539a944
commit
5b6e18e1ef
|
@ -1,10 +1,11 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import random
|
||||
import logging
|
||||
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
@ -122,11 +123,6 @@ def pipeline_for(
|
|||
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
if sys.version_info[1] < 11:
|
||||
# torch.compile only supported on python < 3.11
|
||||
pipe.unet = torch.compile(
|
||||
pipe.unet, mode='reduce-overhead', fullgraph=True)
|
||||
|
||||
if over_mem:
|
||||
if not image:
|
||||
pipe.enable_vae_slicing()
|
||||
|
@ -135,6 +131,11 @@ def pipeline_for(
|
|||
pipe.enable_model_cpu_offload()
|
||||
|
||||
else:
|
||||
if sys.version_info[1] < 11:
|
||||
# torch.compile only supported on python < 3.11
|
||||
pipe.unet = torch.compile(
|
||||
pipe.unet, mode='reduce-overhead', fullgraph=True)
|
||||
|
||||
pipe = pipe.to('cuda')
|
||||
|
||||
return pipe
|
||||
|
|
Loading…
Reference in New Issue