Skip to content

Commit 73e0bc6

Browse files
committed
No allow mismatched sizes
1 parent 803da8f commit 73e0bc6

File tree

1 file changed

+4
-15
lines changed

1 file changed

+4
-15
lines changed

src/diffusers/modeling_flax_utils.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -230,10 +230,6 @@ def from_pretrained(
230230
cache_dir (`Union[str, os.PathLike]`, *optional*):
231231
Path to a directory in which a downloaded pretrained model configuration should be cached if the
232232
standard cache should not be used.
233-
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
234-
Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
235-
as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
236-
checkpoint with 3 labels).
237233
force_download (`bool`, *optional*, defaults to `False`):
238234
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
239235
cached versions if they exist.
@@ -275,7 +271,6 @@ def from_pretrained(
275271
```"""
276272
config = kwargs.pop("config", None)
277273
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
278-
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
279274
force_download = kwargs.pop("force_download", False)
280275
resume_download = kwargs.pop("resume_download", False)
281276
proxies = kwargs.pop("proxies", None)
@@ -419,16 +414,10 @@ def from_pretrained(
419414
mismatched_keys = []
420415
for key in state.keys():
421416
if key in shape_state and state[key].shape != shape_state[key].shape:
422-
if ignore_mismatched_sizes:
423-
mismatched_keys.append((key, state[key].shape, shape_state[key].shape))
424-
state[key] = shape_state[key]
425-
else:
426-
raise ValueError(
427-
f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
428-
f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. "
429-
"Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this "
430-
"model."
431-
)
417+
raise ValueError(
418+
f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
419+
f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. "
420+
)
432421

433422
# remove unexpected keys to not be saved again
434423
for unexpected_key in unexpected_keys:

0 commit comments

Comments
 (0)