From 3b4a1cf2fd1e31545581271e745cbf5a7c315854 Mon Sep 17 00:00:00 2001 From: leondgarse Date: Mon, 6 Dec 2021 14:26:51 +0800 Subject: [PATCH] Save exclude_from_weight_decay in config for LAMB Before: ```py import tensorflow_addons as tfa mm = keras.models.Sequential([keras.layers.Input([32, 32, 3]), keras.layers.Flatten(), keras.layers.Dense(10)]) mm.compile(optimizer=tfa.optimizers.LAMB(learning_rate=0.1, weight_decay_rate=0.02, exclude_from_weight_decay=['/gamma', '/beta'])) mm.save('aa.h5') bb = keras.models.load_model('aa.h5') print(bb.optimizer.exclude_from_weight_decay) # None, None ``` After: ```py import tensorflow_addons as tfa mm = keras.models.Sequential([keras.layers.Input([32, 32, 3]), keras.layers.Flatten(), keras.layers.Dense(10)]) mm.compile(optimizer=tfa.optimizers.LAMB(learning_rate=0.1, weight_decay_rate=0.02, exclude_from_weight_decay=['/gamma', '/beta'])) mm.save('aa.h5') bb = keras.models.load_model('aa.h5') print(bb.optimizer.exclude_from_weight_decay, bb.optimizer.exclude_from_layer_adaptation) # ['/gamma', '/beta'] ['/gamma', '/beta'] ``` --- tensorflow_addons/optimizers/lamb.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow_addons/optimizers/lamb.py b/tensorflow_addons/optimizers/lamb.py index 553ea27697..d3f9abbd75 100644 --- a/tensorflow_addons/optimizers/lamb.py +++ b/tensorflow_addons/optimizers/lamb.py @@ -235,6 +235,8 @@ def get_config(self): "beta_1": self._serialize_hyperparameter("beta_1"), "beta_2": self._serialize_hyperparameter("beta_2"), "epsilon": self.epsilon, + "exclude_from_weight_decay": self.exclude_from_weight_decay, + "exclude_from_layer_adaptation": self.exclude_from_layer_adaptation, } ) return config