Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions src/diffusers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@
logger = logging.get_logger(__name__)


if is_torch_version(">=", "1.9.0"):
_LOW_CPU_MEM_USAGE_DEFAULT = True
else:
_LOW_CPU_MEM_USAGE_DEFAULT = False


def get_parameter_device(parameter: torch.nn.Module):
try:
return next(parameter.parameters()).device
Expand Down Expand Up @@ -278,11 +284,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
more information about each option see [designing a device
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
fast_load (`bool`, *optional*, defaults to `True`):
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
this argument will be ignored and the model will be loaded normally.
setting this argument to `True` will raise an error.

<Tip>

Expand Down Expand Up @@ -311,16 +317,26 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
torch_dtype = kwargs.pop("torch_dtype", None)
subfolder = kwargs.pop("subfolder", None)
device_map = kwargs.pop("device_map", None)
fast_load = kwargs.pop("fast_load", True)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)

# Check if we can handle device_map and dispatching the weights
if device_map is not None and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError("Loading and dispatching requires torch >= 1.9.0")
raise NotImplementedError(
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or sets"
" `device_map=None`."
)

# Fast init is only possible if torch version is >= 1.9.0
_INIT_EMPTY_WEIGHTS = fast_load or device_map is not None
if _INIT_EMPTY_WEIGHTS and not is_torch_version(">=", "1.9.0"):
logger.warn("Loading with `fast_load` requires torch >= 1.9.0. Falling back to normal loading.")
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or sets"
" `low_cpu_mem_usage=False`."
)

if low_cpu_mem_usage is False and device_map is not None:
raise ValueError(
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
)

user_agent = {
"diffusers": __version__,
Expand Down Expand Up @@ -403,7 +419,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

# restore default dtype

if _INIT_EMPTY_WEIGHTS:
if low_cpu_mem_usage:
# Instantiate model with empty weights
with accelerate.init_empty_weights():
model, unused_kwargs = cls.from_config(
Expand Down
44 changes: 36 additions & 8 deletions src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import diffusers
import PIL
from accelerate.utils.versions import is_torch_version
from huggingface_hub import snapshot_download
from packaging import version
from PIL import Image
Expand All @@ -33,6 +34,7 @@
from .configuration_utils import ConfigMixin
from .dynamic_modules_utils import get_class_from_dynamic_module
from .hub_utils import http_user_agent
from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from .utils import (
CONFIG_NAME,
Expand Down Expand Up @@ -328,6 +330,19 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information. specify the folder name here.
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be refined to each
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
same device.

To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
more information about each option see [designing a device
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
setting this argument to `True` will raise an error.

kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
Expand Down Expand Up @@ -380,7 +395,25 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
provider = kwargs.pop("provider", None)
sess_options = kwargs.pop("sess_options", None)
device_map = kwargs.pop("device_map", None)
fast_load = kwargs.pop("fast_load", True)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)

if device_map is not None and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or sets"
" `device_map=None`."
)

if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or sets"
" `low_cpu_mem_usage=False`."
)

if low_cpu_mem_usage is False and device_map is not None:
raise ValueError(
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
)

# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
Expand Down Expand Up @@ -573,17 +606,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0")
)

if is_diffusers_model:
loading_kwargs["fast_load"] = fast_load

# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
# To make default loading faster we set the `low_cpu_mem_usage=fast_load` flag which is `True` by default.
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
# This makes sure that the weights won't be initialized which significantly speeds up loading.
if is_transformers_model and device_map is None:
loading_kwargs["low_cpu_mem_usage"] = fast_load

if is_diffusers_model or is_transformers_model:
loading_kwargs["device_map"] = device_map
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage

# check if the module is in a subdirectory
if os.path.isdir(os.path.join(cached_folder, name)):
Expand Down
8 changes: 4 additions & 4 deletions tests/models/test_models_unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_from_pretrained_accelerate(self):

@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
def test_from_pretrained_accelerate_wont_change_results(self):
# by defautl model loading will use accelerate as `fast_load=True`
# by defautl model loading will use accelerate as `low_cpu_mem_usage=True`
model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
model_accelerate.to(torch_device)
model_accelerate.eval()
Expand All @@ -156,7 +156,7 @@ def test_from_pretrained_accelerate_wont_change_results(self):
gc.collect()

model_normal_load, _ = UNet2DModel.from_pretrained(
"fusing/unet-ldm-dummy-update", output_loading_info=True, fast_init=False
"fusing/unet-ldm-dummy-update", output_loading_info=True, low_cpu_mem_usage=False
)
model_normal_load.to(torch_device)
model_normal_load.eval()
Expand All @@ -170,7 +170,7 @@ def test_memory_footprint_gets_reduced(self):
gc.collect()

tracemalloc.start()
# by defautl model loading will use accelerate as `fast_load=True`
# by defautl model loading will use accelerate as `low_cpu_mem_usage=True`
model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
model_accelerate.to(torch_device)
model_accelerate.eval()
Expand All @@ -181,7 +181,7 @@ def test_memory_footprint_gets_reduced(self):
gc.collect()

model_normal_load, _ = UNet2DModel.from_pretrained(
"fusing/unet-ldm-dummy-update", output_loading_info=True, fast_init=False
"fusing/unet-ldm-dummy-update", output_loading_info=True, low_cpu_mem_usage=False
)
model_normal_load.to(torch_device)
model_normal_load.eval()
Expand Down
12 changes: 6 additions & 6 deletions tests/pipelines/stable_diffusion/test_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,23 +823,23 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No
assert test_callback_fn.has_been_called
assert number_of_steps == 51

def test_stable_diffusion_fast_load(self):
def test_stable_diffusion_low_cpu_mem_usage(self):
pipeline_id = "CompVis/stable-diffusion-v1-4"

start_time = time.time()
pipeline_fast_load = StableDiffusionPipeline.from_pretrained(
pipeline_low_cpu_mem_usage = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float16
)
pipeline_fast_load.to(torch_device)
fast_load_time = time.time() - start_time
pipeline_low_cpu_mem_usage.to(torch_device)
low_cpu_mem_usage_time = time.time() - start_time

start_time = time.time()
_ = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, fast_load=False
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, low_cpu_mem_usage=False
)
normal_load_time = time.time() - start_time

assert 2 * fast_load_time < normal_load_time
assert 2 * low_cpu_mem_usage_time < normal_load_time

@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self):
Expand Down