diff --git a/tensorflow_addons/layers/BUILD b/tensorflow_addons/layers/BUILD index 42991189ed..feed49a9a3 100644 --- a/tensorflow_addons/layers/BUILD +++ b/tensorflow_addons/layers/BUILD @@ -11,6 +11,7 @@ py_library( "normalizations.py", "optical_flow.py", "poincare.py", + "polynomial.py", "sparsemax.py", "tlu.py", "wrappers.py", @@ -24,6 +25,20 @@ py_library( ], ) +py_test( + name = "polynomial_test", + size = "small", + srcs = [ + "polynomial_test.py", + ], + main = "polynomial_test.py", + python_version = "PY3", + srcs_version = "PY2AND3", + deps = [ + ":layers", + ], +) + py_test( name = "gelu_test", size = "small", diff --git a/tensorflow_addons/layers/README.md b/tensorflow_addons/layers/README.md index 1c8babc943..19c109a22c 100644 --- a/tensorflow_addons/layers/README.md +++ b/tensorflow_addons/layers/README.md @@ -8,6 +8,7 @@ | normalizations | @smokrow | moritz.kroeger@tu-dortmund.de | | opticalflow | @fsx950223 | fsx950223@gmail.com | | poincare | @rahulunair | rahulunair@gmail.com | +| polynomial | @tanzheny | tanzheny@google.com | sparsemax | @AndreasMadsen | amwwebdk+github@gmail.com | | tlu | @AakashKumarNain | aakashnain@outlook.com | | wrappers | @seanpmorgan | seanmorgan@outlook.com | @@ -21,6 +22,7 @@ | normalizations | InstanceNormalization | https://arxiv.org/abs/1607.08022 | | opticalflow | CorrelationCost | https://arxiv.org/abs/1504.06852 | | poincare | PoincareNormalize | https://arxiv.org/abs/1705.08039 | +| polynomial | PolynomialCrossing | https://arxiv.org/pdf/1708.05123 | | sparsemax| Sparsemax | https://arxiv.org/abs/1602.02068 | | tlu | TLU | https://arxiv.org/abs/1911.09737 | | wrappers | WeightNormalization | https://arxiv.org/abs/1602.07868 | diff --git a/tensorflow_addons/layers/__init__.py b/tensorflow_addons/layers/__init__.py index 24e55cf156..f2dc97c7a4 100644 --- a/tensorflow_addons/layers/__init__.py +++ b/tensorflow_addons/layers/__init__.py @@ -20,6 +20,7 @@ from tensorflow_addons.layers.normalizations import InstanceNormalization from tensorflow_addons.layers.optical_flow import CorrelationCost from tensorflow_addons.layers.poincare import PoincareNormalize +from tensorflow_addons.layers.polynomial import PolynomialCrossing from tensorflow_addons.layers.sparsemax import Sparsemax from tensorflow_addons.layers.tlu import TLU from tensorflow_addons.layers.wrappers import WeightNormalization diff --git a/tensorflow_addons/layers/polynomial.py b/tensorflow_addons/layers/polynomial.py new file mode 100644 index 0000000000..f42c638f76 --- /dev/null +++ b/tensorflow_addons/layers/polynomial.py @@ -0,0 +1,162 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implements Polynomial Crossing Layer.""" + +import tensorflow as tf +from typeguard import typechecked + +from tensorflow_addons.utils import types + + +@tf.keras.utils.register_keras_serializable(package="Addons") +class PolynomialCrossing(tf.keras.layers.Layer): + """Layer for Deep & Cross Network to learn explicit feature interactions. + + A layer that applies feature crossing in learning certain explicit + bounded-degree feature interactions more efficiently. The `call` method + accepts `inputs` as a tuple of size 2 tensors. The first input `x0` should be + the input to the first `PolynomialCrossing` layer in the stack, or the input + to the network (usually after the embedding layer), the second input `xi` + is the output of the previous `PolynomialCrossing` layer in the stack, i.e., + the i-th `PolynomialCrossing` layer. + + The output is y = x0 * (W .* x) + bias + xi, where .* designates dot product. + + References + See [R. Wang](https://arxiv.org/pdf/1708.05123.pdf) + + Example: + + ```python + # after embedding layer in a functional model: + input = tf.keras.Input(shape=(None,), name='index', dtype=tf.int64) + x0 = tf.keras.layers.Embedding(input_dim=32, output_dim=6)) + x1 = PolynomialCrossing(projection_dim=None)((x0, x0)) + x2 = PolynomialCrossing(projection_dim=None)((x0, x1)) + logits = tf.keras.layers.Dense(units=10)(x2) + model = tf.keras.Model(input, logits) + ``` + + Arguments: + projection_dim: project dimension. Default is `None` such that a full + (`input_dim` by `input_dim`) matrix is used. + use_bias: whether to calculate the bias/intercept for this layer. If set to + False, no bias/intercept will be used in calculations, e.g., the data is + already centered. + kernel_initializer: Initializer instance to use on the kernel matrix. + bias_initializer: Initializer instance to use on the bias vector. + kernel_regularizer: Regularizer instance to use on the kernel matrix. + bias_regularizer: Regularizer instance to use on bias vector. + + Input shape: + A tuple of 2 (batch_size, `input_dim`) dimensional inputs. + + Output shape: + A single (batch_size, `input_dim`) dimensional output. + """ + + @typechecked + def __init__( + self, + projection_dim: int = None, + use_bias: bool = True, + kernel_initializer: types.Initializer = "truncated_normal", + bias_initializer: types.Initializer = "zeros", + kernel_regularizer: types.Regularizer = None, + bias_regularizer: types.Regularizer = None, + **kwargs, + ): + super(PolynomialCrossing, self).__init__(**kwargs) + + self.projection_dim = projection_dim + self.use_bias = use_bias + self.kernel_initializer = tf.keras.initializers.get(kernel_initializer) + self.bias_initializer = tf.keras.initializers.get(bias_initializer) + self.kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer) + self.bias_regularizer = tf.keras.regularizers.get(bias_regularizer) + + self.supports_masking = True + + def build(self, input_shape): + if not isinstance(input_shape, (tuple, list)) or len(input_shape) != 2: + raise ValueError( + "Input shapes must be a tuple or list of size 2, " + "got {}".format(input_shape) + ) + last_dim = input_shape[-1][-1] + if self.projection_dim is None: + kernel_shape = [last_dim, last_dim] + else: + if self.projection_dim != last_dim: + raise ValueError( + "The case where `projection_dim` != last " + "dimension of the inputs is not supported yet, got " + "`projection_dim` {}, and last dimension of input " + "{}".format(self.projection_dim, last_dim) + ) + kernel_shape = [last_dim, self.projection_dim] + self.kernel = self.add_weight( + "kernel", + shape=kernel_shape, + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + dtype=self.dtype, + trainable=True, + ) + if self.use_bias: + self.bias = self.add_weight( + "bias", + shape=[last_dim], + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + dtype=self.dtype, + trainable=True, + ) + self.built = True + + def call(self, inputs): + if not isinstance(inputs, (tuple, list)) or len(inputs) != 2: + raise ValueError( + "Inputs to the layer must be a tuple or list of size 2, " + "got {}".format(inputs) + ) + x0, x = inputs + outputs = x0 * tf.matmul(x, self.kernel) + x + if self.use_bias: + outputs = tf.add(outputs, self.bias) + return outputs + + def get_config(self): + config = { + "projection_dim": self.projection_dim, + "use_bias": self.use_bias, + "kernel_initializer": tf.keras.initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": tf.keras.initializers.serialize(self.bias_initializer), + "kernel_regularizer": tf.keras.regularizers.serialize( + self.kernel_regularizer + ), + "bias_regularizer": tf.keras.regularizers.serialize(self.bias_regularizer), + } + base_config = super(PolynomialCrossing, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + def compute_output_shape(self, input_shape): + if not isinstance(input_shape, (tuple, list)): + raise ValueError( + "A `PolynomialCrossing` layer should be called " "on a list of inputs." + ) + return input_shape[0] diff --git a/tensorflow_addons/layers/polynomial_test.py b/tensorflow_addons/layers/polynomial_test.py new file mode 100644 index 0000000000..392eff4a85 --- /dev/null +++ b/tensorflow_addons/layers/polynomial_test.py @@ -0,0 +1,59 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for PolynomialCrossing layer.""" + +import numpy as np +import tensorflow as tf + +from tensorflow_addons.layers.polynomial import PolynomialCrossing +from tensorflow_addons.utils import test_utils + + +@test_utils.run_all_in_graph_and_eager_modes +class PolynomialCrossingTest(tf.test.TestCase): + # Do not use layer_test due to multiple inputs. + + def test_full_matrix(self): + x0 = np.asarray([[0.1, 0.2, 0.3]]).astype(np.float32) + x = np.asarray([[0.4, 0.5, 0.6]]).astype(np.float32) + layer = PolynomialCrossing(projection_dim=None, kernel_initializer="ones") + output = layer([x0, x]) + self.evaluate(tf.compat.v1.global_variables_initializer()) + self.assertAllClose(np.asarray([[0.55, 0.8, 1.05]]), output) + + def test_invalid_proj_dim(self): + with self.assertRaisesRegexp(ValueError, r"is not supported yet"): + x0 = np.random.random((12, 5)) + x = np.random.random((12, 5)) + layer = PolynomialCrossing(projection_dim=6) + layer([x0, x]) + + def test_invalid_inputs(self): + with self.assertRaisesRegexp(ValueError, r"must be a tuple or list of size 2"): + x0 = np.random.random((12, 5)) + x = np.random.random((12, 5)) + x1 = np.random.random((12, 5)) + layer = PolynomialCrossing(projection_dim=6) + layer([x0, x, x1]) + + def test_serialization(self): + layer = PolynomialCrossing(projection_dim=None) + serialized_layer = tf.keras.layers.serialize(layer) + new_layer = tf.keras.layers.deserialize(serialized_layer) + self.assertEqual(layer.get_config(), new_layer.get_config()) + + +if __name__ == "__main__": + tf.test.main()