diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 5f79e7fe0155..bfcba2916a6b 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -472,6 +472,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model = cls.from_config(config, **unused_kwargs) state_dict = load_state_dict(model_file) + dtype = set(v.dtype for v in state_dict.values()) + + if len(dtype) > 1 and torch.float32 not in dtype: + raise ValueError( + f"The weights of the model file {model_file} have a mixture of incompatible dtypes {dtype}. Please" + f" make sure that {model_file} weights have only one dtype." + ) + elif len(dtype) > 1 and torch.float32 in dtype: + dtype = torch.float32 + else: + dtype = dtype.pop() + + # move model to correct dtype + model = model.to(dtype) + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( model, state_dict, diff --git a/tests/models/test_models_unet_1d.py b/tests/models/test_models_unet_1d.py index 089d935651a5..b494c231b5fe 100644 --- a/tests/models/test_models_unet_1d.py +++ b/tests/models/test_models_unet_1d.py @@ -63,8 +63,8 @@ def test_outputs_equivalence(self): super().test_outputs_equivalence() @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") - def test_from_pretrained_save_pretrained(self): - super().test_from_pretrained_save_pretrained() + def test_from_save_pretrained(self): + super().test_from_save_pretrained() @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_model_from_pretrained(self): @@ -183,8 +183,8 @@ def test_outputs_equivalence(self): super().test_outputs_equivalence() @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") - def test_from_pretrained_save_pretrained(self): - super().test_from_pretrained_save_pretrained() + def test_from_save_pretrained(self): + super().test_from_save_pretrained() @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_model_from_pretrained(self): diff --git a/tests/pipelines/versatile_diffusion/test_versatile_diffusion_mega.py b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_mega.py index ab4580dae1fe..ad24ec01f633 100644 --- a/tests/pipelines/versatile_diffusion/test_versatile_diffusion_mega.py +++ b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_mega.py @@ -42,7 +42,7 @@ def tearDown(self): gc.collect() torch.cuda.empty_cache() - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): pipe = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion", torch_dtype=torch.float16) pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index cad1887f4df8..68ab914b4209 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -27,7 +27,7 @@ class ModelTesterMixin: - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) @@ -57,6 +57,24 @@ def test_from_pretrained_save_pretrained(self): max_diff = (image - new_image).abs().sum().item() self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") + def test_from_save_pretrained_dtype(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + for dtype in [torch.float32, torch.float16, torch.bfloat16]: + if torch_device == "mps" and dtype == torch.bfloat16: + continue + with tempfile.TemporaryDirectory() as tmpdirname: + model.to(dtype) + model.save_pretrained(tmpdirname) + new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True) + assert new_model.dtype == dtype + new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False) + assert new_model.dtype == dtype + def test_determinism(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 033f363ff41f..ec44b69cb1f9 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -656,7 +656,7 @@ def test_warning_unused_kwargs(self): assert cap_logger.out == "Keyword arguments {'not_used': True} not recognized.\n" - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): # 1. Load models model = UNet2DModel( block_out_channels=(32, 64), diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 6a76581632ad..0243e8840522 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -333,7 +333,7 @@ def check_over_forward(self, time_step=0, **forward_kwargs): assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): kwargs = dict(self.forward_default_kwargs) num_inference_steps = kwargs.pop("num_inference_steps", None) @@ -860,7 +860,7 @@ def check_over_configs(self, time_step=0, **config): assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): pass def check_over_forward(self, time_step=0, **forward_kwargs): @@ -1037,7 +1037,7 @@ def check_over_configs(self, time_step=0, **config): assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): pass def check_over_forward(self, time_step=0, **forward_kwargs): @@ -1717,7 +1717,7 @@ def check_over_configs(self, time_step=0, **config): assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): pass def check_over_forward(self, time_step=0, **forward_kwargs): diff --git a/tests/test_scheduler_flax.py b/tests/test_scheduler_flax.py index 5ada689b724d..da1042f3d698 100644 --- a/tests/test_scheduler_flax.py +++ b/tests/test_scheduler_flax.py @@ -126,7 +126,7 @@ def check_over_forward(self, time_step=0, **forward_kwargs): assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): kwargs = dict(self.forward_default_kwargs) num_inference_steps = kwargs.pop("num_inference_steps", None) @@ -408,7 +408,7 @@ def check_over_configs(self, time_step=0, **config): assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): kwargs = dict(self.forward_default_kwargs) num_inference_steps = kwargs.pop("num_inference_steps", None) @@ -690,7 +690,7 @@ def check_over_configs(self, time_step=0, **config): assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): pass def test_scheduler_outputs_equivalence(self):