|
| 1 | +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | + |
| 16 | +from __future__ import absolute_import |
| 17 | +from __future__ import division |
| 18 | +from __future__ import print_function |
| 19 | + |
| 20 | +import tensorflow as tf |
| 21 | + |
| 22 | +from tensorflow_addons.utils.python import keras_utils |
| 23 | + |
| 24 | + |
| 25 | +@tf.function |
| 26 | +@keras_utils.register_keras_custom_object |
| 27 | +def sparsemax(logits, axis=-1, name=None): |
| 28 | + """Sparsemax activation function [1]. |
| 29 | +
|
| 30 | + For each batch `i` and class `j` we have |
| 31 | + $$sparsemax[i, j] = max(logits[i, j] - tau(logits[i, :]), 0)$$ |
| 32 | +
|
| 33 | + [1]: https://arxiv.org/abs/1602.02068 |
| 34 | +
|
| 35 | + Args: |
| 36 | + logits: Input tensor. |
| 37 | + axis: Integer, axis along which the sparsemax operation is applied. |
| 38 | + name: A name for the operation (optional). |
| 39 | + Returns: |
| 40 | + Tensor, output of sparsemax transformation. Has the same type and |
| 41 | + shape as `logits`. |
| 42 | + Raises: |
| 43 | + ValueError: In case `dim(logits) == 1`. |
| 44 | + """ |
| 45 | + logits = tf.convert_to_tensor(logits, name="logits") |
| 46 | + |
| 47 | + # We need its original shape for shape inference. |
| 48 | + shape = logits.get_shape() |
| 49 | + rank = shape.rank |
| 50 | + is_last_axis = (axis == -1) or (axis == rank - 1) |
| 51 | + |
| 52 | + if is_last_axis: |
| 53 | + output = _compute_2d_sparsemax(logits, name=name) |
| 54 | + output.set_shape(shape) |
| 55 | + return output |
| 56 | + |
| 57 | + # If dim is not the last dimension, we have to do a transpose so that we can |
| 58 | + # still perform softmax on its last dimension. |
| 59 | + |
| 60 | + # Swap logits' dimension of dim and its last dimension. |
| 61 | + rank_op = tf.rank(logits) |
| 62 | + axis_norm = axis % rank |
| 63 | + logits = _swap_axis(logits, axis_norm, tf.math.subtract(rank_op, 1)) |
| 64 | + |
| 65 | + # Do the actual softmax on its last dimension. |
| 66 | + output = _compute_2d_sparsemax(logits) |
| 67 | + output = _swap_axis( |
| 68 | + output, axis_norm, tf.math.subtract(rank_op, 1), name=name) |
| 69 | + |
| 70 | + # Make shape inference work since transpose may erase its static shape. |
| 71 | + output.set_shape(shape) |
| 72 | + return output |
| 73 | + |
| 74 | + |
| 75 | +def _swap_axis(logits, dim_index, last_index, **kwargs): |
| 76 | + return tf.transpose( |
| 77 | + logits, |
| 78 | + tf.concat([ |
| 79 | + tf.range(dim_index), [last_index], |
| 80 | + tf.range(dim_index + 1, last_index), [dim_index] |
| 81 | + ], 0), **kwargs) |
| 82 | + |
| 83 | + |
| 84 | +@tf.function |
| 85 | +def _compute_2d_sparsemax(logits, name=None): |
| 86 | + """Performs the sparsemax operation when axis=-1.""" |
| 87 | + shape_op = tf.shape(logits) |
| 88 | + obs = tf.math.reduce_prod(shape_op[:-1]) |
| 89 | + dims = shape_op[-1] |
| 90 | + |
| 91 | + # In the paper, they call the logits z. |
| 92 | + # The mean(logits) can be substracted from logits to make the algorithm |
| 93 | + # more numerically stable. the instability in this algorithm comes mostly |
| 94 | + # from the z_cumsum. Substacting the mean will cause z_cumsum to be close |
| 95 | + # to zero. However, in practise the numerical instability issues are very |
| 96 | + # minor and substacting the mean causes extra issues with inf and nan |
| 97 | + # input. |
| 98 | + # Reshape to [obs, dims] as it is almost free and means the remanining |
| 99 | + # code doesn't need to worry about the rank. |
| 100 | + z = tf.reshape(logits, [obs, dims]) |
| 101 | + |
| 102 | + # sort z |
| 103 | + z_sorted, _ = tf.nn.top_k(z, k=dims) |
| 104 | + |
| 105 | + # calculate k(z) |
| 106 | + z_cumsum = tf.math.cumsum(z_sorted, axis=-1) |
| 107 | + k = tf.range(1, tf.cast(dims, logits.dtype) + 1, dtype=logits.dtype) |
| 108 | + z_check = 1 + k * z_sorted > z_cumsum |
| 109 | + # because the z_check vector is always [1,1,...1,0,0,...0] finding the |
| 110 | + # (index + 1) of the last `1` is the same as just summing the number of 1. |
| 111 | + k_z = tf.math.reduce_sum(tf.cast(z_check, tf.int32), axis=-1) |
| 112 | + |
| 113 | + # calculate tau(z) |
| 114 | + # If there are inf values or all values are -inf, the k_z will be zero, |
| 115 | + # this is mathematically invalid and will also cause the gather_nd to fail. |
| 116 | + # Prevent this issue for now by setting k_z = 1 if k_z = 0, this is then |
| 117 | + # fixed later (see p_safe) by returning p = nan. This results in the same |
| 118 | + # behavior as softmax. |
| 119 | + k_z_safe = tf.math.maximum(k_z, 1) |
| 120 | + indices = tf.stack( |
| 121 | + [tf.range(0, obs), tf.reshape(k_z_safe, [-1]) - 1], axis=1) |
| 122 | + tau_sum = tf.gather_nd(z_cumsum, indices) |
| 123 | + tau_z = (tau_sum - 1) / tf.cast(k_z, logits.dtype) |
| 124 | + |
| 125 | + # calculate p |
| 126 | + p = tf.math.maximum( |
| 127 | + tf.cast(0, logits.dtype), z - tf.expand_dims(tau_z, -1)) |
| 128 | + # If k_z = 0 or if z = nan, then the input is invalid |
| 129 | + p_safe = tf.where( |
| 130 | + tf.math.logical_or( |
| 131 | + tf.math.equal(k_z, 0), tf.math.is_nan(z_cumsum[:, -1])), |
| 132 | + tf.fill([obs, dims], tf.cast(float("nan"), logits.dtype)), p) |
| 133 | + |
| 134 | + # Reshape back to original size |
| 135 | + p_safe = tf.reshape(p_safe, shape_op, name=name) |
| 136 | + return p_safe |
0 commit comments