Skip to content

Commit 1c3c072

Browse files
authored
Added support for noisy dense layers. (#2099)
* Create noisy_dense.py * Create noisy_dense_test.py * Update __init__.py * Fix minor typo * Update noisy_dense_test.py * Update comments * Update comments * Update noisy_dense.py * fix typo * Update noisy_dense.py * Update noisy_dense_test.py * Fix compliance issues * Fix compliance issues * Update comments * Fix typo * Update CODEOWNERS * Update CODEOWNERS * add use bias to config * Update noisy_dense.py * Update CODEOWNERS * Revert "Update CODEOWNERS" This reverts commit 82e979f. * Update noisy_dense.py * Update noisy_dense.py * Update noisy_dense.py * Update noisy_dense.py * Revert "Update CODEOWNERS" This reverts commit 840ab1c. * Revert "Revert "Update CODEOWNERS"" This reverts commit 7852e62. * Update noisy_dense.py * Code reformatted with updated black * Update noisy_dense.py * Update noisy_dense.py * Update noisy_dense.py * Added support for manual noise reset * support for noise removal * tests for noise removal * use typecheck and remove unicode, * fix typo and code cleanup * control noise removal through call * Inherit from Dense instead of Layer * Added missing comment * Documentation and test improvement * fix typo * minor formatting changes * minor formatting fix Co-authored-by: schaall <[email protected]>
1 parent 8320f16 commit 1c3c072

File tree

4 files changed

+408
-0
lines changed

4 files changed

+408
-0
lines changed

.github/CODEOWNERS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@
105105
/tensorflow_addons/layers/tests/esn_test.py @pedrolarben
106106
/tensorflow_addons/layers/snake.py @failure-to-thrive
107107
/tensorflow_addons/layers/tests/snake_test.py @failure-to-thrive
108+
/tensorflow_addons/layers/noisy_dense.py @leonshams
109+
/tensorflow_addons/layers/tests/noisy_dense_test.py @leonshams
108110

109111
/tensorflow_addons/losses/contrastive.py @windqaq
110112
/tensorflow_addons/losses/tests/contrastive_test.py @windqaq

tensorflow_addons/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,4 @@
3838
from tensorflow_addons.layers.tlu import TLU
3939
from tensorflow_addons.layers.wrappers import WeightNormalization
4040
from tensorflow_addons.layers.esn import ESN
41+
from tensorflow_addons.layers.noisy_dense import NoisyDense
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
import tensorflow as tf
17+
from tensorflow.keras import (
18+
activations,
19+
initializers,
20+
regularizers,
21+
constraints,
22+
)
23+
from tensorflow.keras import backend as K
24+
from tensorflow.keras.layers import InputSpec
25+
from typeguard import typechecked
26+
27+
from tensorflow_addons.utils import types
28+
29+
30+
def _scale_noise(x):
31+
return tf.sign(x) * tf.sqrt(tf.abs(x))
32+
33+
34+
@tf.keras.utils.register_keras_serializable(package="Addons")
35+
class NoisyDense(tf.keras.layers.Dense):
36+
r"""Noisy dense layer that injects random noise to the weights of dense layer.
37+
38+
Noisy dense layers are fully connected layers whose weights and biases are
39+
augmented by factorised Gaussian noise. The factorised Gaussian noise is
40+
controlled through gradient descent by a second weights layer.
41+
42+
A `NoisyDense` layer implements the operation:
43+
$$
44+
\mathrm{NoisyDense}(x) =
45+
\mathrm{activation}(\mathrm{dot}(x, \mu + (\sigma \cdot \epsilon))
46+
+ \mathrm{bias})
47+
$$
48+
where $\mu$ is the standard weights layer, $\epsilon$ is the factorised
49+
Gaussian noise, and $\sigma$ is a second weights layer which controls
50+
$\epsilon$.
51+
52+
Note: bias only added if `use_bias` is `True`.
53+
54+
Example:
55+
56+
>>> # Create a `Sequential` model and add a NoisyDense
57+
>>> # layer as the first layer.
58+
>>> model = tf.keras.models.Sequential()
59+
>>> model.add(tf.keras.Input(shape=(16,)))
60+
>>> model.add(NoisyDense(32, activation='relu'))
61+
>>> # Now the model will take as input arrays of shape (None, 16)
62+
>>> # and output arrays of shape (None, 32).
63+
>>> # Note that after the first layer, you don't need to specify
64+
>>> # the size of the input anymore:
65+
>>> model.add(NoisyDense(32))
66+
>>> model.output_shape
67+
(None, 32)
68+
69+
Arguments:
70+
units: Positive integer, dimensionality of the output space.
71+
sigma: A float between 0-1 used as a standard deviation figure and is
72+
applied to the gaussian noise layer (`sigma_kernel` and `sigma_bias`).
73+
activation: Activation function to use.
74+
If you don't specify anything, no activation is applied
75+
(ie. "linear" activation: `a(x) = x`).
76+
use_bias: Boolean, whether the layer uses a bias vector.
77+
kernel_regularizer: Regularizer function applied to
78+
the `kernel` weights matrix.
79+
bias_regularizer: Regularizer function applied to the bias vector.
80+
activity_regularizer: Regularizer function applied to
81+
the output of the layer (its "activation").
82+
kernel_constraint: Constraint function applied to
83+
the `kernel` weights matrix.
84+
bias_constraint: Constraint function applied to the bias vector.
85+
86+
Input shape:
87+
N-D tensor with shape: `(batch_size, ..., input_dim)`.
88+
The most common situation would be
89+
a 2D input with shape `(batch_size, input_dim)`.
90+
91+
Output shape:
92+
N-D tensor with shape: `(batch_size, ..., units)`.
93+
For instance, for a 2D input with shape `(batch_size, input_dim)`,
94+
the output would have shape `(batch_size, units)`.
95+
96+
References:
97+
- [Noisy Networks for Explanation](https://arxiv.org/pdf/1706.10295.pdf)
98+
"""
99+
100+
@typechecked
101+
def __init__(
102+
self,
103+
units: int,
104+
sigma: float = 0.5,
105+
activation: types.Activation = None,
106+
use_bias: bool = True,
107+
kernel_regularizer: types.Regularizer = None,
108+
bias_regularizer: types.Regularizer = None,
109+
activity_regularizer: types.Regularizer = None,
110+
kernel_constraint: types.Constraint = None,
111+
bias_constraint: types.Constraint = None,
112+
**kwargs
113+
):
114+
super().__init__(
115+
units=units,
116+
activation=activation,
117+
use_bias=use_bias,
118+
kernel_regularizer=kernel_regularizer,
119+
bias_regularizer=bias_regularizer,
120+
activity_regularizer=activity_regularizer,
121+
kernel_constraint=kernel_constraint,
122+
bias_constraint=bias_constraint,
123+
**kwargs,
124+
)
125+
delattr(self, "kernel_initializer")
126+
delattr(self, "bias_initializer")
127+
self.sigma = sigma
128+
129+
def build(self, input_shape):
130+
# Make sure dtype is correct
131+
dtype = tf.dtypes.as_dtype(self.dtype or K.floatx())
132+
if not (dtype.is_floating or dtype.is_complex):
133+
raise TypeError(
134+
"Unable to build `Dense` layer with non-floating point "
135+
"dtype %s" % (dtype,)
136+
)
137+
138+
input_shape = tf.TensorShape(input_shape)
139+
self.last_dim = tf.compat.dimension_value(input_shape[-1])
140+
sqrt_dim = self.last_dim ** (1 / 2)
141+
if self.last_dim is None:
142+
raise ValueError(
143+
"The last dimension of the inputs to `Dense` "
144+
"should be defined. Found `None`."
145+
)
146+
self.input_spec = InputSpec(min_ndim=2, axes={-1: self.last_dim})
147+
148+
sigma_init = initializers.Constant(value=self.sigma / sqrt_dim)
149+
mu_init = initializers.RandomUniform(minval=-1 / sqrt_dim, maxval=1 / sqrt_dim)
150+
151+
# Learnable parameters
152+
self.sigma_kernel = self.add_weight(
153+
"sigma_kernel",
154+
shape=[self.last_dim, self.units],
155+
initializer=sigma_init,
156+
regularizer=self.kernel_regularizer,
157+
constraint=self.kernel_constraint,
158+
dtype=self.dtype,
159+
trainable=True,
160+
)
161+
162+
self.mu_kernel = self.add_weight(
163+
"mu_kernel",
164+
shape=[self.last_dim, self.units],
165+
initializer=mu_init,
166+
regularizer=self.kernel_regularizer,
167+
constraint=self.kernel_constraint,
168+
dtype=self.dtype,
169+
trainable=True,
170+
)
171+
172+
if self.use_bias:
173+
self.sigma_bias = self.add_weight(
174+
"sigma_bias",
175+
shape=[
176+
self.units,
177+
],
178+
initializer=sigma_init,
179+
regularizer=self.bias_regularizer,
180+
constraint=self.bias_constraint,
181+
dtype=self.dtype,
182+
trainable=True,
183+
)
184+
185+
self.mu_bias = self.add_weight(
186+
"mu_bias",
187+
shape=[
188+
self.units,
189+
],
190+
initializer=mu_init,
191+
regularizer=self.bias_regularizer,
192+
constraint=self.bias_constraint,
193+
dtype=self.dtype,
194+
trainable=True,
195+
)
196+
else:
197+
self.sigma_bias = None
198+
self.mu_bias = None
199+
self._reset_noise()
200+
self.built = True
201+
202+
@property
203+
def kernel(self):
204+
return self.mu_kernel + (self.sigma_kernel * self.eps_kernel)
205+
206+
@property
207+
def bias(self):
208+
if self.use_bias:
209+
return self.mu_bias + (self.sigma_bias * self.eps_bias)
210+
211+
def _reset_noise(self):
212+
"""Create the factorised Gaussian noise."""
213+
214+
dtype = self._compute_dtype_object
215+
216+
# Generate random noise
217+
eps_i = tf.random.normal([self.last_dim, self.units], dtype=dtype)
218+
eps_j = tf.random.normal(
219+
[
220+
self.units,
221+
],
222+
dtype=dtype,
223+
)
224+
225+
# Scale the random noise
226+
self.eps_kernel = _scale_noise(eps_i) * _scale_noise(eps_j)
227+
self.eps_bias = _scale_noise(eps_j)
228+
229+
def _remove_noise(self):
230+
"""Remove the factorised Gaussian noise."""
231+
232+
dtype = self._compute_dtype_object
233+
self.eps_kernel = tf.zeros([self.last_dim, self.units], dtype=dtype)
234+
self.eps_bias = tf.zeros([self.units], dtype=dtype)
235+
236+
def call(self, inputs, reset_noise=True, remove_noise=False):
237+
# Generate fixed parameters added as the noise
238+
if remove_noise:
239+
self._remove_noise()
240+
elif reset_noise:
241+
self._reset_noise()
242+
243+
# TODO(WindQAQ): Replace this with `dense()` once public.
244+
return super().call(inputs)
245+
246+
def get_config(self):
247+
# TODO(WindQAQ): Get rid of this hacky way.
248+
config = super(tf.keras.layers.Dense, self).get_config()
249+
config.update(
250+
{
251+
"units": self.units,
252+
"sigma": self.sigma,
253+
"activation": activations.serialize(self.activation),
254+
"use_bias": self.use_bias,
255+
"kernel_regularizer": regularizers.serialize(self.kernel_regularizer),
256+
"bias_regularizer": regularizers.serialize(self.bias_regularizer),
257+
"activity_regularizer": regularizers.serialize(
258+
self.activity_regularizer
259+
),
260+
"kernel_constraint": constraints.serialize(self.kernel_constraint),
261+
"bias_constraint": constraints.serialize(self.bias_constraint),
262+
}
263+
)
264+
return config

0 commit comments

Comments
 (0)