diff --git a/tensorflow_addons/optimizers/rectified_adam.py b/tensorflow_addons/optimizers/rectified_adam.py index ee2c29efb7..97a318fb36 100644 --- a/tensorflow_addons/optimizers/rectified_adam.py +++ b/tensorflow_addons/optimizers/rectified_adam.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== """Rectified Adam (RAdam) optimizer.""" +import warnings import tensorflow as tf from tensorflow_addons.utils.types import FloatTensorLike @@ -79,7 +80,10 @@ def __init__( weight_decay: FloatTensorLike = 0.0, amsgrad: bool = False, sma_threshold: FloatTensorLike = 5.0, - total_steps: int = 0, + # float for total_steps is here to be able to load models created before + # https://github.com/tensorflow/addons/pull/1375 was merged. It should be + # removed for Addons 0.11. + total_steps: Union[int, float] = 0, warmup_proportion: FloatTensorLike = 0.1, min_lr: FloatTensorLike = 0.0, name: str = "RectifiedAdam", @@ -123,7 +127,16 @@ def __init__( self._set_hyper("decay", self._initial_decay) self._set_hyper("weight_decay", weight_decay) self._set_hyper("sma_threshold", sma_threshold) - self._set_hyper("total_steps", float(total_steps)) + if isinstance(total_steps, float): + warnings.warn( + "The parameter `total_steps` passed to the __init__ of RectifiedAdam " + "is a float. This behavior is deprecated and in Addons 0.11, this " + "will raise an error. Use an int instead. If you get this message " + "when loading a model, save it again and the `total_steps` parameter " + "will automatically be converted to a int.", + DeprecationWarning, + ) + self._set_hyper("total_steps", int(total_steps)) self._set_hyper("warmup_proportion", warmup_proportion) self._set_hyper("min_lr", min_lr) self.epsilon = epsilon or tf.keras.backend.epsilon() diff --git a/tensorflow_addons/optimizers/rectified_adam_test.py b/tensorflow_addons/optimizers/rectified_adam_test.py index 5950fcdd31..73e0a56b6b 100644 --- a/tensorflow_addons/optimizers/rectified_adam_test.py +++ b/tensorflow_addons/optimizers/rectified_adam_test.py @@ -172,5 +172,14 @@ def test_get_config(self): self.assertEqual(config["total_steps"], 0) +def test_serialization(): + optimizer = RectifiedAdam( + lr=1e-3, total_steps=10000, warmup_proportion=0.1, min_lr=1e-5, + ) + config = tf.keras.optimizers.serialize(optimizer) + new_optimizer = tf.keras.optimizers.deserialize(config) + assert new_optimizer.get_config() == optimizer.get_config() + + if __name__ == "__main__": sys.exit(pytest.main([__file__]))