Skip to content

Commit dec18c8

Browse files
authored
[Flax] dont warn for bf16 weights (#923)
dont warn for bf16 weights
1 parent 25dfd0f commit dec18c8

File tree

1 file changed

+0
-23
lines changed

1 file changed

+0
-23
lines changed

src/diffusers/modeling_flax_utils.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -482,29 +482,6 @@ def from_pretrained(
482482
" training."
483483
)
484484

485-
# dictionary of key: dtypes for the model params
486-
param_dtypes = jax.tree_map(lambda x: x.dtype, state)
487-
# extract keys of parameters not in jnp.float32
488-
fp16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.float16]
489-
bf16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16]
490-
491-
# raise a warning if any of the parameters are not in jnp.float32
492-
if len(fp16_params) > 0:
493-
logger.warning(
494-
f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from "
495-
f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n"
496-
"You should probably UPCAST the model weights to float32 if this was not intended. "
497-
"See [`~ModelMixin.to_fp32`] for further information on how to do this."
498-
)
499-
500-
if len(bf16_params) > 0:
501-
logger.warning(
502-
f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from "
503-
f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n"
504-
"You should probably UPCAST the model weights to float32 if this was not intended. "
505-
"See [`~ModelMixin.to_fp32`] for further information on how to do this."
506-
)
507-
508485
return model, unflatten_dict(state)
509486

510487
def save_pretrained(

0 commit comments

Comments
 (0)