Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tensorflow_addons/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@
from __future__ import print_function

# Weight Normalization Wrapper
from tensorflow_addons.layers.python.wrappers import WeightNorm
from tensorflow_addons.layers.python.wrappers import WeightNormalization
20 changes: 10 additions & 10 deletions tensorflow_addons/layers/python/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,22 @@
from tensorflow.python.ops import variables as tf_variables


class WeightNorm(Wrapper):
class WeightNormalization(Wrapper):
""" This wrapper reparameterizes a layer by decoupling the weight's
magnitude and direction. This speeds up convergence by improving the
conditioning of the optimization problem.
Weight Normalization: A Simple Reparameterization to Accelerate
Training of Deep Neural Networks: https://arxiv.org/abs/1602.07868
Tim Salimans, Diederik P. Kingma (2016)
WeightNorm wrapper works for keras and tf layers.
WeightNormalization wrapper works for keras and tf layers.
```python
net = WeightNorm(tf.keras.layers.Conv2D(2, 2, activation='relu'),
net = WeightNormalization(tf.keras.layers.Conv2D(2, 2, activation='relu'),
input_shape=(32, 32, 3), data_init=True)(x)
net = WeightNorm(tf.keras.layers.Conv2D(16, 5, activation='relu'),
net = WeightNormalization(tf.keras.layers.Conv2D(16, 5, activation='relu'),
data_init=True)(net)
net = WeightNorm(tf.keras.layers.Dense(120, activation='relu'),
net = WeightNormalization(tf.keras.layers.Dense(120, activation='relu'),
data_init=True)(net)
net = WeightNorm(tf.keras.layers.Dense(n_classes),
net = WeightNormalization(tf.keras.layers.Dense(n_classes),
data_init=True)(net)
```
Arguments:
Expand All @@ -55,7 +55,7 @@ class WeightNorm(Wrapper):
def __init__(self, layer, data_init=False, **kwargs):
if not isinstance(layer, Layer):
raise ValueError(
'Please initialize `WeightNorm` layer with a '
'Please initialize `WeightNormalization` layer with a '
'`Layer` instance. You passed: {input}'.format(input=layer))

if not context.executing_eagerly() and data_init:
Expand All @@ -67,7 +67,7 @@ def __init__(self, layer, data_init=False, **kwargs):
if data_init:
self.initialized = False

super(WeightNorm, self).__init__(layer, **kwargs)
super(WeightNormalization, self).__init__(layer, **kwargs)
self._track_checkpointable(layer, name='layer')

def _compute_weights(self):
Expand Down Expand Up @@ -114,7 +114,7 @@ def build(self, input_shape):

if not hasattr(self.layer, 'kernel'):
raise ValueError(
'`WeightNorm` must wrap a layer that'
'`WeightNormalization` must wrap a layer that'
' contains a `kernel` for weights'
)

Expand All @@ -137,7 +137,7 @@ def build(self, input_shape):

self.layer.built = True

super(WeightNorm, self).build()
super(WeightNormalization, self).build()
self.built = True

def call(self, inputs):
Expand Down
12 changes: 6 additions & 6 deletions tensorflow_addons/layers/python/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@
from tensorflow.python import keras


class WeightNormTest(test.TestCase):
class WeightNormalizationTest(test.TestCase):

@tf_test_util.run_all_in_graph_and_eager_modes
def test_weightnorm_dense_train(self):
model = keras.models.Sequential()
model.add(wrappers.WeightNorm(
model.add(wrappers.WeightNormalization(
keras.layers.Dense(2), input_shape=(3, 4)))

model.compile(optimizer=RMSPropOptimizer(0.01), loss='mse')
Expand All @@ -48,7 +48,7 @@ def test_weightnorm_dense_train(self):
@tf_test_util.run_all_in_graph_and_eager_modes
def test_weightnorm_conv2d(self):
model = keras.models.Sequential()
model.add(wrappers.WeightNorm(
model.add(wrappers.WeightNormalization(
keras.layers.Conv2D(5, (2, 2), padding='same'),
input_shape=(4, 4, 3)))

Expand All @@ -63,7 +63,7 @@ def test_weightnorm_conv2d(self):
@tf_test_util.run_all_in_graph_and_eager_modes
def test_weight_norm_tflayers(self):
images = random_ops.random_uniform((2, 4, 4, 3))
wn_wrapper = wrappers.WeightNorm(layers.Conv2D(32, [2, 2]),
wn_wrapper = wrappers.WeightNormalization(layers.Conv2D(32, [2, 2]),
input_shape=(4, 4, 3))
wn_wrapper.apply(images)
self.assertTrue(hasattr(wn_wrapper.layer, 'g'))
Expand All @@ -72,12 +72,12 @@ def test_weight_norm_tflayers(self):
def test_weight_norm_nonlayer(self):
images = random_ops.random_uniform((2, 4, 43))
with self.assertRaises(ValueError):
wrappers.WeightNorm(images)
wrappers.WeightNormalization(images)

@tf_test_util.run_all_in_graph_and_eager_modes
def test_weight_norm_nokernel(self):
with self.assertRaises(ValueError):
wrappers.WeightNorm(layers.MaxPooling2D(2, 2)).build((2, 2))
wrappers.WeightNormalization(layers.MaxPooling2D(2, 2)).build((2, 2))


if __name__ == "__main__":
Expand Down