Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 30 additions & 18 deletions tensorflow_addons/layers/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,32 +92,39 @@ def build(self, input_shape):
trainable=False)

if self.data_init:
self._naked_layer = tf.keras.layers.deserialize(
tf.keras.layers.serialize(self.layer))
self._naked_layer.build(input_shape)
self._naked_layer.set_weights(self.layer.get_weights())
self._naked_layer.activation = None
# Used for data initialization in self._data_dep_init.
layer_config = tf.keras.layers.serialize(self.layer)
layer_config['config']['trainable'] = False
self._naked_clone_layer = tf.keras.layers.deserialize(layer_config)
self._naked_clone_layer.build(input_shape)
self._naked_clone_layer.set_weights(self.layer.get_weights())
self._naked_clone_layer.activation = None

self.built = True

def call(self, inputs):
"""Call `Layer`"""

def _do_nothing():
return inputs
return tf.identity(self.g)

def _update_weights():
self._initialize_weights(inputs)
return inputs
# Ensure we read `self.g` after _update_weights.
with tf.control_dependencies(self._initialize_weights(inputs)):
return tf.identity(self.g)

inputs = tf.cond(self._initialized, _do_nothing, _update_weights)
g = tf.cond(self._initialized, _do_nothing, _update_weights)

with tf.name_scope('compute_weights'):
# Replace kernel by normalized weight variable.
self.layer.kernel = tf.nn.l2_normalize(
self.v, axis=self.kernel_norm_axes) * self.g
self.v, axis=self.kernel_norm_axes) * g

return self.layer(inputs)
# Ensure we calculate result after updating kernel.
update_kernel = tf.identity(self.layer.kernel)
with tf.control_dependencies([update_kernel]):
outputs = self.layer(inputs)
return outputs

def compute_output_shape(self, input_shape):
return tf.TensorShape(
Expand All @@ -136,31 +143,36 @@ def _initialize_weights(self, inputs):
message='The layer has been initialized.')
]):
if self.data_init:
self._data_dep_init(inputs)
assign_tensors = self._data_dep_init(inputs)
else:
self._init_norm()
self._initialized.assign(True)
assign_tensors = self._init_norm()
assign_tensors.append(self._initialized.assign(True))
return assign_tensors

def _init_norm(self):
"""Set the weight g with the norm of the weight vector."""
with tf.name_scope('init_norm'):
v_flat = tf.reshape(self.v, [-1, self.layer_depth])
v_norm = tf.linalg.norm(v_flat, axis=0)
self.g.assign(tf.reshape(v_norm, (self.layer_depth,)))
g_tensor = self.g.assign(tf.reshape(v_norm, (self.layer_depth,)))
return [g_tensor]

def _data_dep_init(self, inputs):
"""Data dependent initialization."""
with tf.name_scope('data_dep_init'):
# Generate data dependent init values
x_init = self._naked_layer(inputs)
x_init = self._naked_clone_layer(inputs)
data_norm_axes = list(range(x_init.shape.rank - 1))
m_init, v_init = tf.nn.moments(x_init, data_norm_axes)
scale_init = 1. / tf.math.sqrt(v_init + 1e-10)

# Assign data dependent init values
self.g.assign(self.g * scale_init)
g_tensor = self.g.assign(self.g * scale_init)
if hasattr(self.layer, 'bias'):
self.layer.bias.assign(-m_init * scale_init)
bias_tensor = self.layer.bias.assign(-m_init * scale_init)
return [g_tensor, bias_tensor]
else:
return [g_tensor]

def get_config(self):
config = {'data_init': self.data_init}
Expand Down
114 changes: 42 additions & 72 deletions tensorflow_addons/layers/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,82 +26,52 @@

@test_utils.run_all_in_graph_and_eager_modes
class WeightNormalizationTest(tf.test.TestCase):
def test_weightnorm_dense_train(self):
model = tf.keras.models.Sequential()
model.add(
wrappers.WeightNormalization(
tf.keras.layers.Dense(2), input_shape=(3, 4)))
model.compile(
optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001),
loss='mse')
model.fit(
np.random.random((10, 3, 4)),
np.random.random((10, 3, 2)),
epochs=3,
batch_size=10)
self.assertTrue(hasattr(model.layers[0], 'g'))

def test_weightnorm_dense_train_notinit(self):
model = tf.keras.models.Sequential()
model.add(
wrappers.WeightNormalization(
tf.keras.layers.Dense(2), input_shape=(3, 4), data_init=False))

model.compile(
optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001),
loss='mse')
model.fit(
np.random.random((10, 3, 4)),
np.random.random((10, 3, 2)),
epochs=3,
batch_size=10)
self.assertTrue(hasattr(model.layers[0], 'g'))

def test_weightnorm_conv2d(self):
model = tf.keras.models.Sequential()
model.add(
wrappers.WeightNormalization(
tf.keras.layers.Conv2D(5, (2, 2), padding='same'),
input_shape=(4, 4, 3)))

model.add(tf.keras.layers.Activation('relu'))
model.compile(
optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001),
loss='mse')
model.fit(
np.random.random((2, 4, 4, 3)),
np.random.random((2, 4, 4, 5)),
epochs=3,
batch_size=10)

self.assertTrue(hasattr(model.layers[0], 'g'))

def test_weightnorm_applylayer(self):
images = tf.random.uniform((2, 4, 4, 3))
wn_wrapper = wrappers.WeightNormalization(
tf.keras.layers.Conv2D(32, [2, 2]), input_shape=(4, 4, 3))
wn_wrapper.apply(images)
self.assertTrue(hasattr(wn_wrapper, 'g'))

def test_weightnorm_nonlayer(self):
images = tf.random.uniform((2, 4, 43))
with self.assertRaises(AssertionError):
wrappers.WeightNormalization(images)

def test_weightnorm_nokernel(self):
with self.assertRaises(ValueError):
wrappers.WeightNormalization(tf.keras.layers.MaxPooling2D(
2, 2)).build((2, 2))

def test_weightnorm_keras(self):
input_data = np.random.random((10, 3, 4)).astype(np.float32)
def test_weightnorm(self):
test_utils.layer_test(
wrappers.WeightNormalization,
kwargs={
'layer': tf.keras.layers.Conv2D(5, (2, 2)),
},
input_shape=(2, 4, 4, 3))

def _check_data_init(self, data_init, input_data, expected_output):
layer = tf.keras.layers.Dense(
input_data.shape[-1],
activation=None,
kernel_initializer='identity',
bias_initializer='zeros')
test_utils.layer_test(
wrappers.WeightNormalization,
kwargs={
'layer': tf.keras.layers.Dense(2),
'input_shape': (3, 4)
'layer': layer,
'data_init': data_init,
},
input_data=input_data)
input_data=input_data,
expected_output=expected_output)

def test_weightnorm_with_data_init_is_false(self):
input_data = np.array([[[-4, -4], [4, 4]]], dtype=np.float32)
self._check_data_init(
data_init=False, input_data=input_data, expected_output=input_data)

def test_weightnorm_with_data_init_is_true(self):
input_data = np.array([[[-4, -4], [4, 4]]], dtype=np.float32)
self._check_data_init(
data_init=True,
input_data=input_data,
expected_output=input_data / 4)

def test_weightnorm_non_layer(self):
images = tf.random.uniform((2, 4, 43))
with self.assertRaises(AssertionError):
wrappers.WeightNormalization(images)

def test_weightnorm_non_kernel_layer(self):
images = tf.random.uniform((2, 2, 2))
with self.assertRaisesRegexp(ValueError, 'contains a `kernel`'):
non_kernel_layer = tf.keras.layers.MaxPooling2D(2, 2)
wn_wrapper = wrappers.WeightNormalization(non_kernel_layer)
wn_wrapper(images)


if __name__ == "__main__":
Expand Down