From d74d701584e7f0e85dc7f450619e1a9d06713a78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Thu, 17 Jan 2019 13:51:21 +0800 Subject: [PATCH 1/3] ENH: implement PoincareNormalize --- tensorflow_addons/layers/BUILD | 16 +++- tensorflow_addons/layers/python/poincare.py | 72 ++++++++++++++ .../layers/python/poincare_test.py | 96 +++++++++++++++++++ 3 files changed, 183 insertions(+), 1 deletion(-) create mode 100644 tensorflow_addons/layers/python/poincare.py create mode 100644 tensorflow_addons/layers/python/poincare_test.py diff --git a/tensorflow_addons/layers/BUILD b/tensorflow_addons/layers/BUILD index 1d5c07d687..0de36d33e0 100644 --- a/tensorflow_addons/layers/BUILD +++ b/tensorflow_addons/layers/BUILD @@ -7,6 +7,7 @@ py_library( srcs = ([ "__init__.py", "python/__init__.py", + "python/poincare.py", "python/wrappers.py", ]), srcs_version = "PY2AND3", @@ -22,4 +23,17 @@ py_test( ":layers_py", ], srcs_version = "PY2AND3", -) \ No newline at end of file +) + +py_test( + name = "poincare_py_test", + size = "small", + srcs = [ + "python/poincare_test.py", + ], + main = "python/poincare_test.py", + deps = [ + ":layers_py", + ], + srcs_version = "PY2AND3", +) diff --git a/tensorflow_addons/layers/python/poincare.py b/tensorflow_addons/layers/python/poincare.py new file mode 100644 index 0000000000..e9b846b838 --- /dev/null +++ b/tensorflow_addons/layers/python/poincare.py @@ -0,0 +1,72 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implementing PoincareNormalize layer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.keras.engine.base_layer import Layer +from tensorflow.python.ops import math_ops + + +class PoincareNormalize(Layer): + """Project into the Poincare ball with norm <= 1.0 - epsilon. + + https://en.wikipedia.org/wiki/Poincare_ball_model + + Used in + Poincare Embeddings for Learning Hierarchical Representations + Maximilian Nickel, Douwe Kiela + https://arxiv.org/pdf/1705.08039.pdf + + For a 1-D tensor with `axis = 0`, computes + + (x * (1 - epsilon)) / ||x|| if ||x|| > 1 - epsilon + output = + x otherwise + + For `x` with more dimensions, independently normalizes each 1-D slice along + dimension `axis`. + + Arguments: + axis: Axis along which to normalize. A scalar or a vector of + integers. + epsilon: A small deviation from the edge of the unit sphere for numerical + stability. + """ + + def __init__(self, axis=1, epsilon=1e-5, **kwargs): + super(PoincareNormalize, self).__init__(**kwargs) + self.axis = axis + self.epsilon = epsilon + + def call(self, inputs): + x = ops.convert_to_tensor(inputs) + square_sum = math_ops.reduce_sum( + math_ops.square(x), self.axis, keepdims=True) + x_inv_norm = math_ops.rsqrt(square_sum) + x_inv_norm = math_ops.minimum((1. - self.epsilon) * x_inv_norm, 1.) + outputs = math_ops.multiply(x, x_inv_norm) + return outputs + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = {'axis': self.axis, 'epsilon': self.epsilon} + base_config = super(PoincareNormalize, self).get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/tensorflow_addons/layers/python/poincare_test.py b/tensorflow_addons/layers/python/poincare_test.py new file mode 100644 index 0000000000..fe9fe9f656 --- /dev/null +++ b/tensorflow_addons/layers/python/poincare_test.py @@ -0,0 +1,96 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for PoincareNormalize layer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.utils import generic_utils +from tensorflow.python.platform import test +from tensorflow_addons.layers.python.poincare import PoincareNormalize + + +class PoincareNormalizeTest(test.TestCase): + def _PoincareNormalize(self, x, dim, epsilon=1e-5): + if isinstance(dim, list): + norm = np.linalg.norm(x, axis=tuple(dim)) + for d in dim: + norm = np.expand_dims(norm, d) + norm_x = ((1. - epsilon) * x) / norm + else: + norm = np.expand_dims( + np.apply_along_axis(np.linalg.norm, dim, x), dim) + norm_x = ((1. - epsilon) * x) / norm + return np.where(norm > 1.0 - epsilon, norm_x, x) + + def testPoincareNormalize(self): + x_shape = [20, 7, 3] + epsilon = 1e-5 + tol = 1e-6 + np.random.seed(1) + inputs = np.random.random_sample(x_shape).astype(np.float32) + + for dim in range(len(x_shape)): + outputs_expected = self._PoincareNormalize(inputs, dim, epsilon) + + with generic_utils.custom_object_scope({ + 'PoincareNormalize': + PoincareNormalize + }): + outputs = testing_utils.layer_test( + PoincareNormalize, + kwargs={ + 'axis': dim, + 'epsilon': epsilon + }, + input_data=inputs, + expected_output=outputs_expected) + for y in outputs_expected, outputs: + norm = np.linalg.norm(y, axis=dim) + self.assertLessEqual(norm.max(), 1. - epsilon + tol) + + def testPoincareNormalizeDimArray(self): + x_shape = [20, 7, 3] + epsilon = 1e-5 + tol = 1e-6 + np.random.seed(1) + inputs = np.random.random_sample(x_shape).astype(np.float32) + dim = [1, 2] + + outputs_expected = self._PoincareNormalize(inputs, dim, epsilon) + + with generic_utils.custom_object_scope({ + 'PoincareNormalize': + PoincareNormalize + }): + outputs = testing_utils.layer_test( + PoincareNormalize, + kwargs={ + 'axis': dim, + 'epsilon': epsilon + }, + input_data=inputs, + expected_output=outputs_expected) + for y in outputs_expected, outputs: + norm = np.linalg.norm(y, axis=tuple(dim)) + self.assertLessEqual(norm.max(), 1. - epsilon + tol) + + +if __name__ == '__main__': + test.main() From f264c7d8de9a30e61b3ca83de5afa1ca83e5a16c Mon Sep 17 00:00:00 2001 From: Sean Morgan Date: Fri, 25 Jan 2019 15:36:26 -0500 Subject: [PATCH 2/3] Register custom layer --- tensorflow_addons/layers/python/poincare.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tensorflow_addons/layers/python/poincare.py b/tensorflow_addons/layers/python/poincare.py index e9b846b838..037e7b0c82 100644 --- a/tensorflow_addons/layers/python/poincare.py +++ b/tensorflow_addons/layers/python/poincare.py @@ -19,6 +19,7 @@ from __future__ import print_function from tensorflow.python.framework import ops +from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.ops import math_ops @@ -70,3 +71,7 @@ def get_config(self): config = {'axis': self.axis, 'epsilon': self.epsilon} base_config = super(PoincareNormalize, self).get_config() return dict(list(base_config.items()) + list(config.items())) + + +generic_utils._GLOBAL_CUSTOM_OBJECTS['PoincareNormalize'] = PoincareNormalize + From 40c722b9446d95bbef20978ccf9508fe3966e5ae Mon Sep 17 00:00:00 2001 From: Sean Morgan Date: Fri, 25 Jan 2019 15:37:23 -0500 Subject: [PATCH 3/3] Modify tests for pre-registered layer --- .../layers/python/poincare_test.py | 53 ++++++++----------- 1 file changed, 22 insertions(+), 31 deletions(-) diff --git a/tensorflow_addons/layers/python/poincare_test.py b/tensorflow_addons/layers/python/poincare_test.py index fe9fe9f656..81be19c249 100644 --- a/tensorflow_addons/layers/python/poincare_test.py +++ b/tensorflow_addons/layers/python/poincare_test.py @@ -21,7 +21,6 @@ import numpy as np from tensorflow.python.keras import testing_utils -from tensorflow.python.keras.utils import generic_utils from tensorflow.python.platform import test from tensorflow_addons.layers.python.poincare import PoincareNormalize @@ -49,21 +48,17 @@ def testPoincareNormalize(self): for dim in range(len(x_shape)): outputs_expected = self._PoincareNormalize(inputs, dim, epsilon) - with generic_utils.custom_object_scope({ - 'PoincareNormalize': - PoincareNormalize - }): - outputs = testing_utils.layer_test( - PoincareNormalize, - kwargs={ - 'axis': dim, - 'epsilon': epsilon - }, - input_data=inputs, - expected_output=outputs_expected) - for y in outputs_expected, outputs: - norm = np.linalg.norm(y, axis=dim) - self.assertLessEqual(norm.max(), 1. - epsilon + tol) + outputs = testing_utils.layer_test( + PoincareNormalize, + kwargs={ + 'axis': dim, + 'epsilon': epsilon + }, + input_data=inputs, + expected_output=outputs_expected) + for y in outputs_expected, outputs: + norm = np.linalg.norm(y, axis=dim) + self.assertLessEqual(norm.max(), 1. - epsilon + tol) def testPoincareNormalizeDimArray(self): x_shape = [20, 7, 3] @@ -75,21 +70,17 @@ def testPoincareNormalizeDimArray(self): outputs_expected = self._PoincareNormalize(inputs, dim, epsilon) - with generic_utils.custom_object_scope({ - 'PoincareNormalize': - PoincareNormalize - }): - outputs = testing_utils.layer_test( - PoincareNormalize, - kwargs={ - 'axis': dim, - 'epsilon': epsilon - }, - input_data=inputs, - expected_output=outputs_expected) - for y in outputs_expected, outputs: - norm = np.linalg.norm(y, axis=tuple(dim)) - self.assertLessEqual(norm.max(), 1. - epsilon + tol) + outputs = testing_utils.layer_test( + PoincareNormalize, + kwargs={ + 'axis': dim, + 'epsilon': epsilon + }, + input_data=inputs, + expected_output=outputs_expected) + for y in outputs_expected, outputs: + norm = np.linalg.norm(y, axis=tuple(dim)) + self.assertLessEqual(norm.max(), 1. - epsilon + tol) if __name__ == '__main__':