Skip to content

Commit 803da8f

Browse files
committed
PRNGKey(0) for jax.eval_shape
1 parent 305a544 commit 803da8f

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

src/diffusers/modeling_flax_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,6 @@ def from_pretrained(
284284
revision = kwargs.pop("revision", None)
285285
from_auto_class = kwargs.pop("_from_auto", False)
286286
subfolder = kwargs.pop("subfolder", None)
287-
prng_key = kwargs.pop("prng_key", None)
288287

289288
user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class}
290289

@@ -399,7 +398,7 @@ def from_pretrained(
399398
# flatten dicts
400399
state = flatten_dict(state)
401400

402-
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
401+
prng_key = jax.random.PRNGKey(0)
403402
params_shape_tree = jax.eval_shape(model.init_weights, prng_key)
404403
required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
405404

0 commit comments

Comments
 (0)