Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tensorflow_addons/activations/gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def gelu(x, approximate=True):
A `Tensor`. Has the same type as `x`.
"""
x = tf.convert_to_tensor(x)
return _activation_ops_so.gelu(x, approximate)
return _activation_ops_so.addons_gelu(x, approximate)


@tf.RegisterGradient("Gelu")
@tf.RegisterGradient("Addons>Gelu")
def _gelu_grad(op, grad):
return _activation_ops_so.gelu_grad(grad, op.inputs[0],
op.get_attr("approximate"))
return _activation_ops_so.addons_gelu_grad(grad, op.inputs[0],
op.get_attr("approximate"))
10 changes: 5 additions & 5 deletions tensorflow_addons/activations/hardshrink.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ def hardshrink(x, lower=-1.0, upper=1.0):
A `Tensor`. Has the same type as `x`.
"""
x = tf.convert_to_tensor(x)
return _activation_ops_so.hardshrink(x, lower, upper)
return _activation_ops_so.addons_hardshrink(x, lower, upper)


@tf.RegisterGradient("Hardshrink")
@tf.RegisterGradient("Addons>Hardshrink")
def _hardshrink_grad(op, grad):
return _activation_ops_so.hardshrink_grad(grad, op.inputs[0],
op.get_attr("lower"),
op.get_attr("upper"))
return _activation_ops_so.addons_hardshrink_grad(grad, op.inputs[0],
op.get_attr("lower"),
op.get_attr("upper"))
28 changes: 15 additions & 13 deletions tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

namespace tensorflow {
namespace addons {

using CPUDevice = Eigen::ThreadPoolDevice;

#define REGISTER_GELU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Gelu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
GeluOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("GeluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
#define REGISTER_GELU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Addons>Gelu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
GeluOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("Addons>GeluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
GeluGradOp<CPUDevice, type>);

// Gelu only makes sense with floating points.
Expand Down Expand Up @@ -61,17 +62,18 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
} // namespace functor

// Registration of the GPU implementations.
#define REGISTER_GELU_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Gelu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
GeluOp<GPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("GeluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
#define REGISTER_GELU_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Addons>Gelu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
GeluOp<GPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("Addons>GeluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
GeluGradOp<GPUDevice, type>);

TF_CALL_GPU_NUMBER_TYPES(REGISTER_GELU_GPU_KERNELS);
#undef REGISTER_GELU_GPU_KERNELS

#endif // GOOGLE_CUDA

} // namespace tensorflow
} // end namespace addons
} // namespace tensorflow
8 changes: 5 additions & 3 deletions tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_ADDONS_GELU_OP_H_
#define TENSORFLOW_ADDONS_GELU_OP_H_
#ifndef TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_GELU_OP_H_
#define TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_GELU_OP_H_

#define EIGEN_USE_THREADS

Expand All @@ -23,6 +23,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

namespace tensorflow {
namespace addons {
namespace functor {

// Functor used by GeluOp to do the computations.
Expand Down Expand Up @@ -137,8 +138,9 @@ void GeluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
approximate, output->flat<T>());
}

} // end namespace addons
} // namespace tensorflow

#undef EIGEN_USE_THREADS

#endif // TENSORFLOW_ADDONS_GELU_OP_H_
#endif // TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_GELU_OP_H_
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include "third_party/eigen3/Eigen/Core"

namespace tensorflow {
namespace addons {

using GPUDevice = Eigen::GpuDevice;

Expand All @@ -31,6 +32,7 @@ using GPUDevice = Eigen::GpuDevice;

TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);

} // end namespace addons
} // namespace tensorflow

#endif // GOOGLE_CUDA
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,18 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

namespace tensorflow {
namespace addons {

using CPUDevice = Eigen::ThreadPoolDevice;

#define REGISTER_HARDSHRINK_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Hardshrink").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
HardshrinkOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("HardshrinkGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
HardshrinkGradOp<CPUDevice, type>);
#define REGISTER_HARDSHRINK_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Addons>Hardshrink").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
HardshrinkOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER(Name("Addons>HardshrinkGrad") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T"), \
HardshrinkGradOp<CPUDevice, type>);

// Hardshrink only makes sense with floating points.
TF_CALL_GPU_NUMBER_TYPES(REGISTER_HARDSHRINK_KERNELS);
Expand Down Expand Up @@ -61,17 +63,19 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
} // namespace functor

// Registration of the GPU implementations.
#define REGISTER_HARDSHRINK_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Hardshrink").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
HardshrinkOp<GPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("HardshrinkGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
HardshrinkGradOp<GPUDevice, type>);
#define REGISTER_HARDSHRINK_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Addons>Hardshrink").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
HardshrinkOp<GPUDevice, type>); \
REGISTER_KERNEL_BUILDER(Name("Addons>HardshrinkGrad") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T"), \
HardshrinkGradOp<GPUDevice, type>);

TF_CALL_GPU_NUMBER_TYPES(REGISTER_HARDSHRINK_GPU_KERNELS);
#undef REGISTER_HARDSHRINK_GPU_KERNELS

#endif // GOOGLE_CUDA

} // end namespace addons
} // namespace tensorflow
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_ADDONS_HARDSHRINK_OP_H_
#define TENSORFLOW_ADDONS_HARDSHRINK_OP_H_
#ifndef TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_HARDSHRINK_OP_H_
#define TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_HARDSHRINK_OP_H_

#define EIGEN_USE_THREADS

Expand All @@ -24,6 +24,8 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

namespace tensorflow {
namespace addons {

namespace functor {

// Functor used by HardshrinkOp to do the computations.
Expand Down Expand Up @@ -134,8 +136,9 @@ void HardshrinkGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
upper, output->flat<T>());
}

} // end namespace addons
} // namespace tensorflow

#undef EIGEN_USE_THREADS

#endif // TENSORFLOW_ADDONS_HARDSHRINK_OP_H_
#endif // TENSORFLOW_ADDONS_ACTIVATIONS_KERNELS_HARDSHRINK_OP_H_
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include "third_party/eigen3/Eigen/Core"

namespace tensorflow {
namespace addons {

using GPUDevice = Eigen::GpuDevice;

Expand All @@ -31,6 +32,7 @@ using GPUDevice = Eigen::GpuDevice;

TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);

} // end namespace addons
} // namespace tensorflow

#endif // GOOGLE_CUDA
8 changes: 5 additions & 3 deletions tensorflow_addons/custom_ops/activations/cc/ops/gelu_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,22 @@ limitations under the License.
#include "tensorflow/core/framework/shape_inference.h"

namespace tensorflow {
namespace addons {

REGISTER_OP("Gelu")
REGISTER_OP("Addons>Gelu")
.Input("features: T")
.Output("activations: T")
.Attr("T: {half, float, double}")
.Attr("approximate: bool = true")
.SetShapeFn(shape_inference::UnchangedShape);

REGISTER_OP("GeluGrad")
REGISTER_OP("Addons>GeluGrad")
.Input("gradients: T")
.Input("features: T")
.Output("backprops: T")
.Attr("T: {half, float, double}")
.Attr("approximate: bool = true")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn);

} // namespace tensorflow
} // end namespace addons
} // namespace tensorflow
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ limitations under the License.

namespace tensorflow {

REGISTER_OP("Hardshrink")
REGISTER_OP("Addons>Hardshrink")
.Input("features: T")
.Output("activations: T")
.Attr("T: {half, float, double}")
.Attr("lower: float = -1.0")
.Attr("upper: float = 1.0")
.SetShapeFn(shape_inference::UnchangedShape);

REGISTER_OP("HardshrinkGrad")
REGISTER_OP("Addons>HardshrinkGrad")
.Input("gradients: T")
.Input("features: T")
.Output("backprops: T")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow_addons/custom_ops/image/cc/kernels/adjust_hsv_in_yiq_op.h"

namespace tensorflow {
namespace addons {

typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
Expand Down Expand Up @@ -138,7 +139,7 @@ class AdjustHsvInYiqOp<CPUDevice> : public AdjustHsvInYiqOpBase {
};

REGISTER_KERNEL_BUILDER(
Name("AdjustHsvInYiq").Device(DEVICE_CPU).TypeConstraint<float>("T"),
Name("Addons>AdjustHsvInYiq").Device(DEVICE_CPU).TypeConstraint<float>("T"),
AdjustHsvInYiqOp<CPUDevice>);

#if GOOGLE_CUDA
Expand All @@ -162,8 +163,9 @@ class AdjustHsvInYiqOp<GPUDevice> : public AdjustHsvInYiqOpBase {
};

REGISTER_KERNEL_BUILDER(
Name("AdjustHsvInYiq").Device(DEVICE_GPU).TypeConstraint<float>("T"),
Name("Addons>AdjustHsvInYiq").Device(DEVICE_GPU).TypeConstraint<float>("T"),
AdjustHsvInYiqOp<GPUDevice>);
#endif

} // namespace tensorflow
} // end namespace addons
} // namespace tensorflow
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_
#define TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_
#ifndef TENSORFLOW_ADDONS_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_
#define TENSORFLOW_ADDONS_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_

#if GOOGLE_CUDA
#define EIGEN_USE_GPU
Expand All @@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"

namespace tensorflow {
namespace addons {

static constexpr int kChannelSize = 3;

Expand Down Expand Up @@ -82,6 +83,7 @@ struct AdjustHsvInYiqGPU {

#endif // GOOGLE_CUDA

} // end namespace addons
} // namespace tensorflow

#endif // TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_
#endif // TENSORFLOW_ADDONS_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/util/gpu_kernel_helper.h"

namespace tensorflow {
namespace addons {

namespace {

Expand Down Expand Up @@ -90,5 +91,7 @@ void AdjustHsvInYiqGPU::operator()(OpKernelContext* ctx, int channel_count,
}
}
} // namespace functor
} // end namespace addons
} // namespace tensorflow

#endif // GOOGLE_CUDA
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ limitations under the License.

namespace tensorflow {

using tensorflow::functor::BlockedImageUnionFindFunctor;
using tensorflow::functor::FindRootFunctor;
using tensorflow::functor::ImageConnectedComponentsFunctor;
using tensorflow::functor::TensorRangeFunctor;
namespace addons {

using tensorflow::addons::functor::BlockedImageUnionFindFunctor;
using tensorflow::addons::functor::FindRootFunctor;
using tensorflow::addons::functor::ImageConnectedComponentsFunctor;
using tensorflow::addons::functor::TensorRangeFunctor;

using OutputType = typename BlockedImageUnionFindFunctor<bool>::OutputType;

Expand Down Expand Up @@ -119,10 +121,10 @@ struct ImageConnectedComponentsFunctor<CPUDevice, T> {

} // end namespace functor

#define REGISTER_IMAGE_CONNECTED_COMPONENTS(TYPE) \
REGISTER_KERNEL_BUILDER(Name("ImageConnectedComponents") \
.Device(DEVICE_CPU) \
.TypeConstraint<TYPE>("dtype"), \
#define REGISTER_IMAGE_CONNECTED_COMPONENTS(TYPE) \
REGISTER_KERNEL_BUILDER(Name("Addons>ImageConnectedComponents") \
.Device(DEVICE_CPU) \
.TypeConstraint<TYPE>("dtype"), \
ImageConnectedComponents<CPUDevice, TYPE>)
// Connected components (arguably) make sense for number, bool, and string types
TF_CALL_NUMBER_TYPES(REGISTER_IMAGE_CONNECTED_COMPONENTS);
Expand All @@ -135,4 +137,5 @@ TF_CALL_string(REGISTER_IMAGE_CONNECTED_COMPONENTS);
// shared memory in CUDA thread blocks, instead of starting with single-pixel
// blocks).

} // end namespace addons
} // end namespace tensorflow
Loading