Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 61 additions & 48 deletions tensorflow_addons/layers/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ class WeightNormalization(tf.keras.layers.Wrapper):
def __init__(self, layer, data_init=True, **kwargs):
super(WeightNormalization, self).__init__(layer, **kwargs)
self.data_init = data_init
self._initialized = False
self._track_trackable(layer, name='layer')

def build(self, input_shape):
Expand All @@ -69,85 +68,99 @@ def build(self, input_shape):
if not self.layer.built:
self.layer.build(input_shape)

if not hasattr(self.layer, 'kernel'):
raise ValueError('`WeightNormalization` must wrap a layer that'
' contains a `kernel` for weights')
if not hasattr(self.layer, 'kernel'):
raise ValueError('`WeightNormalization` must wrap a layer that'
' contains a `kernel` for weights')

# The kernel's filter or unit dimension is -1
self.layer_depth = int(self.layer.kernel.shape[-1])
self.kernel_norm_axes = list(range(self.layer.kernel.shape.rank - 1))

self.g = self.add_variable(
name='g',
shape=(self.layer_depth,),
initializer='ones',
dtype=self.layer.kernel.dtype,
trainable=True)
self.v = self.layer.kernel

self._initialized = self.add_variable(
name='initialized',
shape=None,
initializer='zeros',
dtype=tf.dtypes.bool,
trainable=False)

# The kernel's filter or unit dimension is -1
self.layer_depth = int(self.layer.kernel.shape[-1])
self.kernel_norm_axes = list(
range(self.layer.kernel.shape.rank - 1))

self.v = self.layer.kernel
self.g = self.add_variable(
name="g",
shape=(self.layer_depth,),
initializer=tf.keras.initializers.get('ones'),
dtype=self.layer.kernel.dtype,
trainable=True)
if self.data_init:
self._naked_layer = tf.keras.layers.deserialize(
tf.keras.layers.serialize(self.layer))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering if you could comment on this. Creating a copy of the layer for the data init, but haven't seen this trick before. I guess no other way to copy a layer?

Looks great... thank you very much!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will do, Sean

self._naked_layer.build(input_shape)
self._naked_layer.set_weights(self.layer.get_weights())
self._naked_layer.activation = None

super(WeightNormalization, self).build()
self.built = True

def call(self, inputs):
"""Call `Layer`"""
if not self._initialized:
self._initialize_weights(inputs)

self._compute_weights() # Recompute weights for each forward pass
output = self.layer(inputs)
return output
def _do_nothing():
return inputs

def compute_output_shape(self, input_shape):
return tf.TensorShape(
self.layer.compute_output_shape(input_shape).as_list())
def _update_weights():
self._initialize_weights(inputs)
return inputs

def _compute_weights(self):
"""Generate normalized weights.
inputs = tf.cond(self._initialized, _do_nothing, _update_weights)

This method will update the value of self.layer.kernel with the
normalized value, so that the layer is ready for call().
"""
with tf.name_scope('compute_weights'):
# Replace kernel by normalized weight variable.
self.layer.kernel = tf.nn.l2_normalize(
self.v, axis=self.kernel_norm_axes) * self.g

return self.layer(inputs)

def compute_output_shape(self, input_shape):
return tf.TensorShape(
self.layer.compute_output_shape(input_shape).as_list())

def _initialize_weights(self, inputs):
"""Initialize weight g.

The initial value of g could either from the initial value in v,
or by the input value if self.data_init is True.
"""
if self.data_init:
self._data_dep_init(inputs)
else:
self._init_norm()
self._initialized = True
with tf.control_dependencies([
tf.debugging.assert_equal( # pylint: disable=bad-continuation
self._initialized,
False,
message='The layer has been initialized.')
]):
if self.data_init:
self._data_dep_init(inputs)
else:
self._init_norm()
self._initialized.assign(True)

def _init_norm(self):
"""Set the weight g with the norm of the weight vector."""
with tf.name_scope('init_norm'):
flat = tf.reshape(self.v, [-1, self.layer_depth])
self.g.assign(
tf.reshape(tf.linalg.norm(flat, axis=0), (self.layer_depth,)))
v_flat = tf.reshape(self.v, [-1, self.layer_depth])
v_norm = tf.linalg.norm(v_flat, axis=0)
self.g.assign(tf.reshape(v_norm, (self.layer_depth,)))

# TODO: Get data init to work with tf_function compile #428
def _data_dep_init(self, inputs):
"""Data dependent initialization."""

with tf.name_scope('data_dep_init'):
# Generate data dependent init values
existing_activation = self.layer.activation
self.layer.activation = None
x_init = self.layer(inputs)
x_init = self._naked_layer(inputs)
data_norm_axes = list(range(x_init.shape.rank - 1))
m_init, v_init = tf.nn.moments(x_init, data_norm_axes)
scale_init = 1. / tf.math.sqrt(v_init + 1e-10)

# Assign data dependent init values
self.g = self.g * scale_init
if hasattr(self.layer, 'bias'):
self.layer.bias = -m_init * scale_init
self.layer.activation = existing_activation
# Assign data dependent init values
self.g.assign(self.g * scale_init)
if hasattr(self.layer, 'bias'):
self.layer.bias.assign(-m_init * scale_init)

def get_config(self):
config = {'data_init': self.data_init}
Expand Down
13 changes: 3 additions & 10 deletions tensorflow_addons/layers/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,14 @@

@test_utils.run_all_in_graph_and_eager_modes
class WeightNormalizationTest(tf.test.TestCase):
# TODO: Get data init to work with tf_function compile #428
def test_weightnorm_dense_train(self):
model = tf.keras.models.Sequential()
model.add(
wrappers.WeightNormalization(
tf.keras.layers.Dense(2), input_shape=(3, 4)))
model.compile(
optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001),
loss='mse',
experimental_run_tf_function=False)
loss='mse')
model.fit(
np.random.random((10, 3, 4)),
np.random.random((10, 3, 2)),
Expand All @@ -60,7 +58,6 @@ def test_weightnorm_dense_train_notinit(self):
self.assertTrue(hasattr(model.layers[0], 'g'))

def test_weightnorm_conv2d(self):
# TODO: Get data init to work with tf_function compile #428
model = tf.keras.models.Sequential()
model.add(
wrappers.WeightNormalization(
Expand All @@ -70,8 +67,7 @@ def test_weightnorm_conv2d(self):
model.add(tf.keras.layers.Activation('relu'))
model.compile(
optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001),
loss='mse',
experimental_run_tf_function=False)
loss='mse')
model.fit(
np.random.random((2, 4, 4, 3)),
np.random.random((2, 4, 4, 5)),
Expand Down Expand Up @@ -105,10 +101,7 @@ def test_weightnorm_keras(self):
'layer': tf.keras.layers.Dense(2),
'input_shape': (3, 4)
},
input_data=input_data,
# TODO: Fix the bug thats causing layer test to run a
# graph Tensor in eager mode.
validate_training=False)
input_data=input_data)


if __name__ == "__main__":
Expand Down