Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions tensorflow_addons/custom_ops/image/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,28 @@ package(default_visibility = ["//visibility:public"])
load("@local_config_tf//:build_defs.bzl", "D_GLIBCXX_USE_CXX11_ABI")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured", "if_cuda")

cc_library(
name = "distort_image_ops_gpu",
srcs = [
"cc/kernels/adjust_hsv_in_yiq_op.h",
"cc/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc",
],
copts = if_cuda_is_configured([
"-DGOOGLE_CUDA=1",
"-x cuda",
"-nvcc_options=relaxed-constexpr",
"-nvcc_options=ftz=true",
]),
deps = [
"@local_config_tf//:libtensorflow_framework",
"@local_config_tf//:tf_header_lib",
] + if_cuda_is_configured([
"@local_config_cuda//cuda:cuda_libs",
"@local_config_cuda//cuda:cuda_headers",
]),
alwayslink = 1,
)

cc_binary(
name = "_distort_image_ops.so",
srcs = [
Expand All @@ -16,12 +38,12 @@ cc_binary(
"-pthread",
"-std=c++11",
D_GLIBCXX_USE_CXX11_ABI,
],
] + if_cuda(["-DGOOGLE_CUDA=1"]),
linkshared = 1,
deps = [
"@local_config_tf//:libtensorflow_framework",
"@local_config_tf//:tf_header_lib",
],
] + if_cuda_is_configured([":distort_image_ops_gpu"]),
)

cc_library(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,21 @@ limitations under the License.
#define EIGEN_USE_GPU

#include "tensorflow_addons/custom_ops/image/cc/kernels/adjust_hsv_in_yiq_op.h"
#include "tensorflow/core/kernels/gpu_utils.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"

namespace tensorflow {

namespace {

template <typename T>
inline se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory, uint64 size) {
se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory), size * sizeof(T));
se::DeviceMemory<T> typed(wrapped);
return typed;
}
} // namespace

Copy link
Member Author

@WindQAQ WindQAQ Aug 4, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are copied from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/gpu_utils.h#L37 because tensorflow/core/kernels/* are not included core TF package. Please correct me if I'm wrong.

namespace internal {

__global__ void compute_tranformation_matrix_cuda(const float* const delta_h,
Expand Down