diff --git a/WORKSPACE b/WORKSPACE index 24062f1570..103360fc5d 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -1,6 +1,19 @@ +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("//build_deps/tf_dependency:tf_configure.bzl", "tf_configure") load("//build_deps/gpu:cuda_configure.bzl", "cuda_configure") + +http_archive( + name = "cub_archive", + build_file = "//build_deps/gpu:cub.BUILD", + sha256 = "6bfa06ab52a650ae7ee6963143a0bbc667d6504822cbd9670369b598f18c58c3", + strip_prefix = "cub-1.8.0", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/NVlabs/cub/archive/1.8.0.zip", + "https://github.com/NVlabs/cub/archive/1.8.0.zip", + ], +) + tf_configure( name = "local_config_tf", ) diff --git a/build_deps/gpu/cub.BUILD b/build_deps/gpu/cub.BUILD new file mode 100644 index 0000000000..cdc9e4f377 --- /dev/null +++ b/build_deps/gpu/cub.BUILD @@ -0,0 +1,25 @@ +# Description: CUB library which is a set of primitives for GPU programming. + +load("@local_config_cuda//cuda:build_defs.bzl", "cuda_default_copts", "if_cuda") + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # BSD + +filegroup( + name = "cub_header_files", + srcs = glob([ + "cub/**", + ]), +) + +cc_library( + name = "cub", + hdrs = if_cuda([":cub_header_files"]), + include_prefix = "gpu", + deps = [ + "@local_config_cuda//cuda:cuda_headers", + ], +) diff --git a/tensorflow_addons/custom_ops/README.md b/tensorflow_addons/custom_ops/README.md index 522be99119..0c91c9e02c 100644 --- a/tensorflow_addons/custom_ops/README.md +++ b/tensorflow_addons/custom_ops/README.md @@ -6,3 +6,4 @@ | Image | Ops for image manipulation | | Seq2seq | Ops for seq2seq encoder-decoder framework | | Text | Ops for text processing | +| Layers | Ops for model layers | diff --git a/tensorflow_addons/custom_ops/layers/BUILD b/tensorflow_addons/custom_ops/layers/BUILD new file mode 100644 index 0000000000..ed0c567f59 --- /dev/null +++ b/tensorflow_addons/custom_ops/layers/BUILD @@ -0,0 +1,49 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) + +load("@local_config_tf//:build_defs.bzl", "D_GLIBCXX_USE_CXX11_ABI") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured", "if_cuda") + +cc_binary( + name = "_correlation_cost_ops.so", + srcs = [ + "cc/kernels/correlation_cost_op.cc", + "cc/kernels/correlation_cost_op.h", + "cc/kernels/correlation_cost_op_gpu.cu.cc", + "cc/ops/correlation_cost_op.cc", + ], + copts = [ + "-pthread", + "-std=c++11", + D_GLIBCXX_USE_CXX11_ABI, + ], + linkshared = 1, + deps = [ + "@local_config_tf//:libtensorflow_framework", + "@local_config_tf//:tf_header_lib", + ] + if_cuda_is_configured([":correlation_cost_ops_gpu"]), +) + +cc_library( + name = "correlation_cost_ops_gpu", + srcs = [ + "cc/kernels/correlation_cost_op.h", + "cc/kernels/correlation_cost_op_gpu.cu.cc", + ], + copts = if_cuda_is_configured([ + "-DGOOGLE_CUDA=1", + "-x cuda", + "-nvcc_options=relaxed-constexpr", + "-nvcc_options=ftz=true", + ]), + deps = [ + "@local_config_tf//:libtensorflow_framework", + "@local_config_tf//:tf_header_lib", + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_libs", + "@local_config_cuda//cuda:cuda_headers", + "@cub_archive//:cub", + ]), + alwayslink = 1, +) diff --git a/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op.cc b/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op.cc new file mode 100644 index 0000000000..e1f4b1cdbc --- /dev/null +++ b/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op.cc @@ -0,0 +1,349 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA + +#include "tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + +template +struct CorrelationCostFunctor { + Status operator()(OpKernelContext* context, const Tensor& input_a_t, + const Tensor& input_b_t, Tensor* output_t, + /* params */ + int kernel_size, int max_displacement, int stride_1, + int stride_2, int pad, TensorFormat data_format) { + const int32 oN = GetTensorDim(*output_t, FORMAT_NCHW, 'N'); + // const int32 oC = GetTensorDim(*output_t, FORMAT_NCHW, 'C'); + const int32 oH = GetTensorDim(*output_t, FORMAT_NCHW, 'H'); + const int32 oW = GetTensorDim(*output_t, FORMAT_NCHW, 'W'); + const int32 iH = GetTensorDim(input_a_t, data_format, 'H'); + const int32 iW = GetTensorDim(input_a_t, data_format, 'W'); + const int32 iC = GetTensorDim(input_a_t, data_format, 'C'); + + const int K = kernel_size * kernel_size * iC; + + const auto input_a = input_a_t.tensor(); + const auto input_b = input_b_t.tensor(); + auto output = output_t->tensor(); + output.setZero(); + + const int kernel_rad = (kernel_size - 1) / 2; + const int displacement_rad = max_displacement / stride_2; + const int displacement_size = 2 * displacement_rad + 1; + + const bool is_NCHW = (data_format == FORMAT_NCHW); + + for (int n = 0; n < oN; ++n) { + for (int h = 0; h < oH; ++h) { + const int h1 = (h - pad) * stride_1 + max_displacement + kernel_rad; + for (int w = 0; w < oW; ++w) { + const int w1 = (w - pad) * stride_1 + max_displacement + kernel_rad; + + for (int tj = -displacement_rad; tj <= displacement_rad; ++tj) { + for (int ti = -displacement_rad; ti <= displacement_rad; ++ti) { + const int tc = (tj + displacement_rad) * displacement_size + + (ti + displacement_rad); + + const int w2 = w1 + ti * stride_2; + const int h2 = h1 + tj * stride_2; + + for (int j = -kernel_rad; j <= kernel_rad; ++j) { + // out-of-bound test + if ((h1 + j < 0) || (h1 + j >= iH)) continue; + if ((h2 + j < 0) || (h2 + j >= iH)) continue; + for (int i = -kernel_rad; i <= kernel_rad; ++i) { + if ((w1 + i < 0) || (w1 + i >= iW)) continue; + if ((w2 + i < 0) || (w2 + i >= iW)) continue; + for (int c = 0; c < iC; ++c) { + // eq. (1) in FlowNet: Learning Optical Flow with + // Convolutional Networks + if (is_NCHW) { + output(n, tc, h, w) += input_a(n, c, h1 + j, w1 + i) * + input_b(n, c, h2 + j, w2 + i); + } else { + output(n, tc, h, w) += input_a(n, h1 + j, w1 + i, c) * + input_b(n, h2 + j, w2 + i, c); + } + } + } + } + output(n, tc, h, w) /= K; + } + } + } + } + } + return Status::OK(); + } +}; + +template +struct CorrelationCostGradFunctor { + Status operator()(OpKernelContext* context, const Tensor& input_a_t, + const Tensor& input_b_t, const Tensor& topdiff_t, + Tensor* output_a_gradient_t, Tensor* output_b_gradient_t, + /* params */ + int kernel_size, int max_displacement, int stride_1, + int stride_2, int pad, TensorFormat data_format) { + const int32 iN = GetTensorDim(input_a_t, data_format, 'N'); + const int32 iC = GetTensorDim(input_a_t, data_format, 'C'); + const int32 iH = GetTensorDim(input_a_t, data_format, 'H'); + const int32 iW = GetTensorDim(input_a_t, data_format, 'W'); + + // topdiff is NCHW + // const int32 oC = GetTensorDim(topdiff_t, FORMAT_NCHW, 'C'); + const int32 oH = GetTensorDim(topdiff_t, FORMAT_NCHW, 'H'); + const int32 oW = GetTensorDim(topdiff_t, FORMAT_NCHW, 'W'); + + const auto topdiff = topdiff_t.tensor(); + const auto input_a = input_a_t.tensor(); + const auto input_b = input_b_t.tensor(); + auto output_a_gradient = output_a_gradient_t->tensor(); + auto output_b_gradient = output_b_gradient_t->tensor(); + output_a_gradient.setZero(); + output_b_gradient.setZero(); + + const int kernel_rad = (kernel_size - 1) / 2; + const int displacement_rad = max_displacement / stride_2; + const int displacement_size = 2 * displacement_rad + 1; + const int K = kernel_size * kernel_size * iC; + + const bool is_NCHW = (data_format == FORMAT_NCHW); + + for (int n = 0; n < iN; ++n) { + for (int h = 0; h < oH; ++h) { + const int h1 = (h - pad) * stride_1 + max_displacement + kernel_rad; + for (int w = 0; w < oW; ++w) { + const int w1 = (w - pad) * stride_1 + max_displacement + kernel_rad; + + for (int tj = -displacement_rad; tj <= displacement_rad; ++tj) { + for (int ti = -displacement_rad; ti <= displacement_rad; ++ti) { + const int tc = (tj + displacement_rad) * displacement_size + + (ti + displacement_rad); + + const int w2 = w1 + ti * stride_2; + const int h2 = h1 + tj * stride_2; + + for (int j = -kernel_rad; j <= kernel_rad; ++j) { + // out-of-bound test + if ((h1 + j < 0) || (h1 + j >= iH)) continue; + if ((h2 + j < 0) || (h2 + j >= iH)) continue; + for (int i = -kernel_rad; i <= kernel_rad; ++i) { + if ((w1 + i < 0) || (w1 + i >= iW)) continue; + if ((w2 + i < 0) || (w2 + i >= iW)) continue; + for (int c = 0; c < iC; ++c) { + // eq. (1) in FlowNet: Learning Optical Flow with + // Convolutional Networks + if (is_NCHW) { + output_a_gradient(n, c, h1 + j, w1 + i) += + topdiff(n, tc, h, w) * input_b(n, c, h2 + j, w2 + i) / + K; + output_b_gradient(n, c, h2 + j, w2 + i) += + topdiff(n, tc, h, w) * input_a(n, c, h1 + j, w1 + i) / + K; + } else { + output_a_gradient(n, h1 + j, w1 + i, c) += + topdiff(n, tc, h, w) * input_b(n, h2 + j, w2 + i, c) / + K; + output_b_gradient(n, h2 + j, w2 + i, c) += + topdiff(n, tc, h, w) * input_a(n, h1 + j, w1 + i, c) / + K; + } + } + } + } + } + } + } + } + } + return Status::OK(); + } +}; + +} // end namespace functor + +template +class CorrelationCostOp : public OpKernel { + public: + explicit CorrelationCostOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("kernel_size", &kernel_size)); + OP_REQUIRES_OK(context, + context->GetAttr("max_displacement", &max_displacement)); + OP_REQUIRES_OK(context, context->GetAttr("stride_1", &stride_1)); + OP_REQUIRES_OK(context, context->GetAttr("stride_2", &stride_2)); + OP_REQUIRES_OK(context, context->GetAttr("pad", &pad)); + string data_format; + OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); + OP_REQUIRES(context, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + OP_REQUIRES(context, kernel_size % 2 != 0, + errors::InvalidArgument("kernel_size must be odd")); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input_a_t = context->input(0); + const Tensor& input_b_t = context->input(1); + + // we didn't check the batch-dimension during "SetShapeFn" + OP_REQUIRES(context, input_a_t.shape() == input_b_t.shape(), + errors::InvalidArgument("Input shapes have to be the same")); + + const int32 N = GetTensorDim(input_a_t, data_format_, 'N'); + const int32 H = GetTensorDim(input_a_t, data_format_, 'H'); + const int32 W = GetTensorDim(input_a_t, data_format_, 'W'); + + // output channels are d**2 where, d = 2r + 1 + const int32 r = max_displacement / stride_2; + const int32 d = 2 * r + 1; + const int32 border = max_displacement + (kernel_size - 1) / 2; + + const int32 Cout = d * d; + const int32 Hout = + static_cast(ceil(static_cast(((H + 2 * pad) - border * 2)) / + static_cast(stride_1))); + const int32 Wout = + static_cast(ceil(static_cast(((W + 2 * pad) - border * 2)) / + static_cast(stride_1))); + + OP_REQUIRES(context, Hout >= 1, + errors::InvalidArgument( + "Neighborhood and kernel don't fit in input height.")); + OP_REQUIRES(context, Wout >= 1, + errors::InvalidArgument( + "Neighborhood and kernel don't fit in input width.")); + + Tensor* output_t; + OP_REQUIRES_OK( + context, context->allocate_output(0, TensorShape({N, Cout, Hout, Wout}), + &output_t)); + + functor::CorrelationCostFunctor correlationCostFunc; + Status s = correlationCostFunc(context, input_a_t, input_b_t, output_t, + /* params */ + kernel_size, max_displacement, stride_1, + stride_2, pad, data_format_); + + OP_REQUIRES_OK(context, s); + } + + private: + int kernel_size; + int max_displacement; + int stride_1; + int stride_2; + int pad; + TensorFormat data_format_; +}; + +template +class CorrelationCostGradOp : public OpKernel { + public: + explicit CorrelationCostGradOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("kernel_size", &kernel_size)); + OP_REQUIRES_OK(context, + context->GetAttr("max_displacement", &max_displacement)); + OP_REQUIRES_OK(context, context->GetAttr("stride_1", &stride_1)); + OP_REQUIRES_OK(context, context->GetAttr("stride_2", &stride_2)); + OP_REQUIRES_OK(context, context->GetAttr("pad", &pad)); + string data_format; + OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); + OP_REQUIRES(context, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + OP_REQUIRES(context, kernel_size % 2 != 0, + errors::InvalidArgument("kernel_size must be odd")); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input_a_t = context->input(0); + const Tensor& input_b_t = context->input(1); + const Tensor& topdiff_t = context->input(2); + + OP_REQUIRES(context, input_a_t.shape() == input_b_t.shape(), + errors::InvalidArgument("Input shapes have to be the same")); + + // Allocate the memory for the bottom diffs + Tensor* output_a_gradient_t; + OP_REQUIRES_OK(context, context->allocate_output(0, input_a_t.shape(), + &output_a_gradient_t)); + Tensor* output_b_gradient_t; + OP_REQUIRES_OK(context, context->allocate_output(1, input_b_t.shape(), + &output_b_gradient_t)); + + functor::CorrelationCostGradFunctor correlationCostGrad; + Status s = correlationCostGrad(context, input_a_t, input_b_t, topdiff_t, + output_a_gradient_t, output_b_gradient_t, + /* params */ + kernel_size, max_displacement, stride_1, + stride_2, pad, data_format_); + + OP_REQUIRES_OK(context, s); + } + + private: + int kernel_size; + int max_displacement; + int stride_1; + int stride_2; + int pad; + TensorFormat data_format_; +}; + +// Register the CPU kernels. +#define REGISTER_CORRELATIONCOST_OP_CPU(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("CorrelationCost").Device(DEVICE_CPU).TypeConstraint("T"), \ + CorrelationCostOp) \ + REGISTER_KERNEL_BUILDER( \ + Name("CorrelationCostGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ + CorrelationCostGradOp) + +TF_CALL_float(REGISTER_CORRELATIONCOST_OP_CPU); +#undef REGISTER_CORRELATIONCOST_OP_CPU + +// Register the GPU kernels. +#ifdef GOOGLE_CUDA + +#define REGISTER_CORRELATIONCOST_OP_GPU(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("CorrelationCost").Device(DEVICE_GPU).TypeConstraint("T"), \ + CorrelationCostOp) \ + REGISTER_KERNEL_BUILDER( \ + Name("CorrelationCostGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ + CorrelationCostGradOp) + +TF_CALL_float(REGISTER_CORRELATIONCOST_OP_GPU); +#undef REGISTER_CORRELATIONCOST_OP_GPU + +#endif // GOOGLE_CUDA + +} // end namespace tensorflow diff --git a/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op.h b/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op.h new file mode 100644 index 0000000000..056c0cfc64 --- /dev/null +++ b/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORRELATION_COST_OP_H_ +#define TENSORFLOW_CORRELATION_COST_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { +namespace functor { + +template +struct CorrelationCostFunctor { + Status operator()(OpKernelContext* context, const Tensor& input_a_t, + const Tensor& input_b_t, Tensor* output_t, + /* params */ + int kernel_size, int max_displacement, int stride_1, + int stride_2, int pad, TensorFormat data_format); +}; + +template +struct CorrelationCostGradFunctor { + Status operator()(OpKernelContext* context, const Tensor& input_a_t, + const Tensor& input_b_t, const Tensor& topdiff_t, + Tensor* output_a_gradient_t, Tensor* output_b_gradient_t, + /* params */ + int kernel_size, int max_displacement, int stride_1, + int stride_2, int pad, TensorFormat data_format); +}; + +} // end namespace functor +} // end namespace tensorflow + +#endif // TENSORFLOW_CORRELATION_COST_OP_H_ diff --git a/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc b/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc new file mode 100644 index 0000000000..27d3375043 --- /dev/null +++ b/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc @@ -0,0 +1,475 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op.h" +#include "gpu/cub/device/device_reduce.cuh" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + +namespace { + +/* There are two ways to implement the correlation layer: +- pad first and then compute cross-correlation costs (faster) +- have if-else in the computation of cross-correlation costs + +This implementation is inspired from +https://github.com/NVIDIA/flownet2-pytorch +*/ + +template +__global__ void pad_and_transpose(const float *input, float *output, int C, + int H, int W, int P) { + // NCHW -> pad(NHWC) + const int n = blockIdx.x; + const int h = blockIdx.y; + const int w = blockIdx.z; + const int c0 = threadIdx.x; + const int pW = (W + 2 * P); + const int pH = (H + 2 * P); + + float value; + for (int c = c0; c < C; c += THREADS_PER_BLOCK) { + value = input[n * (C * H * W) + c * (H * W) + h * W + w]; + output[n * (C * pH * pW) + (h + P) * (pW * C) + (w + P) * C + c] = value; + } +} + +template +__global__ void pad_and_no_transpose(const float *input, float *output, int C, + int H, int W, int P) { + // NHWC -> pad(NHWC) + const int n = blockIdx.x; + const int h = blockIdx.y; + const int w = blockIdx.z; + const int c0 = threadIdx.x; + const int pW = (W + 2 * P); + const int pH = (H + 2 * P); + + float value; + for (int c = c0; c < C; c += THREADS_PER_BLOCK) { + value = input[n * (C * H * W) + h * (W * C) + w * C + c]; + output[n * (C * pH * pW) + (h + P) * (pW * C) + (w + P) * C + c] = value; + } +} + +template +__global__ void Correlation_forward(float *output, int Cout, int Hout, int Wout, + float *pInput1, int Cin, int Hin, int Win, + float *pInput2, int pad, int kernel_size, + int max_displacement, int stride1, + int stride2) { + const int pWin = Win + 2 * pad; + const int pHin = Hin + 2 * pad; + + const int kernel_rad = (kernel_size - 1) / 2; + const int displacement_rad = max_displacement / stride2; + const int displacement_size = 2 * displacement_rad + 1; + + const int n = blockIdx.x; + const int h1 = blockIdx.y * stride1 + max_displacement + kernel_rad; + const int w1 = blockIdx.z * stride1 + max_displacement + kernel_rad; + const int c = threadIdx.x; + + const int K = kernel_size * kernel_size * Cin; + + typedef cub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_sum_storage; + float thread_accumulation = 0; + + for (int tj = -displacement_rad; tj <= displacement_rad; ++tj) { + for (int ti = -displacement_rad; ti <= displacement_rad; ++ti) { + thread_accumulation = 0; + int w2 = w1 + ti * stride2; + int h2 = h1 + tj * stride2; + + for (int j = -kernel_rad; j <= kernel_rad; ++j) { + for (int i = -kernel_rad; i <= kernel_rad; ++i) { + for (int ch = c; ch < Cin; ch += THREADS_PER_BLOCK) { + const int indx1 = n * (pHin * pWin * Cin) + + (h1 + j) * (pWin * Cin) + (w1 + i) * Cin + ch; + const int indx2 = n * (pHin * pWin * Cin) + + (h2 + j) * (pWin * Cin) + (w2 + i) * Cin + ch; + thread_accumulation += pInput1[indx1] * pInput2[indx2]; + } + } + } + __syncthreads(); + + // THREADS_PER_BLOCK==32, hence there is only one warp per block + const float reduce_sum = + WarpReduce(temp_sum_storage).Sum(thread_accumulation); + if (c == 0) { + const int tc = (tj + displacement_rad) * displacement_size + + (ti + displacement_rad); + const int tindx = n * (Cout * Hout * Wout) + tc * (Hout * Wout) + + blockIdx.y * Wout + blockIdx.z; + output[tindx] = reduce_sum / K; + } + } + } +} + +template +__global__ void Correlation_backward_input1( + int item, float *gradInput1, int Cin, int Hin, int Win, + const float *gradOutput, int Cout, int Hout, int Wout, const float *rInput2, + int pad_size, int kernel_size, int max_displacement, int stride1, + int stride2, bool is_NCHW) { + const int n = item; + const int h = blockIdx.x * stride1 + pad_size; + const int w = blockIdx.y * stride1 + pad_size; + const int c = blockIdx.z; + const int t0 = threadIdx.x; + + const int kernel_rad = (kernel_size - 1) / 2; + const int displacement_rad = max_displacement / stride2; + const int displacement_size = 2 * displacement_rad + 1; + + int Wmin = (w - kernel_rad - max_displacement) / stride1; + int Hmin = (h - kernel_rad - max_displacement) / stride1; + + int Wmax = (w + kernel_rad - max_displacement) / stride1; + int Hmax = (h + kernel_rad - max_displacement) / stride1; + + if (Wmax < 0 || Hmax < 0 || Wmin >= Wout || Hmin >= Hout) { + // assumes gradInput1 is pre-allocated and zero filled + return; + } + + if (Wmin > Wmax || Hmin > Hmax) { + // assumes gradInput1 is pre-allocated and zero filled + return; + } + + Wmin = max(0, Wmin); + Wmax = min(Wout - 1, Wmax); + + Hmin = max(0, Hmin); + Hmax = min(Hout - 1, Hmax); + + const int pWin = Win + 2 * pad_size; + const int pHin = Hin + 2 * pad_size; + const float nelems = kernel_size * kernel_size * Cin; + + typedef cub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_sum_storage; + float thread_accumulation = 0; + + for (int tc = t0; tc < Cout; tc += THREADS_PER_BLOCK) { + int i2 = (tc % displacement_size - displacement_rad) * stride2; + int j2 = (tc / displacement_size - displacement_rad) * stride2; + + int indx2 = + n * (pHin * pWin * Cin) + (h + j2) * (pWin * Cin) + (w + i2) * Cin + c; + + float val2 = rInput2[indx2]; + + for (int j = Hmin; j <= Hmax; ++j) { + for (int i = Wmin; i <= Wmax; ++i) { + const int tindx = + n * (Cout * Hout * Wout) + tc * (Hout * Wout) + j * Wout + i; + thread_accumulation += gradOutput[tindx] * val2; + } + } + } + __syncthreads(); + + // THREADS_PER_BLOCK==32, hence there is only one warp per block + const float reduce_sum = + WarpReduce(temp_sum_storage).Sum(thread_accumulation); + if (t0 == 0) { + if (is_NCHW) { + const int indx1 = n * (Cin * Hin * Win) + c * (Hin * Win) + + (h - pad_size) * Win + (w - pad_size); + gradInput1[indx1] = reduce_sum / nelems; + } else { + const int indx1 = n * (Cin * Hin * Win) + (h - pad_size) * (Win * Cin) + + (w - pad_size) * Cin + c; + gradInput1[indx1] = reduce_sum / nelems; + } + } +} + +template +__global__ void Correlation_backward_input2( + int item, float *gradInput2, int Cin, int Hin, int Win, + const float *gradOutput, int Cout, int Hout, int Wout, const float *rInput1, + int pad_size, int kernel_size, int max_displacement, int stride1, + int stride2, bool is_NCHW) { + const int n = item; + const int h = blockIdx.x * stride1 + pad_size; + const int w = blockIdx.y * stride1 + pad_size; + const int c = blockIdx.z; + const int t0 = threadIdx.x; + + const int kernel_rad = (kernel_size - 1) / 2; + const int displacement_rad = max_displacement / stride2; + const int displacement_size = 2 * displacement_rad + 1; + + const int pWin = Win + 2 * pad_size; + const int pHin = Hin + 2 * pad_size; + const float nelems = kernel_size * kernel_size * Cin; + + typedef cub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_sum_storage; + float thread_accumulation = 0; + + for (int tc = t0; tc < Cout; tc += THREADS_PER_BLOCK) { + const int i2 = (tc % displacement_size - displacement_rad) * stride2; + const int j2 = (tc / displacement_size - displacement_rad) * stride2; + + int Wmin = (w - kernel_rad - max_displacement - i2) / stride1; + int Hmin = (h - kernel_rad - max_displacement - j2) / stride1; + + int Wmax = (w + kernel_rad - max_displacement - i2) / stride1; + int Hmax = (h + kernel_rad - max_displacement - j2) / stride1; + + if (Wmax < 0 || Hmax < 0 || Wmin >= Wout || Hmin >= Hout) { + // assumes gradInput2 is pre-allocated and zero filled + continue; + } + + if (Wmin > Wmax || Hmin > Hmax) { + // assumes gradInput2 is pre-allocated and zero filled + continue; + } + + Wmin = max(0, Wmin); + Wmax = min(Wout - 1, Wmax); + + Hmin = max(0, Hmin); + Hmax = min(Hout - 1, Hmax); + + const int indx1 = + n * (pHin * pWin * Cin) + (h - j2) * (pWin * Cin) + (w - i2) * Cin + c; + const float val1 = rInput1[indx1]; + + for (int j = Hmin; j <= Hmax; ++j) { + for (int i = Wmin; i <= Wmax; ++i) { + const int tindx = + n * (Cout * Hout * Wout) + tc * (Hout * Wout) + j * Wout + i; + thread_accumulation += gradOutput[tindx] * val1; + } + } + } + __syncthreads(); + + const float reduce_sum = + WarpReduce(temp_sum_storage).Sum(thread_accumulation); + if (t0 == 0) { + if (is_NCHW) { + const int indx2 = n * (Cin * Hin * Win) + c * (Hin * Win) + + (h - pad_size) * (Win) + (w - pad_size); + gradInput2[indx2] = reduce_sum / nelems; + } else { + const int indx2 = n * (Cin * Hin * Win) + (h - pad_size) * (Win * Cin) + + (w - pad_size) * Cin + c; + gradInput2[indx2] = reduce_sum / nelems; + } + } +} + +}; // namespace + +template +struct CorrelationCostFunctor { + Status operator()(OpKernelContext *context, const Tensor &input_a_t, + const Tensor &input_b_t, Tensor *output_t, + /* params */ + int kernel_size, int max_displacement, int stride_1, + int stride_2, int pad, TensorFormat data_format) { + // do not change: the CUDA kernels expects THREADS_PER_BLOCK==32 + const int THREADS_PER_BLOCK = 32; + + const int32 N = GetTensorDim(input_a_t, data_format, 'N'); + const int32 iC = GetTensorDim(input_a_t, data_format, 'C'); + const int32 iH = GetTensorDim(input_a_t, data_format, 'H'); + const int32 iW = GetTensorDim(input_a_t, data_format, 'W'); + + Tensor padded_a_t; + Tensor padded_b_t; + TensorShape padded_shape({N, iH + 2 * pad, iW + 2 * pad, iC}); + Status s; + s = context->allocate_temp(DataTypeToEnum::value, padded_shape, + &padded_a_t); + if (!TF_PREDICT_TRUE(s.ok())) { + return s; + } + s = context->allocate_temp(DataTypeToEnum::value, padded_shape, + &padded_b_t); + if (!TF_PREDICT_TRUE(s.ok())) { + return s; + } + + dim3 blocks_grid(N, iH, iW); + dim3 threads_block(THREADS_PER_BLOCK); + + // the output is always NCHW (python transposes it to NHWC) + const int32 oC = GetTensorDim(*output_t, FORMAT_NCHW, 'C'); + const int32 oH = GetTensorDim(*output_t, FORMAT_NCHW, 'H'); + const int32 oW = GetTensorDim(*output_t, FORMAT_NCHW, 'W'); + + // set everything to zero (we zero-pad) + cudaMemset(padded_a_t.flat().data(), 0, + padded_a_t.NumElements() * sizeof(Dtype)); + cudaMemset(padded_b_t.flat().data(), 0, + padded_b_t.NumElements() * sizeof(Dtype)); + cudaMemset(output_t->flat().data(), 0, + output_t->NumElements() * sizeof(Dtype)); + + const bool is_NCHW = (data_format == FORMAT_NCHW); + if (is_NCHW) { + pad_and_transpose<<>>( + input_a_t.flat().data(), padded_a_t.flat().data(), iC, + iH, iW, pad); + pad_and_transpose<<>>( + input_b_t.flat().data(), padded_b_t.flat().data(), iC, + iH, iW, pad); + } else { + pad_and_no_transpose<<>>( + input_a_t.flat().data(), padded_a_t.flat().data(), iC, + iH, iW, pad); + pad_and_no_transpose<<>>( + input_b_t.flat().data(), padded_b_t.flat().data(), iC, + iH, iW, pad); + } + + const GPUDevice &d = context->eigen_gpu_device(); + + dim3 threadsPerBlock(THREADS_PER_BLOCK); + dim3 totalBlocksCorr(N, oH, oW); + + Correlation_forward< + THREADS_PER_BLOCK><<>>( + output_t->flat().data(), oC, oH, oW, + padded_a_t.flat().data(), iC, iH, iW, + padded_b_t.flat().data(), pad, kernel_size, max_displacement, + stride_1, stride_2); + + return Status::OK(); + } +}; + +template +struct CorrelationCostGradFunctor { + Status operator()(OpKernelContext *context, const Tensor &input_a_t, + const Tensor &input_b_t, const Tensor &topdiff_t, + Tensor *output_a_gradient_t, Tensor *output_b_gradient_t, + /* params */ + int kernel_size, int max_displacement, int stride_1, + int stride_2, int pad, TensorFormat data_format) { + // do not change: the CUDA kernels expects THREADS_PER_BLOCK==32 + const int THREADS_PER_BLOCK = 32; + + const int32 N = GetTensorDim(input_a_t, data_format, 'N'); + const int32 iC = GetTensorDim(input_a_t, data_format, 'C'); + const int32 iH = GetTensorDim(input_a_t, data_format, 'H'); + const int32 iW = GetTensorDim(input_a_t, data_format, 'W'); + + Tensor padded_a_t; + Tensor padded_b_t; + TensorShape padded_shape({N, iH + 2 * pad, iW + 2 * pad, iC}); + Status s; + s = context->allocate_temp(DataTypeToEnum::value, padded_shape, + &padded_a_t); + if (!TF_PREDICT_TRUE(s.ok())) { + return s; + } + s = context->allocate_temp(DataTypeToEnum::value, padded_shape, + &padded_b_t); + if (!TF_PREDICT_TRUE(s.ok())) { + return s; + } + + dim3 blocks_grid(N, iH, iW); + dim3 threads_block(THREADS_PER_BLOCK); + + // topdiff is NCHW + const int32 oC = GetTensorDim(topdiff_t, FORMAT_NCHW, 'C'); + const int32 oH = GetTensorDim(topdiff_t, FORMAT_NCHW, 'H'); + const int32 oW = GetTensorDim(topdiff_t, FORMAT_NCHW, 'W'); + + // set everything to zero (we zero-pad) + cudaMemset(padded_a_t.flat().data(), 0, + padded_a_t.NumElements() * sizeof(Dtype)); + cudaMemset(padded_b_t.flat().data(), 0, + padded_b_t.NumElements() * sizeof(Dtype)); + cudaMemset(output_a_gradient_t->flat().data(), 0, + output_a_gradient_t->NumElements() * sizeof(Dtype)); + cudaMemset(output_b_gradient_t->flat().data(), 0, + output_b_gradient_t->NumElements() * sizeof(Dtype)); + + const bool is_NCHW = (data_format == FORMAT_NCHW); + if (is_NCHW) { + pad_and_transpose<<>>( + input_a_t.flat().data(), padded_a_t.flat().data(), iC, + iH, iW, pad); + pad_and_transpose<<>>( + input_b_t.flat().data(), padded_b_t.flat().data(), iC, + iH, iW, pad); + } else { + pad_and_no_transpose<<>>( + input_a_t.flat().data(), padded_a_t.flat().data(), iC, + iH, iW, pad); + pad_and_no_transpose<<>>( + input_b_t.flat().data(), padded_b_t.flat().data(), iC, + iH, iW, pad); + } + + const GPUDevice &d = context->eigen_gpu_device(); + + dim3 threadsPerBlock(THREADS_PER_BLOCK); + dim3 totalBlocksCorr(iH, iW, iC); + + for (int n = 0; n < N; ++n) { + Correlation_backward_input1< + THREADS_PER_BLOCK><<>>( + n, output_a_gradient_t->flat().data(), iC, iH, iW, + topdiff_t.flat().data(), oC, oH, oW, + padded_b_t.flat().data(), pad, kernel_size, max_displacement, + stride_1, stride_2, is_NCHW); + } + + for (int n = 0; n < N; n++) { + Correlation_backward_input2< + THREADS_PER_BLOCK><<>>( + n, output_b_gradient_t->flat().data(), iC, iH, iW, + topdiff_t.flat().data(), oC, oH, oW, + padded_a_t.flat().data(), pad, kernel_size, max_displacement, + stride_1, stride_2, is_NCHW); + } + + return Status::OK(); + } +}; + +template struct CorrelationCostFunctor; +template struct CorrelationCostGradFunctor; + +} // end namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow_addons/custom_ops/layers/cc/ops/correlation_cost_op.cc b/tensorflow_addons/custom_ops/layers/cc/ops/correlation_cost_op.cc new file mode 100644 index 0000000000..61fd9d7f0f --- /dev/null +++ b/tensorflow_addons/custom_ops/layers/cc/ops/correlation_cost_op.cc @@ -0,0 +1,133 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +using ::tensorflow::shape_inference::InferenceContext; +using ::tensorflow::shape_inference::ShapeHandle; + +// -------------------------------------------------------------------------- + +REGISTER_OP("CorrelationCost") + .Input("input_a: T") + .Input("input_b: T") + .Output("output: T") + .Attr("kernel_size: int") + .Attr("max_displacement: int") + .Attr("stride_1: int") + .Attr("stride_2: int") + .Attr("pad: int") + .Attr("data_format: {'NHWC', 'NCHW'} = 'NHWC'") + .Attr("T: realnumbertype") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle input_a, input_b, input; + + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_a)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &input_b)); + TF_RETURN_IF_ERROR(c->Merge(input_a, input_b, &input)); + + // get input shapes + int32 B, H, W; + B = c->Value(c->Dim(input, 0)); + string data_format; + Status s = c->GetAttr("data_format", &data_format); + if (s.ok() && data_format == "NCHW") { + H = c->Value(c->Dim(input, 2)); + W = c->Value(c->Dim(input, 3)); + } else { + H = c->Value(c->Dim(input, 1)); + W = c->Value(c->Dim(input, 2)); + } + + int32 kernel_size; + int32 max_displacement; + int32 stride_1; + int32 stride_2; + int32 pad; + + TF_RETURN_IF_ERROR(c->GetAttr("kernel_size", &kernel_size)); + TF_RETURN_IF_ERROR(c->GetAttr("max_displacement", &max_displacement)); + // stride in input + TF_RETURN_IF_ERROR(c->GetAttr("stride_1", &stride_1)); + // stride in patch + TF_RETURN_IF_ERROR(c->GetAttr("stride_2", &stride_2)); + TF_RETURN_IF_ERROR(c->GetAttr("pad", &pad)); + + // output channels are d**2 where, d = 2r + 1 + const int32 r = max_displacement / stride_2; + const int32 d = 2 * r + 1; + const int32 border = max_displacement + (kernel_size - 1) / 2; + + const int32 Cout = d * d; + // for spatial dimensions, we pad the inputs + const int32 Hout = static_cast( + ceil(static_cast(((H + 2 * pad) - border * 2)) / + static_cast(stride_1))); + const int32 Wout = static_cast( + ceil(static_cast(((W + 2 * pad) - border * 2)) / + static_cast(stride_1))); + + // Note, the output is always NCHW (even when input is NHWC) + c->set_output(0, c->MakeShape({B, Cout, Hout, Wout})); + return Status::OK(); + }) + .Doc(R"Doc( +Compute Correlation costs. + +This layer implements the correlation operation from +FlowNet Learning Optical Flow with Convolutional Networks (Fischer et al.) + +input_a: A `Tensor` of the format specified by `data_format`. +input_b: A `Tensor` of the format specified by `data_format`. +kernel_size: An integer specifying the height and width of the + patch used to compute the per-patch costs. +max_displacement: An integer specifying the maximum search radius + for each position. +stride_1: An integer specifying the stride length in the input. +stride_2: An integer specifying the stride length in the patch. +pad: An integer specifying the paddings in height and width. +data_format: Specifies the data format. + Possible values are: + "NHWC" float [batch, height, width, channels] + "NCHW" float [batch, channels, height, width] + Defaults to `"NHWC"`. +)Doc"); + +REGISTER_OP("CorrelationCostGrad") + .Input("orig_input_a: T") + .Input("orig_input_b: T") + .Input("top_diff: T") + .Output("bottom_diff_a: T") + .Output("bottom_diff_b: T") + .Attr("T: realnumbertype") + .Attr("kernel_size: int") + .Attr("max_displacement: int") + .Attr("stride_1: int") + .Attr("stride_2: int") + .Attr("pad: int") + .Attr("data_format: {'NHWC', 'NCHW'} = 'NHWC'") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle shp_hnd; + TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &shp_hnd)); + c->set_output(0, shp_hnd); + c->set_output(1, shp_hnd); + return Status::OK(); + }) + .Doc(R"doc(CorrelationCostGrad op.)doc"); + +} // namespace tensorflow diff --git a/tensorflow_addons/layers/BUILD b/tensorflow_addons/layers/BUILD index e05719a245..59aeb562b5 100644 --- a/tensorflow_addons/layers/BUILD +++ b/tensorflow_addons/layers/BUILD @@ -8,10 +8,14 @@ py_library( "__init__.py", "maxout.py", "normalizations.py", + "optical_flow.py", "poincare.py", "sparsemax.py", "wrappers.py", ], + data = [ + "//tensorflow_addons/custom_ops/layers:_correlation_cost_ops.so", + ], srcs_version = "PY2AND3", deps = [ "//tensorflow_addons/activations", @@ -70,3 +74,16 @@ py_test( ":layers", ], ) + +py_test( + name = "optical_flow_test", + size = "small", + srcs = [ + "optical_flow_test.py", + ], + main = "optical_flow_test.py", + srcs_version = "PY2AND3", + deps = [ + ":layers", + ], +) diff --git a/tensorflow_addons/layers/README.md b/tensorflow_addons/layers/README.md index 94cf0b55d8..4e4e4b48dc 100644 --- a/tensorflow_addons/layers/README.md +++ b/tensorflow_addons/layers/README.md @@ -5,6 +5,7 @@ |:---------- |:----------- |:------------- | | maxout | | | | normalizations | @smokrow | moritz.kroeger@tu-dortmund.de | +| opticalflow | | | | poincare | | | | sparsemax | @AndreasMadsen | amwwebdk+github@gmail.com | | wrappers | @seanpmorgan | seanmorgan@outlook.com | @@ -15,6 +16,7 @@ | maxout | Maxout | https://arxiv.org/abs/1302.4389 | | normalizations | GroupNormalization | https://arxiv.org/abs/1803.08494 | | normalizations | InstanceNormalization | https://arxiv.org/abs/1607.08022 | +| opticalflow | CorrelationCost | https://arxiv.org/abs/1504.06852 | | poincare | PoincareNormalize | https://arxiv.org/abs/1705.08039 | | sparsemax| Sparsemax | https://arxiv.org/abs/1602.02068 | | wrappers | WeightNormalization | https://arxiv.org/abs/1602.07868 | diff --git a/tensorflow_addons/layers/__init__.py b/tensorflow_addons/layers/__init__.py index e5acf9666a..382f2aa80e 100644 --- a/tensorflow_addons/layers/__init__.py +++ b/tensorflow_addons/layers/__init__.py @@ -21,6 +21,7 @@ from tensorflow_addons.layers.maxout import Maxout from tensorflow_addons.layers.normalizations import GroupNormalization from tensorflow_addons.layers.normalizations import InstanceNormalization +from tensorflow_addons.layers.optical_flow import CorrelationCost from tensorflow_addons.layers.poincare import PoincareNormalize from tensorflow_addons.layers.sparsemax import Sparsemax from tensorflow_addons.layers.wrappers import WeightNormalization diff --git a/tensorflow_addons/layers/optical_flow.py b/tensorflow_addons/layers/optical_flow.py new file mode 100644 index 0000000000..c2c843548a --- /dev/null +++ b/tensorflow_addons/layers/optical_flow.py @@ -0,0 +1,213 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tensorflow op performing correlation cost operation.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow_addons.utils import keras_utils +from tensorflow_addons.utils.resource_loader import get_path_to_datafile + +_correlation_cost_op_so = tf.load_op_library( + get_path_to_datafile("custom_ops/layers/_correlation_cost_ops.so")) + + +@tf.function +def correlation_cost(input_a, + input_b, + kernel_size, + max_displacement, + stride_1, + stride_2, + pad, + data_format='channels_last', + name=None): + """Correlation Cost Volume computation. + + "FlowNet: Learning Optical Flow with Convolutional Networks" + Philipp Fischer, Alexey Dosovitskiy, Eddy Ilg, Philip Hausser, + Caner Hazirbas, Vladimir Golkov, Patrick van der Smagt, + Daniel Cremers, Thomas Brox. https://arxiv.org/abs/1504.06852 + + Computes a cost volume using correlation for two inputs. For feature + maps A, B with spatial dimensions w, h, c it computes + + output(a, b) = sum_{l in [-k,k]**2} < I(a+l), J(b+l) > + + where the patches of size K=2d + 1 are centered in position a resp. b. + + The output shape is [B, C', H', W'], where + + r = max_displacement / stride_2; + bd = max_displacement + (kernel_size - 1) / 2 + C' = (2 * r + 1) ** 2 + H' = H + 2 * (pad - bd) / stride_1 + W' = W + 2 * (pad - bd) / stride_1 + + Note: When the data_format requests "channels_last", an additional explicit + transpose operation is executed. + + Args: + input_a: A `Tensor` of the format specified by `data_format`. + input_b: A `Tensor` of the format specified by `data_format`. + kernel_size: An integer specifying the height and width of the + patch used to compute the per-patch costs. + max_displacement: An integer specifying the maximum search radius + for each position. + stride_1: An integer specifying the stride length in the input. + stride_2: An integer specifying the stride length in the patch. + pad: An integer specifying the paddings in height and width. + data_format: Specifies the data format. + Possible values are: + "channels_last" float [batch, height, width, channels] + "channels_first" float [batch, channels, height, width] + Defaults to `"channels_last"`. + name: A name for the operation (optional). + + Returns: + A `Tensor` of the format specified by `data_format`. + """ + + with tf.name_scope(name or "correlation_cost"): + op_call = _correlation_cost_op_so.correlation_cost + + if data_format == "channels_last": + op_data_format = "NHWC" + elif data_format == "channels_first": + op_data_format = "NCHW" + else: + raise ValueError("`data_format` must be either `channels_last` or" + "`channels_first`") + + ret = op_call( + input_a, + input_b, + kernel_size=kernel_size, + max_displacement=max_displacement, + stride_1=stride_1, + stride_2=stride_2, + pad=pad, + data_format=op_data_format) + if data_format == 'channels_last': + # this is easier to maintain without + # specializing an additional cuda kernel + return tf.transpose(ret, [0, 2, 3, 1]) + return ret + + +@tf.RegisterGradient("CorrelationCost") +def _correlation_cost_grad(op, grad_output): + kernel_size = op.get_attr("kernel_size") + max_displacement = op.get_attr("max_displacement") + stride_1 = op.get_attr("stride_1") + stride_2 = op.get_attr("stride_2") + pad = op.get_attr("pad") + data_format = op.get_attr("data_format") + + input_a = tf.convert_to_tensor(op.inputs[0], name="input_a") + input_b = tf.convert_to_tensor(op.inputs[1], name="input_b") + grad_output_tensor = tf.convert_to_tensor(grad_output, name="grad_output") + + op_call = _correlation_cost_op_so.correlation_cost_grad + grads = op_call( + input_a, + input_b, + grad_output_tensor, + kernel_size=kernel_size, + max_displacement=max_displacement, + stride_1=stride_1, + stride_2=stride_2, + pad=pad, + data_format=data_format) + + grad_input_a = tf.convert_to_tensor(grads[0], name="grad_input_a") + grad_input_b = tf.convert_to_tensor(grads[1], name="grad_input_b") + return [grad_input_a, grad_input_b] + + +@keras_utils.register_keras_custom_object +class CorrelationCost(tf.keras.layers.Layer): + def __init__(self, kernel_size, max_displacement, stride_1, stride_2, pad, + data_format, **kwargs): + self.kernel_size = kernel_size + self.max_displacement = max_displacement + self.stride_1 = stride_1 + self.stride_2 = stride_2 + self.pad = pad + + if data_format != "channels_last" and data_format != "channels_first": + raise ValueError("`data_format` must be either `channels_last` or" + "`channels_first`, instead got %s" % data_format) + + self.data_format = data_format + + super(CorrelationCost, self).__init__(**kwargs) + + def build(self, input_shape): + if not isinstance(input_shape, list): + raise ValueError("Input must be a list of two Tensors to process") + super(CorrelationCost, self).build(input_shape) + + def call(self, inputs): + if not isinstance(inputs, list): + raise ValueError("Input must be a list of two Tensors to process") + + input_a = tf.convert_to_tensor(inputs[0]) + input_b = tf.convert_to_tensor(inputs[1]) + + return correlation_cost( + input_a, + input_b, + kernel_size=self.kernel_size, + max_displacement=self.max_displacement, + stride_1=self.stride_1, + stride_2=self.stride_2, + pad=self.pad, + data_format=self.data_format) + + def compute_output_shape(self, input_shape): + assert isinstance(input_shape, list) + n = input_shape[0][0] + r = self.max_displacement / self.stride_2 + bd = self.max_displacement + (self.kernel_size - 1) / 2 + output_c = (2 * r + 1)**2 + + if self.data_format == "channels_first": + output_h = input_shape[0][1] + 2 * (self.pad - bd) / self.stride_1 + output_w = input_shape[0][2] + 2 * (self.pad - bd) / self.stride_1 + return [n, output_c, output_h, output_w] + + elif self.data_format == "channels_last": + output_h = input_shape[0][0] + 2 * (self.pad - bd) / self.stride_1 + output_w = input_shape[0][1] + 2 * (self.pad - bd) / self.stride_1 + return [n, output_h, output_w, output_c] + else: + raise ValueError("`data_format` must be either `channels_last` or" + "`channels_first`") + + def get_config(self): + config = { + 'kernel_size': self.kernel_size, + 'max_displacement': self.max_displacement, + 'stride_1': self.stride_1, + 'stride_2': self.stride_2, + 'pad': self.pad, + 'data_format': self.data_format + } + + base_config = super(CorrelationCost, self).get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/tensorflow_addons/layers/optical_flow_test.py b/tensorflow_addons/layers/optical_flow_test.py new file mode 100644 index 0000000000..060f572c5f --- /dev/null +++ b/tensorflow_addons/layers/optical_flow_test.py @@ -0,0 +1,197 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf +from tensorflow_addons.layers.optical_flow import correlation_cost, CorrelationCost +from tensorflow_addons.utils import test_utils + + +@test_utils.run_all_in_graph_and_eager_modes +class CorrelationCostTest(tf.test.TestCase): + def _forward(self, input_a, input_b, kernel_size, max_displacement, + stride_1, stride_2, pad, data_format): + + input_a_op = tf.convert_to_tensor(input_a, dtype=tf.float32) + input_b_op = tf.convert_to_tensor(input_b, dtype=tf.float32) + + output = correlation_cost( + input_a_op, + input_b_op, + kernel_size=kernel_size, + max_displacement=max_displacement, + stride_1=stride_1, + stride_2=stride_2, + pad=pad, + data_format=data_format) + + return output + + def _forward_simple(self, data_format): + # We are just testing where the output has vanishing values. + with test_utils.use_gpu(): + val = [[[[0, -6, 9, 5], [1, -5, 10, 3], [2, -4, 11, 1]], + [[3, -3, 12, -1], [4, -2, 13, -3], [5, -1, 14, -5]]], + [[[6, 0, 15, -7], [7, 1, 16, -9], [8, 2, 17, -11]], + [[9, 3, 18, -13], [10, 4, 19, -15], [11, 5, 20, -17]]]] + + input_a = tf.constant(np.array(val), dtype=tf.float32) + valb = np.array(val).transpose(2, 3, 0, 1).reshape(2, 2, 3, 4) + input_b = tf.constant(valb, dtype=tf.float32) + + if data_format == 'channels_last': + input_a = tf.transpose(input_a, [0, 2, 3, 1]) + input_b = tf.transpose(input_b, [0, 2, 3, 1]) + + input_a_tensor = tf.convert_to_tensor(input_a, dtype=tf.float32) + input_b_tensor = tf.convert_to_tensor(input_b, dtype=tf.float32) + + kernel_size = 1 + max_displacement = 2 + stride_1 = 1 + stride_2 = 2 + pad = 4 + + actual = self._forward( + input_a_tensor, + input_b_tensor, + kernel_size=kernel_size, + max_displacement=max_displacement, + stride_1=stride_1, + stride_2=stride_2, + pad=pad, + data_format=data_format) + + if data_format == 'channels_last': + # NHWC -> NCHW + actual = tf.transpose(actual, [0, 3, 1, 2]) + + # We can test fixed ids, as output is independent from data_format + expected_ids = np.concatenate([np.zeros(464,), np.ones(464,)]) + self.assertAllClose( + tf.where(tf.equal(actual, 0))[:, 0], expected_ids) + + counts = [54, 52, 54, 50, 44, 50, 54, 52, 54] + expected_ids = np.concatenate( + [k * np.ones(v,) for k, v in enumerate(counts)]) + expected_ids = np.concatenate([expected_ids, expected_ids]) + self.assertAllClose( + tf.where(tf.equal(actual, 0))[:, 1], expected_ids) + self.assertEqual(actual.shape, (2, 9, 7, 8)) + + def _gradients(self, data_format): + with test_utils.use_gpu(): + batch, channels, height, width = 2, 3, 5, 6 + input_a = np.random.randn(batch, channels, height, + width).astype(np.float32) + input_b = np.random.randn(batch, channels, height, + width).astype(np.float32) + + kernel_size = 1 + max_displacement = 2 + stride_1 = 1 + stride_2 = 2 + pad = 4 + + if data_format == 'channels_last': + input_a = tf.transpose(input_a, [0, 2, 3, 1]) + input_b = tf.transpose(input_b, [0, 2, 3, 1]) + + input_a_op = tf.convert_to_tensor(input_a) + input_b_op = tf.convert_to_tensor(input_b) + + def correlation_fn(input_a, input_b): + return correlation_cost( + input_a, + input_b, + kernel_size=kernel_size, + max_displacement=max_displacement, + stride_1=stride_1, + stride_2=stride_2, + pad=pad, + data_format=data_format) + + theoretical, numerical = tf.test.compute_gradient( + correlation_fn, [input_a_op, input_b_op]) + + self.assertAllClose(theoretical[0], numerical[0], atol=1e-3) + + def _keras(self, data_format): + # Unable to use `layer_test` as this layer has multiple inputs. + with test_utils.use_gpu(): + val_a = [[[[0, -6, 9, 5], [1, -5, 10, 3], [2, -4, 11, 1]], + [[3, -3, 12, -1], [4, -2, 13, -3], [5, -1, 14, -5]]], + [[[6, 0, 15, -7], [7, 1, 16, -9], [8, 2, 17, -11]], + [[9, 3, 18, -13], [10, 4, 19, -15], [11, 5, 20, -17]]]] + val_b = np.array(val_a).transpose(2, 3, 0, 1).reshape(2, 2, 3, 4) + + # yapf: disable + input_a = tf.keras.Input(shape=(2, 3, 4,)) + input_b = tf.keras.Input(shape=(2, 3, 4,)) + + layer = CorrelationCost( + kernel_size=1, + max_displacement=2, + stride_1=1, + stride_2=2, + pad=4, + data_format=data_format) + + expected_output_shape = tuple( + layer.compute_output_shape([(2, 3, 4,), (2, 3, 4,)]))[1:] + # yapf: enable + + x = [input_a, input_b] + y = layer(x) + model = tf.python.keras.models.Model(x, y) + actual_output = model.predict([val_a, val_b]) + + expected_output_type = 'float32' + if tf.keras.backend.dtype(y[0]) != expected_output_type: + raise AssertionError( + "Inferred output type %s does not equal " + "expected output type %s" % (tf.keras.backend.dtype(y[0]), + expected_output_type)) + + if actual_output[0].shape != expected_output_shape: + raise AssertionError( + "Expected shape %s does not equal output shape" + "%s" % (actual_output[0].shape, expected_output_shape)) + + def testForwardNCHW(self): + self._forward_simple(data_format='channels_first') + + def testForwardNHWC(self): + self._forward_simple(data_format='channels_last') + + def testBackwardNCHW(self): + self._gradients(data_format='channels_first') + + def testBackwardNHWC(self): + self._gradients(data_format='channels_last') + + def testKerasNCHW(self): + self._keras(data_format='channels_first') + + def testKerasNHWC(self): + self._keras(data_format='channels_last') + + +if __name__ == "__main__": + tf.test.main()