Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/nightly_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ jobs:
run: |
python -m pip install -e .[quality,test]
python -m pip install -U git+https://github.com/huggingface/transformers
python -m pip install git+https://github.com/huggingface/accelerate

- name: Environment
run: |
Expand Down Expand Up @@ -134,6 +135,7 @@ jobs:
${CONDA_RUN} python -m pip install --upgrade pip
${CONDA_RUN} python -m pip install -e .[quality,test]
${CONDA_RUN} python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate

- name: Environment
shell: arch -arch arm64 bash {0}
Expand All @@ -157,4 +159,4 @@ jobs:
uses: actions/upload-artifact@v2
with:
name: torch_mps_test_reports
path: reports
path: reports
2 changes: 2 additions & 0 deletions .github/workflows/pr_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ jobs:
apt-get update && apt-get install libsndfile1-dev -y
python -m pip install -e .[quality,test]
python -m pip install -U git+https://github.com/huggingface/transformers
python -m pip install git+https://github.com/huggingface/accelerate

- name: Environment
run: |
Expand Down Expand Up @@ -126,6 +127,7 @@ jobs:
${CONDA_RUN} python -m pip install --upgrade pip
${CONDA_RUN} python -m pip install -e .[quality,test]
${CONDA_RUN} python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate
${CONDA_RUN} python -m pip install -U git+https://github.com/huggingface/transformers

- name: Environment
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/push_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ jobs:
run: |
python -m pip install -e .[quality,test]
python -m pip install -U git+https://github.com/huggingface/transformers
python -m pip install git+https://github.com/huggingface/accelerate

- name: Environment
run: |
Expand Down Expand Up @@ -130,6 +131,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install -e .[quality,test,training]
python -m pip install git+https://github.com/huggingface/accelerate
python -m pip install -U git+https://github.com/huggingface/transformers

- name: Environment
Expand All @@ -151,4 +153,4 @@ jobs:
uses: actions/upload-artifact@v2
with:
name: examples_test_reports
path: reports
path: reports
23 changes: 7 additions & 16 deletions src/diffusers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import os
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
Expand Down Expand Up @@ -489,11 +490,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
state_dict = load_state_dict(model_file)
# move the parms from meta device to cpu
for param_name, param in state_dict.items():
set_module_tensor_to_device(model, param_name, param_device, value=param)
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
if accepts_dtype:
set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype)
else:
set_module_tensor_to_device(model, param_name, param_device, value=param)
else: # else let accelerate handle loading and dispatching.
# Load weights and dispatch according to the device_map
# by deafult the device_map is None and the weights are loaded on the CPU
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map)
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map, dtype=torch_dtype)

loading_info = {
"missing_keys": [],
Expand All @@ -519,20 +524,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
model = cls.from_config(config, **unused_kwargs)

state_dict = load_state_dict(model_file)
dtype = set(v.dtype for v in state_dict.values())

if len(dtype) > 1 and torch.float32 not in dtype:
raise ValueError(
f"The weights of the model file {model_file} have a mixture of incompatible dtypes {dtype}. Please"
f" make sure that {model_file} weights have only one dtype."
)
elif len(dtype) > 1 and torch.float32 in dtype:
dtype = torch.float32
else:
dtype = dtype.pop()

# move model to correct dtype
model = model.to(dtype)

model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def test_from_save_pretrained_dtype(self):
with tempfile.TemporaryDirectory() as tmpdirname:
model.to(dtype)
model.save_pretrained(tmpdirname)
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True)
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype)
assert new_model.dtype == dtype
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False)
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype)
assert new_model.dtype == dtype

def test_determinism(self):
Expand Down