diff --git a/tensorflow_addons/layers/normalizations.py b/tensorflow_addons/layers/normalizations.py index 33d80a3f86..5264e1395a 100644 --- a/tensorflow_addons/layers/normalizations.py +++ b/tensorflow_addons/layers/normalizations.py @@ -155,9 +155,6 @@ def get_config(self): base_config = super().get_config() return {**base_config, **config} - def compute_output_shape(self, input_shape): - return input_shape - def _reshape_into_groups(self, inputs, input_shape, tensor_input_shape): group_shape = [tensor_input_shape[i] for i in range(len(input_shape))] @@ -447,9 +444,6 @@ def call(self, inputs): normalized_inputs = inputs * tf.math.rsqrt(nu2 + epsilon) return self.gamma * normalized_inputs + self.beta - def compute_output_shape(self, input_shape): - return input_shape - def get_config(self): config = { "axis": self.axis, diff --git a/tensorflow_addons/layers/tests/normalizations_test.py b/tensorflow_addons/layers/tests/normalizations_test.py index f77d725801..98442ca7f9 100644 --- a/tensorflow_addons/layers/tests/normalizations_test.py +++ b/tensorflow_addons/layers/tests/normalizations_test.py @@ -346,6 +346,52 @@ def test_groupnorm_convnet_no_center_no_scale(): ) +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +@pytest.mark.parametrize("center", [True, False]) +@pytest.mark.parametrize("scale", [True, False]) +def test_group_norm_compute_output_shape(center, scale): + + target_variables_len = [center, scale].count(True) + target_trainable_variables_len = [center, scale].count(True) + layer1 = GroupNormalization(groups=2, center=center, scale=scale) + layer1.build(input_shape=[8, 28, 28, 16]) # build() + assert len(layer1.variables) == target_variables_len + assert len(layer1.trainable_variables) == target_trainable_variables_len + + layer2 = GroupNormalization(groups=2, center=center, scale=scale) + layer2.compute_output_shape(input_shape=[8, 28, 28, 16]) # compute_output_shape() + assert len(layer2.variables) == target_variables_len + assert len(layer2.trainable_variables) == target_trainable_variables_len + + layer3 = GroupNormalization(groups=2, center=center, scale=scale) + layer3(tf.random.normal(shape=[8, 28, 28, 16])) # call() + assert len(layer3.variables) == target_variables_len + assert len(layer3.trainable_variables) == target_trainable_variables_len + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +@pytest.mark.parametrize("center", [True, False]) +@pytest.mark.parametrize("scale", [True, False]) +def test_instance_norm_compute_output_shape(center, scale): + + target_variables_len = [center, scale].count(True) + target_trainable_variables_len = [center, scale].count(True) + layer1 = InstanceNormalization(groups=2, center=center, scale=scale) + layer1.build(input_shape=[8, 28, 28, 16]) # build() + assert len(layer1.variables) == target_variables_len + assert len(layer1.trainable_variables) == target_trainable_variables_len + + layer2 = InstanceNormalization(groups=2, center=center, scale=scale) + layer2.compute_output_shape(input_shape=[8, 28, 28, 16]) # compute_output_shape() + assert len(layer2.variables) == target_variables_len + assert len(layer2.trainable_variables) == target_trainable_variables_len + + layer3 = InstanceNormalization(groups=2, center=center, scale=scale) + layer3(tf.random.normal(shape=[8, 28, 28, 16])) # call() + assert len(layer3.variables) == target_variables_len + assert len(layer3.trainable_variables) == target_trainable_variables_len + + def calculate_frn( x, beta=0.2, gamma=1, eps=1e-6, learned_epsilon=False, dtype=np.float32 ): @@ -471,3 +517,23 @@ def test_filter_response_normalization_save(tmpdir): model.save(filepath, save_format="h5") filepath = str(tmpdir / "test") model.save(filepath, save_format="tf") + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_filter_response_norm_compute_output_shape(): + target_variables_len = 2 + target_trainable_variables_len = 2 + layer1 = FilterResponseNormalization() + layer1.build(input_shape=[8, 28, 28, 16]) # build() + assert len(layer1.variables) == target_variables_len + assert len(layer1.trainable_variables) == target_trainable_variables_len + + layer2 = FilterResponseNormalization() + layer2.compute_output_shape(input_shape=[8, 28, 28, 16]) # compute_output_shape() + assert len(layer2.variables) == target_variables_len + assert len(layer2.trainable_variables) == target_trainable_variables_len + + layer3 = FilterResponseNormalization() + layer3(tf.random.normal(shape=[8, 28, 28, 16])) # call() + assert len(layer3.variables) == target_variables_len + assert len(layer3.trainable_variables) == target_trainable_variables_len