Skip to content

Commit 305a544

Browse files
committed
Rn random_state -> shape_state
1 parent 32c2be5 commit 305a544

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/diffusers/modeling_flax_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def from_pretrained(
403403
params_shape_tree = jax.eval_shape(model.init_weights, prng_key)
404404
required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
405405

406-
random_state = flatten_dict(unfreeze(params_shape_tree))
406+
shape_state = flatten_dict(unfreeze(params_shape_tree))
407407

408408
missing_keys = required_params - set(state.keys())
409409
unexpected_keys = set(state.keys()) - required_params
@@ -419,14 +419,14 @@ def from_pretrained(
419419
# matching the weights in the model.
420420
mismatched_keys = []
421421
for key in state.keys():
422-
if key in random_state and state[key].shape != random_state[key].shape:
422+
if key in shape_state and state[key].shape != shape_state[key].shape:
423423
if ignore_mismatched_sizes:
424-
mismatched_keys.append((key, state[key].shape, random_state[key].shape))
425-
state[key] = random_state[key]
424+
mismatched_keys.append((key, state[key].shape, shape_state[key].shape))
425+
state[key] = shape_state[key]
426426
else:
427427
raise ValueError(
428428
f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
429-
f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. "
429+
f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. "
430430
"Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this "
431431
"model."
432432
)

0 commit comments

Comments
 (0)