diff --git a/tensorflow_addons/layers/normalizations.py b/tensorflow_addons/layers/normalizations.py index dd343ba8a5..1b0c5ac872 100644 --- a/tensorflow_addons/layers/normalizations.py +++ b/tensorflow_addons/layers/normalizations.py @@ -161,7 +161,7 @@ def _reshape_into_groups(self, inputs, input_shape, tensor_input_shape): group_shape = [tensor_input_shape[i] for i in range(len(input_shape))] group_shape[self.axis] = input_shape[self.axis] // self.groups - group_shape.insert(1, self.groups) + group_shape.insert(self.axis, self.groups) group_shape = tf.stack(group_shape) reshaped_inputs = tf.reshape(inputs, group_shape) return reshaped_inputs, group_shape @@ -169,11 +169,12 @@ def _reshape_into_groups(self, inputs, input_shape, tensor_input_shape): def _apply_normalization(self, reshaped_inputs, input_shape): group_shape = tf.keras.backend.int_shape(reshaped_inputs) - group_reduction_axes = list(range(len(group_shape))) - # Remember the ordering of the tensor is [batch, group , steps]. Jump - # the first 2 to calculate the variance and the mean + group_reduction_axes = list(range(1, len(group_shape))) + axis = -2 if self.axis == -1 else self.axis - 1 + group_reduction_axes.pop(axis) + mean, variance = tf.nn.moments( - reshaped_inputs, group_reduction_axes[2:], keepdims=True) + reshaped_inputs, group_reduction_axes, keepdims=True) gamma, beta = self._get_reshaped_weights(input_shape) normalized_inputs = tf.nn.batch_normalization( @@ -269,7 +270,7 @@ def _add_beta_weight(self, input_shape): def _create_broadcast_shape(self, input_shape): broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis] // self.groups - broadcast_shape.insert(1, self.groups) + broadcast_shape.insert(self.axis, self.groups) return broadcast_shape diff --git a/tensorflow_addons/layers/normalizations_test.py b/tensorflow_addons/layers/normalizations_test.py index b71d4ab436..e4f43857f9 100644 --- a/tensorflow_addons/layers/normalizations_test.py +++ b/tensorflow_addons/layers/normalizations_test.py @@ -53,7 +53,7 @@ def run_reshape_test(axis, group, input_shape, expected_shape): self.assertEqual(int(group_shape[i]), expected_shape[i]) input_shape = (10, 10, 10) - expected_shape = [10, 5, 10, 2] + expected_shape = [10, 10, 5, 2] run_reshape_test(2, 5, input_shape, expected_shape) input_shape = (10, 10, 10) @@ -108,18 +108,18 @@ def _test_specific_layer(self, inputs, axis, groups, center, scale): np_inputs = inputs.numpy() reshaped_dims = list(np_inputs.shape) reshaped_dims[axis] = reshaped_dims[axis] // groups - reshaped_dims.insert(1, groups) + reshaped_dims.insert(axis, groups) reshaped_inputs = np.reshape(np_inputs, tuple(reshaped_dims)) + group_reduction_axes = list(range(1, len(reshaped_dims))) + axis = -2 if axis == -1 else axis - 1 + group_reduction_axes.pop(axis) + # Calculate mean and variance mean = np.mean( - reshaped_inputs, - axis=tuple(range(2, len(reshaped_dims))), - keepdims=True) + reshaped_inputs, axis=tuple(group_reduction_axes), keepdims=True) variance = np.var( - reshaped_inputs, - axis=tuple(range(2, len(reshaped_dims))), - keepdims=True) + reshaped_inputs, axis=tuple(group_reduction_axes), keepdims=True) # Get gamma and beta initalized by layer gamma, beta = layer._get_reshaped_weights(input_shape)