diff --git a/tensorflow_addons/optimizers/discriminative_layer_training.py b/tensorflow_addons/optimizers/discriminative_layer_training.py index 494c29d365..80add77378 100644 --- a/tensorflow_addons/optimizers/discriminative_layer_training.py +++ b/tensorflow_addons/optimizers/discriminative_layer_training.py @@ -14,7 +14,7 @@ # ============================================================================== """Discriminative Layer Training Optimizer for TensorFlow.""" -from typing import Union +from typing import List, Union import tensorflow as tf from typeguard import typechecked @@ -24,55 +24,53 @@ class MultiOptimizer(tf.keras.optimizers.Optimizer): """Multi Optimizer Wrapper for Discriminative Layer Training. - Creates a wrapper around a set of instantiated optimizer layer pairs. Generally useful for transfer learning - of deep networks. + Creates a wrapper around a set of instantiated optimizer layer pairs. + Generally useful for transfer learning of deep networks. - Each optimizer will optimize only the weights associated with its paired layer. This can be used - to implement discriminative layer training by assigning different learning rates to each optimizer - layer pair. (Optimizer, list(Layers)) pairs are also supported. Please note that the layers must be - instantiated before instantiating the optimizer. + Each optimizer will optimize only the weights associated with its paired layer. + This can be used to implement discriminative layer training by assigning + different learning rates to each optimizer layer pair. + `(tf.keras.optimizers.Optimizer, List[tf.keras.layers.Layer])` pairs are also supported. + Please note that the layers must be instantiated before instantiating the optimizer. Args: - optimizers_and_layers: a list of tuples of an optimizer and a layer or model. Each tuple should contain - exactly 1 instantiated optimizer and 1 object that subclasses tf.keras.Model or tf.keras.Layer. Nested - layers and models will be automatically discovered. Alternatively, in place of a single layer, you can pass - a list of layers. - optimizer_specs: specialized list for serialization. Should be left as None for almost all cases. If you are - loading a serialized version of this optimizer, please use tf.keras.models.load_model after saving a - model compiled with this optimizer. + optimizers_and_layers: a list of tuples of an optimizer and a layer or model. + Each tuple should contain exactly 1 instantiated optimizer and 1 object that + subclasses `tf.keras.Model`, `tf.keras.Sequential` or `tf.keras.layers.Layer`. + Nested layers and models will be automatically discovered. + Alternatively, in place of a single layer, you can pass a list of layers. + optimizer_specs: specialized list for serialization. + Should be left as None for almost all cases. + If you are loading a serialized version of this optimizer, + please use `tf.keras.models.load_model` after saving a model compiled with this optimizer. Usage: - ```python - model = get_model() - - opt1 = tf.keras.optimizers.Adam(learning_rate=1e-4) - opt2 = tf.keras.optimizers.Adam(learning_rate=1e-2) - - opt_layer_pairs = [(opt1, model.layers[0]), (opt2, model.layers[1:])] - - loss = tf.keras.losses.MSE - optimizer = tfa.optimizers.MultiOpt(opt_layer_pairs) - - model.compile(optimizer=optimizer, loss = loss) - - model.fit(x,y) - ''' + >>> model = tf.keras.Sequential([ + ... tf.keras.Input(shape=(4,)), + ... tf.keras.layers.Dense(8), + ... tf.keras.layers.Dense(16), + ... tf.keras.layers.Dense(32), + ... ]) + >>> optimizers = [ + ... tf.keras.optimizers.Adam(learning_rate=1e-4), + ... tf.keras.optimizers.Adam(learning_rate=1e-2) + ... ] + >>> optimizers_and_layers = [(optimizers[0], model.layers[0]), (optimizers[1], model.layers[1:])] + >>> optimizer = tfa.optimizers.MultiOptimizer(optimizers_and_layers) + >>> model.compile(optimizer=optimizer, loss="mse") Reference: + - [Universal Language Model Fine-tuning for Text Classification](https://arxiv.org/abs/1801.06146) + - [Collaborative Layer-wise Discriminative Learning in Deep Neural Networks](https://arxiv.org/abs/1607.05440) - [Universal Language Model Fine-tuning for Text Classification](https://arxiv.org/abs/1801.06146) - [Collaborative Layer-wise Discriminative Learning in Deep Neural Networks](https://arxiv.org/abs/1607.05440) - - Notes: - - Currently, MultiOpt does not support callbacks that modify optimizers. However, you can instantiate - optimizer layer pairs with tf.keras.optimizers.schedules.LearningRateSchedule instead of a static learning - rate. + Note: Currently, `tfa.optimizers.MultiOptimizer` does not support callbacks that modify optimizers. + However, you can instantiate optimizer layer pairs with + `tf.keras.optimizers.schedules.LearningRateSchedule` + instead of a static learning rate. - This code should function on CPU, GPU, and TPU. Apply the with strategy.scope() context as you + This code should function on CPU, GPU, and TPU. Apply with `tf.distribute.Strategy().scope()` context as you would with any other optimizer. - """ @typechecked @@ -80,7 +78,7 @@ def __init__( self, optimizers_and_layers: Union[list, None] = None, optimizer_specs: Union[list, None] = None, - name: str = "MultiOptimzer", + name: str = "MultiOptimizer", **kwargs ): @@ -88,8 +86,8 @@ def __init__( if optimizer_specs is None and optimizers_and_layers is not None: self.optimizer_specs = [ - self.create_optimizer_spec(opt, layer) - for opt, layer in optimizers_and_layers + self.create_optimizer_spec(optimizer, layers_or_model) + for optimizer, layers_or_model in optimizers_and_layers ] elif optimizer_specs is not None and optimizers_and_layers is None: @@ -99,14 +97,13 @@ def __init__( else: raise RuntimeError( - "You must specify either an list of optimizers and layers or a list of optimizer_specs" + "Must specify one of `optimizers_and_layers` or `optimizer_specs`." ) - def apply_gradients(self, grads_and_vars, name=None, **kwargs): + def apply_gradients(self, grads_and_vars, **kwargs): """Wrapped apply_gradient method. - Returns a list of tf ops to be executed. - Name of variable is used rather than var.ref() to enable serialization and deserialization. + Returns an operation to be executed. """ for spec in self.optimizer_specs: @@ -131,29 +128,35 @@ def get_config(self): return config @classmethod - def create_optimizer_spec(cls, optimizer_instance, layer): - - assert isinstance( - optimizer_instance, tf.keras.optimizers.Optimizer - ), "Object passed is not an instance of tf.keras.optimizers.Optimizer" - - assert isinstance(layer, tf.keras.layers.Layer) or isinstance( - layer, tf.keras.Model - ), "Object passed is not an instance of tf.keras.layers.Layer nor tf.keras.Model" + def create_optimizer_spec( + cls, + optimizer: tf.keras.optimizers.Optimizer, + layers_or_model: Union[ + tf.keras.Model, + tf.keras.Sequential, + tf.keras.layers.Layer, + List[tf.keras.layers.Layer], + ], + ): + """Creates a serializable optimizer spec. - if type(layer) == list: - weights = [var.name for sublayer in layer for var in sublayer.weights] + The name of each variable is used rather than `var.ref()` to enable serialization and deserialization. + """ + if isinstance(layers_or_model, list): + weights = [ + var.name for sublayer in layers_or_model for var in sublayer.weights + ] else: - weights = [var.name for var in layer.weights] + weights = [var.name for var in layers_or_model.weights] return { - "optimizer": optimizer_instance, + "optimizer": optimizer, "weights": weights, } @classmethod def maybe_initialize_optimizer_spec(cls, optimizer_spec): - if type(optimizer_spec["optimizer"]) == dict: + if isinstance(optimizer_spec["optimizer"], dict): optimizer_spec["optimizer"] = tf.keras.optimizers.deserialize( optimizer_spec["optimizer"] ) diff --git a/tensorflow_addons/optimizers/tests/discriminative_layer_training_test.py b/tensorflow_addons/optimizers/tests/discriminative_layer_training_test.py index 08a096b840..80a613e586 100644 --- a/tensorflow_addons/optimizers/tests/discriminative_layer_training_test.py +++ b/tensorflow_addons/optimizers/tests/discriminative_layer_training_test.py @@ -22,32 +22,19 @@ from tensorflow_addons.utils import test_utils -def _dtypes_to_test(use_gpu): - # Based on issue #347 in the following link, - # "https://github.com/tensorflow/addons/issues/347" - # tf.half is not registered for 'ResourceScatterUpdate' OpKernel - # for 'GPU' devices. - # So we have to remove tf.half when testing with gpu. - # The function "_DtypesToTest" is from - # "https://github.com/tensorflow/tensorflow/blob/5d4a6cee737a1dc6c20172a1dc1 - # 5df10def2df72/tensorflow/python/kernel_tests/conv_ops_3d_test.py#L53-L62" - # TODO(WindQAQ): Clean up this in TF2.4 - - if use_gpu: - return [tf.float32, tf.float64] - else: - return [tf.half, tf.float32, tf.float64] +def assert_list_allclose(a, b): + for x, y in zip(a, b): + np.testing.assert_allclose(x, y) -@pytest.mark.with_device(["cpu", "gpu"]) -@pytest.mark.parametrize("dtype", [tf.float16, tf.float32, tf.float64]) -@pytest.mark.parametrize("serialize", [True, False]) -def test_fit_layer_optimizer(dtype, device, serialize): - # Test ensures that each optimizer is only optimizing its own layer with its learning rate +def assert_list_not_allclose(a, b): + for x, y in zip(a, b): + test_utils.assert_not_allclose(x, y) - if "gpu" in device and dtype == tf.float16: - pytest.xfail("See https://github.com/tensorflow/addons/issues/347") +@pytest.mark.with_device(["cpu", "gpu"]) +@pytest.mark.parametrize("serialize", [True, False]) +def test_fit_layer_optimizer(device, serialize, tmpdir): model = tf.keras.Sequential( [tf.keras.Input(shape=[1]), tf.keras.layers.Dense(1), tf.keras.layers.Dense(1)] ) @@ -55,10 +42,8 @@ def test_fit_layer_optimizer(dtype, device, serialize): x = np.array(np.ones([100])) y = np.array(np.ones([100])) - weights_before_train = ( - model.layers[0].weights[0].numpy(), - model.layers[1].weights[0].numpy(), - ) + dense1_weights_before_train = [weight.numpy() for weight in model.layers[0].weights] + dense2_weights_before_train = [weight.numpy() for weight in model.layers[1].weights] opt1 = tf.keras.optimizers.Adam(learning_rate=1e-3) opt2 = tf.keras.optimizers.SGD(learning_rate=0) @@ -72,31 +57,205 @@ def test_fit_layer_optimizer(dtype, device, serialize): # serialize whole model including optimizer, clear the session, then reload the whole model. if serialize: - model.save("test", save_format="tf") + model.save(str(tmpdir), save_format="tf") tf.keras.backend.clear_session() - model = tf.keras.models.load_model("test") + model = tf.keras.models.load_model(str(tmpdir)) model.fit(x, y, batch_size=8, epochs=10) - weights_after_train = ( - model.layers[0].weights[0].numpy(), - model.layers[1].weights[0].numpy(), - ) + dense1_weights_after_train = [weight.numpy() for weight in model.layers[0].weights] + dense2_weights_after_train = [weight.numpy() for weight in model.layers[1].weights] + + assert_list_not_allclose(dense1_weights_before_train, dense1_weights_after_train) + assert_list_allclose(dense2_weights_before_train, dense2_weights_after_train) - with np.testing.assert_raises(AssertionError): - # expect weights to be different for layer 1 - test_utils.assert_allclose_according_to_type( - weights_before_train[0], weights_after_train[0] - ) - # expect weights to be same for layer 2 - test_utils.assert_allclose_according_to_type( - weights_before_train[1], weights_after_train[1] +def test_list_of_layers(): + model = tf.keras.Sequential( + [ + tf.keras.Input(shape=(4,)), + tf.keras.layers.Dense(16), + tf.keras.layers.Dense(16), + tf.keras.layers.Dense(32), + tf.keras.layers.Dense(32), + ] ) + optimizers_and_layers = [ + (tf.keras.optimizers.SGD(learning_rate=0.0), model.layers[0]), + (tf.keras.optimizers.Adam(), model.layers[1]), + (tf.keras.optimizers.Adam(), model.layers[2:]), + ] -def test_serialization(): + weights_before_train = [ + [weight.numpy() for weight in layer.weights] for layer in model.layers + ] + + multi_optimizer = MultiOptimizer(optimizers_and_layers) + model.compile(multi_optimizer, loss="mse") + + x = np.ones((128, 4)).astype(np.float32) + y = np.ones((128, 32)).astype(np.float32) + model.fit(x, y, batch_size=32, epochs=10) + + weights_after_train = [ + [weight.numpy() for weight in layer.weights] for layer in model.layers + ] + + assert_list_allclose(weights_before_train[0], weights_after_train[0]) + + for layer_before, layer_after in zip( + weights_before_train[1:], weights_after_train[1:] + ): + assert_list_not_allclose(layer_before, layer_after) + + +def test_model(): + inputs = tf.keras.Input(shape=(4,)) + output = tf.keras.layers.Dense(16)(inputs) + output = tf.keras.layers.Dense(16)(output) + output = tf.keras.layers.Dense(32)(output) + output = tf.keras.layers.Dense(32)(output) + model = tf.keras.Model(inputs, output) + + # Adam optimizer on the whole model and an additional SGD on the last layer. + optimizers_and_layers = [ + (tf.keras.optimizers.Adam(), model), + (tf.keras.optimizers.SGD(), model.layers[-1]), + ] + + multi_optimizer = MultiOptimizer(optimizers_and_layers) + model.compile(multi_optimizer, loss="mse") + + x = np.ones((128, 4)).astype(np.float32) + y = np.ones((128, 32)).astype(np.float32) + model.fit(x, y, batch_size=32, epochs=10) + + +def test_subclass_model(): + class Block(tf.keras.Model): + def __init__(self, units): + super().__init__() + self.dense1 = tf.keras.layers.Dense(units) + self.dense2 = tf.keras.layers.Dense(units) + + def call(self, x): + return self.dense2(self.dense1(x)) + + class Custom(tf.keras.Model): + def __init__(self): + super().__init__() + self.block1 = Block(16) + self.block2 = Block(32) + + def call(self, x): + return self.block2(self.block1(x)) + + model = Custom() + model.build(input_shape=(None, 4)) + + optimizers_and_layers = [ + (tf.keras.optimizers.SGD(learning_rate=0.0), model.block1), + (tf.keras.optimizers.Adam(), model.block2), + ] + + block1_weights_before_train = [weight.numpy() for weight in model.block1.weights] + block2_weights_before_train = [weight.numpy() for weight in model.block2.weights] + multi_optimizer = MultiOptimizer(optimizers_and_layers) + + x = np.ones((128, 4)).astype(np.float32) + y = np.ones((128, 32)).astype(np.float32) + mse = tf.keras.losses.MeanSquaredError() + + for _ in range(10): + for i in range(0, 128, 32): + x_batch = x[i : i + 32] + y_batch = y[i : i + 32] + with tf.GradientTape() as tape: + loss = mse(y_batch, model(x_batch)) + + grads = tape.gradient(loss, model.trainable_variables) + multi_optimizer.apply_gradients(zip(grads, model.trainable_variables)) + + block1_weights_after_train = [weight.numpy() for weight in model.block1.weights] + block2_weights_after_train = [weight.numpy() for weight in model.block2.weights] + + assert_list_allclose(block1_weights_before_train, block1_weights_after_train) + assert_list_not_allclose(block2_weights_before_train, block2_weights_after_train) + + +def test_pretrained_model(): + resnet = tf.keras.applications.ResNet50(include_top=False, weights=None) + dense = tf.keras.layers.Dense(32) + model = tf.keras.Sequential([resnet, dense]) + + resnet_weights_before_train = [ + weight.numpy() for weight in resnet.trainable_weights + ] + dense_weights_before_train = [weight.numpy() for weight in dense.weights] + + optimizers_and_layers = [(tf.keras.optimizers.SGD(), dense)] + + multi_optimizer = MultiOptimizer(optimizers_and_layers) + model.compile(multi_optimizer, loss="mse") + + x = np.ones((128, 32, 32, 3)).astype(np.float32) + y = np.ones((128, 32)).astype(np.float32) + model.fit(x, y, batch_size=32) + + resnet_weights_after_train = [weight.numpy() for weight in resnet.trainable_weights] + dense_weights_after_train = [weight.numpy() for weight in dense.weights] + + assert_list_allclose(resnet_weights_before_train, resnet_weights_after_train) + assert_list_not_allclose(dense_weights_before_train, dense_weights_after_train) + + +def test_nested_model(): + def get_model(): + inputs = tf.keras.Input(shape=(4,)) + outputs = tf.keras.layers.Dense(1)(inputs) + return tf.keras.Model(inputs, outputs) + + model1 = get_model() + model2 = get_model() + model3 = get_model() + + inputs = tf.keras.Input(shape=(4,)) + y1 = model1(inputs) + y2 = model2(inputs) + y3 = model3(inputs) + outputs = tf.keras.layers.Average()([y1, y2, y3]) + model = tf.keras.Model(inputs, outputs) + + optimizers_and_layers = [ + (tf.keras.optimizers.SGD(), model1), + (tf.keras.optimizers.SGD(learning_rate=0.0), model2), + (tf.keras.optimizers.SGD(), model3), + ] + + model1_weights_before_train = [weight.numpy() for weight in model1.weights] + model2_weights_before_train = [weight.numpy() for weight in model2.weights] + model3_weights_before_train = [weight.numpy() for weight in model3.weights] + + multi_optimizer = MultiOptimizer(optimizers_and_layers) + + model.compile(multi_optimizer, loss="mse") + + x = np.ones((128, 4)).astype(np.float32) + y = np.ones((128, 32)).astype(np.float32) + model.fit(x, y) + + model1_weights_after_train = [weight.numpy() for weight in model1.weights] + model2_weights_after_train = [weight.numpy() for weight in model2.weights] + model3_weights_after_train = [weight.numpy() for weight in model3.weights] + + assert_list_not_allclose(model1_weights_before_train, model1_weights_after_train) + assert_list_allclose(model2_weights_before_train, model2_weights_after_train) + assert_list_not_allclose(model3_weights_before_train, model3_weights_after_train) + + +def test_serialization(): model = tf.keras.Sequential( [tf.keras.Input(shape=[1]), tf.keras.layers.Dense(1), tf.keras.layers.Dense(1)] ) diff --git a/tensorflow_addons/utils/test_utils.py b/tensorflow_addons/utils/test_utils.py index cdb33bcf82..c4206a3d31 100644 --- a/tensorflow_addons/utils/test_utils.py +++ b/tensorflow_addons/utils/test_utils.py @@ -217,6 +217,25 @@ def pytest_collection_modifyitems(items): item.add_marker(pytest.mark.skip("The gpu is not available.")) +def assert_not_allclose(a, b, **kwargs): + """Assert that two numpy arrays, do not have near values. + + Args: + a: the first value to compare. + b: the second value to compare. + **kwargs: additional keyword arguments to be passed to the underlying + `np.testing.assert_allclose` call. + + Raises: + AssertionError: If `a` and `b` are unexpectedly close at all elements. + """ + try: + np.testing.assert_allclose(a, b, **kwargs) + except AssertionError: + return + raise AssertionError("The two values are close at all elements") + + def assert_allclose_according_to_type( a, b,