From ca902832a93136ec492ef6d7ba35334120d0ec91 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 3 Nov 2022 17:46:21 +0100 Subject: [PATCH 1/4] correct naming --- src/diffusers/modeling_utils.py | 14 ++++++++++---- src/diffusers/pipeline_utils.py | 8 ++++---- tests/models/test_models_unet_2d.py | 8 ++++---- .../stable_diffusion/test_stable_diffusion.py | 12 ++++++------ 4 files changed, 24 insertions(+), 18 deletions(-) diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index f4697636719e..26613c396c1f 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -35,6 +35,12 @@ logger = logging.get_logger(__name__) +if is_torch_version(">=", "1.9.0"): + _LOW_CPU_MEM_USAGE_DEFAULT = True +else: + _LOW_CPU_MEM_USAGE_DEFAULT = False + + def get_parameter_device(parameter: torch.nn.Module): try: return next(parameter.parameters()).device @@ -278,7 +284,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). - fast_load (`bool`, *optional*, defaults to `True`): + low_cpu_mem_usage (`bool`, *optional*, defaults to `True`): Speed up model loading by not initializing the weights and only loading the pre-trained weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, @@ -311,16 +317,16 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P torch_dtype = kwargs.pop("torch_dtype", None) subfolder = kwargs.pop("subfolder", None) device_map = kwargs.pop("device_map", None) - fast_load = kwargs.pop("fast_load", True) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) # Check if we can handle device_map and dispatching the weights if device_map is not None and not is_torch_version(">=", "1.9.0"): raise NotImplementedError("Loading and dispatching requires torch >= 1.9.0") # Fast init is only possible if torch version is >= 1.9.0 - _INIT_EMPTY_WEIGHTS = fast_load or device_map is not None + _INIT_EMPTY_WEIGHTS = low_cpu_mem_usage or device_map is not None if _INIT_EMPTY_WEIGHTS and not is_torch_version(">=", "1.9.0"): - logger.warn("Loading with `fast_load` requires torch >= 1.9.0. Falling back to normal loading.") + logger.warn("Loading with `low_cpu_mem_usage` requires torch >= 1.9.0. Falling back to normal loading.") user_agent = { "diffusers": __version__, diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 5c248ec1a956..47b689b5c500 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -380,7 +380,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P provider = kwargs.pop("provider", None) sess_options = kwargs.pop("sess_options", None) device_map = kwargs.pop("device_map", None) - fast_load = kwargs.pop("fast_load", True) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", True) # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained @@ -574,13 +574,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ) if is_diffusers_model: - loading_kwargs["fast_load"] = fast_load + loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers. - # To make default loading faster we set the `low_cpu_mem_usage=fast_load` flag which is `True` by default. + # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default. # This makes sure that the weights won't be initialized which significantly speeds up loading. if is_transformers_model and device_map is None: - loading_kwargs["low_cpu_mem_usage"] = fast_load + loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage if is_diffusers_model or is_transformers_model: loading_kwargs["device_map"] = device_map diff --git a/tests/models/test_models_unet_2d.py b/tests/models/test_models_unet_2d.py index feee72457769..71ddf1a13414 100644 --- a/tests/models/test_models_unet_2d.py +++ b/tests/models/test_models_unet_2d.py @@ -133,7 +133,7 @@ def test_from_pretrained_accelerate(self): @unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU") def test_from_pretrained_accelerate_wont_change_results(self): - # by defautl model loading will use accelerate as `fast_load=True` + # by defautl model loading will use accelerate as `low_cpu_mem_usage=True` model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True) model_accelerate.to(torch_device) model_accelerate.eval() @@ -156,7 +156,7 @@ def test_from_pretrained_accelerate_wont_change_results(self): gc.collect() model_normal_load, _ = UNet2DModel.from_pretrained( - "fusing/unet-ldm-dummy-update", output_loading_info=True, fast_init=False + "fusing/unet-ldm-dummy-update", output_loading_info=True, low_cpu_mem_usage=False ) model_normal_load.to(torch_device) model_normal_load.eval() @@ -170,7 +170,7 @@ def test_memory_footprint_gets_reduced(self): gc.collect() tracemalloc.start() - # by defautl model loading will use accelerate as `fast_load=True` + # by defautl model loading will use accelerate as `low_cpu_mem_usage=True` model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True) model_accelerate.to(torch_device) model_accelerate.eval() @@ -181,7 +181,7 @@ def test_memory_footprint_gets_reduced(self): gc.collect() model_normal_load, _ = UNet2DModel.from_pretrained( - "fusing/unet-ldm-dummy-update", output_loading_info=True, fast_init=False + "fusing/unet-ldm-dummy-update", output_loading_info=True, low_cpu_mem_usage=False ) model_normal_load.to(torch_device) model_normal_load.eval() diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 0f7798735585..b01094a607ae 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -823,23 +823,23 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No assert test_callback_fn.has_been_called assert number_of_steps == 51 - def test_stable_diffusion_fast_load(self): + def test_stable_diffusion_low_cpu_mem_usage(self): pipeline_id = "CompVis/stable-diffusion-v1-4" start_time = time.time() - pipeline_fast_load = StableDiffusionPipeline.from_pretrained( + pipeline_low_cpu_mem_usage = StableDiffusionPipeline.from_pretrained( pipeline_id, revision="fp16", torch_dtype=torch.float16 ) - pipeline_fast_load.to(torch_device) - fast_load_time = time.time() - start_time + pipeline_low_cpu_mem_usage.to(torch_device) + low_cpu_mem_usage_time = time.time() - start_time start_time = time.time() _ = StableDiffusionPipeline.from_pretrained( - pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, fast_load=False + pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, low_cpu_mem_usage=False ) normal_load_time = time.time() - start_time - assert 2 * fast_load_time < normal_load_time + assert 2 * low_cpu_mem_usage_time < normal_load_time @unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU") def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self): From 28c94cbac75909cb453f3fd1a6dd841a4183b2ea Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 3 Nov 2022 17:55:21 +0100 Subject: [PATCH 2/4] finish --- src/diffusers/modeling_utils.py | 26 +++++++++++++------- src/diffusers/pipeline_utils.py | 42 +++++++++++++++++++++++++++------ 2 files changed, 53 insertions(+), 15 deletions(-) diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 26613c396c1f..4e9e8b42e58f 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -284,11 +284,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). - low_cpu_mem_usage (`bool`, *optional*, defaults to `True`): + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): Speed up model loading by not initializing the weights and only loading the pre-trained weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, - this argument will be ignored and the model will be loaded normally. + setting this argument to `True` will raise an error. @@ -321,12 +321,22 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Check if we can handle device_map and dispatching the weights if device_map is not None and not is_torch_version(">=", "1.9.0"): - raise NotImplementedError("Loading and dispatching requires torch >= 1.9.0") + raise NotImplementedError( + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or sets" + " `device_map=None`." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or sets" + " `low_cpu_mem_usage=False`." + ) - # Fast init is only possible if torch version is >= 1.9.0 - _INIT_EMPTY_WEIGHTS = low_cpu_mem_usage or device_map is not None - if _INIT_EMPTY_WEIGHTS and not is_torch_version(">=", "1.9.0"): - logger.warn("Loading with `low_cpu_mem_usage` requires torch >= 1.9.0. Falling back to normal loading.") + if low_cpu_mem_usage is False and device_map is not None: + raise ValueError( + "You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and" + " dispatching. Please make sure to set `low_cpu_mem_usage=True`." + ) user_agent = { "diffusers": __version__, @@ -409,7 +419,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # restore default dtype - if _INIT_EMPTY_WEIGHTS: + if low_cpu_mem_usage: # Instantiate model with empty weights with accelerate.init_empty_weights(): model, unused_kwargs = cls.from_config( diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 47b689b5c500..b47bba25a5df 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -25,6 +25,7 @@ import diffusers import PIL +from accelerate.utils.versions import is_torch_version from huggingface_hub import snapshot_download from packaging import version from PIL import Image @@ -33,6 +34,7 @@ from .configuration_utils import ConfigMixin from .dynamic_modules_utils import get_class_from_dynamic_module from .hub_utils import http_user_agent +from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from .utils import ( CONFIG_NAME, @@ -328,6 +330,19 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P Mirror source to accelerate downloads in China. If you are from China and have an accessibility problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. Please refer to the mirror site for more information. specify the folder name here. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the + same device. + + To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading by not initializing the weights and only loading the pre-trained weights. This + also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the + model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, + setting this argument to `True` will raise an error. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the @@ -380,7 +395,25 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P provider = kwargs.pop("provider", None) sess_options = kwargs.pop("sess_options", None) device_map = kwargs.pop("device_map", None) - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", True) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + + if device_map is not None and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or sets" + " `device_map=None`." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or sets" + " `low_cpu_mem_usage=False`." + ) + + if low_cpu_mem_usage is False and device_map is not None: + raise ValueError( + "You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and" + " dispatching. Please make sure to set `low_cpu_mem_usage=True`." + ) # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained @@ -573,17 +606,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0") ) - if is_diffusers_model: - loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage - # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers. # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default. # This makes sure that the weights won't be initialized which significantly speeds up loading. - if is_transformers_model and device_map is None: - loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage - if is_diffusers_model or is_transformers_model: loading_kwargs["device_map"] = device_map + loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage # check if the module is in a subdirectory if os.path.isdir(os.path.join(cached_folder, name)): From 75223abc67ee5e6ac58be7d715a0f453485a05a5 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 3 Nov 2022 18:01:26 +0100 Subject: [PATCH 3/4] Apply suggestions from code review --- src/diffusers/modeling_utils.py | 2 +- src/diffusers/pipeline_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 4e9e8b42e58f..4cd772ec3a0d 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -334,7 +334,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if low_cpu_mem_usage is False and device_map is not None: raise ValueError( - "You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and" + f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and" " dispatching. Please make sure to set `low_cpu_mem_usage=True`." ) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index b47bba25a5df..72189f996b2d 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -411,7 +411,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if low_cpu_mem_usage is False and device_map is not None: raise ValueError( - "You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and" + f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and" " dispatching. Please make sure to set `low_cpu_mem_usage=True`." ) From 7d1dbfc56216a6ab5cdbdd419b863b166b158b32 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 3 Nov 2022 18:09:21 +0100 Subject: [PATCH 4/4] Apply suggestions from code review Co-authored-by: Suraj Patil --- src/diffusers/modeling_utils.py | 6 +++--- src/diffusers/pipeline_utils.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 4cd772ec3a0d..9e05672bf163 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -322,19 +322,19 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Check if we can handle device_map and dispatching the weights if device_map is not None and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( - "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or sets" + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" " `device_map=None`." ) if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( - "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or sets" + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" " `low_cpu_mem_usage=False`." ) if low_cpu_mem_usage is False and device_map is not None: raise ValueError( - f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and" + f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and" " dispatching. Please make sure to set `low_cpu_mem_usage=True`." ) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 72189f996b2d..36c2d5b888ef 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -399,13 +399,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if device_map is not None and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( - "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or sets" + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" " `device_map=None`." ) if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( - "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or sets" + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" " `low_cpu_mem_usage=False`." )