Skip to content

Commit 682ec5a

Browse files
committed
implement sparsemax and sparsemax_loss
* ENH: Including a Sparsemax Later and A SparsemaxLoss class
1 parent 1bd2ba6 commit 682ec5a

File tree

17 files changed

+1027
-0
lines changed

17 files changed

+1027
-0
lines changed

BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ sh_binary(
66
"MANIFEST.in",
77
"setup.py",
88
"tensorflow_addons/__init__.py",
9+
"//tensorflow_addons/activations:activations_py",
910
"//tensorflow_addons/custom_ops:custom_ops_py",
1011
"//tensorflow_addons/layers:layers_py",
1112
"//tensorflow_addons/losses:losses_py",
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
licenses(["notice"]) # Apache 2.0
2+
3+
package(default_visibility = ["//visibility:public"])
4+
5+
py_library(
6+
name = "activations_py",
7+
srcs = [
8+
"__init__.py",
9+
"python/__init__.py",
10+
"python/sparsemax.py"
11+
],
12+
srcs_version = "PY2AND3",
13+
deps = [
14+
"//tensorflow_addons/utils:utils_py",
15+
],
16+
)
17+
18+
py_test(
19+
name = "sparsemax_py_test",
20+
size = "small",
21+
srcs = [
22+
"python/sparsemax_test.py",
23+
],
24+
main = "python/sparsemax_test.py",
25+
srcs_version = "PY2AND3",
26+
deps = [
27+
":activations_py",
28+
],
29+
)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""
16+
A module containing activation routines.
17+
"""
18+
19+
from __future__ import absolute_import
20+
from __future__ import division
21+
from __future__ import print_function
22+
23+
from tensorflow_addons.activations.python.sparsemax import sparsemax

tensorflow_addons/activations/python/__init__.py

Whitespace-only changes.
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import tensorflow as tf
21+
22+
from tensorflow_addons.utils.python import keras_utils
23+
24+
@keras_utils.register_keras_custom_object
25+
def sparsemax(logits, axis=-1, name=None):
26+
"""Sparsemax activation function [1].
27+
28+
For each batch `i` and class `j` we have
29+
$$sparsemax[i, j] = max(logits[i, j] - tau(logits[i, :]), 0)$$
30+
31+
[1]: https://arxiv.org/abs/1602.02068
32+
33+
Args:
34+
logits: Input tensor.
35+
axis: Integer, axis along which the sparsemax operation is applied.
36+
name: A name for the operation (optional).
37+
Returns:
38+
Tensor, output of sparsemax transformation. Has the same type and
39+
shape as `logits`.
40+
Raises:
41+
ValueError: In case `dim(logits) == 1`.
42+
"""
43+
logits = tf.convert_to_tensor(logits, name="logits")
44+
45+
# We need its original shape for shape inference.
46+
shape = logits.get_shape()
47+
rank = shape.rank
48+
is_last_axis = (axis == -1) or (axis == rank - 1)
49+
50+
if is_last_axis:
51+
output = _compute_2d_sparsemax(logits, name=name)
52+
output.set_shape(shape)
53+
return output
54+
55+
# If dim is not the last dimension, we have to do a transpose so that we can
56+
# still perform softmax on its last dimension.
57+
58+
# Swap logits' dimension of dim and its last dimension.
59+
rank_op = tf.rank(logits)
60+
axis_norm = axis % rank
61+
logits = _swap_axis(logits, axis_norm, tf.math.subtract(rank_op, 1))
62+
63+
# Do the actual softmax on its last dimension.
64+
output = _compute_2d_sparsemax(logits)
65+
output = _swap_axis(output, axis_norm, tf.math.subtract(rank_op, 1),
66+
name=name)
67+
68+
# Make shape inference work since transpose may erase its static shape.
69+
output.set_shape(shape)
70+
return output
71+
72+
73+
def _swap_axis(logits, dim_index, last_index, **kwargs):
74+
return tf.transpose(
75+
logits,
76+
tf.concat([
77+
tf.range(dim_index), [last_index],
78+
tf.range(dim_index + 1, last_index), [dim_index]
79+
], 0),
80+
**kwargs)
81+
82+
83+
@tf.function
84+
def _compute_2d_sparsemax(logits, name=None):
85+
"""Performs the sparsemax operation when axis=-1"""
86+
shape_op = tf.shape(logits)
87+
obs = tf.math.reduce_prod(shape_op[:-1])
88+
dims = shape_op[-1]
89+
90+
# In the paper, they call the logits z.
91+
# The mean(logits) can be substracted from logits to make the algorithm
92+
# more numerically stable. the instability in this algorithm comes mostly
93+
# from the z_cumsum. Substacting the mean will cause z_cumsum to be close
94+
# to zero. However, in practise the numerical instability issues are very
95+
# minor and substacting the mean causes extra issues with inf and nan
96+
# input.
97+
# Reshape to [obs, dims] as it is almost free and means the remanining
98+
# code doesn't need to worry about the rank.
99+
z = tf.reshape(logits, [obs, dims])
100+
101+
# sort z
102+
z_sorted, _ = tf.nn.top_k(z, k=dims)
103+
104+
# calculate k(z)
105+
z_cumsum = tf.math.cumsum(z_sorted, axis=-1)
106+
k = tf.range(
107+
1, tf.cast(dims, logits.dtype) + 1, dtype=logits.dtype)
108+
z_check = 1 + k * z_sorted > z_cumsum
109+
# because the z_check vector is always [1,1,...1,0,0,...0] finding the
110+
# (index + 1) of the last `1` is the same as just summing the number of 1.
111+
k_z = tf.math.reduce_sum(tf.cast(z_check, tf.int32), axis=-1)
112+
113+
# calculate tau(z)
114+
# If there are inf values or all values are -inf, the k_z will be zero,
115+
# this is mathematically invalid and will also cause the gather_nd to fail.
116+
# Prevent this issue for now by setting k_z = 1 if k_z = 0, this is then
117+
# fixed later (see p_safe) by returning p = nan. This results in the same
118+
# behavior as softmax.
119+
k_z_safe = tf.math.maximum(k_z, 1)
120+
indices = tf.stack([
121+
tf.range(0, obs),
122+
tf.reshape(k_z_safe, [-1]) - 1
123+
], axis=1)
124+
tau_sum = tf.gather_nd(z_cumsum, indices)
125+
tau_z = (tau_sum - 1) / tf.cast(k_z, logits.dtype)
126+
127+
# calculate p
128+
p = tf.math.maximum(
129+
tf.cast(0, logits.dtype), z - tf.expand_dims(tau_z, -1))
130+
# If k_z = 0 or if z = nan, then the input is invalid
131+
p_safe = tf.where(
132+
tf.math.logical_or(
133+
tf.math.equal(k_z, 0),
134+
tf.math.is_nan(z_cumsum[:, -1])
135+
),
136+
tf.fill([obs, dims], tf.cast(float("nan"), logits.dtype)),
137+
p)
138+
139+
# Reshape back to original size
140+
p_safe = tf.reshape(p_safe, shape_op, name=name)
141+
return p_safe

0 commit comments

Comments
 (0)