Skip to content
Merged
14 changes: 14 additions & 0 deletions tensorflow_addons/activations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ py_library(
"gelu.py",
"hardshrink.py",
"lisht.py",
"mish.py",
"rrelu.py",
"softshrink.py",
"sparsemax.py",
Expand Down Expand Up @@ -86,6 +87,19 @@ py_test(
],
)

py_test(
name = "mish_test",
size = "small",
srcs = [
"mish_test.py",
],
main = "mish_test.py",
srcs_version = "PY2AND3",
deps = [
":activations",
],
)

py_test(
name = "softshrink_test",
size = "small",
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_addons/activations/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
| gelu | @AakashKumarNain @WindQAQ | [email protected] [email protected] |
| hardshrink| @WindQAQ | [email protected] |
| lisht | @WindQAQ | [email protected] |
| mish | @digantamisra98 @WindQAQ | [email protected], [email protected] |
| softshrink| @WindQAQ | [email protected] |
| sparsemax | @AndreasMadsen | [email protected] |
| tanhshrink| @fsx950223 | [email protected] |
Expand All @@ -17,12 +18,12 @@
| gelu | gelu | https://arxiv.org/abs/1606.08415 |
| hardshrink| hardshrink | |
| lisht | lisht | https://arxiv.org/abs/1901.05894 |
| mish | mish | https://arxiv.org/abs/1908.08681 |
| softshrink| softshrink | |
| sparsemax | sparsemax | https://arxiv.org/abs/1602.02068 |
| tanhshrink| tanhshrink | |
| rrelu | rrelu | https://arxiv.org/abs/1505.00853 |


## Contribution Guidelines
#### Standard API
In order to conform with the current API standard, all activations
Expand Down
1 change: 1 addition & 0 deletions tensorflow_addons/activations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from tensorflow_addons.activations.gelu import gelu
from tensorflow_addons.activations.hardshrink import hardshrink
from tensorflow_addons.activations.lisht import lisht
from tensorflow_addons.activations.mish import mish
from tensorflow_addons.activations.softshrink import softshrink
from tensorflow_addons.activations.rrelu import rrelu
from tensorflow_addons.activations.sparsemax import sparsemax
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_addons/activations/activations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
class ActivationsTest(tf.test.TestCase):

ALL_ACTIVATIONS = [
"gelu", "hardshrink", "lisht", "softshrink", "sparsemax", "rrelu",
"tanhshrink"
"gelu", "hardshrink", "lisht", "mish", "rrelu", "softshrink",
"sparsemax", "tanhshrink"
]

def test_serialization(self):
Expand Down
49 changes: 49 additions & 0 deletions tensorflow_addons/activations/mish.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow_addons.utils import keras_utils
from tensorflow_addons.utils.resource_loader import get_path_to_datafile

_activation_ops_so = tf.load_op_library(
get_path_to_datafile("custom_ops/activations/_activation_ops.so"))


@keras_utils.register_keras_custom_object
@tf.function
def mish(x):
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function.

Computes mish activation: x * tanh(softplus(x))

See [Mish: A Self Regularized Non-Monotonic Neural Activation Function](https://arxiv.org/abs/1908.08681).

Args:
x: A `Tensor`. Must be one of the following types:
`float16`, `float32`, `float64`.
Returns:
A `Tensor`. Has the same type as `x`.
"""
x = tf.convert_to_tensor(x)
return _activation_ops_so.addons_mish(x)


@tf.RegisterGradient("Addons>Mish")
def _mish_grad(op, grad):
return _activation_ops_so.addons_mish_grad(grad, op.inputs[0])
59 changes: 59 additions & 0 deletions tensorflow_addons/activations/mish_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl.testing import parameterized

import numpy as np
import tensorflow as tf
from tensorflow_addons.activations import mish
from tensorflow_addons.utils import test_utils


@test_utils.run_all_in_graph_and_eager_modes
class MishTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(("float16", np.float16),
("float32", np.float32),
("float64", np.float64))
def test_mish(self, dtype):
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)
expected_result = tf.constant(
[-0.2525015, -0.30340144, 0.0, 0.86509836, 1.943959], dtype=dtype)
self.assertAllCloseAccordingToType(mish(x), expected_result)

@parameterized.named_parameters(("float32", np.float32),
("float64", np.float64))
def test_theoretical_gradients(self, dtype):
# Only test theoretical gradients for float32 and float64
# because of the instability of float16 while computing jacobian
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)

theoretical, numerical = tf.test.compute_gradient(mish, [x])
self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4)

def test_unknown_shape(self):
fn = mish.get_concrete_function(
tf.TensorSpec(shape=None, dtype=tf.float32))

for shape in [(1,), (1, 2), (1, 2, 3), (1, 2, 3, 4)]:
x = tf.ones(shape=shape, dtype=tf.float32)
self.assertAllClose(fn(x), mish(x))


if __name__ == "__main__":
tf.test.main()
5 changes: 5 additions & 0 deletions tensorflow_addons/custom_ops/activations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ custom_op_library(
"cc/kernels/hardshrink_op.h",
"cc/kernels/lisht_op.cc",
"cc/kernels/lisht_op.h",
"cc/kernels/mish_op.cc",
"cc/kernels/mish_op.h",
"cc/kernels/rrelu_op.cc",
"cc/kernels/rrelu_op.h",
"cc/kernels/softshrink_op.cc",
Expand All @@ -22,6 +24,7 @@ custom_op_library(
"cc/ops/gelu_op.cc",
"cc/ops/hardshrink_op.cc",
"cc/ops/lisht_op.cc",
"cc/ops/mish_op.cc",
"cc/ops/rrelu_op.cc",
"cc/ops/softshrink_op.cc",
"cc/ops/tanhshrink_op.cc",
Expand All @@ -33,6 +36,8 @@ custom_op_library(
"cc/kernels/hardshrink_op_gpu.cu.cc",
"cc/kernels/lisht_op.h",
"cc/kernels/lisht_op_gpu.cu.cc",
"cc/kernels/mish_op.h",
"cc/kernels/mish_op_gpu.cu.cc",
"cc/kernels/rrelu_op.h",
"cc/kernels/rrelu_op_gpu.cu.cc",
"cc/kernels/softshrink_op.h",
Expand Down
79 changes: 79 additions & 0 deletions tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#define EIGEN_USE_THREADS

#include "tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

namespace tensorflow {
namespace addons {

using CPUDevice = Eigen::ThreadPoolDevice;

#define REGISTER_MISH_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Addons>Mish").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
MishOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("Addons>MishGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
MishGradOp<CPUDevice, type>);

// Mish only makes sense with floating points.
TF_CALL_GPU_NUMBER_TYPES(REGISTER_MISH_KERNELS);
#undef REGISTER_MISH_KERNELS

#if GOOGLE_CUDA

using GPUDevice = Eigen::GpuDevice;

// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void Mish<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstTensor features, \
typename TTypes<T>::Tensor activations); \
extern template struct Mish<GPUDevice, T>; \
\
template <> \
void MishGrad<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
typename TTypes<T>::ConstTensor features, \
typename TTypes<T>::Tensor backprops); \
extern template struct MishGrad<GPUDevice, T>;

TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
#undef DECLARE_GPU_SPEC
} // namespace functor

// Registration of the GPU implementations.
#define REGISTER_MISH_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Addons>Mish").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
MishOp<GPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("Addons>MishGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
MishGradOp<GPUDevice, type>);

TF_CALL_GPU_NUMBER_TYPES(REGISTER_MISH_GPU_KERNELS);
#undef REGISTER_MISH_GPU_KERNELS

#endif // GOOGLE_CUDA

} // namespace addons
} // namespace tensorflow
Loading