Skip to content

Commit b6a1091

Browse files
AndreasMadsenseanpmorgan
authored andcommitted
implement sparsemax and sparsemax_loss (#65)
* ENH: implement sparsemax and sparsemax_loss
1 parent 440bd6a commit b6a1091

File tree

20 files changed

+1056
-0
lines changed

20 files changed

+1056
-0
lines changed

BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ sh_binary(
66
"MANIFEST.in",
77
"setup.py",
88
"//tensorflow_addons",
9+
"//tensorflow_addons/activations:activations_py",
910
"//tensorflow_addons/custom_ops:custom_ops_py",
1011
"//tensorflow_addons/layers:layers_py",
1112
"//tensorflow_addons/losses:losses_py",

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@ developments that cannot be integrated into core TensorFlow
1212
## Contents
1313
| Sub-Package | Addon | Reference |
1414
|:----------------------- |:----------- |:---------------------------- |
15+
| tfa.activations | Sparsemax | https://arxiv.org/abs/1602.02068 |
1516
| tfa.image | transform | |
1617
| tfa.layers | Maxout | https://arxiv.org/abs/1302.4389 |
1718
| tfa.layers | PoinareNormalize | https://arxiv.org/abs/1705.08039 |
1819
| tfa.layers | WeightNormalization | https://arxiv.org/abs/1602.07868 |
1920
| tfa.losses | LiftedStructLoss | https://arxiv.org/abs/1511.06452 |
21+
| tfa.losses | SparsemaxLoss | https://arxiv.org/abs/1602.02068 |
2022
| tfa.losses | TripletSemiHardLoss | https://arxiv.org/abs/1503.03832 |
2123
| tfa.optimizers | LazyAdamOptimizer | https://arxiv.org/abs/1412.6980 |
2224
| tfa.text | skip_gram_sample | https://arxiv.org/abs/1301.3781 |
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
licenses(["notice"]) # Apache 2.0
2+
3+
package(default_visibility = ["//visibility:public"])
4+
5+
py_library(
6+
name = "activations_py",
7+
srcs = [
8+
"__init__.py",
9+
"python/__init__.py",
10+
"python/sparsemax.py"
11+
],
12+
srcs_version = "PY2AND3",
13+
deps = [
14+
"//tensorflow_addons/utils:utils_py",
15+
],
16+
)
17+
18+
py_test(
19+
name = "sparsemax_py_test",
20+
size = "small",
21+
srcs = [
22+
"python/sparsemax_test.py",
23+
],
24+
main = "python/sparsemax_test.py",
25+
srcs_version = "PY2AND3",
26+
deps = [
27+
":activations_py",
28+
],
29+
)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Addons - Layers
2+
3+
## Contents
4+
| Layer | Reference |
5+
|:----------------------- |:-----------------------------|
6+
| Sparsemax | https://arxiv.org/abs/1602.02068 |
7+
8+
9+
## Contribution Guidelines
10+
#### Standard API
11+
In order to conform with the current API standard, all activations
12+
must:
13+
* Be a `tf.function`.
14+
* Have the signature `fn(input, axis=-1, name=None)`.
15+
* [Register as a keras global object](https://github.com/tensorflow/addons/blob/master/tensorflow_addons/utils/python/keras_utils.py)
16+
so it can be serialized properly.
17+
* Add the addon to the `py_library` in this sub-package's BUILD file.
18+
19+
#### Testing Requirements
20+
* Simple unittests that demonstrate the layer is behaving as expected.
21+
* When applicable, run all unittests with TensorFlow's
22+
`@run_all_in_graph_and_eager_modes` decorator.
23+
* Add a `py_test` to this sub-package's BUILD file.
24+
25+
#### Documentation Requirements
26+
* Update the table of contents in the project's central README.
27+
* Update the table of contents in this sub-package's README.
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""A module containing activation routines."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
from tensorflow_addons.activations.python.sparsemax import sparsemax

tensorflow_addons/activations/python/__init__.py

Whitespace-only changes.
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import tensorflow as tf
21+
22+
from tensorflow_addons.utils.python import keras_utils
23+
24+
25+
@tf.function
26+
@keras_utils.register_keras_custom_object
27+
def sparsemax(logits, axis=-1, name=None):
28+
"""Sparsemax activation function [1].
29+
30+
For each batch `i` and class `j` we have
31+
$$sparsemax[i, j] = max(logits[i, j] - tau(logits[i, :]), 0)$$
32+
33+
[1]: https://arxiv.org/abs/1602.02068
34+
35+
Args:
36+
logits: Input tensor.
37+
axis: Integer, axis along which the sparsemax operation is applied.
38+
name: A name for the operation (optional).
39+
Returns:
40+
Tensor, output of sparsemax transformation. Has the same type and
41+
shape as `logits`.
42+
Raises:
43+
ValueError: In case `dim(logits) == 1`.
44+
"""
45+
logits = tf.convert_to_tensor(logits, name="logits")
46+
47+
# We need its original shape for shape inference.
48+
shape = logits.get_shape()
49+
rank = shape.rank
50+
is_last_axis = (axis == -1) or (axis == rank - 1)
51+
52+
if is_last_axis:
53+
output = _compute_2d_sparsemax(logits, name=name)
54+
output.set_shape(shape)
55+
return output
56+
57+
# If dim is not the last dimension, we have to do a transpose so that we can
58+
# still perform softmax on its last dimension.
59+
60+
# Swap logits' dimension of dim and its last dimension.
61+
rank_op = tf.rank(logits)
62+
axis_norm = axis % rank
63+
logits = _swap_axis(logits, axis_norm, tf.math.subtract(rank_op, 1))
64+
65+
# Do the actual softmax on its last dimension.
66+
output = _compute_2d_sparsemax(logits)
67+
output = _swap_axis(
68+
output, axis_norm, tf.math.subtract(rank_op, 1), name=name)
69+
70+
# Make shape inference work since transpose may erase its static shape.
71+
output.set_shape(shape)
72+
return output
73+
74+
75+
def _swap_axis(logits, dim_index, last_index, **kwargs):
76+
return tf.transpose(
77+
logits,
78+
tf.concat([
79+
tf.range(dim_index), [last_index],
80+
tf.range(dim_index + 1, last_index), [dim_index]
81+
], 0), **kwargs)
82+
83+
84+
@tf.function
85+
def _compute_2d_sparsemax(logits, name=None):
86+
"""Performs the sparsemax operation when axis=-1."""
87+
shape_op = tf.shape(logits)
88+
obs = tf.math.reduce_prod(shape_op[:-1])
89+
dims = shape_op[-1]
90+
91+
# In the paper, they call the logits z.
92+
# The mean(logits) can be substracted from logits to make the algorithm
93+
# more numerically stable. the instability in this algorithm comes mostly
94+
# from the z_cumsum. Substacting the mean will cause z_cumsum to be close
95+
# to zero. However, in practise the numerical instability issues are very
96+
# minor and substacting the mean causes extra issues with inf and nan
97+
# input.
98+
# Reshape to [obs, dims] as it is almost free and means the remanining
99+
# code doesn't need to worry about the rank.
100+
z = tf.reshape(logits, [obs, dims])
101+
102+
# sort z
103+
z_sorted, _ = tf.nn.top_k(z, k=dims)
104+
105+
# calculate k(z)
106+
z_cumsum = tf.math.cumsum(z_sorted, axis=-1)
107+
k = tf.range(1, tf.cast(dims, logits.dtype) + 1, dtype=logits.dtype)
108+
z_check = 1 + k * z_sorted > z_cumsum
109+
# because the z_check vector is always [1,1,...1,0,0,...0] finding the
110+
# (index + 1) of the last `1` is the same as just summing the number of 1.
111+
k_z = tf.math.reduce_sum(tf.cast(z_check, tf.int32), axis=-1)
112+
113+
# calculate tau(z)
114+
# If there are inf values or all values are -inf, the k_z will be zero,
115+
# this is mathematically invalid and will also cause the gather_nd to fail.
116+
# Prevent this issue for now by setting k_z = 1 if k_z = 0, this is then
117+
# fixed later (see p_safe) by returning p = nan. This results in the same
118+
# behavior as softmax.
119+
k_z_safe = tf.math.maximum(k_z, 1)
120+
indices = tf.stack(
121+
[tf.range(0, obs), tf.reshape(k_z_safe, [-1]) - 1], axis=1)
122+
tau_sum = tf.gather_nd(z_cumsum, indices)
123+
tau_z = (tau_sum - 1) / tf.cast(k_z, logits.dtype)
124+
125+
# calculate p
126+
p = tf.math.maximum(
127+
tf.cast(0, logits.dtype), z - tf.expand_dims(tau_z, -1))
128+
# If k_z = 0 or if z = nan, then the input is invalid
129+
p_safe = tf.where(
130+
tf.math.logical_or(
131+
tf.math.equal(k_z, 0), tf.math.is_nan(z_cumsum[:, -1])),
132+
tf.fill([obs, dims], tf.cast(float("nan"), logits.dtype)), p)
133+
134+
# Reshape back to original size
135+
p_safe = tf.reshape(p_safe, shape_op, name=name)
136+
return p_safe

0 commit comments

Comments
 (0)