Skip to content

Commit fab1752

Browse files
[Low CPU memory] + device map (#772)
* add accelerate to load models with smaller memory footprint * remove low_cpu_mem_usage as it is reduntant * move accelerate init weights context to modelling utils * add test to ensure results are the same when loading with accelerate * add tests to ensure ram usage gets lower when using accelerate * move accelerate logic to single snippet under modelling utils and remove it from configuration utils * format code using to pass quality check * fix imports with isor * add accelerate to test extra deps * only import accelerate if device_map is set to auto * move accelerate availability check to diffusers import utils * format code * add device map to pipeline abstraction * lint it to pass PR quality check * fix class check to use accelerate when using diffusers ModelMixin subclasses * use low_cpu_mem_usage in transformers if device_map is not available * NoModuleLayer * comment out tests * up * uP * finish * Update src/diffusers/pipelines/stable_diffusion/safety_checker.py * finish * uP * make style Co-authored-by: Pi Esposito <[email protected]>
1 parent feaa732 commit fab1752

File tree

3 files changed

+79
-3
lines changed

3 files changed

+79
-3
lines changed

src/diffusers/pipeline_utils.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,19 @@
3232
from .configuration_utils import ConfigMixin
3333
from .dynamic_modules_utils import get_class_from_dynamic_module
3434
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
35-
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging
35+
from .utils import (
36+
CONFIG_NAME,
37+
DIFFUSERS_CACHE,
38+
ONNX_WEIGHTS_NAME,
39+
WEIGHTS_NAME,
40+
BaseOutput,
41+
is_transformers_available,
42+
logging,
43+
)
44+
45+
46+
if is_transformers_available():
47+
from transformers import PreTrainedModel
3648

3749

3850
INDEX_FILE = "diffusion_pytorch_model.bin"
@@ -338,6 +350,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
338350
custom_pipeline = kwargs.pop("custom_pipeline", None)
339351
provider = kwargs.pop("provider", None)
340352
sess_options = kwargs.pop("sess_options", None)
353+
device_map = kwargs.pop("device_map", None)
341354

342355
# 1. Download the checkpoints and configs
343356
# use snapshot download here to get it working from from_pretrained
@@ -463,6 +476,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
463476
loading_kwargs["provider"] = provider
464477
loading_kwargs["sess_options"] = sess_options
465478

479+
if (
480+
issubclass(class_obj, diffusers.ModelMixin)
481+
or is_transformers_available()
482+
and issubclass(class_obj, PreTrainedModel)
483+
):
484+
loading_kwargs["device_map"] = device_map
485+
466486
# check if the module is in a subdirectory
467487
if os.path.isdir(os.path.join(cached_folder, name)):
468488
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)

src/diffusers/pipelines/stable_diffusion/safety_checker.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ def cosine_distance(image_embeds, text_embeds):
1919
class StableDiffusionSafetyChecker(PreTrainedModel):
2020
config_class = CLIPConfig
2121

22+
_no_split_modules = ["CLIPEncoderLayer"]
23+
2224
def __init__(self, config: CLIPConfig):
2325
super().__init__(config)
2426

@@ -28,8 +30,8 @@ def __init__(self, config: CLIPConfig):
2830
self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
2931
self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
3032

31-
self.register_buffer("concept_embeds_weights", torch.ones(17))
32-
self.register_buffer("special_care_embeds_weights", torch.ones(3))
33+
self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
34+
self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)
3335

3436
@torch.no_grad()
3537
def forward(self, clip_input, images):

tests/test_pipelines.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@
1717
import os
1818
import random
1919
import tempfile
20+
import tracemalloc
2021
import unittest
2122

2223
import numpy as np
2324
import torch
2425

26+
import accelerate
2527
import PIL
28+
import transformers
2629
from diffusers import (
2730
AutoencoderKL,
2831
DDIMPipeline,
@@ -50,6 +53,7 @@
5053
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
5154
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device
5255
from diffusers.utils.testing_utils import get_tests_dir
56+
from packaging import version
5357
from PIL import Image
5458
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
5559

@@ -2034,3 +2038,53 @@ def test_callback_fn(step: int, timestep: int, latents: np.ndarray) -> None:
20342038
pipe(prompt=prompt, num_inference_steps=5, guidance_scale=7.5, callback=test_callback_fn, callback_steps=1)
20352039
assert test_callback_fn.has_been_called
20362040
assert number_of_steps == 6
2041+
2042+
@slow
2043+
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
2044+
def test_stable_diffusion_accelerate_load_works(self):
2045+
if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"):
2046+
return
2047+
2048+
if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"):
2049+
return
2050+
2051+
model_id = "CompVis/stable-diffusion-v1-4"
2052+
_ = StableDiffusionPipeline.from_pretrained(
2053+
model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto"
2054+
).to(torch_device)
2055+
2056+
@slow
2057+
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
2058+
def test_stable_diffusion_accelerate_load_reduces_memory_footprint(self):
2059+
if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"):
2060+
return
2061+
2062+
if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"):
2063+
return
2064+
2065+
pipeline_id = "CompVis/stable-diffusion-v1-4"
2066+
2067+
torch.cuda.empty_cache()
2068+
gc.collect()
2069+
2070+
tracemalloc.start()
2071+
pipeline_normal_load = StableDiffusionPipeline.from_pretrained(
2072+
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True
2073+
)
2074+
pipeline_normal_load.to(torch_device)
2075+
_, peak_normal = tracemalloc.get_traced_memory()
2076+
tracemalloc.stop()
2077+
2078+
del pipeline_normal_load
2079+
torch.cuda.empty_cache()
2080+
gc.collect()
2081+
2082+
tracemalloc.start()
2083+
_ = StableDiffusionPipeline.from_pretrained(
2084+
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto"
2085+
)
2086+
_, peak_accelerate = tracemalloc.get_traced_memory()
2087+
2088+
tracemalloc.stop()
2089+
2090+
assert peak_accelerate < peak_normal

0 commit comments

Comments
 (0)