Skip to content

Commit 9e09680

Browse files
committed
Migrate api style
1 parent 76fea54 commit 9e09680

File tree

7 files changed

+35
-27
lines changed

7 files changed

+35
-27
lines changed

tensorflow_addons/activations/BUILD

100644100755
File mode changed.

tensorflow_addons/activations/tanhshrink.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
@keras_utils.register_keras_custom_object
2929
@tf.function
30-
def tanhshrink(features, name="Tanhshrink"):
30+
def tanhshrink(features, name=None):
3131
"""Applies the element-wise function: x - tanh(x)
3232
3333
Args:
@@ -36,13 +36,11 @@ def tanhshrink(features, name="Tanhshrink"):
3636
Returns:
3737
A `Tensor`. Has the same type as `features`.
3838
"""
39-
with tf.name_scope(name) as name:
39+
with tf.name_scope(name or "tanhshrink") as name:
4040
features = tf.convert_to_tensor(features, name="features")
41-
if features.dtype.is_integer:
42-
features = tf.cast(features, tf.float32)
43-
return _activation_ops_so.tanhshrink(features, name=name)
41+
return _activation_ops_so.addons_tanhshrink(features, name=name)
4442

4543

46-
@tf.RegisterGradient("Tanhshrink")
44+
@tf.RegisterGradient("Addons>Tanhshrink")
4745
def _tanhshrink_grad(op, grad):
48-
return _activation_ops_so.tanhshrink_grad(grad, op.inputs[0])
46+
return _activation_ops_so.addons_tanhshrink_grad(grad, op.inputs[0])

tensorflow_addons/custom_ops/activations/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,6 @@ cc_binary(
9696
] + if_cuda_is_configured([
9797
":gelu_op_gpu",
9898
":hardshrink_op_gpu",
99-
":tanhshrink_op_gpu"
99+
":tanhshrink_op_gpu",
100100
]),
101101
)

tensorflow_addons/custom_ops/activations/cc/kernels/tanhshrink_op.cc

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,18 @@ limitations under the License.
2121
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
2222

2323
namespace tensorflow {
24+
namespace addons {
2425

2526
using CPUDevice = Eigen::ThreadPoolDevice;
2627

27-
#define REGISTER_TANHSHRINK_KERNELS(type) \
28-
REGISTER_KERNEL_BUILDER( \
29-
Name("Tanhshrink").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
30-
TanhshrinkOp<CPUDevice, type>); \
31-
REGISTER_KERNEL_BUILDER( \
32-
Name("TanhshrinkGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
33-
TanhshrinkGradOp<CPUDevice, type>);
28+
#define REGISTER_TANHSHRINK_KERNELS(type) \
29+
REGISTER_KERNEL_BUILDER( \
30+
Name("Addons>Tanhshrink").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
31+
TanhshrinkOp<CPUDevice, type>); \
32+
REGISTER_KERNEL_BUILDER(Name("Addons>TanhshrinkGrad") \
33+
.Device(DEVICE_CPU) \
34+
.TypeConstraint<type>("T"), \
35+
TanhshrinkGradOp<CPUDevice, type>);
3436

3537
TF_CALL_GPU_NUMBER_TYPES(REGISTER_TANHSHRINK_KERNELS);
3638
#undef REGISTER_TANHSHRINK_KERNELS
@@ -60,17 +62,19 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
6062
} // namespace functor
6163

6264
// Registration of the GPU implementations.
63-
#define REGISTER_TANHSHRINK_GPU_KERNELS(type) \
64-
REGISTER_KERNEL_BUILDER( \
65-
Name("Tanhshrink").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
66-
TanhshrinkOp<GPUDevice, type>); \
67-
REGISTER_KERNEL_BUILDER( \
68-
Name("TanhshrinkGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
69-
TanhshrinkGradOp<GPUDevice, type>);
65+
#define REGISTER_TANHSHRINK_GPU_KERNELS(type) \
66+
REGISTER_KERNEL_BUILDER( \
67+
Name("Addons>Tanhshrink").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
68+
TanhshrinkOp<GPUDevice, type>); \
69+
REGISTER_KERNEL_BUILDER(Name("Addons>TanhshrinkGrad") \
70+
.Device(DEVICE_GPU) \
71+
.TypeConstraint<type>("T"), \
72+
TanhshrinkGradOp<GPUDevice, type>);
7073

7174
TF_CALL_GPU_NUMBER_TYPES(REGISTER_TANHSHRINK_GPU_KERNELS);
7275
#undef REGISTER_TANHSHRINK_GPU_KERNELS
7376

7477
#endif // GOOGLE_CUDA
7578

79+
} // namespace addons
7680
} // namespace tensorflow

tensorflow_addons/custom_ops/activations/cc/kernels/tanhshrink_op.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
#ifndef TENSORFLOW_ADDONS_TANHSHRINK_OP_H_
17-
#define TENSORFLOW_ADDONS_TANHSHRINK_OP_H_
16+
#ifndef TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_TANHSHRINK_OP_H_
17+
#define TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_TANHSHRINK_OP_H_
1818

1919
#define EIGEN_USE_THREADS
2020

@@ -23,6 +23,7 @@ limitations under the License.
2323
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
2424

2525
namespace tensorflow {
26+
namespace addons {
2627
namespace functor {
2728

2829
template <typename Device, typename T>
@@ -88,7 +89,8 @@ void TanhshrinkGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
8889
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
8990
output->flat<T>());
9091
}
91-
}
92+
} // namespace addons
93+
} // namespace tensorflow
9294

9395
#undef EIGEN_USE_THREADS
9496

tensorflow_addons/custom_ops/activations/cc/kernels/tanhshrink_op_gpu.cu.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ limitations under the License.
2222
#include "third_party/eigen3/Eigen/Core"
2323

2424
namespace tensorflow {
25+
namespace addons {
2526

2627
using GPUDevice = Eigen::GpuDevice;
2728

@@ -31,6 +32,7 @@ using GPUDevice = Eigen::GpuDevice;
3132

3233
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
3334

35+
} // namespace addons
3436
} // namespace tensorflow
3537

3638
#endif // GOOGLE_CUDA

tensorflow_addons/custom_ops/activations/cc/ops/tanhshrink_op.cc

100644100755
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,20 @@ limitations under the License.
1818
#include "tensorflow/core/framework/shape_inference.h"
1919

2020
namespace tensorflow {
21+
namespace addons {
2122

22-
REGISTER_OP("Tanhshrink")
23+
REGISTER_OP("Addons>Tanhshrink")
2324
.Input("features: T")
2425
.Output("activations: T")
2526
.Attr("T: {half, float, double}")
2627
.SetShapeFn(shape_inference::UnchangedShape);
2728

28-
REGISTER_OP("TanhshrinkGrad")
29+
REGISTER_OP("Addons>TanhshrinkGrad")
2930
.Input("gradients: T")
3031
.Input("features: T")
3132
.Output("backprops: T")
3233
.Attr("T: {half, float, double}")
3334
.SetShapeFn(shape_inference::MergeBothInputsShapeFn);
3435

36+
} // namespace addons
3537
} // namespace tensorflow

0 commit comments

Comments
 (0)