diff --git a/tensorflow_addons/layers/wrappers.py b/tensorflow_addons/layers/wrappers.py index fd101b219a..a5df48664e 100644 --- a/tensorflow_addons/layers/wrappers.py +++ b/tensorflow_addons/layers/wrappers.py @@ -92,11 +92,13 @@ def build(self, input_shape): trainable=False) if self.data_init: - self._naked_layer = tf.keras.layers.deserialize( - tf.keras.layers.serialize(self.layer)) - self._naked_layer.build(input_shape) - self._naked_layer.set_weights(self.layer.get_weights()) - self._naked_layer.activation = None + # Used for data initialization in self._data_dep_init. + layer_config = tf.keras.layers.serialize(self.layer) + layer_config['config']['trainable'] = False + self._naked_clone_layer = tf.keras.layers.deserialize(layer_config) + self._naked_clone_layer.build(input_shape) + self._naked_clone_layer.set_weights(self.layer.get_weights()) + self._naked_clone_layer.activation = None self.built = True @@ -104,20 +106,25 @@ def call(self, inputs): """Call `Layer`""" def _do_nothing(): - return inputs + return tf.identity(self.g) def _update_weights(): - self._initialize_weights(inputs) - return inputs + # Ensure we read `self.g` after _update_weights. + with tf.control_dependencies(self._initialize_weights(inputs)): + return tf.identity(self.g) - inputs = tf.cond(self._initialized, _do_nothing, _update_weights) + g = tf.cond(self._initialized, _do_nothing, _update_weights) with tf.name_scope('compute_weights'): # Replace kernel by normalized weight variable. self.layer.kernel = tf.nn.l2_normalize( - self.v, axis=self.kernel_norm_axes) * self.g + self.v, axis=self.kernel_norm_axes) * g - return self.layer(inputs) + # Ensure we calculate result after updating kernel. + update_kernel = tf.identity(self.layer.kernel) + with tf.control_dependencies([update_kernel]): + outputs = self.layer(inputs) + return outputs def compute_output_shape(self, input_shape): return tf.TensorShape( @@ -136,31 +143,36 @@ def _initialize_weights(self, inputs): message='The layer has been initialized.') ]): if self.data_init: - self._data_dep_init(inputs) + assign_tensors = self._data_dep_init(inputs) else: - self._init_norm() - self._initialized.assign(True) + assign_tensors = self._init_norm() + assign_tensors.append(self._initialized.assign(True)) + return assign_tensors def _init_norm(self): """Set the weight g with the norm of the weight vector.""" with tf.name_scope('init_norm'): v_flat = tf.reshape(self.v, [-1, self.layer_depth]) v_norm = tf.linalg.norm(v_flat, axis=0) - self.g.assign(tf.reshape(v_norm, (self.layer_depth,))) + g_tensor = self.g.assign(tf.reshape(v_norm, (self.layer_depth,))) + return [g_tensor] def _data_dep_init(self, inputs): """Data dependent initialization.""" with tf.name_scope('data_dep_init'): # Generate data dependent init values - x_init = self._naked_layer(inputs) + x_init = self._naked_clone_layer(inputs) data_norm_axes = list(range(x_init.shape.rank - 1)) m_init, v_init = tf.nn.moments(x_init, data_norm_axes) scale_init = 1. / tf.math.sqrt(v_init + 1e-10) # Assign data dependent init values - self.g.assign(self.g * scale_init) + g_tensor = self.g.assign(self.g * scale_init) if hasattr(self.layer, 'bias'): - self.layer.bias.assign(-m_init * scale_init) + bias_tensor = self.layer.bias.assign(-m_init * scale_init) + return [g_tensor, bias_tensor] + else: + return [g_tensor] def get_config(self): config = {'data_init': self.data_init} diff --git a/tensorflow_addons/layers/wrappers_test.py b/tensorflow_addons/layers/wrappers_test.py index 9d83bbec50..b4bdb9c494 100644 --- a/tensorflow_addons/layers/wrappers_test.py +++ b/tensorflow_addons/layers/wrappers_test.py @@ -26,82 +26,52 @@ @test_utils.run_all_in_graph_and_eager_modes class WeightNormalizationTest(tf.test.TestCase): - def test_weightnorm_dense_train(self): - model = tf.keras.models.Sequential() - model.add( - wrappers.WeightNormalization( - tf.keras.layers.Dense(2), input_shape=(3, 4))) - model.compile( - optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001), - loss='mse') - model.fit( - np.random.random((10, 3, 4)), - np.random.random((10, 3, 2)), - epochs=3, - batch_size=10) - self.assertTrue(hasattr(model.layers[0], 'g')) - - def test_weightnorm_dense_train_notinit(self): - model = tf.keras.models.Sequential() - model.add( - wrappers.WeightNormalization( - tf.keras.layers.Dense(2), input_shape=(3, 4), data_init=False)) - - model.compile( - optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001), - loss='mse') - model.fit( - np.random.random((10, 3, 4)), - np.random.random((10, 3, 2)), - epochs=3, - batch_size=10) - self.assertTrue(hasattr(model.layers[0], 'g')) - - def test_weightnorm_conv2d(self): - model = tf.keras.models.Sequential() - model.add( - wrappers.WeightNormalization( - tf.keras.layers.Conv2D(5, (2, 2), padding='same'), - input_shape=(4, 4, 3))) - - model.add(tf.keras.layers.Activation('relu')) - model.compile( - optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001), - loss='mse') - model.fit( - np.random.random((2, 4, 4, 3)), - np.random.random((2, 4, 4, 5)), - epochs=3, - batch_size=10) - - self.assertTrue(hasattr(model.layers[0], 'g')) - - def test_weightnorm_applylayer(self): - images = tf.random.uniform((2, 4, 4, 3)) - wn_wrapper = wrappers.WeightNormalization( - tf.keras.layers.Conv2D(32, [2, 2]), input_shape=(4, 4, 3)) - wn_wrapper.apply(images) - self.assertTrue(hasattr(wn_wrapper, 'g')) - - def test_weightnorm_nonlayer(self): - images = tf.random.uniform((2, 4, 43)) - with self.assertRaises(AssertionError): - wrappers.WeightNormalization(images) - - def test_weightnorm_nokernel(self): - with self.assertRaises(ValueError): - wrappers.WeightNormalization(tf.keras.layers.MaxPooling2D( - 2, 2)).build((2, 2)) - - def test_weightnorm_keras(self): - input_data = np.random.random((10, 3, 4)).astype(np.float32) + def test_weightnorm(self): + test_utils.layer_test( + wrappers.WeightNormalization, + kwargs={ + 'layer': tf.keras.layers.Conv2D(5, (2, 2)), + }, + input_shape=(2, 4, 4, 3)) + + def _check_data_init(self, data_init, input_data, expected_output): + layer = tf.keras.layers.Dense( + input_data.shape[-1], + activation=None, + kernel_initializer='identity', + bias_initializer='zeros') test_utils.layer_test( wrappers.WeightNormalization, kwargs={ - 'layer': tf.keras.layers.Dense(2), - 'input_shape': (3, 4) + 'layer': layer, + 'data_init': data_init, }, - input_data=input_data) + input_data=input_data, + expected_output=expected_output) + + def test_weightnorm_with_data_init_is_false(self): + input_data = np.array([[[-4, -4], [4, 4]]], dtype=np.float32) + self._check_data_init( + data_init=False, input_data=input_data, expected_output=input_data) + + def test_weightnorm_with_data_init_is_true(self): + input_data = np.array([[[-4, -4], [4, 4]]], dtype=np.float32) + self._check_data_init( + data_init=True, + input_data=input_data, + expected_output=input_data / 4) + + def test_weightnorm_non_layer(self): + images = tf.random.uniform((2, 4, 43)) + with self.assertRaises(AssertionError): + wrappers.WeightNormalization(images) + + def test_weightnorm_non_kernel_layer(self): + images = tf.random.uniform((2, 2, 2)) + with self.assertRaisesRegexp(ValueError, 'contains a `kernel`'): + non_kernel_layer = tf.keras.layers.MaxPooling2D(2, 2) + wn_wrapper = wrappers.WeightNormalization(non_kernel_layer) + wn_wrapper(images) if __name__ == "__main__":