2020import jax
2121import jax .numpy as jnp
2222import msgpack .exceptions
23- from flax .core .frozen_dict import FrozenDict
23+ from flax .core .frozen_dict import FrozenDict , unfreeze
2424from flax .serialization import from_bytes , to_bytes
2525from flax .traverse_util import flatten_dict , unflatten_dict
2626from huggingface_hub import hf_hub_download
@@ -183,6 +183,9 @@ def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
183183 ```"""
184184 return self ._cast_floating_to (params , jnp .float16 , mask )
185185
186+ def init_weights (self , rng : jax .random .PRNGKey ) -> Dict :
187+ raise NotImplementedError (f"init method has to be implemented for { self } " )
188+
186189 @classmethod
187190 def from_pretrained (
188191 cls ,
@@ -272,6 +275,7 @@ def from_pretrained(
272275 ```"""
273276 config = kwargs .pop ("config" , None )
274277 cache_dir = kwargs .pop ("cache_dir" , DIFFUSERS_CACHE )
278+ ignore_mismatched_sizes = kwargs .pop ("ignore_mismatched_sizes" , False )
275279 force_download = kwargs .pop ("force_download" , False )
276280 resume_download = kwargs .pop ("resume_download" , False )
277281 proxies = kwargs .pop ("proxies" , None )
@@ -280,6 +284,7 @@ def from_pretrained(
280284 revision = kwargs .pop ("revision" , None )
281285 from_auto_class = kwargs .pop ("_from_auto" , False )
282286 subfolder = kwargs .pop ("subfolder" , None )
287+ prng_key = kwargs .pop ("prng_key" , None )
283288
284289 user_agent = {"file_type" : "model" , "framework" : "flax" , "from_auto_class" : from_auto_class }
285290
@@ -394,6 +399,82 @@ def from_pretrained(
394399 # flatten dicts
395400 state = flatten_dict (state )
396401
402+ prng_key = prng_key if prng_key is not None else jax .random .PRNGKey (0 )
403+ params_shape_tree = jax .eval_shape (model .init_weights , prng_key )
404+ required_params = set (flatten_dict (unfreeze (params_shape_tree )).keys ())
405+
406+ random_state = flatten_dict (unfreeze (params_shape_tree ))
407+
408+ missing_keys = required_params - set (state .keys ())
409+ unexpected_keys = set (state .keys ()) - required_params
410+
411+ if missing_keys :
412+ logger .warning (
413+ f"The checkpoint { pretrained_model_name_or_path } is missing required keys: { missing_keys } . "
414+ "Make sure to call model.init_weights to initialize the missing weights."
415+ )
416+ cls ._missing_keys = missing_keys
417+
418+ # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
419+ # matching the weights in the model.
420+ mismatched_keys = []
421+ for key in state .keys ():
422+ if key in random_state and state [key ].shape != random_state [key ].shape :
423+ if ignore_mismatched_sizes :
424+ mismatched_keys .append ((key , state [key ].shape , random_state [key ].shape ))
425+ state [key ] = random_state [key ]
426+ else :
427+ raise ValueError (
428+ f"Trying to load the pretrained weight for { key } failed: checkpoint has shape "
429+ f"{ state [key ].shape } which is incompatible with the model shape { random_state [key ].shape } . "
430+ "Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this "
431+ "model."
432+ )
433+
434+ # remove unexpected keys to not be saved again
435+ for unexpected_key in unexpected_keys :
436+ del state [unexpected_key ]
437+
438+ if len (unexpected_keys ) > 0 :
439+ logger .warning (
440+ f"Some weights of the model checkpoint at { pretrained_model_name_or_path } were not used when"
441+ f" initializing { model .__class__ .__name__ } : { unexpected_keys } \n - This IS expected if you are"
442+ f" initializing { model .__class__ .__name__ } from the checkpoint of a model trained on another task or"
443+ " with another architecture (e.g. initializing a BertForSequenceClassification model from a"
444+ " BertForPreTraining model).\n - This IS NOT expected if you are initializing"
445+ f" { model .__class__ .__name__ } from the checkpoint of a model that you expect to be exactly identical"
446+ " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
447+ )
448+ else :
449+ logger .info (f"All model checkpoint weights were used when initializing { model .__class__ .__name__ } .\n " )
450+
451+ if len (missing_keys ) > 0 :
452+ logger .warning (
453+ f"Some weights of { model .__class__ .__name__ } were not initialized from the model checkpoint at"
454+ f" { pretrained_model_name_or_path } and are newly initialized: { missing_keys } \n You should probably"
455+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
456+ )
457+ elif len (mismatched_keys ) == 0 :
458+ logger .info (
459+ f"All the weights of { model .__class__ .__name__ } were initialized from the model checkpoint at"
460+ f" { pretrained_model_name_or_path } .\n If your task is similar to the task the model of the checkpoint"
461+ f" was trained on, you can already use { model .__class__ .__name__ } for predictions without further"
462+ " training."
463+ )
464+ if len (mismatched_keys ) > 0 :
465+ mismatched_warning = "\n " .join (
466+ [
467+ f"- { key } : found shape { shape1 } in the checkpoint and { shape2 } in the model instantiated"
468+ for key , shape1 , shape2 in mismatched_keys
469+ ]
470+ )
471+ logger .warning (
472+ f"Some weights of { model .__class__ .__name__ } were not initialized from the model checkpoint at"
473+ f" { pretrained_model_name_or_path } and are newly initialized because the shapes did not"
474+ f" match:\n { mismatched_warning } \n You should probably TRAIN this model on a down-stream task to be able"
475+ " to use it for predictions and inference."
476+ )
477+
397478 # dictionary of key: dtypes for the model params
398479 param_dtypes = jax .tree_map (lambda x : x .dtype , state )
399480 # extract keys of parameters not in jnp.float32
0 commit comments