From 5b6e18e1effc9d3157e7365723733121c17ce049 Mon Sep 17 00:00:00 2001 From: Guillermo Rodriguez Date: Sat, 7 Oct 2023 11:12:15 -0300 Subject: [PATCH] Fix import bug and only enable unet compilation on high end cards --- skynet/utils.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/skynet/utils.py b/skynet/utils.py index 36219bc..c47fec6 100755 --- a/skynet/utils.py +++ b/skynet/utils.py @@ -1,10 +1,11 @@ #!/usr/bin/python import io -import logging import os +import sys import time import random +import logging from typing import Optional from pathlib import Path @@ -122,11 +123,6 @@ def pipeline_for( pipe.enable_xformers_memory_efficient_attention() - if sys.version_info[1] < 11: - # torch.compile only supported on python < 3.11 - pipe.unet = torch.compile( - pipe.unet, mode='reduce-overhead', fullgraph=True) - if over_mem: if not image: pipe.enable_vae_slicing() @@ -135,6 +131,11 @@ def pipeline_for( pipe.enable_model_cpu_offload() else: + if sys.version_info[1] < 11: + # torch.compile only supported on python < 3.11 + pipe.unet = torch.compile( + pipe.unet, mode='reduce-overhead', fullgraph=True) + pipe = pipe.to('cuda') return pipe