|
| 1 | +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | + |
| 16 | +import tensorflow as tf |
| 17 | +from tensorflow.keras import ( |
| 18 | + activations, |
| 19 | + initializers, |
| 20 | + regularizers, |
| 21 | + constraints, |
| 22 | +) |
| 23 | +from tensorflow.keras import backend as K |
| 24 | +from tensorflow.keras.layers import InputSpec |
| 25 | +from typeguard import typechecked |
| 26 | + |
| 27 | +from tensorflow_addons.utils import types |
| 28 | + |
| 29 | + |
| 30 | +def _scale_noise(x): |
| 31 | + return tf.sign(x) * tf.sqrt(tf.abs(x)) |
| 32 | + |
| 33 | + |
| 34 | +@tf.keras.utils.register_keras_serializable(package="Addons") |
| 35 | +class NoisyDense(tf.keras.layers.Dense): |
| 36 | + r"""Noisy dense layer that injects random noise to the weights of dense layer. |
| 37 | +
|
| 38 | + Noisy dense layers are fully connected layers whose weights and biases are |
| 39 | + augmented by factorised Gaussian noise. The factorised Gaussian noise is |
| 40 | + controlled through gradient descent by a second weights layer. |
| 41 | +
|
| 42 | + A `NoisyDense` layer implements the operation: |
| 43 | + $$ |
| 44 | + \mathrm{NoisyDense}(x) = |
| 45 | + \mathrm{activation}(\mathrm{dot}(x, \mu + (\sigma \cdot \epsilon)) |
| 46 | + + \mathrm{bias}) |
| 47 | + $$ |
| 48 | + where $\mu$ is the standard weights layer, $\epsilon$ is the factorised |
| 49 | + Gaussian noise, and $\sigma$ is a second weights layer which controls |
| 50 | + $\epsilon$. |
| 51 | +
|
| 52 | + Note: bias only added if `use_bias` is `True`. |
| 53 | +
|
| 54 | + Example: |
| 55 | +
|
| 56 | + >>> # Create a `Sequential` model and add a NoisyDense |
| 57 | + >>> # layer as the first layer. |
| 58 | + >>> model = tf.keras.models.Sequential() |
| 59 | + >>> model.add(tf.keras.Input(shape=(16,))) |
| 60 | + >>> model.add(NoisyDense(32, activation='relu')) |
| 61 | + >>> # Now the model will take as input arrays of shape (None, 16) |
| 62 | + >>> # and output arrays of shape (None, 32). |
| 63 | + >>> # Note that after the first layer, you don't need to specify |
| 64 | + >>> # the size of the input anymore: |
| 65 | + >>> model.add(NoisyDense(32)) |
| 66 | + >>> model.output_shape |
| 67 | + (None, 32) |
| 68 | +
|
| 69 | + Arguments: |
| 70 | + units: Positive integer, dimensionality of the output space. |
| 71 | + sigma: A float between 0-1 used as a standard deviation figure and is |
| 72 | + applied to the gaussian noise layer (`sigma_kernel` and `sigma_bias`). |
| 73 | + activation: Activation function to use. |
| 74 | + If you don't specify anything, no activation is applied |
| 75 | + (ie. "linear" activation: `a(x) = x`). |
| 76 | + use_bias: Boolean, whether the layer uses a bias vector. |
| 77 | + kernel_regularizer: Regularizer function applied to |
| 78 | + the `kernel` weights matrix. |
| 79 | + bias_regularizer: Regularizer function applied to the bias vector. |
| 80 | + activity_regularizer: Regularizer function applied to |
| 81 | + the output of the layer (its "activation"). |
| 82 | + kernel_constraint: Constraint function applied to |
| 83 | + the `kernel` weights matrix. |
| 84 | + bias_constraint: Constraint function applied to the bias vector. |
| 85 | +
|
| 86 | + Input shape: |
| 87 | + N-D tensor with shape: `(batch_size, ..., input_dim)`. |
| 88 | + The most common situation would be |
| 89 | + a 2D input with shape `(batch_size, input_dim)`. |
| 90 | +
|
| 91 | + Output shape: |
| 92 | + N-D tensor with shape: `(batch_size, ..., units)`. |
| 93 | + For instance, for a 2D input with shape `(batch_size, input_dim)`, |
| 94 | + the output would have shape `(batch_size, units)`. |
| 95 | +
|
| 96 | + References: |
| 97 | + - [Noisy Networks for Explanation](https://arxiv.org/pdf/1706.10295.pdf) |
| 98 | + """ |
| 99 | + |
| 100 | + @typechecked |
| 101 | + def __init__( |
| 102 | + self, |
| 103 | + units: int, |
| 104 | + sigma: float = 0.5, |
| 105 | + activation: types.Activation = None, |
| 106 | + use_bias: bool = True, |
| 107 | + kernel_regularizer: types.Regularizer = None, |
| 108 | + bias_regularizer: types.Regularizer = None, |
| 109 | + activity_regularizer: types.Regularizer = None, |
| 110 | + kernel_constraint: types.Constraint = None, |
| 111 | + bias_constraint: types.Constraint = None, |
| 112 | + **kwargs |
| 113 | + ): |
| 114 | + super().__init__( |
| 115 | + units=units, |
| 116 | + activation=activation, |
| 117 | + use_bias=use_bias, |
| 118 | + kernel_regularizer=kernel_regularizer, |
| 119 | + bias_regularizer=bias_regularizer, |
| 120 | + activity_regularizer=activity_regularizer, |
| 121 | + kernel_constraint=kernel_constraint, |
| 122 | + bias_constraint=bias_constraint, |
| 123 | + **kwargs, |
| 124 | + ) |
| 125 | + delattr(self, "kernel_initializer") |
| 126 | + delattr(self, "bias_initializer") |
| 127 | + self.sigma = sigma |
| 128 | + |
| 129 | + def build(self, input_shape): |
| 130 | + # Make sure dtype is correct |
| 131 | + dtype = tf.dtypes.as_dtype(self.dtype or K.floatx()) |
| 132 | + if not (dtype.is_floating or dtype.is_complex): |
| 133 | + raise TypeError( |
| 134 | + "Unable to build `Dense` layer with non-floating point " |
| 135 | + "dtype %s" % (dtype,) |
| 136 | + ) |
| 137 | + |
| 138 | + input_shape = tf.TensorShape(input_shape) |
| 139 | + self.last_dim = tf.compat.dimension_value(input_shape[-1]) |
| 140 | + sqrt_dim = self.last_dim ** (1 / 2) |
| 141 | + if self.last_dim is None: |
| 142 | + raise ValueError( |
| 143 | + "The last dimension of the inputs to `Dense` " |
| 144 | + "should be defined. Found `None`." |
| 145 | + ) |
| 146 | + self.input_spec = InputSpec(min_ndim=2, axes={-1: self.last_dim}) |
| 147 | + |
| 148 | + sigma_init = initializers.Constant(value=self.sigma / sqrt_dim) |
| 149 | + mu_init = initializers.RandomUniform(minval=-1 / sqrt_dim, maxval=1 / sqrt_dim) |
| 150 | + |
| 151 | + # Learnable parameters |
| 152 | + self.sigma_kernel = self.add_weight( |
| 153 | + "sigma_kernel", |
| 154 | + shape=[self.last_dim, self.units], |
| 155 | + initializer=sigma_init, |
| 156 | + regularizer=self.kernel_regularizer, |
| 157 | + constraint=self.kernel_constraint, |
| 158 | + dtype=self.dtype, |
| 159 | + trainable=True, |
| 160 | + ) |
| 161 | + |
| 162 | + self.mu_kernel = self.add_weight( |
| 163 | + "mu_kernel", |
| 164 | + shape=[self.last_dim, self.units], |
| 165 | + initializer=mu_init, |
| 166 | + regularizer=self.kernel_regularizer, |
| 167 | + constraint=self.kernel_constraint, |
| 168 | + dtype=self.dtype, |
| 169 | + trainable=True, |
| 170 | + ) |
| 171 | + |
| 172 | + if self.use_bias: |
| 173 | + self.sigma_bias = self.add_weight( |
| 174 | + "sigma_bias", |
| 175 | + shape=[ |
| 176 | + self.units, |
| 177 | + ], |
| 178 | + initializer=sigma_init, |
| 179 | + regularizer=self.bias_regularizer, |
| 180 | + constraint=self.bias_constraint, |
| 181 | + dtype=self.dtype, |
| 182 | + trainable=True, |
| 183 | + ) |
| 184 | + |
| 185 | + self.mu_bias = self.add_weight( |
| 186 | + "mu_bias", |
| 187 | + shape=[ |
| 188 | + self.units, |
| 189 | + ], |
| 190 | + initializer=mu_init, |
| 191 | + regularizer=self.bias_regularizer, |
| 192 | + constraint=self.bias_constraint, |
| 193 | + dtype=self.dtype, |
| 194 | + trainable=True, |
| 195 | + ) |
| 196 | + else: |
| 197 | + self.sigma_bias = None |
| 198 | + self.mu_bias = None |
| 199 | + self._reset_noise() |
| 200 | + self.built = True |
| 201 | + |
| 202 | + @property |
| 203 | + def kernel(self): |
| 204 | + return self.mu_kernel + (self.sigma_kernel * self.eps_kernel) |
| 205 | + |
| 206 | + @property |
| 207 | + def bias(self): |
| 208 | + if self.use_bias: |
| 209 | + return self.mu_bias + (self.sigma_bias * self.eps_bias) |
| 210 | + |
| 211 | + def _reset_noise(self): |
| 212 | + """Create the factorised Gaussian noise.""" |
| 213 | + |
| 214 | + dtype = self._compute_dtype_object |
| 215 | + |
| 216 | + # Generate random noise |
| 217 | + eps_i = tf.random.normal([self.last_dim, self.units], dtype=dtype) |
| 218 | + eps_j = tf.random.normal( |
| 219 | + [ |
| 220 | + self.units, |
| 221 | + ], |
| 222 | + dtype=dtype, |
| 223 | + ) |
| 224 | + |
| 225 | + # Scale the random noise |
| 226 | + self.eps_kernel = _scale_noise(eps_i) * _scale_noise(eps_j) |
| 227 | + self.eps_bias = _scale_noise(eps_j) |
| 228 | + |
| 229 | + def _remove_noise(self): |
| 230 | + """Remove the factorised Gaussian noise.""" |
| 231 | + |
| 232 | + dtype = self._compute_dtype_object |
| 233 | + self.eps_kernel = tf.zeros([self.last_dim, self.units], dtype=dtype) |
| 234 | + self.eps_bias = tf.zeros([self.units], dtype=dtype) |
| 235 | + |
| 236 | + def call(self, inputs, reset_noise=True, remove_noise=False): |
| 237 | + # Generate fixed parameters added as the noise |
| 238 | + if remove_noise: |
| 239 | + self._remove_noise() |
| 240 | + elif reset_noise: |
| 241 | + self._reset_noise() |
| 242 | + |
| 243 | + # TODO(WindQAQ): Replace this with `dense()` once public. |
| 244 | + return super().call(inputs) |
| 245 | + |
| 246 | + def get_config(self): |
| 247 | + # TODO(WindQAQ): Get rid of this hacky way. |
| 248 | + config = super(tf.keras.layers.Dense, self).get_config() |
| 249 | + config.update( |
| 250 | + { |
| 251 | + "units": self.units, |
| 252 | + "sigma": self.sigma, |
| 253 | + "activation": activations.serialize(self.activation), |
| 254 | + "use_bias": self.use_bias, |
| 255 | + "kernel_regularizer": regularizers.serialize(self.kernel_regularizer), |
| 256 | + "bias_regularizer": regularizers.serialize(self.bias_regularizer), |
| 257 | + "activity_regularizer": regularizers.serialize( |
| 258 | + self.activity_regularizer |
| 259 | + ), |
| 260 | + "kernel_constraint": constraints.serialize(self.kernel_constraint), |
| 261 | + "bias_constraint": constraints.serialize(self.bias_constraint), |
| 262 | + } |
| 263 | + ) |
| 264 | + return config |
0 commit comments