Skip to content
81 changes: 81 additions & 0 deletions tensorflow_addons/layers/normalizations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def _test_specific_layer(self, inputs, axis, groups, center, scale):

def _create_and_fit_Sequential_model(self, layer, shape):
# Helperfunction for quick evaluation
np.random.seed(0x2020)
model = tf.keras.models.Sequential()
model.add(layer)
model.add(tf.keras.layers.Dense(32))
Expand Down Expand Up @@ -233,6 +234,7 @@ def test_regularizations(self):
def test_groupnorm_conv(self):
# Check if Axis is working for CONV nets
# Testing for 1 == LayerNorm, 5 == GroupNorm, -1 == InstanceNorm
np.random.seed(0x2020)
groups = [-1, 5, 1]
for i in groups:
model = tf.keras.models.Sequential()
Expand All @@ -246,6 +248,85 @@ def test_groupnorm_conv(self):
model.fit(x=x, y=y, epochs=1)
self.assertTrue(hasattr(model.layers[0], "gamma"))

def test_groupnorm_correctness_1d(self):
np.random.seed(0x2020)
model = tf.keras.models.Sequential()
norm = GroupNormalization(input_shape=(10,), groups=2)
model.add(norm)
model.compile(loss="mse", optimizer="rmsprop")

x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10))
model.fit(x, x, epochs=5, verbose=0)
out = model.predict(x)
out -= self.evaluate(norm.beta)
out /= self.evaluate(norm.gamma)

self.assertAllClose(out.mean(), 0.0, atol=1e-1)
self.assertAllClose(out.std(), 1.0, atol=1e-1)

def test_groupnorm_2d_different_groups(self):
np.random.seed(0x2020)
groups = [2, 1, 10]
for i in groups:
model = tf.keras.models.Sequential()
norm = GroupNormalization(axis=1, groups=i, input_shape=(10, 3))
model.add(norm)
# centered and variance are 5.0 and 10.0, respectively
model.compile(loss="mse", optimizer="rmsprop")
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10, 3))
model.fit(x, x, epochs=5, verbose=0)
out = model.predict(x)
out -= np.reshape(self.evaluate(norm.beta), (1, 10, 1))
out /= np.reshape(self.evaluate(norm.gamma), (1, 10, 1))

self.assertAllClose(
out.mean(axis=(0, 1), dtype=np.float32), (0.0, 0.0, 0.0), atol=1e-1
)
self.assertAllClose(
out.std(axis=(0, 1), dtype=np.float32), (1.0, 1.0, 1.0), atol=1e-1
)

def test_groupnorm_convnet(self):
np.random.seed(0x2020)
model = tf.keras.models.Sequential()
norm = GroupNormalization(axis=1, input_shape=(3, 4, 4), groups=3)
model.add(norm)
model.compile(loss="mse", optimizer="sgd")

# centered = 5.0, variance = 10.0
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 3, 4, 4))
model.fit(x, x, epochs=4, verbose=0)
out = model.predict(x)
out -= np.reshape(self.evaluate(norm.beta), (1, 3, 1, 1))
out /= np.reshape(self.evaluate(norm.gamma), (1, 3, 1, 1))

self.assertAllClose(
np.mean(out, axis=(0, 2, 3), dtype=np.float32), (0.0, 0.0, 0.0), atol=1e-1
)
self.assertAllClose(
np.std(out, axis=(0, 2, 3), dtype=np.float32), (1.0, 1.0, 1.0), atol=1e-1
)

def test_groupnorm_convnet_no_center_no_scale(self):
np.random.seed(0x2020)
model = tf.keras.models.Sequential()
norm = GroupNormalization(
axis=-1, groups=2, center=False, scale=False, input_shape=(3, 4, 4)
)
model.add(norm)
model.compile(loss="mse", optimizer="sgd")
# centered and variance are 5.0 and 10.0, respectively
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 3, 4, 4))
model.fit(x, x, epochs=4, verbose=0)
out = model.predict(x)

self.assertAllClose(
np.mean(out, axis=(0, 2, 3), dtype=np.float32), (0.0, 0.0, 0.0), atol=1e-1
)
self.assertAllClose(
np.std(out, axis=(0, 2, 3), dtype=np.float32), (1.0, 1.0, 1.0), atol=1e-1
)


if __name__ == "__main__":
tf.test.main()