diff --git a/tensorflow_addons/custom_ops/seq2seq/BUILD b/tensorflow_addons/custom_ops/seq2seq/BUILD index aac484b93d..4f95a02373 100644 --- a/tensorflow_addons/custom_ops/seq2seq/BUILD +++ b/tensorflow_addons/custom_ops/seq2seq/BUILD @@ -3,23 +3,45 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//visibility:public"]) load("@local_config_tf//:build_defs.bzl", "D_GLIBCXX_USE_CXX11_ABI") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured", "if_cuda") cc_binary( name = "_beam_search_ops.so", srcs = [ "cc/kernels/beam_search_ops.cc", "cc/kernels/beam_search_ops.h", - # "cc/kernels/beam_search_ops_gpu.cu.cc", "cc/ops/beam_search_ops.cc", ], copts = [ "-pthread", "-std=c++11", D_GLIBCXX_USE_CXX11_ABI, - ], + ] + if_cuda(["-DGOOGLE_CUDA=1"]), linkshared = 1, deps = [ "@local_config_tf//:libtensorflow_framework", "@local_config_tf//:tf_header_lib", + ] + if_cuda_is_configured([":beam_search_ops_gpu"]), +) + +cc_library( + name = "beam_search_ops_gpu", + srcs = [ + "cc/kernels/beam_search_ops.h", + "cc/kernels/beam_search_ops_gpu.cu.cc", ], + copts = if_cuda_is_configured([ + "-DGOOGLE_CUDA=1", + "-x cuda", + "-nvcc_options=relaxed-constexpr", + "-nvcc_options=ftz=true", + ]), + deps = [ + "@local_config_tf//:libtensorflow_framework", + "@local_config_tf//:tf_header_lib", + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_libs", + "@local_config_cuda//cuda:cuda_headers", + ]), + alwayslink = 1, ) diff --git a/tensorflow_addons/custom_ops/seq2seq/cc/kernels/beam_search_ops.h b/tensorflow_addons/custom_ops/seq2seq/cc/kernels/beam_search_ops.h index e809b6f985..f11bb6cd8f 100644 --- a/tensorflow_addons/custom_ops/seq2seq/cc/kernels/beam_search_ops.h +++ b/tensorflow_addons/custom_ops/seq2seq/cc/kernels/beam_search_ops.h @@ -37,4 +37,4 @@ struct GatherTree { } // namespace functor } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_ \ No newline at end of file +#endif // TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_ diff --git a/tensorflow_addons/custom_ops/seq2seq/cc/kernels/beam_search_ops_gpu.cu.cc b/tensorflow_addons/custom_ops/seq2seq/cc/kernels/beam_search_ops_gpu.cu.cc index bff67438aa..b6018be293 100644 --- a/tensorflow_addons/custom_ops/seq2seq/cc/kernels/beam_search_ops_gpu.cu.cc +++ b/tensorflow_addons/custom_ops/seq2seq/cc/kernels/beam_search_ops_gpu.cu.cc @@ -18,7 +18,7 @@ limitations under the License. #define EIGEN_USE_GPU #include "tensorflow_addons/custom_ops/seq2seq/cc/kernels/beam_search_ops.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { namespace functor { diff --git a/tensorflow_addons/seq2seq/beam_search_ops_test.py b/tensorflow_addons/seq2seq/beam_search_ops_test.py index a8fd760d08..14e7621e2c 100644 --- a/tensorflow_addons/seq2seq/beam_search_ops_test.py +++ b/tensorflow_addons/seq2seq/beam_search_ops_test.py @@ -71,9 +71,6 @@ def testBadParentValuesOnCPU(self): self.evaluate(beams) def testBadParentValuesOnGPU(self): - # TODO: Fix #348 issue - self.skipTest('Wait #348 to be fixed') - # Only want to run this test on CUDA devices, as gather_tree is not # registered for SYCL devices. if not tf.test.is_gpu_available(cuda_only=True):