From 18e3edb2b1a020c941b9844073de399e917dc3ec Mon Sep 17 00:00:00 2001 From: tanzhenyu Date: Mon, 3 Feb 2020 00:52:13 -0800 Subject: [PATCH 1/8] Add PolynomialCrossing to Addons This is from internal design to open-source the DCN network from paper https://arxiv.org/abs/1708.05123 --- tensorflow_addons/layers/BUILD | 15 ++ tensorflow_addons/layers/README.md | 2 + tensorflow_addons/layers/__init__.py | 1 + tensorflow_addons/layers/polynomial.py | 150 ++++++++++++++++++++ tensorflow_addons/layers/polynomial_test.py | 48 +++++++ 5 files changed, 216 insertions(+) create mode 100644 tensorflow_addons/layers/polynomial.py create mode 100644 tensorflow_addons/layers/polynomial_test.py diff --git a/tensorflow_addons/layers/BUILD b/tensorflow_addons/layers/BUILD index 42991189ed..feed49a9a3 100644 --- a/tensorflow_addons/layers/BUILD +++ b/tensorflow_addons/layers/BUILD @@ -11,6 +11,7 @@ py_library( "normalizations.py", "optical_flow.py", "poincare.py", + "polynomial.py", "sparsemax.py", "tlu.py", "wrappers.py", @@ -24,6 +25,20 @@ py_library( ], ) +py_test( + name = "polynomial_test", + size = "small", + srcs = [ + "polynomial_test.py", + ], + main = "polynomial_test.py", + python_version = "PY3", + srcs_version = "PY2AND3", + deps = [ + ":layers", + ], +) + py_test( name = "gelu_test", size = "small", diff --git a/tensorflow_addons/layers/README.md b/tensorflow_addons/layers/README.md index 1c8babc943..19c109a22c 100644 --- a/tensorflow_addons/layers/README.md +++ b/tensorflow_addons/layers/README.md @@ -8,6 +8,7 @@ | normalizations | @smokrow | moritz.kroeger@tu-dortmund.de | | opticalflow | @fsx950223 | fsx950223@gmail.com | | poincare | @rahulunair | rahulunair@gmail.com | +| polynomial | @tanzheny | tanzheny@google.com | sparsemax | @AndreasMadsen | amwwebdk+github@gmail.com | | tlu | @AakashKumarNain | aakashnain@outlook.com | | wrappers | @seanpmorgan | seanmorgan@outlook.com | @@ -21,6 +22,7 @@ | normalizations | InstanceNormalization | https://arxiv.org/abs/1607.08022 | | opticalflow | CorrelationCost | https://arxiv.org/abs/1504.06852 | | poincare | PoincareNormalize | https://arxiv.org/abs/1705.08039 | +| polynomial | PolynomialCrossing | https://arxiv.org/pdf/1708.05123 | | sparsemax| Sparsemax | https://arxiv.org/abs/1602.02068 | | tlu | TLU | https://arxiv.org/abs/1911.09737 | | wrappers | WeightNormalization | https://arxiv.org/abs/1602.07868 | diff --git a/tensorflow_addons/layers/__init__.py b/tensorflow_addons/layers/__init__.py index 24e55cf156..f2dc97c7a4 100644 --- a/tensorflow_addons/layers/__init__.py +++ b/tensorflow_addons/layers/__init__.py @@ -20,6 +20,7 @@ from tensorflow_addons.layers.normalizations import InstanceNormalization from tensorflow_addons.layers.optical_flow import CorrelationCost from tensorflow_addons.layers.poincare import PoincareNormalize +from tensorflow_addons.layers.polynomial import PolynomialCrossing from tensorflow_addons.layers.sparsemax import Sparsemax from tensorflow_addons.layers.tlu import TLU from tensorflow_addons.layers.wrappers import WeightNormalization diff --git a/tensorflow_addons/layers/polynomial.py b/tensorflow_addons/layers/polynomial.py new file mode 100644 index 0000000000..de9bae0f4c --- /dev/null +++ b/tensorflow_addons/layers/polynomial.py @@ -0,0 +1,150 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implements Polynomial Crossing Layer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.compat.v2 as tf +from tensorflow_addons.utils import keras_utils + + +@keras_utils.register_keras_custom_object +class PolynomialCrossing(tf.keras.layers.Layer): + """Layer for Deep & Cross Network to learn explicit feature interactions. + + A layer that applies feature crossing in learning certain explicit + bounded-degree feature interactions more efficiently. The `call` method + accepts `inputs` as a tuple of size 2 tensors. The first input `x0` should be + the input to the first `PolynomialCrossing` layer in the stack, or the input + to the network (usually after the embedding layer), the second input `xi` + is the output of the previous `PolynomialCrossing` layer in the stack, i.e., + the i-th `PolynomialCrossing` layer. + + The output is y = x0 * (W .* x) + bias + xi, where .* designates dot product. + + References + See [R. Wang](https://arxiv.org/pdf/1708.05123.pdf) + + Example: + + ```python + # after embedding layer in a functional model: + input = tf.keras.Input(shape=(None,), name='index', dtype=tf.int64) + x0 = tf.keras.layers.Embedding(input_dim=32, output_dim=6)) + x1 = PolynomialCrossing(projection_dim=None)((x0, x0)) + x2 = PolynomialCrossing(projection_dim=None)((x0, x1)) + logits = tf.keras.layers.Dense(units=10)(x2) + model = tf.keras.Model(input, logits) + ``` + + Attributes: + projection_dim: project dimension. Default is `None` such that a full + (`input_dim` by `input_dim`) matrix is used. + use_bias: whether to calculate the bias/intercept for this layer. If set to + False, no bias/intercept will be used in calculations, e.g., the data is + already centered. + kernel_initializer: Initializer instance to use on the kernel matrix. + bias_initializer: Initializer instance to use on the bias vector. + kernel_regularizer: Regularizer instance to use on the kernel matrix. + bias_regularizer: Regularizer instance to use on bias vector. + + Input shape: A tuple of 2 (batch_size, `input_dim`) dimensional inputs. + Output shape: A single (batch_size, `input_dim`) dimensional output. + """ + + def __init__(self, + projection_dim=None, + use_bias=True, + kernel_initializer='truncated_normal', + bias_initializer='zeros', + kernel_regularizer=None, + bias_regularizer=None, + **kwargs): + super(PolynomialCrossing, self).__init__(**kwargs) + + self.projection_dim = projection_dim + self.use_bias = use_bias + self.kernel_initializer = tf.keras.initializers.get(kernel_initializer) + self.bias_initializer = tf.keras.initializers.get(bias_initializer) + self.kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer) + self.bias_regularizer = tf.keras.regularizers.get(bias_regularizer) + + self.supports_masking = True + + def build(self, input_shape): + if not isinstance(input_shape, (tuple, list)) or len(input_shape) != 2: + raise ValueError('Input shapes must be a tuple or list of size 2, ' + 'got {}'.format(input_shape)) + last_dim = input_shape[-1][-1] + if self.projection_dim is None: + kernel_shape = [last_dim, last_dim] + else: + if self.projection_dim != last_dim: + raise ValueError('The case where `projection_dim` != last ' + 'dimension of the inputs is not supported yet, got ' + '`projection_dim` {}, and last dimension of input ' + '{}'.format(self.projection_dim, last_dim)) + kernel_shape = [last_dim, self.projection_dim] + self.kernel = self.add_weight( + 'kernel', + shape=kernel_shape, + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + dtype=self.dtype, + trainable=True) + if self.use_bias: + self.bias = self.add_weight( + 'bias', + shape=[last_dim], + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + dtype=self.dtype, + trainable=True) + + def call(self, inputs): + if not isinstance(inputs, (tuple, list)) or len(inputs) != 2: + raise ValueError('Inputs to the layer must be a tuple or list of size 2, ' + 'got {}'.format(inputs)) + x0, x = inputs + outputs = x0 * tf.matmul(x, self.kernel) + x + if self.use_bias: + outputs = tf.add(outputs, self.bias) + return outputs + + def get_config(self): + config = { + 'projection_dim': + self.projection_dim, + 'use_bias': + self.use_bias, + 'kernel_initializer': + tf.keras.initializers.serialize(self.kernel_initializer), + 'bias_initializer': + tf.keras.initializers.serialize(self.bias_initializer), + 'kernel_regularizer': + tf.keras.regularizers.serialize(self.kernel_regularizer), + 'bias_regularizer': + tf.keras.regularizers.serialize(self.bias_regularizer) + } + base_config = super(PolynomialCrossing, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + def compute_output_shape(self, input_shape): + if not isinstance(input_shape, (tuple, list)): + raise ValueError('A `PolynomialCrossing` layer should be called ' + 'on a list of inputs.') + return input_shape[0] \ No newline at end of file diff --git a/tensorflow_addons/layers/polynomial_test.py b/tensorflow_addons/layers/polynomial_test.py new file mode 100644 index 0000000000..4dda05f02c --- /dev/null +++ b/tensorflow_addons/layers/polynomial_test.py @@ -0,0 +1,48 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for PolynomialCrossing layer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow.compat.v2 as tf + +from tensorflow_addons.layers.polynomial import PolynomialCrossing +from tensorflow_addons.utils import test_utils + + +@test_utils.run_all_in_graph_and_eager_modes +class PolynomialCrossingTest(tf.test.TestCase): + # Do not use layer_test due to multiple inputs. + + def test_full_matrix(self): + x0 = np.random.random((12, 5)) + x = np.random.random((12, 5)) + layer = PolynomialCrossing(projection_dim=None) + layer([x0, x]) + + def test_invalid_proj_dim(self): + with self.assertRaisesRegexp(ValueError, r'is not supported yet'): + x0 = np.random.random((12, 5)) + x = np.random.random((12, 5)) + layer = PolynomialCrossing(projection_dim=6) + layer([x0, x]) + + +if __name__ == '__main__': + tf.enable_v2_behavior() + tf.test.main() \ No newline at end of file From 5dbe61c2ea2cd5177b29d4e903d559eaadc8cd2d Mon Sep 17 00:00:00 2001 From: tanzhenyu Date: Mon, 3 Feb 2020 01:12:41 -0800 Subject: [PATCH 2/8] Add PolynomialCrossing to Addonds 2 Update commit per Addons format. --- tensorflow_addons/layers/polynomial.py | 245 ++++++++++---------- tensorflow_addons/layers/polynomial_test.py | 31 ++- 2 files changed, 138 insertions(+), 138 deletions(-) diff --git a/tensorflow_addons/layers/polynomial.py b/tensorflow_addons/layers/polynomial.py index de9bae0f4c..c18a8ef9bb 100644 --- a/tensorflow_addons/layers/polynomial.py +++ b/tensorflow_addons/layers/polynomial.py @@ -14,137 +14,142 @@ # ============================================================================== """Implements Polynomial Crossing Layer.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +import tensorflow as tf +from typeguard import typechecked -import tensorflow.compat.v2 as tf -from tensorflow_addons.utils import keras_utils +from tensorflow_addons.utils import types -@keras_utils.register_keras_custom_object +@tf.keras.utils.register_keras_serializable(package='Addons') class PolynomialCrossing(tf.keras.layers.Layer): - """Layer for Deep & Cross Network to learn explicit feature interactions. - - A layer that applies feature crossing in learning certain explicit - bounded-degree feature interactions more efficiently. The `call` method - accepts `inputs` as a tuple of size 2 tensors. The first input `x0` should be - the input to the first `PolynomialCrossing` layer in the stack, or the input - to the network (usually after the embedding layer), the second input `xi` - is the output of the previous `PolynomialCrossing` layer in the stack, i.e., - the i-th `PolynomialCrossing` layer. - - The output is y = x0 * (W .* x) + bias + xi, where .* designates dot product. - - References - See [R. Wang](https://arxiv.org/pdf/1708.05123.pdf) - - Example: - - ```python - # after embedding layer in a functional model: - input = tf.keras.Input(shape=(None,), name='index', dtype=tf.int64) - x0 = tf.keras.layers.Embedding(input_dim=32, output_dim=6)) - x1 = PolynomialCrossing(projection_dim=None)((x0, x0)) - x2 = PolynomialCrossing(projection_dim=None)((x0, x1)) - logits = tf.keras.layers.Dense(units=10)(x2) - model = tf.keras.Model(input, logits) - ``` - - Attributes: - projection_dim: project dimension. Default is `None` such that a full - (`input_dim` by `input_dim`) matrix is used. - use_bias: whether to calculate the bias/intercept for this layer. If set to - False, no bias/intercept will be used in calculations, e.g., the data is - already centered. - kernel_initializer: Initializer instance to use on the kernel matrix. - bias_initializer: Initializer instance to use on the bias vector. - kernel_regularizer: Regularizer instance to use on the kernel matrix. - bias_regularizer: Regularizer instance to use on bias vector. - - Input shape: A tuple of 2 (batch_size, `input_dim`) dimensional inputs. - Output shape: A single (batch_size, `input_dim`) dimensional output. - """ - - def __init__(self, - projection_dim=None, - use_bias=True, - kernel_initializer='truncated_normal', - bias_initializer='zeros', - kernel_regularizer=None, - bias_regularizer=None, - **kwargs): - super(PolynomialCrossing, self).__init__(**kwargs) - - self.projection_dim = projection_dim - self.use_bias = use_bias - self.kernel_initializer = tf.keras.initializers.get(kernel_initializer) - self.bias_initializer = tf.keras.initializers.get(bias_initializer) - self.kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer) - self.bias_regularizer = tf.keras.regularizers.get(bias_regularizer) - - self.supports_masking = True + """Layer for Deep & Cross Network to learn explicit feature interactions. + + A layer that applies feature crossing in learning certain explicit + bounded-degree feature interactions more efficiently. The `call` method + accepts `inputs` as a tuple of size 2 tensors. The first input `x0` should be + the input to the first `PolynomialCrossing` layer in the stack, or the input + to the network (usually after the embedding layer), the second input `xi` + is the output of the previous `PolynomialCrossing` layer in the stack, i.e., + the i-th `PolynomialCrossing` layer. + + The output is y = x0 * (W .* x) + bias + xi, where .* designates dot product. + + References + See [R. Wang](https://arxiv.org/pdf/1708.05123.pdf) + + Example: + + ```python + # after embedding layer in a functional model: + input = tf.keras.Input(shape=(None,), name='index', dtype=tf.int64) + x0 = tf.keras.layers.Embedding(input_dim=32, output_dim=6)) + x1 = PolynomialCrossing(projection_dim=None)((x0, x0)) + x2 = PolynomialCrossing(projection_dim=None)((x0, x1)) + logits = tf.keras.layers.Dense(units=10)(x2) + model = tf.keras.Model(input, logits) + ``` + + Arguments: + projection_dim: project dimension. Default is `None` such that a full + (`input_dim` by `input_dim`) matrix is used. + use_bias: whether to calculate the bias/intercept for this layer. If set to + False, no bias/intercept will be used in calculations, e.g., the data is + already centered. + kernel_initializer: Initializer instance to use on the kernel matrix. + bias_initializer: Initializer instance to use on the bias vector. + kernel_regularizer: Regularizer instance to use on the kernel matrix. + bias_regularizer: Regularizer instance to use on bias vector. + + Input shape: + A tuple of 2 (batch_size, `input_dim`) dimensional inputs. + + Output shape: + A single (batch_size, `input_dim`) dimensional output. + """ + + @typechecked + def __init__( + self, + projection_dim :int = None, + use_bias :bool = True, + kernel_initializer: types.Initializer = 'truncated_normal', + bias_initializer : types.Initializer = 'zeros', + kernel_regularizer : types.Regularizer = None, + bias_regularizer: types.Regularizer = None, + **kwargs + ): + super(PolynomialCrossing, self).__init__(**kwargs) + + self.projection_dim = projection_dim + self.use_bias = use_bias + self.kernel_initializer = tf.keras.initializers.get(kernel_initializer) + self.bias_initializer = tf.keras.initializers.get(bias_initializer) + self.kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer) + self.bias_regularizer = tf.keras.regularizers.get(bias_regularizer) + + self.supports_masking = True def build(self, input_shape): - if not isinstance(input_shape, (tuple, list)) or len(input_shape) != 2: - raise ValueError('Input shapes must be a tuple or list of size 2, ' - 'got {}'.format(input_shape)) - last_dim = input_shape[-1][-1] - if self.projection_dim is None: - kernel_shape = [last_dim, last_dim] - else: - if self.projection_dim != last_dim: - raise ValueError('The case where `projection_dim` != last ' - 'dimension of the inputs is not supported yet, got ' - '`projection_dim` {}, and last dimension of input ' - '{}'.format(self.projection_dim, last_dim)) - kernel_shape = [last_dim, self.projection_dim] - self.kernel = self.add_weight( - 'kernel', - shape=kernel_shape, - initializer=self.kernel_initializer, - regularizer=self.kernel_regularizer, - dtype=self.dtype, - trainable=True) - if self.use_bias: - self.bias = self.add_weight( - 'bias', - shape=[last_dim], - initializer=self.bias_initializer, - regularizer=self.bias_regularizer, + if not isinstance(input_shape, (tuple, list)) or len(input_shape) != 2: + raise ValueError('Input shapes must be a tuple or list of size 2, ' + 'got {}'.format(input_shape)) + last_dim = input_shape[-1][-1] + if self.projection_dim is None: + kernel_shape = [last_dim, last_dim] + else: + if self.projection_dim != last_dim: + raise ValueError('The case where `projection_dim` != last ' + 'dimension of the inputs is not supported yet, got ' + '`projection_dim` {}, and last dimension of input ' + '{}'.format(self.projection_dim, last_dim)) + kernel_shape = [last_dim, self.projection_dim] + self.kernel = self.add_weight( + 'kernel', + shape=kernel_shape, + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, dtype=self.dtype, trainable=True) + if self.use_bias: + self.bias = self.add_weight( + 'bias', + shape=[last_dim], + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + dtype=self.dtype, + trainable=True) + self.built = True def call(self, inputs): - if not isinstance(inputs, (tuple, list)) or len(inputs) != 2: - raise ValueError('Inputs to the layer must be a tuple or list of size 2, ' - 'got {}'.format(inputs)) - x0, x = inputs - outputs = x0 * tf.matmul(x, self.kernel) + x - if self.use_bias: - outputs = tf.add(outputs, self.bias) - return outputs + if not isinstance(inputs, (tuple, list)) or len(inputs) != 2: + raise ValueError('Inputs to the layer must be a tuple or list of size 2, ' + 'got {}'.format(inputs)) + x0, x = inputs + outputs = x0 * tf.matmul(x, self.kernel) + x + if self.use_bias: + outputs = tf.add(outputs, self.bias) + return outputs def get_config(self): - config = { - 'projection_dim': - self.projection_dim, - 'use_bias': - self.use_bias, - 'kernel_initializer': - tf.keras.initializers.serialize(self.kernel_initializer), - 'bias_initializer': - tf.keras.initializers.serialize(self.bias_initializer), - 'kernel_regularizer': - tf.keras.regularizers.serialize(self.kernel_regularizer), - 'bias_regularizer': - tf.keras.regularizers.serialize(self.bias_regularizer) - } - base_config = super(PolynomialCrossing, self).get_config() - return dict(list(base_config.items()) + list(config.items())) + config = { + 'projection_dim': + self.projection_dim, + 'use_bias': + self.use_bias, + 'kernel_initializer': + tf.keras.initializers.serialize(self.kernel_initializer), + 'bias_initializer': + tf.keras.initializers.serialize(self.bias_initializer), + 'kernel_regularizer': + tf.keras.regularizers.serialize(self.kernel_regularizer), + 'bias_regularizer': + tf.keras.regularizers.serialize(self.bias_regularizer) + } + base_config = super(PolynomialCrossing, self).get_config() + return dict(list(base_config.items()) + list(config.items())) def compute_output_shape(self, input_shape): - if not isinstance(input_shape, (tuple, list)): - raise ValueError('A `PolynomialCrossing` layer should be called ' - 'on a list of inputs.') - return input_shape[0] \ No newline at end of file + if not isinstance(input_shape, (tuple, list)): + raise ValueError('A `PolynomialCrossing` layer should be called ' + 'on a list of inputs.') + return input_shape[0] \ No newline at end of file diff --git a/tensorflow_addons/layers/polynomial_test.py b/tensorflow_addons/layers/polynomial_test.py index 4dda05f02c..96be731c8c 100644 --- a/tensorflow_addons/layers/polynomial_test.py +++ b/tensorflow_addons/layers/polynomial_test.py @@ -14,12 +14,8 @@ # ============================================================================== """Tests for PolynomialCrossing layer.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - import numpy as np -import tensorflow.compat.v2 as tf +import tensorflow as tf from tensorflow_addons.layers.polynomial import PolynomialCrossing from tensorflow_addons.utils import test_utils @@ -27,22 +23,21 @@ @test_utils.run_all_in_graph_and_eager_modes class PolynomialCrossingTest(tf.test.TestCase): - # Do not use layer_test due to multiple inputs. + # Do not use layer_test due to multiple inputs. - def test_full_matrix(self): - x0 = np.random.random((12, 5)) - x = np.random.random((12, 5)) - layer = PolynomialCrossing(projection_dim=None) - layer([x0, x]) + def test_full_matrix(self): + x0 = np.random.random((12, 5)) + x = np.random.random((12, 5)) + layer = PolynomialCrossing(projection_dim=None) + layer([x0, x]) - def test_invalid_proj_dim(self): - with self.assertRaisesRegexp(ValueError, r'is not supported yet'): - x0 = np.random.random((12, 5)) - x = np.random.random((12, 5)) - layer = PolynomialCrossing(projection_dim=6) - layer([x0, x]) + def test_invalid_proj_dim(self): + with self.assertRaisesRegexp(ValueError, r'is not supported yet'): + x0 = np.random.random((12, 5)) + x = np.random.random((12, 5)) + layer = PolynomialCrossing(projection_dim=6) + layer([x0, x]) if __name__ == '__main__': - tf.enable_v2_behavior() tf.test.main() \ No newline at end of file From 806cadf72c551e9044a800924860630ac2806ad3 Mon Sep 17 00:00:00 2001 From: tanzhenyu Date: Mon, 3 Feb 2020 07:02:58 -0800 Subject: [PATCH 3/8] more indentation fix More indentitation fix. --- tensorflow_addons/layers/polynomial.py | 221 ++++++++++---------- tensorflow_addons/layers/polynomial_test.py | 6 +- 2 files changed, 117 insertions(+), 110 deletions(-) diff --git a/tensorflow_addons/layers/polynomial.py b/tensorflow_addons/layers/polynomial.py index c18a8ef9bb..f42c638f76 100644 --- a/tensorflow_addons/layers/polynomial.py +++ b/tensorflow_addons/layers/polynomial.py @@ -20,7 +20,7 @@ from tensorflow_addons.utils import types -@tf.keras.utils.register_keras_serializable(package='Addons') +@tf.keras.utils.register_keras_serializable(package="Addons") class PolynomialCrossing(tf.keras.layers.Layer): """Layer for Deep & Cross Network to learn explicit feature interactions. @@ -35,121 +35,128 @@ class PolynomialCrossing(tf.keras.layers.Layer): The output is y = x0 * (W .* x) + bias + xi, where .* designates dot product. References - See [R. Wang](https://arxiv.org/pdf/1708.05123.pdf) + See [R. Wang](https://arxiv.org/pdf/1708.05123.pdf) Example: - ```python - # after embedding layer in a functional model: - input = tf.keras.Input(shape=(None,), name='index', dtype=tf.int64) - x0 = tf.keras.layers.Embedding(input_dim=32, output_dim=6)) - x1 = PolynomialCrossing(projection_dim=None)((x0, x0)) - x2 = PolynomialCrossing(projection_dim=None)((x0, x1)) - logits = tf.keras.layers.Dense(units=10)(x2) - model = tf.keras.Model(input, logits) - ``` + ```python + # after embedding layer in a functional model: + input = tf.keras.Input(shape=(None,), name='index', dtype=tf.int64) + x0 = tf.keras.layers.Embedding(input_dim=32, output_dim=6)) + x1 = PolynomialCrossing(projection_dim=None)((x0, x0)) + x2 = PolynomialCrossing(projection_dim=None)((x0, x1)) + logits = tf.keras.layers.Dense(units=10)(x2) + model = tf.keras.Model(input, logits) + ``` Arguments: - projection_dim: project dimension. Default is `None` such that a full - (`input_dim` by `input_dim`) matrix is used. - use_bias: whether to calculate the bias/intercept for this layer. If set to - False, no bias/intercept will be used in calculations, e.g., the data is - already centered. - kernel_initializer: Initializer instance to use on the kernel matrix. - bias_initializer: Initializer instance to use on the bias vector. - kernel_regularizer: Regularizer instance to use on the kernel matrix. - bias_regularizer: Regularizer instance to use on bias vector. - - Input shape: + projection_dim: project dimension. Default is `None` such that a full + (`input_dim` by `input_dim`) matrix is used. + use_bias: whether to calculate the bias/intercept for this layer. If set to + False, no bias/intercept will be used in calculations, e.g., the data is + already centered. + kernel_initializer: Initializer instance to use on the kernel matrix. + bias_initializer: Initializer instance to use on the bias vector. + kernel_regularizer: Regularizer instance to use on the kernel matrix. + bias_regularizer: Regularizer instance to use on bias vector. + + Input shape: A tuple of 2 (batch_size, `input_dim`) dimensional inputs. Output shape: A single (batch_size, `input_dim`) dimensional output. """ - @typechecked - def __init__( - self, - projection_dim :int = None, - use_bias :bool = True, - kernel_initializer: types.Initializer = 'truncated_normal', - bias_initializer : types.Initializer = 'zeros', - kernel_regularizer : types.Regularizer = None, - bias_regularizer: types.Regularizer = None, - **kwargs - ): - super(PolynomialCrossing, self).__init__(**kwargs) - - self.projection_dim = projection_dim - self.use_bias = use_bias - self.kernel_initializer = tf.keras.initializers.get(kernel_initializer) - self.bias_initializer = tf.keras.initializers.get(bias_initializer) - self.kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer) - self.bias_regularizer = tf.keras.regularizers.get(bias_regularizer) - - self.supports_masking = True - - def build(self, input_shape): - if not isinstance(input_shape, (tuple, list)) or len(input_shape) != 2: - raise ValueError('Input shapes must be a tuple or list of size 2, ' - 'got {}'.format(input_shape)) - last_dim = input_shape[-1][-1] - if self.projection_dim is None: - kernel_shape = [last_dim, last_dim] - else: - if self.projection_dim != last_dim: - raise ValueError('The case where `projection_dim` != last ' - 'dimension of the inputs is not supported yet, got ' - '`projection_dim` {}, and last dimension of input ' - '{}'.format(self.projection_dim, last_dim)) - kernel_shape = [last_dim, self.projection_dim] - self.kernel = self.add_weight( - 'kernel', - shape=kernel_shape, - initializer=self.kernel_initializer, - regularizer=self.kernel_regularizer, - dtype=self.dtype, - trainable=True) - if self.use_bias: - self.bias = self.add_weight( - 'bias', - shape=[last_dim], - initializer=self.bias_initializer, - regularizer=self.bias_regularizer, + @typechecked + def __init__( + self, + projection_dim: int = None, + use_bias: bool = True, + kernel_initializer: types.Initializer = "truncated_normal", + bias_initializer: types.Initializer = "zeros", + kernel_regularizer: types.Regularizer = None, + bias_regularizer: types.Regularizer = None, + **kwargs, + ): + super(PolynomialCrossing, self).__init__(**kwargs) + + self.projection_dim = projection_dim + self.use_bias = use_bias + self.kernel_initializer = tf.keras.initializers.get(kernel_initializer) + self.bias_initializer = tf.keras.initializers.get(bias_initializer) + self.kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer) + self.bias_regularizer = tf.keras.regularizers.get(bias_regularizer) + + self.supports_masking = True + + def build(self, input_shape): + if not isinstance(input_shape, (tuple, list)) or len(input_shape) != 2: + raise ValueError( + "Input shapes must be a tuple or list of size 2, " + "got {}".format(input_shape) + ) + last_dim = input_shape[-1][-1] + if self.projection_dim is None: + kernel_shape = [last_dim, last_dim] + else: + if self.projection_dim != last_dim: + raise ValueError( + "The case where `projection_dim` != last " + "dimension of the inputs is not supported yet, got " + "`projection_dim` {}, and last dimension of input " + "{}".format(self.projection_dim, last_dim) + ) + kernel_shape = [last_dim, self.projection_dim] + self.kernel = self.add_weight( + "kernel", + shape=kernel_shape, + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, dtype=self.dtype, - trainable=True) - self.built = True - - def call(self, inputs): - if not isinstance(inputs, (tuple, list)) or len(inputs) != 2: - raise ValueError('Inputs to the layer must be a tuple or list of size 2, ' - 'got {}'.format(inputs)) - x0, x = inputs - outputs = x0 * tf.matmul(x, self.kernel) + x - if self.use_bias: - outputs = tf.add(outputs, self.bias) - return outputs - - def get_config(self): - config = { - 'projection_dim': - self.projection_dim, - 'use_bias': - self.use_bias, - 'kernel_initializer': - tf.keras.initializers.serialize(self.kernel_initializer), - 'bias_initializer': - tf.keras.initializers.serialize(self.bias_initializer), - 'kernel_regularizer': - tf.keras.regularizers.serialize(self.kernel_regularizer), - 'bias_regularizer': - tf.keras.regularizers.serialize(self.bias_regularizer) - } - base_config = super(PolynomialCrossing, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - def compute_output_shape(self, input_shape): - if not isinstance(input_shape, (tuple, list)): - raise ValueError('A `PolynomialCrossing` layer should be called ' - 'on a list of inputs.') - return input_shape[0] \ No newline at end of file + trainable=True, + ) + if self.use_bias: + self.bias = self.add_weight( + "bias", + shape=[last_dim], + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + dtype=self.dtype, + trainable=True, + ) + self.built = True + + def call(self, inputs): + if not isinstance(inputs, (tuple, list)) or len(inputs) != 2: + raise ValueError( + "Inputs to the layer must be a tuple or list of size 2, " + "got {}".format(inputs) + ) + x0, x = inputs + outputs = x0 * tf.matmul(x, self.kernel) + x + if self.use_bias: + outputs = tf.add(outputs, self.bias) + return outputs + + def get_config(self): + config = { + "projection_dim": self.projection_dim, + "use_bias": self.use_bias, + "kernel_initializer": tf.keras.initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": tf.keras.initializers.serialize(self.bias_initializer), + "kernel_regularizer": tf.keras.regularizers.serialize( + self.kernel_regularizer + ), + "bias_regularizer": tf.keras.regularizers.serialize(self.bias_regularizer), + } + base_config = super(PolynomialCrossing, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + def compute_output_shape(self, input_shape): + if not isinstance(input_shape, (tuple, list)): + raise ValueError( + "A `PolynomialCrossing` layer should be called " "on a list of inputs." + ) + return input_shape[0] diff --git a/tensorflow_addons/layers/polynomial_test.py b/tensorflow_addons/layers/polynomial_test.py index 96be731c8c..677f903a94 100644 --- a/tensorflow_addons/layers/polynomial_test.py +++ b/tensorflow_addons/layers/polynomial_test.py @@ -32,12 +32,12 @@ def test_full_matrix(self): layer([x0, x]) def test_invalid_proj_dim(self): - with self.assertRaisesRegexp(ValueError, r'is not supported yet'): + with self.assertRaisesRegexp(ValueError, r"is not supported yet"): x0 = np.random.random((12, 5)) x = np.random.random((12, 5)) layer = PolynomialCrossing(projection_dim=6) layer([x0, x]) -if __name__ == '__main__': - tf.test.main() \ No newline at end of file +if __name__ == "__main__": + tf.test.main() From a0872ac884ea2bcc69af725991befc27d8136fad Mon Sep 17 00:00:00 2001 From: tanzhenyu Date: Mon, 3 Feb 2020 21:19:58 -0800 Subject: [PATCH 4/8] Update polynomial_test.py Added a few more tests. --- tensorflow_addons/layers/polynomial_test.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/tensorflow_addons/layers/polynomial_test.py b/tensorflow_addons/layers/polynomial_test.py index 677f903a94..9af09274f7 100644 --- a/tensorflow_addons/layers/polynomial_test.py +++ b/tensorflow_addons/layers/polynomial_test.py @@ -26,10 +26,11 @@ class PolynomialCrossingTest(tf.test.TestCase): # Do not use layer_test due to multiple inputs. def test_full_matrix(self): - x0 = np.random.random((12, 5)) - x = np.random.random((12, 5)) + x0 = np.asarray([[.1, .2, .3]]).astype(np.float32) + x = np.asarray([[.4, .5, .6]]).astype(np.float32) layer = PolynomialCrossing(projection_dim=None) - layer([x0, x]) + output = layer([x0, x]) + self.assertAllClose(np.asarray([[.55, .8, 1.05]]), output) def test_invalid_proj_dim(self): with self.assertRaisesRegexp(ValueError, r"is not supported yet"): @@ -38,6 +39,20 @@ def test_invalid_proj_dim(self): layer = PolynomialCrossing(projection_dim=6) layer([x0, x]) + def test_invalid_inputs(self): + with self.assertRaisesRegexp(ValueError, r"must be a tuple or list of size 2"): + x0 = np.random.random((12, 5)) + x = np.random.random((12, 5)) + x1 = np.random.random((12, 5)) + layer = PolynomialCrossing(projection_dim=6) + layer([x0, x, x1]) + + def test_serialization(self): + layer = PolynomialCrossing(projection_dim=None) + serialized_layer = tf.keras.layers.serialize(layer) + new_layer = tf.keras.layers.deserialize(serialized_layer) + self.assertEqual(layer.get_config(), new_layer.get_config()) + if __name__ == "__main__": tf.test.main() From 276956f7865a72755497e74f269c7e1730f644b9 Mon Sep 17 00:00:00 2001 From: tanzhenyu Date: Mon, 3 Feb 2020 21:22:05 -0800 Subject: [PATCH 5/8] Update polynomial_test.py black the test --- tensorflow_addons/layers/polynomial_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow_addons/layers/polynomial_test.py b/tensorflow_addons/layers/polynomial_test.py index 9af09274f7..8ccd699c63 100644 --- a/tensorflow_addons/layers/polynomial_test.py +++ b/tensorflow_addons/layers/polynomial_test.py @@ -26,11 +26,11 @@ class PolynomialCrossingTest(tf.test.TestCase): # Do not use layer_test due to multiple inputs. def test_full_matrix(self): - x0 = np.asarray([[.1, .2, .3]]).astype(np.float32) - x = np.asarray([[.4, .5, .6]]).astype(np.float32) + x0 = np.asarray([[0.1, 0.2, 0.3]]).astype(np.float32) + x = np.asarray([[0.4, 0.5, 0.6]]).astype(np.float32) layer = PolynomialCrossing(projection_dim=None) output = layer([x0, x]) - self.assertAllClose(np.asarray([[.55, .8, 1.05]]), output) + self.assertAllClose(np.asarray([[0.55, 0.8, 1.05]]), output) def test_invalid_proj_dim(self): with self.assertRaisesRegexp(ValueError, r"is not supported yet"): From f9bc09fd80950cb06b8a63fc394955dd1137a799 Mon Sep 17 00:00:00 2001 From: tanzhenyu Date: Mon, 3 Feb 2020 22:08:38 -0800 Subject: [PATCH 6/8] Update polynomial_test.py adding var init. --- tensorflow_addons/layers/polynomial_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow_addons/layers/polynomial_test.py b/tensorflow_addons/layers/polynomial_test.py index 8ccd699c63..aff6209202 100644 --- a/tensorflow_addons/layers/polynomial_test.py +++ b/tensorflow_addons/layers/polynomial_test.py @@ -30,6 +30,7 @@ def test_full_matrix(self): x = np.asarray([[0.4, 0.5, 0.6]]).astype(np.float32) layer = PolynomialCrossing(projection_dim=None) output = layer([x0, x]) + self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertAllClose(np.asarray([[0.55, 0.8, 1.05]]), output) def test_invalid_proj_dim(self): From 9d0982bc2f6af2c13f983130cabda49272820e6e Mon Sep 17 00:00:00 2001 From: tanzhenyu Date: Tue, 4 Feb 2020 01:06:18 -0800 Subject: [PATCH 7/8] Update polynomial_test.py fix kernel initializer --- tensorflow_addons/layers/polynomial_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/layers/polynomial_test.py b/tensorflow_addons/layers/polynomial_test.py index aff6209202..cc98874385 100644 --- a/tensorflow_addons/layers/polynomial_test.py +++ b/tensorflow_addons/layers/polynomial_test.py @@ -28,7 +28,7 @@ class PolynomialCrossingTest(tf.test.TestCase): def test_full_matrix(self): x0 = np.asarray([[0.1, 0.2, 0.3]]).astype(np.float32) x = np.asarray([[0.4, 0.5, 0.6]]).astype(np.float32) - layer = PolynomialCrossing(projection_dim=None) + layer = PolynomialCrossing(projection_dim=None, kernel_initializer='ones') output = layer([x0, x]) self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertAllClose(np.asarray([[0.55, 0.8, 1.05]]), output) From daef7c8f87b3a1b3925f775b3d99636aaf8d659e Mon Sep 17 00:00:00 2001 From: tanzhenyu Date: Tue, 4 Feb 2020 01:08:48 -0800 Subject: [PATCH 8/8] Update polynomial_test.py update. --- tensorflow_addons/layers/polynomial_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/layers/polynomial_test.py b/tensorflow_addons/layers/polynomial_test.py index cc98874385..392eff4a85 100644 --- a/tensorflow_addons/layers/polynomial_test.py +++ b/tensorflow_addons/layers/polynomial_test.py @@ -28,7 +28,7 @@ class PolynomialCrossingTest(tf.test.TestCase): def test_full_matrix(self): x0 = np.asarray([[0.1, 0.2, 0.3]]).astype(np.float32) x = np.asarray([[0.4, 0.5, 0.6]]).astype(np.float32) - layer = PolynomialCrossing(projection_dim=None, kernel_initializer='ones') + layer = PolynomialCrossing(projection_dim=None, kernel_initializer="ones") output = layer([x0, x]) self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertAllClose(np.asarray([[0.55, 0.8, 1.05]]), output)