diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 879f140855..2070520eb7 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -105,6 +105,8 @@ /tensorflow_addons/layers/tests/esn_test.py @pedrolarben /tensorflow_addons/layers/snake.py @failure-to-thrive /tensorflow_addons/layers/tests/snake_test.py @failure-to-thrive +/tensorflow_addons/layers/noisy_dense.py @leonshams +/tensorflow_addons/layers/tests/noisy_dense_test.py @leonshams /tensorflow_addons/losses/contrastive.py @windqaq /tensorflow_addons/losses/tests/contrastive_test.py @windqaq diff --git a/tensorflow_addons/layers/__init__.py b/tensorflow_addons/layers/__init__.py index 594d025a9f..f8a0d9a11a 100644 --- a/tensorflow_addons/layers/__init__.py +++ b/tensorflow_addons/layers/__init__.py @@ -38,3 +38,4 @@ from tensorflow_addons.layers.tlu import TLU from tensorflow_addons.layers.wrappers import WeightNormalization from tensorflow_addons.layers.esn import ESN +from tensorflow_addons.layers.noisy_dense import NoisyDense diff --git a/tensorflow_addons/layers/noisy_dense.py b/tensorflow_addons/layers/noisy_dense.py new file mode 100644 index 0000000000..647b28db7d --- /dev/null +++ b/tensorflow_addons/layers/noisy_dense.py @@ -0,0 +1,264 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import tensorflow as tf +from tensorflow.keras import ( + activations, + initializers, + regularizers, + constraints, +) +from tensorflow.keras import backend as K +from tensorflow.keras.layers import InputSpec +from typeguard import typechecked + +from tensorflow_addons.utils import types + + +def _scale_noise(x): + return tf.sign(x) * tf.sqrt(tf.abs(x)) + + +@tf.keras.utils.register_keras_serializable(package="Addons") +class NoisyDense(tf.keras.layers.Dense): + r"""Noisy dense layer that injects random noise to the weights of dense layer. + + Noisy dense layers are fully connected layers whose weights and biases are + augmented by factorised Gaussian noise. The factorised Gaussian noise is + controlled through gradient descent by a second weights layer. + + A `NoisyDense` layer implements the operation: + $$ + \mathrm{NoisyDense}(x) = + \mathrm{activation}(\mathrm{dot}(x, \mu + (\sigma \cdot \epsilon)) + + \mathrm{bias}) + $$ + where $\mu$ is the standard weights layer, $\epsilon$ is the factorised + Gaussian noise, and $\sigma$ is a second weights layer which controls + $\epsilon$. + + Note: bias only added if `use_bias` is `True`. + + Example: + + >>> # Create a `Sequential` model and add a NoisyDense + >>> # layer as the first layer. + >>> model = tf.keras.models.Sequential() + >>> model.add(tf.keras.Input(shape=(16,))) + >>> model.add(NoisyDense(32, activation='relu')) + >>> # Now the model will take as input arrays of shape (None, 16) + >>> # and output arrays of shape (None, 32). + >>> # Note that after the first layer, you don't need to specify + >>> # the size of the input anymore: + >>> model.add(NoisyDense(32)) + >>> model.output_shape + (None, 32) + + Arguments: + units: Positive integer, dimensionality of the output space. + sigma: A float between 0-1 used as a standard deviation figure and is + applied to the gaussian noise layer (`sigma_kernel` and `sigma_bias`). + activation: Activation function to use. + If you don't specify anything, no activation is applied + (ie. "linear" activation: `a(x) = x`). + use_bias: Boolean, whether the layer uses a bias vector. + kernel_regularizer: Regularizer function applied to + the `kernel` weights matrix. + bias_regularizer: Regularizer function applied to the bias vector. + activity_regularizer: Regularizer function applied to + the output of the layer (its "activation"). + kernel_constraint: Constraint function applied to + the `kernel` weights matrix. + bias_constraint: Constraint function applied to the bias vector. + + Input shape: + N-D tensor with shape: `(batch_size, ..., input_dim)`. + The most common situation would be + a 2D input with shape `(batch_size, input_dim)`. + + Output shape: + N-D tensor with shape: `(batch_size, ..., units)`. + For instance, for a 2D input with shape `(batch_size, input_dim)`, + the output would have shape `(batch_size, units)`. + + References: + - [Noisy Networks for Explanation](https://arxiv.org/pdf/1706.10295.pdf) + """ + + @typechecked + def __init__( + self, + units: int, + sigma: float = 0.5, + activation: types.Activation = None, + use_bias: bool = True, + kernel_regularizer: types.Regularizer = None, + bias_regularizer: types.Regularizer = None, + activity_regularizer: types.Regularizer = None, + kernel_constraint: types.Constraint = None, + bias_constraint: types.Constraint = None, + **kwargs + ): + super().__init__( + units=units, + activation=activation, + use_bias=use_bias, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + kernel_constraint=kernel_constraint, + bias_constraint=bias_constraint, + **kwargs, + ) + delattr(self, "kernel_initializer") + delattr(self, "bias_initializer") + self.sigma = sigma + + def build(self, input_shape): + # Make sure dtype is correct + dtype = tf.dtypes.as_dtype(self.dtype or K.floatx()) + if not (dtype.is_floating or dtype.is_complex): + raise TypeError( + "Unable to build `Dense` layer with non-floating point " + "dtype %s" % (dtype,) + ) + + input_shape = tf.TensorShape(input_shape) + self.last_dim = tf.compat.dimension_value(input_shape[-1]) + sqrt_dim = self.last_dim ** (1 / 2) + if self.last_dim is None: + raise ValueError( + "The last dimension of the inputs to `Dense` " + "should be defined. Found `None`." + ) + self.input_spec = InputSpec(min_ndim=2, axes={-1: self.last_dim}) + + sigma_init = initializers.Constant(value=self.sigma / sqrt_dim) + mu_init = initializers.RandomUniform(minval=-1 / sqrt_dim, maxval=1 / sqrt_dim) + + # Learnable parameters + self.sigma_kernel = self.add_weight( + "sigma_kernel", + shape=[self.last_dim, self.units], + initializer=sigma_init, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + dtype=self.dtype, + trainable=True, + ) + + self.mu_kernel = self.add_weight( + "mu_kernel", + shape=[self.last_dim, self.units], + initializer=mu_init, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + dtype=self.dtype, + trainable=True, + ) + + if self.use_bias: + self.sigma_bias = self.add_weight( + "sigma_bias", + shape=[ + self.units, + ], + initializer=sigma_init, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint, + dtype=self.dtype, + trainable=True, + ) + + self.mu_bias = self.add_weight( + "mu_bias", + shape=[ + self.units, + ], + initializer=mu_init, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint, + dtype=self.dtype, + trainable=True, + ) + else: + self.sigma_bias = None + self.mu_bias = None + self._reset_noise() + self.built = True + + @property + def kernel(self): + return self.mu_kernel + (self.sigma_kernel * self.eps_kernel) + + @property + def bias(self): + if self.use_bias: + return self.mu_bias + (self.sigma_bias * self.eps_bias) + + def _reset_noise(self): + """Create the factorised Gaussian noise.""" + + dtype = self._compute_dtype_object + + # Generate random noise + eps_i = tf.random.normal([self.last_dim, self.units], dtype=dtype) + eps_j = tf.random.normal( + [ + self.units, + ], + dtype=dtype, + ) + + # Scale the random noise + self.eps_kernel = _scale_noise(eps_i) * _scale_noise(eps_j) + self.eps_bias = _scale_noise(eps_j) + + def _remove_noise(self): + """Remove the factorised Gaussian noise.""" + + dtype = self._compute_dtype_object + self.eps_kernel = tf.zeros([self.last_dim, self.units], dtype=dtype) + self.eps_bias = tf.zeros([self.units], dtype=dtype) + + def call(self, inputs, reset_noise=True, remove_noise=False): + # Generate fixed parameters added as the noise + if remove_noise: + self._remove_noise() + elif reset_noise: + self._reset_noise() + + # TODO(WindQAQ): Replace this with `dense()` once public. + return super().call(inputs) + + def get_config(self): + # TODO(WindQAQ): Get rid of this hacky way. + config = super(tf.keras.layers.Dense, self).get_config() + config.update( + { + "units": self.units, + "sigma": self.sigma, + "activation": activations.serialize(self.activation), + "use_bias": self.use_bias, + "kernel_regularizer": regularizers.serialize(self.kernel_regularizer), + "bias_regularizer": regularizers.serialize(self.bias_regularizer), + "activity_regularizer": regularizers.serialize( + self.activity_regularizer + ), + "kernel_constraint": constraints.serialize(self.kernel_constraint), + "bias_constraint": constraints.serialize(self.bias_constraint), + } + ) + return config diff --git a/tensorflow_addons/layers/tests/noisy_dense_test.py b/tensorflow_addons/layers/tests/noisy_dense_test.py new file mode 100644 index 0000000000..9f76307518 --- /dev/null +++ b/tensorflow_addons/layers/tests/noisy_dense_test.py @@ -0,0 +1,141 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests NoisyDense layer.""" + + +import pytest +import numpy as np + +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras.mixed_precision.experimental import Policy + +from tensorflow_addons.utils import test_utils +from tensorflow_addons.layers.noisy_dense import NoisyDense + + +@pytest.mark.parametrize( + "input_shape", [(3, 2), (3, 4, 2), (None, None, 2), (3, 4, 5, 2)] +) +def test_noisy_dense(input_shape): + test_utils.layer_test(NoisyDense, kwargs={"units": 3}, input_shape=input_shape) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +@pytest.mark.parametrize("dtype", ["float16", "float32", "float64"]) +def test_noisy_dense_dtype(dtype): + inputs = tf.convert_to_tensor( + np.random.randint(low=0, high=7, size=(2, 2)), dtype=dtype + ) + layer = NoisyDense(5, dtype=dtype, name="noisy_dense_" + dtype) + outputs = layer(inputs) + np.testing.assert_array_equal(outputs.dtype, dtype) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_noisy_dense_with_policy(): + inputs = tf.convert_to_tensor(np.random.randint(low=0, high=7, size=(2, 2))) + layer = NoisyDense(5, dtype=Policy("mixed_float16"), name="noisy_dense_policy") + outputs = layer(inputs) + output_signature = layer.compute_output_signature( + tf.TensorSpec(dtype="float16", shape=(2, 2)) + ) + np.testing.assert_array_equal(output_signature.dtype, tf.dtypes.float16) + np.testing.assert_array_equal(output_signature.shape, (2, 5)) + np.testing.assert_array_equal(outputs.dtype, "float16") + np.testing.assert_array_equal(layer.mu_kernel.dtype, "float32") + np.testing.assert_array_equal(layer.sigma_kernel.dtype, "float32") + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_noisy_dense_regularization(): + layer = NoisyDense( + 3, + kernel_regularizer=keras.regularizers.l1(0.01), + bias_regularizer="l1", + activity_regularizer="l2", + name="noisy_dense_reg", + ) + layer(keras.backend.variable(np.ones((2, 4)))) + np.testing.assert_array_equal(5, len(layer.losses)) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_noisy_dense_constraints(): + k_constraint = keras.constraints.max_norm(0.01) + b_constraint = keras.constraints.max_norm(0.01) + layer = NoisyDense( + 3, + kernel_constraint=k_constraint, + bias_constraint=b_constraint, + name="noisy_dense_constriants", + ) + layer(keras.backend.variable(np.ones((2, 4)))) + np.testing.assert_array_equal(layer.mu_kernel.constraint, k_constraint) + np.testing.assert_array_equal(layer.sigma_kernel.constraint, k_constraint) + np.testing.assert_array_equal(layer.mu_bias.constraint, b_constraint) + np.testing.assert_array_equal(layer.sigma_bias.constraint, b_constraint) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_noisy_dense_automatic_reset_noise(): + inputs = tf.convert_to_tensor(np.random.randint(low=0, high=7, size=(2, 2))) + layer = NoisyDense(5, name="noise_dense_auto_reset_noise") + layer(inputs) + initial_eps_kernel = layer.eps_kernel + initial_eps_bias = layer.eps_bias + layer(inputs) + new_eps_kernel = layer.eps_kernel + new_eps_bias = layer.eps_bias + np.testing.assert_raises( + AssertionError, + np.testing.assert_array_equal, + initial_eps_kernel, + new_eps_kernel, + ) + np.testing.assert_raises( + AssertionError, + np.testing.assert_array_equal, + initial_eps_bias, + new_eps_bias, + ) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_noisy_dense_remove_noise(): + inputs = tf.convert_to_tensor(np.random.randint(low=0, high=7, size=(2, 2))) + layer = NoisyDense(5, name="noise_dense_manual_reset_noise") + layer(inputs) + initial_eps_kernel = layer.eps_kernel + initial_eps_bias = layer.eps_bias + layer(inputs, reset_noise=False, remove_noise=True) + new_eps_kernel = layer.eps_kernel + new_eps_bias = layer.eps_bias + kernel_zeros = tf.zeros(initial_eps_kernel.shape, dtype=initial_eps_kernel.dtype) + bias_zeros = tf.zeros(initial_eps_bias.shape, dtype=initial_eps_kernel.dtype) + np.testing.assert_raises( + AssertionError, + np.testing.assert_array_equal, + initial_eps_kernel, + new_eps_kernel, + ) + np.testing.assert_raises( + AssertionError, + np.testing.assert_array_equal, + initial_eps_bias, + new_eps_bias, + ) + np.testing.assert_array_equal(kernel_zeros, new_eps_kernel) + np.testing.assert_array_equal(bias_zeros, new_eps_bias)