From a38f391224c42da09da9e8858f62d040c67a7711 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Wed, 25 Sep 2019 22:18:46 -0700 Subject: [PATCH 01/13] add mish --- tensorflow_addons/activations/BUILD | 14 +++ tensorflow_addons/activations/README.md | 2 + tensorflow_addons/activations/__init__.py | 1 + tensorflow_addons/activations/mish.py | 49 ++++++++ tensorflow_addons/activations/mish_test.py | 75 ++++++++++++ .../custom_ops/activations/BUILD | 26 +++++ .../activations/cc/kernels/mish_op.cc | 79 +++++++++++++ .../activations/cc/kernels/mish_op.h | 110 ++++++++++++++++++ .../activations/cc/kernels/mish_op_gpu.cu.cc | 38 ++++++ .../custom_ops/activations/cc/ops/mish_op.cc | 37 ++++++ 10 files changed, 431 insertions(+) create mode 100644 tensorflow_addons/activations/mish.py create mode 100644 tensorflow_addons/activations/mish_test.py create mode 100644 tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.cc create mode 100644 tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h create mode 100644 tensorflow_addons/custom_ops/activations/cc/kernels/mish_op_gpu.cu.cc create mode 100644 tensorflow_addons/custom_ops/activations/cc/ops/mish_op.cc diff --git a/tensorflow_addons/activations/BUILD b/tensorflow_addons/activations/BUILD index e5f9640bb4..605416ac0f 100644 --- a/tensorflow_addons/activations/BUILD +++ b/tensorflow_addons/activations/BUILD @@ -8,6 +8,7 @@ py_library( "__init__.py", "gelu.py", "hardshrink.py", + "mish.py", "sparsemax.py", "tanhshrink.py", ], @@ -57,6 +58,19 @@ py_test( ], ) +py_test( + name = "mish_test", + size = "medium", + srcs = [ + "mish_test.py", + ], + main = "mish_test.py", + srcs_version = "PY2AND3", + deps = [ + ":activations", + ], +) + py_test( name = "tanhshrink_test", size = "medium", diff --git a/tensorflow_addons/activations/README.md b/tensorflow_addons/activations/README.md index d34c12b962..25342db39e 100644 --- a/tensorflow_addons/activations/README.md +++ b/tensorflow_addons/activations/README.md @@ -5,6 +5,7 @@ |:----------|:--------------------------|:-----------------------------------------| | gelu | @AakashKumarNain @WindQAQ | aakashnain@outlook.com windqaq@gmail.com | | hardshrink| @WindQAQ | windqaq@gmail.com +| mish | @WindQAQ | windqaq@gmail.com | | sparsemax | @AndreasMadsen | amwwebdk+github@gmail.com | | tanhshrink | @fsx950223 | fsx950223@gmail.com | @@ -13,6 +14,7 @@ |:----------|:-----------|:---------------------------------| | gelu | gelu | https://arxiv.org/abs/1606.08415 | | hardshrink| hardshrink | | +| mish | mish | https://arxiv.org/abs/1908.08681 | | sparsemax | Sparsemax | https://arxiv.org/abs/1602.02068 | | tanhshrink | Tanhshrink | | diff --git a/tensorflow_addons/activations/__init__.py b/tensorflow_addons/activations/__init__.py index 4208b57817..e714ffda26 100644 --- a/tensorflow_addons/activations/__init__.py +++ b/tensorflow_addons/activations/__init__.py @@ -20,5 +20,6 @@ from tensorflow_addons.activations.gelu import gelu from tensorflow_addons.activations.hardshrink import hardshrink +from tensorflow_addons.activations.mish import mish from tensorflow_addons.activations.sparsemax import sparsemax from tensorflow_addons.activations.tanhshrink import tanhshrink diff --git a/tensorflow_addons/activations/mish.py b/tensorflow_addons/activations/mish.py new file mode 100644 index 0000000000..8b64b3ee66 --- /dev/null +++ b/tensorflow_addons/activations/mish.py @@ -0,0 +1,49 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow_addons.utils import keras_utils +from tensorflow_addons.utils.resource_loader import get_path_to_datafile + +_activation_ops_so = tf.load_op_library( + get_path_to_datafile("custom_ops/activations/_activation_ops.so")) + + +@keras_utils.register_keras_custom_object +@tf.function +def mish(x): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function. + + Computes mish activation: x * tanh(softplus(x)) + + See [Mish: A Self Regularized Non-Monotonic Neural Activation Function](https://arxiv.org/abs/1908.08681). + + Args: + x: A `Tensor`. Must be one of the following types: + `float16`, `float32`, `float64`. + Returns: + A `Tensor`. Has the same type as `x`. + """ + x = tf.convert_to_tensor(x) + return _activation_ops_so.addons_mish(x) + + +@tf.RegisterGradient("Addons>Mish") +def _mish_grad(op, grad): + return _activation_ops_so.addons_mish_grad(grad, op.inputs[0]) diff --git a/tensorflow_addons/activations/mish_test.py b/tensorflow_addons/activations/mish_test.py new file mode 100644 index 0000000000..a7bdf1c1b8 --- /dev/null +++ b/tensorflow_addons/activations/mish_test.py @@ -0,0 +1,75 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +import math + +import numpy as np +import tensorflow as tf +from tensorflow_addons.activations import mish +from tensorflow_addons.utils import test_utils + + +@test_utils.run_all_in_graph_and_eager_modes +class MishTest(tf.test.TestCase, parameterized.TestCase): + @parameterized.named_parameters(("float16", np.float16), + ("float32", np.float32), + ("float64", np.float64)) + def test_mish(self, dtype): + x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) + expected_result = tf.constant( + [-0.2525015, -0.30340144, 0.0, 0.86509836, 1.943959], dtype=dtype) + self.assertAllCloseAccordingToType(mish(x), expected_result) + + @parameterized.named_parameters(("float32", np.float32), + ("float64", np.float64)) + def test_theoretical_gradients(self, dtype): + # Only test theoretical gradients for float32 and float64 + # because of the instability of float16 while computing jacobian + x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) + + theoretical, numerical = tf.test.compute_gradient(mish, [x]) + self.assertAllCloseAccordingToType( + theoretical, numerical, atol=1e-4) + + def test_unknown_shape(self): + fn = mish.get_concrete_function( + tf.TensorSpec(shape=None, dtype=tf.float32)) + + for shape in [(1,), (1, 2), (1, 2, 3), (1, 2, 3, 4)]: + x = tf.ones(shape=shape, dtype=tf.float32) + self.assertAllClose(fn(x), mish(x)) + + def test_serialization(self): + config = tf.keras.activations.serialize(mish) + fn = tf.keras.activations.deserialize(config) + self.assertEqual(fn, mish) + + def test_serialization_with_layers(self): + layer = tf.keras.layers.Dense(3, activation=mish) + config = tf.keras.layers.serialize(layer) + deserialized_layer = tf.keras.layers.deserialize(config) + self.assertEqual(deserialized_layer.__class__.__name__, + layer.__class__.__name__) + self.assertEqual(deserialized_layer.activation.__name__, "mish") + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_addons/custom_ops/activations/BUILD b/tensorflow_addons/custom_ops/activations/BUILD index b61d7a6fa3..41f1e04f79 100644 --- a/tensorflow_addons/custom_ops/activations/BUILD +++ b/tensorflow_addons/custom_ops/activations/BUILD @@ -49,6 +49,28 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "mish_op_gpu", + srcs = [ + "cc/kernels/mish_op.h", + "cc/kernels/mish_op_gpu.cu.cc", + ], + copts = if_cuda_is_configured([ + "-DGOOGLE_CUDA=1", + "-x cuda", + "-nvcc_options=relaxed-constexpr", + "-nvcc_options=ftz=true", + ]), + deps = [ + "@local_config_tf//:libtensorflow_framework", + "@local_config_tf//:tf_header_lib", + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_libs", + "@local_config_cuda//cuda:cuda_headers", + ]), + alwayslink = 1, +) + cc_library( name = "tanhshrink_op_gpu", srcs = [ @@ -78,10 +100,13 @@ cc_binary( "cc/kernels/gelu_op.h", "cc/kernels/hardshrink_op.cc", "cc/kernels/hardshrink_op.h", + "cc/kernels/mish_op.cc", + "cc/kernels/mish_op.h", "cc/kernels/tanhshrink_op.cc", "cc/kernels/tanhshrink_op.h", "cc/ops/gelu_op.cc", "cc/ops/hardshrink_op.cc", + "cc/ops/mish_op.cc", "cc/ops/tanhshrink_op.cc", ], copts = [ @@ -96,6 +121,7 @@ cc_binary( ] + if_cuda_is_configured([ ":gelu_op_gpu", ":hardshrink_op_gpu", + ":mish_op_gpu", ":tanhshrink_op_gpu", ]), ) diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.cc new file mode 100644 index 0000000000..7e72489590 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.cc @@ -0,0 +1,79 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace addons { + +using CPUDevice = Eigen::ThreadPoolDevice; + +#define REGISTER_MISH_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Addons>Mish").Device(DEVICE_CPU).TypeConstraint("T"), \ + MishOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Addons>MishGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ + MishGradOp); + +// Mish only makes sense with floating points. +TF_CALL_GPU_NUMBER_TYPES(REGISTER_MISH_KERNELS); +#undef REGISTER_MISH_KERNELS + +#if GOOGLE_CUDA + +using GPUDevice = Eigen::GpuDevice; + +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void Mish::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor features, \ + typename TTypes::Tensor activations); \ + extern template struct Mish; \ + \ + template <> \ + void MishGrad::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor gradients, \ + typename TTypes::ConstTensor features, \ + typename TTypes::Tensor backprops); \ + extern template struct MishGrad; + +TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); +#undef DECLARE_GPU_SPEC +} // namespace functor + +// Registration of the GPU implementations. +#define REGISTER_MISH_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Addons>Mish").Device(DEVICE_GPU).TypeConstraint("T"), \ + MishOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Addons>MishGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ + MishGradOp); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_MISH_GPU_KERNELS); +#undef REGISTER_MISH_GPU_KERNELS + +#endif // GOOGLE_CUDA + +} // namespace addons +} // namespace tensorflow diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h b/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h new file mode 100644 index 0000000000..38aacfe548 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h @@ -0,0 +1,110 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_MISH_OP_H_ +#define TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_MISH_OP_H_ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace addons { +namespace functor { + +// Functor used by MishOp to do the computations. +template +struct Mish { + // Computes Mish activation. + // + // features: any shape. + // activations: same shape as "features". + void operator()(const Device& d, typename TTypes::ConstTensor features, + typename TTypes::Tensor activations) { + activations.device(d) = features * features.exp().log1p().tanh(); + } +}; + +// Functor used by MishGradOp to do the computations. +template +struct MishGrad { + // Computes MishGrad backprops. + // + // gradients: gradients backpropagated to the Mish op. + // features: inputs that were passed to the Mish op. + // backprops: gradients to backpropagate to the Mish inputs. + void operator()(const Device& d, typename TTypes::ConstTensor gradients, + typename TTypes::ConstTensor features, + typename TTypes::Tensor backprops) { + auto& e = features.exp(); + auto& es = e.square(); + auto& omega = static_cast(4) * (features + static_cast(1)) + + static_cast(4) * es + e.cube() + + e * (static_cast(4) * features + static_cast(6)); + auto& delta = static_cast(2) * e + es + static_cast(2); + backprops.device(d) = gradients * e * omega / delta.square(); + } +}; + +} // namespace functor + +template +class MishOp : public UnaryElementWiseOp> { + public: + explicit MishOp(OpKernelConstruction* context) + : UnaryElementWiseOp>::UnaryElementWiseOp(context) { + } + + void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { + functor::Mish functor; + functor(context->eigen_device(), input.flat(), output->flat()); + } +}; + +template +class MishGradOp : public BinaryElementWiseOp> { + public: + explicit MishGradOp(OpKernelConstruction* context) + : BinaryElementWiseOp>::BinaryElementWiseOp( + context) { + } + + void OperateNoTemplate(OpKernelContext* context, const Tensor& g, + const Tensor& a, Tensor* output); + + template + void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, + Tensor* output) { + OperateNoTemplate(context, g, a, output); + } +}; + +template +void MishGradOp::OperateNoTemplate(OpKernelContext* context, + const Tensor& g, const Tensor& a, + Tensor* output) { + functor::MishGrad functor; + functor(context->eigen_device(), g.flat(), a.flat(), + output->flat()); +} + +} // namespace addons +} // namespace tensorflow + +#undef EIGEN_USE_THREADS + +#endif // TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_MISH_OP_H_ diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op_gpu.cu.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op_gpu.cu.cc new file mode 100644 index 0000000000..e2fc143cee --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op_gpu.cu.cc @@ -0,0 +1,38 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h" +#include "tensorflow/core/framework/register_types.h" +#include "third_party/eigen3/Eigen/Core" + +namespace tensorflow { +namespace addons { + +using GPUDevice = Eigen::GpuDevice; + +#define DEFINE_GPU_KERNELS(T) \ + template struct functor::Mish; \ + template struct functor::MishGrad; + +TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); + +} // namespace addons +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow_addons/custom_ops/activations/cc/ops/mish_op.cc b/tensorflow_addons/custom_ops/activations/cc/ops/mish_op.cc new file mode 100644 index 0000000000..2f8a548889 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/ops/mish_op.cc @@ -0,0 +1,37 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { +namespace addons { + +REGISTER_OP("Addons>Mish") + .Input("features: T") + .Output("activations: T") + .Attr("T: {half, float, double}") + .SetShapeFn(shape_inference::UnchangedShape); + +REGISTER_OP("Addons>MishGrad") + .Input("gradients: T") + .Input("features: T") + .Output("backprops: T") + .Attr("T: {half, float, double}") + .SetShapeFn(shape_inference::MergeBothInputsShapeFn); + +} // namespace addons +} // namespace tensorflow From 9810a66323b0acf74050578e5c37ad3e47397c77 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Wed, 25 Sep 2019 22:21:46 -0700 Subject: [PATCH 02/13] format code --- tensorflow_addons/activations/mish_test.py | 5 +--- .../activations/cc/kernels/mish_op.cc | 4 +-- .../activations/cc/kernels/mish_op.h | 25 +++++++++---------- 3 files changed, 15 insertions(+), 19 deletions(-) diff --git a/tensorflow_addons/activations/mish_test.py b/tensorflow_addons/activations/mish_test.py index a7bdf1c1b8..7fddb51e19 100644 --- a/tensorflow_addons/activations/mish_test.py +++ b/tensorflow_addons/activations/mish_test.py @@ -19,8 +19,6 @@ from absl.testing import parameterized -import math - import numpy as np import tensorflow as tf from tensorflow_addons.activations import mish @@ -46,8 +44,7 @@ def test_theoretical_gradients(self, dtype): x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) theoretical, numerical = tf.test.compute_gradient(mish, [x]) - self.assertAllCloseAccordingToType( - theoretical, numerical, atol=1e-4) + self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4) def test_unknown_shape(self): fn = mish.get_concrete_function( diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.cc index 7e72489590..edef3a0839 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.cc +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.cc @@ -47,13 +47,13 @@ namespace functor { template <> \ void Mish::operator()( \ const GPUDevice& d, typename TTypes::ConstTensor features, \ - typename TTypes::Tensor activations); \ + typename TTypes::Tensor activations); \ extern template struct Mish; \ \ template <> \ void MishGrad::operator()( \ const GPUDevice& d, typename TTypes::ConstTensor gradients, \ - typename TTypes::ConstTensor features, \ + typename TTypes::ConstTensor features, \ typename TTypes::Tensor backprops); \ extern template struct MishGrad; diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h b/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h index 38aacfe548..9f434fb300 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h @@ -35,7 +35,7 @@ struct Mish { // activations: same shape as "features". void operator()(const Device& d, typename TTypes::ConstTensor features, typename TTypes::Tensor activations) { - activations.device(d) = features * features.exp().log1p().tanh(); + activations.device(d) = features * features.exp().log1p().tanh(); } }; @@ -50,13 +50,13 @@ struct MishGrad { void operator()(const Device& d, typename TTypes::ConstTensor gradients, typename TTypes::ConstTensor features, typename TTypes::Tensor backprops) { - auto& e = features.exp(); - auto& es = e.square(); - auto& omega = static_cast(4) * (features + static_cast(1)) + - static_cast(4) * es + e.cube() + - e * (static_cast(4) * features + static_cast(6)); - auto& delta = static_cast(2) * e + es + static_cast(2); - backprops.device(d) = gradients * e * omega / delta.square(); + auto& e = features.exp(); + auto& es = e.square(); + auto& omega = static_cast(4) * (features + static_cast(1)) + + static_cast(4) * es + e.cube() + + e * (static_cast(4) * features + static_cast(6)); + auto& delta = static_cast(2) * e + es + static_cast(2); + backprops.device(d) = gradients * e * omega / delta.square(); } }; @@ -66,12 +66,12 @@ template class MishOp : public UnaryElementWiseOp> { public: explicit MishOp(OpKernelConstruction* context) - : UnaryElementWiseOp>::UnaryElementWiseOp(context) { - } + : UnaryElementWiseOp>::UnaryElementWiseOp(context) {} void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { functor::Mish functor; - functor(context->eigen_device(), input.flat(), output->flat()); + functor(context->eigen_device(), input.flat(), + output->flat()); } }; @@ -80,8 +80,7 @@ class MishGradOp : public BinaryElementWiseOp> { public: explicit MishGradOp(OpKernelConstruction* context) : BinaryElementWiseOp>::BinaryElementWiseOp( - context) { - } + context) {} void OperateNoTemplate(OpKernelContext* context, const Tensor& g, const Tensor& a, Tensor* output); From cda5a6f987ca67674c5552027991d2fa25ab51cf Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Fri, 4 Oct 2019 17:54:11 -0700 Subject: [PATCH 03/13] update README --- tensorflow_addons/activations/README.md | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/tensorflow_addons/activations/README.md b/tensorflow_addons/activations/README.md index bfb54640ac..4010a23ce7 100644 --- a/tensorflow_addons/activations/README.md +++ b/tensorflow_addons/activations/README.md @@ -4,13 +4,9 @@ | Submodule | Maintainers | Contact Info | |:----------|:--------------------------|:-----------------------------------------| | gelu | @AakashKumarNain @WindQAQ | aakashnain@outlook.com windqaq@gmail.com | -<<<<<<< HEAD -| hardshrink| @WindQAQ | windqaq@gmail.com -| mish | @WindQAQ | windqaq@gmail.com | -======= | hardshrink| @WindQAQ | windqaq@gmail.com | | lisht | @WindQAQ | windqaq@gmail.com | ->>>>>>> a603625227f0c39de7f68fde93ff0e23b602f284 +| mish | @WindQAQ | windqaq@gmail.com | | sparsemax | @AndreasMadsen | amwwebdk+github@gmail.com | | tanhshrink| @fsx950223 | fsx950223@gmail.com | @@ -19,16 +15,12 @@ |:----------|:-----------|:---------------------------------| | gelu | gelu | https://arxiv.org/abs/1606.08415 | | hardshrink| hardshrink | | -<<<<<<< HEAD -| mish | mish | https://arxiv.org/abs/1908.08681 | -| sparsemax | Sparsemax | https://arxiv.org/abs/1602.02068 | -| tanhshrink | Tanhshrink | | -======= | lisht | lisht | https://arxiv.org/abs/1901.05894 | +| mish | mish | https://arxiv.org/abs/1908.08681 | +| sparsemax | sparsemax | https://arxiv.org/abs/1602.02068 | +| tanhshrink| tanhshrink | | | sparsemax | sparsemax | https://arxiv.org/abs/1602.02068 | | tanhshrink| tanhshrink | | ->>>>>>> a603625227f0c39de7f68fde93ff0e23b602f284 - ## Contribution Guidelines #### Standard API From 7315412bb4afb4159224e6bc2a570f75650f4125 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Fri, 4 Oct 2019 17:54:21 -0700 Subject: [PATCH 04/13] update tests --- tensorflow_addons/activations/activations_test.py | 2 +- tensorflow_addons/activations/mish_test.py | 13 ------------- 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/tensorflow_addons/activations/activations_test.py b/tensorflow_addons/activations/activations_test.py index 31a4b82196..45d7e95f38 100644 --- a/tensorflow_addons/activations/activations_test.py +++ b/tensorflow_addons/activations/activations_test.py @@ -26,7 +26,7 @@ class ActivationsTest(tf.test.TestCase): ALL_ACTIVATIONS = [ - "gelu", "hardshrink", "lisht", "sparsemax", "tanhshrink" + "gelu", "hardshrink", "lisht", "mish", "sparsemax", "tanhshrink" ] def test_serialization(self): diff --git a/tensorflow_addons/activations/mish_test.py b/tensorflow_addons/activations/mish_test.py index 7fddb51e19..74e374ed75 100644 --- a/tensorflow_addons/activations/mish_test.py +++ b/tensorflow_addons/activations/mish_test.py @@ -54,19 +54,6 @@ def test_unknown_shape(self): x = tf.ones(shape=shape, dtype=tf.float32) self.assertAllClose(fn(x), mish(x)) - def test_serialization(self): - config = tf.keras.activations.serialize(mish) - fn = tf.keras.activations.deserialize(config) - self.assertEqual(fn, mish) - - def test_serialization_with_layers(self): - layer = tf.keras.layers.Dense(3, activation=mish) - config = tf.keras.layers.serialize(layer) - deserialized_layer = tf.keras.layers.deserialize(config) - self.assertEqual(deserialized_layer.__class__.__name__, - layer.__class__.__name__) - self.assertEqual(deserialized_layer.activation.__name__, "mish") - if __name__ == "__main__": tf.test.main() From 61c9cf68a24d3fbd0ba08466c73fac8a20cd6fcc Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Fri, 4 Oct 2019 17:56:39 -0700 Subject: [PATCH 05/13] get rid of auto --- .../custom_ops/activations/cc/kernels/mish_op.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h b/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h index 9f434fb300..b7917daf1a 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h @@ -50,12 +50,12 @@ struct MishGrad { void operator()(const Device& d, typename TTypes::ConstTensor gradients, typename TTypes::ConstTensor features, typename TTypes::Tensor backprops) { - auto& e = features.exp(); - auto& es = e.square(); - auto& omega = static_cast(4) * (features + static_cast(1)) + + typename TTypes::Tensor& e = features.exp(); + typename TTypes::Tensor& es = e.square(); + typename TTypes::Tensor& omega = static_cast(4) * (features + static_cast(1)) + static_cast(4) * es + e.cube() + e * (static_cast(4) * features + static_cast(6)); - auto& delta = static_cast(2) * e + es + static_cast(2); + typename TTypes::Tensor& delta = static_cast(2) * e + es + static_cast(2); backprops.device(d) = gradients * e * omega / delta.square(); } }; From 6fb3185d40a5fa98ebcc05152813d68e5981de1d Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Fri, 4 Oct 2019 18:36:54 -0700 Subject: [PATCH 06/13] eval intermediate value --- .../custom_ops/activations/cc/kernels/mish_op.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h b/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h index b7917daf1a..525c50621a 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h @@ -50,12 +50,12 @@ struct MishGrad { void operator()(const Device& d, typename TTypes::ConstTensor gradients, typename TTypes::ConstTensor features, typename TTypes::Tensor backprops) { - typename TTypes::Tensor& e = features.exp(); - typename TTypes::Tensor& es = e.square(); - typename TTypes::Tensor& omega = static_cast(4) * (features + static_cast(1)) + + const auto& e = features.exp().eval(); + const auto& es = e.square().eval(); + const auto& omega = static_cast(4) * (features + static_cast(1)) + static_cast(4) * es + e.cube() + e * (static_cast(4) * features + static_cast(6)); - typename TTypes::Tensor& delta = static_cast(2) * e + es + static_cast(2); + const auto& delta = static_cast(2) * e + es + static_cast(2); backprops.device(d) = gradients * e * omega / delta.square(); } }; From 6057c6947d7176411bd2161db89f0640915bbc09 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Fri, 4 Oct 2019 18:37:49 -0700 Subject: [PATCH 07/13] remove duplicated --- tensorflow_addons/activations/README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/tensorflow_addons/activations/README.md b/tensorflow_addons/activations/README.md index 4010a23ce7..82c98c1295 100644 --- a/tensorflow_addons/activations/README.md +++ b/tensorflow_addons/activations/README.md @@ -19,8 +19,6 @@ | mish | mish | https://arxiv.org/abs/1908.08681 | | sparsemax | sparsemax | https://arxiv.org/abs/1602.02068 | | tanhshrink| tanhshrink | | -| sparsemax | sparsemax | https://arxiv.org/abs/1602.02068 | -| tanhshrink| tanhshrink | | ## Contribution Guidelines #### Standard API From e66ed2d7691edd9d7395122efa77e4933644ffc6 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Fri, 4 Oct 2019 18:40:33 -0700 Subject: [PATCH 08/13] format codes --- tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h b/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h index 525c50621a..0202adc52f 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h @@ -53,8 +53,8 @@ struct MishGrad { const auto& e = features.exp().eval(); const auto& es = e.square().eval(); const auto& omega = static_cast(4) * (features + static_cast(1)) + - static_cast(4) * es + e.cube() + - e * (static_cast(4) * features + static_cast(6)); + static_cast(4) * es + e.cube() + + e * (static_cast(4) * features + static_cast(6)); const auto& delta = static_cast(2) * e + es + static_cast(2); backprops.device(d) = gradients * e * omega / delta.square(); } From ce0856fe36e9b89e470cf02ec7a3c439e18039db Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Sat, 5 Oct 2019 10:04:44 -0700 Subject: [PATCH 09/13] update maintainer --- tensorflow_addons/activations/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/activations/README.md b/tensorflow_addons/activations/README.md index 82c98c1295..d7eae79348 100644 --- a/tensorflow_addons/activations/README.md +++ b/tensorflow_addons/activations/README.md @@ -6,7 +6,7 @@ | gelu | @AakashKumarNain @WindQAQ | aakashnain@outlook.com windqaq@gmail.com | | hardshrink| @WindQAQ | windqaq@gmail.com | | lisht | @WindQAQ | windqaq@gmail.com | -| mish | @WindQAQ | windqaq@gmail.com | +| mish | @digantamisra98 @WindQAQ | mishradiganta91@gmail.com, windqaq@gmail.com | | sparsemax | @AndreasMadsen | amwwebdk+github@gmail.com | | tanhshrink| @fsx950223 | fsx950223@gmail.com | From 3ae20a2b8f4636ab0333fc060a6cee49cc7efbbb Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Thu, 10 Oct 2019 23:01:40 -0700 Subject: [PATCH 10/13] safely deal with overflow/underflow --- .../activations/cc/kernels/mish_op.h | 37 +++++++++++++++---- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h b/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h index 0202adc52f..c111f99ef7 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h @@ -35,7 +35,18 @@ struct Mish { // activations: same shape as "features". void operator()(const Device& d, typename TTypes::ConstTensor features, typename TTypes::Tensor activations) { - activations.device(d) = features * features.exp().log1p().tanh(); + // softplus implementation + // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/softplus_op.h + static const T threshold = + Eigen::numext::log(Eigen::NumTraits::epsilon()) + T(2); + const auto& too_large = features > features.constant(-threshold); + const auto& too_small = features < features.constant(threshold); + const auto& features_exp = features.exp(); + const auto& sp = too_large.select( + features, + too_small.select(features_exp, + (features_exp + features.constant(T(1))).log())); + activations.device(d) = features * sp.tanh(); } }; @@ -50,13 +61,23 @@ struct MishGrad { void operator()(const Device& d, typename TTypes::ConstTensor gradients, typename TTypes::ConstTensor features, typename TTypes::Tensor backprops) { - const auto& e = features.exp().eval(); - const auto& es = e.square().eval(); - const auto& omega = static_cast(4) * (features + static_cast(1)) + - static_cast(4) * es + e.cube() + - e * (static_cast(4) * features + static_cast(6)); - const auto& delta = static_cast(2) * e + es + static_cast(2); - backprops.device(d) = gradients * e * omega / delta.square(); + // softplus implementation + // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/softplus_op.h + static const T threshold = + Eigen::numext::log(Eigen::NumTraits::epsilon()) + T(2); + const auto& too_large = features > features.constant(-threshold); + const auto& too_small = features < features.constant(threshold); + const auto& features_exp = features.exp(); + const auto& sp = too_large.select( + features, + too_small.select(features_exp, + (features_exp + features.constant(T(1))).log())); + + const auto& grad_sp = static_cast(1) - (-sp).exp(); + const auto& tsp = sp.tanh(); + const auto& grad_tsp = ((static_cast(1) - tsp * tsp) * grad_sp); + const auto& grad = features * grad_tsp + tsp; + backprops.device(d) = gradients * grad; } }; From b850de6492732cb9aed8df8ea341ea64cb6dea70 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Thu, 10 Oct 2019 23:03:50 -0700 Subject: [PATCH 11/13] test values that extremely close to zero --- tensorflow_addons/activations/mish_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/activations/mish_test.py b/tensorflow_addons/activations/mish_test.py index 74e374ed75..2ffa2e7f9b 100644 --- a/tensorflow_addons/activations/mish_test.py +++ b/tensorflow_addons/activations/mish_test.py @@ -41,7 +41,7 @@ def test_mish(self, dtype): def test_theoretical_gradients(self, dtype): # Only test theoretical gradients for float32 and float64 # because of the instability of float16 while computing jacobian - x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) + x = tf.constant([-1e-20, -2.0, -1.0, 0.0, 1.0, 2.0, 1e-20], dtype=dtype) theoretical, numerical = tf.test.compute_gradient(mish, [x]) self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4) From 504e2e8c3d35b7e9a8c5d795207964b6b07268fa Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Thu, 10 Oct 2019 23:11:17 -0700 Subject: [PATCH 12/13] format codes --- tensorflow_addons/activations/activations_test.py | 3 ++- tensorflow_addons/activations/mish_test.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow_addons/activations/activations_test.py b/tensorflow_addons/activations/activations_test.py index ac687ed6b8..b2df48e47f 100644 --- a/tensorflow_addons/activations/activations_test.py +++ b/tensorflow_addons/activations/activations_test.py @@ -26,7 +26,8 @@ class ActivationsTest(tf.test.TestCase): ALL_ACTIVATIONS = [ - "gelu", "hardshrink", "lisht", "mish", "softshrink", "sparsemax", "tanhshrink" + "gelu", "hardshrink", "lisht", "mish", "softshrink", "sparsemax", + "tanhshrink" ] def test_serialization(self): diff --git a/tensorflow_addons/activations/mish_test.py b/tensorflow_addons/activations/mish_test.py index 2ffa2e7f9b..2de6013055 100644 --- a/tensorflow_addons/activations/mish_test.py +++ b/tensorflow_addons/activations/mish_test.py @@ -41,7 +41,8 @@ def test_mish(self, dtype): def test_theoretical_gradients(self, dtype): # Only test theoretical gradients for float32 and float64 # because of the instability of float16 while computing jacobian - x = tf.constant([-1e-20, -2.0, -1.0, 0.0, 1.0, 2.0, 1e-20], dtype=dtype) + x = tf.constant([-1e-20, -2.0, -1.0, 0.0, 1.0, 2.0, 1e-20], + dtype=dtype) theoretical, numerical = tf.test.compute_gradient(mish, [x]) self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4) From 163eadccab04d48470c47878d0522142e35312c3 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Fri, 11 Oct 2019 09:33:32 -0700 Subject: [PATCH 13/13] remove values close to zero --- tensorflow_addons/activations/mish_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tensorflow_addons/activations/mish_test.py b/tensorflow_addons/activations/mish_test.py index 2de6013055..74e374ed75 100644 --- a/tensorflow_addons/activations/mish_test.py +++ b/tensorflow_addons/activations/mish_test.py @@ -41,8 +41,7 @@ def test_mish(self, dtype): def test_theoretical_gradients(self, dtype): # Only test theoretical gradients for float32 and float64 # because of the instability of float16 while computing jacobian - x = tf.constant([-1e-20, -2.0, -1.0, 0.0, 1.0, 2.0, 1e-20], - dtype=dtype) + x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) theoretical, numerical = tf.test.compute_gradient(mish, [x]) self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4)