diff --git a/tensorflow_addons/activations/BUILD b/tensorflow_addons/activations/BUILD index e5f9640bb4..a1972f34c9 100644 --- a/tensorflow_addons/activations/BUILD +++ b/tensorflow_addons/activations/BUILD @@ -8,6 +8,7 @@ py_library( "__init__.py", "gelu.py", "hardshrink.py", + "lisht.py", "sparsemax.py", "tanhshrink.py", ], @@ -20,7 +21,7 @@ py_library( py_test( name = "sparsemax_test", - size = "medium", + size = "small", srcs = [ "sparsemax_test.py", ], @@ -33,7 +34,7 @@ py_test( py_test( name = "gelu_test", - size = "medium", + size = "small", srcs = [ "gelu_test.py", ], @@ -46,7 +47,7 @@ py_test( py_test( name = "hardshrink_test", - size = "medium", + size = "small", srcs = [ "hardshrink_test.py", ], @@ -57,9 +58,22 @@ py_test( ], ) +py_test( + name = "lisht_test", + size = "small", + srcs = [ + "lisht_test.py", + ], + main = "lisht_test.py", + srcs_version = "PY2AND3", + deps = [ + ":activations", + ], +) + py_test( name = "tanhshrink_test", - size = "medium", + size = "small", srcs = [ "tanhshrink_test.py", ], diff --git a/tensorflow_addons/activations/README.md b/tensorflow_addons/activations/README.md index d34c12b962..ede5eb30fb 100644 --- a/tensorflow_addons/activations/README.md +++ b/tensorflow_addons/activations/README.md @@ -4,17 +4,19 @@ | Submodule | Maintainers | Contact Info | |:----------|:--------------------------|:-----------------------------------------| | gelu | @AakashKumarNain @WindQAQ | aakashnain@outlook.com windqaq@gmail.com | -| hardshrink| @WindQAQ | windqaq@gmail.com +| hardshrink| @WindQAQ | windqaq@gmail.com | +| lisht | @WindQAQ | windqaq@gmail.com | | sparsemax | @AndreasMadsen | amwwebdk+github@gmail.com | -| tanhshrink | @fsx950223 | fsx950223@gmail.com | +| tanhshrink| @fsx950223 | fsx950223@gmail.com | ## Contents | Submodule | Activation | Reference | |:----------|:-----------|:---------------------------------| | gelu | gelu | https://arxiv.org/abs/1606.08415 | | hardshrink| hardshrink | | -| sparsemax | Sparsemax | https://arxiv.org/abs/1602.02068 | -| tanhshrink | Tanhshrink | | +| lisht | lisht | https://arxiv.org/abs/1901.05894 | +| sparsemax | sparsemax | https://arxiv.org/abs/1602.02068 | +| tanhshrink| tanhshrink | | ## Contribution Guidelines @@ -22,7 +24,6 @@ In order to conform with the current API standard, all activations must: * Be a `tf.function`. - * Have the signature `fn(input, axis=-1, name=None)`. * [Register as a keras global object](https://github.com/tensorflow/addons/blob/master/tensorflow_addons/utils/python/keras_utils.py) so it can be serialized properly. * Add the addon to the `py_library` in this sub-package's BUILD file. diff --git a/tensorflow_addons/activations/__init__.py b/tensorflow_addons/activations/__init__.py index 4208b57817..313a78a1e3 100644 --- a/tensorflow_addons/activations/__init__.py +++ b/tensorflow_addons/activations/__init__.py @@ -20,5 +20,6 @@ from tensorflow_addons.activations.gelu import gelu from tensorflow_addons.activations.hardshrink import hardshrink +from tensorflow_addons.activations.lisht import lisht from tensorflow_addons.activations.sparsemax import sparsemax from tensorflow_addons.activations.tanhshrink import tanhshrink diff --git a/tensorflow_addons/activations/lisht.py b/tensorflow_addons/activations/lisht.py new file mode 100644 index 0000000000..cbef569792 --- /dev/null +++ b/tensorflow_addons/activations/lisht.py @@ -0,0 +1,49 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow_addons.utils import keras_utils +from tensorflow_addons.utils.resource_loader import get_path_to_datafile + +_activation_ops_so = tf.load_op_library( + get_path_to_datafile("custom_ops/activations/_activation_ops.so")) + + +@keras_utils.register_keras_custom_object +@tf.function +def lisht(x): + """LiSHT: Non-Parameteric Linearly Scaled Hyperbolic Tangent Activation Function. + + Computes linearly scaled hyperbolic tangent (LiSHT): `x * tanh(x)` + + See [LiSHT: Non-Parameteric Linearly Scaled Hyperbolic Tangent Activation Function for Neural Networks](https://arxiv.org/abs/1901.05894). + + Args: + x: A `Tensor`. Must be one of the following types: + `float16`, `float32`, `float64`. + Returns: + A `Tensor`. Has the same type as `x`. + """ + x = tf.convert_to_tensor(x) + return _activation_ops_so.addons_lisht(x) + + +@tf.RegisterGradient("Addons>Lisht") +def _lisht_grad(op, grad): + return _activation_ops_so.addons_lisht_grad(grad, op.inputs[0]) diff --git a/tensorflow_addons/activations/lisht_test.py b/tensorflow_addons/activations/lisht_test.py new file mode 100644 index 0000000000..b4e7fd2dfc --- /dev/null +++ b/tensorflow_addons/activations/lisht_test.py @@ -0,0 +1,73 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +import numpy as np +import tensorflow as tf +from tensorflow_addons.activations import lisht +from tensorflow_addons.utils import test_utils + + +@test_utils.run_all_in_graph_and_eager_modes +class LishtTest(tf.test.TestCase, parameterized.TestCase): + @parameterized.named_parameters(("float16", np.float16), + ("float32", np.float32), + ("float64", np.float64)) + def test_lisht(self, dtype): + x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) + expected_result = tf.constant( + [1.9280552, 0.7615942, 0.0, 0.7615942, 1.9280552], dtype=dtype) + self.assertAllCloseAccordingToType(lisht(x), expected_result) + + @parameterized.named_parameters(("float32", np.float32), + ("float64", np.float64)) + def test_theoretical_gradients(self, dtype): + # Only test theoretical gradients for float32 and float64 + # because of the instability of float16 while computing jacobian + x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) + + theoretical, numerical = tf.test.compute_gradient(lisht, [x]) + self.assertAllCloseAccordingToType( + theoretical, numerical, rtol=5e-4, atol=5e-4) + + def test_unknown_shape(self): + fn = lisht.get_concrete_function( + tf.TensorSpec(shape=None, dtype=tf.float32)) + + for shape in [(1,), (1, 2), (1, 2, 3), (1, 2, 3, 4)]: + x = tf.ones(shape=shape, dtype=tf.float32) + self.assertAllClose(fn(x), lisht(x)) + + def test_serialization(self): + config = tf.keras.activations.serialize(lisht) + fn = tf.keras.activations.deserialize(config) + self.assertEqual(fn, lisht) + + def test_serialization_with_layers(self): + layer = tf.keras.layers.Dense(3, activation=lisht) + config = tf.keras.layers.serialize(layer) + deserialized_layer = tf.keras.layers.deserialize(config) + self.assertEqual(deserialized_layer.__class__.__name__, + layer.__class__.__name__) + self.assertEqual(deserialized_layer.activation.__name__, "lisht") + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_addons/custom_ops/activations/BUILD b/tensorflow_addons/custom_ops/activations/BUILD index b61d7a6fa3..567b06250d 100644 --- a/tensorflow_addons/custom_ops/activations/BUILD +++ b/tensorflow_addons/custom_ops/activations/BUILD @@ -49,6 +49,28 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "lisht_op_gpu", + srcs = [ + "cc/kernels/lisht_op.h", + "cc/kernels/lisht_op_gpu.cu.cc", + ], + copts = if_cuda_is_configured([ + "-DGOOGLE_CUDA=1", + "-x cuda", + "-nvcc_options=relaxed-constexpr", + "-nvcc_options=ftz=true", + ]), + deps = [ + "@local_config_tf//:libtensorflow_framework", + "@local_config_tf//:tf_header_lib", + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_libs", + "@local_config_cuda//cuda:cuda_headers", + ]), + alwayslink = 1, +) + cc_library( name = "tanhshrink_op_gpu", srcs = [ @@ -78,10 +100,13 @@ cc_binary( "cc/kernels/gelu_op.h", "cc/kernels/hardshrink_op.cc", "cc/kernels/hardshrink_op.h", + "cc/kernels/lisht_op.cc", + "cc/kernels/lisht_op.h", "cc/kernels/tanhshrink_op.cc", "cc/kernels/tanhshrink_op.h", "cc/ops/gelu_op.cc", "cc/ops/hardshrink_op.cc", + "cc/ops/lisht_op.cc", "cc/ops/tanhshrink_op.cc", ], copts = [ @@ -96,6 +121,7 @@ cc_binary( ] + if_cuda_is_configured([ ":gelu_op_gpu", ":hardshrink_op_gpu", + ":lisht_op_gpu", ":tanhshrink_op_gpu", ]), ) diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/lisht_op.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/lisht_op.cc new file mode 100644 index 0000000000..05d56a043f --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/lisht_op.cc @@ -0,0 +1,79 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow_addons/custom_ops/activations/cc/kernels/lisht_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace addons { + +using CPUDevice = Eigen::ThreadPoolDevice; + +#define REGISTER_LISHT_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Addons>Lisht").Device(DEVICE_CPU).TypeConstraint("T"), \ + LishtOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Addons>LishtGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ + LishtGradOp); + +// Lisht only makes sense with floating points. +TF_CALL_GPU_NUMBER_TYPES(REGISTER_LISHT_KERNELS); +#undef REGISTER_LISHT_KERNELS + +#if GOOGLE_CUDA + +using GPUDevice = Eigen::GpuDevice; + +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void Lisht::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor features, \ + typename TTypes::Tensor activations); \ + extern template struct Lisht; \ + \ + template <> \ + void LishtGrad::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor gradients, \ + typename TTypes::ConstTensor features, \ + typename TTypes::Tensor backprops); \ + extern template struct LishtGrad; + +TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); +#undef DECLARE_GPU_SPEC +} // namespace functor + +// Registration of the GPU implementations. +#define REGISTER_LISHT_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Addons>Lisht").Device(DEVICE_GPU).TypeConstraint("T"), \ + LishtOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Addons>LishtGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ + LishtGradOp); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_LISHT_GPU_KERNELS); +#undef REGISTER_LISHT_GPU_KERNELS + +#endif // GOOGLE_CUDA + +} // namespace addons +} // namespace tensorflow diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/lisht_op.h b/tensorflow_addons/custom_ops/activations/cc/kernels/lisht_op.h new file mode 100644 index 0000000000..a3b5e85ca0 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/lisht_op.h @@ -0,0 +1,106 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_LISHT_OP_H_ +#define TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_LISHT_OP_H_ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace addons { +namespace functor { + +// Functor used by LishtOp to do the computations. +template +struct Lisht { + // Computes Lisht activation. + // + // features: any shape. + // activations: same shape as "features". + void operator()(const Device& d, typename TTypes::ConstTensor features, + typename TTypes::Tensor activations) { + activations.device(d) = features * features.tanh(); + } +}; + +// Functor used by LishtGradOp to do the computations. +template +struct LishtGrad { + // Computes LishtGrad backprops. + // + // gradients: gradients backpropagated to the Lisht op. + // features: inputs that were passed to the Lisht op. + // backprops: gradients to backpropagate to the Lisht inputs. + void operator()(const Device& d, typename TTypes::ConstTensor gradients, + typename TTypes::ConstTensor features, + typename TTypes::Tensor backprops) { + const auto g = features.tanh(); + backprops.device(d) = + gradients * (g + features * (static_cast(1.0) - g.square())); + } +}; + +} // namespace functor + +template +class LishtOp : public UnaryElementWiseOp> { + public: + explicit LishtOp(OpKernelConstruction* context) + : UnaryElementWiseOp>::UnaryElementWiseOp(context) { + } + + void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { + functor::Lisht functor; + functor(context->eigen_device(), input.flat(), + output->flat()); + } +}; + +template +class LishtGradOp : public BinaryElementWiseOp> { + public: + explicit LishtGradOp(OpKernelConstruction* context) + : BinaryElementWiseOp>::BinaryElementWiseOp( + context) {} + + void OperateNoTemplate(OpKernelContext* context, const Tensor& g, + const Tensor& a, Tensor* output); + + template + void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, + Tensor* output) { + OperateNoTemplate(context, g, a, output); + } +}; + +template +void LishtGradOp::OperateNoTemplate(OpKernelContext* context, + const Tensor& g, const Tensor& a, + Tensor* output) { + functor::LishtGrad functor; + functor(context->eigen_device(), g.flat(), a.flat(), + output->flat()); +} + +} // namespace addons +} // namespace tensorflow + +#undef EIGEN_USE_THREADS + +#endif // TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_LISHT_OP_H_ diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/lisht_op_gpu.cu.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/lisht_op_gpu.cu.cc new file mode 100644 index 0000000000..66e0f979a1 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/lisht_op_gpu.cu.cc @@ -0,0 +1,38 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow_addons/custom_ops/activations/cc/kernels/lisht_op.h" +#include "tensorflow/core/framework/register_types.h" +#include "third_party/eigen3/Eigen/Core" + +namespace tensorflow { +namespace addons { + +using GPUDevice = Eigen::GpuDevice; + +#define DEFINE_GPU_KERNELS(T) \ + template struct functor::Lisht; \ + template struct functor::LishtGrad; + +TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); + +} // namespace addons +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow_addons/custom_ops/activations/cc/ops/lisht_op.cc b/tensorflow_addons/custom_ops/activations/cc/ops/lisht_op.cc new file mode 100644 index 0000000000..1a5b1712e9 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/ops/lisht_op.cc @@ -0,0 +1,37 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { +namespace addons { + +REGISTER_OP("Addons>Lisht") + .Input("features: T") + .Output("activations: T") + .Attr("T: {half, float, double}") + .SetShapeFn(shape_inference::UnchangedShape); + +REGISTER_OP("Addons>LishtGrad") + .Input("gradients: T") + .Input("features: T") + .Output("backprops: T") + .Attr("T: {half, float, double}") + .SetShapeFn(shape_inference::MergeBothInputsShapeFn); + +} // namespace addons +} // namespace tensorflow