diff --git a/tensorflow_addons/layers/wrappers.py b/tensorflow_addons/layers/wrappers.py index c83c8d2b51..2852d33a34 100644 --- a/tensorflow_addons/layers/wrappers.py +++ b/tensorflow_addons/layers/wrappers.py @@ -16,6 +16,8 @@ from __future__ import division from __future__ import print_function +import logging + import tensorflow as tf @@ -61,6 +63,11 @@ def __init__(self, layer, data_init=True, **kwargs): self._init_critical_section = tf.CriticalSection(name='init_mutex') self.is_rnn = isinstance(self.layer, tf.keras.layers.RNN) + if self.data_init and self.is_rnn: + logging.warn( + "WeightNormalization: Using `data_init=True` with RNNs " + "is advised against by the paper. Use `data_init=False`.") + def build(self, input_shape): """Build `Layer`""" input_shape = tf.TensorShape(input_shape) @@ -76,17 +83,22 @@ def build(self, input_shape): raise ValueError('`WeightNormalization` must wrap a layer that' ' contains a `kernel` for weights') + if self.is_rnn: + kernel = kernel_layer.recurrent_kernel + else: + kernel = kernel_layer.kernel + # The kernel's filter or unit dimension is -1 - self.layer_depth = int(kernel_layer.kernel.shape[-1]) - self.kernel_norm_axes = list(range(kernel_layer.kernel.shape.rank - 1)) + self.layer_depth = int(kernel.shape[-1]) + self.kernel_norm_axes = list(range(kernel.shape.rank - 1)) self.g = self.add_weight( name='g', shape=(self.layer_depth,), initializer='ones', - dtype=kernel_layer.kernel.dtype, + dtype=kernel.dtype, trainable=True) - self.v = kernel_layer.kernel + self.v = kernel self._initialized = self.add_weight( name='initialized', @@ -104,9 +116,7 @@ def build(self, input_shape): layer_config) self._naked_clone_layer.build(input_shape) self._naked_clone_layer.set_weights(self.layer.get_weights()) - if self.is_rnn: - self._naked_clone_layer.cell.activation = None - else: + if not self.is_rnn: self._naked_clone_layer.activation = None self.built = True @@ -127,11 +137,16 @@ def _update_weights(): with tf.name_scope('compute_weights'): # Replace kernel by normalized weight variable. - self.layer.kernel = tf.nn.l2_normalize( - self.v, axis=self.kernel_norm_axes) * g + kernel = tf.nn.l2_normalize(self.v, axis=self.kernel_norm_axes) * g + + if self.is_rnn: + self.layer.cell.recurrent_kernel = kernel + update_kernel = tf.identity(self.layer.cell.recurrent_kernel) + else: + self.layer.kernel = kernel + update_kernel = tf.identity(self.layer.kernel) # Ensure we calculate result after updating kernel. - update_kernel = tf.identity(self.layer.kernel) with tf.control_dependencies([update_kernel]): outputs = self.layer(inputs) return outputs @@ -176,6 +191,14 @@ def _data_dep_init(self, inputs): m_init, v_init = tf.nn.moments(x_init, data_norm_axes) scale_init = 1. / tf.math.sqrt(v_init + 1e-10) + # RNNs have fused kernels that are tiled + # Repeat scale_init to match the shape of fused kernel + # Note: This is only to support the operation, + # the paper advises against RNN+data_dep_init + if scale_init.shape[0] != self.g.shape[0]: + rep = int(self.g.shape[0] / scale_init.shape[0]) + scale_init = tf.tile(scale_init, [rep]) + # Assign data dependent init values g_tensor = self.g.assign(self.g * scale_init) if hasattr(self.layer, 'bias') and self.layer.bias is not None: @@ -188,3 +211,15 @@ def get_config(self): config = {'data_init': self.data_init} base_config = super(WeightNormalization, self).get_config() return dict(list(base_config.items()) + list(config.items())) + + def remove(self): + kernel = tf.Variable( + tf.nn.l2_normalize(self.v, axis=self.kernel_norm_axes) * self.g, + name='recurrent_kernel' if self.is_rnn else 'kernel') + + if self.is_rnn: + self.layer.cell.recurrent_kernel = kernel + else: + self.layer.kernel = kernel + + return self.layer diff --git a/tensorflow_addons/layers/wrappers_test.py b/tensorflow_addons/layers/wrappers_test.py index 77da7af9ab..455e46bda8 100644 --- a/tensorflow_addons/layers/wrappers_test.py +++ b/tensorflow_addons/layers/wrappers_test.py @@ -17,6 +17,8 @@ from __future__ import division from __future__ import print_function +from absl.testing import parameterized + import numpy as np import tensorflow as tf @@ -25,8 +27,8 @@ @test_utils.run_all_in_graph_and_eager_modes -class WeightNormalizationTest(tf.test.TestCase): - def test_weightnorm(self): +class WeightNormalizationTest(tf.test.TestCase, parameterized.TestCase): + def test_basic(self): test_utils.layer_test( wrappers.WeightNormalization, kwargs={ @@ -34,7 +36,7 @@ def test_weightnorm(self): }, input_shape=(2, 4, 4, 3)) - def test_weightnorm_no_bias(self): + def test_no_bias(self): test_utils.layer_test( wrappers.WeightNormalization, kwargs={ @@ -57,31 +59,31 @@ def _check_data_init(self, data_init, input_data, expected_output): input_data=input_data, expected_output=expected_output) - def test_weightnorm_with_data_init_is_false(self): + def test_with_data_init_is_false(self): input_data = np.array([[[-4, -4], [4, 4]]], dtype=np.float32) self._check_data_init( data_init=False, input_data=input_data, expected_output=input_data) - def test_weightnorm_with_data_init_is_true(self): + def test_with_data_init_is_true(self): input_data = np.array([[[-4, -4], [4, 4]]], dtype=np.float32) self._check_data_init( data_init=True, input_data=input_data, expected_output=input_data / 4) - def test_weightnorm_non_layer(self): + def test_non_layer(self): images = tf.random.uniform((2, 4, 43)) with self.assertRaises(AssertionError): wrappers.WeightNormalization(images) - def test_weightnorm_non_kernel_layer(self): + def test_non_kernel_layer(self): images = tf.random.uniform((2, 2, 2)) with self.assertRaisesRegexp(ValueError, 'contains a `kernel`'): non_kernel_layer = tf.keras.layers.MaxPooling2D(2, 2) wn_wrapper = wrappers.WeightNormalization(non_kernel_layer) wn_wrapper(images) - def test_weightnorm_with_time_dist(self): + def test_with_time_dist(self): batch_shape = (32, 16, 64, 64, 3) inputs = tf.keras.layers.Input(batch_shape=batch_shape) a = tf.keras.layers.Conv2D(3, 5) @@ -89,21 +91,89 @@ def test_weightnorm_with_time_dist(self): out = tf.keras.layers.TimeDistributed(b)(inputs) model = tf.keras.Model(inputs, out) - def test_weightnorm_with_rnn(self): - inputs = tf.keras.layers.Input(shape=(None, 3)) - rnn_layer = tf.keras.layers.SimpleRNN(4) - wt_rnn = wrappers.WeightNormalization(rnn_layer) - dense = tf.keras.layers.Dense(1) - model = tf.keras.models.Sequential(layers=[inputs, wt_rnn, dense]) - - def test_save_file_h5(self): + @parameterized.named_parameters( + ["Dense", lambda: tf.keras.layers.Dense(1), False], + ["SimpleRNN", lambda: tf.keras.layers.SimpleRNN(1), True], + ["Conv2D", lambda: tf.keras.layers.Conv2D(3, 1), False], + ["LSTM", lambda: tf.keras.layers.LSTM(1), True]) + def test_serialization(self, base_layer, rnn): + base_layer = base_layer() + wn_layer = wrappers.WeightNormalization(base_layer, not rnn) + new_wn_layer = tf.keras.layers.deserialize( + tf.keras.layers.serialize(wn_layer)) + self.assertEqual(wn_layer.data_init, new_wn_layer.data_init) + self.assertEqual(wn_layer.is_rnn, new_wn_layer.is_rnn) + self.assertEqual(wn_layer.is_rnn, rnn) + if not isinstance(base_layer, tf.keras.layers.LSTM): + # Issue with LSTM serialization, check with TF-core + # Before serialization: tensorflow.python.keras.layers.recurrent_v2.LSTM + # After serialization: tensorflow.python.keras.layers.recurrent.LSTM + self.assertTrue( + isinstance(new_wn_layer.layer, base_layer.__class__)) + + @parameterized.named_parameters( + ["Dense", lambda: tf.keras.layers.Dense(1), [25]], + ["SimpleRNN", lambda: tf.keras.layers.SimpleRNN(1), [None, 10]], + ["Conv2D", lambda: tf.keras.layers.Conv2D(3, 1), [3, 3, 1]], + ["LSTM", lambda: tf.keras.layers.LSTM(1), [10, 10]]) + def test_model_build(self, base_layer_fn, input_shape): + inputs = tf.keras.layers.Input(shape=input_shape) + for data_init in [True, False]: + base_layer = base_layer_fn() + wt_layer = wrappers.WeightNormalization(base_layer, data_init) + model = tf.keras.models.Sequential(layers=[inputs, wt_layer]) + model.build() + + @parameterized.named_parameters( + ["Dense", lambda: tf.keras.layers.Dense(1), [25]], + ["SimpleRNN", lambda: tf.keras.layers.SimpleRNN(1), [10, 10]], + ["Conv2D", lambda: tf.keras.layers.Conv2D(3, 1), [3, 3, 1]], + ["LSTM", lambda: tf.keras.layers.LSTM(1), [10, 10]]) + def test_save_file_h5(self, base_layer, input_shape): self.create_tempfile('wrapper_test_model.h5') - conv = tf.keras.layers.Conv1D(1, 1) - wn_conv = wrappers.WeightNormalization(conv) + base_layer = base_layer() + wn_conv = wrappers.WeightNormalization(base_layer) model = tf.keras.Sequential(layers=[wn_conv]) - model.build([1, 2, 3]) + model.build([None] + input_shape) model.save_weights('wrapper_test_model.h5') + @parameterized.named_parameters( + ["Dense", lambda: tf.keras.layers.Dense(1), [25]], + ["SimpleRNN", lambda: tf.keras.layers.SimpleRNN(1), [10, 10]], + ["Conv2D", lambda: tf.keras.layers.Conv2D(3, 1), [3, 3, 1]], + ["LSTM", lambda: tf.keras.layers.LSTM(1), [10, 10]]) + def test_forward_pass(self, base_layer, input_shape): + sample_data = np.ones([1] + input_shape, dtype=np.float32) + base_layer = base_layer() + base_output = base_layer(sample_data) + wn_layer = wrappers.WeightNormalization(base_layer, False) + wn_output = wn_layer(sample_data) + self.evaluate(tf.compat.v1.global_variables_initializer()) + self.assertAllClose( + self.evaluate(base_output), self.evaluate(wn_output)) + + @parameterized.named_parameters( + ["Dense", lambda: tf.keras.layers.Dense(1), [25]], + ["SimpleRNN", lambda: tf.keras.layers.SimpleRNN(1), [10, 10]], + ["Conv2D", lambda: tf.keras.layers.Conv2D(3, 1), [3, 3, 1]], + ["LSTM", lambda: tf.keras.layers.LSTM(1), [10, 10]]) + def test_removal(self, base_layer_fn, input_shape): + sample_data = np.ones([1] + input_shape, dtype=np.float32) + + for data_init in [True, False]: + base_layer = base_layer_fn() + wn_layer = wrappers.WeightNormalization(base_layer, data_init) + wn_output = wn_layer(sample_data) + self.evaluate(tf.compat.v1.global_variables_initializer()) + with tf.control_dependencies([wn_output]): + wn_removed_layer = wn_layer.remove() + wn_removed_output = wn_removed_layer(sample_data) + + self.evaluate(tf.compat.v1.global_variables_initializer()) + self.assertAllClose( + self.evaluate(wn_removed_output), self.evaluate(wn_output)) + self.assertTrue(isinstance(wn_removed_layer, base_layer.__class__)) + if __name__ == "__main__": tf.test.main()