From 308c27b0f73fe76fbccedbe7b849160992c89320 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Wed, 11 Sep 2019 15:41:06 +0800 Subject: [PATCH 1/6] add hardshrink kernel --- tensorflow_addons/activations/BUILD | 16 ++- tensorflow_addons/activations/README.md | 2 + tensorflow_addons/activations/__init__.py | 1 + tensorflow_addons/activations/hardshrink.py | 51 +++++++ .../activations/hardshrink_test.py | 97 +++++++++++++ .../custom_ops/activations/BUILD | 30 +++- .../activations/cc/kernels/hardshrink_op.cc | 77 ++++++++++ .../activations/cc/kernels/hardshrink_op.h | 131 ++++++++++++++++++ .../cc/kernels/hardshrink_op_gpu_cu.cc | 36 +++++ .../activations/cc/ops/hardshrink_op.cc | 39 ++++++ 10 files changed, 478 insertions(+), 2 deletions(-) create mode 100644 tensorflow_addons/activations/hardshrink.py create mode 100644 tensorflow_addons/activations/hardshrink_test.py create mode 100644 tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.cc create mode 100644 tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.h create mode 100644 tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op_gpu_cu.cc create mode 100644 tensorflow_addons/custom_ops/activations/cc/ops/hardshrink_op.cc diff --git a/tensorflow_addons/activations/BUILD b/tensorflow_addons/activations/BUILD index 34e87c6298..64c1f984d7 100644 --- a/tensorflow_addons/activations/BUILD +++ b/tensorflow_addons/activations/BUILD @@ -7,6 +7,7 @@ py_library( srcs = [ "__init__.py", "gelu.py", + "hardshrink.py", "sparsemax.py", ], data = [ @@ -31,7 +32,7 @@ py_test( py_test( name = "gelu_test", - size = "large", + size = "medium", srcs = [ "gelu_test.py", ], @@ -41,3 +42,16 @@ py_test( ":activations", ], ) + +py_test( + name = "hardshrink_test", + size = "medium", + srcs = [ + "hardshrink_test.py", + ], + main = "hardshrink_test.py", + srcs_version = "PY2AND3", + deps = [ + ":activations", + ], +) diff --git a/tensorflow_addons/activations/README.md b/tensorflow_addons/activations/README.md index 500eee194b..3531deb89f 100644 --- a/tensorflow_addons/activations/README.md +++ b/tensorflow_addons/activations/README.md @@ -4,12 +4,14 @@ | Submodule | Maintainers | Contact Info | |:----------|:--------------------------|:-----------------------------------------| | gelu | @AakashKumarNain @WindQAQ | aakashnain@outlook.com windqaq@gmail.com | +| hardshrink| @WindQAQ | windqaq@gmail.com | sparsemax | @AndreasMadsen | amwwebdk+github@gmail.com | ## Contents | Submodule | Activation | Reference | |:----------|:-----------|:---------------------------------| | gelu | gelu | https://arxiv.org/abs/1606.08415 | +| hardshrink| hardshrnk | | | sparsemax | Sparsemax | https://arxiv.org/abs/1602.02068 | diff --git a/tensorflow_addons/activations/__init__.py b/tensorflow_addons/activations/__init__.py index 45903a3975..5e1123b1b3 100644 --- a/tensorflow_addons/activations/__init__.py +++ b/tensorflow_addons/activations/__init__.py @@ -19,4 +19,5 @@ from __future__ import print_function from tensorflow_addons.activations.gelu import gelu +from tensorflow_addons.activations.hardshrink import hardshrink from tensorflow_addons.activations.sparsemax import sparsemax diff --git a/tensorflow_addons/activations/hardshrink.py b/tensorflow_addons/activations/hardshrink.py new file mode 100644 index 0000000000..640fdef623 --- /dev/null +++ b/tensorflow_addons/activations/hardshrink.py @@ -0,0 +1,51 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow_addons.utils import keras_utils +from tensorflow_addons.utils.resource_loader import get_path_to_datafile + +_activation_ops_so = tf.load_op_library( + get_path_to_datafile("custom_ops/activations/_activation_ops.so")) + + +@keras_utils.register_keras_custom_object +@tf.function +def hardshrink(x, lower=-1.0, upper=1.0): + """Hard shrink function. + + Computes hard shrink function: + `x if x < lower or x > upper else 0`. + + Args: + x: A `Tensor`. Must be one of the following types: + `float16`, `float32`, `float64`. + lower: `float`, lower bound for setting values to zeros. + upper: `float`, upper bound for setting values to zeros. + Returns: + A `Tensor`. Has the same type as `x`. + """ + x = tf.convert_to_tensor(x) + return _activation_ops_so.hardshrink(x, lower, upper) + + +@tf.RegisterGradient("Hardshrink") +def _hardshrink_grad(op, grad): + return _activation_ops_so.hardshrink_grad(grad, op.inputs[0], + op.get_attr("lower"), op.get_attr("upper")) diff --git a/tensorflow_addons/activations/hardshrink_test.py b/tensorflow_addons/activations/hardshrink_test.py new file mode 100644 index 0000000000..64f7a53f7b --- /dev/null +++ b/tensorflow_addons/activations/hardshrink_test.py @@ -0,0 +1,97 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +import numpy as np +import tensorflow as tf +from tensorflow_addons.activations import hardshrink +from tensorflow_addons.utils import test_utils + + +def _ref_hardshrink(x, lower=-1.0, upper=1.0): + x = tf.convert_to_tensor(x) + return tf.where(tf.math.logical_or(x < lower, x > upper), x, 0.0) + + +@test_utils.run_all_in_graph_and_eager_modes +class HardshrinkTest(tf.test.TestCase, parameterized.TestCase): + def test_invalid(self): + with self.assertRaisesOpError("lower must be less than or equal to upper."): + y = hardshrink(tf.ones(shape=(1, 2, 3)), lower=2.0, upper=-2.0) + self.evaluate(y) + + @parameterized.named_parameters(("float16", np.float16), + ("float32", np.float32), + ("float64", np.float64)) + def test_hardshrink(self, dtype): + x = (np.random.rand(2, 3, 4) * 2.0 - 1.0).astype(dtype) + self.assertAllCloseAccordingToType(hardshrink(x), _ref_hardshrink(x)) + self.assertAllCloseAccordingToType(hardshrink(x, -2.0, 2.0), _ref_hardshrink(x, -2.0, 2.0)) + + @parameterized.named_parameters(("float16", np.float16), + ("float32", np.float32), + ("float64", np.float64)) + def test_gradients(self, dtype): + x = tf.constant([-1.5, -0.5, 0.5, 1.5], dtype=dtype) + + with tf.GradientTape(persistent=True) as tape: + tape.watch(x) + y_ref = _ref_hardshrink(x) + y = hardshrink(x) + grad_ref = tape.gradient(y_ref, x) + grad = tape.gradient(y, x) + self.assertAllCloseAccordingToType(grad, grad_ref) + + @parameterized.named_parameters(("float32", np.float32), + ("float64", np.float64)) + def test_theoretical_gradients(self, dtype): + # Only test theoretical gradients for float32 and float64 + # because of the instability of float16 while computing jacobian + x = tf.constant([-1.5, -0.5, 0.5, 1.5], dtype=dtype) + + theoretical, numerical = tf.test.compute_gradient( + lambda x: hardshrink(x), [x]) + self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4) + + def test_unknown_shape(self): + fn = hardshrink.get_concrete_function( + tf.TensorSpec(shape=None, dtype=tf.float32)) + + for shape in [(1,), (1, 2), (1, 2, 3), (1, 2, 3, 4)]: + x = tf.ones(shape=shape, dtype=tf.float32) + self.assertAllClose(fn(x), hardshrink(x)) + + def test_serialization(self): + ref_fn = hardshrink + config = tf.keras.activations.serialize(ref_fn) + fn = tf.keras.activations.deserialize(config) + self.assertEqual(fn, ref_fn) + + def test_serialization_with_layers(self): + layer = tf.keras.layers.Dense(3, activation=hardshrink) + config = tf.keras.layers.serialize(layer) + deserialized_layer = tf.keras.layers.deserialize(config) + self.assertEqual(deserialized_layer.__class__.__name__, + layer.__class__.__name__) + self.assertEqual(deserialized_layer.activation.__name__, "hardshrink") + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_addons/custom_ops/activations/BUILD b/tensorflow_addons/custom_ops/activations/BUILD index a199fbc689..f4778cf23c 100644 --- a/tensorflow_addons/custom_ops/activations/BUILD +++ b/tensorflow_addons/custom_ops/activations/BUILD @@ -27,12 +27,37 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "hardshrink_op_gpu", + srcs = [ + "cc/kernels/hardshrink_op.h", + "cc/kernels/hardshrink_op_gpu.cu.cc", + ], + copts = if_cuda_is_configured([ + "-DGOOGLE_CUDA=1", + "-x cuda", + "-nvcc_options=relaxed-constexpr", + "-nvcc_options=ftz=true", + ]), + deps = [ + "@local_config_tf//:libtensorflow_framework", + "@local_config_tf//:tf_header_lib", + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_libs", + "@local_config_cuda//cuda:cuda_headers", + ]), + alwayslink = 1, +) + cc_binary( name = "_activation_ops.so", srcs = [ "cc/kernels/gelu_op.cc", "cc/kernels/gelu_op.h", + "cc/kernels/hardshrink_op.cc", + "cc/kernels/hardshrink_op.h", "cc/ops/gelu_op.cc", + "cc/ops/hardshrink_op.cc", ], copts = [ "-pthread", @@ -43,5 +68,8 @@ cc_binary( deps = [ "@local_config_tf//:libtensorflow_framework", "@local_config_tf//:tf_header_lib", - ] + if_cuda_is_configured([":gelu_op_gpu"]), + ] + if_cuda_is_configured([ + ":gelu_op_gpu", + ":hardshrink_op_gpu", + ]), ) diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.cc new file mode 100644 index 0000000000..d4bd4cf163 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.cc @@ -0,0 +1,77 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +using CPUDevice = Eigen::ThreadPoolDevice; + +#define REGISTER_HARDSHRINK_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Hardshrink").Device(DEVICE_CPU).TypeConstraint("T"), \ + HardshrinkOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("HardshrinkGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ + HardshrinkGradOp); + +// Hardshrink only makes sense with floating points. +TF_CALL_GPU_NUMBER_TYPES(REGISTER_HARDSHRINK_KERNELS); +#undef REGISTER_HARDSHRINK_KERNELS + +#if GOOGLE_CUDA + +using GPUDevice = Eigen::GpuDevice; + +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void Hardshrink::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor features, \ + T lower, T upper, typename TTypes::Tensor activations); \ + extern template struct Hardshrink; \ + \ + template <> \ + void HardshrinkGrad::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor gradients, \ + typename TTypes::ConstTensor features, T lower, T upper, \ + typename TTypes::Tensor backprops); \ + extern template struct HardshrinkGrad; + +TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); +#undef DECLARE_GPU_SPEC +} // namespace functor + +// Registration of the GPU implementations. +#define REGISTER_HARDSHRINK_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Hardshrink").Device(DEVICE_GPU).TypeConstraint("T"), \ + HardshrinkOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Hardshrink").Device(DEVICE_GPU).TypeConstraint("T"), \ + HardshrinkGradOp); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_HARDSHRINK_GPU_KERNELS); +#undef REGISTER_HARDSHRINK_GPU_KERNELS + +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.h b/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.h new file mode 100644 index 0000000000..76974769c4 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.h @@ -0,0 +1,131 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_ADDONS_HARDSHRINK_OP_H_ +#define TENSORFLOW_ADDONS_HARDSHRINK_OP_H_ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace functor { + +// Functor used by HardshrinkOp to do the computations. +template +struct Hardshrink { + // Computes Hardshrink activation. + // + // features: any shape. + // lower: the lower bound for setting values to zeros. + // upper: the upper bound for setting values to zeros. + // activations: same shape as "features". + void operator()(const Device& d, typename TTypes::ConstTensor features, + T lower, T upper, typename TTypes::Tensor activations) { + activations.device(d) = (features < lower || features > upper).select(features, features.constant(static_cast(0))); + } +}; + +// Functor used by HardshrinkGradOp to do the computations. +template +struct HardshrinkGrad { + // Computes HardshrinkGrad backprops. + // + // gradients: gradients backpropagated to the Hardshink op. + // features: the inputs that were passed to the Hardshrink op. + // lower: the lower bound for setting values to zeros. + // upper: the upper bound for setting values to zeros. + // backprops: gradients to backpropagate to the Hardshrink inputs. + void operator()(const Device& d, typename TTypes::ConstTensor gradients, + typename TTypes::ConstTensor features, T lower, T upper, + typename TTypes::Tensor backprops) { + backprops.device(d) = (features < lower || features > upper).select(gradients, features.constant(static_cast(0))); + } +}; + +} // namespace functor + +template +class HardshrinkOp : public UnaryElementWiseOp> { + public: + explicit HardshrinkOp(OpKernelConstruction* context) + : UnaryElementWiseOp>::UnaryElementWiseOp(context) { + float lower, upper; + OP_REQUIRES_OK(context, context->GetAttr("lower", &lower)); + OP_REQUIRES_OK(context, context->GetAttr("upper", &upper)); + lower_ = static_cast(lower); + upper_ = static_cast(upper); + + OP_REQUIRES(context, lower_ <= upper_, errors::InvalidArgument("lower must be less than or equal to upper.")); + } + + void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { + functor::Hardshrink functor; + functor(context->eigen_device(), input.flat(), lower_, upper_, + output->flat()); + } + + private: + T lower_; + T upper_; +}; + +template +class HardshrinkGradOp : public BinaryElementWiseOp> { + public: + explicit HardshrinkGradOp(OpKernelConstruction* context) + : BinaryElementWiseOp>::BinaryElementWiseOp( + context) { + float lower, upper; + OP_REQUIRES_OK(context, context->GetAttr("lower", &lower)); + OP_REQUIRES_OK(context, context->GetAttr("upper", &upper)); + lower_ = static_cast(lower); + upper_ = static_cast(upper); + + OP_REQUIRES(context, lower_ <= upper_, errors::InvalidArgument("lower must be less than or equal to upper.")); + } + + void OperateNoTemplate(OpKernelContext* context, const Tensor& g, + const Tensor& a, T lower, T upper, Tensor* output); + + template + void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, + Tensor* output) { + OperateNoTemplate(context, g, a, lower_, upper_, output); + } + + private: + T lower_; + T upper_; +}; + +template +void HardshrinkGradOp::OperateNoTemplate(OpKernelContext* context, + const Tensor& g, const Tensor& a, + T lower, T upper, + Tensor* output) { + functor::HardshrinkGrad functor; + functor(context->eigen_device(), g.flat(), a.flat(), + lower, upper, output->flat()); +} + +} // namespace tensorflow + +#undef EIGEN_USE_THREADS + +#endif // TENSORFLOW_ADDONS_HARDSHRINK_OP_H_ diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op_gpu_cu.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op_gpu_cu.cc new file mode 100644 index 0000000000..dc90c5af4b --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op_gpu_cu.cc @@ -0,0 +1,36 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.h" +#include "tensorflow/core/framework/register_types.h" +#include "third_party/eigen3/Eigen/Core" + +namespace tensorflow { + +using GPUDevice = Eigen::GpuDevice; + +#define DEFINE_GPU_KERNELS(T) \ + template struct functor::Hardshrink; \ + template struct functor::HardshrinkGrad; + +TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow_addons/custom_ops/activations/cc/ops/hardshrink_op.cc b/tensorflow_addons/custom_ops/activations/cc/ops/hardshrink_op.cc new file mode 100644 index 0000000000..b48f07b917 --- /dev/null +++ b/tensorflow_addons/custom_ops/activations/cc/ops/hardshrink_op.cc @@ -0,0 +1,39 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("Hardshrink") + .Input("features: T") + .Output("activations: T") + .Attr("T: {half, float, double}") + .Attr("lower: float = -1.0") + .Attr("upper: float = 1.0") + .SetShapeFn(shape_inference::UnchangedShape); + +REGISTER_OP("HardshrinkGrad") + .Input("gradients: T") + .Input("features: T") + .Output("backprops: T") + .Attr("T: {half, float, double}") + .Attr("lower: float = -1.0") + .Attr("upper: float = 1.0") + .SetShapeFn(shape_inference::MergeBothInputsShapeFn); + +} // namespace tensorflow From dcc9ed99d6fa0ccff14b02b2c3a0540aca4b2af1 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Wed, 11 Sep 2019 16:04:25 +0800 Subject: [PATCH 2/6] fix typo --- tensorflow_addons/activations/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/activations/README.md b/tensorflow_addons/activations/README.md index 3531deb89f..b61926dc8e 100644 --- a/tensorflow_addons/activations/README.md +++ b/tensorflow_addons/activations/README.md @@ -11,7 +11,7 @@ | Submodule | Activation | Reference | |:----------|:-----------|:---------------------------------| | gelu | gelu | https://arxiv.org/abs/1606.08415 | -| hardshrink| hardshrnk | | +| hardshrink| hardshrink | | | sparsemax | Sparsemax | https://arxiv.org/abs/1602.02068 | From 10811c317f76f7e6322aa27d259edab2b50cfe34 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Wed, 11 Sep 2019 16:41:31 +0800 Subject: [PATCH 3/6] fix cc name --- .../kernels/{hardshrink_op_gpu_cu.cc => hardshrink_op_gpu.cu.cc} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tensorflow_addons/custom_ops/activations/cc/kernels/{hardshrink_op_gpu_cu.cc => hardshrink_op_gpu.cu.cc} (100%) diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op_gpu_cu.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op_gpu.cu.cc similarity index 100% rename from tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op_gpu_cu.cc rename to tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op_gpu.cu.cc From adb6d9ac20b7d789c046bc9847d627566e10483d Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Wed, 11 Sep 2019 17:24:14 +0800 Subject: [PATCH 4/6] format codes --- tensorflow_addons/activations/hardshrink.py | 3 +- .../activations/hardshrink_test.py | 6 ++- .../activations/cc/kernels/hardshrink_op.cc | 38 ++++++++--------- .../activations/cc/kernels/hardshrink_op.h | 42 ++++++++++++------- .../cc/kernels/hardshrink_op_gpu.cu.cc | 2 +- 5 files changed, 52 insertions(+), 39 deletions(-) diff --git a/tensorflow_addons/activations/hardshrink.py b/tensorflow_addons/activations/hardshrink.py index 640fdef623..998a2b4899 100644 --- a/tensorflow_addons/activations/hardshrink.py +++ b/tensorflow_addons/activations/hardshrink.py @@ -48,4 +48,5 @@ def hardshrink(x, lower=-1.0, upper=1.0): @tf.RegisterGradient("Hardshrink") def _hardshrink_grad(op, grad): return _activation_ops_so.hardshrink_grad(grad, op.inputs[0], - op.get_attr("lower"), op.get_attr("upper")) + op.get_attr("lower"), + op.get_attr("upper")) diff --git a/tensorflow_addons/activations/hardshrink_test.py b/tensorflow_addons/activations/hardshrink_test.py index 64f7a53f7b..a16b9be3b9 100644 --- a/tensorflow_addons/activations/hardshrink_test.py +++ b/tensorflow_addons/activations/hardshrink_test.py @@ -33,7 +33,8 @@ def _ref_hardshrink(x, lower=-1.0, upper=1.0): @test_utils.run_all_in_graph_and_eager_modes class HardshrinkTest(tf.test.TestCase, parameterized.TestCase): def test_invalid(self): - with self.assertRaisesOpError("lower must be less than or equal to upper."): + with self.assertRaisesOpError( + "lower must be less than or equal to upper."): # pylint: disable=bad-continuation y = hardshrink(tf.ones(shape=(1, 2, 3)), lower=2.0, upper=-2.0) self.evaluate(y) @@ -43,7 +44,8 @@ def test_invalid(self): def test_hardshrink(self, dtype): x = (np.random.rand(2, 3, 4) * 2.0 - 1.0).astype(dtype) self.assertAllCloseAccordingToType(hardshrink(x), _ref_hardshrink(x)) - self.assertAllCloseAccordingToType(hardshrink(x, -2.0, 2.0), _ref_hardshrink(x, -2.0, 2.0)) + self.assertAllCloseAccordingToType( + hardshrink(x, -2.0, 2.0), _ref_hardshrink(x, -2.0, 2.0)) @parameterized.named_parameters(("float16", np.float16), ("float32", np.float32), diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.cc index d4bd4cf163..eaeacda25a 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.cc +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.cc @@ -25,10 +25,10 @@ namespace tensorflow { using CPUDevice = Eigen::ThreadPoolDevice; #define REGISTER_HARDSHRINK_KERNELS(type) \ - REGISTER_KERNEL_BUILDER( \ + REGISTER_KERNEL_BUILDER( \ Name("Hardshrink").Device(DEVICE_CPU).TypeConstraint("T"), \ HardshrinkOp); \ - REGISTER_KERNEL_BUILDER( \ + REGISTER_KERNEL_BUILDER( \ Name("HardshrinkGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ HardshrinkGradOp); @@ -42,18 +42,18 @@ using GPUDevice = Eigen::GpuDevice; // Forward declarations of the functor specializations for GPU. namespace functor { -#define DECLARE_GPU_SPEC(T) \ - template <> \ - void Hardshrink::operator()( \ - const GPUDevice& d, typename TTypes::ConstTensor features, \ - T lower, T upper, typename TTypes::Tensor activations); \ - extern template struct Hardshrink; \ - \ - template <> \ - void HardshrinkGrad::operator()( \ - const GPUDevice& d, typename TTypes::ConstTensor gradients, \ - typename TTypes::ConstTensor features, T lower, T upper, \ - typename TTypes::Tensor backprops); \ +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void Hardshrink::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor features, T lower, \ + T upper, typename TTypes::Tensor activations); \ + extern template struct Hardshrink; \ + \ + template <> \ + void HardshrinkGrad::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor gradients, \ + typename TTypes::ConstTensor features, T lower, T upper, \ + typename TTypes::Tensor backprops); \ extern template struct HardshrinkGrad; TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); @@ -61,11 +61,11 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); } // namespace functor // Registration of the GPU implementations. -#define REGISTER_HARDSHRINK_GPU_KERNELS(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Hardshrink").Device(DEVICE_GPU).TypeConstraint("T"), \ - HardshrinkOp); \ - REGISTER_KERNEL_BUILDER( \ +#define REGISTER_HARDSHRINK_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Hardshrink").Device(DEVICE_GPU).TypeConstraint("T"), \ + HardshrinkOp); \ + REGISTER_KERNEL_BUILDER( \ Name("Hardshrink").Device(DEVICE_GPU).TypeConstraint("T"), \ HardshrinkGradOp); diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.h b/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.h index 76974769c4..29ddd12759 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.h +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.h @@ -37,7 +37,9 @@ struct Hardshrink { // activations: same shape as "features". void operator()(const Device& d, typename TTypes::ConstTensor features, T lower, T upper, typename TTypes::Tensor activations) { - activations.device(d) = (features < lower || features > upper).select(features, features.constant(static_cast(0))); + activations.device(d) = + (features < lower || features > upper) + .select(features, features.constant(static_cast(0))); } }; @@ -54,7 +56,9 @@ struct HardshrinkGrad { void operator()(const Device& d, typename TTypes::ConstTensor gradients, typename TTypes::ConstTensor features, T lower, T upper, typename TTypes::Tensor backprops) { - backprops.device(d) = (features < lower || features > upper).select(gradients, features.constant(static_cast(0))); + backprops.device(d) = + (features < lower || features > upper) + .select(gradients, features.constant(static_cast(0))); } }; @@ -64,14 +68,17 @@ template class HardshrinkOp : public UnaryElementWiseOp> { public: explicit HardshrinkOp(OpKernelConstruction* context) - : UnaryElementWiseOp>::UnaryElementWiseOp(context) { - float lower, upper; + : UnaryElementWiseOp>::UnaryElementWiseOp( + context) { + float lower, upper; OP_REQUIRES_OK(context, context->GetAttr("lower", &lower)); OP_REQUIRES_OK(context, context->GetAttr("upper", &upper)); lower_ = static_cast(lower); upper_ = static_cast(upper); - - OP_REQUIRES(context, lower_ <= upper_, errors::InvalidArgument("lower must be less than or equal to upper.")); + + OP_REQUIRES( + context, lower_ <= upper_, + errors::InvalidArgument("lower must be less than or equal to upper.")); } void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { @@ -86,18 +93,21 @@ class HardshrinkOp : public UnaryElementWiseOp> { }; template -class HardshrinkGradOp : public BinaryElementWiseOp> { +class HardshrinkGradOp + : public BinaryElementWiseOp> { public: explicit HardshrinkGradOp(OpKernelConstruction* context) - : BinaryElementWiseOp>::BinaryElementWiseOp( - context) { - float lower, upper; + : BinaryElementWiseOp< + T, HardshrinkGradOp>::BinaryElementWiseOp(context) { + float lower, upper; OP_REQUIRES_OK(context, context->GetAttr("lower", &lower)); OP_REQUIRES_OK(context, context->GetAttr("upper", &upper)); lower_ = static_cast(lower); upper_ = static_cast(upper); - OP_REQUIRES(context, lower_ <= upper_, errors::InvalidArgument("lower must be less than or equal to upper.")); + OP_REQUIRES( + context, lower_ <= upper_, + errors::InvalidArgument("lower must be less than or equal to upper.")); } void OperateNoTemplate(OpKernelContext* context, const Tensor& g, @@ -116,12 +126,12 @@ class HardshrinkGradOp : public BinaryElementWiseOp void HardshrinkGradOp::OperateNoTemplate(OpKernelContext* context, - const Tensor& g, const Tensor& a, - T lower, T upper, - Tensor* output) { + const Tensor& g, + const Tensor& a, T lower, + T upper, Tensor* output) { functor::HardshrinkGrad functor; - functor(context->eigen_device(), g.flat(), a.flat(), - lower, upper, output->flat()); + functor(context->eigen_device(), g.flat(), a.flat(), lower, + upper, output->flat()); } } // namespace tensorflow diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op_gpu.cu.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op_gpu.cu.cc index dc90c5af4b..ac5d03ee37 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op_gpu.cu.cc +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op_gpu.cu.cc @@ -25,7 +25,7 @@ namespace tensorflow { using GPUDevice = Eigen::GpuDevice; -#define DEFINE_GPU_KERNELS(T) \ +#define DEFINE_GPU_KERNELS(T) \ template struct functor::Hardshrink; \ template struct functor::HardshrinkGrad; From 3fb57d2d0409981aa7b6fcd1f0e0cffd66d14112 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Wed, 11 Sep 2019 17:47:20 +0800 Subject: [PATCH 5/6] fix typo --- .../custom_ops/activations/cc/kernels/hardshrink_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.cc index eaeacda25a..750a8cd2f7 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.cc +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.cc @@ -66,7 +66,7 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); Name("Hardshrink").Device(DEVICE_GPU).TypeConstraint("T"), \ HardshrinkOp); \ REGISTER_KERNEL_BUILDER( \ - Name("Hardshrink").Device(DEVICE_GPU).TypeConstraint("T"), \ + Name("HardshrinkGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ HardshrinkGradOp); TF_CALL_GPU_NUMBER_TYPES(REGISTER_HARDSHRINK_GPU_KERNELS); From 5d962521c5bf60955dae7c21658549a6ad8db856 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Wed, 11 Sep 2019 17:55:15 +0800 Subject: [PATCH 6/6] make linter happy --- .../custom_ops/activations/cc/kernels/hardshrink_op.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.cc index 750a8cd2f7..54d0795ba2 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.cc +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.cc @@ -61,11 +61,11 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); } // namespace functor // Registration of the GPU implementations. -#define REGISTER_HARDSHRINK_GPU_KERNELS(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Hardshrink").Device(DEVICE_GPU).TypeConstraint("T"), \ - HardshrinkOp); \ - REGISTER_KERNEL_BUILDER( \ +#define REGISTER_HARDSHRINK_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Hardshrink").Device(DEVICE_GPU).TypeConstraint("T"), \ + HardshrinkOp); \ + REGISTER_KERNEL_BUILDER( \ Name("HardshrinkGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ HardshrinkGradOp);