@@ -561,13 +561,16 @@ def check_format(preset):
561
561
return "keras"
562
562
563
563
564
- def load_serialized_object (
565
- preset ,
566
- config_file = CONFIG_FILE ,
567
- config_overrides = {},
568
- ):
564
+ def load_serialized_object (preset , config_file = CONFIG_FILE , ** kwargs ):
565
+ kwargs = kwargs or {}
569
566
config = load_config (preset , config_file )
570
- config ["config" ] = {** config ["config" ], ** config_overrides }
567
+
568
+ # `dtype` in config might be a serialized `DTypePolicy` or `DTypePolicyMap`.
569
+ # Ensure that `dtype` is properly configured.
570
+ dtype = kwargs .pop ("dtype" , None )
571
+ config = set_dtype_in_config (config , dtype )
572
+
573
+ config ["config" ] = {** config ["config" ], ** kwargs }
571
574
return keras .saving .deserialize_keras_object (config )
572
575
573
576
@@ -590,3 +593,25 @@ def jax_memory_cleanup(layer):
590
593
for weight in layer .weights :
591
594
if getattr (weight , "_value" , None ) is not None :
592
595
weight ._value .delete ()
596
+
597
+
598
+ def set_dtype_in_config (config , dtype = None ):
599
+ if dtype is None :
600
+ return config
601
+
602
+ config = config .copy ()
603
+ if "dtype" not in config ["config" ]:
604
+ # Forward `dtype` to the config.
605
+ config ["config" ]["dtype" ] = dtype
606
+ elif (
607
+ "dtype" in config ["config" ]
608
+ and isinstance (config ["config" ]["dtype" ], dict )
609
+ and "DTypePolicyMap" in config ["config" ]["dtype" ]["class_name" ]
610
+ ):
611
+ # If it is `DTypePolicyMap` in `config`, forward `dtype` as its default
612
+ # policy.
613
+ policy_map_config = config ["config" ]["dtype" ]["config" ]
614
+ policy_map_config ["default_policy" ] = dtype
615
+ for k in policy_map_config ["policy_map" ].keys ():
616
+ policy_map_config ["policy_map" ][k ]["config" ]["source_name" ] = dtype
617
+ return config
0 commit comments