Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 45 additions & 10 deletions tensorflow_addons/layers/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from __future__ import division
from __future__ import print_function

import logging

import tensorflow as tf


Expand Down Expand Up @@ -61,6 +63,11 @@ def __init__(self, layer, data_init=True, **kwargs):
self._init_critical_section = tf.CriticalSection(name='init_mutex')
self.is_rnn = isinstance(self.layer, tf.keras.layers.RNN)

if self.data_init and self.is_rnn:
logging.warn(
"WeightNormalization: Using `data_init=True` with RNNs "
"is advised against by the paper. Use `data_init=False`.")

def build(self, input_shape):
"""Build `Layer`"""
input_shape = tf.TensorShape(input_shape)
Expand All @@ -76,17 +83,22 @@ def build(self, input_shape):
raise ValueError('`WeightNormalization` must wrap a layer that'
' contains a `kernel` for weights')

if self.is_rnn:
kernel = kernel_layer.recurrent_kernel
else:
kernel = kernel_layer.kernel

# The kernel's filter or unit dimension is -1
self.layer_depth = int(kernel_layer.kernel.shape[-1])
self.kernel_norm_axes = list(range(kernel_layer.kernel.shape.rank - 1))
self.layer_depth = int(kernel.shape[-1])
self.kernel_norm_axes = list(range(kernel.shape.rank - 1))

self.g = self.add_weight(
name='g',
shape=(self.layer_depth,),
initializer='ones',
dtype=kernel_layer.kernel.dtype,
dtype=kernel.dtype,
trainable=True)
self.v = kernel_layer.kernel
self.v = kernel

self._initialized = self.add_weight(
name='initialized',
Expand All @@ -104,9 +116,7 @@ def build(self, input_shape):
layer_config)
self._naked_clone_layer.build(input_shape)
self._naked_clone_layer.set_weights(self.layer.get_weights())
if self.is_rnn:
self._naked_clone_layer.cell.activation = None
else:
if not self.is_rnn:
self._naked_clone_layer.activation = None

self.built = True
Expand All @@ -127,11 +137,16 @@ def _update_weights():

with tf.name_scope('compute_weights'):
# Replace kernel by normalized weight variable.
self.layer.kernel = tf.nn.l2_normalize(
self.v, axis=self.kernel_norm_axes) * g
kernel = tf.nn.l2_normalize(self.v, axis=self.kernel_norm_axes) * g

if self.is_rnn:
self.layer.cell.recurrent_kernel = kernel
update_kernel = tf.identity(self.layer.cell.recurrent_kernel)
else:
self.layer.kernel = kernel
update_kernel = tf.identity(self.layer.kernel)

# Ensure we calculate result after updating kernel.
update_kernel = tf.identity(self.layer.kernel)
with tf.control_dependencies([update_kernel]):
outputs = self.layer(inputs)
return outputs
Expand Down Expand Up @@ -176,6 +191,14 @@ def _data_dep_init(self, inputs):
m_init, v_init = tf.nn.moments(x_init, data_norm_axes)
scale_init = 1. / tf.math.sqrt(v_init + 1e-10)

# RNNs have fused kernels that are tiled
# Repeat scale_init to match the shape of fused kernel
# Note: This is only to support the operation,
# the paper advises against RNN+data_dep_init
if scale_init.shape[0] != self.g.shape[0]:
rep = int(self.g.shape[0] / scale_init.shape[0])
scale_init = tf.tile(scale_init, [rep])

# Assign data dependent init values
g_tensor = self.g.assign(self.g * scale_init)
if hasattr(self.layer, 'bias') and self.layer.bias is not None:
Expand All @@ -188,3 +211,15 @@ def get_config(self):
config = {'data_init': self.data_init}
base_config = super(WeightNormalization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

def remove(self):
kernel = tf.Variable(
tf.nn.l2_normalize(self.v, axis=self.kernel_norm_axes) * self.g,
name='recurrent_kernel' if self.is_rnn else 'kernel')

if self.is_rnn:
self.layer.cell.recurrent_kernel = kernel
else:
self.layer.kernel = kernel

return self.layer
108 changes: 89 additions & 19 deletions tensorflow_addons/layers/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from __future__ import division
from __future__ import print_function

from absl.testing import parameterized

import numpy as np
import tensorflow as tf

Expand All @@ -25,16 +27,16 @@


@test_utils.run_all_in_graph_and_eager_modes
class WeightNormalizationTest(tf.test.TestCase):
def test_weightnorm(self):
class WeightNormalizationTest(tf.test.TestCase, parameterized.TestCase):
def test_basic(self):
test_utils.layer_test(
wrappers.WeightNormalization,
kwargs={
'layer': tf.keras.layers.Conv2D(5, (2, 2)),
},
input_shape=(2, 4, 4, 3))

def test_weightnorm_no_bias(self):
def test_no_bias(self):
test_utils.layer_test(
wrappers.WeightNormalization,
kwargs={
Expand All @@ -57,53 +59,121 @@ def _check_data_init(self, data_init, input_data, expected_output):
input_data=input_data,
expected_output=expected_output)

def test_weightnorm_with_data_init_is_false(self):
def test_with_data_init_is_false(self):
input_data = np.array([[[-4, -4], [4, 4]]], dtype=np.float32)
self._check_data_init(
data_init=False, input_data=input_data, expected_output=input_data)

def test_weightnorm_with_data_init_is_true(self):
def test_with_data_init_is_true(self):
input_data = np.array([[[-4, -4], [4, 4]]], dtype=np.float32)
self._check_data_init(
data_init=True,
input_data=input_data,
expected_output=input_data / 4)

def test_weightnorm_non_layer(self):
def test_non_layer(self):
images = tf.random.uniform((2, 4, 43))
with self.assertRaises(AssertionError):
wrappers.WeightNormalization(images)

def test_weightnorm_non_kernel_layer(self):
def test_non_kernel_layer(self):
images = tf.random.uniform((2, 2, 2))
with self.assertRaisesRegexp(ValueError, 'contains a `kernel`'):
non_kernel_layer = tf.keras.layers.MaxPooling2D(2, 2)
wn_wrapper = wrappers.WeightNormalization(non_kernel_layer)
wn_wrapper(images)

def test_weightnorm_with_time_dist(self):
def test_with_time_dist(self):
batch_shape = (32, 16, 64, 64, 3)
inputs = tf.keras.layers.Input(batch_shape=batch_shape)
a = tf.keras.layers.Conv2D(3, 5)
b = wrappers.WeightNormalization(a)
out = tf.keras.layers.TimeDistributed(b)(inputs)
model = tf.keras.Model(inputs, out)

def test_weightnorm_with_rnn(self):
inputs = tf.keras.layers.Input(shape=(None, 3))
rnn_layer = tf.keras.layers.SimpleRNN(4)
wt_rnn = wrappers.WeightNormalization(rnn_layer)
dense = tf.keras.layers.Dense(1)
model = tf.keras.models.Sequential(layers=[inputs, wt_rnn, dense])

def test_save_file_h5(self):
@parameterized.named_parameters(
["Dense", lambda: tf.keras.layers.Dense(1), False],
["SimpleRNN", lambda: tf.keras.layers.SimpleRNN(1), True],
["Conv2D", lambda: tf.keras.layers.Conv2D(3, 1), False],
["LSTM", lambda: tf.keras.layers.LSTM(1), True])
def test_serialization(self, base_layer, rnn):
base_layer = base_layer()
wn_layer = wrappers.WeightNormalization(base_layer, not rnn)
new_wn_layer = tf.keras.layers.deserialize(
tf.keras.layers.serialize(wn_layer))
self.assertEqual(wn_layer.data_init, new_wn_layer.data_init)
self.assertEqual(wn_layer.is_rnn, new_wn_layer.is_rnn)
self.assertEqual(wn_layer.is_rnn, rnn)
if not isinstance(base_layer, tf.keras.layers.LSTM):
# Issue with LSTM serialization, check with TF-core
# Before serialization: tensorflow.python.keras.layers.recurrent_v2.LSTM
# After serialization: tensorflow.python.keras.layers.recurrent.LSTM
self.assertTrue(
isinstance(new_wn_layer.layer, base_layer.__class__))

@parameterized.named_parameters(
["Dense", lambda: tf.keras.layers.Dense(1), [25]],
["SimpleRNN", lambda: tf.keras.layers.SimpleRNN(1), [None, 10]],
["Conv2D", lambda: tf.keras.layers.Conv2D(3, 1), [3, 3, 1]],
["LSTM", lambda: tf.keras.layers.LSTM(1), [10, 10]])
def test_model_build(self, base_layer_fn, input_shape):
inputs = tf.keras.layers.Input(shape=input_shape)
for data_init in [True, False]:
base_layer = base_layer_fn()
wt_layer = wrappers.WeightNormalization(base_layer, data_init)
model = tf.keras.models.Sequential(layers=[inputs, wt_layer])
model.build()

@parameterized.named_parameters(
["Dense", lambda: tf.keras.layers.Dense(1), [25]],
["SimpleRNN", lambda: tf.keras.layers.SimpleRNN(1), [10, 10]],
["Conv2D", lambda: tf.keras.layers.Conv2D(3, 1), [3, 3, 1]],
["LSTM", lambda: tf.keras.layers.LSTM(1), [10, 10]])
def test_save_file_h5(self, base_layer, input_shape):
self.create_tempfile('wrapper_test_model.h5')
conv = tf.keras.layers.Conv1D(1, 1)
wn_conv = wrappers.WeightNormalization(conv)
base_layer = base_layer()
wn_conv = wrappers.WeightNormalization(base_layer)
model = tf.keras.Sequential(layers=[wn_conv])
model.build([1, 2, 3])
model.build([None] + input_shape)
model.save_weights('wrapper_test_model.h5')

@parameterized.named_parameters(
["Dense", lambda: tf.keras.layers.Dense(1), [25]],
["SimpleRNN", lambda: tf.keras.layers.SimpleRNN(1), [10, 10]],
["Conv2D", lambda: tf.keras.layers.Conv2D(3, 1), [3, 3, 1]],
["LSTM", lambda: tf.keras.layers.LSTM(1), [10, 10]])
def test_forward_pass(self, base_layer, input_shape):
sample_data = np.ones([1] + input_shape, dtype=np.float32)
base_layer = base_layer()
base_output = base_layer(sample_data)
wn_layer = wrappers.WeightNormalization(base_layer, False)
wn_output = wn_layer(sample_data)
self.evaluate(tf.compat.v1.global_variables_initializer())
self.assertAllClose(
self.evaluate(base_output), self.evaluate(wn_output))

@parameterized.named_parameters(
["Dense", lambda: tf.keras.layers.Dense(1), [25]],
["SimpleRNN", lambda: tf.keras.layers.SimpleRNN(1), [10, 10]],
["Conv2D", lambda: tf.keras.layers.Conv2D(3, 1), [3, 3, 1]],
["LSTM", lambda: tf.keras.layers.LSTM(1), [10, 10]])
def test_removal(self, base_layer_fn, input_shape):
sample_data = np.ones([1] + input_shape, dtype=np.float32)

for data_init in [True, False]:
base_layer = base_layer_fn()
wn_layer = wrappers.WeightNormalization(base_layer, data_init)
wn_output = wn_layer(sample_data)
self.evaluate(tf.compat.v1.global_variables_initializer())
with tf.control_dependencies([wn_output]):
wn_removed_layer = wn_layer.remove()
wn_removed_output = wn_removed_layer(sample_data)

self.evaluate(tf.compat.v1.global_variables_initializer())
self.assertAllClose(
self.evaluate(wn_removed_output), self.evaluate(wn_output))
self.assertTrue(isinstance(wn_removed_layer, base_layer.__class__))


if __name__ == "__main__":
tf.test.main()