Skip to content

Commit f10576a

Browse files
authored
Flax from_pretrained: clean up mismatched_keys. (#630)
Flax from_pretrained: clean up `mismatched_keys`. Originally removed in 73e0bc6.
1 parent 84b9df5 commit f10576a

File tree

1 file changed

+1
-17
lines changed

1 file changed

+1
-17
lines changed

src/diffusers/modeling_flax_utils.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -436,9 +436,6 @@ def from_pretrained(
436436
)
437437
cls._missing_keys = missing_keys
438438

439-
# Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
440-
# matching the weights in the model.
441-
mismatched_keys = []
442439
for key in state.keys():
443440
if key in shape_state and state[key].shape != shape_state[key].shape:
444441
raise ValueError(
@@ -466,26 +463,13 @@ def from_pretrained(
466463
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
467464
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
468465
)
469-
elif len(mismatched_keys) == 0:
466+
else:
470467
logger.info(
471468
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
472469
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
473470
f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
474471
" training."
475472
)
476-
if len(mismatched_keys) > 0:
477-
mismatched_warning = "\n".join(
478-
[
479-
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
480-
for key, shape1, shape2 in mismatched_keys
481-
]
482-
)
483-
logger.warning(
484-
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
485-
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
486-
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
487-
" to use it for predictions and inference."
488-
)
489473

490474
# dictionary of key: dtypes for the model params
491475
param_dtypes = jax.tree_map(lambda x: x.dtype, state)

0 commit comments

Comments
 (0)