File tree Expand file tree Collapse file tree 1 file changed +0
-23
lines changed Expand file tree Collapse file tree 1 file changed +0
-23
lines changed Original file line number Diff line number Diff line change @@ -482,29 +482,6 @@ def from_pretrained(
482482 " training."
483483 )
484484
485- # dictionary of key: dtypes for the model params
486- param_dtypes = jax .tree_map (lambda x : x .dtype , state )
487- # extract keys of parameters not in jnp.float32
488- fp16_params = [k for k in param_dtypes if param_dtypes [k ] == jnp .float16 ]
489- bf16_params = [k for k in param_dtypes if param_dtypes [k ] == jnp .bfloat16 ]
490-
491- # raise a warning if any of the parameters are not in jnp.float32
492- if len (fp16_params ) > 0 :
493- logger .warning (
494- f"Some of the weights of { model .__class__ .__name__ } were initialized in float16 precision from "
495- f"the model checkpoint at { pretrained_model_name_or_path } :\n { fp16_params } \n "
496- "You should probably UPCAST the model weights to float32 if this was not intended. "
497- "See [`~ModelMixin.to_fp32`] for further information on how to do this."
498- )
499-
500- if len (bf16_params ) > 0 :
501- logger .warning (
502- f"Some of the weights of { model .__class__ .__name__ } were initialized in bfloat16 precision from "
503- f"the model checkpoint at { pretrained_model_name_or_path } :\n { bf16_params } \n "
504- "You should probably UPCAST the model weights to float32 if this was not intended. "
505- "See [`~ModelMixin.to_fp32`] for further information on how to do this."
506- )
507-
508485 return model , unflatten_dict (state )
509486
510487 def save_pretrained (
You can’t perform that action at this time.
0 commit comments