diff --git a/tensorflow_addons/layers/wrappers.py b/tensorflow_addons/layers/wrappers.py index badc6596df..5ba4d12fbc 100644 --- a/tensorflow_addons/layers/wrappers.py +++ b/tensorflow_addons/layers/wrappers.py @@ -62,7 +62,8 @@ def __init__(self, layer, data_init=True, **kwargs): def build(self, input_shape): """Build `Layer`""" input_shape = tf.TensorShape(input_shape).as_list() - self.input_spec = tf.keras.layers.InputSpec(shape=input_shape) + self.input_spec = tf.keras.layers.InputSpec( + shape=[None] + input_shape[1:]) if not self.layer.built: self.layer.build(input_shape) diff --git a/tensorflow_addons/layers/wrappers_test.py b/tensorflow_addons/layers/wrappers_test.py index b4bdb9c494..0ff8e417a4 100644 --- a/tensorflow_addons/layers/wrappers_test.py +++ b/tensorflow_addons/layers/wrappers_test.py @@ -73,6 +73,14 @@ def test_weightnorm_non_kernel_layer(self): wn_wrapper = wrappers.WeightNormalization(non_kernel_layer) wn_wrapper(images) + def test_weightnorm_with_time_dist(self): + batch_shape = (32, 16, 64, 64, 3) + inputs = tf.keras.layers.Input(batch_shape=batch_shape) + a = tf.keras.layers.Conv2D(3, 5) + b = wrappers.WeightNormalization(a) + out = tf.keras.layers.TimeDistributed(b)(inputs) + model = tf.keras.Model(inputs, out) + if __name__ == "__main__": tf.test.main()