diff --git a/tensorflow_addons/optimizers/lamb.py b/tensorflow_addons/optimizers/lamb.py index 32995f5c45..553ea27697 100644 --- a/tensorflow_addons/optimizers/lamb.py +++ b/tensorflow_addons/optimizers/lamb.py @@ -19,6 +19,8 @@ """ import re +import warnings + from typing import Optional, Union, Callable, List from typeguard import typechecked @@ -41,7 +43,7 @@ def __init__( beta_1: FloatTensorLike = 0.9, beta_2: FloatTensorLike = 0.999, epsilon: FloatTensorLike = 1e-6, - weight_decay_rate: FloatTensorLike = 0.0, + weight_decay: FloatTensorLike = 0.0, exclude_from_weight_decay: Optional[List[str]] = None, exclude_from_layer_adaptation: Optional[List[str]] = None, name: str = "LAMB", @@ -58,7 +60,7 @@ def __init__( beta_2: A `float` value or a constant `float` tensor. The exponential decay rate for the 2nd moment estimates. epsilon: A small constant for numerical stability. - weight_decay_rate: weight decay rate. + weight_decay: weight decay. exclude_from_weight_decay: List of regex patterns of variables excluded from weight decay. Variables whose name contain a substring matching the pattern will be excluded. @@ -74,6 +76,16 @@ def __init__( decay of learning rate. `lr` is included for backward compatibility, recommended to use `learning_rate` instead. """ + + if "weight_decay_rate" in kwargs: + warnings.warn( + "weight_decay_rate has been renamed to weight_decay," + "and will be deprecated in Addons 0.18.", + DeprecationWarning, + ) + weight_decay = kwargs["weight_decay_rate"] + del kwargs["weight_decay_rate"] + super().__init__(name, **kwargs) # Just adding the square of the weights to the loss function is *not* @@ -82,7 +94,7 @@ def __init__( # # Instead we want to decay the weights in a manner that doesn't interact # with the m/v parameters. - self._set_hyper("weight_decay_rate", weight_decay_rate) + self._set_hyper("weight_decay", weight_decay) self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) # This is learning rate decay for using keras learning rate schedule. @@ -112,12 +124,12 @@ def _prepare_local(self, var_device, var_dtype, apply_state): local_step = tf.cast(self.iterations + 1, var_dtype) beta_1_t = tf.identity(self._get_hyper("beta_1", var_dtype)) beta_2_t = tf.identity(self._get_hyper("beta_2", var_dtype)) - weight_decay_rate = tf.identity(self._get_hyper("weight_decay_rate", var_dtype)) + weight_decay = tf.identity(self._get_hyper("weight_decay", var_dtype)) beta_1_power = tf.pow(beta_1_t, local_step) beta_2_power = tf.pow(beta_2_t, local_step) apply_state[(var_device, var_dtype)].update( dict( - weight_decay_rate=weight_decay_rate, + weight_decay=weight_decay, epsilon=tf.convert_to_tensor(self.epsilon, var_dtype), beta_1_t=beta_1_t, beta_1_power=beta_1_power, @@ -153,7 +165,7 @@ def _resource_apply_dense(self, grad, var, apply_state=None): var_name = self._get_variable_name(var.name) if self._do_use_weight_decay(var_name): - update += coefficients["weight_decay_rate"] * var + update += coefficients["weight_decay"] * var ratio = 1.0 if self._do_layer_adaptation(var_name): @@ -196,7 +208,7 @@ def _resource_apply_sparse(self, grad, var, indices, apply_state=None): var_name = self._get_variable_name(var.name) if self._do_use_weight_decay(var_name): - update += coefficients["weight_decay_rate"] * var + update += coefficients["weight_decay"] * var ratio = 1.0 if self._do_layer_adaptation(var_name): @@ -218,9 +230,7 @@ def get_config(self): config.update( { "learning_rate": self._serialize_hyperparameter("learning_rate"), - "weight_decay_rate": self._serialize_hyperparameter( - "weight_decay_rate" - ), + "weight_decay": self._serialize_hyperparameter("weight_decay"), "decay": self._serialize_hyperparameter("decay"), "beta_1": self._serialize_hyperparameter("beta_1"), "beta_2": self._serialize_hyperparameter("beta_2"), diff --git a/tensorflow_addons/optimizers/tests/lamb_test.py b/tensorflow_addons/optimizers/tests/lamb_test.py index 9c497600e2..631aed99a5 100644 --- a/tensorflow_addons/optimizers/tests/lamb_test.py +++ b/tensorflow_addons/optimizers/tests/lamb_test.py @@ -138,7 +138,7 @@ def test_basic_with_learning_rate_decay(): beta_1=beta_1, beta_2=beta_2, epsilon=epsilon, - weight_decay_rate=lamb_wd, + weight_decay=lamb_wd, decay=decay, ) @@ -280,7 +280,7 @@ def test_minimize_mean_square_loss_with_weight_decay(): def loss(): return tf.reduce_mean(tf.square(x - w)) - opt = lamb.LAMB(0.02, weight_decay_rate=0.01) + opt = lamb.LAMB(0.02, weight_decay=0.01) # Run 200 steps for _ in range(200): @@ -334,7 +334,7 @@ def test_get_config(): def test_exclude_weight_decay(): - opt = lamb.LAMB(0.01, weight_decay_rate=0.01, exclude_from_weight_decay=["var1"]) + opt = lamb.LAMB(0.01, weight_decay=0.01, exclude_from_weight_decay=["var1"]) assert opt._do_use_weight_decay("var0") assert not opt._do_use_weight_decay("var1") assert not opt._do_use_weight_decay("var1_weight") @@ -352,3 +352,10 @@ def test_serialization(): config = tf.keras.optimizers.serialize(optimizer) new_optimizer = tf.keras.optimizers.deserialize(config) assert new_optimizer.get_config() == optimizer.get_config() + + +def test_weight_decay_rate_deprecation(): + with pytest.deprecated_call(): + opt = lamb.LAMB(0.01, weight_decay_rate=0.01) + config = opt.get_config() + assert config["weight_decay"] == 0.01