From ceb0859cb39f8da4fe75124233ca759834319fac Mon Sep 17 00:00:00 2001 From: Jared Nielsen Date: Wed, 1 Apr 2020 14:47:25 -0700 Subject: [PATCH 1/6] Fix type for LAMB optimizer exclude_from_weight_decay --- tensorflow_addons/optimizers/lamb.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow_addons/optimizers/lamb.py b/tensorflow_addons/optimizers/lamb.py index d5a5807048..b2a112fc45 100644 --- a/tensorflow_addons/optimizers/lamb.py +++ b/tensorflow_addons/optimizers/lamb.py @@ -42,8 +42,8 @@ def __init__( beta_2: FloatTensorLike = 0.999, epsilon: FloatTensorLike = 1e-6, weight_decay_rate: FloatTensorLike = 0.0, - exclude_from_weight_decay: Optional[str] = None, - exclude_from_layer_adaptation: Optional[str] = None, + exclude_from_weight_decay: List[str] = None, + exclude_from_layer_adaptation: List[str] = None, name: str = "LAMB", **kwargs ): @@ -59,10 +59,10 @@ def __init__( The exponential decay rate for the 2nd moment estimates. epsilon: A small constant for numerical stability. weight_decay_rate: weight decay rate. - exclude_from_weight_decay: comma separated name patterns of + exclude_from_weight_decay: List of regex patterns of variables excluded from weight decay. Variables whose name contain a substring matching the pattern will be excluded. - exclude_from_layer_adaptation: comma separated name patterns of + exclude_from_layer_adaptation: List of regex patterns of variables excluded from layer adaptation. Variables whose name contain a substring matching the pattern will be excluded. name: Optional name for the operations created when applying From 43cdde584b0c42a9a6484fd9a1a6e7d4d318185f Mon Sep 17 00:00:00 2001 From: Jared Nielsen Date: Wed, 1 Apr 2020 14:50:30 -0700 Subject: [PATCH 2/6] Add import --- tensorflow_addons/optimizers/lamb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/optimizers/lamb.py b/tensorflow_addons/optimizers/lamb.py index b2a112fc45..a2389dbc46 100644 --- a/tensorflow_addons/optimizers/lamb.py +++ b/tensorflow_addons/optimizers/lamb.py @@ -19,7 +19,7 @@ """ import re -from typing import Optional, Union, Callable +from typing import Optional, Union, Callable, List from typeguard import typechecked import tensorflow as tf From 910a0a6fe68ffa9509c5f61a8ac2cf6dcb69daf9 Mon Sep 17 00:00:00 2001 From: Jared Nielsen Date: Wed, 1 Apr 2020 14:53:29 -0700 Subject: [PATCH 3/6] Add optional wrapper --- tensorflow_addons/optimizers/lamb.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow_addons/optimizers/lamb.py b/tensorflow_addons/optimizers/lamb.py index a2389dbc46..1e121e172d 100644 --- a/tensorflow_addons/optimizers/lamb.py +++ b/tensorflow_addons/optimizers/lamb.py @@ -42,8 +42,8 @@ def __init__( beta_2: FloatTensorLike = 0.999, epsilon: FloatTensorLike = 1e-6, weight_decay_rate: FloatTensorLike = 0.0, - exclude_from_weight_decay: List[str] = None, - exclude_from_layer_adaptation: List[str] = None, + exclude_from_weight_decay: Optional[List[str]] = None, + exclude_from_layer_adaptation: Optional[List[str]] = None, name: str = "LAMB", **kwargs ): From c79c65299e14727c821c602cca4834d5c7d1887b Mon Sep 17 00:00:00 2001 From: Jared Nielsen Date: Thu, 2 Apr 2020 11:01:12 -0700 Subject: [PATCH 4/6] Add test --- tensorflow_addons/optimizers/lamb_test.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tensorflow_addons/optimizers/lamb_test.py b/tensorflow_addons/optimizers/lamb_test.py index 398c20ee94..88744dd31f 100644 --- a/tensorflow_addons/optimizers/lamb_test.py +++ b/tensorflow_addons/optimizers/lamb_test.py @@ -404,6 +404,14 @@ def test_get_config(self): config = opt.get_config() self.assertEqual(config["learning_rate"], 1e-4) + def test_exclude_weight_decay(self): + opt = lamb.LAMB( + 0.01, weight_decay_rate=0.01, exclude_from_weight_decay=["var1"] + ) + assert opt._do_use_weight_decay("var0") + assert not opt._do_use_weight_decay("var1") + assert not opt._do_use_weight_decay("var1_weight") + if __name__ == "__main__": sys.exit(pytest.main([__file__])) From d6e2523b352abd00a45336aa8d786f4525d8945c Mon Sep 17 00:00:00 2001 From: Jared Nielsen Date: Thu, 2 Apr 2020 14:42:34 -0700 Subject: [PATCH 5/6] Layer adaption test --- tensorflow_addons/optimizers/lamb_test.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tensorflow_addons/optimizers/lamb_test.py b/tensorflow_addons/optimizers/lamb_test.py index 54b6a6773f..7fdc93d0c7 100644 --- a/tensorflow_addons/optimizers/lamb_test.py +++ b/tensorflow_addons/optimizers/lamb_test.py @@ -409,3 +409,9 @@ def test_exclude_weight_decay(self): assert opt._do_use_weight_decay("var0") assert not opt._do_use_weight_decay("var1") assert not opt._do_use_weight_decay("var1_weight") + + def test_exclude_layer_adaptation(self): + opt = lamb.LAMB(0.01, exclude_from_layer_adaptation=["var1"]) + assert opt._do_layer_adaptation("var0") + assert not opt._do_layer_adaptation("var1") + assert not opt._do_layer_adaptation("var1") From 71bcfd888bb3266bf5ef8967ad9ac01d743d59ca Mon Sep 17 00:00:00 2001 From: Jared Nielsen Date: Thu, 2 Apr 2020 14:44:30 -0700 Subject: [PATCH 6/6] Typo --- tensorflow_addons/optimizers/lamb_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/optimizers/lamb_test.py b/tensorflow_addons/optimizers/lamb_test.py index 7fdc93d0c7..ede68e5f24 100644 --- a/tensorflow_addons/optimizers/lamb_test.py +++ b/tensorflow_addons/optimizers/lamb_test.py @@ -414,4 +414,4 @@ def test_exclude_layer_adaptation(self): opt = lamb.LAMB(0.01, exclude_from_layer_adaptation=["var1"]) assert opt._do_layer_adaptation("var0") assert not opt._do_layer_adaptation("var1") - assert not opt._do_layer_adaptation("var1") + assert not opt._do_layer_adaptation("var1_weight")