File tree Expand file tree Collapse file tree 1 file changed +5
-5
lines changed Expand file tree Collapse file tree 1 file changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -403,7 +403,7 @@ def from_pretrained(
403403 params_shape_tree = jax .eval_shape (model .init_weights , prng_key )
404404 required_params = set (flatten_dict (unfreeze (params_shape_tree )).keys ())
405405
406- random_state = flatten_dict (unfreeze (params_shape_tree ))
406+ shape_state = flatten_dict (unfreeze (params_shape_tree ))
407407
408408 missing_keys = required_params - set (state .keys ())
409409 unexpected_keys = set (state .keys ()) - required_params
@@ -419,14 +419,14 @@ def from_pretrained(
419419 # matching the weights in the model.
420420 mismatched_keys = []
421421 for key in state .keys ():
422- if key in random_state and state [key ].shape != random_state [key ].shape :
422+ if key in shape_state and state [key ].shape != shape_state [key ].shape :
423423 if ignore_mismatched_sizes :
424- mismatched_keys .append ((key , state [key ].shape , random_state [key ].shape ))
425- state [key ] = random_state [key ]
424+ mismatched_keys .append ((key , state [key ].shape , shape_state [key ].shape ))
425+ state [key ] = shape_state [key ]
426426 else :
427427 raise ValueError (
428428 f"Trying to load the pretrained weight for { key } failed: checkpoint has shape "
429- f"{ state [key ].shape } which is incompatible with the model shape { random_state [key ].shape } . "
429+ f"{ state [key ].shape } which is incompatible with the model shape { shape_state [key ].shape } . "
430430 "Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this "
431431 "model."
432432 )
You can’t perform that action at this time.
0 commit comments