diff --git a/tensorflow_addons/activations/BUILD b/tensorflow_addons/activations/BUILD old mode 100644 new mode 100755 index 64c1f984d7..e5f9640bb4 --- a/tensorflow_addons/activations/BUILD +++ b/tensorflow_addons/activations/BUILD @@ -9,6 +9,7 @@ py_library( "gelu.py", "hardshrink.py", "sparsemax.py", + "tanhshrink.py", ], data = [ "//tensorflow_addons/custom_ops/activations:_activation_ops.so", @@ -55,3 +56,16 @@ py_test( ":activations", ], ) + +py_test( + name = "tanhshrink_test", + size = "medium", + srcs = [ + "tanhshrink_test.py", + ], + main = "tanhshrink_test.py", + srcs_version = "PY2AND3", + deps = [ + ":activations", + ], +) diff --git a/tensorflow_addons/activations/README.md b/tensorflow_addons/activations/README.md index b61926dc8e..d34c12b962 100644 --- a/tensorflow_addons/activations/README.md +++ b/tensorflow_addons/activations/README.md @@ -6,6 +6,7 @@ | gelu | @AakashKumarNain @WindQAQ | aakashnain@outlook.com windqaq@gmail.com | | hardshrink| @WindQAQ | windqaq@gmail.com | sparsemax | @AndreasMadsen | amwwebdk+github@gmail.com | +| tanhshrink | @fsx950223 | fsx950223@gmail.com | ## Contents | Submodule | Activation | Reference | @@ -13,6 +14,7 @@ | gelu | gelu | https://arxiv.org/abs/1606.08415 | | hardshrink| hardshrink | | | sparsemax | Sparsemax | https://arxiv.org/abs/1602.02068 | +| tanhshrink | Tanhshrink | | ## Contribution Guidelines diff --git a/tensorflow_addons/activations/__init__.py b/tensorflow_addons/activations/__init__.py index 5e1123b1b3..4208b57817 100644 --- a/tensorflow_addons/activations/__init__.py +++ b/tensorflow_addons/activations/__init__.py @@ -21,3 +21,4 @@ from tensorflow_addons.activations.gelu import gelu from tensorflow_addons.activations.hardshrink import hardshrink from tensorflow_addons.activations.sparsemax import sparsemax +from tensorflow_addons.activations.tanhshrink import tanhshrink diff --git a/tensorflow_addons/activations/tanhshrink.py b/tensorflow_addons/activations/tanhshrink.py new file mode 100755 index 0000000000..00287de903 --- /dev/null +++ b/tensorflow_addons/activations/tanhshrink.py @@ -0,0 +1,45 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow_addons.utils import keras_utils +from tensorflow_addons.utils.resource_loader import get_path_to_datafile + +_activation_ops_so = tf.load_op_library( + get_path_to_datafile("custom_ops/activations/_activation_ops.so")) + + +@keras_utils.register_keras_custom_object +@tf.function +def tanhshrink(x): + """Applies the element-wise function: x - tanh(x) + + Args: + features: A `Tensor`. Must be one of the following types: + `float16`, `float32`, `float64`. + Returns: + A `Tensor`. Has the same type as `features`. + """ + x = tf.convert_to_tensor(x) + return _activation_ops_so.addons_tanhshrink(x) + + +@tf.RegisterGradient("Addons>Tanhshrink") +def _tanhshrink_grad(op, grad): + return _activation_ops_so.addons_tanhshrink_grad(grad, op.inputs[0]) diff --git a/tensorflow_addons/activations/tanhshrink_test.py b/tensorflow_addons/activations/tanhshrink_test.py new file mode 100755 index 0000000000..86d72629bb --- /dev/null +++ b/tensorflow_addons/activations/tanhshrink_test.py @@ -0,0 +1,62 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +import numpy as np +import tensorflow as tf +from tensorflow_addons.activations import tanhshrink +from tensorflow_addons.utils import test_utils + + +def _ref_tanhshrink(x): + return x - tf.tanh(x) + + +@test_utils.run_all_in_graph_and_eager_modes +class TanhshrinkTest(tf.test.TestCase, parameterized.TestCase): + @parameterized.named_parameters(("float16", np.float16), + ("float32", np.float32), + ("float64", np.float64)) + def test_tanhshrink(self, dtype): + x = tf.constant([1.0, 2.0, 3.0], dtype=dtype) + self.assertAllCloseAccordingToType(tanhshrink(x), _ref_tanhshrink(x)) + + @parameterized.named_parameters(("float16", np.float16), + ("float32", np.float32), + ("float64", np.float64)) + def test_gradients(self, dtype): + x = tf.constant([1.0, 2.0, 3.0], dtype=dtype) + with tf.GradientTape(persistent=True) as tape: + tape.watch(x) + y_ref = _ref_tanhshrink(x) + y = tanhshrink(x) + grad_ref = tape.gradient(y_ref, x) + grad = tape.gradient(y, x) + self.assertAllCloseAccordingToType(grad, grad_ref) + + def test_serialization(self): + ref_fn = tanhshrink + config = tf.keras.activations.serialize(ref_fn) + fn = tf.keras.activations.deserialize(config) + self.assertEqual(fn, ref_fn) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_addons/custom_ops/activations/BUILD b/tensorflow_addons/custom_ops/activations/BUILD old mode 100644 new mode 100755 index f4778cf23c..b61d7a6fa3 --- a/tensorflow_addons/custom_ops/activations/BUILD +++ b/tensorflow_addons/custom_ops/activations/BUILD @@ -49,6 +49,28 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "tanhshrink_op_gpu", + srcs = [ + "cc/kernels/tanhshrink_op.h", + "cc/kernels/tanhshrink_op_gpu.cu.cc", + ], + copts = if_cuda_is_configured([ + "-DGOOGLE_CUDA=1", + "-x cuda", + "-nvcc_options=relaxed-constexpr", + "-nvcc_options=ftz=true", + ]), + deps = [ + "@local_config_tf//:libtensorflow_framework", + "@local_config_tf//:tf_header_lib", + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_libs", + "@local_config_cuda//cuda:cuda_headers", + ]), + alwayslink = 1, +) + cc_binary( name = "_activation_ops.so", srcs = [ @@ -56,8 +78,11 @@ cc_binary( "cc/kernels/gelu_op.h", "cc/kernels/hardshrink_op.cc", "cc/kernels/hardshrink_op.h", + "cc/kernels/tanhshrink_op.cc", + "cc/kernels/tanhshrink_op.h", "cc/ops/gelu_op.cc", "cc/ops/hardshrink_op.cc", + "cc/ops/tanhshrink_op.cc", ], copts = [ "-pthread", @@ -71,5 +96,6 @@ cc_binary( ] + if_cuda_is_configured([ ":gelu_op_gpu", ":hardshrink_op_gpu", + ":tanhshrink_op_gpu", ]), ) diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/tanhshrink_op.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/tanhshrink_op.cc new file mode 100644 index 0000000000..55bc374804 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/tanhshrink_op.cc @@ -0,0 +1,80 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow_addons/custom_ops/activations/cc/kernels/tanhshrink_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace addons { + +using CPUDevice = Eigen::ThreadPoolDevice; + +#define REGISTER_TANHSHRINK_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Addons>Tanhshrink").Device(DEVICE_CPU).TypeConstraint("T"), \ + TanhshrinkOp); \ + REGISTER_KERNEL_BUILDER(Name("Addons>TanhshrinkGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + TanhshrinkGradOp); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_TANHSHRINK_KERNELS); +#undef REGISTER_TANHSHRINK_KERNELS + +#if GOOGLE_CUDA + +using GPUDevice = Eigen::GpuDevice; + +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void Tanhshrink::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor features, \ + typename TTypes::Tensor activations); \ + extern template struct Tanhshrink; \ + \ + template <> \ + void TanhshrinkGrad::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor gradients, \ + typename TTypes::ConstTensor features, \ + typename TTypes::Tensor backprops); \ + extern template struct TanhshrinkGrad; + +TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); +#undef DECLARE_GPU_SPEC +} // namespace functor + +// Registration of the GPU implementations. +#define REGISTER_TANHSHRINK_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Addons>Tanhshrink").Device(DEVICE_GPU).TypeConstraint("T"), \ + TanhshrinkOp); \ + REGISTER_KERNEL_BUILDER(Name("Addons>TanhshrinkGrad") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T"), \ + TanhshrinkGradOp); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_TANHSHRINK_GPU_KERNELS); +#undef REGISTER_TANHSHRINK_GPU_KERNELS + +#endif // GOOGLE_CUDA + +} // namespace addons +} // namespace tensorflow diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/tanhshrink_op.h b/tensorflow_addons/custom_ops/activations/cc/kernels/tanhshrink_op.h new file mode 100644 index 0000000000..f4b9f22373 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/tanhshrink_op.h @@ -0,0 +1,96 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_TANHSHRINK_OP_H_ +#define TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_TANHSHRINK_OP_H_ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace addons { +namespace functor { + +template +struct Tanhshrink { + void operator()(const Device& d, typename TTypes::ConstTensor features, + typename TTypes::Tensor activations) { + activations.device(d) = features - features.tanh(); + } +}; + +template +struct TanhshrinkGrad { + void operator()(const Device& d, typename TTypes::ConstTensor gradients, + typename TTypes::ConstTensor features, + typename TTypes::Tensor backprops) { + backprops.device(d) = gradients * features.tanh().square(); + } +}; + +} // namespace functor + +template +class TanhshrinkOp : public UnaryElementWiseOp> { + public: + using UnaryElementWiseOp>::UnaryElementWiseOp; + + void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { + functor::Tanhshrink functor; + functor(context->eigen_device(), input.flat(), + output->flat()); + } +}; + +template +class TanhshrinkGradOp + : public BinaryElementWiseOp> { + public: + using BinaryElementWiseOp>::BinaryElementWiseOp; + + void OperateNoTemplate(OpKernelContext* context, const Tensor& g, + const Tensor& a, Tensor* output); + + // INPUTS: + // g (gradients): backpropagated gradients + // a (inputs): the inputs that were passed to the Tanhshrink op. + // OUTPUT: + // gradients to backprop + template + void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, + Tensor* output) { + OperateNoTemplate(context, g, a, output); + } +}; + +template +void TanhshrinkGradOp::OperateNoTemplate(OpKernelContext* context, + const Tensor& g, + const Tensor& a, + Tensor* output) { + functor::TanhshrinkGrad functor; + functor(context->eigen_device(), g.flat(), a.flat(), + output->flat()); +} +} // namespace addons +} // namespace tensorflow + +#undef EIGEN_USE_THREADS + +#endif // TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_TANHSHRINK_OP_H_ \ No newline at end of file diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/tanhshrink_op_gpu.cu.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/tanhshrink_op_gpu.cu.cc new file mode 100755 index 0000000000..5c9aa20ef0 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/tanhshrink_op_gpu.cu.cc @@ -0,0 +1,38 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow_addons/custom_ops/activations/cc/kernels/tanhshrink_op.h" +#include "tensorflow/core/framework/register_types.h" +#include "third_party/eigen3/Eigen/Core" + +namespace tensorflow { +namespace addons { + +using GPUDevice = Eigen::GpuDevice; + +#define DEFINE_GPU_KERNELS(T) \ + template struct functor::Tanhshrink; \ + template struct functor::TanhshrinkGrad; + +TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); + +} // namespace addons +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow_addons/custom_ops/activations/cc/ops/tanhshrink_op.cc b/tensorflow_addons/custom_ops/activations/cc/ops/tanhshrink_op.cc new file mode 100755 index 0000000000..6689a7e1e6 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/ops/tanhshrink_op.cc @@ -0,0 +1,37 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { +namespace addons { + +REGISTER_OP("Addons>Tanhshrink") + .Input("features: T") + .Output("activations: T") + .Attr("T: {half, float, double}") + .SetShapeFn(shape_inference::UnchangedShape); + +REGISTER_OP("Addons>TanhshrinkGrad") + .Input("gradients: T") + .Input("features: T") + .Output("backprops: T") + .Attr("T: {half, float, double}") + .SetShapeFn(shape_inference::MergeBothInputsShapeFn); + +} // namespace addons +} // namespace tensorflow \ No newline at end of file