Skip to content

Commit 42bb459

Browse files
[Low cpu memory] Correct naming and improve default usage (#1122)
* correct naming * finish * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Suraj Patil <[email protected]> Co-authored-by: Suraj Patil <[email protected]>
1 parent 988c822 commit 42bb459

File tree

4 files changed

+71
-27
lines changed

4 files changed

+71
-27
lines changed

src/diffusers/modeling_utils.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@
3535
logger = logging.get_logger(__name__)
3636

3737

38+
if is_torch_version(">=", "1.9.0"):
39+
_LOW_CPU_MEM_USAGE_DEFAULT = True
40+
else:
41+
_LOW_CPU_MEM_USAGE_DEFAULT = False
42+
43+
3844
def get_parameter_device(parameter: torch.nn.Module):
3945
try:
4046
return next(parameter.parameters()).device
@@ -278,11 +284,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
278284
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
279285
more information about each option see [designing a device
280286
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
281-
fast_load (`bool`, *optional*, defaults to `True`):
287+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
282288
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
283289
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
284290
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
285-
this argument will be ignored and the model will be loaded normally.
291+
setting this argument to `True` will raise an error.
286292
287293
<Tip>
288294
@@ -311,16 +317,26 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
311317
torch_dtype = kwargs.pop("torch_dtype", None)
312318
subfolder = kwargs.pop("subfolder", None)
313319
device_map = kwargs.pop("device_map", None)
314-
fast_load = kwargs.pop("fast_load", True)
320+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
315321

316322
# Check if we can handle device_map and dispatching the weights
317323
if device_map is not None and not is_torch_version(">=", "1.9.0"):
318-
raise NotImplementedError("Loading and dispatching requires torch >= 1.9.0")
324+
raise NotImplementedError(
325+
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
326+
" `device_map=None`."
327+
)
319328

320-
# Fast init is only possible if torch version is >= 1.9.0
321-
_INIT_EMPTY_WEIGHTS = fast_load or device_map is not None
322-
if _INIT_EMPTY_WEIGHTS and not is_torch_version(">=", "1.9.0"):
323-
logger.warn("Loading with `fast_load` requires torch >= 1.9.0. Falling back to normal loading.")
329+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
330+
raise NotImplementedError(
331+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
332+
" `low_cpu_mem_usage=False`."
333+
)
334+
335+
if low_cpu_mem_usage is False and device_map is not None:
336+
raise ValueError(
337+
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
338+
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
339+
)
324340

325341
user_agent = {
326342
"diffusers": __version__,
@@ -403,7 +419,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
403419

404420
# restore default dtype
405421

406-
if _INIT_EMPTY_WEIGHTS:
422+
if low_cpu_mem_usage:
407423
# Instantiate model with empty weights
408424
with accelerate.init_empty_weights():
409425
model, unused_kwargs = cls.from_config(

src/diffusers/pipeline_utils.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import diffusers
2727
import PIL
28+
from accelerate.utils.versions import is_torch_version
2829
from huggingface_hub import snapshot_download
2930
from packaging import version
3031
from PIL import Image
@@ -33,6 +34,7 @@
3334
from .configuration_utils import ConfigMixin
3435
from .dynamic_modules_utils import get_class_from_dynamic_module
3536
from .hub_utils import http_user_agent
37+
from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
3638
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
3739
from .utils import (
3840
CONFIG_NAME,
@@ -328,6 +330,19 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
328330
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
329331
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
330332
Please refer to the mirror site for more information. specify the folder name here.
333+
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
334+
A map that specifies where each submodule should go. It doesn't need to be refined to each
335+
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
336+
same device.
337+
338+
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
339+
more information about each option see [designing a device
340+
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
341+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
342+
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
343+
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
344+
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
345+
setting this argument to `True` will raise an error.
331346
332347
kwargs (remaining dictionary of keyword arguments, *optional*):
333348
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
@@ -380,7 +395,25 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
380395
provider = kwargs.pop("provider", None)
381396
sess_options = kwargs.pop("sess_options", None)
382397
device_map = kwargs.pop("device_map", None)
383-
fast_load = kwargs.pop("fast_load", True)
398+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
399+
400+
if device_map is not None and not is_torch_version(">=", "1.9.0"):
401+
raise NotImplementedError(
402+
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
403+
" `device_map=None`."
404+
)
405+
406+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
407+
raise NotImplementedError(
408+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
409+
" `low_cpu_mem_usage=False`."
410+
)
411+
412+
if low_cpu_mem_usage is False and device_map is not None:
413+
raise ValueError(
414+
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
415+
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
416+
)
384417

385418
# 1. Download the checkpoints and configs
386419
# use snapshot download here to get it working from from_pretrained
@@ -573,17 +606,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
573606
and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0")
574607
)
575608

576-
if is_diffusers_model:
577-
loading_kwargs["fast_load"] = fast_load
578-
579609
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
580-
# To make default loading faster we set the `low_cpu_mem_usage=fast_load` flag which is `True` by default.
610+
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
581611
# This makes sure that the weights won't be initialized which significantly speeds up loading.
582-
if is_transformers_model and device_map is None:
583-
loading_kwargs["low_cpu_mem_usage"] = fast_load
584-
585612
if is_diffusers_model or is_transformers_model:
586613
loading_kwargs["device_map"] = device_map
614+
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
587615

588616
# check if the module is in a subdirectory
589617
if os.path.isdir(os.path.join(cached_folder, name)):

tests/models/test_models_unet_2d.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def test_from_pretrained_accelerate(self):
133133

134134
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
135135
def test_from_pretrained_accelerate_wont_change_results(self):
136-
# by defautl model loading will use accelerate as `fast_load=True`
136+
# by defautl model loading will use accelerate as `low_cpu_mem_usage=True`
137137
model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
138138
model_accelerate.to(torch_device)
139139
model_accelerate.eval()
@@ -156,7 +156,7 @@ def test_from_pretrained_accelerate_wont_change_results(self):
156156
gc.collect()
157157

158158
model_normal_load, _ = UNet2DModel.from_pretrained(
159-
"fusing/unet-ldm-dummy-update", output_loading_info=True, fast_init=False
159+
"fusing/unet-ldm-dummy-update", output_loading_info=True, low_cpu_mem_usage=False
160160
)
161161
model_normal_load.to(torch_device)
162162
model_normal_load.eval()
@@ -170,7 +170,7 @@ def test_memory_footprint_gets_reduced(self):
170170
gc.collect()
171171

172172
tracemalloc.start()
173-
# by defautl model loading will use accelerate as `fast_load=True`
173+
# by defautl model loading will use accelerate as `low_cpu_mem_usage=True`
174174
model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
175175
model_accelerate.to(torch_device)
176176
model_accelerate.eval()
@@ -181,7 +181,7 @@ def test_memory_footprint_gets_reduced(self):
181181
gc.collect()
182182

183183
model_normal_load, _ = UNet2DModel.from_pretrained(
184-
"fusing/unet-ldm-dummy-update", output_loading_info=True, fast_init=False
184+
"fusing/unet-ldm-dummy-update", output_loading_info=True, low_cpu_mem_usage=False
185185
)
186186
model_normal_load.to(torch_device)
187187
model_normal_load.eval()

tests/pipelines/stable_diffusion/test_stable_diffusion.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -823,23 +823,23 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No
823823
assert test_callback_fn.has_been_called
824824
assert number_of_steps == 51
825825

826-
def test_stable_diffusion_fast_load(self):
826+
def test_stable_diffusion_low_cpu_mem_usage(self):
827827
pipeline_id = "CompVis/stable-diffusion-v1-4"
828828

829829
start_time = time.time()
830-
pipeline_fast_load = StableDiffusionPipeline.from_pretrained(
830+
pipeline_low_cpu_mem_usage = StableDiffusionPipeline.from_pretrained(
831831
pipeline_id, revision="fp16", torch_dtype=torch.float16
832832
)
833-
pipeline_fast_load.to(torch_device)
834-
fast_load_time = time.time() - start_time
833+
pipeline_low_cpu_mem_usage.to(torch_device)
834+
low_cpu_mem_usage_time = time.time() - start_time
835835

836836
start_time = time.time()
837837
_ = StableDiffusionPipeline.from_pretrained(
838-
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, fast_load=False
838+
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, low_cpu_mem_usage=False
839839
)
840840
normal_load_time = time.time() - start_time
841841

842-
assert 2 * fast_load_time < normal_load_time
842+
assert 2 * low_cpu_mem_usage_time < normal_load_time
843843

844844
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
845845
def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self):

0 commit comments

Comments
 (0)