diff --git a/tensorflow_addons/optimizers/discriminative_layer_training.py b/tensorflow_addons/optimizers/discriminative_layer_training.py index 4231fc7da3..1b8b697ba8 100644 --- a/tensorflow_addons/optimizers/discriminative_layer_training.py +++ b/tensorflow_addons/optimizers/discriminative_layer_training.py @@ -143,7 +143,15 @@ def apply_gradients(self, grads_and_vars, **kwargs): def get_config(self): config = super(MultiOptimizer, self).get_config() - config.update({"optimizer_specs": self.optimizer_specs}) + optimizer_specs_without_gv = [] + for optimizer_spec in self.optimizer_specs: + optimizer_specs_without_gv.append( + { + "optimizer": optimizer_spec["optimizer"], + "weights": optimizer_spec["weights"], + } + ) + config.update({"optimizer_specs": optimizer_specs_without_gv}) return config @classmethod diff --git a/tensorflow_addons/optimizers/tests/discriminative_layer_training_test.py b/tensorflow_addons/optimizers/tests/discriminative_layer_training_test.py index f8a76fdbf0..e53d75a157 100644 --- a/tensorflow_addons/optimizers/tests/discriminative_layer_training_test.py +++ b/tensorflow_addons/optimizers/tests/discriminative_layer_training_test.py @@ -286,3 +286,38 @@ def test_serialization(): new_optimizer = tf.keras.optimizers.deserialize(config) assert new_optimizer.get_config() == optimizer.get_config() + + +def test_serialization_after_training(tmpdir): + x = np.array(np.ones([100])) + y = np.array(np.ones([100])) + model = tf.keras.Sequential( + [tf.keras.Input(shape=[1]), tf.keras.layers.Dense(1), tf.keras.layers.Dense(1)] + ) + + opt1 = tf.keras.optimizers.Adam(learning_rate=1e-3) + opt2 = tf.keras.optimizers.SGD(learning_rate=0) + + opt_layer_pairs = [(opt1, model.layers[0]), (opt2, model.layers[1])] + + optimizer = MultiOptimizer(opt_layer_pairs) + + # Train the model for a few epochs. + model.compile(loss="categorical_crossentropy", optimizer=optimizer) + model.fit(x, y) + + # Verify the optimizer can still be serialized (saved). + model.save(str(tmpdir)) + loaded_model = tf.keras.models.load_model(str(tmpdir)) + old_config = model.optimizer.get_config() + new_config = loaded_model.optimizer.get_config() + # Verify the loaded model has the same optimizer as before. + assert len(old_config["optimizer_specs"]) == len(new_config["optimizer_specs"]) + for old_optimizer_spec, new_optimizer_spec in zip( + old_config["optimizer_specs"], new_config["optimizer_specs"] + ): + assert old_optimizer_spec["weights"] == new_optimizer_spec["weights"] + assert ( + old_optimizer_spec["optimizer"].get_config() + == new_optimizer_spec["optimizer"].get_config() + )