Skip to content

Commit 201daef

Browse files
committed
ENH: implement PoincareNormalize
1 parent d4b3114 commit 201daef

File tree

3 files changed

+182
-0
lines changed

3 files changed

+182
-0
lines changed

tensorflow_addons/layers/BUILD

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ py_library(
77
srcs = ([
88
"__init__.py",
99
"python/__init__.py",
10+
"python/poincare.py",
1011
"python/wrappers.py",
1112
]),
1213
srcs_version = "PY2AND3",
@@ -22,4 +23,17 @@ py_test(
2223
":layers_py",
2324
],
2425
srcs_version = "PY2AND3",
26+
)
27+
28+
py_test(
29+
name = "poincare_py_test",
30+
size = "small",
31+
srcs = [
32+
"python/poincare_test.py",
33+
],
34+
main = "python/poincare_test.py",
35+
deps = [
36+
":layers_py",
37+
],
38+
srcs_version = "PY2AND3",
2539
)
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Implementing PoincareNormalize layer."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
from tensorflow.python.framework import ops
22+
from tensorflow.python.keras.engine.base_layer import Layer
23+
from tensorflow.python.ops import math_ops
24+
25+
26+
class PoincareNormalize(Layer):
27+
"""Project into the Poincare ball with norm <= 1.0 - epsilon.
28+
29+
https://en.wikipedia.org/wiki/Poincare_ball_model
30+
31+
Used in
32+
Poincare Embeddings for Learning Hierarchical Representations
33+
Maximilian Nickel, Douwe Kiela
34+
https://arxiv.org/pdf/1705.08039.pdf
35+
36+
For a 1-D tensor with `axis = 0`, computes
37+
38+
(x * (1 - epsilon)) / ||x|| if ||x|| > 1 - epsilon
39+
output =
40+
x otherwise
41+
42+
For `x` with more dimensions, independently normalizes each 1-D slice along
43+
dimension `axis`.
44+
45+
Arguments:
46+
axis: Axis along which to normalize. A scalar or a vector of
47+
integers.
48+
epsilon: A small deviation from the edge of the unit sphere for numerical
49+
stability.
50+
"""
51+
52+
def __init__(self, axis=1, epsilon=1e-5, **kwargs):
53+
super(PoincareNormalize, self).__init__(**kwargs)
54+
self.axis = axis
55+
self.epsilon = epsilon
56+
57+
def call(self, inputs):
58+
x = ops.convert_to_tensor(inputs)
59+
square_sum = math_ops.reduce_sum(
60+
math_ops.square(x), self.axis, keepdims=True)
61+
x_inv_norm = math_ops.rsqrt(square_sum)
62+
x_inv_norm = math_ops.minimum((1. - self.epsilon) * x_inv_norm, 1.)
63+
outputs = math_ops.multiply(x, x_inv_norm)
64+
return outputs
65+
66+
def compute_output_shape(self, input_shape):
67+
return input_shape
68+
69+
def get_config(self):
70+
config = {'axis': self.axis, 'epsilon': self.epsilon}
71+
base_config = super(PoincareNormalize, self).get_config()
72+
return dict(list(base_config.items()) + list(config.items()))
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for PoincareNormalize layer."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import numpy as np
22+
23+
from tensorflow.python.keras import testing_utils
24+
from tensorflow.python.keras.utils import generic_utils
25+
from tensorflow.python.platform import test
26+
from tensorflow_addons.layers.python.poincare import PoincareNormalize
27+
28+
29+
class PoincareNormalizeTest(test.TestCase):
30+
def _PoincareNormalize(self, x, dim, epsilon=1e-5):
31+
if isinstance(dim, list):
32+
norm = np.linalg.norm(x, axis=tuple(dim))
33+
for d in dim:
34+
norm = np.expand_dims(norm, d)
35+
norm_x = ((1. - epsilon) * x) / norm
36+
else:
37+
norm = np.expand_dims(
38+
np.apply_along_axis(np.linalg.norm, dim, x), dim)
39+
norm_x = ((1. - epsilon) * x) / norm
40+
return np.where(norm > 1.0 - epsilon, norm_x, x)
41+
42+
def testPoincareNormalize(self):
43+
x_shape = [20, 7, 3]
44+
epsilon = 1e-5
45+
tol = 1e-6
46+
np.random.seed(1)
47+
inputs = np.random.random_sample(x_shape).astype(np.float32)
48+
49+
for dim in range(len(x_shape)):
50+
outputs_expected = self._PoincareNormalize(inputs, dim, epsilon)
51+
52+
with generic_utils.custom_object_scope({
53+
'PoincareNormalize':
54+
PoincareNormalize
55+
}):
56+
outputs = testing_utils.layer_test(
57+
PoincareNormalize,
58+
kwargs={
59+
'axis': dim,
60+
'epsilon': epsilon
61+
},
62+
input_data=inputs,
63+
expected_output=outputs_expected)
64+
for y in outputs_expected, outputs:
65+
norm = np.linalg.norm(y, axis=dim)
66+
self.assertLessEqual(norm.max(), 1. - epsilon + tol)
67+
68+
def testPoincareNormalizeDimArray(self):
69+
x_shape = [20, 7, 3]
70+
epsilon = 1e-5
71+
tol = 1e-6
72+
np.random.seed(1)
73+
inputs = np.random.random_sample(x_shape).astype(np.float32)
74+
dim = [1, 2]
75+
76+
outputs_expected = self._PoincareNormalize(inputs, dim, epsilon)
77+
78+
with generic_utils.custom_object_scope({
79+
'PoincareNormalize':
80+
PoincareNormalize
81+
}):
82+
outputs = testing_utils.layer_test(
83+
PoincareNormalize,
84+
kwargs={
85+
'axis': dim,
86+
'epsilon': epsilon
87+
},
88+
input_data=inputs,
89+
expected_output=outputs_expected)
90+
for y in outputs_expected, outputs:
91+
norm = np.linalg.norm(y, axis=tuple(dim))
92+
self.assertLessEqual(norm.max(), 1. - epsilon + tol)
93+
94+
95+
if __name__ == '__main__':
96+
test.main()

0 commit comments

Comments
 (0)