diff --git a/tensorflow_addons/layers/__init__.py b/tensorflow_addons/layers/__init__.py index 9606352b2d..576ae7a4b3 100644 --- a/tensorflow_addons/layers/__init__.py +++ b/tensorflow_addons/layers/__init__.py @@ -17,6 +17,7 @@ from tensorflow_addons.layers.gelu import GELU from tensorflow_addons.layers.maxout import Maxout from tensorflow_addons.layers.multihead_attention import MultiHeadAttention +from tensorflow_addons.layers.normalizations import FilterResponseNormalization from tensorflow_addons.layers.normalizations import GroupNormalization from tensorflow_addons.layers.normalizations import InstanceNormalization from tensorflow_addons.layers.optical_flow import CorrelationCost diff --git a/tensorflow_addons/layers/normalizations.py b/tensorflow_addons/layers/normalizations.py index 7322701f5f..ca2cad07fc 100644 --- a/tensorflow_addons/layers/normalizations.py +++ b/tensorflow_addons/layers/normalizations.py @@ -321,3 +321,204 @@ def __init__(self, **kwargs): kwargs["groups"] = -1 super().__init__(**kwargs) + + +@tf.keras.utils.register_keras_serializable(package="Addons") +class FilterResponseNormalization(tf.keras.layers.Layer): + """Filter response normalization layer. + + Filter Response Normalization (FRN), a normalization + method that enables models trained with per-channel + normalization to achieve high accuracy. It performs better than + all other normalization techniques for small batches and is par + with Batch Normalization for bigger batch sizes. + + Arguments + axis: List of axes that should be normalized. This should represent the + spatial dimensions. + epsilon: Small positive float value added to variance to avoid dividing by zero. + beta_initializer: Initializer for the beta weight. + gamma_initializer: Initializer for the gamma weight. + beta_regularizer: Optional regularizer for the beta weight. + gamma_regularizer: Optional regularizer for the gamma weight. + beta_constraint: Optional constraint for the beta weight. + gamma_constraint: Optional constraint for the gamma weight. + learned_epsilon: (bool) Whether to add another learnable + epsilon parameter or not. + name: Optional name for the layer + + Input shape + Arbitrary. Use the keyword argument `input_shape` + (tuple of integers, does not include the samples axis) + when using this layer as the first layer in a model. This layer, as of now, + works on a 4-D tensor where the tensor should have the shape [N X H X W X C] + + TODO: Add support for NCHW data format and FC layers. + + Output shape + Same shape as input. + + References + - [Filter Response Normalization Layer: Eliminating Batch Dependence + in the training of Deep Neural Networks] + (https://arxiv.org/abs/1911.09737) + """ + + def __init__( + self, + epsilon: float = 1e-6, + axis: list = [1, 2], + beta_initializer: types.Initializer = "zeros", + gamma_initializer: types.Initializer = "ones", + beta_regularizer: types.Regularizer = None, + gamma_regularizer: types.Regularizer = None, + beta_constraint: types.Constraint = None, + gamma_constraint: types.Constraint = None, + learned_epsilon: bool = False, + learned_epsilon_constraint: types.Constraint = None, + name: str = None, + **kwargs + ): + super().__init__(name=name, **kwargs) + self.epsilon = tf.math.abs(tf.cast(epsilon, dtype=self.dtype)) + self.beta_initializer = tf.keras.initializers.get(beta_initializer) + self.gamma_initializer = tf.keras.initializers.get(gamma_initializer) + self.beta_regularizer = tf.keras.regularizers.get(beta_regularizer) + self.gamma_regularizer = tf.keras.regularizers.get(gamma_regularizer) + self.beta_constraint = tf.keras.constraints.get(beta_constraint) + self.gamma_constraint = tf.keras.constraints.get(gamma_constraint) + self.use_eps_learned = learned_epsilon + self.supports_masking = True + + if self.use_eps_learned: + self.eps_learned_initializer = tf.keras.initializers.Constant(1e-4) + self.eps_learned_constraint = tf.keras.constraints.get( + learned_epsilon_constraint + ) + self.eps_learned = self.add_weight( + shape=(1,), + name="learned_epsilon", + dtype=self.dtype, + initializer=tf.keras.initializers.get(self.eps_learned_initializer), + regularizer=None, + constraint=self.eps_learned_constraint, + ) + else: + self.eps_learned_initializer = None + self.eps_learned_constraint = None + + self._check_axis(axis) + + def build(self, input_shape): + if len(tf.TensorShape(input_shape)) != 4: + raise ValueError( + """Only 4-D tensors (CNNs) are supported + as of now.""" + ) + self._check_if_input_shape_is_none(input_shape) + self._create_input_spec(input_shape) + self._add_gamma_weight(input_shape) + self._add_beta_weight(input_shape) + super().build(input_shape) + + def call(self, inputs): + epsilon = self.epsilon + if self.use_eps_learned: + epsilon += tf.math.abs(self.eps_learned) + nu2 = tf.reduce_mean(tf.square(inputs), axis=self.axis, keepdims=True) + normalized_inputs = inputs * tf.math.rsqrt(nu2 + epsilon) + return self.gamma * normalized_inputs + self.beta + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = { + "axis": self.axis, + "epsilon": self.epsilon, + "learned_epsilon": self.use_eps_learned, + "beta_initializer": tf.keras.initializers.serialize(self.beta_initializer), + "gamma_initializer": tf.keras.initializers.serialize( + self.gamma_initializer + ), + "beta_regularizer": tf.keras.regularizers.serialize(self.beta_regularizer), + "gamma_regularizer": tf.keras.regularizers.serialize( + self.gamma_regularizer + ), + "beta_constraint": tf.keras.constraints.serialize(self.beta_constraint), + "gamma_constraint": tf.keras.constraints.serialize(self.gamma_constraint), + "learned_epsilon_constraint": tf.keras.constraints.serialize( + self.eps_learned_constraint + ), + } + base_config = super().get_config() + return dict(**base_config, **config) + + def _create_input_spec(self, input_shape): + ndims = len(tf.TensorShape(input_shape)) + for idx, x in enumerate(self.axis): + if x < 0: + self.axis[idx] = ndims + x + + # Validate axes + for x in self.axis: + if x < 0 or x >= ndims: + raise ValueError("Invalid axis: %d" % x) + + if len(self.axis) != len(set(self.axis)): + raise ValueError("Duplicate axis: %s" % self.axis) + + axis_to_dim = {x: input_shape[x] for x in self.axis} + self.input_spec = tf.keras.layers.InputSpec(ndim=ndims, axes=axis_to_dim) + + def _check_axis(self, axis): + if not isinstance(axis, list): + raise TypeError( + """Expected a list of values but got {}.""".format(type(axis)) + ) + else: + self.axis = axis + + if self.axis != [1, 2]: + raise ValueError( + """FilterResponseNormalization operates on per-channel basis. + Axis values should be a list of spatial dimensions.""" + ) + + def _check_if_input_shape_is_none(self, input_shape): + dim1, dim2 = input_shape[self.axis[0]], input_shape[self.axis[1]] + if dim1 is None or dim2 is None: + raise ValueError( + """Axis {} of input tensor should have a defined dimension but + the layer received an input with shape {}.""".format( + self.axis, input_shape + ) + ) + + def _add_gamma_weight(self, input_shape): + # Get the channel dimension + dim = input_shape[-1] + shape = [1, 1, 1, dim] + # Initialize gamma with shape (1, 1, 1, C) + self.gamma = self.add_weight( + shape=shape, + name="gamma", + dtype=self.dtype, + initializer=self.gamma_initializer, + regularizer=self.gamma_regularizer, + constraint=self.gamma_constraint, + ) + + def _add_beta_weight(self, input_shape): + # Get the channel dimension + dim = input_shape[-1] + shape = [1, 1, 1, dim] + # Initialize beta with shape (1, 1, 1, C) + self.beta = self.add_weight( + shape=shape, + name="beta", + dtype=self.dtype, + initializer=self.beta_initializer, + regularizer=self.beta_regularizer, + constraint=self.beta_constraint, + ) diff --git a/tensorflow_addons/layers/normalizations_test.py b/tensorflow_addons/layers/normalizations_test.py index 61b21e8b40..3f209c9264 100644 --- a/tensorflow_addons/layers/normalizations_test.py +++ b/tensorflow_addons/layers/normalizations_test.py @@ -19,6 +19,7 @@ import numpy as np import tensorflow as tf +from tensorflow_addons.layers.normalizations import FilterResponseNormalization from tensorflow_addons.layers.normalizations import GroupNormalization from tensorflow_addons.layers.normalizations import InstanceNormalization from tensorflow_addons.utils import test_utils @@ -331,5 +332,121 @@ def test_groupnorm_convnet_no_center_no_scale(self): ) +def calculate_frn( + x, beta=0.2, gamma=1, eps=1e-6, learned_epsilon=False, dtype=np.float32 +): + if learned_epsilon: + eps = eps + 1e-4 + eps = tf.cast(eps, dtype=dtype) + nu2 = tf.reduce_mean(tf.square(x), axis=[1, 2], keepdims=True) + x = x * tf.math.rsqrt(nu2 + tf.abs(eps)) + return gamma * x + beta + + +def set_random_seed(): + seed = 0x2020 + np.random.seed(seed) + tf.random.set_seed(seed) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) +def test_with_beta(dtype): + set_random_seed() + inputs = np.random.rand(28, 28, 1).astype(dtype) + inputs = np.expand_dims(inputs, axis=0) + frn = FilterResponseNormalization( + beta_initializer="ones", gamma_initializer="ones", dtype=dtype + ) + frn.build((None, 28, 28, 1)) + observed = frn(inputs) + expected = calculate_frn(inputs, beta=1, gamma=1, dtype=dtype) + np.testing.assert_allclose(expected[0], observed[0]) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) +def test_with_gamma(dtype): + set_random_seed() + inputs = np.random.rand(28, 28, 1).astype(dtype) + inputs = np.expand_dims(inputs, axis=0) + frn = FilterResponseNormalization( + beta_initializer="zeros", gamma_initializer="ones", dtype=dtype + ) + frn.build((None, 28, 28, 1)) + observed = frn(inputs) + expected = calculate_frn(inputs, beta=0, gamma=1, dtype=dtype) + np.testing.assert_allclose(expected[0], observed[0]) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) +def test_with_epsilon(dtype): + set_random_seed() + inputs = np.random.rand(28, 28, 1).astype(dtype) + inputs = np.expand_dims(inputs, axis=0) + frn = FilterResponseNormalization( + beta_initializer=tf.keras.initializers.Constant(0.5), + gamma_initializer="ones", + learned_epsilon=True, + dtype=dtype, + ) + frn.build((None, 28, 28, 1)) + observed = frn(inputs) + expected = calculate_frn( + inputs, beta=0.5, gamma=1, learned_epsilon=True, dtype=dtype + ) + np.testing.assert_allclose(expected[0], observed[0]) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) +def test_keras_model(dtype): + set_random_seed() + frn = FilterResponseNormalization( + beta_initializer="ones", gamma_initializer="ones", dtype=dtype + ) + random_inputs = np.random.rand(10, 32, 32, 3).astype(dtype) + random_labels = np.random.randint(2, size=(10,)).astype(dtype) + input_layer = tf.keras.layers.Input(shape=(32, 32, 3)) + x = frn(input_layer) + x = tf.keras.layers.Flatten()(x) + out = tf.keras.layers.Dense(1, activation="sigmoid")(x) + model = tf.keras.models.Model(input_layer, out) + model.compile(loss="binary_crossentropy", optimizer="sgd") + model.fit(random_inputs, random_labels, epochs=2) + + +@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) +def test_serialization(dtype): + frn = FilterResponseNormalization( + beta_initializer="ones", gamma_initializer="ones", dtype=dtype + ) + serialized_frn = tf.keras.layers.serialize(frn) + new_layer = tf.keras.layers.deserialize(serialized_frn) + assert frn.get_config() == new_layer.get_config() + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) +def test_eps_gards(dtype): + set_random_seed() + random_inputs = np.random.rand(10, 32, 32, 3).astype(np.float32) + random_labels = np.random.randint(2, size=(10,)).astype(np.float32) + input_layer = tf.keras.layers.Input(shape=(32, 32, 3)) + frn = FilterResponseNormalization( + beta_initializer="ones", gamma_initializer="ones", learned_epsilon=True + ) + initial_eps_value = frn.eps_learned.numpy()[0] + x = frn(input_layer) + x = tf.keras.layers.Flatten()(x) + out = tf.keras.layers.Dense(1, activation="sigmoid")(x) + model = tf.keras.models.Model(input_layer, out) + model.compile(loss="binary_crossentropy", optimizer="sgd") + model.fit(random_inputs, random_labels, epochs=1) + final_eps_value = frn.eps_learned.numpy()[0] + assert initial_eps_value != final_eps_value + + if __name__ == "__main__": sys.exit(pytest.main([__file__]))