|
17 | 17 | import os |
18 | 18 | import random |
19 | 19 | import tempfile |
20 | | -import tracemalloc |
21 | 20 | import unittest |
22 | 21 |
|
23 | 22 | import numpy as np |
24 | 23 | import torch |
25 | 24 |
|
26 | | -import accelerate |
27 | 25 | import PIL |
28 | | -import transformers |
29 | 26 | from diffusers import ( |
30 | 27 | AutoencoderKL, |
31 | 28 | DDIMPipeline, |
|
44 | 41 | from diffusers.pipeline_utils import DiffusionPipeline |
45 | 42 | from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME |
46 | 43 | from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device |
47 | | -from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu |
48 | | -from packaging import version |
| 44 | +from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir |
49 | 45 | from PIL import Image |
50 | 46 | from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer |
51 | 47 |
|
@@ -487,71 +483,3 @@ def test_ddpm_ddim_equality_batched(self): |
487 | 483 |
|
488 | 484 | # the values aren't exactly equal, but the images look the same visually |
489 | 485 | assert np.abs(ddpm_images - ddim_images).max() < 1e-1 |
490 | | - |
491 | | - @require_torch_gpu |
492 | | - def test_stable_diffusion_accelerate_load_works(self): |
493 | | - if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"): |
494 | | - return |
495 | | - |
496 | | - if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"): |
497 | | - return |
498 | | - |
499 | | - model_id = "CompVis/stable-diffusion-v1-4" |
500 | | - _ = StableDiffusionPipeline.from_pretrained( |
501 | | - model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto" |
502 | | - ).to(torch_device) |
503 | | - |
504 | | - @require_torch_gpu |
505 | | - def test_stable_diffusion_accelerate_load_reduces_memory_footprint(self): |
506 | | - if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"): |
507 | | - return |
508 | | - |
509 | | - if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"): |
510 | | - return |
511 | | - |
512 | | - pipeline_id = "CompVis/stable-diffusion-v1-4" |
513 | | - |
514 | | - torch.cuda.empty_cache() |
515 | | - gc.collect() |
516 | | - |
517 | | - tracemalloc.start() |
518 | | - pipeline_normal_load = StableDiffusionPipeline.from_pretrained( |
519 | | - pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True |
520 | | - ) |
521 | | - pipeline_normal_load.to(torch_device) |
522 | | - _, peak_normal = tracemalloc.get_traced_memory() |
523 | | - tracemalloc.stop() |
524 | | - |
525 | | - del pipeline_normal_load |
526 | | - torch.cuda.empty_cache() |
527 | | - gc.collect() |
528 | | - |
529 | | - tracemalloc.start() |
530 | | - _ = StableDiffusionPipeline.from_pretrained( |
531 | | - pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto" |
532 | | - ) |
533 | | - _, peak_accelerate = tracemalloc.get_traced_memory() |
534 | | - |
535 | | - tracemalloc.stop() |
536 | | - |
537 | | - assert peak_accelerate < peak_normal |
538 | | - |
539 | | - @slow |
540 | | - @unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU") |
541 | | - def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self): |
542 | | - torch.cuda.empty_cache() |
543 | | - torch.cuda.reset_max_memory_allocated() |
544 | | - |
545 | | - pipeline_id = "CompVis/stable-diffusion-v1-4" |
546 | | - prompt = "Andromeda galaxy in a bottle" |
547 | | - |
548 | | - pipeline = StableDiffusionPipeline.from_pretrained( |
549 | | - pipeline_id, revision="fp16", torch_dtype=torch.float32, use_auth_token=True |
550 | | - ) |
551 | | - pipeline.cuda_with_minimal_gpu_usage() |
552 | | - |
553 | | - _ = pipeline(prompt) |
554 | | - |
555 | | - mem_bytes = torch.cuda.max_memory_allocated() |
556 | | - # make sure that less than 0.8 GB is allocated |
557 | | - assert mem_bytes < 0.8 * 10**9 |
0 commit comments