diff --git a/README.md b/README.md index c165bf06e9..4a1e2fc174 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,9 @@ developments that cannot be integrated into core TensorFlow |:----------------------- |:----------- |:---------------------------- | | tfa.activations | Sparsemax | https://arxiv.org/abs/1602.02068 | | tfa.image | transform | | +| tfa.layers | GroupNormalization | https://arxiv.org/abs/1803.08494 | +| tfa.layers | InstanceNormalization | https://arxiv.org/abs/1607.08022 | +| tfa.layers | LayerNormalization | https://arxiv.org/abs/1607.06450 | | tfa.layers | Maxout | https://arxiv.org/abs/1302.4389 | | tfa.layers | PoinareNormalize | https://arxiv.org/abs/1705.08039 | | tfa.layers | WeightNormalization | https://arxiv.org/abs/1602.07868 | diff --git a/tensorflow_addons/layers/BUILD b/tensorflow_addons/layers/BUILD index b8de59fbe0..3ae079bbf1 100644 --- a/tensorflow_addons/layers/BUILD +++ b/tensorflow_addons/layers/BUILD @@ -8,6 +8,7 @@ py_library( "__init__.py", "python/__init__.py", "python/maxout.py", + "python/normalizations.py", "python/poincare.py", "python/sparsemax.py", "python/wrappers.py", @@ -29,7 +30,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":layers_py", - ], + ] ) py_test( @@ -55,18 +56,18 @@ py_test( srcs_version = "PY2AND3", deps = [ ":layers_py", - ], + ] ) py_test( - name = "poincare_py_test", - size = "small", + name = "layers_normalizations_py_test", + size= "small", srcs = [ - "python/poincare_test.py", + "python/normalizations_test.py", ], - main = "python/poincare_test.py", + main = "python/normalizations_test.py", srcs_version = "PY2AND3", deps = [ - ":layers_py", - ], + ":layers_py", + ] ) diff --git a/tensorflow_addons/layers/README.md b/tensorflow_addons/layers/README.md index c9832c87c1..9e34f0ac5d 100644 --- a/tensorflow_addons/layers/README.md +++ b/tensorflow_addons/layers/README.md @@ -3,6 +3,9 @@ ## Contents | Layer | Reference | |:----------------------- |:-----------------------------| +| GroupNormalization | https://arxiv.org/abs/1803.08494 | +| InstanceNormalization | https://arxiv.org/abs/1607.08022 | +| LayerNormalization | https://arxiv.org/abs/1607.06450 | | Maxout | https://arxiv.org/abs/1302.4389 | | PoinareNormalize | https://arxiv.org/abs/1705.08039 | | WeightNormalization | https://arxiv.org/abs/1602.07868 | diff --git a/tensorflow_addons/layers/__init__.py b/tensorflow_addons/layers/__init__.py index 0e06709ac7..c5e0497726 100644 --- a/tensorflow_addons/layers/__init__.py +++ b/tensorflow_addons/layers/__init__.py @@ -19,6 +19,9 @@ from __future__ import print_function from tensorflow_addons.layers.python.maxout import Maxout +from tensorflow_addons.layers.python.normalizations import GroupNormalization +from tensorflow_addons.layers.python.normalizations import InstanceNormalization +from tensorflow_addons.layers.python.normalizations import LayerNormalization from tensorflow_addons.layers.python.poincare import PoincareNormalize from tensorflow_addons.layers.python.sparsemax import Sparsemax from tensorflow_addons.layers.python.wrappers import WeightNormalization diff --git a/tensorflow_addons/layers/python/normalizations.py b/tensorflow_addons/layers/python/normalizations.py new file mode 100644 index 0000000000..2a07a3d802 --- /dev/null +++ b/tensorflow_addons/layers/python/normalizations.py @@ -0,0 +1,361 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Orginal implementation from keras_contrib/layer/normalization +# ============================================================================= +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +import tensorflow as tf +from tensorflow_addons.utils.python import keras_utils + + +@keras_utils.register_keras_custom_object +class GroupNormalization(tf.keras.layers.Layer): + """Group normalization layer. + + Group Normalization divides the channels into groups and computes + within each group the mean and variance for normalization. + Empirically, its accuracy is more stable than batch norm in a wide + range of small batch sizes, if learning rate is adjusted linearly + with batch sizes. + + Relation to Layer Normalization: + If the number of groups is set to 1, then this operation becomes identical + to Layer Normalization. + + Relation to Instance Normalization: + If the number of groups is set to the + input dimension (number of groups is equal + to number of channels), then this operation becomes + identical to Instance Normalization. + + Arguments + groups: Integer, the number of groups for Group Normalization. + Can be in the range [1, N] where N is the input dimension. + The input dimension must be divisible by the number of groups. + axis: Integer, the axis that should be normalized. + epsilon: Small float added to variance to avoid dividing by zero. + center: If True, add offset of `beta` to normalized tensor. + If False, `beta` is ignored. + scale: If True, multiply by `gamma`. + If False, `gamma` is not used. + beta_initializer: Initializer for the beta weight. + gamma_initializer: Initializer for the gamma weight. + beta_regularizer: Optional regularizer for the beta weight. + gamma_regularizer: Optional regularizer for the gamma weight. + beta_constraint: Optional constraint for the beta weight. + gamma_constraint: Optional constraint for the gamma weight. + + Input shape + Arbitrary. Use the keyword argument `input_shape` + (tuple of integers, does not include the samples axis) + when using this layer as the first layer in a model. + + Output shape + Same shape as input. + References + - [Group Normalization](https://arxiv.org/abs/1803.08494) + """ + + def __init__(self, + groups=2, + axis=-1, + epsilon=1e-5, + center=True, + scale=True, + beta_initializer='zeros', + gamma_initializer='ones', + beta_regularizer=None, + gamma_regularizer=None, + beta_constraint=None, + gamma_constraint=None, + **kwargs): + super(GroupNormalization, self).__init__(**kwargs) + self.supports_masking = True + self.groups = groups + self.axis = axis + self.epsilon = epsilon + self.center = center + self.scale = scale + self.beta_initializer = tf.keras.initializers.get(beta_initializer) + self.gamma_initializer = tf.keras.initializers.get(gamma_initializer) + self.beta_regularizer = tf.keras.regularizers.get(beta_regularizer) + self.gamma_regularizer = tf.keras.regularizers.get(gamma_regularizer) + self.beta_constraint = tf.keras.constraints.get(beta_constraint) + self.gamma_constraint = tf.keras.constraints.get(gamma_constraint) + self._check_axis() + + def build(self, input_shape): + + self._check_if_input_shape_is_none(input_shape) + self._set_number_of_groups_for_instance_norm(input_shape) + self._check_size_of_dimensions(input_shape) + self._create_input_spec(input_shape) + + self._add_gamma_weight(input_shape) + self._add_beta_weight(input_shape) + self.built = True + super(GroupNormalization, self).build(input_shape) + + def call(self, inputs): + + input_shape = tf.keras.backend.int_shape(inputs) + tensor_input_shape = tf.shape(inputs) + + reshaped_inputs, group_shape = self._reshape_into_groups( + inputs, input_shape, tensor_input_shape) + + normalized_inputs = self._apply_normalization(reshaped_inputs, + input_shape) + + outputs = tf.reshape(normalized_inputs, tensor_input_shape) + + return outputs + + def get_config(self): + config = { + 'groups': + self.groups, + 'axis': + self.axis, + 'epsilon': + self.epsilon, + 'center': + self.center, + 'scale': + self.scale, + 'beta_initializer': + tf.keras.initializers.serialize(self.beta_initializer), + 'gamma_initializer': + tf.keras.initializers.serialize(self.gamma_initializer), + 'beta_regularizer': + tf.keras.regularizers.serialize(self.beta_regularizer), + 'gamma_regularizer': + tf.keras.regularizers.serialize(self.gamma_regularizer), + 'beta_constraint': + tf.keras.constraints.serialize(self.beta_constraint), + 'gamma_constraint': + tf.keras.constraints.serialize(self.gamma_constraint) + } + base_config = super(GroupNormalization, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + def compute_output_shape(self, input_shape): + return input_shape + + def _reshape_into_groups(self, inputs, input_shape, tensor_input_shape): + + group_shape = [tensor_input_shape[i] for i in range(len(input_shape))] + group_shape[self.axis] = input_shape[self.axis] // self.groups + group_shape.insert(1, self.groups) + group_shape = tf.stack(group_shape) + reshaped_inputs = tf.reshape(inputs, group_shape) + return reshaped_inputs, group_shape + + def _apply_normalization(self, reshaped_inputs, input_shape): + + group_shape = tf.keras.backend.int_shape(reshaped_inputs) + group_reduction_axes = list(range(len(group_shape))) + # Remember the ordering of the tensor is [batch, group , steps]. Jump + # the first 2 to calculate the variance and the mean + mean, variance = tf.nn.moments( + reshaped_inputs, group_reduction_axes[2:], keepdims=True) + + gamma, beta = self._get_reshaped_weights(input_shape) + normalized_inputs = tf.nn.batch_normalization( + reshaped_inputs, + mean=mean, + variance=variance, + scale=gamma, + offset=beta, + variance_epsilon=self.epsilon) + return normalized_inputs + + def _get_reshaped_weights(self, input_shape): + broadcast_shape = self._create_broadcast_shape(input_shape) + gamma = None + beta = None + if self.scale: + gamma = tf.reshape(self.gamma, broadcast_shape) + + if self.center: + beta = tf.reshape(self.beta, broadcast_shape) + return gamma, beta + + def _check_if_input_shape_is_none(self, input_shape): + dim = input_shape[self.axis] + if dim is None: + raise ValueError('Axis ' + str(self.axis) + ' of ' + 'input tensor should have a defined dimension ' + 'but the layer received an input with shape ' + + str(input_shape) + '.') + + def _set_number_of_groups_for_instance_norm(self, input_shape): + dim = input_shape[self.axis] + + if self.groups == -1: + self.groups = dim + + def _check_size_of_dimensions(self, input_shape): + + dim = input_shape[self.axis] + if dim < self.groups: + raise ValueError( + 'Number of groups (' + str(self.groups) + ') cannot be ' + 'more than the number of channels (' + str(dim) + ').') + + if dim % self.groups != 0: + raise ValueError( + 'Number of groups (' + str(self.groups) + ') must be a ' + 'multiple of the number of channels (' + str(dim) + ').') + + def _check_axis(self): + + if self.axis == 0: + raise ValueError( + "You are trying to normalize your batch axis. Do you want to " + "use tf.layer.batch_normalization instead") + + def _create_input_spec(self, input_shape): + + dim = input_shape[self.axis] + self.input_spec = tf.keras.layers.InputSpec( + ndim=len(input_shape), axes={self.axis: dim}) + + def _add_gamma_weight(self, input_shape): + + dim = input_shape[self.axis] + shape = (dim,) + + if self.scale: + self.gamma = self.add_weight( + shape=shape, + name='gamma', + initializer=self.gamma_initializer, + regularizer=self.gamma_regularizer, + constraint=self.gamma_constraint) + else: + self.gamma = None + + def _add_beta_weight(self, input_shape): + + dim = input_shape[self.axis] + shape = (dim,) + + if self.center: + self.beta = self.add_weight( + shape=shape, + name='beta', + initializer=self.beta_initializer, + regularizer=self.beta_regularizer, + constraint=self.beta_constraint) + else: + self.beta = None + + def _create_broadcast_shape(self, input_shape): + broadcast_shape = [1] * len(input_shape) + broadcast_shape[self.axis] = input_shape[self.axis] // self.groups + broadcast_shape.insert(1, self.groups) + return broadcast_shape + + +@keras_utils.register_keras_custom_object +class LayerNormalization(GroupNormalization): + """Layer normalization layer. + + Layer Normalization is an specific case of ```GroupNormalization```since it + normalizes all features of a layer. The Groupsize is 1. + Empirically, its accuracy is more stable than batch norm in a wide + range of small batch sizes, if learning rate is adjusted linearly + with batch sizes. + + Arguments + axis: Integer, the axis that should be normalized. + epsilon: Small float added to variance to avoid dividing by zero. + center: If True, add offset of `beta` to normalized tensor. + If False, `beta` is ignored. + scale: If True, multiply by `gamma`. + If False, `gamma` is not used. + beta_initializer: Initializer for the beta weight. + gamma_initializer: Initializer for the gamma weight. + beta_regularizer: Optional regularizer for the beta weight. + gamma_regularizer: Optional regularizer for the gamma weight. + beta_constraint: Optional constraint for the beta weight. + gamma_constraint: Optional constraint for the gamma weight. + + Input shape + Arbitrary. Use the keyword argument `input_shape` + (tuple of integers, does not include the samples axis) + when using this layer as the first layer in a model. + + Output shape + Same shape as input. + + References + - [Layer Normalization](https://arxiv.org/abs/1607.06450) + """ + + def __init__(self, **kwargs): + if "groups" in kwargs: + logging.warning("The given value for groups will be overwritten.") + kwargs["groups"] = 1 + super(LayerNormalization, self).__init__(**kwargs) + + +@keras_utils.register_keras_custom_object +class InstanceNormalization(GroupNormalization): + """Instance normalization layer. + + Instance Normalization is an specific case of ```GroupNormalization```since + it normalizes all features of one channel. The Groupsize is equal to the + channel size. Empirically, its accuracy is more stable than batch norm in a + wide range of small batch sizes, if learning rate is adjusted linearly + with batch sizes. + + Arguments + axis: Integer, the axis that should be normalized. + epsilon: Small float added to variance to avoid dividing by zero. + center: If True, add offset of `beta` to normalized tensor. + If False, `beta` is ignored. + scale: If True, multiply by `gamma`. + If False, `gamma` is not used. + beta_initializer: Initializer for the beta weight. + gamma_initializer: Initializer for the gamma weight. + beta_regularizer: Optional regularizer for the beta weight. + gamma_regularizer: Optional regularizer for the gamma weight. + beta_constraint: Optional constraint for the beta weight. + gamma_constraint: Optional constraint for the gamma weight. + + Input shape + Arbitrary. Use the keyword argument `input_shape` + (tuple of integers, does not include the samples axis) + when using this layer as the first layer in a model. + + Output shape + Same shape as input. + + References + - [Instance Normalization: The Missing Ingredient for Fast Stylization] + (https://arxiv.org/abs/1607.08022) + """ + + def __init__(self, **kwargs): + if "groups" in kwargs: + logging.warning("The given value for groups will be overwritten.") + + kwargs["groups"] = -1 + super(InstanceNormalization, self).__init__(**kwargs) diff --git a/tensorflow_addons/layers/python/normalizations_test.py b/tensorflow_addons/layers/python/normalizations_test.py new file mode 100644 index 0000000000..f3bf95afae --- /dev/null +++ b/tensorflow_addons/layers/python/normalizations_test.py @@ -0,0 +1,282 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +from tensorflow_addons.layers.python.normalizations import GroupNormalization +from tensorflow_addons.layers.python.normalizations import InstanceNormalization +from tensorflow_addons.layers.python.normalizations import LayerNormalization +from tensorflow_addons.utils.python import test_utils + + +class NormalizationTest(tf.test.TestCase): + + # ------------Tests to ensure proper inheritance. If these suceed you can + # test for Instance norm and Layernorm by setting Groupnorm groups = -1 or 1 + def test_inheritance(self): + self.assertTrue(issubclass(LayerNormalization, GroupNormalization)) + self.assertTrue(issubclass(InstanceNormalization, GroupNormalization)) + self.assertTrue(LayerNormalization.build == GroupNormalization.build) + self.assertTrue( + InstanceNormalization.build == GroupNormalization.build) + self.assertTrue(LayerNormalization.call == GroupNormalization.call) + self.assertTrue(InstanceNormalization.call == GroupNormalization.call) + + def test_groups_after_init(self): + layers = InstanceNormalization() + self.assertTrue(layers.groups == -1) + layers = LayerNormalization() + self.assertTrue(layers.groups == 1) + + # ------------------------------------------------------------------------------ + + def test_reshape(self): + def run_reshape_test(axis, group, input_shape, expected_shape): + group_layer = GroupNormalization(groups=group, axis=axis) + group_layer._set_number_of_groups_for_instance_norm(input_shape) + + inputs = np.ones(input_shape) + tensor_input_shape = tf.convert_to_tensor(input_shape) + reshaped_inputs, group_shape = group_layer._reshape_into_groups( + inputs, (10, 10, 10), tensor_input_shape) + for i in range(len(expected_shape)): + self.assertEqual(int(group_shape[i]), expected_shape[i]) + + input_shape = (10, 10, 10) + expected_shape = [10, 5, 10, 2] + run_reshape_test(2, 5, input_shape, expected_shape) + + input_shape = (10, 10, 10) + expected_shape = [10, 2, 5, 10] + run_reshape_test(1, 2, input_shape, expected_shape) + + input_shape = (10, 10, 10) + expected_shape = [10, 10, 1, 10] + run_reshape_test(1, -1, input_shape, expected_shape) + + input_shape = (10, 10, 10) + expected_shape = [10, 1, 10, 10] + run_reshape_test(1, 1, input_shape, expected_shape) + + def test_feature_input(self): + shape = (10, 100) + for center in [True, False]: + for scale in [True, False]: + for groups in [-1, 1, 2, 5]: + self._test_random_shape_on_all_axis_except_batch( + shape, groups, center, scale) + + def test_picture_input(self): + shape = (10, 30, 30, 3) + for center in [True, False]: + for scale in [True, False]: + for groups in [-1, 1, 3]: + self._test_random_shape_on_all_axis_except_batch( + shape, groups, center, scale) + + def _test_random_shape_on_all_axis_except_batch(self, shape, groups, + center, scale): + inputs = tf.random.normal((shape)) + for axis in range(1, len(shape)): + self._test_specific_layer(inputs, axis, groups, center, scale) + + def _test_specific_layer(self, inputs, axis, groups, center, scale): + + input_shape = inputs.shape + + # Get Output from Keras model + layer = GroupNormalization( + axis=axis, groups=groups, center=center, scale=scale) + model = tf.keras.models.Sequential() + model.add(layer) + outputs = model.predict(inputs) + self.assertFalse(np.isnan(outputs).any()) + + # Create shapes + if groups is -1: + groups = input_shape[axis] + np_inputs = inputs.numpy() + reshaped_dims = list(np_inputs.shape) + reshaped_dims[axis] = reshaped_dims[axis] // groups + reshaped_dims.insert(1, groups) + reshaped_inputs = np.reshape(np_inputs, tuple(reshaped_dims)) + + # Calculate mean and variance + mean = np.mean( + reshaped_inputs, + axis=tuple(range(2, len(reshaped_dims))), + keepdims=True) + variance = np.var( + reshaped_inputs, + axis=tuple(range(2, len(reshaped_dims))), + keepdims=True) + + # Get gamma and beta initalized by layer + gamma, beta = layer._get_reshaped_weights(input_shape) + if gamma is None: + gamma = 1.0 + if beta is None: + beta = 0.0 + + # Get ouput from Numpy + zeroed = reshaped_inputs - mean + rsqrt = 1 / np.sqrt(variance + 1e-5) + output_test = gamma * zeroed * rsqrt + beta + + # compare outputs + output_test = np.reshape(output_test, input_shape.as_list()) + self.assertAlmostEqual(np.mean(output_test - outputs), 0, places=7) + + def _create_and_fit_Sequential_model(self, layer, shape): + # Helperfunction for quick evaluation + model = tf.keras.models.Sequential() + model.add(layer) + model.add(tf.keras.layers.Dense(32)) + model.add(tf.keras.layers.Dense(1)) + + model.compile( + optimizer=tf.keras.optimizers.RMSprop(0.01), + loss="categorical_crossentropy") + layer_shape = (10,) + shape + input_batch = np.random.rand(*layer_shape) + output_batch = np.random.rand(*(10, 1)) + model.fit(x=input_batch, y=output_batch, epochs=1, batch_size=1) + return model + + @test_utils.run_in_graph_and_eager_modes + def test_weights(self): + # Check if weights get initialized correctly + layer = GroupNormalization(groups=1, scale=False, center=False) + layer.build((None, 3, 4)) + self.assertEqual(len(layer.trainable_weights), 0) + self.assertEqual(len(layer.weights), 0) + + layer = LayerNormalization() + layer.build((None, 3, 4)) + self.assertEqual(len(layer.trainable_weights), 2) + self.assertEqual(len(layer.weights), 2) + + layer = InstanceNormalization() + layer.build((None, 3, 4)) + self.assertEqual(len(layer.trainable_weights), 2) + self.assertEqual(len(layer.weights), 2) + + def test_apply_normalization(self): + + input_shape = (1, 4) + expected_shape = (1, 2, 2) + reshaped_inputs = tf.constant([[[2.0, 2.0], [3.0, 3.0]]]) + layer = GroupNormalization(groups=2, axis=1, scale=False, center=False) + normalized_input = layer._apply_normalization(reshaped_inputs, + input_shape) + self.assertTrue( + tf.reduce_all( + tf.equal(normalized_input, + tf.constant([[[0.0, 0.0], [0.0, 0.0]]])))) + + def test_axis_error(self): + + with self.assertRaises(ValueError): + GroupNormalization(axis=0) + + @test_utils.run_in_graph_and_eager_modes + def test_groupnorm_flat(self): + # Check basic usage of groupnorm_flat + # Testing for 1 == LayerNorm, 16 == GroupNorm, -1 == InstanceNorm + + groups = [-1, 16, 1] + shape = (64,) + for i in groups: + model = self._create_and_fit_Sequential_model( + GroupNormalization(groups=i), shape) + self.assertTrue(hasattr(model.layers[0], 'gamma')) + self.assertTrue(hasattr(model.layers[0], 'beta')) + + @test_utils.run_in_graph_and_eager_modes + def test_layernorm_flat(self): + # Check basic usage of layernorm + + model = self._create_and_fit_Sequential_model(LayerNormalization(), + (64,)) + self.assertTrue(hasattr(model.layers[0], 'gamma')) + self.assertTrue(hasattr(model.layers[0], 'beta')) + + @test_utils.run_in_graph_and_eager_modes + def test_instancenorm_flat(self): + # Check basic usage of instancenorm + + model = self._create_and_fit_Sequential_model(InstanceNormalization(), + (64,)) + self.assertTrue(hasattr(model.layers[0], 'gamma')) + self.assertTrue(hasattr(model.layers[0], 'beta')) + + @test_utils.run_in_graph_and_eager_modes + def test_initializer(self): + # Check if the initializer for gamma and beta is working correctly + + layer = GroupNormalization( + groups=32, + beta_initializer='random_normal', + beta_constraint='NonNeg', + gamma_initializer='random_normal', + gamma_constraint='NonNeg') + + model = self._create_and_fit_Sequential_model(layer, (64,)) + + weights = np.array(model.layers[0].get_weights()) + negativ = weights[weights < 0.0] + self.assertTrue(len(negativ) == 0) + + @test_utils.run_in_graph_and_eager_modes + def test_regularizations(self): + + layer = GroupNormalization( + gamma_regularizer='l1', beta_regularizer='l1', groups=4, axis=2) + layer.build((None, 4, 4)) + self.assertEqual(len(layer.losses), 2) + max_norm = tf.keras.constraints.max_norm + layer = GroupNormalization( + gamma_constraint=max_norm, beta_constraint=max_norm) + layer.build((None, 3, 4)) + self.assertEqual(layer.gamma.constraint, max_norm) + self.assertEqual(layer.beta.constraint, max_norm) + + @test_utils.run_in_graph_and_eager_modes + def test_groupnorm_conv(self): + # Check if Axis is working for CONV nets + # Testing for 1 == LayerNorm, 5 == GroupNorm, -1 == InstanceNorm + + groups = [-1, 5, 1] + for i in groups: + model = tf.keras.models.Sequential() + model.add( + GroupNormalization(axis=1, groups=i, input_shape=(20, 20, 3))) + model.add(tf.keras.layers.Conv2D(5, (1, 1), padding='same')) + model.add(tf.keras.layers.Flatten()) + model.add(tf.keras.layers.Dense(1, activation='softmax')) + model.compile( + optimizer=tf.keras.optimizers.RMSprop(0.01), loss='mse') + x = np.random.randint(1000, size=(10, 20, 20, 3)) + y = np.random.randint(1000, size=(10, 1)) + a = model.fit(x=x, y=y, epochs=1) + self.assertTrue(hasattr(model.layers[0], 'gamma')) + + +if __name__ == "__main__": + tf.test.main()