From 5707078f4a78f80b756d01a6385eda80c1148b07 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 3 Apr 2020 09:44:50 +0000 Subject: [PATCH 1/2] Fix LAMB optimizer regex parsing --- tensorflow_addons/optimizers/lamb.py | 10 +++++----- tensorflow_addons/optimizers/lamb_test.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/tensorflow_addons/optimizers/lamb.py b/tensorflow_addons/optimizers/lamb.py index d5a5807048..1e121e172d 100644 --- a/tensorflow_addons/optimizers/lamb.py +++ b/tensorflow_addons/optimizers/lamb.py @@ -19,7 +19,7 @@ """ import re -from typing import Optional, Union, Callable +from typing import Optional, Union, Callable, List from typeguard import typechecked import tensorflow as tf @@ -42,8 +42,8 @@ def __init__( beta_2: FloatTensorLike = 0.999, epsilon: FloatTensorLike = 1e-6, weight_decay_rate: FloatTensorLike = 0.0, - exclude_from_weight_decay: Optional[str] = None, - exclude_from_layer_adaptation: Optional[str] = None, + exclude_from_weight_decay: Optional[List[str]] = None, + exclude_from_layer_adaptation: Optional[List[str]] = None, name: str = "LAMB", **kwargs ): @@ -59,10 +59,10 @@ def __init__( The exponential decay rate for the 2nd moment estimates. epsilon: A small constant for numerical stability. weight_decay_rate: weight decay rate. - exclude_from_weight_decay: comma separated name patterns of + exclude_from_weight_decay: List of regex patterns of variables excluded from weight decay. Variables whose name contain a substring matching the pattern will be excluded. - exclude_from_layer_adaptation: comma separated name patterns of + exclude_from_layer_adaptation: List of regex patterns of variables excluded from layer adaptation. Variables whose name contain a substring matching the pattern will be excluded. name: Optional name for the operations created when applying diff --git a/tensorflow_addons/optimizers/lamb_test.py b/tensorflow_addons/optimizers/lamb_test.py index 18ab6622a4..41546e3136 100644 --- a/tensorflow_addons/optimizers/lamb_test.py +++ b/tensorflow_addons/optimizers/lamb_test.py @@ -401,6 +401,22 @@ def test_get_config(self): config = opt.get_config() self.assertEqual(config["learning_rate"], 1e-4) +<<<<<<< HEAD if __name__ == "__main__": tf.test.main() +======= + def test_exclude_weight_decay(self): + opt = lamb.LAMB( + 0.01, weight_decay_rate=0.01, exclude_from_weight_decay=["var1"] + ) + assert opt._do_use_weight_decay("var0") + assert not opt._do_use_weight_decay("var1") + assert not opt._do_use_weight_decay("var1_weight") + + def test_exclude_layer_adaptation(self): + opt = lamb.LAMB(0.01, exclude_from_layer_adaptation=["var1"]) + assert opt._do_layer_adaptation("var0") + assert not opt._do_layer_adaptation("var1") + assert not opt._do_layer_adaptation("var1_weight") +>>>>>>> ce16e62... Fix LAMB optimizer regex parsing (#1532) From 1ae8548267b4c935d811ebff5d7a63a5d21b8911 Mon Sep 17 00:00:00 2001 From: gabrieldemarmiesse Date: Fri, 3 Apr 2020 09:51:29 +0000 Subject: [PATCH 2/2] Fix conflict. --- tensorflow_addons/optimizers/lamb_test.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tensorflow_addons/optimizers/lamb_test.py b/tensorflow_addons/optimizers/lamb_test.py index 41546e3136..987bb982b3 100644 --- a/tensorflow_addons/optimizers/lamb_test.py +++ b/tensorflow_addons/optimizers/lamb_test.py @@ -401,11 +401,6 @@ def test_get_config(self): config = opt.get_config() self.assertEqual(config["learning_rate"], 1e-4) -<<<<<<< HEAD - -if __name__ == "__main__": - tf.test.main() -======= def test_exclude_weight_decay(self): opt = lamb.LAMB( 0.01, weight_decay_rate=0.01, exclude_from_weight_decay=["var1"] @@ -419,4 +414,7 @@ def test_exclude_layer_adaptation(self): assert opt._do_layer_adaptation("var0") assert not opt._do_layer_adaptation("var1") assert not opt._do_layer_adaptation("var1_weight") ->>>>>>> ce16e62... Fix LAMB optimizer regex parsing (#1532) + + +if __name__ == "__main__": + tf.test.main()