From ab54e0e4cdc11c017be53dbff86c3806351991df Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 16 Dec 2022 12:51:27 +0100 Subject: [PATCH 1/5] [Dtype] Align automatic dtype --- src/diffusers/modeling_utils.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index edc519db6e13..7d56c427ac14 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -518,20 +518,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model = cls.from_config(config, **unused_kwargs) state_dict = load_state_dict(model_file) - dtype = set(v.dtype for v in state_dict.values()) - - if len(dtype) > 1 and torch.float32 not in dtype: - raise ValueError( - f"The weights of the model file {model_file} have a mixture of incompatible dtypes {dtype}. Please" - f" make sure that {model_file} weights have only one dtype." - ) - elif len(dtype) > 1 and torch.float32 in dtype: - dtype = torch.float32 - else: - dtype = dtype.pop() - - # move model to correct dtype - model = model.to(dtype) model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( model, From 0d489696343154ecf629fa1fefb4e329a7d69daa Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 16 Dec 2022 13:02:57 +0100 Subject: [PATCH 2/5] up --- tests/test_modeling_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 68ab914b4209..e4c32c0e78c1 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -70,9 +70,9 @@ def test_from_save_pretrained_dtype(self): with tempfile.TemporaryDirectory() as tmpdirname: model.to(dtype) model.save_pretrained(tmpdirname) - new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True) + new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype) assert new_model.dtype == dtype - new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False) + new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype) assert new_model.dtype == dtype def test_determinism(self): From f33cccf28aa8d5e77963146d137b45ea7620a5b1 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 16 Dec 2022 13:22:54 +0100 Subject: [PATCH 3/5] up --- src/diffusers/modeling_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 7d56c427ac14..16d607f10112 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -489,10 +489,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # move the parms from meta device to cpu for param_name, param in state_dict.items(): set_module_tensor_to_device(model, param_name, param_device, value=param) + # TODO(Patrick) - check whether dtype conversion should be handled here or kwarg is added else: # else let accelerate handle loading and dispatching. # Load weights and dispatch according to the device_map # by deafult the device_map is None and the weights are loaded on the CPU - accelerate.load_checkpoint_and_dispatch(model, model_file, device_map) + accelerate.load_checkpoint_and_dispatch(model, model_file, device_map, dtype=torch_dtype) loading_info = { "missing_keys": [], From 1fd054c3acb1e581b6f650bd6a45572da67d522c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 20 Dec 2022 16:25:20 +0100 Subject: [PATCH 4/5] fix --- src/diffusers/modeling_utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index b1a2b764314b..e7de5258ede8 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -16,6 +16,7 @@ import os from functools import partial +import inspect from typing import Callable, List, Optional, Tuple, Union import torch @@ -489,8 +490,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P state_dict = load_state_dict(model_file) # move the parms from meta device to cpu for param_name, param in state_dict.items(): - set_module_tensor_to_device(model, param_name, param_device, value=param) - # TODO(Patrick) - check whether dtype conversion should be handled here or kwarg is added + accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) + if accepts_dtype: + set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype) + else: + set_module_tensor_to_device(model, param_name, param_device, value=param) else: # else let accelerate handle loading and dispatching. # Load weights and dispatch according to the device_map # by deafult the device_map is None and the weights are loaded on the CPU From a61b208e91a3c1a71cc16378b316614f960a6906 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 20 Dec 2022 17:03:23 +0100 Subject: [PATCH 5/5] re-add accelerate --- .github/workflows/nightly_tests.yml | 4 +++- .github/workflows/pr_tests.yml | 2 ++ .github/workflows/push_tests.yml | 4 +++- src/diffusers/modeling_utils.py | 2 +- 4 files changed, 9 insertions(+), 3 deletions(-) diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index 4aee44c56c3e..fb0ce92cb61c 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -62,6 +62,7 @@ jobs: run: | python -m pip install -e .[quality,test] python -m pip install -U git+https://github.com/huggingface/transformers + python -m pip install git+https://github.com/huggingface/accelerate - name: Environment run: | @@ -134,6 +135,7 @@ jobs: ${CONDA_RUN} python -m pip install --upgrade pip ${CONDA_RUN} python -m pip install -e .[quality,test] ${CONDA_RUN} python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu + ${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate - name: Environment shell: arch -arch arm64 bash {0} @@ -157,4 +159,4 @@ jobs: uses: actions/upload-artifact@v2 with: name: torch_mps_test_reports - path: reports \ No newline at end of file + path: reports diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index 93bbdae388e6..082b12404a85 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -60,6 +60,7 @@ jobs: apt-get update && apt-get install libsndfile1-dev -y python -m pip install -e .[quality,test] python -m pip install -U git+https://github.com/huggingface/transformers + python -m pip install git+https://github.com/huggingface/accelerate - name: Environment run: | @@ -126,6 +127,7 @@ jobs: ${CONDA_RUN} python -m pip install --upgrade pip ${CONDA_RUN} python -m pip install -e .[quality,test] ${CONDA_RUN} python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu + ${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate ${CONDA_RUN} python -m pip install -U git+https://github.com/huggingface/transformers - name: Environment diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml index df3a3bf0fdf2..2d4875b80ced 100644 --- a/.github/workflows/push_tests.yml +++ b/.github/workflows/push_tests.yml @@ -62,6 +62,7 @@ jobs: run: | python -m pip install -e .[quality,test] python -m pip install -U git+https://github.com/huggingface/transformers + python -m pip install git+https://github.com/huggingface/accelerate - name: Environment run: | @@ -130,6 +131,7 @@ jobs: - name: Install dependencies run: | python -m pip install -e .[quality,test,training] + python -m pip install git+https://github.com/huggingface/accelerate python -m pip install -U git+https://github.com/huggingface/transformers - name: Environment @@ -151,4 +153,4 @@ jobs: uses: actions/upload-artifact@v2 with: name: examples_test_reports - path: reports \ No newline at end of file + path: reports diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index e7de5258ede8..9afece3244e9 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -14,9 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import os from functools import partial -import inspect from typing import Callable, List, Optional, Tuple, Union import torch