diff --git a/skynet/utils.py b/skynet/utils.py index 36219bc..c47fec6 100755 --- a/skynet/utils.py +++ b/skynet/utils.py @@ -1,10 +1,11 @@ #!/usr/bin/python import io -import logging import os +import sys import time import random +import logging from typing import Optional from pathlib import Path @@ -122,11 +123,6 @@ def pipeline_for( pipe.enable_xformers_memory_efficient_attention() - if sys.version_info[1] < 11: - # torch.compile only supported on python < 3.11 - pipe.unet = torch.compile( - pipe.unet, mode='reduce-overhead', fullgraph=True) - if over_mem: if not image: pipe.enable_vae_slicing() @@ -135,6 +131,11 @@ def pipeline_for( pipe.enable_model_cpu_offload() else: + if sys.version_info[1] < 11: + # torch.compile only supported on python < 3.11 + pipe.unet = torch.compile( + pipe.unet, mode='reduce-overhead', fullgraph=True) + pipe = pipe.to('cuda') return pipe