diff --git a/tensorflow_addons/activations/BUILD b/tensorflow_addons/activations/BUILD index d454860322..34e87c6298 100644 --- a/tensorflow_addons/activations/BUILD +++ b/tensorflow_addons/activations/BUILD @@ -6,12 +6,14 @@ py_library( name = "activations", srcs = [ "__init__.py", + "gelu.py", "sparsemax.py", ], - srcs_version = "PY2AND3", - deps = [ + data = [ + "//tensorflow_addons/custom_ops/activations:_activation_ops.so", "//tensorflow_addons/utils", ], + srcs_version = "PY2AND3", ) py_test( @@ -26,3 +28,16 @@ py_test( ":activations", ], ) + +py_test( + name = "gelu_test", + size = "large", + srcs = [ + "gelu_test.py", + ], + main = "gelu_test.py", + srcs_version = "PY2AND3", + deps = [ + ":activations", + ], +) diff --git a/tensorflow_addons/activations/README.md b/tensorflow_addons/activations/README.md index 4ab59b23bb..500eee194b 100644 --- a/tensorflow_addons/activations/README.md +++ b/tensorflow_addons/activations/README.md @@ -1,14 +1,16 @@ # Addons - Activations ## Maintainers -| Submodule | Maintainers | Contact Info | -|:---------- |:------------- |:--------------| -| sparsemax | @AndreasMadsen | amwwebdk+github@gmail.com | +| Submodule | Maintainers | Contact Info | +|:----------|:--------------------------|:-----------------------------------------| +| gelu | @AakashKumarNain @WindQAQ | aakashnain@outlook.com windqaq@gmail.com | +| sparsemax | @AndreasMadsen | amwwebdk+github@gmail.com | ## Contents -| Submodule | Activation | Reference | -|:----------------------- |:-------------------|:---------------| -| sparsemax | Sparsemax | https://arxiv.org/abs/1602.02068 | +| Submodule | Activation | Reference | +|:----------|:-----------|:---------------------------------| +| gelu | gelu | https://arxiv.org/abs/1606.08415 | +| sparsemax | Sparsemax | https://arxiv.org/abs/1602.02068 | ## Contribution Guidelines diff --git a/tensorflow_addons/activations/__init__.py b/tensorflow_addons/activations/__init__.py index 5792d00356..45903a3975 100644 --- a/tensorflow_addons/activations/__init__.py +++ b/tensorflow_addons/activations/__init__.py @@ -18,4 +18,5 @@ from __future__ import division from __future__ import print_function +from tensorflow_addons.activations.gelu import gelu from tensorflow_addons.activations.sparsemax import sparsemax diff --git a/tensorflow_addons/activations/gelu.py b/tensorflow_addons/activations/gelu.py new file mode 100644 index 0000000000..539afbbe1c --- /dev/null +++ b/tensorflow_addons/activations/gelu.py @@ -0,0 +1,55 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow_addons.utils import keras_utils +from tensorflow_addons.utils.resource_loader import get_path_to_datafile + +_activation_ops_so = tf.load_op_library( + get_path_to_datafile("custom_ops/activations/_activation_ops.so")) + + +@keras_utils.register_keras_custom_object +@tf.function +def gelu(x, approximate=True): + """Gaussian Error Linear Unit. + + Computes gaussian error linear: + `0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x^3)))` or + `x * P(X <= x) = 0.5 * x * (1 + erf(x / sqrt(2)))`, where P(X) ~ N(0, 1), + depending on whether approximation is enabled. + + See [Gaussian Error Linear Units (GELUs)](https://arxiv.org/abs/1606.08415) + and [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805). + + Args: + x: A `Tensor`. Must be one of the following types: + `float16`, `float32`, `float64`. + approximate: bool, whether to enable approximation. + Returns: + A `Tensor`. Has the same type as `x`. + """ + x = tf.convert_to_tensor(x) + return _activation_ops_so.gelu(x, approximate) + + +@tf.RegisterGradient("Gelu") +def _gelu_grad(op, grad): + return _activation_ops_so.gelu_grad(grad, op.inputs[0], + op.get_attr("approximate")) diff --git a/tensorflow_addons/activations/gelu_test.py b/tensorflow_addons/activations/gelu_test.py new file mode 100644 index 0000000000..f510715593 --- /dev/null +++ b/tensorflow_addons/activations/gelu_test.py @@ -0,0 +1,106 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +import math + +import numpy as np +import tensorflow as tf +from tensorflow_addons.activations import gelu +from tensorflow_addons.utils import test_utils + + +def _ref_gelu(x, approximate=True): + x = tf.convert_to_tensor(x) + if approximate: + pi = tf.cast(math.pi, x.dtype) + coeff = tf.cast(0.044715, x.dtype) + return 0.5 * x * ( + 1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3)))) + else: + return 0.5 * x * ( + 1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype))) + + +@test_utils.run_all_in_graph_and_eager_modes +class GeluTest(tf.test.TestCase, parameterized.TestCase): + @parameterized.named_parameters(("float16", np.float16), + ("float32", np.float32), + ("float64", np.float64)) + def test_gelu(self, dtype): + x = np.random.rand(2, 3, 4).astype(dtype) + self.assertAllCloseAccordingToType(gelu(x), _ref_gelu(x)) + self.assertAllCloseAccordingToType(gelu(x, False), _ref_gelu(x, False)) + + @parameterized.named_parameters(("float16", np.float16), + ("float32", np.float32), + ("float64", np.float64)) + def test_gradients(self, dtype): + x = tf.constant([1.0, 2.0, 3.0], dtype=dtype) + + for approximate in [True, False]: + with self.subTest(approximate=approximate): + with tf.GradientTape(persistent=True) as tape: + tape.watch(x) + y_ref = _ref_gelu(x, approximate) + y = gelu(x, approximate) + grad_ref = tape.gradient(y_ref, x) + grad = tape.gradient(y, x) + self.assertAllCloseAccordingToType(grad, grad_ref) + + @parameterized.named_parameters(("float32", np.float32), + ("float64", np.float64)) + def test_theoretical_gradients(self, dtype): + # Only test theoretical gradients for float32 and float64 + # because of the instability of float16 while computing jacobian + x = tf.constant([1.0, 2.0, 3.0], dtype=dtype) + + for approximate in [True, False]: + with self.subTest(approximate=approximate): + theoretical, numerical = tf.test.compute_gradient( + lambda x: gelu(x, approximate=approximate), [x]) + self.assertAllCloseAccordingToType( + theoretical, numerical, atol=1e-4) + + def test_unknown_shape(self): + fn = gelu.get_concrete_function( + tf.TensorSpec(shape=None, dtype=tf.float32)) + + for shape in [(1,), (1, 2), (1, 2, 3), (1, 2, 3, 4)]: + x = tf.ones(shape=shape, dtype=tf.float32) + self.assertAllClose(fn(x), gelu(x)) + + def test_serialization(self): + ref_fn = gelu + config = tf.keras.activations.serialize(ref_fn) + fn = tf.keras.activations.deserialize(config) + self.assertEqual(fn, ref_fn) + + def test_serialization_with_layers(self): + layer = tf.keras.layers.Dense(3, activation=gelu) + config = tf.keras.layers.serialize(layer) + deserialized_layer = tf.keras.layers.deserialize(config) + self.assertEqual(deserialized_layer.__class__.__name__, + layer.__class__.__name__) + self.assertEqual(deserialized_layer.activation.__name__, "gelu") + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_addons/custom_ops/activations/BUILD b/tensorflow_addons/custom_ops/activations/BUILD new file mode 100644 index 0000000000..a199fbc689 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/BUILD @@ -0,0 +1,47 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) + +load("@local_config_tf//:build_defs.bzl", "D_GLIBCXX_USE_CXX11_ABI") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured", "if_cuda") + +cc_library( + name = "gelu_op_gpu", + srcs = [ + "cc/kernels/gelu_op.h", + "cc/kernels/gelu_op_gpu.cu.cc", + ], + copts = if_cuda_is_configured([ + "-DGOOGLE_CUDA=1", + "-x cuda", + "-nvcc_options=relaxed-constexpr", + "-nvcc_options=ftz=true", + ]), + deps = [ + "@local_config_tf//:libtensorflow_framework", + "@local_config_tf//:tf_header_lib", + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_libs", + "@local_config_cuda//cuda:cuda_headers", + ]), + alwayslink = 1, +) + +cc_binary( + name = "_activation_ops.so", + srcs = [ + "cc/kernels/gelu_op.cc", + "cc/kernels/gelu_op.h", + "cc/ops/gelu_op.cc", + ], + copts = [ + "-pthread", + "-std=c++11", + D_GLIBCXX_USE_CXX11_ABI, + ] + if_cuda(["-DGOOGLE_CUDA=1"]), + linkshared = 1, + deps = [ + "@local_config_tf//:libtensorflow_framework", + "@local_config_tf//:tf_header_lib", + ] + if_cuda_is_configured([":gelu_op_gpu"]), +) diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc new file mode 100644 index 0000000000..e6a8788519 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc @@ -0,0 +1,77 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +using CPUDevice = Eigen::ThreadPoolDevice; + +#define REGISTER_GELU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Gelu").Device(DEVICE_CPU).TypeConstraint("T"), \ + GeluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("GeluGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ + GeluGradOp); + +// Gelu only makes sense with floating points. +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GELU_KERNELS); +#undef REGISTER_GELU_KERNELS + +#ifdef GOOGLE_CUDA + +using GPUDevice = Eigen::GpuDevice; + +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void Gelu::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor features, \ + bool approximate, typename TTypes::Tensor activations); \ + extern template struct Gelu; \ + \ + template <> \ + void GeluGrad::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor gradients, \ + typename TTypes::ConstTensor features, bool approximate, \ + typename TTypes::Tensor backprops); \ + extern template struct GeluGrad; + +TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); +#undef DECLARE_GPU_SPEC +} // namespace functor + +// Registration of the GPU implementations. +#define REGISTER_GELU_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Gelu").Device(DEVICE_GPU).TypeConstraint("T"), \ + GeluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("GeluGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ + GeluGradOp); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GELU_GPU_KERNELS); +#undef REGISTER_GELU_GPU_KERNELS + +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h new file mode 100644 index 0000000000..a0469f3571 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h @@ -0,0 +1,144 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_ADDONS_GELU_OP_H_ +#define TENSORFLOW_ADDONS_GELU_OP_H_ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace functor { + +// Functor used by GeluOp to do the computations. +template +struct Gelu { + // Computes Gelu activation. + // + // features: any shape. + // approximate: whether to enable approximation. + // activations: same shape as "features". + void operator()(const Device& d, typename TTypes::ConstTensor features, + bool approximate, typename TTypes::Tensor activations) { + if (approximate) { + // y = 0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x^3))) + activations.device(d) = + static_cast(0.5) * features * + (static_cast(1) + + (static_cast(M_2_SQRTPI * M_SQRT1_2) * + (features + static_cast(0.044715) * features.cube())) + .tanh()); + } else { + // y = x * normcdf(x) = 0.5 * x * (1 + erf(x / sqrt(2))) + activations.device(d) = + static_cast(0.5) * features * + (static_cast(1) + (features * static_cast(M_SQRT1_2)).erf()); + } + } +}; + +// Functor used by GeluGradOp to do the computations. +template +struct GeluGrad { + // Computes GeluGrad backprops. + // + // gradients: gradients backpropagated to the Gelu op. + // features: the inputs that were passed to the Gelu op. + // approximate: whether to enable approximation. + // backprops: gradients to backpropagate to the Gelu inputs. + void operator()(const Device& d, typename TTypes::ConstTensor gradients, + typename TTypes::ConstTensor features, bool approximate, + typename TTypes::Tensor backprops) { + if (approximate) { + const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); + const T kBeta = kAlpha * static_cast(0.044715) * static_cast(3); + const auto y = + (kAlpha * ((static_cast(0.044715) * features.cube()) + features)) + .tanh(); + backprops.device(d) = ((-features * y.square() + features) * + (kBeta * features.square() + kAlpha) + + static_cast(1) + y) * + gradients * static_cast(0.5); + } else { + backprops.device(d) = + gradients * (static_cast(M_2_SQRTPI * M_SQRT1_2 * 0.5) * features * + (-features.square() * static_cast(0.5)).exp() + + (static_cast(0.5) * + (static_cast(1) + + (features * static_cast(M_SQRT1_2)).erf()))); + } + } +}; + +} // namespace functor + +template +class GeluOp : public UnaryElementWiseOp> { + public: + explicit GeluOp(OpKernelConstruction* context) + : UnaryElementWiseOp>::UnaryElementWiseOp(context) { + OP_REQUIRES_OK(context, context->GetAttr("approximate", &approximate_)); + } + + void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { + functor::Gelu functor; + functor(context->eigen_device(), input.flat(), approximate_, + output->flat()); + } + + private: + bool approximate_; +}; + +template +class GeluGradOp : public BinaryElementWiseOp> { + public: + explicit GeluGradOp(OpKernelConstruction* context) + : BinaryElementWiseOp>::BinaryElementWiseOp( + context) { + OP_REQUIRES_OK(context, context->GetAttr("approximate", &approximate_)); + } + + void OperateNoTemplate(OpKernelContext* context, const Tensor& g, + const Tensor& a, bool approximate, Tensor* output); + + template + void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, + Tensor* output) { + OperateNoTemplate(context, g, a, approximate_, output); + } + + private: + bool approximate_; +}; + +template +void GeluGradOp::OperateNoTemplate(OpKernelContext* context, + const Tensor& g, const Tensor& a, + bool approximate, + Tensor* output) { + functor::GeluGrad functor; + functor(context->eigen_device(), g.flat(), a.flat(), + approximate, output->flat()); +} + +} // namespace tensorflow + +#undef EIGEN_USE_THREADS + +#endif // TENSORFLOW_ADDONS_GELU_OP_H_ diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_gpu.cu.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_gpu.cu.cc new file mode 100644 index 0000000000..37d21e66e0 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_gpu.cu.cc @@ -0,0 +1,36 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h" +#include "tensorflow/core/framework/register_types.h" +#include "third_party/eigen3/Eigen/Core" + +namespace tensorflow { + +using GPUDevice = Eigen::GpuDevice; + +#define DEFINE_GPU_KERNELS(T) \ + template struct functor::Gelu; \ + template struct functor::GeluGrad; + +TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow_addons/custom_ops/activations/cc/ops/gelu_op.cc b/tensorflow_addons/custom_ops/activations/cc/ops/gelu_op.cc new file mode 100644 index 0000000000..03406894b8 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/ops/gelu_op.cc @@ -0,0 +1,37 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("Gelu") + .Input("features: T") + .Output("activations: T") + .Attr("T: {half, float, double}") + .Attr("approximate: bool = true") + .SetShapeFn(shape_inference::UnchangedShape); + +REGISTER_OP("GeluGrad") + .Input("gradients: T") + .Input("features: T") + .Output("backprops: T") + .Attr("T: {half, float, double}") + .Attr("approximate: bool = true") + .SetShapeFn(shape_inference::MergeBothInputsShapeFn); + +} // namespace tensorflow