Skip to content

Commit 32c2be5

Browse files
committed
Add init_weights method to FlaxMixin
1 parent 83a7bb2 commit 32c2be5

File tree

1 file changed

+82
-1
lines changed

1 file changed

+82
-1
lines changed

src/diffusers/modeling_flax_utils.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import jax
2121
import jax.numpy as jnp
2222
import msgpack.exceptions
23-
from flax.core.frozen_dict import FrozenDict
23+
from flax.core.frozen_dict import FrozenDict, unfreeze
2424
from flax.serialization import from_bytes, to_bytes
2525
from flax.traverse_util import flatten_dict, unflatten_dict
2626
from huggingface_hub import hf_hub_download
@@ -183,6 +183,9 @@ def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
183183
```"""
184184
return self._cast_floating_to(params, jnp.float16, mask)
185185

186+
def init_weights(self, rng: jax.random.PRNGKey) -> Dict:
187+
raise NotImplementedError(f"init method has to be implemented for {self}")
188+
186189
@classmethod
187190
def from_pretrained(
188191
cls,
@@ -272,6 +275,7 @@ def from_pretrained(
272275
```"""
273276
config = kwargs.pop("config", None)
274277
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
278+
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
275279
force_download = kwargs.pop("force_download", False)
276280
resume_download = kwargs.pop("resume_download", False)
277281
proxies = kwargs.pop("proxies", None)
@@ -280,6 +284,7 @@ def from_pretrained(
280284
revision = kwargs.pop("revision", None)
281285
from_auto_class = kwargs.pop("_from_auto", False)
282286
subfolder = kwargs.pop("subfolder", None)
287+
prng_key = kwargs.pop("prng_key", None)
283288

284289
user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class}
285290

@@ -394,6 +399,82 @@ def from_pretrained(
394399
# flatten dicts
395400
state = flatten_dict(state)
396401

402+
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
403+
params_shape_tree = jax.eval_shape(model.init_weights, prng_key)
404+
required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
405+
406+
random_state = flatten_dict(unfreeze(params_shape_tree))
407+
408+
missing_keys = required_params - set(state.keys())
409+
unexpected_keys = set(state.keys()) - required_params
410+
411+
if missing_keys:
412+
logger.warning(
413+
f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
414+
"Make sure to call model.init_weights to initialize the missing weights."
415+
)
416+
cls._missing_keys = missing_keys
417+
418+
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
419+
# matching the weights in the model.
420+
mismatched_keys = []
421+
for key in state.keys():
422+
if key in random_state and state[key].shape != random_state[key].shape:
423+
if ignore_mismatched_sizes:
424+
mismatched_keys.append((key, state[key].shape, random_state[key].shape))
425+
state[key] = random_state[key]
426+
else:
427+
raise ValueError(
428+
f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
429+
f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. "
430+
"Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this "
431+
"model."
432+
)
433+
434+
# remove unexpected keys to not be saved again
435+
for unexpected_key in unexpected_keys:
436+
del state[unexpected_key]
437+
438+
if len(unexpected_keys) > 0:
439+
logger.warning(
440+
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
441+
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
442+
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
443+
" with another architecture (e.g. initializing a BertForSequenceClassification model from a"
444+
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
445+
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
446+
" (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
447+
)
448+
else:
449+
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
450+
451+
if len(missing_keys) > 0:
452+
logger.warning(
453+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
454+
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
455+
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
456+
)
457+
elif len(mismatched_keys) == 0:
458+
logger.info(
459+
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
460+
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
461+
f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
462+
" training."
463+
)
464+
if len(mismatched_keys) > 0:
465+
mismatched_warning = "\n".join(
466+
[
467+
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
468+
for key, shape1, shape2 in mismatched_keys
469+
]
470+
)
471+
logger.warning(
472+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
473+
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
474+
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
475+
" to use it for predictions and inference."
476+
)
477+
397478
# dictionary of key: dtypes for the model params
398479
param_dtypes = jax.tree_map(lambda x: x.dtype, state)
399480
# extract keys of parameters not in jnp.float32

0 commit comments

Comments
 (0)