From 3428f28fbaab02ca58658598084aa5d8b7681402 Mon Sep 17 00:00:00 2001 From: leondgarse Date: Wed, 15 Dec 2021 10:05:28 +0800 Subject: [PATCH 1/6] exclude_from_weight_decay for AdamW and SGDW --- .../optimizers/weight_decay_optimizers.py | 41 ++++++++++++++++--- 1 file changed, 36 insertions(+), 5 deletions(-) diff --git a/tensorflow_addons/optimizers/weight_decay_optimizers.py b/tensorflow_addons/optimizers/weight_decay_optimizers.py index bf26d03bfd..bd57549785 100644 --- a/tensorflow_addons/optimizers/weight_decay_optimizers.py +++ b/tensorflow_addons/optimizers/weight_decay_optimizers.py @@ -14,11 +14,12 @@ # ============================================================================== """Base class to make optimizers weight decay ready.""" +import re import tensorflow as tf from tensorflow_addons.utils.types import FloatTensorLike from typeguard import typechecked -from typing import Union, Callable, Type +from typing import Union, Callable, Type, Optional, List class DecoupledWeightDecayExtension: @@ -71,7 +72,12 @@ def __init__(self, weight_decay, *args, **kwargs): """ @typechecked - def __init__(self, weight_decay: Union[FloatTensorLike, Callable], **kwargs): + def __init__( + self, + weight_decay: Union[FloatTensorLike, Callable], + exclude_from_weight_decay: Optional[List[str]] = None, + **kwargs, + ): """Extension class that adds weight decay to an optimizer. Args: @@ -85,10 +91,16 @@ def __init__(self, weight_decay: Union[FloatTensorLike, Callable], **kwargs): super().__init__(**kwargs) self._decay_var_list = None # is set in minimize or apply_gradients self._set_hyper("weight_decay", wd) + self.exclude_from_weight_decay = exclude_from_weight_decay def get_config(self): config = super().get_config() - config.update({"weight_decay": self._serialize_hyperparameter("weight_decay")}) + config.update( + { + "weight_decay": self._serialize_hyperparameter("weight_decay"), + "exclude_from_weight_decay": self.exclude_from_weight_decay, + } + ) return config @classmethod @@ -173,7 +185,7 @@ def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None, **kwar return super().apply_gradients(grads_and_vars, name=name, **kwargs) def _decay_weights_op(self, var, apply_state=None): - if not self._decay_var_list or var.ref() in self._decay_var_list: + if self._do_use_weight_decay(var): var_device, var_dtype = var.device, var.dtype.base_dtype coefficients = (apply_state or {}).get( (var_device, var_dtype) @@ -183,7 +195,7 @@ def _decay_weights_op(self, var, apply_state=None): return tf.no_op() def _decay_weights_sparse_op(self, var, indices, apply_state=None): - if not self._decay_var_list or var.ref() in self._decay_var_list: + if self._do_use_weight_decay(var): var_device, var_dtype = var.device, var.dtype.base_dtype coefficients = (apply_state or {}).get( (var_device, var_dtype) @@ -226,6 +238,25 @@ def _resource_apply_sparse(self, grad, var, indices, apply_state=None): grad, var, indices, apply_state=apply_state ) + def _do_use_weight_decay(self, var): + """Whether to use L2 weight decay for `var`.""" + if not self._decay_var_list or var.ref() in self._decay_var_list: + if self.exclude_from_weight_decay: + var_name = self._get_variable_name(var.name) + for r in self.exclude_from_weight_decay: + if re.search(r, var_name) is not None: + # print("Filtered:", var_name) + return False + return True + return False + + def _get_variable_name(self, param_name): + """Get the variable name from the tensor name.""" + m = re.match("^(.*):\\d+$", param_name) + if m is not None: + param_name = m.group(1) + return param_name + @typechecked def extend_with_decoupled_weight_decay( From d66d8247978697edd9cc8b9c8f805f7cf7ef5b08 Mon Sep 17 00:00:00 2001 From: leondgarse Date: Sat, 18 Dec 2021 12:53:01 +0800 Subject: [PATCH 2/6] exclude_from_weight_decay for AdamW and SGDW --- tensorflow_addons/optimizers/lamb.py | 39 +++++--------- .../optimizers/tests/lamb_test.py | 16 +++--- .../tests/weight_decay_optimizers_test.py | 54 ++++++++++++++++++- tensorflow_addons/optimizers/utils.py | 21 ++++++++ .../optimizers/weight_decay_optimizers.py | 22 +++----- 5 files changed, 102 insertions(+), 50 deletions(-) diff --git a/tensorflow_addons/optimizers/lamb.py b/tensorflow_addons/optimizers/lamb.py index d3f9abbd75..0dfc0413a2 100644 --- a/tensorflow_addons/optimizers/lamb.py +++ b/tensorflow_addons/optimizers/lamb.py @@ -18,7 +18,6 @@ 76 minutes](https://arxiv.org/abs/1904.00962). """ -import re import warnings from typing import Optional, Union, Callable, List @@ -26,6 +25,7 @@ import tensorflow as tf from tensorflow_addons.utils.types import FloatTensorLike +from tensorflow_addons.optimizers.utils import is_variable_excluded_by_regexes @tf.keras.utils.register_keras_serializable(package="Addons") @@ -163,12 +163,11 @@ def _resource_apply_dense(self, grad, var, apply_state=None): v_sqrt = tf.sqrt(v_t_hat) update = m_t_hat / (v_sqrt + coefficients["epsilon"]) - var_name = self._get_variable_name(var.name) - if self._do_use_weight_decay(var_name): + if self._do_use_weight_decay(var): update += coefficients["weight_decay"] * var ratio = 1.0 - if self._do_layer_adaptation(var_name): + if self._do_layer_adaptation(var): w_norm = tf.norm(var, ord=2) g_norm = tf.norm(update, ord=2) ratio = tf.where( @@ -206,12 +205,11 @@ def _resource_apply_sparse(self, grad, var, indices, apply_state=None): v_sqrt = tf.sqrt(v_t_hat) update = m_t_hat / (v_sqrt + coefficients["epsilon"]) - var_name = self._get_variable_name(var.name) - if self._do_use_weight_decay(var_name): + if self._do_use_weight_decay(var): update += coefficients["weight_decay"] * var ratio = 1.0 - if self._do_layer_adaptation(var_name): + if self._do_layer_adaptation(var): w_norm = tf.norm(var, ord=2) g_norm = tf.norm(update, ord=2) ratio = tf.where( @@ -241,26 +239,15 @@ def get_config(self): ) return config - def _do_use_weight_decay(self, param_name): + def _do_use_weight_decay(self, variable): """Whether to use L2 weight decay for `param_name`.""" - if self.exclude_from_weight_decay: - for r in self.exclude_from_weight_decay: - if re.search(r, param_name) is not None: - return False - return True + return not is_variable_excluded_by_regexes( + variable, self.exclude_from_weight_decay + ) - def _do_layer_adaptation(self, param_name): + def _do_layer_adaptation(self, variable): """Whether to do layer-wise learning rate adaptation for `param_name`.""" - if self.exclude_from_layer_adaptation: - for r in self.exclude_from_layer_adaptation: - if re.search(r, param_name) is not None: - return False - return True - - def _get_variable_name(self, param_name): - """Get the variable name from the tensor name.""" - m = re.match("^(.*):\\d+$", param_name) - if m is not None: - param_name = m.group(1) - return param_name + return not is_variable_excluded_by_regexes( + variable, self.exclude_from_layer_adaptation + ) diff --git a/tensorflow_addons/optimizers/tests/lamb_test.py b/tensorflow_addons/optimizers/tests/lamb_test.py index 631aed99a5..e80f7e4ae6 100644 --- a/tensorflow_addons/optimizers/tests/lamb_test.py +++ b/tensorflow_addons/optimizers/tests/lamb_test.py @@ -335,20 +335,22 @@ def test_get_config(): def test_exclude_weight_decay(): opt = lamb.LAMB(0.01, weight_decay=0.01, exclude_from_weight_decay=["var1"]) - assert opt._do_use_weight_decay("var0") - assert not opt._do_use_weight_decay("var1") - assert not opt._do_use_weight_decay("var1_weight") + assert opt._do_use_weight_decay(tf.Variable([], name="var0")) + assert not opt._do_use_weight_decay(tf.Variable([], name="var1")) + assert not opt._do_use_weight_decay(tf.Variable([], name="var1_weight")) def test_exclude_layer_adaptation(): opt = lamb.LAMB(0.01, exclude_from_layer_adaptation=["var1"]) - assert opt._do_layer_adaptation("var0") - assert not opt._do_layer_adaptation("var1") - assert not opt._do_layer_adaptation("var1_weight") + assert opt._do_layer_adaptation(tf.Variable([], name="var0")) + assert not opt._do_layer_adaptation(tf.Variable([], name="var1")) + assert not opt._do_layer_adaptation(tf.Variable([], name="var1_weight")) def test_serialization(): - optimizer = lamb.LAMB(1e-4) + optimizer = lamb.LAMB( + 1e-4, weight_decay_rate=0.01, exclude_from_weight_decay=["var1"] + ) config = tf.keras.optimizers.serialize(optimizer) new_optimizer = tf.keras.optimizers.deserialize(config) assert new_optimizer.get_config() == optimizer.get_config() diff --git a/tensorflow_addons/optimizers/tests/weight_decay_optimizers_test.py b/tensorflow_addons/optimizers/tests/weight_decay_optimizers_test.py index 6a832585c7..d9ac1fe0f1 100644 --- a/tensorflow_addons/optimizers/tests/weight_decay_optimizers_test.py +++ b/tensorflow_addons/optimizers/tests/weight_decay_optimizers_test.py @@ -80,6 +80,7 @@ def do_test( opt = optimizer(**optimizer_kwargs) # Create the update op. # Run 3 steps of the optimizer + optimizer_kwargs.pop("exclude_from_weight_decay", None) for _ in range(3): if do_decay_var_list: opt.apply_gradients( @@ -241,6 +242,31 @@ def test_basic_decay_var_list_adamw(dtype): ) +def test_exclude_weight_decay_adamw(): + optimizer = weight_decay_optimizers.AdamW( + learning_rate=1e-4, weight_decay=1e-4, exclude_from_weight_decay=["var1"] + ) + assert optimizer._do_use_weight_decay(tf.Variable([], name="var0")) + assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1")) + assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1_weight")) + + +@pytest.mark.parametrize("dtype", [(tf.half, 0), (tf.float32, 1), (tf.float64, 2)]) +def test_var_list_with_exclude_list_adamw(dtype): + do_test( + dtype, + weight_decay_optimizers.AdamW, + adamw_update_numpy, + do_decay_var_list=True, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-8, + weight_decay=WEIGHT_DECAY, + exclude_from_weight_decay=["var0_*", "var1_*"], + ) + + def test_keras_fit(): """Check if calling model.fit works.""" model = tf.keras.models.Sequential([tf.keras.layers.Dense(2)]) @@ -341,6 +367,30 @@ def test_basic_decay_var_list_sgdw(dtype): ) +def test_exclude_weight_decay_sgdw(): + optimizer = weight_decay_optimizers.SGDW( + learning_rate=0.01, weight_decay=1e-4, exclude_from_weight_decay=["var1"] + ) + assert optimizer._do_use_weight_decay(tf.Variable([], name="var0")) + assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1")) + assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1_weight")) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +@pytest.mark.parametrize("dtype", [(tf.half, 0), (tf.float32, 1), (tf.float64, 2)]) +def test_var_list_with_exclude_list_sgdw(dtype): + do_test( + dtype, + weight_decay_optimizers.SGDW, + sgdw_update_numpy, + do_decay_var_list=True, + learning_rate=0.001, + momentum=0.9, + weight_decay=WEIGHT_DECAY, + exclude_from_weight_decay=["var0_*", "var1_*"], + ) + + @pytest.mark.parametrize( "optimizer", [ @@ -379,7 +429,9 @@ def test_optimizer_sparse(dtype, optimizer): def test_serialization(): - optimizer = weight_decay_optimizers.AdamW(learning_rate=1e-4, weight_decay=1e-4) + optimizer = weight_decay_optimizers.AdamW( + learning_rate=1e-4, weight_decay=1e-4, exclude_from_weight_decay=["var1"] + ) config = tf.keras.optimizers.serialize(optimizer) new_optimizer = tf.keras.optimizers.deserialize(config) assert new_optimizer.get_config() == optimizer.get_config() diff --git a/tensorflow_addons/optimizers/utils.py b/tensorflow_addons/optimizers/utils.py index af365edbff..1948c59008 100644 --- a/tensorflow_addons/optimizers/utils.py +++ b/tensorflow_addons/optimizers/utils.py @@ -14,7 +14,9 @@ # ============================================================================== """Additional Utilities used for tfa.optimizers.""" +import re import tensorflow as tf +from typing import List def fit_bn(model, *args, **kwargs): @@ -51,3 +53,22 @@ def fit_bn(model, *args, **kwargs): model.trainable = _trainable model._metrics = _metrics + + +def get_variable_name(variable) -> str: + """Get the variable name from the variable tensor.""" + param_name = variable.name + m = re.match("^(.*):\\d+$", param_name) + if m is not None: + param_name = m.group(1) + return param_name + + +def is_variable_excluded_by_regexes(variable, exclude_regexes: List[str]) -> bool: + """Whether to use L2 weight decay for `param_name`.""" + if exclude_regexes: + var_name = get_variable_name(variable) + for r in exclude_regexes: + if re.search(r, var_name): + return True + return False diff --git a/tensorflow_addons/optimizers/weight_decay_optimizers.py b/tensorflow_addons/optimizers/weight_decay_optimizers.py index bd57549785..7c5d4c9735 100644 --- a/tensorflow_addons/optimizers/weight_decay_optimizers.py +++ b/tensorflow_addons/optimizers/weight_decay_optimizers.py @@ -14,9 +14,9 @@ # ============================================================================== """Base class to make optimizers weight decay ready.""" -import re import tensorflow as tf from tensorflow_addons.utils.types import FloatTensorLike +from tensorflow_addons.optimizers.utils import is_variable_excluded_by_regexes from typeguard import typechecked from typing import Union, Callable, Type, Optional, List @@ -84,6 +84,9 @@ def __init__( weight_decay: A `Tensor`, a floating point value, or a schedule that is a `tf.keras.optimizers.schedules.LearningRateSchedule` to decay the variable by, in the update step. + exclude_from_weight_decay: List of regex patterns of + variables excluded from weight decay. Variables whose name + contain a substring matching the pattern will be excluded. **kwargs: Optional list or tuple or set of `Variable` objects to decay. """ @@ -240,22 +243,9 @@ def _resource_apply_sparse(self, grad, var, indices, apply_state=None): def _do_use_weight_decay(self, var): """Whether to use L2 weight decay for `var`.""" - if not self._decay_var_list or var.ref() in self._decay_var_list: - if self.exclude_from_weight_decay: - var_name = self._get_variable_name(var.name) - for r in self.exclude_from_weight_decay: - if re.search(r, var_name) is not None: - # print("Filtered:", var_name) - return False + if self._decay_var_list and var.ref() in self._decay_var_list: return True - return False - - def _get_variable_name(self, param_name): - """Get the variable name from the tensor name.""" - m = re.match("^(.*):\\d+$", param_name) - if m is not None: - param_name = m.group(1) - return param_name + return not is_variable_excluded_by_regexes(var, self.exclude_from_weight_decay) @typechecked From cbc8935106bcc4e455fc8e64b29d1d48e51dbd57 Mon Sep 17 00:00:00 2001 From: leondgarse Date: Sat, 18 Dec 2021 13:09:05 +0800 Subject: [PATCH 3/6] exclude_from_weight_decay for AdamW and SGDW --- tensorflow_addons/optimizers/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow_addons/optimizers/utils.py b/tensorflow_addons/optimizers/utils.py index 1948c59008..7baffa131a 100644 --- a/tensorflow_addons/optimizers/utils.py +++ b/tensorflow_addons/optimizers/utils.py @@ -67,7 +67,8 @@ def get_variable_name(variable) -> str: def is_variable_excluded_by_regexes(variable, exclude_regexes: List[str]) -> bool: """Whether to use L2 weight decay for `param_name`.""" if exclude_regexes: - var_name = get_variable_name(variable) + # var_name = get_variable_name(variable) + var_name = variable.name for r in exclude_regexes: if re.search(r, var_name): return True From ae7e3984041d4b9c3f8259a5ad117e0ca289b980 Mon Sep 17 00:00:00 2001 From: leondgarse Date: Sat, 18 Dec 2021 13:30:11 +0800 Subject: [PATCH 4/6] exclude_from_weight_decay for AdamW and SGDW --- tensorflow_addons/optimizers/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/optimizers/utils.py b/tensorflow_addons/optimizers/utils.py index 7baffa131a..a95c8eaea5 100644 --- a/tensorflow_addons/optimizers/utils.py +++ b/tensorflow_addons/optimizers/utils.py @@ -65,7 +65,7 @@ def get_variable_name(variable) -> str: def is_variable_excluded_by_regexes(variable, exclude_regexes: List[str]) -> bool: - """Whether to use L2 weight decay for `param_name`.""" + """Whether variable is excluded in exclude_regexes by its name.""" if exclude_regexes: # var_name = get_variable_name(variable) var_name = variable.name From 707bf8d8cc5ee254a35366a6c750d2605bfc91e9 Mon Sep 17 00:00:00 2001 From: leondgarse Date: Mon, 20 Dec 2021 09:30:48 +0800 Subject: [PATCH 5/6] exclude_from_weight_decay for AdamW and SGDW --- tensorflow_addons/optimizers/lamb.py | 6 +-- tensorflow_addons/optimizers/utils.py | 2 +- .../optimizers/weight_decay_optimizers.py | 48 ++++++++++++------- 3 files changed, 34 insertions(+), 22 deletions(-) diff --git a/tensorflow_addons/optimizers/lamb.py b/tensorflow_addons/optimizers/lamb.py index 0dfc0413a2..b166657251 100644 --- a/tensorflow_addons/optimizers/lamb.py +++ b/tensorflow_addons/optimizers/lamb.py @@ -25,7 +25,7 @@ import tensorflow as tf from tensorflow_addons.utils.types import FloatTensorLike -from tensorflow_addons.optimizers.utils import is_variable_excluded_by_regexes +from tensorflow_addons.optimizers.utils import is_variable_matched_by_regexes @tf.keras.utils.register_keras_serializable(package="Addons") @@ -241,13 +241,13 @@ def get_config(self): def _do_use_weight_decay(self, variable): """Whether to use L2 weight decay for `param_name`.""" - return not is_variable_excluded_by_regexes( + return not is_variable_matched_by_regexes( variable, self.exclude_from_weight_decay ) def _do_layer_adaptation(self, variable): """Whether to do layer-wise learning rate adaptation for `param_name`.""" - return not is_variable_excluded_by_regexes( + return not is_variable_matched_by_regexes( variable, self.exclude_from_layer_adaptation ) diff --git a/tensorflow_addons/optimizers/utils.py b/tensorflow_addons/optimizers/utils.py index a95c8eaea5..355dbdf74a 100644 --- a/tensorflow_addons/optimizers/utils.py +++ b/tensorflow_addons/optimizers/utils.py @@ -64,7 +64,7 @@ def get_variable_name(variable) -> str: return param_name -def is_variable_excluded_by_regexes(variable, exclude_regexes: List[str]) -> bool: +def is_variable_matched_by_regexes(variable, exclude_regexes: List[str]) -> bool: """Whether variable is excluded in exclude_regexes by its name.""" if exclude_regexes: # var_name = get_variable_name(variable) diff --git a/tensorflow_addons/optimizers/weight_decay_optimizers.py b/tensorflow_addons/optimizers/weight_decay_optimizers.py index 7c5d4c9735..3d882b0169 100644 --- a/tensorflow_addons/optimizers/weight_decay_optimizers.py +++ b/tensorflow_addons/optimizers/weight_decay_optimizers.py @@ -16,7 +16,7 @@ import tensorflow as tf from tensorflow_addons.utils.types import FloatTensorLike -from tensorflow_addons.optimizers.utils import is_variable_excluded_by_regexes +from tensorflow_addons.optimizers.utils import is_variable_matched_by_regexes from typeguard import typechecked from typing import Union, Callable, Type, Optional, List @@ -87,6 +87,8 @@ def __init__( exclude_from_weight_decay: List of regex patterns of variables excluded from weight decay. Variables whose name contain a substring matching the pattern will be excluded. + Note `decay_var_list` in `minimize` or `apply_gradients` takes + priority over `exclude_from_weight_decay` if specified. **kwargs: Optional list or tuple or set of `Variable` objects to decay. """ @@ -145,7 +147,8 @@ def minimize( grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. decay_var_list: Optional list of variables to be decayed. Defaults - to all variables in var_list. + to all variables in var_list. Note `decay_var_list` takes + priority over `exclude_from_weight_decay` if specified. name: Optional name for the returned operation. tape: (Optional) `tf.GradientTape`. If `loss` is provided as a `Tensor`, the tape that computed the `loss` must be provided. @@ -169,10 +172,11 @@ def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None, **kwar Args: grads_and_vars: List of (gradient, variable) pairs. - name: Optional name for the returned operation. Default to the + name: Optional name for the returned operation. Default to the name passed to the `Optimizer` constructor. decay_var_list: Optional list of variables to be decayed. Defaults - to all variables in var_list. + to all variables in var_list. Note `decay_var_list` takes + priority over `exclude_from_weight_decay` if specified. **kwargs: Additional arguments to pass to the base optimizer's apply_gradient method, e.g., TF2.2 added an argument `experimental_aggregate_gradients`. @@ -245,7 +249,7 @@ def _do_use_weight_decay(self, var): """Whether to use L2 weight decay for `var`.""" if self._decay_var_list and var.ref() in self._decay_var_list: return True - return not is_variable_excluded_by_regexes(var, self.exclude_from_weight_decay) + return not is_variable_matched_by_regexes(var, self.exclude_from_weight_decay) @typechecked @@ -264,9 +268,13 @@ def extend_with_decoupled_weight_decay( The API of the new optimizer class slightly differs from the API of the base optimizer: - The first argument to the constructor is the weight decay rate. + - Optional keyword argument `exclude_from_weight_decay` accepts list of + regex patterns of variables excluded from weight decay. Variables whose + name contain a substring matching the pattern will be excluded. - `minimize` and `apply_gradients` accept the optional keyword argument `decay_var_list`, which specifies the variables that should be decayed. - If `None`, all variables that are optimized are decayed. + Note this takes priority over `exclude_from_weight_decay` if specified. + If both `None`, all variables that are optimized are decayed. Usage example: ```python @@ -397,12 +405,14 @@ def __init__( nesterov: boolean. Whether to apply Nesterov momentum. name: Optional name prefix for the operations created when applying gradients. Defaults to 'SGD'. - **kwargs: keyword arguments. Allowed to be {`clipnorm`, - `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by - norm; `clipvalue` is clip gradients by value, `decay` is - included for backward compatibility to allow time inverse decay - of learning rate. `lr` is included for backward compatibility, - recommended to use `learning_rate` instead. + **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, + `lr`, `decay`, `exclude_from_weight_decay`}. `clipnorm` is clip + gradients by norm; `clipvalue` is clip gradients by value. + `decay` is included for backward compatibility to allow time + inverse decay of learning rate. `lr` is included for backward + compatibility, recommended to use `learning_rate` instead. + `exclude_from_weight_decay` accepts list of regex patterns of + variables excluded from weight decay. """ super().__init__( weight_decay, @@ -487,12 +497,14 @@ def __init__( beyond". name: Optional name for the operations created when applying gradients. Defaults to "AdamW". - **kwargs: keyword arguments. Allowed to be {`clipnorm`, - `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by - norm; `clipvalue` is clip gradients by value, `decay` is - included for backward compatibility to allow time inverse decay - of learning rate. `lr` is included for backward compatibility, - recommended to use `learning_rate` instead. + **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, + `lr`, `decay`, `exclude_from_weight_decay`}. `clipnorm` is clip + gradients by norm; `clipvalue` is clip gradients by value. + `decay` is included for backward compatibility to allow time + inverse decay of learning rate. `lr` is included for backward + compatibility, recommended to use `learning_rate` instead. + `exclude_from_weight_decay` accepts list of regex patterns of + variables excluded from weight decay. """ super().__init__( weight_decay, From 6c512e1ec49ae14ca61f1bbf0fd3403204ba21e2 Mon Sep 17 00:00:00 2001 From: leondgarse Date: Mon, 20 Dec 2021 09:42:09 +0800 Subject: [PATCH 6/6] exclude_from_weight_decay for AdamW and SGDW --- tensorflow_addons/optimizers/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow_addons/optimizers/utils.py b/tensorflow_addons/optimizers/utils.py index 355dbdf74a..91d0d92b1e 100644 --- a/tensorflow_addons/optimizers/utils.py +++ b/tensorflow_addons/optimizers/utils.py @@ -64,12 +64,12 @@ def get_variable_name(variable) -> str: return param_name -def is_variable_matched_by_regexes(variable, exclude_regexes: List[str]) -> bool: - """Whether variable is excluded in exclude_regexes by its name.""" - if exclude_regexes: +def is_variable_matched_by_regexes(variable, regexes: List[str]) -> bool: + """Whether variable is matched in regexes list by its name.""" + if regexes: # var_name = get_variable_name(variable) var_name = variable.name - for r in exclude_regexes: + for r in regexes: if re.search(r, var_name): return True return False