File tree Expand file tree Collapse file tree 1 file changed +1
-2
lines changed Expand file tree Collapse file tree 1 file changed +1
-2
lines changed Original file line number Diff line number Diff line change @@ -284,7 +284,6 @@ def from_pretrained(
284284 revision = kwargs .pop ("revision" , None )
285285 from_auto_class = kwargs .pop ("_from_auto" , False )
286286 subfolder = kwargs .pop ("subfolder" , None )
287- prng_key = kwargs .pop ("prng_key" , None )
288287
289288 user_agent = {"file_type" : "model" , "framework" : "flax" , "from_auto_class" : from_auto_class }
290289
@@ -399,7 +398,7 @@ def from_pretrained(
399398 # flatten dicts
400399 state = flatten_dict (state )
401400
402- prng_key = prng_key if prng_key is not None else jax .random .PRNGKey (0 )
401+ prng_key = jax .random .PRNGKey (0 )
403402 params_shape_tree = jax .eval_shape (model .init_weights , prng_key )
404403 required_params = set (flatten_dict (unfreeze (params_shape_tree )).keys ())
405404
You can’t perform that action at this time.
0 commit comments