Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion tensorflow_addons/activations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ py_library(
srcs = [
"__init__.py",
"gelu.py",
"hardshrink.py",
"sparsemax.py",
],
data = [
Expand All @@ -31,7 +32,7 @@ py_test(

py_test(
name = "gelu_test",
size = "large",
size = "medium",
srcs = [
"gelu_test.py",
],
Expand All @@ -41,3 +42,16 @@ py_test(
":activations",
],
)

py_test(
name = "hardshrink_test",
size = "medium",
srcs = [
"hardshrink_test.py",
],
main = "hardshrink_test.py",
srcs_version = "PY2AND3",
deps = [
":activations",
],
)
2 changes: 2 additions & 0 deletions tensorflow_addons/activations/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
| Submodule | Maintainers | Contact Info |
|:----------|:--------------------------|:-----------------------------------------|
| gelu | @AakashKumarNain @WindQAQ | [email protected] [email protected] |
| hardshrink| @WindQAQ | [email protected]
| sparsemax | @AndreasMadsen | [email protected] |

## Contents
| Submodule | Activation | Reference |
|:----------|:-----------|:---------------------------------|
| gelu | gelu | https://arxiv.org/abs/1606.08415 |
| hardshrink| hardshrink | |
| sparsemax | Sparsemax | https://arxiv.org/abs/1602.02068 |


Expand Down
1 change: 1 addition & 0 deletions tensorflow_addons/activations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@
from __future__ import print_function

from tensorflow_addons.activations.gelu import gelu
from tensorflow_addons.activations.hardshrink import hardshrink
from tensorflow_addons.activations.sparsemax import sparsemax
52 changes: 52 additions & 0 deletions tensorflow_addons/activations/hardshrink.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow_addons.utils import keras_utils
from tensorflow_addons.utils.resource_loader import get_path_to_datafile

_activation_ops_so = tf.load_op_library(
get_path_to_datafile("custom_ops/activations/_activation_ops.so"))


@keras_utils.register_keras_custom_object
@tf.function
def hardshrink(x, lower=-1.0, upper=1.0):
"""Hard shrink function.

Computes hard shrink function:
`x if x < lower or x > upper else 0`.

Args:
x: A `Tensor`. Must be one of the following types:
`float16`, `float32`, `float64`.
lower: `float`, lower bound for setting values to zeros.
upper: `float`, upper bound for setting values to zeros.
Returns:
A `Tensor`. Has the same type as `x`.
"""
x = tf.convert_to_tensor(x)
return _activation_ops_so.hardshrink(x, lower, upper)


@tf.RegisterGradient("Hardshrink")
def _hardshrink_grad(op, grad):
return _activation_ops_so.hardshrink_grad(grad, op.inputs[0],
op.get_attr("lower"),
op.get_attr("upper"))
99 changes: 99 additions & 0 deletions tensorflow_addons/activations/hardshrink_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl.testing import parameterized

import numpy as np
import tensorflow as tf
from tensorflow_addons.activations import hardshrink
from tensorflow_addons.utils import test_utils


def _ref_hardshrink(x, lower=-1.0, upper=1.0):
x = tf.convert_to_tensor(x)
return tf.where(tf.math.logical_or(x < lower, x > upper), x, 0.0)


@test_utils.run_all_in_graph_and_eager_modes
class HardshrinkTest(tf.test.TestCase, parameterized.TestCase):
def test_invalid(self):
with self.assertRaisesOpError(
"lower must be less than or equal to upper."): # pylint: disable=bad-continuation
y = hardshrink(tf.ones(shape=(1, 2, 3)), lower=2.0, upper=-2.0)
self.evaluate(y)

@parameterized.named_parameters(("float16", np.float16),
("float32", np.float32),
("float64", np.float64))
def test_hardshrink(self, dtype):
x = (np.random.rand(2, 3, 4) * 2.0 - 1.0).astype(dtype)
self.assertAllCloseAccordingToType(hardshrink(x), _ref_hardshrink(x))
self.assertAllCloseAccordingToType(
hardshrink(x, -2.0, 2.0), _ref_hardshrink(x, -2.0, 2.0))

@parameterized.named_parameters(("float16", np.float16),
("float32", np.float32),
("float64", np.float64))
def test_gradients(self, dtype):
x = tf.constant([-1.5, -0.5, 0.5, 1.5], dtype=dtype)

with tf.GradientTape(persistent=True) as tape:
tape.watch(x)
y_ref = _ref_hardshrink(x)
y = hardshrink(x)
grad_ref = tape.gradient(y_ref, x)
grad = tape.gradient(y, x)
self.assertAllCloseAccordingToType(grad, grad_ref)

@parameterized.named_parameters(("float32", np.float32),
("float64", np.float64))
def test_theoretical_gradients(self, dtype):
# Only test theoretical gradients for float32 and float64
# because of the instability of float16 while computing jacobian
x = tf.constant([-1.5, -0.5, 0.5, 1.5], dtype=dtype)

theoretical, numerical = tf.test.compute_gradient(
lambda x: hardshrink(x), [x])
self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4)

def test_unknown_shape(self):
fn = hardshrink.get_concrete_function(
tf.TensorSpec(shape=None, dtype=tf.float32))

for shape in [(1,), (1, 2), (1, 2, 3), (1, 2, 3, 4)]:
x = tf.ones(shape=shape, dtype=tf.float32)
self.assertAllClose(fn(x), hardshrink(x))

def test_serialization(self):
ref_fn = hardshrink
config = tf.keras.activations.serialize(ref_fn)
fn = tf.keras.activations.deserialize(config)
self.assertEqual(fn, ref_fn)

def test_serialization_with_layers(self):
layer = tf.keras.layers.Dense(3, activation=hardshrink)
config = tf.keras.layers.serialize(layer)
deserialized_layer = tf.keras.layers.deserialize(config)
self.assertEqual(deserialized_layer.__class__.__name__,
layer.__class__.__name__)
self.assertEqual(deserialized_layer.activation.__name__, "hardshrink")


if __name__ == "__main__":
tf.test.main()
30 changes: 29 additions & 1 deletion tensorflow_addons/custom_ops/activations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,37 @@ cc_library(
alwayslink = 1,
)

cc_library(
name = "hardshrink_op_gpu",
srcs = [
"cc/kernels/hardshrink_op.h",
"cc/kernels/hardshrink_op_gpu.cu.cc",
],
copts = if_cuda_is_configured([
"-DGOOGLE_CUDA=1",
"-x cuda",
"-nvcc_options=relaxed-constexpr",
"-nvcc_options=ftz=true",
]),
deps = [
"@local_config_tf//:libtensorflow_framework",
"@local_config_tf//:tf_header_lib",
] + if_cuda_is_configured([
"@local_config_cuda//cuda:cuda_libs",
"@local_config_cuda//cuda:cuda_headers",
]),
alwayslink = 1,
)

cc_binary(
name = "_activation_ops.so",
srcs = [
"cc/kernels/gelu_op.cc",
"cc/kernels/gelu_op.h",
"cc/kernels/hardshrink_op.cc",
"cc/kernels/hardshrink_op.h",
"cc/ops/gelu_op.cc",
"cc/ops/hardshrink_op.cc",
],
copts = [
"-pthread",
Expand All @@ -43,5 +68,8 @@ cc_binary(
deps = [
"@local_config_tf//:libtensorflow_framework",
"@local_config_tf//:tf_header_lib",
] + if_cuda_is_configured([":gelu_op_gpu"]),
] + if_cuda_is_configured([
":gelu_op_gpu",
":hardshrink_op_gpu",
]),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#define EIGEN_USE_THREADS

#include "tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

namespace tensorflow {

using CPUDevice = Eigen::ThreadPoolDevice;

#define REGISTER_HARDSHRINK_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Hardshrink").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think we could change this to Addons>Hardshrink per tensorflow/community#126 ? It should work on the nightly and RC I believe.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder how to do this... I have tried to modify cc/kernels/hardshrink_op.cc and cc/ops/hardshrink_op.cc, but none of them gives the answer. I think I could just follow @tomerk's PR though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works for me.

HardshrinkOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("HardshrinkGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
HardshrinkGradOp<CPUDevice, type>);

// Hardshrink only makes sense with floating points.
TF_CALL_GPU_NUMBER_TYPES(REGISTER_HARDSHRINK_KERNELS);
#undef REGISTER_HARDSHRINK_KERNELS

#if GOOGLE_CUDA

using GPUDevice = Eigen::GpuDevice;

// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void Hardshrink<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstTensor features, T lower, \
T upper, typename TTypes<T>::Tensor activations); \
extern template struct Hardshrink<GPUDevice, T>; \
\
template <> \
void HardshrinkGrad<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
typename TTypes<T>::ConstTensor features, T lower, T upper, \
typename TTypes<T>::Tensor backprops); \
extern template struct HardshrinkGrad<GPUDevice, T>;

TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
#undef DECLARE_GPU_SPEC
} // namespace functor

// Registration of the GPU implementations.
#define REGISTER_HARDSHRINK_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Hardshrink").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
HardshrinkOp<GPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("HardshrinkGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
HardshrinkGradOp<GPUDevice, type>);

TF_CALL_GPU_NUMBER_TYPES(REGISTER_HARDSHRINK_GPU_KERNELS);
#undef REGISTER_HARDSHRINK_GPU_KERNELS

#endif // GOOGLE_CUDA

} // namespace tensorflow
Loading