Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tensorflow_addons/activations/BUILD
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ py_library(
"gelu.py",
"hardshrink.py",
"sparsemax.py",
"tanhshrink.py",
],
data = [
"//tensorflow_addons/custom_ops/activations:_activation_ops.so",
Expand Down Expand Up @@ -55,3 +56,16 @@ py_test(
":activations",
],
)

py_test(
name = "tanhshrink_test",
size = "medium",
srcs = [
"tanhshrink_test.py",
],
main = "tanhshrink_test.py",
srcs_version = "PY2AND3",
deps = [
":activations",
],
)
2 changes: 2 additions & 0 deletions tensorflow_addons/activations/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
| gelu | @AakashKumarNain @WindQAQ | [email protected] [email protected] |
| hardshrink| @WindQAQ | [email protected]
| sparsemax | @AndreasMadsen | [email protected] |
| tanhshrink | @fsx950223 | [email protected] |

## Contents
| Submodule | Activation | Reference |
|:----------|:-----------|:---------------------------------|
| gelu | gelu | https://arxiv.org/abs/1606.08415 |
| hardshrink| hardshrink | |
| sparsemax | Sparsemax | https://arxiv.org/abs/1602.02068 |
| tanhshrink | Tanhshrink | |


## Contribution Guidelines
Expand Down
1 change: 1 addition & 0 deletions tensorflow_addons/activations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
from tensorflow_addons.activations.gelu import gelu
from tensorflow_addons.activations.hardshrink import hardshrink
from tensorflow_addons.activations.sparsemax import sparsemax
from tensorflow_addons.activations.tanhshrink import tanhshrink
45 changes: 45 additions & 0 deletions tensorflow_addons/activations/tanhshrink.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow_addons.utils import keras_utils
from tensorflow_addons.utils.resource_loader import get_path_to_datafile

_activation_ops_so = tf.load_op_library(
get_path_to_datafile("custom_ops/activations/_activation_ops.so"))


@keras_utils.register_keras_custom_object
@tf.function
def tanhshrink(x):
"""Applies the element-wise function: x - tanh(x)

Args:
features: A `Tensor`. Must be one of the following types:
`float16`, `float32`, `float64`.
Returns:
A `Tensor`. Has the same type as `features`.
"""
x = tf.convert_to_tensor(x)
return _activation_ops_so.addons_tanhshrink(x)


@tf.RegisterGradient("Addons>Tanhshrink")
def _tanhshrink_grad(op, grad):
return _activation_ops_so.addons_tanhshrink_grad(grad, op.inputs[0])
62 changes: 62 additions & 0 deletions tensorflow_addons/activations/tanhshrink_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl.testing import parameterized

import numpy as np
import tensorflow as tf
from tensorflow_addons.activations import tanhshrink
from tensorflow_addons.utils import test_utils


def _ref_tanhshrink(x):
return x - tf.tanh(x)


@test_utils.run_all_in_graph_and_eager_modes
class TanhshrinkTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(("float16", np.float16),
("float32", np.float32),
("float64", np.float64))
def test_tanhshrink(self, dtype):
x = tf.constant([1.0, 2.0, 3.0], dtype=dtype)
self.assertAllCloseAccordingToType(tanhshrink(x), _ref_tanhshrink(x))

@parameterized.named_parameters(("float16", np.float16),
("float32", np.float32),
("float64", np.float64))
def test_gradients(self, dtype):
x = tf.constant([1.0, 2.0, 3.0], dtype=dtype)
with tf.GradientTape(persistent=True) as tape:
tape.watch(x)
y_ref = _ref_tanhshrink(x)
y = tanhshrink(x)
grad_ref = tape.gradient(y_ref, x)
grad = tape.gradient(y, x)
self.assertAllCloseAccordingToType(grad, grad_ref)

def test_serialization(self):
ref_fn = tanhshrink
config = tf.keras.activations.serialize(ref_fn)
fn = tf.keras.activations.deserialize(config)
self.assertEqual(fn, ref_fn)


if __name__ == "__main__":
tf.test.main()
26 changes: 26 additions & 0 deletions tensorflow_addons/custom_ops/activations/BUILD
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,40 @@ cc_library(
alwayslink = 1,
)

cc_library(
name = "tanhshrink_op_gpu",
srcs = [
"cc/kernels/tanhshrink_op.h",
"cc/kernels/tanhshrink_op_gpu.cu.cc",
],
copts = if_cuda_is_configured([
"-DGOOGLE_CUDA=1",
"-x cuda",
"-nvcc_options=relaxed-constexpr",
"-nvcc_options=ftz=true",
]),
deps = [
"@local_config_tf//:libtensorflow_framework",
"@local_config_tf//:tf_header_lib",
] + if_cuda_is_configured([
"@local_config_cuda//cuda:cuda_libs",
"@local_config_cuda//cuda:cuda_headers",
]),
alwayslink = 1,
)

cc_binary(
name = "_activation_ops.so",
srcs = [
"cc/kernels/gelu_op.cc",
"cc/kernels/gelu_op.h",
"cc/kernels/hardshrink_op.cc",
"cc/kernels/hardshrink_op.h",
"cc/kernels/tanhshrink_op.cc",
"cc/kernels/tanhshrink_op.h",
"cc/ops/gelu_op.cc",
"cc/ops/hardshrink_op.cc",
"cc/ops/tanhshrink_op.cc",
],
copts = [
"-pthread",
Expand All @@ -71,5 +96,6 @@ cc_binary(
] + if_cuda_is_configured([
":gelu_op_gpu",
":hardshrink_op_gpu",
":tanhshrink_op_gpu",
]),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#define EIGEN_USE_THREADS

#include "tensorflow_addons/custom_ops/activations/cc/kernels/tanhshrink_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

namespace tensorflow {
namespace addons {

using CPUDevice = Eigen::ThreadPoolDevice;

#define REGISTER_TANHSHRINK_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Addons>Tanhshrink").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
TanhshrinkOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER(Name("Addons>TanhshrinkGrad") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T"), \
TanhshrinkGradOp<CPUDevice, type>);

TF_CALL_GPU_NUMBER_TYPES(REGISTER_TANHSHRINK_KERNELS);
#undef REGISTER_TANHSHRINK_KERNELS

#if GOOGLE_CUDA

using GPUDevice = Eigen::GpuDevice;

// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void Tanhshrink<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstTensor features, \
typename TTypes<T>::Tensor activations); \
extern template struct Tanhshrink<GPUDevice, T>; \
\
template <> \
void TanhshrinkGrad<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
typename TTypes<T>::ConstTensor features, \
typename TTypes<T>::Tensor backprops); \
extern template struct TanhshrinkGrad<GPUDevice, T>;

TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
#undef DECLARE_GPU_SPEC
} // namespace functor

// Registration of the GPU implementations.
#define REGISTER_TANHSHRINK_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Addons>Tanhshrink").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
TanhshrinkOp<GPUDevice, type>); \
REGISTER_KERNEL_BUILDER(Name("Addons>TanhshrinkGrad") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T"), \
TanhshrinkGradOp<GPUDevice, type>);

TF_CALL_GPU_NUMBER_TYPES(REGISTER_TANHSHRINK_GPU_KERNELS);
#undef REGISTER_TANHSHRINK_GPU_KERNELS

#endif // GOOGLE_CUDA

} // namespace addons
} // namespace tensorflow
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_TANHSHRINK_OP_H_
#define TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_TANHSHRINK_OP_H_

#define EIGEN_USE_THREADS

#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

namespace tensorflow {
namespace addons {
namespace functor {

template <typename Device, typename T>
struct Tanhshrink {
void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
typename TTypes<T>::Tensor activations) {
activations.device(d) = features - features.tanh();
}
};

template <typename Device, typename T>
struct TanhshrinkGrad {
void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
typename TTypes<T>::ConstTensor features,
typename TTypes<T>::Tensor backprops) {
backprops.device(d) = gradients * features.tanh().square();
}
};

} // namespace functor

template <typename Device, typename T>
class TanhshrinkOp : public UnaryElementWiseOp<T, TanhshrinkOp<Device, T>> {
public:
using UnaryElementWiseOp<T, TanhshrinkOp<Device, T>>::UnaryElementWiseOp;

void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
functor::Tanhshrink<Device, T> functor;
functor(context->eigen_device<Device>(), input.flat<T>(),
output->flat<T>());
}
};

template <typename Device, typename T>
class TanhshrinkGradOp
: public BinaryElementWiseOp<T, TanhshrinkGradOp<Device, T>> {
public:
using BinaryElementWiseOp<T,
TanhshrinkGradOp<Device, T>>::BinaryElementWiseOp;

void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
const Tensor& a, Tensor* output);

// INPUTS:
// g (gradients): backpropagated gradients
// a (inputs): the inputs that were passed to the Tanhshrink op.
// OUTPUT:
// gradients to backprop
template <int NDIMS>
void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
Tensor* output) {
OperateNoTemplate(context, g, a, output);
}
};

template <typename Device, typename T>
void TanhshrinkGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
const Tensor& g,
const Tensor& a,
Tensor* output) {
functor::TanhshrinkGrad<Device, T> functor;
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
output->flat<T>());
}
} // namespace addons
} // namespace tensorflow

#undef EIGEN_USE_THREADS

#endif // TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_TANHSHRINK_OP_H_
Loading