@@ -436,9 +436,6 @@ def from_pretrained(
436436 )
437437 cls ._missing_keys = missing_keys
438438
439- # Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
440- # matching the weights in the model.
441- mismatched_keys = []
442439 for key in state .keys ():
443440 if key in shape_state and state [key ].shape != shape_state [key ].shape :
444441 raise ValueError (
@@ -466,26 +463,13 @@ def from_pretrained(
466463 f" { pretrained_model_name_or_path } and are newly initialized: { missing_keys } \n You should probably"
467464 " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
468465 )
469- elif len ( mismatched_keys ) == 0 :
466+ else :
470467 logger .info (
471468 f"All the weights of { model .__class__ .__name__ } were initialized from the model checkpoint at"
472469 f" { pretrained_model_name_or_path } .\n If your task is similar to the task the model of the checkpoint"
473470 f" was trained on, you can already use { model .__class__ .__name__ } for predictions without further"
474471 " training."
475472 )
476- if len (mismatched_keys ) > 0 :
477- mismatched_warning = "\n " .join (
478- [
479- f"- { key } : found shape { shape1 } in the checkpoint and { shape2 } in the model instantiated"
480- for key , shape1 , shape2 in mismatched_keys
481- ]
482- )
483- logger .warning (
484- f"Some weights of { model .__class__ .__name__ } were not initialized from the model checkpoint at"
485- f" { pretrained_model_name_or_path } and are newly initialized because the shapes did not"
486- f" match:\n { mismatched_warning } \n You should probably TRAIN this model on a down-stream task to be able"
487- " to use it for predictions and inference."
488- )
489473
490474 # dictionary of key: dtypes for the model params
491475 param_dtypes = jax .tree_map (lambda x : x .dtype , state )
0 commit comments