-
Notifications
You must be signed in to change notification settings - Fork 617
[activations] fused gelu kernel #427
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
seanpmorgan
merged 17 commits into
tensorflow:master
from
WindQAQ:activations/fused-gelu-kernel
Aug 26, 2019
Merged
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
b35e84d
add CPU and GPU kernel for gelu
WindQAQ 73e55a8
add some documentations
WindQAQ 921d6b4
format codes
WindQAQ 799b610
support original (non-approximate) gelu
WindQAQ 261cd49
GPUDevice is super fast
WindQAQ 705a080
fix typo
WindQAQ e950650
format codes
WindQAQ d72774a
python API for gelu
WindQAQ 0103fcd
unittests for gelu
WindQAQ 15d2c28
update BUILD file
WindQAQ b85b7a3
lint
WindQAQ 8d0f2e5
update init and README
WindQAQ 240cbfb
alphabetical order
WindQAQ 517f59e
update docs
WindQAQ 2f2bb88
update docs
WindQAQ 4d4faa4
test gradients on non-approximate gelu
WindQAQ 0a93ff2
change test name
WindQAQ File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,14 +1,16 @@ | ||
| # Addons - Activations | ||
|
|
||
| ## Maintainers | ||
| | Submodule | Maintainers | Contact Info | | ||
| |:---------- |:------------- |:--------------| | ||
| | sparsemax | @AndreasMadsen | [email protected] | | ||
| | Submodule | Maintainers | Contact Info | | ||
| |:----------|:--------------------------|:-----------------------------------------| | ||
| | gelu | @AakashKumarNain @WindQAQ | [email protected] [email protected] | | ||
| | sparsemax | @AndreasMadsen | [email protected] | | ||
|
|
||
| ## Contents | ||
| | Submodule | Activation | Reference | | ||
| |:----------------------- |:-------------------|:---------------| | ||
| | sparsemax | Sparsemax | https://arxiv.org/abs/1602.02068 | | ||
| | Submodule | Activation | Reference | | ||
| |:----------|:-----------|:---------------------------------| | ||
| | gelu | gelu | https://arxiv.org/abs/1606.08415 | | ||
| | sparsemax | Sparsemax | https://arxiv.org/abs/1602.02068 | | ||
|
|
||
|
|
||
| ## Contribution Guidelines | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| # Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
|
|
||
| from __future__ import absolute_import | ||
| from __future__ import division | ||
| from __future__ import print_function | ||
|
|
||
| import tensorflow as tf | ||
| from tensorflow_addons.utils import keras_utils | ||
| from tensorflow_addons.utils.resource_loader import get_path_to_datafile | ||
|
|
||
| _activation_ops_so = tf.load_op_library( | ||
| get_path_to_datafile("custom_ops/activations/_activation_ops.so")) | ||
|
|
||
|
|
||
| @keras_utils.register_keras_custom_object | ||
| @tf.function | ||
| def gelu(x, approximate=True): | ||
| """Gaussian Error Linear Unit. | ||
|
|
||
| Computes gaussian error linear: | ||
| `0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x^3)))` or | ||
| `x * P(X <= x) = 0.5 * x * (1 + erf(x / sqrt(2)))`, where P(X) ~ N(0, 1), | ||
| depending on whether approximation is enabled. | ||
|
|
||
| See [Gaussian Error Linear Units (GELUs)](https://arxiv.org/abs/1606.08415) | ||
| and [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805). | ||
|
|
||
| Args: | ||
| x: A `Tensor`. Must be one of the following types: | ||
| `float16`, `float32`, `float64`. | ||
| approximate: bool, whether to enable approximation. | ||
| Returns: | ||
| A `Tensor`. Has the same type as `x`. | ||
| """ | ||
| x = tf.convert_to_tensor(x) | ||
| return _activation_ops_so.gelu(x, approximate) | ||
|
|
||
|
|
||
| @tf.RegisterGradient("Gelu") | ||
| def _gelu_grad(op, grad): | ||
| return _activation_ops_so.gelu_grad(grad, op.inputs[0], | ||
| op.get_attr("approximate")) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| # Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
|
|
||
| from __future__ import absolute_import | ||
| from __future__ import division | ||
| from __future__ import print_function | ||
|
|
||
| from absl.testing import parameterized | ||
|
|
||
| import math | ||
|
|
||
| import numpy as np | ||
| import tensorflow as tf | ||
| from tensorflow_addons.activations import gelu | ||
| from tensorflow_addons.utils import test_utils | ||
|
|
||
|
|
||
| def _ref_gelu(x, approximate=True): | ||
| x = tf.convert_to_tensor(x) | ||
| if approximate: | ||
| pi = tf.cast(math.pi, x.dtype) | ||
| coeff = tf.cast(0.044715, x.dtype) | ||
| return 0.5 * x * ( | ||
| 1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3)))) | ||
| else: | ||
| return 0.5 * x * ( | ||
| 1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype))) | ||
|
|
||
|
|
||
| @test_utils.run_all_in_graph_and_eager_modes | ||
| class GeluTest(tf.test.TestCase, parameterized.TestCase): | ||
| @parameterized.named_parameters(("float16", np.float16), | ||
| ("float32", np.float32), | ||
| ("float64", np.float64)) | ||
| def test_gelu(self, dtype): | ||
| x = np.random.rand(2, 3, 4).astype(dtype) | ||
| self.assertAllCloseAccordingToType(gelu(x), _ref_gelu(x)) | ||
| self.assertAllCloseAccordingToType(gelu(x, False), _ref_gelu(x, False)) | ||
|
|
||
| @parameterized.named_parameters(("float16", np.float16), | ||
| ("float32", np.float32), | ||
| ("float64", np.float64)) | ||
| def test_gradients(self, dtype): | ||
| x = tf.constant([1.0, 2.0, 3.0], dtype=dtype) | ||
|
|
||
| for approximate in [True, False]: | ||
| with self.subTest(approximate=approximate): | ||
| with tf.GradientTape(persistent=True) as tape: | ||
| tape.watch(x) | ||
| y_ref = _ref_gelu(x, approximate) | ||
| y = gelu(x, approximate) | ||
| grad_ref = tape.gradient(y_ref, x) | ||
| grad = tape.gradient(y, x) | ||
| self.assertAllCloseAccordingToType(grad, grad_ref) | ||
|
|
||
| @parameterized.named_parameters(("float32", np.float32), | ||
| ("float64", np.float64)) | ||
| def test_theoretical_gradients(self, dtype): | ||
| # Only test theoretical gradients for float32 and float64 | ||
| # because of the instability of float16 while computing jacobian | ||
| x = tf.constant([1.0, 2.0, 3.0], dtype=dtype) | ||
|
|
||
| for approximate in [True, False]: | ||
| with self.subTest(approximate=approximate): | ||
| theoretical, numerical = tf.test.compute_gradient( | ||
| lambda x: gelu(x, approximate=approximate), [x]) | ||
| self.assertAllCloseAccordingToType( | ||
| theoretical, numerical, atol=1e-4) | ||
|
|
||
| def test_unknown_shape(self): | ||
| fn = gelu.get_concrete_function( | ||
| tf.TensorSpec(shape=None, dtype=tf.float32)) | ||
|
|
||
| for shape in [(1,), (1, 2), (1, 2, 3), (1, 2, 3, 4)]: | ||
| x = tf.ones(shape=shape, dtype=tf.float32) | ||
| self.assertAllClose(fn(x), gelu(x)) | ||
|
|
||
| def test_serialization(self): | ||
| ref_fn = gelu | ||
| config = tf.keras.activations.serialize(ref_fn) | ||
| fn = tf.keras.activations.deserialize(config) | ||
| self.assertEqual(fn, ref_fn) | ||
|
|
||
| def test_serialization_with_layers(self): | ||
| layer = tf.keras.layers.Dense(3, activation=gelu) | ||
| config = tf.keras.layers.serialize(layer) | ||
| deserialized_layer = tf.keras.layers.deserialize(config) | ||
| self.assertEqual(deserialized_layer.__class__.__name__, | ||
| layer.__class__.__name__) | ||
| self.assertEqual(deserialized_layer.activation.__name__, "gelu") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| tf.test.main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| licenses(["notice"]) # Apache 2.0 | ||
|
|
||
| package(default_visibility = ["//visibility:public"]) | ||
|
|
||
| load("@local_config_tf//:build_defs.bzl", "D_GLIBCXX_USE_CXX11_ABI") | ||
| load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured", "if_cuda") | ||
|
|
||
| cc_library( | ||
| name = "gelu_op_gpu", | ||
| srcs = [ | ||
| "cc/kernels/gelu_op.h", | ||
| "cc/kernels/gelu_op_gpu.cu.cc", | ||
| ], | ||
| copts = if_cuda_is_configured([ | ||
| "-DGOOGLE_CUDA=1", | ||
| "-x cuda", | ||
| "-nvcc_options=relaxed-constexpr", | ||
| "-nvcc_options=ftz=true", | ||
| ]), | ||
| deps = [ | ||
| "@local_config_tf//:libtensorflow_framework", | ||
| "@local_config_tf//:tf_header_lib", | ||
| ] + if_cuda_is_configured([ | ||
| "@local_config_cuda//cuda:cuda_libs", | ||
| "@local_config_cuda//cuda:cuda_headers", | ||
| ]), | ||
| alwayslink = 1, | ||
| ) | ||
|
|
||
| cc_binary( | ||
| name = "_activation_ops.so", | ||
| srcs = [ | ||
| "cc/kernels/gelu_op.cc", | ||
| "cc/kernels/gelu_op.h", | ||
| "cc/ops/gelu_op.cc", | ||
| ], | ||
| copts = [ | ||
| "-pthread", | ||
| "-std=c++11", | ||
| D_GLIBCXX_USE_CXX11_ABI, | ||
| ] + if_cuda(["-DGOOGLE_CUDA=1"]), | ||
| linkshared = 1, | ||
| deps = [ | ||
| "@local_config_tf//:libtensorflow_framework", | ||
| "@local_config_tf//:tf_header_lib", | ||
| ] + if_cuda_is_configured([":gelu_op_gpu"]), | ||
| ) |
77 changes: 77 additions & 0 deletions
77
tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| ==============================================================================*/ | ||
|
|
||
| #define EIGEN_USE_THREADS | ||
|
|
||
| #include "tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h" | ||
| #include "tensorflow/core/framework/op_kernel.h" | ||
| #include "tensorflow/core/framework/register_types.h" | ||
| #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" | ||
|
|
||
| namespace tensorflow { | ||
|
|
||
| using CPUDevice = Eigen::ThreadPoolDevice; | ||
|
|
||
| #define REGISTER_GELU_KERNELS(type) \ | ||
| REGISTER_KERNEL_BUILDER( \ | ||
| Name("Gelu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ | ||
| GeluOp<CPUDevice, type>); \ | ||
| REGISTER_KERNEL_BUILDER( \ | ||
| Name("GeluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ | ||
| GeluGradOp<CPUDevice, type>); | ||
|
|
||
| // Gelu only makes sense with floating points. | ||
| TF_CALL_GPU_NUMBER_TYPES(REGISTER_GELU_KERNELS); | ||
| #undef REGISTER_GELU_KERNELS | ||
|
|
||
| #ifdef GOOGLE_CUDA | ||
|
|
||
| using GPUDevice = Eigen::GpuDevice; | ||
|
|
||
| // Forward declarations of the functor specializations for GPU. | ||
| namespace functor { | ||
| #define DECLARE_GPU_SPEC(T) \ | ||
| template <> \ | ||
| void Gelu<GPUDevice, T>::operator()( \ | ||
| const GPUDevice& d, typename TTypes<T>::ConstTensor features, \ | ||
| bool approximate, typename TTypes<T>::Tensor activations); \ | ||
| extern template struct Gelu<GPUDevice, T>; \ | ||
| \ | ||
| template <> \ | ||
| void GeluGrad<GPUDevice, T>::operator()( \ | ||
| const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \ | ||
| typename TTypes<T>::ConstTensor features, bool approximate, \ | ||
| typename TTypes<T>::Tensor backprops); \ | ||
| extern template struct GeluGrad<GPUDevice, T>; | ||
|
|
||
| TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); | ||
| #undef DECLARE_GPU_SPEC | ||
| } // namespace functor | ||
|
|
||
| // Registration of the GPU implementations. | ||
| #define REGISTER_GELU_GPU_KERNELS(type) \ | ||
| REGISTER_KERNEL_BUILDER( \ | ||
| Name("Gelu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ | ||
| GeluOp<GPUDevice, type>); \ | ||
| REGISTER_KERNEL_BUILDER( \ | ||
| Name("GeluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ | ||
| GeluGradOp<GPUDevice, type>); | ||
|
|
||
| TF_CALL_GPU_NUMBER_TYPES(REGISTER_GELU_GPU_KERNELS); | ||
| #undef REGISTER_GELU_GPU_KERNELS | ||
|
|
||
| #endif // GOOGLE_CUDA | ||
|
|
||
| } // namespace tensorflow |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about testing integer and quantization types?
For integer types, I believe that Gelu will simply behave as Relu.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not really get your point. Like
int32, do you mean we first cast it tofloatand do computations infloatand finally cast back toint32? If so, it's weird why users don't explicitly castint32tofloat, and cast output toint32.Actually, most of activation ops in core TF (and PyTorch) can support only floating points input. ReLU/ReLU6 is an exception because
cwiseMax/cwiseMincan run in non-floating dtype.https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/nn_ops.cc#L1053-L1144
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, after a rough computing with google calculator, I found there are some gap between ReLU and GeLU with int type. When input=2, approximate version shows the result of 1.95459769409 and non-approximate version shows the one of 1.9544997361. Get deeper into the definition of GeLU:
When
x=2, gelu(2) = 2 * normcdf(2) ~= 2 * 0.9772 != 2.Approximate
non-approximate
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mostafaelhoushi
For integer types, I believe that Gelu will simply behave as Relu.I don't think this is true. Can you elaborate a bit on this?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @WindQAQ and @AakashKumarNain for your feedback.
I meant if both the input and output are constrained to be integer, then Gelu will behave as Relu. e.g.,
for the example that @WindQAQ mentioned:
gelu(2) ~= 2*0.9772 - 1.9554but if the activations are constrained to be input then we need to round the output to the nearest integer... so
round(gelu(2)) = 2 = relu(2)However, @WindQAQ mentioned an important point that "most of activation ops in core TF (and PyTorch) can support only floating points input. ReLU/ReLU6". Hence, I think you may safely ignore this suggestion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay! Thanks again for the review :-)