diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index cf4c5c5fdeca..cca11281359a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -119,14 +119,13 @@ def disable_attention_slicing(self): # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) - def cuda_with_minimal_gpu_usage(self): + def enable_sequential_cpu_offload(self): if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device("cuda") - self.enable_attention_slicing(1) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: cpu_offload(cpu_offloaded_model, device) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index d7e6c362d1d7..a81710987814 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -15,6 +15,7 @@ import gc import random +import time import unittest import numpy as np @@ -730,3 +731,39 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No ) assert test_callback_fn.has_been_called assert number_of_steps == 51 + + def test_stable_diffusion_accelerate_auto_device(self): + pipeline_id = "CompVis/stable-diffusion-v1-4" + + start_time = time.time() + pipeline_normal_load = StableDiffusionPipeline.from_pretrained( + pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True + ) + pipeline_normal_load.to(torch_device) + normal_load_time = time.time() - start_time + + start_time = time.time() + _ = StableDiffusionPipeline.from_pretrained( + pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto" + ) + meta_device_load_time = time.time() - start_time + + assert 2 * meta_device_load_time < normal_load_time + + @unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU") + def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + + pipeline_id = "CompVis/stable-diffusion-v1-4" + prompt = "Andromeda galaxy in a bottle" + + pipeline = StableDiffusionPipeline.from_pretrained(pipeline_id, revision="fp16", torch_dtype=torch.float16) + pipeline.enable_attention_slicing(1) + pipeline.enable_sequential_cpu_offload() + + _ = pipeline(prompt, num_inference_steps=5) + + mem_bytes = torch.cuda.max_memory_allocated() + # make sure that less than 1.5 GB is allocated + assert mem_bytes < 1.5 * 10**9 diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 6e9388ca3a65..3cc94962c38f 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -17,15 +17,12 @@ import os import random import tempfile -import tracemalloc import unittest import numpy as np import torch -import accelerate import PIL -import transformers from diffusers import ( AutoencoderKL, DDIMPipeline, @@ -44,8 +41,7 @@ from diffusers.pipeline_utils import DiffusionPipeline from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device -from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu -from packaging import version +from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir from PIL import Image from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer @@ -487,71 +483,3 @@ def test_ddpm_ddim_equality_batched(self): # the values aren't exactly equal, but the images look the same visually assert np.abs(ddpm_images - ddim_images).max() < 1e-1 - - @require_torch_gpu - def test_stable_diffusion_accelerate_load_works(self): - if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"): - return - - if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"): - return - - model_id = "CompVis/stable-diffusion-v1-4" - _ = StableDiffusionPipeline.from_pretrained( - model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto" - ).to(torch_device) - - @require_torch_gpu - def test_stable_diffusion_accelerate_load_reduces_memory_footprint(self): - if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"): - return - - if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"): - return - - pipeline_id = "CompVis/stable-diffusion-v1-4" - - torch.cuda.empty_cache() - gc.collect() - - tracemalloc.start() - pipeline_normal_load = StableDiffusionPipeline.from_pretrained( - pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True - ) - pipeline_normal_load.to(torch_device) - _, peak_normal = tracemalloc.get_traced_memory() - tracemalloc.stop() - - del pipeline_normal_load - torch.cuda.empty_cache() - gc.collect() - - tracemalloc.start() - _ = StableDiffusionPipeline.from_pretrained( - pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto" - ) - _, peak_accelerate = tracemalloc.get_traced_memory() - - tracemalloc.stop() - - assert peak_accelerate < peak_normal - - @slow - @unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU") - def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self): - torch.cuda.empty_cache() - torch.cuda.reset_max_memory_allocated() - - pipeline_id = "CompVis/stable-diffusion-v1-4" - prompt = "Andromeda galaxy in a bottle" - - pipeline = StableDiffusionPipeline.from_pretrained( - pipeline_id, revision="fp16", torch_dtype=torch.float32, use_auth_token=True - ) - pipeline.cuda_with_minimal_gpu_usage() - - _ = pipeline(prompt) - - mem_bytes = torch.cuda.max_memory_allocated() - # make sure that less than 0.8 GB is allocated - assert mem_bytes < 0.8 * 10**9