From de72760091f760f675f572e12362517935811305 Mon Sep 17 00:00:00 2001 From: Dheeraj Date: Mon, 11 Nov 2019 00:55:15 +0530 Subject: [PATCH 1/4] Build data_init layer under name_scope The original wrapped layer and the non-trainable layer created for data dependent initialization had a clash in their namespaces. Creating the second layer under a name scope of 'data_dep_init' fixes the issue. --- tensorflow_addons/layers/wrappers.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tensorflow_addons/layers/wrappers.py b/tensorflow_addons/layers/wrappers.py index 5ba4d12fbc..be7f5340f0 100644 --- a/tensorflow_addons/layers/wrappers.py +++ b/tensorflow_addons/layers/wrappers.py @@ -93,12 +93,13 @@ def build(self, input_shape): if self.data_init: # Used for data initialization in self._data_dep_init. - layer_config = tf.keras.layers.serialize(self.layer) - layer_config['config']['trainable'] = False - self._naked_clone_layer = tf.keras.layers.deserialize(layer_config) - self._naked_clone_layer.build(input_shape) - self._naked_clone_layer.set_weights(self.layer.get_weights()) - self._naked_clone_layer.activation = None + with tf.name_scope('data_dep_init'): + layer_config = tf.keras.layers.serialize(self.layer) + layer_config['config']['trainable'] = False + self._naked_clone_layer = tf.keras.layers.deserialize(layer_config) + self._naked_clone_layer.build(input_shape) + self._naked_clone_layer.set_weights(self.layer.get_weights()) + self._naked_clone_layer.activation = None self.built = True From d29c51dd50be723a65a34dc51d2512dd4e2b863a Mon Sep 17 00:00:00 2001 From: Dheeraj Date: Mon, 11 Nov 2019 01:32:38 +0530 Subject: [PATCH 2/4] Lint --- tensorflow_addons/layers/wrappers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow_addons/layers/wrappers.py b/tensorflow_addons/layers/wrappers.py index be7f5340f0..98ddc736ca 100644 --- a/tensorflow_addons/layers/wrappers.py +++ b/tensorflow_addons/layers/wrappers.py @@ -96,7 +96,8 @@ def build(self, input_shape): with tf.name_scope('data_dep_init'): layer_config = tf.keras.layers.serialize(self.layer) layer_config['config']['trainable'] = False - self._naked_clone_layer = tf.keras.layers.deserialize(layer_config) + self._naked_clone_layer = tf.keras.layers.deserialize( + layer_config) self._naked_clone_layer.build(input_shape) self._naked_clone_layer.set_weights(self.layer.get_weights()) self._naked_clone_layer.activation = None From 833f807018fab9237eec8b532b64acbf98fb4959 Mon Sep 17 00:00:00 2001 From: Dheeraj Date: Tue, 12 Nov 2019 22:42:30 +0530 Subject: [PATCH 3/4] Add test for saving --- tensorflow_addons/layers/wrappers_test.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tensorflow_addons/layers/wrappers_test.py b/tensorflow_addons/layers/wrappers_test.py index 0ff8e417a4..17270f9ab5 100644 --- a/tensorflow_addons/layers/wrappers_test.py +++ b/tensorflow_addons/layers/wrappers_test.py @@ -81,6 +81,17 @@ def test_weightnorm_with_time_dist(self): out = tf.keras.layers.TimeDistributed(b)(inputs) model = tf.keras.Model(inputs, out) + def test_save_file_h5(self): + conv = tf.keras.layers.Conv1D(1, 1) + wn_conv = wrappers.WeightNormalization(conv) + model = tf.keras.Sequential(layers=[wn_conv]) + model.build([1, 2, 3]) + model.save_weights('/tmp/model.h5') + + import os + os.remove('/tmp/model.h5') + # TODO: Find a better way to test this + if __name__ == "__main__": tf.test.main() From df7ccb3e7b3353a587238bf7ec6ef39f94ebc85c Mon Sep 17 00:00:00 2001 From: Dheeraj Date: Wed, 13 Nov 2019 15:05:11 +0530 Subject: [PATCH 4/4] Use create_tempfile --- tensorflow_addons/layers/wrappers_test.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tensorflow_addons/layers/wrappers_test.py b/tensorflow_addons/layers/wrappers_test.py index 17270f9ab5..ecb6d354c7 100644 --- a/tensorflow_addons/layers/wrappers_test.py +++ b/tensorflow_addons/layers/wrappers_test.py @@ -82,15 +82,12 @@ def test_weightnorm_with_time_dist(self): model = tf.keras.Model(inputs, out) def test_save_file_h5(self): + self.create_tempfile('wrapper_test_model.h5') conv = tf.keras.layers.Conv1D(1, 1) wn_conv = wrappers.WeightNormalization(conv) model = tf.keras.Sequential(layers=[wn_conv]) model.build([1, 2, 3]) - model.save_weights('/tmp/model.h5') - - import os - os.remove('/tmp/model.h5') - # TODO: Find a better way to test this + model.save_weights('wrapper_test_model.h5') if __name__ == "__main__":