From b35e84d0d5d64632331e78c9899b6a855c6a26ee Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Sat, 17 Aug 2019 16:38:36 +0800 Subject: [PATCH 01/17] add CPU and GPU kernel for gelu --- .../custom_ops/activations/BUILD | 49 +++++++++ .../activations/cc/kernels/gelu_op.cc | 75 +++++++++++++ .../activations/cc/kernels/gelu_op.h | 67 ++++++++++++ .../activations/cc/kernels/gelu_op_functor.h | 52 +++++++++ .../activations/cc/kernels/gelu_op_gpu.cu.cc | 101 ++++++++++++++++++ .../custom_ops/activations/cc/ops/gelu_op.cc | 35 ++++++ 6 files changed, 379 insertions(+) create mode 100644 tensorflow_addons/custom_ops/activations/BUILD create mode 100644 tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc create mode 100644 tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h create mode 100644 tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_functor.h create mode 100644 tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_gpu.cu.cc create mode 100644 tensorflow_addons/custom_ops/activations/cc/ops/gelu_op.cc diff --git a/tensorflow_addons/custom_ops/activations/BUILD b/tensorflow_addons/custom_ops/activations/BUILD new file mode 100644 index 0000000000..8d567af309 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/BUILD @@ -0,0 +1,49 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) + +load("@local_config_tf//:build_defs.bzl", "D_GLIBCXX_USE_CXX11_ABI") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured", "if_cuda") + +cc_library( + name = "gelu_op_gpu", + srcs = [ + "cc/kernels/gelu_op.h", + "cc/kernels/gelu_op_functor.h", + "cc/kernels/gelu_op_gpu.cu.cc", + ], + copts = if_cuda_is_configured([ + "-DGOOGLE_CUDA=1", + "-x cuda", + "-nvcc_options=relaxed-constexpr", + "-nvcc_options=ftz=true", + ]), + deps = [ + "@local_config_tf//:libtensorflow_framework", + "@local_config_tf//:tf_header_lib", + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_libs", + "@local_config_cuda//cuda:cuda_headers", + ]), + alwayslink = 1, +) + +cc_binary( + name = "_activation_ops.so", + srcs = [ + "cc/kernels/gelu_op.cc", + "cc/kernels/gelu_op.h", + "cc/kernels/gelu_op_functor.h", + "cc/ops/gelu_op.cc", + ], + copts = [ + "-pthread", + "-std=c++11", + D_GLIBCXX_USE_CXX11_ABI, + ] + if_cuda(["-DGOOGLE_CUDA=1"]), + linkshared = 1, + deps = [ + "@local_config_tf//:libtensorflow_framework", + "@local_config_tf//:tf_header_lib", + ] + if_cuda_is_configured([":gelu_op_gpu"]), +) diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc new file mode 100644 index 0000000000..01d6a07b71 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc @@ -0,0 +1,75 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" + +namespace tensorflow { + +using CPUDevice = Eigen::ThreadPoolDevice; + +#ifdef GOOGLE_CUDA +using GPUDevice = Eigen::GpuDevice; +#endif + +#define REGISTER_GELU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Gelu").Device(DEVICE_CPU).TypeConstraint("T"), \ + GeluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("GeluGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ + GeluGradOp); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GELU_KERNELS); +#undef REGISTER_GELU_KERNELS + +#ifdef GOOGLE_CUDA + +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void Gelu::operator()(const GPUDevice& d, \ + typename TTypes::ConstTensor features, \ + typename TTypes::Tensor activations); \ + extern template struct Gelu; \ + \ + template <> \ + void GeluGrad::operator()(const GPUDevice& d, \ + typename TTypes::ConstTensor gradients, \ + typename TTypes::ConstTensor features, \ + typename TTypes::Tensor backprops); \ + extern template struct GeluGrad; + +TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); +} // namespace functor + +#define REGISTER_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Gelu").Device(DEVICE_GPU).TypeConstraint("T"), \ + GeluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("GeluGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ + GeluGradOp); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); +#undef REGISTER_GPU_KERNELS + +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h new file mode 100644 index 0000000000..29ce8ab23f --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h @@ -0,0 +1,67 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_ADDONS_GELU_OP_H_ +#define TENSORFLOW_ADDONS_GELU_OP_H_ + +#define EIGEN_USE_THREADS + +#include "tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_functor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +template +class GeluOp : public UnaryElementWiseOp> { + public: + using UnaryElementWiseOp>::UnaryElementWiseOp; + + void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { + functor::Gelu functor; + functor(context->eigen_device(), input.flat(), output->flat()); + } +}; + +template +class GeluGradOp : public BinaryElementWiseOp> { + public: + using BinaryElementWiseOp>::BinaryElementWiseOp; + + void OperateNoTemplate(OpKernelContext* context, const Tensor& g, + const Tensor& a, Tensor* output); + + template + void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, + Tensor* output) { + OperateNoTemplate(context, g, a, output); + } +}; + +template +void GeluGradOp::OperateNoTemplate(OpKernelContext* context, + const Tensor& g, const Tensor& a, + Tensor* output) { + functor::GeluGrad functor; + functor(context->eigen_device(), g.flat(), a.flat(), + output->flat()); +} + +} // namespace tensorflow + +#undef EIGEN_USE_THREADS + +#endif // TENSORFLOW_ADDONS_GELU_OP_H_ diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_functor.h b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_functor.h new file mode 100644 index 0000000000..549d74c2ec --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_functor.h @@ -0,0 +1,52 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_ADDONS_GELU_OP_FUNCTOR_H_ +#define TENSORFLOW_ADDONS_GELU_OP_FUNCTOR_H_ + +#include +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +template +struct Gelu { + void operator()(const Device& d, + typename TTypes::ConstTensor features, + typename TTypes::Tensor activations) { + const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); + activations.device(d) = T(0.5) * features * (T(1) + (kAlpha * (features + T(0.044715) * features.cube())).tanh()); + } +}; + +template +struct GeluGrad { + void operator()(const Device& d, + typename TTypes::ConstTensor gradients, + typename TTypes::ConstTensor features, + typename TTypes::Tensor backprops) { + const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); + const T kBeta = kAlpha * T(0.044715) * T(3); + const auto y = (kAlpha * ((T(0.044715) * features.cube()) + features)).tanh(); + backprops.device(d) = ((-features * (y * y) + features) * (kBeta * features.square() + kAlpha) + T(1) + y) * gradients * T(0.5); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_ADDONS_GELU_OP_FUNCTOR_H_ diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_gpu.cu.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_gpu.cu.cc new file mode 100644 index 0000000000..08625cc6d1 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_gpu.cu.cc @@ -0,0 +1,101 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include +#include "third_party/eigen3/Eigen/Core" +#include "tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_functor.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" +#include "tensorflow/core/util/gpu_launch_config.h" + +namespace tensorflow { + +using GPUDevice = Eigen::GpuDevice; + +namespace functor { + +template +__global__ void GeluKernel(const int32 count, const T* input, T* output) { + + // output[i] = 0.5x * (1 + tanh(sqrt(2 / PI) * (x + 0.044715 * x^3))) + const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); + GPU_1D_KERNEL_LOOP(i, count) { + T x = input[i]; + output[i] = T(0.5) * x * (T(1) + tanh(kAlpha * (x + T(0.044715) * (x * x * x)))); + } +} + +template +__global__ void GeluGradKernel(const int32 count, const T* gradients, const T* features, T* backprops) { + const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); + const T kBeta = kAlpha * T(0.044715) * T(3); + GPU_1D_KERNEL_LOOP(i, count) { + T x = features[i]; + const T y = tanh(kAlpha * ((T(0.044715) * x * x * x) + x)); + backprops[i] = ((-x * (y * y) + x) * (kBeta * x * x + kAlpha) + T(1) + y) * gradients[i] * T(0.5); + } +} + +template +struct Gelu { + void operator()(const GPUDevice& d, + typename TTypes::ConstTensor features, + typename TTypes::Tensor activations) { + + const int32 count = features.size(); + if (count == 0) return; + + // GpuLaunchConfig config = GetGpuLaunchConfig(count, d); + GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize(count, d, GeluKernel, 0, 1024); + + TF_CHECK_OK(GpuLaunchKernel( + GeluKernel, config.block_count, config.thread_per_block, 0, + d.stream(), count, features.data(), activations.data())); + } +}; + +template +struct GeluGrad { + void operator()(const GPUDevice& d, + typename TTypes::ConstTensor gradients, + typename TTypes::ConstTensor features, + typename TTypes::Tensor backprops) { + const int32 count = gradients.size(); + if (count == 0) return; + + // GpuLaunchConfig config = GetGpuLaunchConfig(count, d); + GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize(count, d, GeluKernel, 0, 1024); + + TF_CHECK_OK(GpuLaunchKernel( + GeluGradKernel, config.block_count, config.thread_per_block, 0, + d.stream(), count, gradients.data(), features.data(), backprops.data())); + } +}; + +} // namespace functor + +#define DEFINE_GPU_KERNELS(T) \ + template struct functor::Gelu; \ + template struct functor::GeluGrad; + +TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow_addons/custom_ops/activations/cc/ops/gelu_op.cc b/tensorflow_addons/custom_ops/activations/cc/ops/gelu_op.cc new file mode 100644 index 0000000000..704ac9a397 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/ops/gelu_op.cc @@ -0,0 +1,35 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("Gelu") + .Input("features: T") + .Output("activations: T") + .Attr("T: {half, float, double}") + .SetShapeFn(shape_inference::UnchangedShape); + +REGISTER_OP("GeluGrad") + .Input("gradients: T") + .Input("features: T") + .Output("backprops: T") + .Attr("T: {half, float, double}") + .SetShapeFn(shape_inference::MergeBothInputsShapeFn); + +} // namespace tensorflow From 73e55a81e2748cd797d6afe917a25ab63a658774 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Sat, 17 Aug 2019 20:08:43 +0800 Subject: [PATCH 02/17] add some documentations --- .../activations/cc/kernels/gelu_op.cc | 10 ++++---- .../activations/cc/kernels/gelu_op_functor.h | 20 ++++++++++++---- .../activations/cc/kernels/gelu_op_gpu.cu.cc | 24 ++++++++++++------- 3 files changed, 38 insertions(+), 16 deletions(-) diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc index 01d6a07b71..7cdbb14d3b 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc @@ -24,10 +24,6 @@ namespace tensorflow { using CPUDevice = Eigen::ThreadPoolDevice; -#ifdef GOOGLE_CUDA -using GPUDevice = Eigen::GpuDevice; -#endif - #define REGISTER_GELU_KERNELS(type) \ REGISTER_KERNEL_BUILDER( \ Name("Gelu").Device(DEVICE_CPU).TypeConstraint("T"), \ @@ -36,11 +32,15 @@ using GPUDevice = Eigen::GpuDevice; Name("GeluGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ GeluGradOp); +// Gelu only makes sense with floating points. TF_CALL_GPU_NUMBER_TYPES(REGISTER_GELU_KERNELS); #undef REGISTER_GELU_KERNELS #ifdef GOOGLE_CUDA +using GPUDevice = Eigen::GpuDevice; + +// Forward declarations of the functor specializations for GPU. namespace functor { #define DECLARE_GPU_SPEC(T) \ template <> \ @@ -57,8 +57,10 @@ namespace functor { extern template struct GeluGrad; TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); +#undef DECLARE_GPU_DPEC } // namespace functor +// Registration of the GPU implementations. #define REGISTER_GPU_KERNELS(type) \ REGISTER_KERNEL_BUILDER( \ Name("Gelu").Device(DEVICE_GPU).TypeConstraint("T"), \ diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_functor.h b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_functor.h index 549d74c2ec..04b7c847cc 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_functor.h +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_functor.h @@ -23,26 +23,38 @@ limitations under the License. namespace tensorflow { namespace functor { +// Functor used by GeluOp to do the computations. template struct Gelu { + // Computes Gelu activation. + // + // features: any shape. + // activations: same shape as "features". void operator()(const Device& d, typename TTypes::ConstTensor features, typename TTypes::Tensor activations) { const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); - activations.device(d) = T(0.5) * features * (T(1) + (kAlpha * (features + T(0.044715) * features.cube())).tanh()); + activations.device(d) = static_cast(0.5) * features * (static_cast(1) + (kAlpha * (features + static_cast(0.044715) * features.cube())).tanh()); } }; +// Functor used by GeluGradOp to do the computations. template struct GeluGrad { + // Computes GeluGrad backprops. + // + // gradients: gradients backpropagated to the Gelu op. + // features: either the inputs that were passed to the Gelu or, or its + // outputs (using either one yields the same result here). + // backprops: gradients to backpropagate to the Gelu inputs. void operator()(const Device& d, typename TTypes::ConstTensor gradients, typename TTypes::ConstTensor features, typename TTypes::Tensor backprops) { const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); - const T kBeta = kAlpha * T(0.044715) * T(3); - const auto y = (kAlpha * ((T(0.044715) * features.cube()) + features)).tanh(); - backprops.device(d) = ((-features * (y * y) + features) * (kBeta * features.square() + kAlpha) + T(1) + y) * gradients * T(0.5); + const T kBeta = kAlpha * static_cast(0.044715) * static_cast(3); + const auto y = (kAlpha * ((static_cast(0.044715) * features.cube()) + features)).tanh(); + backprops.device(d) = ((-features * y.square() + features) * (kBeta * features.square() + kAlpha) + static_cast(1) + y) * gradients * static_cast(0.5); } }; diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_gpu.cu.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_gpu.cu.cc index 08625cc6d1..d15a7ec089 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_gpu.cu.cc +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_gpu.cu.cc @@ -37,23 +37,27 @@ __global__ void GeluKernel(const int32 count, const T* input, T* output) { const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); GPU_1D_KERNEL_LOOP(i, count) { T x = input[i]; - output[i] = T(0.5) * x * (T(1) + tanh(kAlpha * (x + T(0.044715) * (x * x * x)))); + output[i] = static_cast(0.5) * x * (static_cast(1) + tanh(kAlpha * (x + static_cast(0.044715) * (x * x * x)))); } } template __global__ void GeluGradKernel(const int32 count, const T* gradients, const T* features, T* backprops) { const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); - const T kBeta = kAlpha * T(0.044715) * T(3); + const T kBeta = kAlpha * static_cast(0.044715) * static_cast(3); GPU_1D_KERNEL_LOOP(i, count) { T x = features[i]; - const T y = tanh(kAlpha * ((T(0.044715) * x * x * x) + x)); - backprops[i] = ((-x * (y * y) + x) * (kBeta * x * x + kAlpha) + T(1) + y) * gradients[i] * T(0.5); + const T y = tanh(kAlpha * ((static_cast(0.044715) * x * x * x) + x)); + backprops[i] = ((-x * (y * y) + x) * (kBeta * x * x + kAlpha) + static_cast(1) + y) * gradients[i] * static_cast(0.5); } } template struct Gelu { + // Computes Gelu activation. + // + // features: any shape. + // activations: same shape as "features". void operator()(const GPUDevice& d, typename TTypes::ConstTensor features, typename TTypes::Tensor activations) { @@ -61,8 +65,7 @@ struct Gelu { const int32 count = features.size(); if (count == 0) return; - // GpuLaunchConfig config = GetGpuLaunchConfig(count, d); - GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize(count, d, GeluKernel, 0, 1024); + GpuLaunchConfig config = GetGpuLaunchConfig(count, d, GeluKernel, 0, 0); TF_CHECK_OK(GpuLaunchKernel( GeluKernel, config.block_count, config.thread_per_block, 0, @@ -72,6 +75,12 @@ struct Gelu { template struct GeluGrad { + // Computes GeluGrad backprop. + // + // gradients: gradient backpropagated to the Gelu op. + // features: either the inputs that were passed to the Gelu, or its outputs + // (using either one yields the same result here). + // backprops: gradient to backpropagate to the Gelu inputs. void operator()(const GPUDevice& d, typename TTypes::ConstTensor gradients, typename TTypes::ConstTensor features, @@ -79,8 +88,7 @@ struct GeluGrad { const int32 count = gradients.size(); if (count == 0) return; - // GpuLaunchConfig config = GetGpuLaunchConfig(count, d); - GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize(count, d, GeluKernel, 0, 1024); + GpuLaunchConfig config = GetGpuLaunchConfig(count, d, GeluKernel, 0, 0); TF_CHECK_OK(GpuLaunchKernel( GeluGradKernel, config.block_count, config.thread_per_block, 0, From 921d6b40c5f6b918ab333c233a747630da15a8ae Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Sat, 17 Aug 2019 21:54:35 +0800 Subject: [PATCH 03/17] format codes --- .../activations/cc/kernels/gelu_op.cc | 62 ++++----- .../activations/cc/kernels/gelu_op.h | 45 +++---- .../activations/cc/kernels/gelu_op_functor.h | 65 +++++----- .../activations/cc/kernels/gelu_op_gpu.cu.cc | 120 +++++++++--------- .../custom_ops/activations/cc/ops/gelu_op.cc | 2 +- 5 files changed, 152 insertions(+), 142 deletions(-) diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc index 7cdbb14d3b..11efb272ce 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc @@ -16,21 +16,21 @@ limitations under the License. #define EIGEN_USE_THREADS #include "tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { using CPUDevice = Eigen::ThreadPoolDevice; -#define REGISTER_GELU_KERNELS(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Gelu").Device(DEVICE_CPU).TypeConstraint("T"), \ - GeluOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("GeluGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ - GeluGradOp); +#define REGISTER_GELU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Gelu").Device(DEVICE_CPU).TypeConstraint("T"), \ + GeluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("GeluGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ + GeluGradOp); // Gelu only makes sense with floating points. TF_CALL_GPU_NUMBER_TYPES(REGISTER_GELU_KERNELS); @@ -42,36 +42,36 @@ using GPUDevice = Eigen::GpuDevice; // Forward declarations of the functor specializations for GPU. namespace functor { -#define DECLARE_GPU_SPEC(T) \ - template <> \ - void Gelu::operator()(const GPUDevice& d, \ - typename TTypes::ConstTensor features, \ - typename TTypes::Tensor activations); \ - extern template struct Gelu; \ - \ - template <> \ - void GeluGrad::operator()(const GPUDevice& d, \ - typename TTypes::ConstTensor gradients, \ - typename TTypes::ConstTensor features, \ - typename TTypes::Tensor backprops); \ - extern template struct GeluGrad; +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void Gelu::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor features, \ + typename TTypes::Tensor activations); \ + extern template struct Gelu; \ + \ + template <> \ + void GeluGrad::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor gradients, \ + typename TTypes::ConstTensor features, \ + typename TTypes::Tensor backprops); \ + extern template struct GeluGrad; TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); #undef DECLARE_GPU_DPEC -} // namespace functor +} // namespace functor // Registration of the GPU implementations. -#define REGISTER_GPU_KERNELS(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Gelu").Device(DEVICE_GPU).TypeConstraint("T"), \ - GeluOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("GeluGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ - GeluGradOp); +#define REGISTER_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Gelu").Device(DEVICE_GPU).TypeConstraint("T"), \ + GeluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("GeluGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ + GeluGradOp); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA -} // namespace tensorflow +} // namespace tensorflow diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h index 29ce8ab23f..fc7988a2bf 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h @@ -18,50 +18,51 @@ limitations under the License. #define EIGEN_USE_THREADS -#include "tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_functor.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_functor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { template class GeluOp : public UnaryElementWiseOp> { - public: - using UnaryElementWiseOp>::UnaryElementWiseOp; + public: + using UnaryElementWiseOp>::UnaryElementWiseOp; - void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { - functor::Gelu functor; - functor(context->eigen_device(), input.flat(), output->flat()); - } + void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { + functor::Gelu functor; + functor(context->eigen_device(), input.flat(), + output->flat()); + } }; template class GeluGradOp : public BinaryElementWiseOp> { - public: - using BinaryElementWiseOp>::BinaryElementWiseOp; + public: + using BinaryElementWiseOp>::BinaryElementWiseOp; - void OperateNoTemplate(OpKernelContext* context, const Tensor& g, - const Tensor& a, Tensor* output); + void OperateNoTemplate(OpKernelContext* context, const Tensor& g, + const Tensor& a, Tensor* output); - template - void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, - Tensor* output) { - OperateNoTemplate(context, g, a, output); - } + template + void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, + Tensor* output) { + OperateNoTemplate(context, g, a, output); + } }; template void GeluGradOp::OperateNoTemplate(OpKernelContext* context, const Tensor& g, const Tensor& a, Tensor* output) { - functor::GeluGrad functor; - functor(context->eigen_device(), g.flat(), a.flat(), - output->flat()); + functor::GeluGrad functor; + functor(context->eigen_device(), g.flat(), a.flat(), + output->flat()); } -} // namespace tensorflow +} // namespace tensorflow #undef EIGEN_USE_THREADS -#endif // TENSORFLOW_ADDONS_GELU_OP_H_ +#endif // TENSORFLOW_ADDONS_GELU_OP_H_ diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_functor.h b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_functor.h index 04b7c847cc..f52a2f7ffc 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_functor.h +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_functor.h @@ -17,8 +17,8 @@ limitations under the License. #define TENSORFLOW_ADDONS_GELU_OP_FUNCTOR_H_ #include -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { namespace functor { @@ -26,39 +26,46 @@ namespace functor { // Functor used by GeluOp to do the computations. template struct Gelu { - // Computes Gelu activation. - // - // features: any shape. - // activations: same shape as "features". - void operator()(const Device& d, - typename TTypes::ConstTensor features, - typename TTypes::Tensor activations) { - const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); - activations.device(d) = static_cast(0.5) * features * (static_cast(1) + (kAlpha * (features + static_cast(0.044715) * features.cube())).tanh()); - } + // Computes Gelu activation. + // + // features: any shape. + // activations: same shape as "features". + void operator()(const Device& d, typename TTypes::ConstTensor features, + typename TTypes::Tensor activations) { + const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); + activations.device(d) = + static_cast(0.5) * features * + (static_cast(1) + + (kAlpha * (features + static_cast(0.044715) * features.cube())) + .tanh()); + } }; // Functor used by GeluGradOp to do the computations. template struct GeluGrad { - // Computes GeluGrad backprops. - // - // gradients: gradients backpropagated to the Gelu op. - // features: either the inputs that were passed to the Gelu or, or its - // outputs (using either one yields the same result here). - // backprops: gradients to backpropagate to the Gelu inputs. - void operator()(const Device& d, - typename TTypes::ConstTensor gradients, - typename TTypes::ConstTensor features, - typename TTypes::Tensor backprops) { - const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); - const T kBeta = kAlpha * static_cast(0.044715) * static_cast(3); - const auto y = (kAlpha * ((static_cast(0.044715) * features.cube()) + features)).tanh(); - backprops.device(d) = ((-features * y.square() + features) * (kBeta * features.square() + kAlpha) + static_cast(1) + y) * gradients * static_cast(0.5); - } + // Computes GeluGrad backprops. + // + // gradients: gradients backpropagated to the Gelu op. + // features: either the inputs that were passed to the Gelu or, or its + // outputs (using either one yields the same result here). + // backprops: gradients to backpropagate to the Gelu inputs. + void operator()(const Device& d, typename TTypes::ConstTensor gradients, + typename TTypes::ConstTensor features, + typename TTypes::Tensor backprops) { + const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); + const T kBeta = kAlpha * static_cast(0.044715) * static_cast(3); + const auto y = + (kAlpha * ((static_cast(0.044715) * features.cube()) + features)) + .tanh(); + backprops.device(d) = ((-features * y.square() + features) * + (kBeta * features.square() + kAlpha) + + static_cast(1) + y) * + gradients * static_cast(0.5); + } }; -} // namespace functor -} // namespace tensorflow +} // namespace functor +} // namespace tensorflow -#endif // TENSORFLOW_ADDONS_GELU_OP_FUNCTOR_H_ +#endif // TENSORFLOW_ADDONS_GELU_OP_FUNCTOR_H_ diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_gpu.cu.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_gpu.cu.cc index d15a7ec089..3634148f8a 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_gpu.cu.cc +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_gpu.cu.cc @@ -18,11 +18,11 @@ limitations under the License. #define EIGEN_USE_GPU #include -#include "third_party/eigen3/Eigen/Core" -#include "tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_functor.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/util/gpu_kernel_helper.h" #include "tensorflow/core/util/gpu_launch_config.h" +#include "tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_functor.h" +#include "third_party/eigen3/Eigen/Core" namespace tensorflow { @@ -32,78 +32,80 @@ namespace functor { template __global__ void GeluKernel(const int32 count, const T* input, T* output) { - - // output[i] = 0.5x * (1 + tanh(sqrt(2 / PI) * (x + 0.044715 * x^3))) - const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); - GPU_1D_KERNEL_LOOP(i, count) { - T x = input[i]; - output[i] = static_cast(0.5) * x * (static_cast(1) + tanh(kAlpha * (x + static_cast(0.044715) * (x * x * x)))); - } + // output[i] = 0.5x * (1 + tanh(sqrt(2 / PI) * (x + 0.044715 * x^3))) + const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); + GPU_1D_KERNEL_LOOP(i, count) { + T x = input[i]; + output[i] = static_cast(0.5) * x * + (static_cast(1) + + tanh(kAlpha * (x + static_cast(0.044715) * (x * x * x)))); + } } template -__global__ void GeluGradKernel(const int32 count, const T* gradients, const T* features, T* backprops) { - const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); - const T kBeta = kAlpha * static_cast(0.044715) * static_cast(3); - GPU_1D_KERNEL_LOOP(i, count) { - T x = features[i]; - const T y = tanh(kAlpha * ((static_cast(0.044715) * x * x * x) + x)); - backprops[i] = ((-x * (y * y) + x) * (kBeta * x * x + kAlpha) + static_cast(1) + y) * gradients[i] * static_cast(0.5); - } +__global__ void GeluGradKernel(const int32 count, const T* gradients, + const T* features, T* backprops) { + const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); + const T kBeta = kAlpha * static_cast(0.044715) * static_cast(3); + GPU_1D_KERNEL_LOOP(i, count) { + T x = features[i]; + const T y = tanh(kAlpha * ((static_cast(0.044715) * x * x * x) + x)); + backprops[i] = ((-x * (y * y) + x) * (kBeta * x * x + kAlpha) + + static_cast(1) + y) * + gradients[i] * static_cast(0.5); + } } template struct Gelu { - // Computes Gelu activation. - // - // features: any shape. - // activations: same shape as "features". - void operator()(const GPUDevice& d, - typename TTypes::ConstTensor features, - typename TTypes::Tensor activations) { - - const int32 count = features.size(); - if (count == 0) return; - - GpuLaunchConfig config = GetGpuLaunchConfig(count, d, GeluKernel, 0, 0); - - TF_CHECK_OK(GpuLaunchKernel( - GeluKernel, config.block_count, config.thread_per_block, 0, - d.stream(), count, features.data(), activations.data())); - } + // Computes Gelu activation. + // + // features: any shape. + // activations: same shape as "features". + void operator()(const GPUDevice& d, typename TTypes::ConstTensor features, + typename TTypes::Tensor activations) { + const int32 count = features.size(); + if (count == 0) return; + + GpuLaunchConfig config = GetGpuLaunchConfig(count, d, GeluKernel, 0, 0); + + TF_CHECK_OK(GpuLaunchKernel(GeluKernel, config.block_count, + config.thread_per_block, 0, d.stream(), count, + features.data(), activations.data())); + } }; template struct GeluGrad { - // Computes GeluGrad backprop. - // - // gradients: gradient backpropagated to the Gelu op. - // features: either the inputs that were passed to the Gelu, or its outputs - // (using either one yields the same result here). - // backprops: gradient to backpropagate to the Gelu inputs. - void operator()(const GPUDevice& d, - typename TTypes::ConstTensor gradients, - typename TTypes::ConstTensor features, - typename TTypes::Tensor backprops) { - const int32 count = gradients.size(); - if (count == 0) return; - - GpuLaunchConfig config = GetGpuLaunchConfig(count, d, GeluKernel, 0, 0); - - TF_CHECK_OK(GpuLaunchKernel( - GeluGradKernel, config.block_count, config.thread_per_block, 0, - d.stream(), count, gradients.data(), features.data(), backprops.data())); - } + // Computes GeluGrad backprop. + // + // gradients: gradient backpropagated to the Gelu op. + // features: either the inputs that were passed to the Gelu, or its outputs + // (using either one yields the same result here). + // backprops: gradient to backpropagate to the Gelu inputs. + void operator()(const GPUDevice& d, typename TTypes::ConstTensor gradients, + typename TTypes::ConstTensor features, + typename TTypes::Tensor backprops) { + const int32 count = gradients.size(); + if (count == 0) return; + + GpuLaunchConfig config = GetGpuLaunchConfig(count, d, GeluKernel, 0, 0); + + TF_CHECK_OK(GpuLaunchKernel(GeluGradKernel, config.block_count, + config.thread_per_block, 0, d.stream(), count, + gradients.data(), features.data(), + backprops.data())); + } }; -} // namespace functor +} // namespace functor -#define DEFINE_GPU_KERNELS(T) \ - template struct functor::Gelu; \ - template struct functor::GeluGrad; +#define DEFINE_GPU_KERNELS(T) \ + template struct functor::Gelu; \ + template struct functor::GeluGrad; TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); -} // namespace tensorflow +} // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA diff --git a/tensorflow_addons/custom_ops/activations/cc/ops/gelu_op.cc b/tensorflow_addons/custom_ops/activations/cc/ops/gelu_op.cc index 704ac9a397..e716cc5a86 100644 --- a/tensorflow_addons/custom_ops/activations/cc/ops/gelu_op.cc +++ b/tensorflow_addons/custom_ops/activations/cc/ops/gelu_op.cc @@ -32,4 +32,4 @@ REGISTER_OP("GeluGrad") .Attr("T: {half, float, double}") .SetShapeFn(shape_inference::MergeBothInputsShapeFn); -} // namespace tensorflow +} // namespace tensorflow From 799b610233a2f3479ec0c7aceb672136c73ee266 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Sun, 18 Aug 2019 16:55:03 +0800 Subject: [PATCH 04/17] support original (non-approximate) gelu --- .../activations/cc/kernels/gelu_op.cc | 4 +- .../activations/cc/kernels/gelu_op.h | 26 +++++++-- .../activations/cc/kernels/gelu_op_functor.h | 53 +++++++++++------ .../activations/cc/kernels/gelu_op_gpu.cu.cc | 58 ++++++++++++++----- .../custom_ops/activations/cc/ops/gelu_op.cc | 2 + 5 files changed, 105 insertions(+), 38 deletions(-) diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc index 11efb272ce..5ccacd665f 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc @@ -46,13 +46,13 @@ namespace functor { template <> \ void Gelu::operator()( \ const GPUDevice& d, typename TTypes::ConstTensor features, \ - typename TTypes::Tensor activations); \ + bool approximate, typename TTypes::Tensor activations); \ extern template struct Gelu; \ \ template <> \ void GeluGrad::operator()( \ const GPUDevice& d, typename TTypes::ConstTensor gradients, \ - typename TTypes::ConstTensor features, \ + typename TTypes::ConstTensor features, bool approximate, \ typename TTypes::Tensor backprops); \ extern template struct GeluGrad; diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h index fc7988a2bf..dcc4ba5611 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h @@ -28,37 +28,51 @@ namespace tensorflow { template class GeluOp : public UnaryElementWiseOp> { public: - using UnaryElementWiseOp>::UnaryElementWiseOp; + explicit GeluOp(OpKernelConstruction* context) + : UnaryElementWiseOp>::UnaryElementWiseOp(context) { + OP_REQUIRES_OK(context, context->GetAttr("approximate", &approximate_)); + } void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { functor::Gelu functor; - functor(context->eigen_device(), input.flat(), + functor(context->eigen_device(), input.flat(), approximate_, output->flat()); } + + private: + bool approximate_; }; template class GeluGradOp : public BinaryElementWiseOp> { public: - using BinaryElementWiseOp>::BinaryElementWiseOp; + explicit GeluGradOp(OpKernelConstruction* context) + : BinaryElementWiseOp>::BinaryElementWiseOp( + context) { + OP_REQUIRES_OK(context, context->GetAttr("approximate", &approximate_)); + } void OperateNoTemplate(OpKernelContext* context, const Tensor& g, - const Tensor& a, Tensor* output); + const Tensor& a, bool approximate, Tensor* output); template void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, Tensor* output) { - OperateNoTemplate(context, g, a, output); + OperateNoTemplate(context, g, a, approximate_, output); } + + private: + bool approximate_; }; template void GeluGradOp::OperateNoTemplate(OpKernelContext* context, const Tensor& g, const Tensor& a, + bool approximate, Tensor* output) { functor::GeluGrad functor; functor(context->eigen_device(), g.flat(), a.flat(), - output->flat()); + approximate, output->flat()); } } // namespace tensorflow diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_functor.h b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_functor.h index f52a2f7ffc..d0003f0f1c 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_functor.h +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_functor.h @@ -29,15 +29,23 @@ struct Gelu { // Computes Gelu activation. // // features: any shape. + // approximate: whether to enable approximation. // activations: same shape as "features". void operator()(const Device& d, typename TTypes::ConstTensor features, - typename TTypes::Tensor activations) { - const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); - activations.device(d) = - static_cast(0.5) * features * - (static_cast(1) + - (kAlpha * (features + static_cast(0.044715) * features.cube())) - .tanh()); + bool approximate, typename TTypes::Tensor activations) { + if (approximate) { + const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); + + activations.device(d) = + static_cast(0.5) * features * + (static_cast(1) + + (kAlpha * (features + static_cast(0.044715) * features.cube())) + .tanh()); + } else { + activations.device(d) = + static_cast(0.5) * features * + (static_cast(1) + (features * static_cast(M_SQRT1_2)).erf()); + } } }; @@ -49,19 +57,30 @@ struct GeluGrad { // gradients: gradients backpropagated to the Gelu op. // features: either the inputs that were passed to the Gelu or, or its // outputs (using either one yields the same result here). + // approximate: whether to enable approximation. // backprops: gradients to backpropagate to the Gelu inputs. void operator()(const Device& d, typename TTypes::ConstTensor gradients, - typename TTypes::ConstTensor features, + typename TTypes::ConstTensor features, bool approximate, typename TTypes::Tensor backprops) { - const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); - const T kBeta = kAlpha * static_cast(0.044715) * static_cast(3); - const auto y = - (kAlpha * ((static_cast(0.044715) * features.cube()) + features)) - .tanh(); - backprops.device(d) = ((-features * y.square() + features) * - (kBeta * features.square() + kAlpha) + - static_cast(1) + y) * - gradients * static_cast(0.5); + if (approximate) { + const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); + const T kBeta = kAlpha * static_cast(0.044715) * static_cast(3); + const auto y = + (kAlpha * ((static_cast(0.044715) * features.cube()) + features)) + .tanh(); + backprops.device(d) = ((-features * y.square() + features) * + (kBeta * features.square() + kAlpha) + + static_cast(1) + y) * + gradients * static_cast(0.5); + } else { + const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2 * 0.5); + backprops.device(d) = + gradients * (kAlpha * features * + (-features.square() * static_cast(0.5)).exp() + + (static_cast(0.5) * + (static_cast(1) + + (features * static_cast(M_SQRT1_2)).erf()))); + } } }; diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_gpu.cu.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_gpu.cu.cc index 3634148f8a..5a9fcc7b96 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_gpu.cu.cc +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_gpu.cu.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/core/util/gpu_kernel_helper.h" #include "tensorflow/core/util/gpu_launch_config.h" #include "tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_functor.h" -#include "third_party/eigen3/Eigen/Core" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { @@ -31,7 +31,8 @@ using GPUDevice = Eigen::GpuDevice; namespace functor { template -__global__ void GeluKernel(const int32 count, const T* input, T* output) { +__global__ void ApproximateGeluKernel(const int32 count, const T* input, + T* output) { // output[i] = 0.5x * (1 + tanh(sqrt(2 / PI) * (x + 0.044715 * x^3))) const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); GPU_1D_KERNEL_LOOP(i, count) { @@ -43,8 +44,19 @@ __global__ void GeluKernel(const int32 count, const T* input, T* output) { } template -__global__ void GeluGradKernel(const int32 count, const T* gradients, - const T* features, T* backprops) { +__global__ void GeluKernel(const int32 count, const T* input, T* output) { + // output[i] = x * P(X <= x) = x * normcdf(x) = 0.5x * (1 + erf(x / sqrt(2)) + GPU_1D_KERNEL_LOOP(i, count) { + T x = input[i]; + output[i] = + static_cast(0.5) * x * + (static_cast(1) + Eigen::numext::erf(x * static_cast(M_SQRT1_2))); + } +} + +template +__global__ void ApproximateGeluGradKernel(const int32 count, const T* gradients, + const T* features, T* backprops) { const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); const T kBeta = kAlpha * static_cast(0.044715) * static_cast(3); GPU_1D_KERNEL_LOOP(i, count) { @@ -56,20 +68,37 @@ __global__ void GeluGradKernel(const int32 count, const T* gradients, } } +template +__global__ void GeluGradKernel(const int32 count, const T* gradients, + const T* features, T* backprops) { + const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2 * 0.5); + GPU_1D_KERNEL_LOOP(i, count) { + T x = features[i]; + backprops[i] = + gradients[i] * (kAlpha * x * exp(-x * x * static_cast(0.5)) + + (static_cast(0.5) * + (static_cast(1) + + Eigen::numext::erf(x * static_cast(M_SQRT1_2))))); + } +} + template struct Gelu { // Computes Gelu activation. // // features: any shape. + // approximate: whether to enable approximation. // activations: same shape as "features". void operator()(const GPUDevice& d, typename TTypes::ConstTensor features, - typename TTypes::Tensor activations) { + bool approximate, typename TTypes::Tensor activations) { const int32 count = features.size(); if (count == 0) return; - GpuLaunchConfig config = GetGpuLaunchConfig(count, d, GeluKernel, 0, 0); + auto kernel = approximate ? ApproximateGeluKernel : GeluKernel; + + GpuLaunchConfig config = GetGpuLaunchConfig(count, d, kernel, 0, 0); - TF_CHECK_OK(GpuLaunchKernel(GeluKernel, config.block_count, + TF_CHECK_OK(GpuLaunchKernel(kernel, config.block_count, config.thread_per_block, 0, d.stream(), count, features.data(), activations.data())); } @@ -82,19 +111,22 @@ struct GeluGrad { // gradients: gradient backpropagated to the Gelu op. // features: either the inputs that were passed to the Gelu, or its outputs // (using either one yields the same result here). + // approximate: whether to enable approximation. // backprops: gradient to backpropagate to the Gelu inputs. void operator()(const GPUDevice& d, typename TTypes::ConstTensor gradients, - typename TTypes::ConstTensor features, + typename TTypes::ConstTensor features, bool approximate, typename TTypes::Tensor backprops) { const int32 count = gradients.size(); if (count == 0) return; - GpuLaunchConfig config = GetGpuLaunchConfig(count, d, GeluKernel, 0, 0); + auto kernel = + approximate ? ApproximateGeluGradKernel : GeluGradKernel; - TF_CHECK_OK(GpuLaunchKernel(GeluGradKernel, config.block_count, - config.thread_per_block, 0, d.stream(), count, - gradients.data(), features.data(), - backprops.data())); + GpuLaunchConfig config = GetGpuLaunchConfig(count, d, kernel, 0, 0); + + TF_CHECK_OK(GpuLaunchKernel( + kernel, config.block_count, config.thread_per_block, 0, d.stream(), + count, gradients.data(), features.data(), backprops.data())); } }; diff --git a/tensorflow_addons/custom_ops/activations/cc/ops/gelu_op.cc b/tensorflow_addons/custom_ops/activations/cc/ops/gelu_op.cc index e716cc5a86..03406894b8 100644 --- a/tensorflow_addons/custom_ops/activations/cc/ops/gelu_op.cc +++ b/tensorflow_addons/custom_ops/activations/cc/ops/gelu_op.cc @@ -23,6 +23,7 @@ REGISTER_OP("Gelu") .Input("features: T") .Output("activations: T") .Attr("T: {half, float, double}") + .Attr("approximate: bool = true") .SetShapeFn(shape_inference::UnchangedShape); REGISTER_OP("GeluGrad") @@ -30,6 +31,7 @@ REGISTER_OP("GeluGrad") .Input("features: T") .Output("backprops: T") .Attr("T: {half, float, double}") + .Attr("approximate: bool = true") .SetShapeFn(shape_inference::MergeBothInputsShapeFn); } // namespace tensorflow From 261cd49bdb72ca22a2f2f7a8061aeece69ad2d5f Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Sun, 18 Aug 2019 19:58:32 +0800 Subject: [PATCH 05/17] GPUDevice is super fast --- .../custom_ops/activations/BUILD | 2 - .../activations/cc/kernels/gelu_op.h | 64 +++++++++- .../activations/cc/kernels/gelu_op_functor.h | 90 -------------- .../activations/cc/kernels/gelu_op_gpu.cu.cc | 111 +----------------- 4 files changed, 65 insertions(+), 202 deletions(-) delete mode 100644 tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_functor.h diff --git a/tensorflow_addons/custom_ops/activations/BUILD b/tensorflow_addons/custom_ops/activations/BUILD index 8d567af309..a199fbc689 100644 --- a/tensorflow_addons/custom_ops/activations/BUILD +++ b/tensorflow_addons/custom_ops/activations/BUILD @@ -9,7 +9,6 @@ cc_library( name = "gelu_op_gpu", srcs = [ "cc/kernels/gelu_op.h", - "cc/kernels/gelu_op_functor.h", "cc/kernels/gelu_op_gpu.cu.cc", ], copts = if_cuda_is_configured([ @@ -33,7 +32,6 @@ cc_binary( srcs = [ "cc/kernels/gelu_op.cc", "cc/kernels/gelu_op.h", - "cc/kernels/gelu_op_functor.h", "cc/ops/gelu_op.cc", ], copts = [ diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h index dcc4ba5611..d33b8044db 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h @@ -20,10 +20,72 @@ limitations under the License. #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_functor.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { +namespace functor { + +// Functor used by GeluOp to do the computations. +template +struct Gelu { + // Computes Gelu activation. + // + // features: any shape. + // approximate: whether to enable approximation. + // activations: same shape as "features". + void operator()(const Device& d, typename TTypes::ConstTensor features, + bool approximate, typename TTypes::Tensor activations) { + if (approximate) { + // y = 0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.44715 * x^3))) + activations.device(d) = + static_cast(0.5) * features * + (static_cast(1) + + (static_cast(M_2_SQRTPI * M_SQRT1_2) * + (features + static_cast(0.044715) * features.cube())) + .tanh()); + } else { + // y = x * normcdf(x) = 0.5 * x * (1 + erf(x / sqrt(2))) + activations.device(d) = + static_cast(0.5) * features * + (static_cast(1) + (features * static_cast(M_SQRT1_2)).erf()); + } + } +}; + +// Functor used by GeluGradOp to do the computations. +template +struct GeluGrad { + // Computes GeluGrad backprops. + // + // gradients: gradients backpropagated to the Gelu op. + // features: either the inputs that were passed to the Gelu or, or its + // outputs (using either one yields the same result here). + // approximate: whether to enable approximation. + // backprops: gradients to backpropagate to the Gelu inputs. + void operator()(const Device& d, typename TTypes::ConstTensor gradients, + typename TTypes::ConstTensor features, bool approximate, + typename TTypes::Tensor backprops) { + if (approximate) { + const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); + const T kBeta = kAlpha * static_cast(0.044715) * static_cast(3); + const auto y = + (kAlpha * ((static_cast(0.044715) * features.cube()) + features)) + .tanh(); + backprops.device(d) = ((-features * y.square() + features) * + (kBeta * features.square() + kAlpha) + + static_cast(1) + y) * + gradients * static_cast(0.5); + } else { + backprops.device(d) = + gradients * (static_cast(M_2_SQRTPI * M_SQRT1_2 * 0.5) * features * + (-features.square() * static_cast(0.5)).exp() + + (static_cast(0.5) * + (static_cast(1) + + (features * static_cast(M_SQRT1_2)).erf()))); + } + } +}; +} template class GeluOp : public UnaryElementWiseOp> { diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_functor.h b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_functor.h deleted file mode 100644 index d0003f0f1c..0000000000 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_functor.h +++ /dev/null @@ -1,90 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_ADDONS_GELU_OP_FUNCTOR_H_ -#define TENSORFLOW_ADDONS_GELU_OP_FUNCTOR_H_ - -#include -#include "tensorflow/core/framework/tensor_types.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" - -namespace tensorflow { -namespace functor { - -// Functor used by GeluOp to do the computations. -template -struct Gelu { - // Computes Gelu activation. - // - // features: any shape. - // approximate: whether to enable approximation. - // activations: same shape as "features". - void operator()(const Device& d, typename TTypes::ConstTensor features, - bool approximate, typename TTypes::Tensor activations) { - if (approximate) { - const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); - - activations.device(d) = - static_cast(0.5) * features * - (static_cast(1) + - (kAlpha * (features + static_cast(0.044715) * features.cube())) - .tanh()); - } else { - activations.device(d) = - static_cast(0.5) * features * - (static_cast(1) + (features * static_cast(M_SQRT1_2)).erf()); - } - } -}; - -// Functor used by GeluGradOp to do the computations. -template -struct GeluGrad { - // Computes GeluGrad backprops. - // - // gradients: gradients backpropagated to the Gelu op. - // features: either the inputs that were passed to the Gelu or, or its - // outputs (using either one yields the same result here). - // approximate: whether to enable approximation. - // backprops: gradients to backpropagate to the Gelu inputs. - void operator()(const Device& d, typename TTypes::ConstTensor gradients, - typename TTypes::ConstTensor features, bool approximate, - typename TTypes::Tensor backprops) { - if (approximate) { - const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); - const T kBeta = kAlpha * static_cast(0.044715) * static_cast(3); - const auto y = - (kAlpha * ((static_cast(0.044715) * features.cube()) + features)) - .tanh(); - backprops.device(d) = ((-features * y.square() + features) * - (kBeta * features.square() + kAlpha) + - static_cast(1) + y) * - gradients * static_cast(0.5); - } else { - const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2 * 0.5); - backprops.device(d) = - gradients * (kAlpha * features * - (-features.square() * static_cast(0.5)).exp() + - (static_cast(0.5) * - (static_cast(1) + - (features * static_cast(M_SQRT1_2)).erf()))); - } - } -}; - -} // namespace functor -} // namespace tensorflow - -#endif // TENSORFLOW_ADDONS_GELU_OP_FUNCTOR_H_ diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_gpu.cu.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_gpu.cu.cc index 5a9fcc7b96..37d21e66e0 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_gpu.cu.cc +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_gpu.cu.cc @@ -17,121 +17,14 @@ limitations under the License. #define EIGEN_USE_GPU -#include +#include "tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h" #include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/util/gpu_kernel_helper.h" -#include "tensorflow/core/util/gpu_launch_config.h" -#include "tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_functor.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "third_party/eigen3/Eigen/Core" namespace tensorflow { using GPUDevice = Eigen::GpuDevice; -namespace functor { - -template -__global__ void ApproximateGeluKernel(const int32 count, const T* input, - T* output) { - // output[i] = 0.5x * (1 + tanh(sqrt(2 / PI) * (x + 0.044715 * x^3))) - const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); - GPU_1D_KERNEL_LOOP(i, count) { - T x = input[i]; - output[i] = static_cast(0.5) * x * - (static_cast(1) + - tanh(kAlpha * (x + static_cast(0.044715) * (x * x * x)))); - } -} - -template -__global__ void GeluKernel(const int32 count, const T* input, T* output) { - // output[i] = x * P(X <= x) = x * normcdf(x) = 0.5x * (1 + erf(x / sqrt(2)) - GPU_1D_KERNEL_LOOP(i, count) { - T x = input[i]; - output[i] = - static_cast(0.5) * x * - (static_cast(1) + Eigen::numext::erf(x * static_cast(M_SQRT1_2))); - } -} - -template -__global__ void ApproximateGeluGradKernel(const int32 count, const T* gradients, - const T* features, T* backprops) { - const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); - const T kBeta = kAlpha * static_cast(0.044715) * static_cast(3); - GPU_1D_KERNEL_LOOP(i, count) { - T x = features[i]; - const T y = tanh(kAlpha * ((static_cast(0.044715) * x * x * x) + x)); - backprops[i] = ((-x * (y * y) + x) * (kBeta * x * x + kAlpha) + - static_cast(1) + y) * - gradients[i] * static_cast(0.5); - } -} - -template -__global__ void GeluGradKernel(const int32 count, const T* gradients, - const T* features, T* backprops) { - const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2 * 0.5); - GPU_1D_KERNEL_LOOP(i, count) { - T x = features[i]; - backprops[i] = - gradients[i] * (kAlpha * x * exp(-x * x * static_cast(0.5)) + - (static_cast(0.5) * - (static_cast(1) + - Eigen::numext::erf(x * static_cast(M_SQRT1_2))))); - } -} - -template -struct Gelu { - // Computes Gelu activation. - // - // features: any shape. - // approximate: whether to enable approximation. - // activations: same shape as "features". - void operator()(const GPUDevice& d, typename TTypes::ConstTensor features, - bool approximate, typename TTypes::Tensor activations) { - const int32 count = features.size(); - if (count == 0) return; - - auto kernel = approximate ? ApproximateGeluKernel : GeluKernel; - - GpuLaunchConfig config = GetGpuLaunchConfig(count, d, kernel, 0, 0); - - TF_CHECK_OK(GpuLaunchKernel(kernel, config.block_count, - config.thread_per_block, 0, d.stream(), count, - features.data(), activations.data())); - } -}; - -template -struct GeluGrad { - // Computes GeluGrad backprop. - // - // gradients: gradient backpropagated to the Gelu op. - // features: either the inputs that were passed to the Gelu, or its outputs - // (using either one yields the same result here). - // approximate: whether to enable approximation. - // backprops: gradient to backpropagate to the Gelu inputs. - void operator()(const GPUDevice& d, typename TTypes::ConstTensor gradients, - typename TTypes::ConstTensor features, bool approximate, - typename TTypes::Tensor backprops) { - const int32 count = gradients.size(); - if (count == 0) return; - - auto kernel = - approximate ? ApproximateGeluGradKernel : GeluGradKernel; - - GpuLaunchConfig config = GetGpuLaunchConfig(count, d, kernel, 0, 0); - - TF_CHECK_OK(GpuLaunchKernel( - kernel, config.block_count, config.thread_per_block, 0, d.stream(), - count, gradients.data(), features.data(), backprops.data())); - } -}; - -} // namespace functor - #define DEFINE_GPU_KERNELS(T) \ template struct functor::Gelu; \ template struct functor::GeluGrad; From 705a08060f3a80a31b122542be43f8637cdac47f Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Sun, 18 Aug 2019 20:28:11 +0800 Subject: [PATCH 06/17] fix typo --- tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc | 2 +- tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc index 5ccacd665f..c27cf6f181 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc @@ -57,7 +57,7 @@ namespace functor { extern template struct GeluGrad; TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); -#undef DECLARE_GPU_DPEC +#undef DECLARE_GPU_SPEC } // namespace functor // Registration of the GPU implementations. diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h index d33b8044db..174227ef5d 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h @@ -85,7 +85,8 @@ struct GeluGrad { } } }; -} + +} // namespace functor template class GeluOp : public UnaryElementWiseOp> { From e95065020a2310eed1b31497284fc3d017fcceb0 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Sun, 18 Aug 2019 20:31:55 +0800 Subject: [PATCH 07/17] format codes --- .../custom_ops/activations/cc/kernels/gelu_op.cc | 6 +++--- .../custom_ops/activations/cc/kernels/gelu_op.h | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc index c27cf6f181..e6a8788519 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc @@ -61,7 +61,7 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); } // namespace functor // Registration of the GPU implementations. -#define REGISTER_GPU_KERNELS(type) \ +#define REGISTER_GELU_GPU_KERNELS(type) \ REGISTER_KERNEL_BUILDER( \ Name("Gelu").Device(DEVICE_GPU).TypeConstraint("T"), \ GeluOp); \ @@ -69,8 +69,8 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); Name("GeluGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ GeluGradOp); -TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); -#undef REGISTER_GPU_KERNELS +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GELU_GPU_KERNELS); +#undef REGISTER_GELU_GPU_KERNELS #endif // GOOGLE_CUDA diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h index 174227ef5d..6228bec00e 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h @@ -86,7 +86,7 @@ struct GeluGrad { } }; -} // namespace functor +} // namespace functor template class GeluOp : public UnaryElementWiseOp> { From d72774a19b726cc1c9ccff4d4347bf344b2666b5 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Tue, 20 Aug 2019 14:55:38 +0800 Subject: [PATCH 08/17] python API for gelu --- tensorflow_addons/activations/gelu.py | 38 +++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 tensorflow_addons/activations/gelu.py diff --git a/tensorflow_addons/activations/gelu.py b/tensorflow_addons/activations/gelu.py new file mode 100644 index 0000000000..667fa68a40 --- /dev/null +++ b/tensorflow_addons/activations/gelu.py @@ -0,0 +1,38 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow_addons.utils import keras_utils +from tensorflow_addons.utils.resource_loader import get_path_to_datafile + + +_activation_ops_so = tf.load_op_library( + get_path_to_datafile("custom_ops/activations/_activation_ops.so")) + + +@keras_utils.register_keras_custom_object +@tf.function +def gelu(x, approximate=True): + x = tf.convert_to_tensor(x) + return _activation_ops_so.gelu(x, approximate) + + +@tf.RegisterGradient("Gelu") +def _gelu_grad(op, grad): + return _activation_ops_so.gelu_grad(grad, op.inputs[0], op.get_attr("approximate")) From 0103fcd5d9d75ec22319f290a5672e6a58f15a7d Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Tue, 20 Aug 2019 14:55:54 +0800 Subject: [PATCH 09/17] unittests for gelu --- tensorflow_addons/activations/gelu_test.py | 108 +++++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 tensorflow_addons/activations/gelu_test.py diff --git a/tensorflow_addons/activations/gelu_test.py b/tensorflow_addons/activations/gelu_test.py new file mode 100644 index 0000000000..19cc9df528 --- /dev/null +++ b/tensorflow_addons/activations/gelu_test.py @@ -0,0 +1,108 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +import math + +import numpy as np +import tensorflow as tf +from tensorflow_addons.activations import gelu +from tensorflow_addons.utils import test_utils + + +def _ref_gelu(x, approximate=True): + x = tf.convert_to_tensor(x) + if approximate: + pi = tf.cast(math.pi, x.dtype) + coeff = tf.cast(0.044715, x.dtype) + return 0.5 * x * (1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3)))) + else: + return 0.5 * x * (1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype))) + + +@test_utils.run_all_in_graph_and_eager_modes +class TestGelu(tf.test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + ("float16", np.float16), + ("float32", np.float32), + ("float64", np.float64) + ) + def test_gelu(self, dtype): + x = np.random.rand(2, 3, 4).astype(dtype) + self.assertAllCloseAccordingToType(gelu.gelu(x), + _ref_gelu(x)) + self.assertAllCloseAccordingToType(gelu.gelu(x, False), + _ref_gelu(x, False)) + + @parameterized.named_parameters( + ("float16", np.float16), + ("float32", np.float32), + ("float64", np.float64) + ) + def test_gradients(self, dtype): + x = tf.constant([1.0, 2.0, 3.0], dtype=dtype) + + for approximate in [True, False]: + with self.subTest(approximate=approximate): + with tf.GradientTape(persistent=True) as tape: + tape.watch(x) + y_ref = _ref_gelu(x, approximate) + y = gelu.gelu(x, approximate) + grad_ref = tape.gradient(y_ref, x) + grad = tape.gradient(y, x) + self.assertAllCloseAccordingToType(grad, grad_ref) + + @parameterized.named_parameters( + ("float32", np.float32), + ("float64", np.float64) + ) + def test_theoretical_gradients(self, dtype): + # Only test theoretical gradients for float32 and float64 + # because of the instability of float16 while computing jacobian + x = tf.constant([1.0, 2.0, 3.0], dtype=dtype) + theoretical, numerical = tf.test.compute_gradient(gelu.gelu, [x]) + self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4) + + def test_unknown_shape(self): + fn = gelu.gelu.get_concrete_function( + tf.TensorSpec(shape=None, dtype=tf.float32)) + + for shape in [(1,), (1, 2), (1, 2, 3), (1, 2, 3, 4)]: + x = tf.ones(shape=shape, dtype=tf.float32) + self.assertAllClose(fn(x), gelu.gelu(x)) + + def test_serialization(self): + ref_fn = gelu.gelu + config = tf.keras.activations.serialize(ref_fn) + fn = tf.keras.activations.deserialize(config) + self.assertEqual(fn, ref_fn) + + def test_serialization_with_layers(self): + layer = tf.keras.layers.Dense(3, activation=gelu.gelu) + config = tf.keras.layers.serialize(layer) + deserialized_layer = tf.keras.layers.deserialize(config) + self.assertEqual(deserialized_layer.__class__.__name__, + layer.__class__.__name__) + self.assertEqual(deserialized_layer.activation.__name__, "gelu") + + +if __name__ == "__main__": + tf.test.main() From 15d2c284c1cb6768faf3722f9208bfd1232eddb8 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Tue, 20 Aug 2019 14:56:07 +0800 Subject: [PATCH 10/17] update BUILD file --- tensorflow_addons/activations/BUILD | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/tensorflow_addons/activations/BUILD b/tensorflow_addons/activations/BUILD index d454860322..34e87c6298 100644 --- a/tensorflow_addons/activations/BUILD +++ b/tensorflow_addons/activations/BUILD @@ -6,12 +6,14 @@ py_library( name = "activations", srcs = [ "__init__.py", + "gelu.py", "sparsemax.py", ], - srcs_version = "PY2AND3", - deps = [ + data = [ + "//tensorflow_addons/custom_ops/activations:_activation_ops.so", "//tensorflow_addons/utils", ], + srcs_version = "PY2AND3", ) py_test( @@ -26,3 +28,16 @@ py_test( ":activations", ], ) + +py_test( + name = "gelu_test", + size = "large", + srcs = [ + "gelu_test.py", + ], + main = "gelu_test.py", + srcs_version = "PY2AND3", + deps = [ + ":activations", + ], +) From b85b7a33ba084f2d9ba0508b90f266cfd23a97b0 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Tue, 20 Aug 2019 15:06:07 +0800 Subject: [PATCH 11/17] lint --- tensorflow_addons/activations/gelu.py | 4 +-- tensorflow_addons/activations/gelu_test.py | 38 +++++++++------------- 2 files changed, 18 insertions(+), 24 deletions(-) diff --git a/tensorflow_addons/activations/gelu.py b/tensorflow_addons/activations/gelu.py index 667fa68a40..2f8706ab00 100644 --- a/tensorflow_addons/activations/gelu.py +++ b/tensorflow_addons/activations/gelu.py @@ -21,7 +21,6 @@ from tensorflow_addons.utils import keras_utils from tensorflow_addons.utils.resource_loader import get_path_to_datafile - _activation_ops_so = tf.load_op_library( get_path_to_datafile("custom_ops/activations/_activation_ops.so")) @@ -35,4 +34,5 @@ def gelu(x, approximate=True): @tf.RegisterGradient("Gelu") def _gelu_grad(op, grad): - return _activation_ops_so.gelu_grad(grad, op.inputs[0], op.get_attr("approximate")) + return _activation_ops_so.gelu_grad(grad, op.inputs[0], + op.get_attr("approximate")) diff --git a/tensorflow_addons/activations/gelu_test.py b/tensorflow_addons/activations/gelu_test.py index 19cc9df528..583a7ef3ce 100644 --- a/tensorflow_addons/activations/gelu_test.py +++ b/tensorflow_addons/activations/gelu_test.py @@ -32,31 +32,27 @@ def _ref_gelu(x, approximate=True): if approximate: pi = tf.cast(math.pi, x.dtype) coeff = tf.cast(0.044715, x.dtype) - return 0.5 * x * (1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3)))) + return 0.5 * x * ( + 1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3)))) else: - return 0.5 * x * (1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype))) + return 0.5 * x * ( + 1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype))) @test_utils.run_all_in_graph_and_eager_modes class TestGelu(tf.test.TestCase, parameterized.TestCase): - - @parameterized.named_parameters( - ("float16", np.float16), - ("float32", np.float32), - ("float64", np.float64) - ) + @parameterized.named_parameters(("float16", np.float16), + ("float32", np.float32), + ("float64", np.float64)) def test_gelu(self, dtype): x = np.random.rand(2, 3, 4).astype(dtype) - self.assertAllCloseAccordingToType(gelu.gelu(x), - _ref_gelu(x)) - self.assertAllCloseAccordingToType(gelu.gelu(x, False), - _ref_gelu(x, False)) - - @parameterized.named_parameters( - ("float16", np.float16), - ("float32", np.float32), - ("float64", np.float64) - ) + self.assertAllCloseAccordingToType(gelu.gelu(x), _ref_gelu(x)) + self.assertAllCloseAccordingToType( + gelu.gelu(x, False), _ref_gelu(x, False)) + + @parameterized.named_parameters(("float16", np.float16), + ("float32", np.float32), + ("float64", np.float64)) def test_gradients(self, dtype): x = tf.constant([1.0, 2.0, 3.0], dtype=dtype) @@ -70,10 +66,8 @@ def test_gradients(self, dtype): grad = tape.gradient(y, x) self.assertAllCloseAccordingToType(grad, grad_ref) - @parameterized.named_parameters( - ("float32", np.float32), - ("float64", np.float64) - ) + @parameterized.named_parameters(("float32", np.float32), + ("float64", np.float64)) def test_theoretical_gradients(self, dtype): # Only test theoretical gradients for float32 and float64 # because of the instability of float16 while computing jacobian From 8d0f2e5008fa10fd82d3bd6c3730cca7b1efee14 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Wed, 21 Aug 2019 10:50:49 +0800 Subject: [PATCH 12/17] update init and README --- tensorflow_addons/activations/README.md | 14 ++++++++------ tensorflow_addons/activations/__init__.py | 1 + tensorflow_addons/activations/gelu_test.py | 17 ++++++++--------- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/tensorflow_addons/activations/README.md b/tensorflow_addons/activations/README.md index 4ab59b23bb..f6d51ea9db 100644 --- a/tensorflow_addons/activations/README.md +++ b/tensorflow_addons/activations/README.md @@ -1,14 +1,16 @@ # Addons - Activations ## Maintainers -| Submodule | Maintainers | Contact Info | -|:---------- |:------------- |:--------------| -| sparsemax | @AndreasMadsen | amwwebdk+github@gmail.com | +| Submodule | Maintainers | Contact Info | +|:----------|:--------------------------|:-----------------------------------------| +| sparsemax | @AndreasMadsen | amwwebdk+github@gmail.com | +| gelu | @AakashKumarNain @WindQAQ | aakashnain@outlook.com windqaq@gmail.com | ## Contents -| Submodule | Activation | Reference | -|:----------------------- |:-------------------|:---------------| -| sparsemax | Sparsemax | https://arxiv.org/abs/1602.02068 | +| Submodule | Activation | Reference | +|:----------|:-----------|:---------------------------------| +| gelu | gelu | https://arxiv.org/abs/1606.08415 | +| sparsemax | Sparsemax | https://arxiv.org/abs/1602.02068 | ## Contribution Guidelines diff --git a/tensorflow_addons/activations/__init__.py b/tensorflow_addons/activations/__init__.py index 5792d00356..45903a3975 100644 --- a/tensorflow_addons/activations/__init__.py +++ b/tensorflow_addons/activations/__init__.py @@ -18,4 +18,5 @@ from __future__ import division from __future__ import print_function +from tensorflow_addons.activations.gelu import gelu from tensorflow_addons.activations.sparsemax import sparsemax diff --git a/tensorflow_addons/activations/gelu_test.py b/tensorflow_addons/activations/gelu_test.py index 583a7ef3ce..846f217635 100644 --- a/tensorflow_addons/activations/gelu_test.py +++ b/tensorflow_addons/activations/gelu_test.py @@ -46,9 +46,8 @@ class TestGelu(tf.test.TestCase, parameterized.TestCase): ("float64", np.float64)) def test_gelu(self, dtype): x = np.random.rand(2, 3, 4).astype(dtype) - self.assertAllCloseAccordingToType(gelu.gelu(x), _ref_gelu(x)) - self.assertAllCloseAccordingToType( - gelu.gelu(x, False), _ref_gelu(x, False)) + self.assertAllCloseAccordingToType(gelu(x), _ref_gelu(x)) + self.assertAllCloseAccordingToType(gelu(x, False), _ref_gelu(x, False)) @parameterized.named_parameters(("float16", np.float16), ("float32", np.float32), @@ -61,7 +60,7 @@ def test_gradients(self, dtype): with tf.GradientTape(persistent=True) as tape: tape.watch(x) y_ref = _ref_gelu(x, approximate) - y = gelu.gelu(x, approximate) + y = gelu(x, approximate) grad_ref = tape.gradient(y_ref, x) grad = tape.gradient(y, x) self.assertAllCloseAccordingToType(grad, grad_ref) @@ -72,25 +71,25 @@ def test_theoretical_gradients(self, dtype): # Only test theoretical gradients for float32 and float64 # because of the instability of float16 while computing jacobian x = tf.constant([1.0, 2.0, 3.0], dtype=dtype) - theoretical, numerical = tf.test.compute_gradient(gelu.gelu, [x]) + theoretical, numerical = tf.test.compute_gradient(gelu, [x]) self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4) def test_unknown_shape(self): - fn = gelu.gelu.get_concrete_function( + fn = gelu.get_concrete_function( tf.TensorSpec(shape=None, dtype=tf.float32)) for shape in [(1,), (1, 2), (1, 2, 3), (1, 2, 3, 4)]: x = tf.ones(shape=shape, dtype=tf.float32) - self.assertAllClose(fn(x), gelu.gelu(x)) + self.assertAllClose(fn(x), gelu(x)) def test_serialization(self): - ref_fn = gelu.gelu + ref_fn = gelu config = tf.keras.activations.serialize(ref_fn) fn = tf.keras.activations.deserialize(config) self.assertEqual(fn, ref_fn) def test_serialization_with_layers(self): - layer = tf.keras.layers.Dense(3, activation=gelu.gelu) + layer = tf.keras.layers.Dense(3, activation=gelu) config = tf.keras.layers.serialize(layer) deserialized_layer = tf.keras.layers.deserialize(config) self.assertEqual(deserialized_layer.__class__.__name__, From 240cbfbbb050ec5ed618584d3f76122bafa2d3ce Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Wed, 21 Aug 2019 10:55:07 +0800 Subject: [PATCH 13/17] alphabetical order --- tensorflow_addons/activations/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/activations/README.md b/tensorflow_addons/activations/README.md index f6d51ea9db..500eee194b 100644 --- a/tensorflow_addons/activations/README.md +++ b/tensorflow_addons/activations/README.md @@ -3,8 +3,8 @@ ## Maintainers | Submodule | Maintainers | Contact Info | |:----------|:--------------------------|:-----------------------------------------| -| sparsemax | @AndreasMadsen | amwwebdk+github@gmail.com | | gelu | @AakashKumarNain @WindQAQ | aakashnain@outlook.com windqaq@gmail.com | +| sparsemax | @AndreasMadsen | amwwebdk+github@gmail.com | ## Contents | Submodule | Activation | Reference | From 517f59efa983ebea33ac788245195c9ddf7b1727 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Wed, 21 Aug 2019 14:21:07 +0800 Subject: [PATCH 14/17] update docs --- tensorflow_addons/activations/gelu.py | 17 +++++++++++++++++ .../custom_ops/activations/cc/kernels/gelu_op.h | 2 +- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/tensorflow_addons/activations/gelu.py b/tensorflow_addons/activations/gelu.py index 2f8706ab00..539afbbe1c 100644 --- a/tensorflow_addons/activations/gelu.py +++ b/tensorflow_addons/activations/gelu.py @@ -28,6 +28,23 @@ @keras_utils.register_keras_custom_object @tf.function def gelu(x, approximate=True): + """Gaussian Error Linear Unit. + + Computes gaussian error linear: + `0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x^3)))` or + `x * P(X <= x) = 0.5 * x * (1 + erf(x / sqrt(2)))`, where P(X) ~ N(0, 1), + depending on whether approximation is enabled. + + See [Gaussian Error Linear Units (GELUs)](https://arxiv.org/abs/1606.08415) + and [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805). + + Args: + x: A `Tensor`. Must be one of the following types: + `float16`, `float32`, `float64`. + approximate: bool, whether to enable approximation. + Returns: + A `Tensor`. Has the same type as `x`. + """ x = tf.convert_to_tensor(x) return _activation_ops_so.gelu(x, approximate) diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h index 6228bec00e..c4766d6c1b 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h @@ -36,7 +36,7 @@ struct Gelu { void operator()(const Device& d, typename TTypes::ConstTensor features, bool approximate, typename TTypes::Tensor activations) { if (approximate) { - // y = 0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.44715 * x^3))) + // y = 0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x^3))) activations.device(d) = static_cast(0.5) * features * (static_cast(1) + From 2f2bb8833636a0afa9a98626f0bbf52ca5222949 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Wed, 21 Aug 2019 21:58:31 +0800 Subject: [PATCH 15/17] update docs --- tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h index c4766d6c1b..a0469f3571 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h @@ -58,8 +58,7 @@ struct GeluGrad { // Computes GeluGrad backprops. // // gradients: gradients backpropagated to the Gelu op. - // features: either the inputs that were passed to the Gelu or, or its - // outputs (using either one yields the same result here). + // features: the inputs that were passed to the Gelu op. // approximate: whether to enable approximation. // backprops: gradients to backpropagate to the Gelu inputs. void operator()(const Device& d, typename TTypes::ConstTensor gradients, From 4d4faa4d33cb1f979d4190aede698f6725ce04e0 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Fri, 23 Aug 2019 13:53:05 +0800 Subject: [PATCH 16/17] test gradients on non-approximate gelu --- tensorflow_addons/activations/gelu_test.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tensorflow_addons/activations/gelu_test.py b/tensorflow_addons/activations/gelu_test.py index 846f217635..8b3e87c2d0 100644 --- a/tensorflow_addons/activations/gelu_test.py +++ b/tensorflow_addons/activations/gelu_test.py @@ -71,8 +71,13 @@ def test_theoretical_gradients(self, dtype): # Only test theoretical gradients for float32 and float64 # because of the instability of float16 while computing jacobian x = tf.constant([1.0, 2.0, 3.0], dtype=dtype) - theoretical, numerical = tf.test.compute_gradient(gelu, [x]) - self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4) + + for approximate in [True, False]: + with self.subTest(approximate=approximate): + theoretical, numerical = tf.test.compute_gradient( + lambda x: gelu(x, approximate=approximate), [x]) + self.assertAllCloseAccordingToType( + theoretical, numerical, atol=1e-4) def test_unknown_shape(self): fn = gelu.get_concrete_function( From 0a93ff27dc03f2ff07282402f9e2f1ec1dc399c0 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Sun, 25 Aug 2019 14:09:23 +0800 Subject: [PATCH 17/17] change test name --- tensorflow_addons/activations/gelu_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/activations/gelu_test.py b/tensorflow_addons/activations/gelu_test.py index 8b3e87c2d0..f510715593 100644 --- a/tensorflow_addons/activations/gelu_test.py +++ b/tensorflow_addons/activations/gelu_test.py @@ -40,7 +40,7 @@ def _ref_gelu(x, approximate=True): @test_utils.run_all_in_graph_and_eager_modes -class TestGelu(tf.test.TestCase, parameterized.TestCase): +class GeluTest(tf.test.TestCase, parameterized.TestCase): @parameterized.named_parameters(("float16", np.float16), ("float32", np.float32), ("float64", np.float64))