Skip to content

Commit 77f6872

Browse files
WindQAQseanpmorgan
authored andcommitted
build gpu kernel for distort ops (#394)
1 parent 695dc19 commit 77f6872

File tree

2 files changed

+35
-4
lines changed

2 files changed

+35
-4
lines changed

tensorflow_addons/custom_ops/image/BUILD

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,28 @@ package(default_visibility = ["//visibility:public"])
55
load("@local_config_tf//:build_defs.bzl", "D_GLIBCXX_USE_CXX11_ABI")
66
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured", "if_cuda")
77

8+
cc_library(
9+
name = "distort_image_ops_gpu",
10+
srcs = [
11+
"cc/kernels/adjust_hsv_in_yiq_op.h",
12+
"cc/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc",
13+
],
14+
copts = if_cuda_is_configured([
15+
"-DGOOGLE_CUDA=1",
16+
"-x cuda",
17+
"-nvcc_options=relaxed-constexpr",
18+
"-nvcc_options=ftz=true",
19+
]),
20+
deps = [
21+
"@local_config_tf//:libtensorflow_framework",
22+
"@local_config_tf//:tf_header_lib",
23+
] + if_cuda_is_configured([
24+
"@local_config_cuda//cuda:cuda_libs",
25+
"@local_config_cuda//cuda:cuda_headers",
26+
]),
27+
alwayslink = 1,
28+
)
29+
830
cc_binary(
931
name = "_distort_image_ops.so",
1032
srcs = [
@@ -16,12 +38,12 @@ cc_binary(
1638
"-pthread",
1739
"-std=c++11",
1840
D_GLIBCXX_USE_CXX11_ABI,
19-
],
41+
] + if_cuda(["-DGOOGLE_CUDA=1"]),
2042
linkshared = 1,
2143
deps = [
2244
"@local_config_tf//:libtensorflow_framework",
2345
"@local_config_tf//:tf_header_lib",
24-
],
46+
] + if_cuda_is_configured([":distort_image_ops_gpu"]),
2547
)
2648

2749
cc_library(

tensorflow_addons/custom_ops/image/cc/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,21 @@ limitations under the License.
1616
#define EIGEN_USE_GPU
1717

1818
#include "tensorflow_addons/custom_ops/image/cc/kernels/adjust_hsv_in_yiq_op.h"
19-
#include "tensorflow/core/kernels/gpu_utils.h"
2019
#include "tensorflow/core/platform/stream_executor.h"
21-
#include "tensorflow/core/util/cuda_kernel_helper.h"
20+
#include "tensorflow/core/util/gpu_kernel_helper.h"
2221

2322
namespace tensorflow {
2423

24+
namespace {
25+
26+
template <typename T>
27+
inline se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory, uint64 size) {
28+
se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory), size * sizeof(T));
29+
se::DeviceMemory<T> typed(wrapped);
30+
return typed;
31+
}
32+
} // namespace
33+
2534
namespace internal {
2635

2736
__global__ void compute_tranformation_matrix_cuda(const float* const delta_h,

0 commit comments

Comments
 (0)