diff --git a/tensorflow_addons/activations/BUILD b/tensorflow_addons/activations/BUILD index a67b39be21..3c59412f12 100644 --- a/tensorflow_addons/activations/BUILD +++ b/tensorflow_addons/activations/BUILD @@ -9,6 +9,7 @@ py_library( "gelu.py", "hardshrink.py", "lisht.py", + "mish.py", "rrelu.py", "softshrink.py", "sparsemax.py", @@ -86,6 +87,19 @@ py_test( ], ) +py_test( + name = "mish_test", + size = "small", + srcs = [ + "mish_test.py", + ], + main = "mish_test.py", + srcs_version = "PY2AND3", + deps = [ + ":activations", + ], +) + py_test( name = "softshrink_test", size = "small", diff --git a/tensorflow_addons/activations/README.md b/tensorflow_addons/activations/README.md index 3a0cfb1323..bb4a2d5c8c 100644 --- a/tensorflow_addons/activations/README.md +++ b/tensorflow_addons/activations/README.md @@ -6,6 +6,7 @@ | gelu | @AakashKumarNain @WindQAQ | aakashnain@outlook.com windqaq@gmail.com | | hardshrink| @WindQAQ | windqaq@gmail.com | | lisht | @WindQAQ | windqaq@gmail.com | +| mish | @digantamisra98 @WindQAQ | mishradiganta91@gmail.com, windqaq@gmail.com | | softshrink| @WindQAQ | windqaq@gmail.com | | sparsemax | @AndreasMadsen | amwwebdk+github@gmail.com | | tanhshrink| @fsx950223 | fsx950223@gmail.com | @@ -17,12 +18,12 @@ | gelu | gelu | https://arxiv.org/abs/1606.08415 | | hardshrink| hardshrink | | | lisht | lisht | https://arxiv.org/abs/1901.05894 | +| mish | mish | https://arxiv.org/abs/1908.08681 | | softshrink| softshrink | | | sparsemax | sparsemax | https://arxiv.org/abs/1602.02068 | | tanhshrink| tanhshrink | | | rrelu | rrelu | https://arxiv.org/abs/1505.00853 | - ## Contribution Guidelines #### Standard API In order to conform with the current API standard, all activations diff --git a/tensorflow_addons/activations/__init__.py b/tensorflow_addons/activations/__init__.py index dcbf4ad004..31b5d88b9a 100644 --- a/tensorflow_addons/activations/__init__.py +++ b/tensorflow_addons/activations/__init__.py @@ -21,6 +21,7 @@ from tensorflow_addons.activations.gelu import gelu from tensorflow_addons.activations.hardshrink import hardshrink from tensorflow_addons.activations.lisht import lisht +from tensorflow_addons.activations.mish import mish from tensorflow_addons.activations.softshrink import softshrink from tensorflow_addons.activations.rrelu import rrelu from tensorflow_addons.activations.sparsemax import sparsemax diff --git a/tensorflow_addons/activations/activations_test.py b/tensorflow_addons/activations/activations_test.py index 58b946577d..c3ce9db8f9 100644 --- a/tensorflow_addons/activations/activations_test.py +++ b/tensorflow_addons/activations/activations_test.py @@ -26,8 +26,8 @@ class ActivationsTest(tf.test.TestCase): ALL_ACTIVATIONS = [ - "gelu", "hardshrink", "lisht", "softshrink", "sparsemax", "rrelu", - "tanhshrink" + "gelu", "hardshrink", "lisht", "mish", "rrelu", "softshrink", + "sparsemax", "tanhshrink" ] def test_serialization(self): diff --git a/tensorflow_addons/activations/mish.py b/tensorflow_addons/activations/mish.py new file mode 100644 index 0000000000..8b64b3ee66 --- /dev/null +++ b/tensorflow_addons/activations/mish.py @@ -0,0 +1,49 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow_addons.utils import keras_utils +from tensorflow_addons.utils.resource_loader import get_path_to_datafile + +_activation_ops_so = tf.load_op_library( + get_path_to_datafile("custom_ops/activations/_activation_ops.so")) + + +@keras_utils.register_keras_custom_object +@tf.function +def mish(x): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function. + + Computes mish activation: x * tanh(softplus(x)) + + See [Mish: A Self Regularized Non-Monotonic Neural Activation Function](https://arxiv.org/abs/1908.08681). + + Args: + x: A `Tensor`. Must be one of the following types: + `float16`, `float32`, `float64`. + Returns: + A `Tensor`. Has the same type as `x`. + """ + x = tf.convert_to_tensor(x) + return _activation_ops_so.addons_mish(x) + + +@tf.RegisterGradient("Addons>Mish") +def _mish_grad(op, grad): + return _activation_ops_so.addons_mish_grad(grad, op.inputs[0]) diff --git a/tensorflow_addons/activations/mish_test.py b/tensorflow_addons/activations/mish_test.py new file mode 100644 index 0000000000..74e374ed75 --- /dev/null +++ b/tensorflow_addons/activations/mish_test.py @@ -0,0 +1,59 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +import numpy as np +import tensorflow as tf +from tensorflow_addons.activations import mish +from tensorflow_addons.utils import test_utils + + +@test_utils.run_all_in_graph_and_eager_modes +class MishTest(tf.test.TestCase, parameterized.TestCase): + @parameterized.named_parameters(("float16", np.float16), + ("float32", np.float32), + ("float64", np.float64)) + def test_mish(self, dtype): + x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) + expected_result = tf.constant( + [-0.2525015, -0.30340144, 0.0, 0.86509836, 1.943959], dtype=dtype) + self.assertAllCloseAccordingToType(mish(x), expected_result) + + @parameterized.named_parameters(("float32", np.float32), + ("float64", np.float64)) + def test_theoretical_gradients(self, dtype): + # Only test theoretical gradients for float32 and float64 + # because of the instability of float16 while computing jacobian + x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) + + theoretical, numerical = tf.test.compute_gradient(mish, [x]) + self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4) + + def test_unknown_shape(self): + fn = mish.get_concrete_function( + tf.TensorSpec(shape=None, dtype=tf.float32)) + + for shape in [(1,), (1, 2), (1, 2, 3), (1, 2, 3, 4)]: + x = tf.ones(shape=shape, dtype=tf.float32) + self.assertAllClose(fn(x), mish(x)) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_addons/custom_ops/activations/BUILD b/tensorflow_addons/custom_ops/activations/BUILD index f12442cefe..a51c5d1dec 100644 --- a/tensorflow_addons/custom_ops/activations/BUILD +++ b/tensorflow_addons/custom_ops/activations/BUILD @@ -13,6 +13,8 @@ custom_op_library( "cc/kernels/hardshrink_op.h", "cc/kernels/lisht_op.cc", "cc/kernels/lisht_op.h", + "cc/kernels/mish_op.cc", + "cc/kernels/mish_op.h", "cc/kernels/rrelu_op.cc", "cc/kernels/rrelu_op.h", "cc/kernels/softshrink_op.cc", @@ -22,6 +24,7 @@ custom_op_library( "cc/ops/gelu_op.cc", "cc/ops/hardshrink_op.cc", "cc/ops/lisht_op.cc", + "cc/ops/mish_op.cc", "cc/ops/rrelu_op.cc", "cc/ops/softshrink_op.cc", "cc/ops/tanhshrink_op.cc", @@ -33,6 +36,8 @@ custom_op_library( "cc/kernels/hardshrink_op_gpu.cu.cc", "cc/kernels/lisht_op.h", "cc/kernels/lisht_op_gpu.cu.cc", + "cc/kernels/mish_op.h", + "cc/kernels/mish_op_gpu.cu.cc", "cc/kernels/rrelu_op.h", "cc/kernels/rrelu_op_gpu.cu.cc", "cc/kernels/softshrink_op.h", diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.cc new file mode 100644 index 0000000000..edef3a0839 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.cc @@ -0,0 +1,79 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace addons { + +using CPUDevice = Eigen::ThreadPoolDevice; + +#define REGISTER_MISH_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Addons>Mish").Device(DEVICE_CPU).TypeConstraint("T"), \ + MishOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Addons>MishGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ + MishGradOp); + +// Mish only makes sense with floating points. +TF_CALL_GPU_NUMBER_TYPES(REGISTER_MISH_KERNELS); +#undef REGISTER_MISH_KERNELS + +#if GOOGLE_CUDA + +using GPUDevice = Eigen::GpuDevice; + +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void Mish::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor features, \ + typename TTypes::Tensor activations); \ + extern template struct Mish; \ + \ + template <> \ + void MishGrad::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor gradients, \ + typename TTypes::ConstTensor features, \ + typename TTypes::Tensor backprops); \ + extern template struct MishGrad; + +TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); +#undef DECLARE_GPU_SPEC +} // namespace functor + +// Registration of the GPU implementations. +#define REGISTER_MISH_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Addons>Mish").Device(DEVICE_GPU).TypeConstraint("T"), \ + MishOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Addons>MishGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ + MishGradOp); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_MISH_GPU_KERNELS); +#undef REGISTER_MISH_GPU_KERNELS + +#endif // GOOGLE_CUDA + +} // namespace addons +} // namespace tensorflow diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h b/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h new file mode 100644 index 0000000000..c111f99ef7 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h @@ -0,0 +1,130 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_MISH_OP_H_ +#define TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_MISH_OP_H_ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace addons { +namespace functor { + +// Functor used by MishOp to do the computations. +template +struct Mish { + // Computes Mish activation. + // + // features: any shape. + // activations: same shape as "features". + void operator()(const Device& d, typename TTypes::ConstTensor features, + typename TTypes::Tensor activations) { + // softplus implementation + // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/softplus_op.h + static const T threshold = + Eigen::numext::log(Eigen::NumTraits::epsilon()) + T(2); + const auto& too_large = features > features.constant(-threshold); + const auto& too_small = features < features.constant(threshold); + const auto& features_exp = features.exp(); + const auto& sp = too_large.select( + features, + too_small.select(features_exp, + (features_exp + features.constant(T(1))).log())); + activations.device(d) = features * sp.tanh(); + } +}; + +// Functor used by MishGradOp to do the computations. +template +struct MishGrad { + // Computes MishGrad backprops. + // + // gradients: gradients backpropagated to the Mish op. + // features: inputs that were passed to the Mish op. + // backprops: gradients to backpropagate to the Mish inputs. + void operator()(const Device& d, typename TTypes::ConstTensor gradients, + typename TTypes::ConstTensor features, + typename TTypes::Tensor backprops) { + // softplus implementation + // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/softplus_op.h + static const T threshold = + Eigen::numext::log(Eigen::NumTraits::epsilon()) + T(2); + const auto& too_large = features > features.constant(-threshold); + const auto& too_small = features < features.constant(threshold); + const auto& features_exp = features.exp(); + const auto& sp = too_large.select( + features, + too_small.select(features_exp, + (features_exp + features.constant(T(1))).log())); + + const auto& grad_sp = static_cast(1) - (-sp).exp(); + const auto& tsp = sp.tanh(); + const auto& grad_tsp = ((static_cast(1) - tsp * tsp) * grad_sp); + const auto& grad = features * grad_tsp + tsp; + backprops.device(d) = gradients * grad; + } +}; + +} // namespace functor + +template +class MishOp : public UnaryElementWiseOp> { + public: + explicit MishOp(OpKernelConstruction* context) + : UnaryElementWiseOp>::UnaryElementWiseOp(context) {} + + void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { + functor::Mish functor; + functor(context->eigen_device(), input.flat(), + output->flat()); + } +}; + +template +class MishGradOp : public BinaryElementWiseOp> { + public: + explicit MishGradOp(OpKernelConstruction* context) + : BinaryElementWiseOp>::BinaryElementWiseOp( + context) {} + + void OperateNoTemplate(OpKernelContext* context, const Tensor& g, + const Tensor& a, Tensor* output); + + template + void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, + Tensor* output) { + OperateNoTemplate(context, g, a, output); + } +}; + +template +void MishGradOp::OperateNoTemplate(OpKernelContext* context, + const Tensor& g, const Tensor& a, + Tensor* output) { + functor::MishGrad functor; + functor(context->eigen_device(), g.flat(), a.flat(), + output->flat()); +} + +} // namespace addons +} // namespace tensorflow + +#undef EIGEN_USE_THREADS + +#endif // TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_MISH_OP_H_ diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op_gpu.cu.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op_gpu.cu.cc new file mode 100644 index 0000000000..e2fc143cee --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op_gpu.cu.cc @@ -0,0 +1,38 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h" +#include "tensorflow/core/framework/register_types.h" +#include "third_party/eigen3/Eigen/Core" + +namespace tensorflow { +namespace addons { + +using GPUDevice = Eigen::GpuDevice; + +#define DEFINE_GPU_KERNELS(T) \ + template struct functor::Mish; \ + template struct functor::MishGrad; + +TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); + +} // namespace addons +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow_addons/custom_ops/activations/cc/ops/mish_op.cc b/tensorflow_addons/custom_ops/activations/cc/ops/mish_op.cc new file mode 100644 index 0000000000..2f8a548889 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/ops/mish_op.cc @@ -0,0 +1,37 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { +namespace addons { + +REGISTER_OP("Addons>Mish") + .Input("features: T") + .Output("activations: T") + .Attr("T: {half, float, double}") + .SetShapeFn(shape_inference::UnchangedShape); + +REGISTER_OP("Addons>MishGrad") + .Input("gradients: T") + .Input("features: T") + .Output("backprops: T") + .Attr("T: {half, float, double}") + .SetShapeFn(shape_inference::MergeBothInputsShapeFn); + +} // namespace addons +} // namespace tensorflow