diff --git a/tensorflow_addons/layers/wrappers.py b/tensorflow_addons/layers/wrappers.py index 166587d2b8..c724ae5059 100644 --- a/tensorflow_addons/layers/wrappers.py +++ b/tensorflow_addons/layers/wrappers.py @@ -58,31 +58,34 @@ def __init__(self, layer, data_init=True, **kwargs): super(WeightNormalization, self).__init__(layer, **kwargs) self.data_init = data_init self._track_trackable(layer, name='layer') + self.is_rnn = isinstance(self.layer, tf.keras.layers.RNN) def build(self, input_shape): """Build `Layer`""" - input_shape = tf.TensorShape(input_shape).as_list() + input_shape = tf.TensorShape(input_shape) self.input_spec = tf.keras.layers.InputSpec( shape=[None] + input_shape[1:]) if not self.layer.built: self.layer.build(input_shape) - if not hasattr(self.layer, 'kernel'): + kernel_layer = self.layer.cell if self.is_rnn else self.layer + + if not hasattr(kernel_layer, 'kernel'): raise ValueError('`WeightNormalization` must wrap a layer that' ' contains a `kernel` for weights') # The kernel's filter or unit dimension is -1 - self.layer_depth = int(self.layer.kernel.shape[-1]) - self.kernel_norm_axes = list(range(self.layer.kernel.shape.rank - 1)) + self.layer_depth = int(kernel_layer.kernel.shape[-1]) + self.kernel_norm_axes = list(range(kernel_layer.kernel.shape.rank - 1)) self.g = self.add_weight( name='g', shape=(self.layer_depth,), initializer='ones', - dtype=self.layer.kernel.dtype, + dtype=kernel_layer.kernel.dtype, trainable=True) - self.v = self.layer.kernel + self.v = kernel_layer.kernel self._initialized = self.add_weight( name='initialized', @@ -100,7 +103,10 @@ def build(self, input_shape): layer_config) self._naked_clone_layer.build(input_shape) self._naked_clone_layer.set_weights(self.layer.get_weights()) - self._naked_clone_layer.activation = None + if self.is_rnn: + self._naked_clone_layer.cell.activation = None + else: + self._naked_clone_layer.activation = None self.built = True diff --git a/tensorflow_addons/layers/wrappers_test.py b/tensorflow_addons/layers/wrappers_test.py index 309bcde9b4..77da7af9ab 100644 --- a/tensorflow_addons/layers/wrappers_test.py +++ b/tensorflow_addons/layers/wrappers_test.py @@ -89,6 +89,13 @@ def test_weightnorm_with_time_dist(self): out = tf.keras.layers.TimeDistributed(b)(inputs) model = tf.keras.Model(inputs, out) + def test_weightnorm_with_rnn(self): + inputs = tf.keras.layers.Input(shape=(None, 3)) + rnn_layer = tf.keras.layers.SimpleRNN(4) + wt_rnn = wrappers.WeightNormalization(rnn_layer) + dense = tf.keras.layers.Dense(1) + model = tf.keras.models.Sequential(layers=[inputs, wt_rnn, dense]) + def test_save_file_h5(self): self.create_tempfile('wrapper_test_model.h5') conv = tf.keras.layers.Conv1D(1, 1)