From 93330af8be4d53f7680f5226fb9ee5b5da19af59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Wed, 28 Aug 2019 14:25:33 +0800 Subject: [PATCH 1/3] TST: test cases pass --- tensorflow_addons/layers/wrappers.py | 4 ++-- tensorflow_addons/layers/wrappers_test.py | 9 +++------ 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/tensorflow_addons/layers/wrappers.py b/tensorflow_addons/layers/wrappers.py index bc9a145087..5eaafffc49 100644 --- a/tensorflow_addons/layers/wrappers.py +++ b/tensorflow_addons/layers/wrappers.py @@ -144,9 +144,9 @@ def _data_dep_init(self, inputs): scale_init = 1. / tf.math.sqrt(v_init + 1e-10) # Assign data dependent init values - self.g = self.g * scale_init + self.g.assign(self.g * scale_init) if hasattr(self.layer, 'bias'): - self.layer.bias = -m_init * scale_init + self.layer.bias.assign(-m_init * scale_init) self.layer.activation = existing_activation def get_config(self): diff --git a/tensorflow_addons/layers/wrappers_test.py b/tensorflow_addons/layers/wrappers_test.py index 54e13719ec..24bbefc398 100644 --- a/tensorflow_addons/layers/wrappers_test.py +++ b/tensorflow_addons/layers/wrappers_test.py @@ -34,8 +34,7 @@ def test_weightnorm_dense_train(self): tf.keras.layers.Dense(2), input_shape=(3, 4))) model.compile( optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001), - loss='mse', - experimental_run_tf_function=False) + loss='mse') model.fit( np.random.random((10, 3, 4)), np.random.random((10, 3, 2)), @@ -103,12 +102,10 @@ def test_weightnorm_keras(self): wrappers.WeightNormalization, kwargs={ 'layer': tf.keras.layers.Dense(2), + 'data_init': False, 'input_shape': (3, 4) }, - input_data=input_data, - # TODO: Fix the bug thats causing layer test to run a - # graph Tensor in eager mode. - validate_training=False) + input_data=input_data) if __name__ == "__main__": From ae3cff2ae64dc86b77fc15ace71e4328eea2201d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Wed, 28 Aug 2019 18:16:19 +0800 Subject: [PATCH 2/3] BUG: fix related bugs --- tensorflow_addons/layers/wrappers.py | 77 +++++++++-------------- tensorflow_addons/layers/wrappers_test.py | 5 +- 2 files changed, 31 insertions(+), 51 deletions(-) diff --git a/tensorflow_addons/layers/wrappers.py b/tensorflow_addons/layers/wrappers.py index 5eaafffc49..0c1ef63970 100644 --- a/tensorflow_addons/layers/wrappers.py +++ b/tensorflow_addons/layers/wrappers.py @@ -69,76 +69,60 @@ def build(self, input_shape): if not self.layer.built: self.layer.build(input_shape) - if not hasattr(self.layer, 'kernel'): - raise ValueError('`WeightNormalization` must wrap a layer that' - ' contains a `kernel` for weights') + if not hasattr(self.layer, 'kernel'): + raise ValueError('`WeightNormalization` must wrap a layer that' + ' contains a `kernel` for weights') - # The kernel's filter or unit dimension is -1 - self.layer_depth = int(self.layer.kernel.shape[-1]) - self.kernel_norm_axes = list( - range(self.layer.kernel.shape.rank - 1)) + # The kernel's filter or unit dimension is -1 + self.layer_depth = int(self.layer.kernel.shape[-1]) + self.kernel_norm_axes = list(range(self.layer.kernel.shape.rank - 1)) - self.v = self.layer.kernel - self.g = self.add_variable( - name="g", - shape=(self.layer_depth,), - initializer=tf.keras.initializers.get('ones'), - dtype=self.layer.kernel.dtype, - trainable=True) + self.g = self.add_variable( + name='g', + shape=(self.layer_depth,), + initializer=tf.keras.initializers.get('ones'), + dtype=self.layer.kernel.dtype, + trainable=True) + self.v = self.layer.kernel - super(WeightNormalization, self).build() + self.built = True def call(self, inputs): """Call `Layer`""" if not self._initialized: - self._initialize_weights(inputs) + if self.data_init: + self._data_dep_init(inputs) + else: + self._init_norm() + self._initialized = True - self._compute_weights() # Recompute weights for each forward pass - output = self.layer(inputs) - return output + with tf.name_scope('compute_weights'): + # Replace kernel by normalized weight variable. + self.layer.kernel = tf.nn.l2_normalize( + self.v, + axis=self.kernel_norm_axes) * self.g + return self.layer(inputs) def compute_output_shape(self, input_shape): return tf.TensorShape( self.layer.compute_output_shape(input_shape).as_list()) - def _compute_weights(self): - """Generate normalized weights. - - This method will update the value of self.layer.kernel with the - normalized value, so that the layer is ready for call(). - """ - with tf.name_scope('compute_weights'): - self.layer.kernel = tf.nn.l2_normalize( - self.v, axis=self.kernel_norm_axes) * self.g - - def _initialize_weights(self, inputs): - """Initialize weight g. - - The initial value of g could either from the initial value in v, - or by the input value if self.data_init is True. - """ - if self.data_init: - self._data_dep_init(inputs) - else: - self._init_norm() - self._initialized = True - def _init_norm(self): """Set the weight g with the norm of the weight vector.""" with tf.name_scope('init_norm'): - flat = tf.reshape(self.v, [-1, self.layer_depth]) - self.g.assign( - tf.reshape(tf.linalg.norm(flat, axis=0), (self.layer_depth,))) + v_flat = tf.reshape(self.v, [-1, self.layer_depth]) + v_norm = tf.linalg.norm(v_flat, axis=0) + self.g.assign(tf.reshape(v_norm, (self.layer_depth,))) - # TODO: Get data init to work with tf_function compile #428 def _data_dep_init(self, inputs): """Data dependent initialization.""" - with tf.name_scope('data_dep_init'): # Generate data dependent init values existing_activation = self.layer.activation self.layer.activation = None x_init = self.layer(inputs) + self.layer.activation = existing_activation + data_norm_axes = list(range(x_init.shape.rank - 1)) m_init, v_init = tf.nn.moments(x_init, data_norm_axes) scale_init = 1. / tf.math.sqrt(v_init + 1e-10) @@ -147,7 +131,6 @@ def _data_dep_init(self, inputs): self.g.assign(self.g * scale_init) if hasattr(self.layer, 'bias'): self.layer.bias.assign(-m_init * scale_init) - self.layer.activation = existing_activation def get_config(self): config = {'data_init': self.data_init} diff --git a/tensorflow_addons/layers/wrappers_test.py b/tensorflow_addons/layers/wrappers_test.py index 24bbefc398..9472cd3293 100644 --- a/tensorflow_addons/layers/wrappers_test.py +++ b/tensorflow_addons/layers/wrappers_test.py @@ -26,7 +26,6 @@ @test_utils.run_all_in_graph_and_eager_modes class WeightNormalizationTest(tf.test.TestCase): - # TODO: Get data init to work with tf_function compile #428 def test_weightnorm_dense_train(self): model = tf.keras.models.Sequential() model.add( @@ -59,7 +58,6 @@ def test_weightnorm_dense_train_notinit(self): self.assertTrue(hasattr(model.layers[0], 'g')) def test_weightnorm_conv2d(self): - # TODO: Get data init to work with tf_function compile #428 model = tf.keras.models.Sequential() model.add( wrappers.WeightNormalization( @@ -69,8 +67,7 @@ def test_weightnorm_conv2d(self): model.add(tf.keras.layers.Activation('relu')) model.compile( optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001), - loss='mse', - experimental_run_tf_function=False) + loss='mse') model.fit( np.random.random((2, 4, 4, 3)), np.random.random((2, 4, 4, 5)), From 39e74fc3aa2b30d4921b0313093c73d0e30bfb8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Wed, 28 Aug 2019 20:08:38 +0800 Subject: [PATCH 3/3] BUG: fix test_weightnorm_keras --- tensorflow_addons/layers/wrappers.py | 68 ++++++++++++++++------- tensorflow_addons/layers/wrappers_test.py | 1 - 2 files changed, 49 insertions(+), 20 deletions(-) diff --git a/tensorflow_addons/layers/wrappers.py b/tensorflow_addons/layers/wrappers.py index 0c1ef63970..fd101b219a 100644 --- a/tensorflow_addons/layers/wrappers.py +++ b/tensorflow_addons/layers/wrappers.py @@ -58,7 +58,6 @@ class WeightNormalization(tf.keras.layers.Wrapper): def __init__(self, layer, data_init=True, **kwargs): super(WeightNormalization, self).__init__(layer, **kwargs) self.data_init = data_init - self._initialized = False self._track_trackable(layer, name='layer') def build(self, input_shape): @@ -80,33 +79,68 @@ def build(self, input_shape): self.g = self.add_variable( name='g', shape=(self.layer_depth,), - initializer=tf.keras.initializers.get('ones'), + initializer='ones', dtype=self.layer.kernel.dtype, trainable=True) self.v = self.layer.kernel + self._initialized = self.add_variable( + name='initialized', + shape=None, + initializer='zeros', + dtype=tf.dtypes.bool, + trainable=False) + + if self.data_init: + self._naked_layer = tf.keras.layers.deserialize( + tf.keras.layers.serialize(self.layer)) + self._naked_layer.build(input_shape) + self._naked_layer.set_weights(self.layer.get_weights()) + self._naked_layer.activation = None + self.built = True def call(self, inputs): """Call `Layer`""" - if not self._initialized: - if self.data_init: - self._data_dep_init(inputs) - else: - self._init_norm() - self._initialized = True + + def _do_nothing(): + return inputs + + def _update_weights(): + self._initialize_weights(inputs) + return inputs + + inputs = tf.cond(self._initialized, _do_nothing, _update_weights) with tf.name_scope('compute_weights'): # Replace kernel by normalized weight variable. self.layer.kernel = tf.nn.l2_normalize( - self.v, - axis=self.kernel_norm_axes) * self.g + self.v, axis=self.kernel_norm_axes) * self.g + return self.layer(inputs) def compute_output_shape(self, input_shape): return tf.TensorShape( self.layer.compute_output_shape(input_shape).as_list()) + def _initialize_weights(self, inputs): + """Initialize weight g. + + The initial value of g could either from the initial value in v, + or by the input value if self.data_init is True. + """ + with tf.control_dependencies([ + tf.debugging.assert_equal( # pylint: disable=bad-continuation + self._initialized, + False, + message='The layer has been initialized.') + ]): + if self.data_init: + self._data_dep_init(inputs) + else: + self._init_norm() + self._initialized.assign(True) + def _init_norm(self): """Set the weight g with the norm of the weight vector.""" with tf.name_scope('init_norm'): @@ -118,19 +152,15 @@ def _data_dep_init(self, inputs): """Data dependent initialization.""" with tf.name_scope('data_dep_init'): # Generate data dependent init values - existing_activation = self.layer.activation - self.layer.activation = None - x_init = self.layer(inputs) - self.layer.activation = existing_activation - + x_init = self._naked_layer(inputs) data_norm_axes = list(range(x_init.shape.rank - 1)) m_init, v_init = tf.nn.moments(x_init, data_norm_axes) scale_init = 1. / tf.math.sqrt(v_init + 1e-10) - # Assign data dependent init values - self.g.assign(self.g * scale_init) - if hasattr(self.layer, 'bias'): - self.layer.bias.assign(-m_init * scale_init) + # Assign data dependent init values + self.g.assign(self.g * scale_init) + if hasattr(self.layer, 'bias'): + self.layer.bias.assign(-m_init * scale_init) def get_config(self): config = {'data_init': self.data_init} diff --git a/tensorflow_addons/layers/wrappers_test.py b/tensorflow_addons/layers/wrappers_test.py index 9472cd3293..9d83bbec50 100644 --- a/tensorflow_addons/layers/wrappers_test.py +++ b/tensorflow_addons/layers/wrappers_test.py @@ -99,7 +99,6 @@ def test_weightnorm_keras(self): wrappers.WeightNormalization, kwargs={ 'layer': tf.keras.layers.Dense(2), - 'data_init': False, 'input_shape': (3, 4) }, input_data=input_data)