diff --git a/tensorflow_addons/optimizers/__init__.py b/tensorflow_addons/optimizers/__init__.py index e3a6b996dc..5e1d90f2e4 100644 --- a/tensorflow_addons/optimizers/__init__.py +++ b/tensorflow_addons/optimizers/__init__.py @@ -26,6 +26,9 @@ from tensorflow_addons.optimizers.cyclical_learning_rate import ( ExponentialCyclicalLearningRate, ) +from tensorflow_addons.optimizers.discriminative_layer_training import ( + MultiOptimzer, +) from tensorflow_addons.optimizers.lamb import LAMB from tensorflow_addons.optimizers.lazy_adam import LazyAdam from tensorflow_addons.optimizers.lookahead import Lookahead diff --git a/tensorflow_addons/optimizers/discriminative_layer_training.py b/tensorflow_addons/optimizers/discriminative_layer_training.py new file mode 100644 index 0000000000..e56387d682 --- /dev/null +++ b/tensorflow_addons/optimizers/discriminative_layer_training.py @@ -0,0 +1,166 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Discriminative Layer Training Optimizer for TensorFlow.""" + +from typing import Union + +import tensorflow as tf +from typeguard import typechecked + + +@tf.keras.utils.register_keras_serializable(package="Addons") +class MultiOptimzer(tf.keras.optimizers.Optimizer): + """Multi Optimizer Wrapper for Discriminative Layer Training. + + Creates a wrapper around a set of instantiated optimizer layer pairs. Generally useful for transfer learning + of deep networks. + + Each optimizer will optimize only the weights associated with its paired layer. This can be used + to implement discriminative layer training by assigning different learning rates to each optimizer + layer pair. (Optimizer, list(Layers)) pairs are also supported. Please note that the layers must be + instantiated before instantiating the optimizer. + + Args: + optimizers_and_layers: a list of tuples of an optimizer and a layer or model. Each tuple should contain + exactly 1 instantiated optimizer and 1 object that subclasses tf.keras.Model or tf.keras.Layer. Nested + layers and models will be automatically discovered. Alternatively, in place of a single layer, you can pass + a list of layers. + optimizer_specs: specialized list for serialization. Should be left as None for almost all cases. If you are + loading a serialized version of this optimizer, please use tf.keras.models.load_model after saving a + model compiled with this optimizer. + + Usage: + + ```python + model = get_model() + + opt1 = tf.keras.optimizers.Adam(learning_rate=1e-4) + opt2 = tf.keras.optimizers.Adam(learning_rate=1e-2) + + opt_layer_pairs = [(opt1, model.layers[0]), (opt2, model.layers[1:])] + + loss = tf.keras.losses.MSE + optimizer = tfa.optimizers.MultiOpt(opt_layer_pairs) + + model.compile(optimizer=optimizer, loss = loss) + + model.fit(x,y) + ''' + + Reference: + + [Universal Language Model Fine-tuning for Text Classification](https://arxiv.org/abs/1801.06146) + [Collaborative Layer-wise Discriminative Learning in Deep Neural Networks](https://arxiv.org/abs/1607.05440) + + Notes: + + Currently, MultiOpt does not support callbacks that modify optimizers. However, you can instantiate + optimizer layer pairs with tf.keras.optimizers.schedules.LearningRateSchedule instead of a static learning + rate. + + This code should function on CPU, GPU, and TPU. Apply the with strategy.scope() context as you + would with any other optimizer. + + """ + + @typechecked + def __init__( + self, + optimizers_and_layers: Union[list, None] = None, + optimizer_specs: Union[list, None] = None, + name: str = "MultiOptimzer", + **kwargs + ): + + super(MultiOptimzer, self).__init__(name, **kwargs) + + if optimizer_specs is None and optimizers_and_layers is not None: + self.optimizer_specs = [ + self.create_optimizer_spec(opt, layer) + for opt, layer in optimizers_and_layers + ] + + elif optimizer_specs is not None and optimizers_and_layers is None: + self.optimizer_specs = [ + self.maybe_initialize_optimizer_spec(spec) for spec in optimizer_specs + ] + + else: + raise RuntimeError( + "You must specify either an list of optimizers and layers or a list of optimizer_specs" + ) + + def apply_gradients(self, grads_and_vars, name=None, **kwargs): + """Wrapped apply_gradient method. + + Returns a list of tf ops to be executed. + Name of variable is used rather than var.ref() to enable serialization and deserialization. + """ + + for spec in self.optimizer_specs: + spec["gv"] = [] + + for grad, var in tuple(grads_and_vars): + for spec in self.optimizer_specs: + for name in spec["weights"]: + if var.name == name: + spec["gv"].append((grad, var)) + + return tf.group( + [ + spec["optimizer"].apply_gradients(spec["gv"], **kwargs) + for spec in self.optimizer_specs + ] + ) + + def get_config(self): + config = super(MultiOptimzer, self).get_config() + config.update({"optimizer_specs": self.optimizer_specs}) + return config + + @classmethod + def create_optimizer_spec(cls, optimizer_instance, layer): + + assert isinstance( + optimizer_instance, tf.keras.optimizers.Optimizer + ), "Object passed is not an instance of tf.keras.optimizers.Optimizer" + + assert isinstance(layer, tf.keras.layers.Layer) or isinstance( + layer, tf.keras.Model + ), "Object passed is not an instance of tf.keras.layers.Layer nor tf.keras.Model" + + if type(layer) == list: + weights = [var.name for sublayer in layer for var in sublayer.weights] + else: + weights = [var.name for var in layer.weights] + + return { + "optimizer": optimizer_instance, + "weights": weights, + } + + @classmethod + def maybe_initialize_optimizer_spec(cls, optimizer_spec): + if type(optimizer_spec["optimizer"]) == dict: + optimizer_spec["optimizer"] = tf.keras.optimizers.deserialize( + optimizer_spec["optimizer"] + ) + + return optimizer_spec + + def __repr__(self): + return "Multi Optimizer with %i optimizer layer pairs" % len( + self.optimizer_specs + ) diff --git a/tensorflow_addons/optimizers/tests/discriminative_layer_training_test.py b/tensorflow_addons/optimizers/tests/discriminative_layer_training_test.py new file mode 100644 index 0000000000..8a93cfb903 --- /dev/null +++ b/tensorflow_addons/optimizers/tests/discriminative_layer_training_test.py @@ -0,0 +1,113 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Discriminative Layer Training Optimizer for TensorFlow.""" + +import pytest +import numpy as np +import tensorflow as tf + +from tensorflow_addons.optimizers.discriminative_layer_training import MultiOptimzer +from tensorflow_addons.utils import test_utils + + +def _dtypes_to_test(use_gpu): + # Based on issue #347 in the following link, + # "https://github.com/tensorflow/addons/issues/347" + # tf.half is not registered for 'ResourceScatterUpdate' OpKernel + # for 'GPU' devices. + # So we have to remove tf.half when testing with gpu. + # The function "_DtypesToTest" is from + # "https://github.com/tensorflow/tensorflow/blob/5d4a6cee737a1dc6c20172a1dc1 + # 5df10def2df72/tensorflow/python/kernel_tests/conv_ops_3d_test.py#L53-L62" + # TODO(WindQAQ): Clean up this in TF2.4 + + if use_gpu: + return [tf.float32, tf.float64] + else: + return [tf.half, tf.float32, tf.float64] + + +@pytest.mark.with_device(["cpu", "gpu"]) +@pytest.mark.parametrize("dtype", [tf.float16, tf.float32, tf.float64]) +@pytest.mark.parametrize("serialize", [True, False]) +def test_fit_layer_optimizer(dtype, device, serialize): + # Test ensures that each optimizer is only optimizing its own layer with its learning rate + + if "gpu" in device and dtype == tf.float16: + pytest.xfail("See https://github.com/tensorflow/addons/issues/347") + + model = tf.keras.Sequential( + [tf.keras.Input(shape=[1]), tf.keras.layers.Dense(1), tf.keras.layers.Dense(1)] + ) + + x = np.array(np.ones([100])) + y = np.array(np.ones([100])) + + weights_before_train = ( + model.layers[0].weights[0].numpy(), + model.layers[1].weights[0].numpy(), + ) + + opt1 = tf.keras.optimizers.Adam(learning_rate=1e-3) + opt2 = tf.keras.optimizers.SGD(learning_rate=0) + + opt_layer_pairs = [(opt1, model.layers[0]), (opt2, model.layers[1])] + + loss = tf.keras.losses.MSE + optimizer = MultiOptimzer(opt_layer_pairs) + + model.compile(optimizer=optimizer, loss=loss) + + # serialize whole model including optimizer, clear the session, then reload the whole model. + if serialize: + model.save("test", save_format="tf") + tf.keras.backend.clear_session() + model = tf.keras.models.load_model("test") + + model.fit(x, y, batch_size=8, epochs=10) + + weights_after_train = ( + model.layers[0].weights[0].numpy(), + model.layers[1].weights[0].numpy(), + ) + + with np.testing.assert_raises(AssertionError): + # expect weights to be different for layer 1 + test_utils.assert_allclose_according_to_type( + weights_before_train[0], weights_after_train[0] + ) + + # expect weights to be same for layer 2 + test_utils.assert_allclose_according_to_type( + weights_before_train[1], weights_after_train[1] + ) + + +def test_serialization(): + + model = tf.keras.Sequential( + [tf.keras.Input(shape=[1]), tf.keras.layers.Dense(1), tf.keras.layers.Dense(1)] + ) + + opt1 = tf.keras.optimizers.Adam(learning_rate=1e-3) + opt2 = tf.keras.optimizers.SGD(learning_rate=0) + + opt_layer_pairs = [(opt1, model.layers[0]), (opt2, model.layers[1])] + + optimizer = MultiOptimzer(opt_layer_pairs) + config = tf.keras.optimizers.serialize(optimizer) + + new_optimizer = tf.keras.optimizers.deserialize(config) + assert new_optimizer.get_config() == optimizer.get_config()