|
17 | 17 | import os |
18 | 18 | import random |
19 | 19 | import tempfile |
| 20 | +import tracemalloc |
20 | 21 | import unittest |
21 | 22 |
|
22 | 23 | import numpy as np |
23 | 24 | import torch |
24 | 25 |
|
| 26 | +import accelerate |
25 | 27 | import PIL |
| 28 | +import transformers |
26 | 29 | from diffusers import ( |
27 | 30 | AutoencoderKL, |
28 | 31 | DDIMPipeline, |
|
50 | 53 | from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME |
51 | 54 | from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device |
52 | 55 | from diffusers.utils.testing_utils import get_tests_dir |
| 56 | +from packaging import version |
53 | 57 | from PIL import Image |
54 | 58 | from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer |
55 | 59 |
|
@@ -2034,3 +2038,53 @@ def test_callback_fn(step: int, timestep: int, latents: np.ndarray) -> None: |
2034 | 2038 | pipe(prompt=prompt, num_inference_steps=5, guidance_scale=7.5, callback=test_callback_fn, callback_steps=1) |
2035 | 2039 | assert test_callback_fn.has_been_called |
2036 | 2040 | assert number_of_steps == 6 |
| 2041 | + |
| 2042 | + @slow |
| 2043 | + @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") |
| 2044 | + def test_stable_diffusion_accelerate_load_works(self): |
| 2045 | + if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"): |
| 2046 | + return |
| 2047 | + |
| 2048 | + if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"): |
| 2049 | + return |
| 2050 | + |
| 2051 | + model_id = "CompVis/stable-diffusion-v1-4" |
| 2052 | + _ = StableDiffusionPipeline.from_pretrained( |
| 2053 | + model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto" |
| 2054 | + ).to(torch_device) |
| 2055 | + |
| 2056 | + @slow |
| 2057 | + @unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU") |
| 2058 | + def test_stable_diffusion_accelerate_load_reduces_memory_footprint(self): |
| 2059 | + if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"): |
| 2060 | + return |
| 2061 | + |
| 2062 | + if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"): |
| 2063 | + return |
| 2064 | + |
| 2065 | + pipeline_id = "CompVis/stable-diffusion-v1-4" |
| 2066 | + |
| 2067 | + torch.cuda.empty_cache() |
| 2068 | + gc.collect() |
| 2069 | + |
| 2070 | + tracemalloc.start() |
| 2071 | + pipeline_normal_load = StableDiffusionPipeline.from_pretrained( |
| 2072 | + pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True |
| 2073 | + ) |
| 2074 | + pipeline_normal_load.to(torch_device) |
| 2075 | + _, peak_normal = tracemalloc.get_traced_memory() |
| 2076 | + tracemalloc.stop() |
| 2077 | + |
| 2078 | + del pipeline_normal_load |
| 2079 | + torch.cuda.empty_cache() |
| 2080 | + gc.collect() |
| 2081 | + |
| 2082 | + tracemalloc.start() |
| 2083 | + _ = StableDiffusionPipeline.from_pretrained( |
| 2084 | + pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto" |
| 2085 | + ) |
| 2086 | + _, peak_accelerate = tracemalloc.get_traced_memory() |
| 2087 | + |
| 2088 | + tracemalloc.stop() |
| 2089 | + |
| 2090 | + assert peak_accelerate < peak_normal |
0 commit comments