From e352207da32e4670a36a295ea477c476118cb0d9 Mon Sep 17 00:00:00 2001 From: Nir Tzachar Date: Wed, 13 Nov 2019 12:19:53 +0200 Subject: [PATCH] Fix wrapped layer has use_biase=False --- tensorflow_addons/layers/wrappers.py | 2 +- tensorflow_addons/layers/wrappers_test.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/tensorflow_addons/layers/wrappers.py b/tensorflow_addons/layers/wrappers.py index 5ba4d12fbc..07641e20b9 100644 --- a/tensorflow_addons/layers/wrappers.py +++ b/tensorflow_addons/layers/wrappers.py @@ -168,7 +168,7 @@ def _data_dep_init(self, inputs): # Assign data dependent init values g_tensor = self.g.assign(self.g * scale_init) - if hasattr(self.layer, 'bias'): + if hasattr(self.layer, 'bias') and self.layer.bias is not None: bias_tensor = self.layer.bias.assign(-m_init * scale_init) return [g_tensor, bias_tensor] else: diff --git a/tensorflow_addons/layers/wrappers_test.py b/tensorflow_addons/layers/wrappers_test.py index 0ff8e417a4..e135cd6b39 100644 --- a/tensorflow_addons/layers/wrappers_test.py +++ b/tensorflow_addons/layers/wrappers_test.py @@ -34,6 +34,14 @@ def test_weightnorm(self): }, input_shape=(2, 4, 4, 3)) + def test_weightnorm_no_bias(self): + test_utils.layer_test( + wrappers.WeightNormalization, + kwargs={ + 'layer': tf.keras.layers.Dense(5, use_bias=False), + }, + input_shape=(2, 4)) + def _check_data_init(self, data_init, input_data, expected_output): layer = tf.keras.layers.Dense( input_data.shape[-1],