diff --git a/setup.py b/setup.py index affb2e06fc56..f51e044e9628 100644 --- a/setup.py +++ b/setup.py @@ -104,6 +104,7 @@ "torch>=1.4", "torchvision", "transformers>=4.21.0", + "accelerate>=0.12.0" ] # this is a lookup table with items like: diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 659f2ee8a66a..4d609043d731 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -21,6 +21,7 @@ import torch from torch import Tensor, device +from diffusers.utils import is_accelerate_available from huggingface_hub import hf_hub_download from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from requests import HTTPError @@ -293,33 +294,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P from_auto_class = kwargs.pop("_from_auto", False) torch_dtype = kwargs.pop("torch_dtype", None) subfolder = kwargs.pop("subfolder", None) + device_map = kwargs.pop("device_map", None) user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} # Load config if we don't provide a configuration config_path = pretrained_model_name_or_path - model, unused_kwargs = cls.from_config( - config_path, - cache_dir=cache_dir, - return_unused_kwargs=True, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - **kwargs, - ) - if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): - raise ValueError( - f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." - ) - elif torch_dtype is not None: - model = model.to(torch_dtype) - - model.register_to_config(_name_or_path=pretrained_model_name_or_path) # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the # Load model pretrained_model_name_or_path = str(pretrained_model_name_or_path) @@ -391,25 +372,81 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ) # restore default dtype - state_dict = load_state_dict(model_file) - model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( - model, - state_dict, - model_file, - pretrained_model_name_or_path, - ignore_mismatched_sizes=ignore_mismatched_sizes, - ) - # Set model in evaluation mode to deactivate DropOut modules by default - model.eval() + if device_map == "auto": + if is_accelerate_available(): + import accelerate + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + with accelerate.init_empty_weights(): + model, unused_kwargs = cls.from_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + device_map=device_map, + **kwargs, + ) + + accelerate.load_checkpoint_and_dispatch(model, model_file, device_map) + + loading_info = { + "missing_keys": [], + "unexpected_keys": [], + "mismatched_keys": [], + "error_msgs": [], + } + else: + model, unused_kwargs = cls.from_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + device_map=device_map, + **kwargs, + ) + + state_dict = load_state_dict(model_file) + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( + model, + state_dict, + model_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + ) - if output_loading_info: loading_info = { "missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "mismatched_keys": mismatched_keys, "error_msgs": error_msgs, } + + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + elif torch_dtype is not None: + model = model.to(torch_dtype) + + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + if output_loading_info: return model, loading_info return model diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index c1285bb8c23d..83f0d5c92bc8 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -23,6 +23,7 @@ USE_TF, USE_TORCH, DummyObject, + is_accelerate_available, is_flax_available, is_inflect_available, is_modelcards_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index de344d074da0..b2aabee70c92 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -159,6 +159,13 @@ except importlib_metadata.PackageNotFoundError: _scipy_available = False +_accelerate_available = importlib.util.find_spec("accelerate") is not None +try: + _accelerate_version = importlib_metadata.version("accelerate") + logger.debug(f"Successfully imported accelerate version {_accelerate_version}") +except importlib_metadata.PackageNotFoundError: + _accelerate_available = False + def is_torch_available(): return _torch_available @@ -196,6 +203,10 @@ def is_scipy_available(): return _scipy_available +def is_accelerate_available(): + return _accelerate_available + + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py index 734fb5924d84..9d40331ea600 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet.py @@ -13,7 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc import math +import tracemalloc import unittest import torch @@ -133,6 +135,74 @@ def test_from_pretrained_hub(self): assert image is not None, "Make sure output is not None" + @unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU") + def test_from_pretrained_accelerate(self): + model, _ = UNet2DModel.from_pretrained( + "fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto" + ) + model.to(torch_device) + image = model(**self.dummy_input).sample + + assert image is not None, "Make sure output is not None" + + @unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU") + def test_from_pretrained_accelerate_wont_change_results(self): + model_accelerate, _ = UNet2DModel.from_pretrained( + "fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto" + ) + model_accelerate.to(torch_device) + model_accelerate.eval() + + noise = torch.randn( + 1, + model_accelerate.config.in_channels, + model_accelerate.config.sample_size, + model_accelerate.config.sample_size, + generator=torch.manual_seed(0), + ) + noise = noise.to(torch_device) + time_step = torch.tensor([10] * noise.shape[0]).to(torch_device) + + arr_accelerate = model_accelerate(noise, time_step)["sample"] + + # two models don't need to stay in the device at the same time + del model_accelerate + torch.cuda.empty_cache() + gc.collect() + + model_normal_load, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True) + model_normal_load.to(torch_device) + model_normal_load.eval() + arr_normal_load = model_normal_load(noise, time_step)["sample"] + + assert torch.allclose(arr_accelerate, arr_normal_load, rtol=1e-3) + + @unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU") + def test_memory_footprint_gets_reduced(self): + torch.cuda.empty_cache() + gc.collect() + + tracemalloc.start() + model_accelerate, _ = UNet2DModel.from_pretrained( + "fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto" + ) + model_accelerate.to(torch_device) + model_accelerate.eval() + _, peak_accelerate = tracemalloc.get_traced_memory() + + del model_accelerate + torch.cuda.empty_cache() + gc.collect() + + model_normal_load, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True) + model_normal_load.to(torch_device) + model_normal_load.eval() + _, peak_normal = tracemalloc.get_traced_memory() + + tracemalloc.stop() + + assert peak_accelerate < peak_normal + def test_output_pretrained(self): model = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update") model.eval()