Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
0ea501b
add accelerate to load models with smaller memory footprint
piEsposito Sep 5, 2022
7631dd6
remove low_cpu_mem_usage as it is reduntant
piEsposito Sep 12, 2022
973eb23
Merge branch 'main' of github.com:huggingface/diffusers into main
piEsposito Sep 12, 2022
8592e23
move accelerate init weights context to modelling utils
piEsposito Sep 16, 2022
76b8e4a
add test to ensure results are the same when loading with accelerate
piEsposito Sep 16, 2022
dd7f9b9
add tests to ensure ram usage gets lower when using accelerate
piEsposito Sep 16, 2022
ec5f7aa
move accelerate logic to single snippet under modelling utils and rem…
piEsposito Sep 16, 2022
ae5f56d
Merge branch 'huggingface:main' into main
piEsposito Sep 16, 2022
8392e3f
format code using to pass quality check
piEsposito Sep 16, 2022
615054a
fix imports with isor
piEsposito Sep 16, 2022
75c08a9
add accelerate to test extra deps
piEsposito Sep 16, 2022
7e06f3d
Merge branch 'main' into main
piEsposito Sep 16, 2022
6189b86
only import accelerate if device_map is set to auto
piEsposito Sep 21, 2022
02818b5
Merge branch 'main' of github.com:piEsposito/diffusers into main
piEsposito Sep 21, 2022
dc14ace
Merge branch 'main' of github.com:huggingface/diffusers into main
piEsposito Sep 21, 2022
bc51061
move accelerate availability check to diffusers import utils
piEsposito Sep 22, 2022
ad1b55d
Merge remote-tracking branch 'upstream/main' into main
piEsposito Sep 22, 2022
e020d73
format code
piEsposito Sep 22, 2022
c3778bb
Merge branch 'main' into main
piEsposito Sep 22, 2022
0e2319d
Merge branch 'main' into main
piEsposito Oct 3, 2022
25e07d8
Merge branch 'main' into main
patrickvonplaten Oct 4, 2022
6206595
add device map to pipeline abstraction
piEsposito Oct 5, 2022
8912f53
lint it to pass PR quality check
piEsposito Oct 5, 2022
fd8829f
fix class check to use accelerate when using diffusers ModelMixin sub…
piEsposito Oct 5, 2022
85c2442
Merge branch 'main' of https://github.com/piEsposito/diffusers into main
patrickvonplaten Oct 5, 2022
8132cfd
use low_cpu_mem_usage in transformers if device_map is not available
piEsposito Oct 5, 2022
fe251f3
Merge branch 'main' into main
piEsposito Oct 5, 2022
4f7e319
Merge branch 'main' of https://github.com/piEsposito/diffusers into main
patrickvonplaten Oct 5, 2022
06adc23
Merge branch 'main' into main
piEsposito Oct 5, 2022
1fd2ea4
NoModuleLayer
patrickvonplaten Oct 7, 2022
c2d9a84
comment out tests
patrickvonplaten Oct 7, 2022
a7bb7f8
up
patrickvonplaten Oct 7, 2022
0a9bcd9
uP
patrickvonplaten Oct 7, 2022
6d0bbba
finish
patrickvonplaten Oct 7, 2022
e599033
Update src/diffusers/pipelines/stable_diffusion/safety_checker.py
patrickvonplaten Oct 7, 2022
dfaabd6
finish
patrickvonplaten Oct 7, 2022
3d406af
Merge branch 'piEspositoMain' of https://github.com/huggingface/diffu…
patrickvonplaten Oct 7, 2022
bbc05c5
uP
patrickvonplaten Oct 7, 2022
91f8a59
Merge branch 'main' into piEspositoMain
patrickvonplaten Oct 7, 2022
dcf7e0f
make style
patrickvonplaten Oct 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,19 @@
from .configuration_utils import ConfigMixin
from .dynamic_modules_utils import get_class_from_dynamic_module
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging
from .utils import (
CONFIG_NAME,
DIFFUSERS_CACHE,
ONNX_WEIGHTS_NAME,
WEIGHTS_NAME,
BaseOutput,
is_transformers_available,
logging,
)


if is_transformers_available():
from transformers import PreTrainedModel


INDEX_FILE = "diffusion_pytorch_model.bin"
Expand Down Expand Up @@ -338,6 +350,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
custom_pipeline = kwargs.pop("custom_pipeline", None)
provider = kwargs.pop("provider", None)
sess_options = kwargs.pop("sess_options", None)
device_map = kwargs.pop("device_map", None)

# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
Expand Down Expand Up @@ -463,6 +476,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
loading_kwargs["provider"] = provider
loading_kwargs["sess_options"] = sess_options

if (
issubclass(class_obj, diffusers.ModelMixin)
or is_transformers_available()
and issubclass(class_obj, PreTrainedModel)
):
loading_kwargs["device_map"] = device_map

# check if the module is in a subdirectory
if os.path.isdir(os.path.join(cached_folder, name)):
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
Expand Down
6 changes: 4 additions & 2 deletions src/diffusers/pipelines/stable_diffusion/safety_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def cosine_distance(image_embeds, text_embeds):
class StableDiffusionSafetyChecker(PreTrainedModel):
config_class = CLIPConfig

_no_split_modules = ["CLIPEncoderLayer"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needed to allow to use device_map="auto"


def __init__(self, config: CLIPConfig):
super().__init__(config)

Expand All @@ -28,8 +30,8 @@ def __init__(self, config: CLIPConfig):
self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)

self.register_buffer("concept_embeds_weights", torch.ones(17))
self.register_buffer("special_care_embeds_weights", torch.ones(3))
self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)

@torch.no_grad()
def forward(self, clip_input, images):
Expand Down
54 changes: 54 additions & 0 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
import os
import random
import tempfile
import tracemalloc
import unittest

import numpy as np
import torch

import accelerate

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to sort the imports to make the linter happy, running isort locally gives this order,

import os
import random
import tempfile
import tracemalloc
import unittest

import accelerate
import numpy as np
import PIL
import torch

import transformers

import PIL
import transformers

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above comment for isort import grouping and ordering.

from diffusers import (
AutoencoderKL,
DDIMPipeline,
Expand Down Expand Up @@ -50,6 +53,7 @@
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device
from diffusers.utils.testing_utils import get_tests_dir
from packaging import version
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer

Expand Down Expand Up @@ -2034,3 +2038,53 @@ def test_callback_fn(step: int, timestep: int, latents: np.ndarray) -> None:
pipe(prompt=prompt, num_inference_steps=5, guidance_scale=7.5, callback=test_callback_fn, callback_steps=1)
assert test_callback_fn.has_been_called
assert number_of_steps == 6

@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_accelerate_load_works(self):
if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"):
return

if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"):
return

model_id = "CompVis/stable-diffusion-v1-4"
_ = StableDiffusionPipeline.from_pretrained(
model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto"
).to(torch_device)

@slow
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
def test_stable_diffusion_accelerate_load_reduces_memory_footprint(self):
if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"):
return

if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"):
return

pipeline_id = "CompVis/stable-diffusion-v1-4"

torch.cuda.empty_cache()
gc.collect()

tracemalloc.start()
pipeline_normal_load = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True
)
pipeline_normal_load.to(torch_device)
_, peak_normal = tracemalloc.get_traced_memory()
tracemalloc.stop()

del pipeline_normal_load
torch.cuda.empty_cache()
gc.collect()

tracemalloc.start()
_ = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto"
)
_, peak_accelerate = tracemalloc.get_traced_memory()

tracemalloc.stop()

assert peak_accelerate < peak_normal