From a9a9536aa5ab561fcf6f8b979221e9590955ac77 Mon Sep 17 00:00:00 2001 From: PatWie Date: Thu, 25 Apr 2019 15:26:34 +0200 Subject: [PATCH 01/18] add correlation cost layer --- tensorflow_addons/custom_ops/README.md | 1 + .../custom_ops/opticalflow/BUILD | 25 + .../cc/kernels/correlation_cost_op.cc | 345 +++++++++++++ .../cc/kernels/correlation_cost_op.h | 47 ++ .../cc/kernels/correlation_cost_op_gpu.cu.cc | 477 ++++++++++++++++++ .../opticalflow/cc/ops/correlation_cost_op.cc | 133 +++++ .../opticalflow/cc/python/__init__.py | 19 + .../kernel_tests/correlation_cost_op_test.py | 228 +++++++++ .../cc/python/ops/correlation_cost_op.py | 126 +++++ 9 files changed, 1401 insertions(+) create mode 100644 tensorflow_addons/custom_ops/opticalflow/BUILD create mode 100644 tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op.cc create mode 100644 tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op.h create mode 100644 tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op_gpu.cu.cc create mode 100644 tensorflow_addons/custom_ops/opticalflow/cc/ops/correlation_cost_op.cc create mode 100644 tensorflow_addons/custom_ops/opticalflow/cc/python/__init__.py create mode 100644 tensorflow_addons/custom_ops/opticalflow/cc/python/kernel_tests/correlation_cost_op_test.py create mode 100644 tensorflow_addons/custom_ops/opticalflow/cc/python/ops/correlation_cost_op.py diff --git a/tensorflow_addons/custom_ops/README.md b/tensorflow_addons/custom_ops/README.md index 522be99119..0d657954e6 100644 --- a/tensorflow_addons/custom_ops/README.md +++ b/tensorflow_addons/custom_ops/README.md @@ -6,3 +6,4 @@ | Image | Ops for image manipulation | | Seq2seq | Ops for seq2seq encoder-decoder framework | | Text | Ops for text processing | +| OpticalFlow | Ops for optical flow processing | diff --git a/tensorflow_addons/custom_ops/opticalflow/BUILD b/tensorflow_addons/custom_ops/opticalflow/BUILD new file mode 100644 index 0000000000..fa6e708b67 --- /dev/null +++ b/tensorflow_addons/custom_ops/opticalflow/BUILD @@ -0,0 +1,25 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) + +cc_binary( + name = "_distort_image_ops.so", + srcs = [ + "cc/kernels/correlation_cost_op.cc", + "cc/kernels/correlation_cost_op.h", + "cc/kernels/correlation_cost_op_gpu.cu.cc", + "cc/ops/correlation_cost_op.cc", + ], + copts = [ + "-pthread", + "-std=c++11", + "-D_GLIBCXX_USE_CXX11_ABI=0", + ], + linkshared = 1, + deps = [ + "@local_config_tf//:libtensorflow_framework", + "@local_config_tf//:tf_header_lib", + ]+ if_cuda([ + "@cub_archive//:cub" + ]), +) diff --git a/tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op.cc b/tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op.cc new file mode 100644 index 0000000000..6bfaf38828 --- /dev/null +++ b/tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op.cc @@ -0,0 +1,345 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/correlation_cost/kernels/correlation_cost_op.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + +template +struct CorrelationCostFunctor { + Status operator()(OpKernelContext* context, const Tensor& input_a_t, + const Tensor& input_b_t, Tensor* output_t, + /* params */ + int kernel_size, int max_displacement, int stride_1, + int stride_2, int pad, TensorFormat data_format) { + const int32 oN = GetTensorDim(*output_t, FORMAT_NCHW, 'N'); + // const int32 oC = GetTensorDim(*output_t, FORMAT_NCHW, 'C'); + const int32 oH = GetTensorDim(*output_t, FORMAT_NCHW, 'H'); + const int32 oW = GetTensorDim(*output_t, FORMAT_NCHW, 'W'); + const int32 iH = GetTensorDim(input_a_t, data_format, 'H'); + const int32 iW = GetTensorDim(input_a_t, data_format, 'W'); + const int32 iC = GetTensorDim(input_a_t, data_format, 'C'); + + const int K = kernel_size * kernel_size * iC; + + const auto input_a = input_a_t.tensor(); + const auto input_b = input_b_t.tensor(); + auto output = output_t->tensor(); + output.setZero(); + + const int kernel_rad = (kernel_size - 1) / 2; + const int displacement_rad = max_displacement / stride_2; + const int displacement_size = 2 * displacement_rad + 1; + + const bool is_NCHW = (data_format == FORMAT_NCHW); + + for (int n = 0; n < oN; ++n) { + for (int h = 0; h < oH; ++h) { + const int h1 = (h - pad) * stride_1 + max_displacement + kernel_rad; + for (int w = 0; w < oW; ++w) { + const int w1 = (w - pad) * stride_1 + max_displacement + kernel_rad; + + for (int tj = -displacement_rad; tj <= displacement_rad; ++tj) { + for (int ti = -displacement_rad; ti <= displacement_rad; ++ti) { + const int tc = (tj + displacement_rad) * displacement_size + + (ti + displacement_rad); + + const int w2 = w1 + ti * stride_2; + const int h2 = h1 + tj * stride_2; + + for (int j = -kernel_rad; j <= kernel_rad; ++j) { + // out-of-bound test + if ((h1 + j < 0) || (h1 + j >= iH)) continue; + if ((h2 + j < 0) || (h2 + j >= iH)) continue; + for (int i = -kernel_rad; i <= kernel_rad; ++i) { + if ((w1 + i < 0) || (w1 + i >= iW)) continue; + if ((w2 + i < 0) || (w2 + i >= iW)) continue; + for (int c = 0; c < iC; ++c) { + // eq. (1) in FlowNet: Learning Optical Flow with + // Convolutional Networks + if (is_NCHW) { + output(n, tc, h, w) += input_a(n, c, h1 + j, w1 + i) * + input_b(n, c, h2 + j, w2 + i); + } else { + output(n, tc, h, w) += input_a(n, h1 + j, w1 + i, c) * + input_b(n, h2 + j, w2 + i, c); + } + } + } + } + output(n, tc, h, w) /= K; + } + } + } + } + } + return Status::OK(); + } +}; + +template +struct CorrelationCostGradFunctor { + Status operator()(OpKernelContext* context, const Tensor& input_a_t, + const Tensor& input_b_t, const Tensor& topdiff_t, + Tensor* output_a_gradient_t, Tensor* output_b_gradient_t, + /* params */ + int kernel_size, int max_displacement, int stride_1, + int stride_2, int pad, TensorFormat data_format) { + const int32 iN = GetTensorDim(input_a_t, data_format, 'N'); + const int32 iC = GetTensorDim(input_a_t, data_format, 'C'); + const int32 iH = GetTensorDim(input_a_t, data_format, 'H'); + const int32 iW = GetTensorDim(input_a_t, data_format, 'W'); + + // topdiff is NCHW + // const int32 oC = GetTensorDim(topdiff_t, FORMAT_NCHW, 'C'); + const int32 oH = GetTensorDim(topdiff_t, FORMAT_NCHW, 'H'); + const int32 oW = GetTensorDim(topdiff_t, FORMAT_NCHW, 'W'); + + const auto topdiff = topdiff_t.tensor(); + const auto input_a = input_a_t.tensor(); + const auto input_b = input_b_t.tensor(); + auto output_a_gradient = output_a_gradient_t->tensor(); + auto output_b_gradient = output_b_gradient_t->tensor(); + output_a_gradient.setZero(); + output_b_gradient.setZero(); + + const int kernel_rad = (kernel_size - 1) / 2; + const int displacement_rad = max_displacement / stride_2; + const int displacement_size = 2 * displacement_rad + 1; + const int K = kernel_size * kernel_size * iC; + + const bool is_NCHW = (data_format == FORMAT_NCHW); + + for (int n = 0; n < iN; ++n) { + for (int h = 0; h < oH; ++h) { + const int h1 = (h - pad) * stride_1 + max_displacement + kernel_rad; + for (int w = 0; w < oW; ++w) { + const int w1 = (w - pad) * stride_1 + max_displacement + kernel_rad; + + for (int tj = -displacement_rad; tj <= displacement_rad; ++tj) { + for (int ti = -displacement_rad; ti <= displacement_rad; ++ti) { + const int tc = (tj + displacement_rad) * displacement_size + + (ti + displacement_rad); + + const int w2 = w1 + ti * stride_2; + const int h2 = h1 + tj * stride_2; + + for (int j = -kernel_rad; j <= kernel_rad; ++j) { + // out-of-bound test + if ((h1 + j < 0) || (h1 + j >= iH)) continue; + if ((h2 + j < 0) || (h2 + j >= iH)) continue; + for (int i = -kernel_rad; i <= kernel_rad; ++i) { + if ((w1 + i < 0) || (w1 + i >= iW)) continue; + if ((w2 + i < 0) || (w2 + i >= iW)) continue; + for (int c = 0; c < iC; ++c) { + // eq. (1) in FlowNet: Learning Optical Flow with + // Convolutional Networks + if (is_NCHW) { + output_a_gradient(n, c, h1 + j, w1 + i) += + topdiff(n, tc, h, w) * input_b(n, c, h2 + j, w2 + i) / + K; + output_b_gradient(n, c, h2 + j, w2 + i) += + topdiff(n, tc, h, w) * input_a(n, c, h1 + j, w1 + i) / + K; + } else { + output_a_gradient(n, h1 + j, w1 + i, c) += + topdiff(n, tc, h, w) * input_b(n, h2 + j, w2 + i, c) / + K; + output_b_gradient(n, h2 + j, w2 + i, c) += + topdiff(n, tc, h, w) * input_a(n, h1 + j, w1 + i, c) / + K; + } + } + } + } + } + } + } + } + } + return Status::OK(); + } +}; + +} // end namespace functor + +template +class CorrelationCostOp : public OpKernel { + public: + explicit CorrelationCostOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("kernel_size", &kernel_size)); + OP_REQUIRES_OK(context, + context->GetAttr("max_displacement", &max_displacement)); + OP_REQUIRES_OK(context, context->GetAttr("stride_1", &stride_1)); + OP_REQUIRES_OK(context, context->GetAttr("stride_2", &stride_2)); + OP_REQUIRES_OK(context, context->GetAttr("pad", &pad)); + string data_format; + OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); + OP_REQUIRES(context, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + OP_REQUIRES(context, kernel_size % 2 != 0, + errors::InvalidArgument("kernel_size must be odd")); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input_a_t = context->input(0); + const Tensor& input_b_t = context->input(1); + + // we didn't check the batch-dimension during "SetShapeFn" + OP_REQUIRES(context, input_a_t.shape() == input_b_t.shape(), + errors::InvalidArgument("Input shapes have to be the same")); + + const int32 N = GetTensorDim(input_a_t, data_format_, 'N'); + const int32 H = GetTensorDim(input_a_t, data_format_, 'H'); + const int32 W = GetTensorDim(input_a_t, data_format_, 'W'); + + // output channels are d**2 where, d = 2r + 1 + const int32 r = max_displacement / stride_2; + const int32 d = 2 * r + 1; + const int32 border = max_displacement + (kernel_size - 1) / 2; + + const int32 Cout = d * d; + const int32 Hout = + static_cast(ceil(static_cast(((H + 2 * pad) - border * 2)) / + static_cast(stride_1))); + const int32 Wout = + static_cast(ceil(static_cast(((W + 2 * pad) - border * 2)) / + static_cast(stride_1))); + + OP_REQUIRES(context, Hout >= 1, + errors::InvalidArgument( + "Neighborhood and kernel don't fit in input height.")); + OP_REQUIRES(context, Wout >= 1, + errors::InvalidArgument( + "Neighborhood and kernel don't fit in input width.")); + + Tensor* output_t; + OP_REQUIRES_OK( + context, context->allocate_output(0, TensorShape({N, Cout, Hout, Wout}), + &output_t)); + + + functor::CorrelationCostFunctor correlationCostFunc; + Status s = correlationCostFunc( + context, input_a_t, input_b_t, output_t, + /* params */ + kernel_size, max_displacement, stride_1, stride_2, pad, data_format_); + + OP_REQUIRES_OK(context, s); + } + + private: + int kernel_size; + int max_displacement; + int stride_1; + int stride_2; + int pad; + TensorFormat data_format_; +}; + +template +class CorrelationCostGradOp : public OpKernel { + public: + explicit CorrelationCostGradOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("kernel_size", &kernel_size)); + OP_REQUIRES_OK(context, + context->GetAttr("max_displacement", &max_displacement)); + OP_REQUIRES_OK(context, context->GetAttr("stride_1", &stride_1)); + OP_REQUIRES_OK(context, context->GetAttr("stride_2", &stride_2)); + OP_REQUIRES_OK(context, context->GetAttr("pad", &pad)); + string data_format; + OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); + OP_REQUIRES(context, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + OP_REQUIRES(context, kernel_size % 2 != 0, + errors::InvalidArgument("kernel_size must be odd")); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input_a_t = context->input(0); + const Tensor& input_b_t = context->input(1); + const Tensor& topdiff_t = context->input(2); + + OP_REQUIRES(context, input_a_t.shape() == input_b_t.shape(), + errors::InvalidArgument("Input shapes have to be the same")); + + // Allocate the memory for the bottom diffs + Tensor* output_a_gradient_t; + OP_REQUIRES_OK(context, context->allocate_output(0, input_a_t.shape(), + &output_a_gradient_t)); + Tensor* output_b_gradient_t; + OP_REQUIRES_OK(context, context->allocate_output(1, input_b_t.shape(), + &output_b_gradient_t)); + + functor::CorrelationCostGradFunctor correlationCostGrad; + Status s = correlationCostGrad( + context, input_a_t, input_b_t, topdiff_t, + output_a_gradient_t, output_b_gradient_t, + /* params */ + kernel_size, max_displacement, stride_1, stride_2, pad, data_format_); + + OP_REQUIRES_OK(context, s); + } + + private: + int kernel_size; + int max_displacement; + int stride_1; + int stride_2; + int pad; + TensorFormat data_format_; +}; + +// Register the CPU kernels. +#define REGISTER_CORRELATIONCOST_OP_CPU(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("CorrelationCost").Device(DEVICE_CPU).TypeConstraint("T"), \ + CorrelationCostOp) \ + REGISTER_KERNEL_BUILDER( \ + Name("CorrelationCostGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ + CorrelationCostGradOp) + +TF_CALL_float(REGISTER_CORRELATIONCOST_OP_CPU); +#undef REGISTER_CORRELATIONCOST_OP_CPU + +// Register the GPU kernels. +#ifdef GOOGLE_CUDA + +#define REGISTER_CORRELATIONCOST_OP_GPU(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("CorrelationCost").Device(DEVICE_GPU).TypeConstraint("T"), \ + CorrelationCostOp) \ + REGISTER_KERNEL_BUILDER( \ + Name("CorrelationCostGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ + CorrelationCostGradOp) + +TF_CALL_float(REGISTER_CORRELATIONCOST_OP_GPU); +#undef REGISTER_CORRELATIONCOST_OP_GPU + +#endif // GOOGLE_CUDA + +} // end namespace tensorflow diff --git a/tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op.h b/tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op.h new file mode 100644 index 0000000000..056c0cfc64 --- /dev/null +++ b/tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORRELATION_COST_OP_H_ +#define TENSORFLOW_CORRELATION_COST_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { +namespace functor { + +template +struct CorrelationCostFunctor { + Status operator()(OpKernelContext* context, const Tensor& input_a_t, + const Tensor& input_b_t, Tensor* output_t, + /* params */ + int kernel_size, int max_displacement, int stride_1, + int stride_2, int pad, TensorFormat data_format); +}; + +template +struct CorrelationCostGradFunctor { + Status operator()(OpKernelContext* context, const Tensor& input_a_t, + const Tensor& input_b_t, const Tensor& topdiff_t, + Tensor* output_a_gradient_t, Tensor* output_b_gradient_t, + /* params */ + int kernel_size, int max_displacement, int stride_1, + int stride_2, int pad, TensorFormat data_format); +}; + +} // end namespace functor +} // end namespace tensorflow + +#endif // TENSORFLOW_CORRELATION_COST_OP_H_ diff --git a/tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op_gpu.cu.cc b/tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op_gpu.cu.cc new file mode 100644 index 0000000000..51451dfb06 --- /dev/null +++ b/tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op_gpu.cu.cc @@ -0,0 +1,477 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/contrib/correlation_cost/kernels/correlation_cost_op.h" + +#include "external/cub_archive/cub/device/device_reduce.cuh" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + +namespace { + +/* There are two ways to implement the correlation layer: +- pad first and then compute cross-correlation costs (faster) +- have if-else in the computation of cross-correlation costs + +This implementation is inspired from +https://github.com/NVIDIA/flownet2-pytorch +*/ + +template +__global__ void pad_and_transpose(const float *input, float *output, int C, + int H, int W, int P) { + // NCHW -> pad(NHWC) + const int n = blockIdx.x; + const int h = blockIdx.y; + const int w = blockIdx.z; + const int c0 = threadIdx.x; + const int pW = (W + 2 * P); + const int pH = (H + 2 * P); + + float value; + for (int c = c0; c < C; c += THREADS_PER_BLOCK) { + value = input[n * (C * H * W) + c * (H * W) + h * W + w]; + output[n * (C * pH * pW) + (h + P) * (pW * C) + (w + P) * C + c] = value; + } +} + +template +__global__ void pad_and_no_transpose(const float *input, float *output, int C, + int H, int W, int P) { + // NHWC -> pad(NHWC) + const int n = blockIdx.x; + const int h = blockIdx.y; + const int w = blockIdx.z; + const int c0 = threadIdx.x; + const int pW = (W + 2 * P); + const int pH = (H + 2 * P); + + float value; + for (int c = c0; c < C; c += THREADS_PER_BLOCK) { + value = input[n * (C * H * W) + h * (W * C) + w * C + c]; + output[n * (C * pH * pW) + (h + P) * (pW * C) + (w + P) * C + c] = value; + } +} + +template +__global__ void Correlation_forward(float *output, int Cout, int Hout, int Wout, + float *pInput1, int Cin, int Hin, int Win, + float *pInput2, int pad, int kernel_size, + int max_displacement, int stride1, + int stride2) { + const int pWin = Win + 2 * pad; + const int pHin = Hin + 2 * pad; + + const int kernel_rad = (kernel_size - 1) / 2; + const int displacement_rad = max_displacement / stride2; + const int displacement_size = 2 * displacement_rad + 1; + + const int n = blockIdx.x; + const int h1 = blockIdx.y * stride1 + max_displacement + kernel_rad; + const int w1 = blockIdx.z * stride1 + max_displacement + kernel_rad; + const int c = threadIdx.x; + + const int K = kernel_size * kernel_size * Cin; + + typedef cub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_sum_storage; + float thread_accumulation = 0; + + for (int tj = -displacement_rad; tj <= displacement_rad; ++tj) { + for (int ti = -displacement_rad; ti <= displacement_rad; ++ti) { + thread_accumulation = 0; + int w2 = w1 + ti * stride2; + int h2 = h1 + tj * stride2; + + for (int j = -kernel_rad; j <= kernel_rad; ++j) { + for (int i = -kernel_rad; i <= kernel_rad; ++i) { + for (int ch = c; ch < Cin; ch += THREADS_PER_BLOCK) { + const int indx1 = n * (pHin * pWin * Cin) + + (h1 + j) * (pWin * Cin) + (w1 + i) * Cin + ch; + const int indx2 = n * (pHin * pWin * Cin) + + (h2 + j) * (pWin * Cin) + (w2 + i) * Cin + ch; + thread_accumulation += pInput1[indx1] * pInput2[indx2]; + } + } + } + __syncthreads(); + + // THREADS_PER_BLOCK==32, hence there is only one warp per block + const float reduce_sum = + WarpReduce(temp_sum_storage).Sum(thread_accumulation); + if (c == 0) { + const int tc = (tj + displacement_rad) * displacement_size + + (ti + displacement_rad); + const int tindx = n * (Cout * Hout * Wout) + tc * (Hout * Wout) + + blockIdx.y * Wout + blockIdx.z; + output[tindx] = reduce_sum / K; + } + } + } +} + +template +__global__ void Correlation_backward_input1( + int item, float *gradInput1, int Cin, int Hin, int Win, + const float *gradOutput, int Cout, int Hout, int Wout, const float *rInput2, + int pad_size, int kernel_size, int max_displacement, int stride1, + int stride2, bool is_NCHW) { + const int n = item; + const int h = blockIdx.x * stride1 + pad_size; + const int w = blockIdx.y * stride1 + pad_size; + const int c = blockIdx.z; + const int t0 = threadIdx.x; + + const int kernel_rad = (kernel_size - 1) / 2; + const int displacement_rad = max_displacement / stride2; + const int displacement_size = 2 * displacement_rad + 1; + + int Wmin = (w - kernel_rad - max_displacement) / stride1; + int Hmin = (h - kernel_rad - max_displacement) / stride1; + + int Wmax = (w + kernel_rad - max_displacement) / stride1; + int Hmax = (h + kernel_rad - max_displacement) / stride1; + + if (Wmax < 0 || Hmax < 0 || Wmin >= Wout || Hmin >= Hout) { + // assumes gradInput1 is pre-allocated and zero filled + return; + } + + if (Wmin > Wmax || Hmin > Hmax) { + // assumes gradInput1 is pre-allocated and zero filled + return; + } + + Wmin = max(0, Wmin); + Wmax = min(Wout - 1, Wmax); + + Hmin = max(0, Hmin); + Hmax = min(Hout - 1, Hmax); + + const int pWin = Win + 2 * pad_size; + const int pHin = Hin + 2 * pad_size; + const float nelems = kernel_size * kernel_size * Cin; + + typedef cub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_sum_storage; + float thread_accumulation = 0; + + for (int tc = t0; tc < Cout; tc += THREADS_PER_BLOCK) { + int i2 = (tc % displacement_size - displacement_rad) * stride2; + int j2 = (tc / displacement_size - displacement_rad) * stride2; + + int indx2 = + n * (pHin * pWin * Cin) + (h + j2) * (pWin * Cin) + (w + i2) * Cin + c; + + float val2 = rInput2[indx2]; + + for (int j = Hmin; j <= Hmax; ++j) { + for (int i = Wmin; i <= Wmax; ++i) { + const int tindx = + n * (Cout * Hout * Wout) + tc * (Hout * Wout) + j * Wout + i; + thread_accumulation += gradOutput[tindx] * val2; + } + } + } + __syncthreads(); + + // THREADS_PER_BLOCK==32, hence there is only one warp per block + const float reduce_sum = + WarpReduce(temp_sum_storage).Sum(thread_accumulation); + if (t0 == 0) { + if (is_NCHW) { + const int indx1 = n * (Cin * Hin * Win) + c * (Hin * Win) + + (h - pad_size) * Win + (w - pad_size); + gradInput1[indx1] = reduce_sum / nelems; + } else { + const int indx1 = n * (Cin * Hin * Win) + (h - pad_size) * (Win * Cin) + + (w - pad_size) * Cin + c; + gradInput1[indx1] = reduce_sum / nelems; + } + } +} + +template +__global__ void Correlation_backward_input2( + int item, float *gradInput2, int Cin, int Hin, int Win, + const float *gradOutput, int Cout, int Hout, int Wout, const float *rInput1, + int pad_size, int kernel_size, int max_displacement, int stride1, + int stride2, bool is_NCHW) { + const int n = item; + const int h = blockIdx.x * stride1 + pad_size; + const int w = blockIdx.y * stride1 + pad_size; + const int c = blockIdx.z; + const int t0 = threadIdx.x; + + const int kernel_rad = (kernel_size - 1) / 2; + const int displacement_rad = max_displacement / stride2; + const int displacement_size = 2 * displacement_rad + 1; + + const int pWin = Win + 2 * pad_size; + const int pHin = Hin + 2 * pad_size; + const float nelems = kernel_size * kernel_size * Cin; + + typedef cub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_sum_storage; + float thread_accumulation = 0; + + for (int tc = t0; tc < Cout; tc += THREADS_PER_BLOCK) { + const int i2 = (tc % displacement_size - displacement_rad) * stride2; + const int j2 = (tc / displacement_size - displacement_rad) * stride2; + + int Wmin = (w - kernel_rad - max_displacement - i2) / stride1; + int Hmin = (h - kernel_rad - max_displacement - j2) / stride1; + + int Wmax = (w + kernel_rad - max_displacement - i2) / stride1; + int Hmax = (h + kernel_rad - max_displacement - j2) / stride1; + + if (Wmax < 0 || Hmax < 0 || Wmin >= Wout || Hmin >= Hout) { + // assumes gradInput2 is pre-allocated and zero filled + continue; + } + + if (Wmin > Wmax || Hmin > Hmax) { + // assumes gradInput2 is pre-allocated and zero filled + continue; + } + + Wmin = max(0, Wmin); + Wmax = min(Wout - 1, Wmax); + + Hmin = max(0, Hmin); + Hmax = min(Hout - 1, Hmax); + + const int indx1 = + n * (pHin * pWin * Cin) + (h - j2) * (pWin * Cin) + (w - i2) * Cin + c; + const float val1 = rInput1[indx1]; + + for (int j = Hmin; j <= Hmax; ++j) { + for (int i = Wmin; i <= Wmax; ++i) { + const int tindx = + n * (Cout * Hout * Wout) + tc * (Hout * Wout) + j * Wout + i; + thread_accumulation += gradOutput[tindx] * val1; + } + } + } + __syncthreads(); + + const float reduce_sum = + WarpReduce(temp_sum_storage).Sum(thread_accumulation); + if (t0 == 0) { + if (is_NCHW) { + const int indx2 = n * (Cin * Hin * Win) + c * (Hin * Win) + + (h - pad_size) * (Win) + (w - pad_size); + gradInput2[indx2] = reduce_sum / nelems; + } else { + const int indx2 = n * (Cin * Hin * Win) + (h - pad_size) * (Win * Cin) + + (w - pad_size) * Cin + c; + gradInput2[indx2] = reduce_sum / nelems; + } + } +} + +}; // namespace + +template +struct CorrelationCostFunctor { + Status operator()(OpKernelContext *context, const Tensor &input_a_t, + const Tensor &input_b_t, Tensor *output_t, + /* params */ + int kernel_size, int max_displacement, int stride_1, + int stride_2, int pad, TensorFormat data_format) { + // do not change: the CUDA kernels expects THREADS_PER_BLOCK==32 + const int THREADS_PER_BLOCK = 32; + + const int32 N = GetTensorDim(input_a_t, data_format, 'N'); + const int32 iC = GetTensorDim(input_a_t, data_format, 'C'); + const int32 iH = GetTensorDim(input_a_t, data_format, 'H'); + const int32 iW = GetTensorDim(input_a_t, data_format, 'W'); + + Tensor padded_a_t; + Tensor padded_b_t; + TensorShape padded_shape({N, iH + 2 * pad, iW + 2 * pad, iC}); + Status s; + s = context->allocate_temp(DataTypeToEnum::value, padded_shape, + &padded_a_t); + if (!TF_PREDICT_TRUE(s.ok())) { + return s; + } + s = context->allocate_temp(DataTypeToEnum::value, padded_shape, + &padded_b_t); + if (!TF_PREDICT_TRUE(s.ok())) { + return s; + } + + dim3 blocks_grid(N, iH, iW); + dim3 threads_block(THREADS_PER_BLOCK); + + // the output is always NCHW (python transposes it to NHWC) + const int32 oC = GetTensorDim(*output_t, FORMAT_NCHW, 'C'); + const int32 oH = GetTensorDim(*output_t, FORMAT_NCHW, 'H'); + const int32 oW = GetTensorDim(*output_t, FORMAT_NCHW, 'W'); + + // set everything to zero (we zero-pad) + cudaMemset(padded_a_t.flat().data(), 0, + padded_a_t.NumElements() * sizeof(Dtype)); + cudaMemset(padded_b_t.flat().data(), 0, + padded_b_t.NumElements() * sizeof(Dtype)); + cudaMemset(output_t->flat().data(), 0, + output_t->NumElements() * sizeof(Dtype)); + + const bool is_NCHW = (data_format == FORMAT_NCHW); + if (is_NCHW) { + pad_and_transpose<<>>( + input_a_t.flat().data(), padded_a_t.flat().data(), iC, + iH, iW, pad); + pad_and_transpose<<>>( + input_b_t.flat().data(), padded_b_t.flat().data(), iC, + iH, iW, pad); + } else { + pad_and_no_transpose<<>>( + input_a_t.flat().data(), padded_a_t.flat().data(), iC, + iH, iW, pad); + pad_and_no_transpose<<>>( + input_b_t.flat().data(), padded_b_t.flat().data(), iC, + iH, iW, pad); + } + + const GPUDevice &d = context->eigen_gpu_device(); + + dim3 threadsPerBlock(THREADS_PER_BLOCK); + dim3 totalBlocksCorr(N, oH, oW); + + Correlation_forward + <<>>( + output_t->flat().data(), oC, oH, oW, + padded_a_t.flat().data(), iC, iH, iW, + padded_b_t.flat().data(), pad, kernel_size, max_displacement, + stride_1, stride_2); + + return Status::OK(); + } +}; + +template +struct CorrelationCostGradFunctor { + Status operator()(OpKernelContext *context, const Tensor &input_a_t, + const Tensor &input_b_t, const Tensor &topdiff_t, + Tensor *output_a_gradient_t, Tensor *output_b_gradient_t, + /* params */ + int kernel_size, int max_displacement, int stride_1, + int stride_2, int pad, TensorFormat data_format) { + // do not change: the CUDA kernels expects THREADS_PER_BLOCK==32 + const int THREADS_PER_BLOCK = 32; + + const int32 N = GetTensorDim(input_a_t, data_format, 'N'); + const int32 iC = GetTensorDim(input_a_t, data_format, 'C'); + const int32 iH = GetTensorDim(input_a_t, data_format, 'H'); + const int32 iW = GetTensorDim(input_a_t, data_format, 'W'); + + Tensor padded_a_t; + Tensor padded_b_t; + TensorShape padded_shape({N, iH + 2 * pad, iW + 2 * pad, iC}); + Status s; + s = context->allocate_temp(DataTypeToEnum::value, padded_shape, + &padded_a_t); + if (!TF_PREDICT_TRUE(s.ok())) { + return s; + } + s = context->allocate_temp(DataTypeToEnum::value, padded_shape, + &padded_b_t); + if (!TF_PREDICT_TRUE(s.ok())) { + return s; + } + + dim3 blocks_grid(N, iH, iW); + dim3 threads_block(THREADS_PER_BLOCK); + + // topdiff is NCHW + const int32 oC = GetTensorDim(topdiff_t, FORMAT_NCHW, 'C'); + const int32 oH = GetTensorDim(topdiff_t, FORMAT_NCHW, 'H'); + const int32 oW = GetTensorDim(topdiff_t, FORMAT_NCHW, 'W'); + + // set everything to zero (we zero-pad) + cudaMemset(padded_a_t.flat().data(), 0, + padded_a_t.NumElements() * sizeof(Dtype)); + cudaMemset(padded_b_t.flat().data(), 0, + padded_b_t.NumElements() * sizeof(Dtype)); + cudaMemset(output_a_gradient_t->flat().data(), 0, + output_a_gradient_t->NumElements() * sizeof(Dtype)); + cudaMemset(output_b_gradient_t->flat().data(), 0, + output_b_gradient_t->NumElements() * sizeof(Dtype)); + + const bool is_NCHW = (data_format == FORMAT_NCHW); + if (is_NCHW) { + pad_and_transpose<<>>( + input_a_t.flat().data(), padded_a_t.flat().data(), iC, + iH, iW, pad); + pad_and_transpose<<>>( + input_b_t.flat().data(), padded_b_t.flat().data(), iC, + iH, iW, pad); + } else { + pad_and_no_transpose<<>>( + input_a_t.flat().data(), padded_a_t.flat().data(), iC, + iH, iW, pad); + pad_and_no_transpose<<>>( + input_b_t.flat().data(), padded_b_t.flat().data(), iC, + iH, iW, pad); + } + + const GPUDevice &d = context->eigen_gpu_device(); + + dim3 threadsPerBlock(THREADS_PER_BLOCK); + dim3 totalBlocksCorr(iH, iW, iC); + + for (int n = 0; n < N; ++n) { + Correlation_backward_input1 + <<>>( + n, output_a_gradient_t->flat().data(), iC, iH, iW, + topdiff_t.flat().data(), oC, oH, oW, + padded_b_t.flat().data(), pad, kernel_size, + max_displacement, stride_1, stride_2, is_NCHW); + } + + for (int n = 0; n < N; n++) { + Correlation_backward_input2 + <<>>( + n, output_b_gradient_t->flat().data(), iC, iH, iW, + topdiff_t.flat().data(), oC, oH, oW, + padded_a_t.flat().data(), pad, kernel_size, + max_displacement, stride_1, stride_2, is_NCHW); + } + + return Status::OK(); + } +}; + +template struct CorrelationCostFunctor; +template struct CorrelationCostGradFunctor; + +} // end namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow_addons/custom_ops/opticalflow/cc/ops/correlation_cost_op.cc b/tensorflow_addons/custom_ops/opticalflow/cc/ops/correlation_cost_op.cc new file mode 100644 index 0000000000..9fb152354d --- /dev/null +++ b/tensorflow_addons/custom_ops/opticalflow/cc/ops/correlation_cost_op.cc @@ -0,0 +1,133 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +using ::tensorflow::shape_inference::InferenceContext; +using ::tensorflow::shape_inference::ShapeHandle; + +// -------------------------------------------------------------------------- + +REGISTER_OP("CorrelationCost") + .Input("input_a: T") + .Input("input_b: T") + .Output("output: T") + .Attr("kernel_size: int") + .Attr("max_displacement: int") + .Attr("stride_1: int") + .Attr("stride_2: int") + .Attr("pad: int") + .Attr("data_format: {'NHWC', 'NCHW'} = 'NHWC'") + .Attr("T: realnumbertype") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle input_a, input_b, input; + + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_a)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &input_b)); + TF_RETURN_IF_ERROR(c->Merge(input_a, input_b, &input)); + + // get input shapes + int32 B, H, W; + B = c->Value(c->Dim(input, 0)); + string data_format; + Status s = c->GetAttr("data_format", &data_format); + if (s.ok() && data_format == "NCHW") { + H = c->Value(c->Dim(input, 2)); + W = c->Value(c->Dim(input, 3)); + } else { + H = c->Value(c->Dim(input, 1)); + W = c->Value(c->Dim(input, 2)); + } + + int32 kernel_size; + int32 max_displacement; + int32 stride_1; + int32 stride_2; + int32 pad; + + TF_RETURN_IF_ERROR(c->GetAttr("kernel_size", &kernel_size)); + TF_RETURN_IF_ERROR(c->GetAttr("max_displacement", &max_displacement)); + // stride in input + TF_RETURN_IF_ERROR(c->GetAttr("stride_1", &stride_1)); + // stride in patch + TF_RETURN_IF_ERROR(c->GetAttr("stride_2", &stride_2)); + TF_RETURN_IF_ERROR(c->GetAttr("pad", &pad)); + + // output channels are d**2 where, d = 2r + 1 + const int32 r = max_displacement / stride_2; + const int32 d = 2 * r + 1; + const int32 border = max_displacement + (kernel_size - 1) / 2; + + const int32 Cout = d * d; + // for spatial dimensions, we pad the inputs + const int32 Hout = static_cast( + ceil(static_cast(((H + 2 * pad) - border * 2)) / + static_cast(stride_1))); + const int32 Wout = static_cast( + ceil(static_cast(((W + 2 * pad) - border * 2)) / + static_cast(stride_1))); + + // Note, the output is always NCHW (even when input is NHWC) + c->set_output(0, c->MakeShape({B, Cout, Hout, Wout})); + return Status::OK(); + }) + .Doc(R"Doc( +Compute Correlation costs. + +This layer implements the correlation operation from +FlowNet: Learning Optical Flow with Convolutional Networks (Fischer et al.) + +input_a: A `Tensor` of the format specified by `data_format`. +input_b: A `Tensor` of the format specified by `data_format`. +kernel_size: An integer specifying the height and width of the + patch used to compute the per-patch costs. +max_displacement: An integer specifying the maximum search radius + for each position. +stride_1: An integer specifying the stride length in the input. +stride_2: An integer specifying the stride length in the patch. +pad: An integer specifying the paddings in height and width. +data_format: Specifies the data format. + Possible values are: + "NHWC" float [batch, height, width, channels] + "NCHW" float [batch, channels, height, width] + Defaults to `"NHWC"`. +)Doc"); + +REGISTER_OP("CorrelationCostGrad") + .Input("orig_input_a: T") + .Input("orig_input_b: T") + .Input("top_diff: T") + .Output("bottom_diff_a: T") + .Output("bottom_diff_b: T") + .Attr("T: realnumbertype") + .Attr("kernel_size: int") + .Attr("max_displacement: int") + .Attr("stride_1: int") + .Attr("stride_2: int") + .Attr("pad: int") + .Attr("data_format: {'NHWC', 'NCHW'} = 'NHWC'") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle shp_hnd; + TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &shp_hnd)); + c->set_output(0, shp_hnd); + c->set_output(1, shp_hnd); + return Status::OK(); + }) + .Doc(R"doc(CorrelationCostGrad op.)doc"); + +} // namespace tensorflow diff --git a/tensorflow_addons/custom_ops/opticalflow/cc/python/__init__.py b/tensorflow_addons/custom_ops/opticalflow/cc/python/__init__.py new file mode 100644 index 0000000000..c7d6a8e997 --- /dev/null +++ b/tensorflow_addons/custom_ops/opticalflow/cc/python/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ops module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow_addons/custom_ops/opticalflow/cc/python/kernel_tests/correlation_cost_op_test.py b/tensorflow_addons/custom_ops/opticalflow/cc/python/kernel_tests/correlation_cost_op_test.py new file mode 100644 index 0000000000..e8b84b7000 --- /dev/null +++ b/tensorflow_addons/custom_ops/opticalflow/cc/python/kernel_tests/correlation_cost_op_test.py @@ -0,0 +1,228 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + + +from tensorflow.contrib.correlation_cost.python.ops import correlation_cost_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.framework import ops +from tensorflow.python.platform import test +from tensorflow.python.framework import constant_op + + +class CorrelationCostTest(test.TestCase): + + def _forward(self, input_a, input_b, + kernel_size, + max_displacement, + stride_1, + stride_2, + pad, + data_format, + use_gpu=False): + with self.test_session(use_gpu=use_gpu, force_gpu=use_gpu) as sess: + + input_a_op = ops.convert_to_tensor(input_a, dtype=dtypes.float32) + input_b_op = ops.convert_to_tensor(input_b, dtype=dtypes.float32) + + kernel_size = 1 + max_displacement = 2 + stride_1 = 1 + stride_2 = 2 + pad = 4 + + call_op = correlation_cost_op.correlation_cost + actual_op = call_op(input_a_op, input_b_op, + kernel_size=kernel_size, + max_displacement=max_displacement, + stride_1=stride_1, + stride_2=stride_2, + pad=pad, + data_format=data_format) + + return sess.run(actual_op) + + def _forward_both(self, data_format='NCHW'): + val = [[[[0, -6, 9, 5], + [1, -5, 10, 3], + [2, -4, 11, 1]], + [[3, -3, 12, -1], + [4, -2, 13, -3], + [5, -1, 14, -5]]], + [[[6, 0, 15, -7], + [7, 1, 16, -9], + [8, 2, 17, -11]], + [[9, 3, 18, -13], + [10, 4, 19, -15], + [11, 5, 20, -17]]]] + + input_a = constant_op.constant(np.array(val), dtype=dtypes.float32) + valb = np.array(val).transpose(2, 3, 0, 1).reshape(2, 2, 3, 4) + input_b = constant_op.constant(valb, dtype=dtypes.float32) + + kernel_size = 1 + max_displacement = 2 + stride_1 = 1 + stride_2 = 2 + pad = 4 + + if data_format == 'NHWC': + input_a = array_ops.transpose(input_a, [0, 2, 3, 1]) + input_b = array_ops.transpose(input_b, [0, 2, 3, 1]) + + actual_cpu = self._forward(input_a, input_b, + kernel_size=kernel_size, + max_displacement=max_displacement, + stride_1=stride_1, + stride_2=stride_2, + pad=pad, + data_format=data_format, + use_gpu=False) + + actual_gpu = self._forward(input_a, input_b, + kernel_size=kernel_size, + max_displacement=max_displacement, + stride_1=stride_1, + stride_2=stride_2, + pad=pad, + data_format=data_format, + use_gpu=True) + + self.assertEqual(actual_cpu.shape, actual_gpu.shape) + self.assertAllClose(actual_cpu, actual_gpu) + + def _forward_simple(self, data_format='NCHW', use_gpu=False): + # cumbersome calculation by hand for a fixed input + # we just test where zeros occurs and a few entries + val = [[[[0, -6, 9, 5], + [1, -5, 10, 3], + [2, -4, 11, 1]], + [[3, -3, 12, -1], + [4, -2, 13, -3], + [5, -1, 14, -5]]], + [[[6, 0, 15, -7], + [7, 1, 16, -9], + [8, 2, 17, -11]], + [[9, 3, 18, -13], + [10, 4, 19, -15], + [11, 5, 20, -17]]]] + + input_a = constant_op.constant(np.array(val), dtype=dtypes.float32) + valb = np.array(val).transpose(2, 3, 0, 1).reshape(2, 2, 3, 4) + input_b = constant_op.constant(valb, dtype=dtypes.float32) + + if data_format == 'NHWC': + input_a = array_ops.transpose(input_a, [0, 2, 3, 1]) + input_b = array_ops.transpose(input_b, [0, 2, 3, 1]) + + input_a_op = ops.convert_to_tensor(input_a, dtype=dtypes.float32) + input_b_op = ops.convert_to_tensor(input_b, dtype=dtypes.float32) + + kernel_size = 1 + max_displacement = 2 + stride_1 = 1 + stride_2 = 2 + pad = 4 + + actual = self._forward(input_a_op, input_b_op, + kernel_size=kernel_size, + max_displacement=max_displacement, + stride_1=stride_1, + stride_2=stride_2, + pad=pad, + data_format=data_format, + use_gpu=use_gpu) + + if data_format == 'NHWC': + # NHWC -> NCHW + actual = actual.transpose(0, 3, 1, 2) + + # we just need to test fixed ids, as output is NCHW independently from data_format + expected_ids = np.concatenate([np.zeros(464,), np.ones(464,)]) + self.assertAllClose(np.where(actual == 0)[0], expected_ids) + + counts = [54, 52, 54, 50, 44, 50, 54, 52, 54] + expected_ids = np.concatenate([k * np.ones(v,) + for k, v in enumerate(counts)]) + expected_ids = np.concatenate([expected_ids, expected_ids]) + self.assertAllClose(np.where(actual == 0)[1], expected_ids) + self.assertEqual(actual.shape, (2, 9, 7, 8)) + + def _gradients(self, data_format='NCHW', use_gpu=False): + + batch, channels, height, width = 2, 3, 5, 6 + input_a = np.random.randn(batch, channels, height, width) + input_b = np.random.randn(batch, channels, height, width) + + kernel_size = 1 + max_displacement = 2 + stride_1 = 1 + stride_2 = 2 + pad = 4 + + if data_format == 'NHWC': + input_a = input_a.transpose(0, 2, 3, 1) + input_b = input_b.transpose(0, 2, 3, 1) + + with self.test_session(use_gpu=use_gpu, force_gpu=use_gpu): + + input_a_op = ops.convert_to_tensor(input_a, dtype=dtypes.float32) + input_b_op = ops.convert_to_tensor(input_b, dtype=dtypes.float32) + + call_op = correlation_cost_op.correlation_cost + actual_op = call_op(input_a_op, input_b_op, + kernel_size=kernel_size, + max_displacement=max_displacement, + stride_1=stride_1, + stride_2=stride_2, + pad=pad, + data_format=data_format) + + err_a = test.compute_gradient_error( + [input_a_op, input_b_op], + [input_a.shape, input_b.shape], + actual_op, actual_op.shape.as_list()) + + self.assertLess(err_a, 1e-4) + + def testForwardNCHW(self): + self._forward_simple(data_format='NCHW', use_gpu=False) + self._forward_simple(data_format='NCHW', use_gpu=True) + + def testForwardNHWC(self): + self._forward_simple(data_format='NHWC', use_gpu=False) + self._forward_simple(data_format='NHWC', use_gpu=True) + + def testForwardSame(self): + self._forward_both(data_format='NCHW') + self._forward_both(data_format='NHWC') + + def testBackwardNCHW(self): + self._gradients(data_format='NCHW', use_gpu=False) + self._gradients(data_format='NCHW', use_gpu=True) + + def testBackwardNHWC(self): + self._gradients(data_format='NHWC', use_gpu=False) + self._gradients(data_format='NHWC', use_gpu=True) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow_addons/custom_ops/opticalflow/cc/python/ops/correlation_cost_op.py b/tensorflow_addons/custom_ops/opticalflow/cc/python/ops/correlation_cost_op.py new file mode 100644 index 0000000000..ef70756738 --- /dev/null +++ b/tensorflow_addons/custom_ops/opticalflow/cc/python/ops/correlation_cost_op.py @@ -0,0 +1,126 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tensorflow op performing correlation cost operation.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.correlation_cost.ops import gen_correlation_cost_op +from tensorflow.contrib.util import loader +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import resource_loader + +_correlation_cost_op_so = loader.load_op_library( + resource_loader.get_path_to_datafile("_correlation_cost_op.so")) + +# pylint: disable=redefined-builtin + + +def correlation_cost(input_a, + input_b, + kernel_size, + max_displacement, + stride_1, + stride_2, + pad, + data_format='NHWC', + name=None): + """Correlation Cost Volume computation. + + Computes a cost volume using correlation for two inputs. For feature + maps A, B with spatial dimensions w, h, c it computes + + output(a, b) = sum_{l in [-k,k]**2} < I(a+l), J(b+l) > + + where the patches of size K=2d + 1 are centered in position a resp. b. + + The output shape is [B, C', H', W'], where + + r = max_displacement / stride_2; + bd = max_displacement + (kernel_size - 1) / 2 + C' = (2 * r + 1) ** 2 + H' = H + 2 * (pad - bd) / stride_1 + W' = W + 2 * (pad - bd) / stride_1 + + Note: When the data_format requests "NHWC", an additional explicit + transpose operation is executed. + + Args: + input_a: A `Tensor` of the format specified by `data_format`. + input_b: A `Tensor` of the format specified by `data_format`. + kernel_size: An integer specifying the height and width of the + patch used to compute the per-patch costs. + max_displacement: An integer specifying the maximum search radius + for each position. + stride_1: An integer specifying the stride length in the input. + stride_2: An integer specifying the stride length in the patch. + pad: An integer specifying the paddings in height and width. + data_format: Specifies the data format. + Possible values are: + "NHWC" float [batch, height, width, channels] + "NCHW" float [batch, channels, height, width] + Defaults to `"NHWC"`. + name: A name for the operation (optional). + + Returns: + A `Tensor` of the format specified by `data_format`. + """ + + with ops.name_scope(name, "correlation_cost"): + op_call = gen_correlation_cost_op.correlation_cost + ret = op_call(input_a, input_b, + kernel_size=kernel_size, + max_displacement=max_displacement, + stride_1=stride_1, + stride_2=stride_2, + pad=pad, + data_format=data_format) + if data_format == 'NHWC': + # this is easier to maintain without + # specializing an additional cuda kernel + return array_ops.transpose(ret, [0, 2, 3, 1]) + return ret + + +correlation_cost_grad = gen_correlation_cost_op.correlation_cost_grad + + +@ops.RegisterGradient("CorrelationCost") +def _correlation_cost_grad(op, grad_output): + kernel_size = op.get_attr("kernel_size") + max_displacement = op.get_attr("max_displacement") + stride_1 = op.get_attr("stride_1") + stride_2 = op.get_attr("stride_2") + pad = op.get_attr("pad") + data_format = op.get_attr("data_format") + + input_a = ops.convert_to_tensor(op.inputs[0], name="input_a") + input_b = ops.convert_to_tensor(op.inputs[1], name="input_b") + grad_output_tensor = ops.convert_to_tensor(grad_output, name="grad_output") + + op_call = gen_correlation_cost_op.correlation_cost_grad + grads = op_call(input_a, input_b, grad_output_tensor, + kernel_size=kernel_size, + max_displacement=max_displacement, + stride_1=stride_1, + stride_2=stride_2, + pad=pad, + data_format=data_format) + + grad_input_a = ops.convert_to_tensor(grads[0], name="grad_input_a") + grad_input_b = ops.convert_to_tensor(grads[1], name="grad_input_b") + return [grad_input_a, grad_input_b] From 1096a45bcf8278737f515035773ceee50d91fb74 Mon Sep 17 00:00:00 2001 From: PatWie Date: Thu, 25 Apr 2019 15:33:23 +0200 Subject: [PATCH 02/18] move files around --- tensorflow_addons/opticalflow/BUILD | 29 +++++++++++++++++++ .../cc/python => opticalflow}/__init__.py | 2 ++ .../correlation_cost.py} | 0 .../correlation_cost_test.py} | 0 4 files changed, 31 insertions(+) create mode 100644 tensorflow_addons/opticalflow/BUILD rename tensorflow_addons/{custom_ops/opticalflow/cc/python => opticalflow}/__init__.py (91%) rename tensorflow_addons/{custom_ops/opticalflow/cc/python/ops/correlation_cost_op.py => opticalflow/correlation_cost.py} (100%) rename tensorflow_addons/{custom_ops/opticalflow/cc/python/kernel_tests/correlation_cost_op_test.py => opticalflow/correlation_cost_test.py} (100%) diff --git a/tensorflow_addons/opticalflow/BUILD b/tensorflow_addons/opticalflow/BUILD new file mode 100644 index 0000000000..7709dafb91 --- /dev/null +++ b/tensorflow_addons/opticalflow/BUILD @@ -0,0 +1,29 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "opticalflow", + srcs = ([ + "__init__.py", + "correlation_cost.py", + ]), + data = [ + "//tensorflow_addons/custom_ops/opticalflow:_correlation_cost_ops.so", + "//tensorflow_addons/utils", + ], + srcs_version = "PY2AND3", +) + +py_test( + name = "correlation_cost_test", + size = "small", + srcs = [ + "correlation_cost_test.py", + ], + main = "correlation_cost_test.py", + srcs_version = "PY2AND3", + deps = [ + ":opticalflow", + ], +) diff --git a/tensorflow_addons/custom_ops/opticalflow/cc/python/__init__.py b/tensorflow_addons/opticalflow/__init__.py similarity index 91% rename from tensorflow_addons/custom_ops/opticalflow/cc/python/__init__.py rename to tensorflow_addons/opticalflow/__init__.py index c7d6a8e997..d6565f7a6c 100644 --- a/tensorflow_addons/custom_ops/opticalflow/cc/python/__init__.py +++ b/tensorflow_addons/opticalflow/__init__.py @@ -17,3 +17,5 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function + +from tensorflow_addons.opticalflow.correlation_cost import correlation_cost diff --git a/tensorflow_addons/custom_ops/opticalflow/cc/python/ops/correlation_cost_op.py b/tensorflow_addons/opticalflow/correlation_cost.py similarity index 100% rename from tensorflow_addons/custom_ops/opticalflow/cc/python/ops/correlation_cost_op.py rename to tensorflow_addons/opticalflow/correlation_cost.py diff --git a/tensorflow_addons/custom_ops/opticalflow/cc/python/kernel_tests/correlation_cost_op_test.py b/tensorflow_addons/opticalflow/correlation_cost_test.py similarity index 100% rename from tensorflow_addons/custom_ops/opticalflow/cc/python/kernel_tests/correlation_cost_op_test.py rename to tensorflow_addons/opticalflow/correlation_cost_test.py From 3ee6f601a32329e4c34c4d13e58d56ad1e21335a Mon Sep 17 00:00:00 2001 From: PatWie Date: Thu, 2 May 2019 22:14:14 +0200 Subject: [PATCH 03/18] make code-format --- .../custom_ops/opticalflow/BUILD | 4 +- .../cc/kernels/correlation_cost_op.cc | 19 +- .../cc/kernels/correlation_cost_op_gpu.cu.cc | 36 +- .../opticalflow/correlation_cost.py | 161 ++++---- .../opticalflow/correlation_cost_test.py | 384 +++++++++--------- 5 files changed, 301 insertions(+), 303 deletions(-) diff --git a/tensorflow_addons/custom_ops/opticalflow/BUILD b/tensorflow_addons/custom_ops/opticalflow/BUILD index fa6e708b67..37f46b7be3 100644 --- a/tensorflow_addons/custom_ops/opticalflow/BUILD +++ b/tensorflow_addons/custom_ops/opticalflow/BUILD @@ -19,7 +19,7 @@ cc_binary( deps = [ "@local_config_tf//:libtensorflow_framework", "@local_config_tf//:tf_header_lib", - ]+ if_cuda([ - "@cub_archive//:cub" + ] + if_cuda([ + "@cub_archive//:cub", ]), ) diff --git a/tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op.cc b/tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op.cc index 6bfaf38828..e4da834c8a 100644 --- a/tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op.cc +++ b/tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op.cc @@ -241,12 +241,11 @@ class CorrelationCostOp : public OpKernel { context, context->allocate_output(0, TensorShape({N, Cout, Hout, Wout}), &output_t)); - functor::CorrelationCostFunctor correlationCostFunc; - Status s = correlationCostFunc( - context, input_a_t, input_b_t, output_t, - /* params */ - kernel_size, max_displacement, stride_1, stride_2, pad, data_format_); + Status s = correlationCostFunc(context, input_a_t, input_b_t, output_t, + /* params */ + kernel_size, max_displacement, stride_1, + stride_2, pad, data_format_); OP_REQUIRES_OK(context, s); } @@ -296,11 +295,11 @@ class CorrelationCostGradOp : public OpKernel { &output_b_gradient_t)); functor::CorrelationCostGradFunctor correlationCostGrad; - Status s = correlationCostGrad( - context, input_a_t, input_b_t, topdiff_t, - output_a_gradient_t, output_b_gradient_t, - /* params */ - kernel_size, max_displacement, stride_1, stride_2, pad, data_format_); + Status s = correlationCostGrad(context, input_a_t, input_b_t, topdiff_t, + output_a_gradient_t, output_b_gradient_t, + /* params */ + kernel_size, max_displacement, stride_1, + stride_2, pad, data_format_); OP_REQUIRES_OK(context, s); } diff --git a/tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op_gpu.cu.cc b/tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op_gpu.cu.cc index 51451dfb06..effd205ff5 100644 --- a/tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op_gpu.cu.cc +++ b/tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op_gpu.cu.cc @@ -364,12 +364,12 @@ struct CorrelationCostFunctor { dim3 threadsPerBlock(THREADS_PER_BLOCK); dim3 totalBlocksCorr(N, oH, oW); - Correlation_forward - <<>>( - output_t->flat().data(), oC, oH, oW, - padded_a_t.flat().data(), iC, iH, iW, - padded_b_t.flat().data(), pad, kernel_size, max_displacement, - stride_1, stride_2); + Correlation_forward< + THREADS_PER_BLOCK><<>>( + output_t->flat().data(), oC, oH, oW, + padded_a_t.flat().data(), iC, iH, iW, + padded_b_t.flat().data(), pad, kernel_size, max_displacement, + stride_1, stride_2); return Status::OK(); } @@ -447,21 +447,21 @@ struct CorrelationCostGradFunctor { dim3 totalBlocksCorr(iH, iW, iC); for (int n = 0; n < N; ++n) { - Correlation_backward_input1 - <<>>( - n, output_a_gradient_t->flat().data(), iC, iH, iW, - topdiff_t.flat().data(), oC, oH, oW, - padded_b_t.flat().data(), pad, kernel_size, - max_displacement, stride_1, stride_2, is_NCHW); + Correlation_backward_input1< + THREADS_PER_BLOCK><<>>( + n, output_a_gradient_t->flat().data(), iC, iH, iW, + topdiff_t.flat().data(), oC, oH, oW, + padded_b_t.flat().data(), pad, kernel_size, max_displacement, + stride_1, stride_2, is_NCHW); } for (int n = 0; n < N; n++) { - Correlation_backward_input2 - <<>>( - n, output_b_gradient_t->flat().data(), iC, iH, iW, - topdiff_t.flat().data(), oC, oH, oW, - padded_a_t.flat().data(), pad, kernel_size, - max_displacement, stride_1, stride_2, is_NCHW); + Correlation_backward_input2< + THREADS_PER_BLOCK><<>>( + n, output_b_gradient_t->flat().data(), iC, iH, iW, + topdiff_t.flat().data(), oC, oH, oW, + padded_a_t.flat().data(), pad, kernel_size, max_displacement, + stride_1, stride_2, is_NCHW); } return Status::OK(); diff --git a/tensorflow_addons/opticalflow/correlation_cost.py b/tensorflow_addons/opticalflow/correlation_cost.py index ef70756738..4b2e790165 100644 --- a/tensorflow_addons/opticalflow/correlation_cost.py +++ b/tensorflow_addons/opticalflow/correlation_cost.py @@ -39,61 +39,63 @@ def correlation_cost(input_a, pad, data_format='NHWC', name=None): - """Correlation Cost Volume computation. - - Computes a cost volume using correlation for two inputs. For feature - maps A, B with spatial dimensions w, h, c it computes - - output(a, b) = sum_{l in [-k,k]**2} < I(a+l), J(b+l) > - - where the patches of size K=2d + 1 are centered in position a resp. b. - - The output shape is [B, C', H', W'], where - - r = max_displacement / stride_2; - bd = max_displacement + (kernel_size - 1) / 2 - C' = (2 * r + 1) ** 2 - H' = H + 2 * (pad - bd) / stride_1 - W' = W + 2 * (pad - bd) / stride_1 - - Note: When the data_format requests "NHWC", an additional explicit - transpose operation is executed. - - Args: - input_a: A `Tensor` of the format specified by `data_format`. - input_b: A `Tensor` of the format specified by `data_format`. - kernel_size: An integer specifying the height and width of the - patch used to compute the per-patch costs. - max_displacement: An integer specifying the maximum search radius - for each position. - stride_1: An integer specifying the stride length in the input. - stride_2: An integer specifying the stride length in the patch. - pad: An integer specifying the paddings in height and width. - data_format: Specifies the data format. - Possible values are: - "NHWC" float [batch, height, width, channels] - "NCHW" float [batch, channels, height, width] - Defaults to `"NHWC"`. - name: A name for the operation (optional). - - Returns: - A `Tensor` of the format specified by `data_format`. - """ - - with ops.name_scope(name, "correlation_cost"): - op_call = gen_correlation_cost_op.correlation_cost - ret = op_call(input_a, input_b, - kernel_size=kernel_size, - max_displacement=max_displacement, - stride_1=stride_1, - stride_2=stride_2, - pad=pad, - data_format=data_format) - if data_format == 'NHWC': - # this is easier to maintain without - # specializing an additional cuda kernel - return array_ops.transpose(ret, [0, 2, 3, 1]) - return ret + """Correlation Cost Volume computation. + + Computes a cost volume using correlation for two inputs. For feature + maps A, B with spatial dimensions w, h, c it computes + + output(a, b) = sum_{l in [-k,k]**2} < I(a+l), J(b+l) > + + where the patches of size K=2d + 1 are centered in position a resp. b. + + The output shape is [B, C', H', W'], where + + r = max_displacement / stride_2; + bd = max_displacement + (kernel_size - 1) / 2 + C' = (2 * r + 1) ** 2 + H' = H + 2 * (pad - bd) / stride_1 + W' = W + 2 * (pad - bd) / stride_1 + + Note: When the data_format requests "NHWC", an additional explicit + transpose operation is executed. + + Args: + input_a: A `Tensor` of the format specified by `data_format`. + input_b: A `Tensor` of the format specified by `data_format`. + kernel_size: An integer specifying the height and width of the + patch used to compute the per-patch costs. + max_displacement: An integer specifying the maximum search radius + for each position. + stride_1: An integer specifying the stride length in the input. + stride_2: An integer specifying the stride length in the patch. + pad: An integer specifying the paddings in height and width. + data_format: Specifies the data format. + Possible values are: + "NHWC" float [batch, height, width, channels] + "NCHW" float [batch, channels, height, width] + Defaults to `"NHWC"`. + name: A name for the operation (optional). + + Returns: + A `Tensor` of the format specified by `data_format`. + """ + + with ops.name_scope(name, "correlation_cost"): + op_call = gen_correlation_cost_op.correlation_cost + ret = op_call( + input_a, + input_b, + kernel_size=kernel_size, + max_displacement=max_displacement, + stride_1=stride_1, + stride_2=stride_2, + pad=pad, + data_format=data_format) + if data_format == 'NHWC': + # this is easier to maintain without + # specializing an additional cuda kernel + return array_ops.transpose(ret, [0, 2, 3, 1]) + return ret correlation_cost_grad = gen_correlation_cost_op.correlation_cost_grad @@ -101,26 +103,29 @@ def correlation_cost(input_a, @ops.RegisterGradient("CorrelationCost") def _correlation_cost_grad(op, grad_output): - kernel_size = op.get_attr("kernel_size") - max_displacement = op.get_attr("max_displacement") - stride_1 = op.get_attr("stride_1") - stride_2 = op.get_attr("stride_2") - pad = op.get_attr("pad") - data_format = op.get_attr("data_format") - - input_a = ops.convert_to_tensor(op.inputs[0], name="input_a") - input_b = ops.convert_to_tensor(op.inputs[1], name="input_b") - grad_output_tensor = ops.convert_to_tensor(grad_output, name="grad_output") - - op_call = gen_correlation_cost_op.correlation_cost_grad - grads = op_call(input_a, input_b, grad_output_tensor, - kernel_size=kernel_size, - max_displacement=max_displacement, - stride_1=stride_1, - stride_2=stride_2, - pad=pad, - data_format=data_format) - - grad_input_a = ops.convert_to_tensor(grads[0], name="grad_input_a") - grad_input_b = ops.convert_to_tensor(grads[1], name="grad_input_b") - return [grad_input_a, grad_input_b] + kernel_size = op.get_attr("kernel_size") + max_displacement = op.get_attr("max_displacement") + stride_1 = op.get_attr("stride_1") + stride_2 = op.get_attr("stride_2") + pad = op.get_attr("pad") + data_format = op.get_attr("data_format") + + input_a = ops.convert_to_tensor(op.inputs[0], name="input_a") + input_b = ops.convert_to_tensor(op.inputs[1], name="input_b") + grad_output_tensor = ops.convert_to_tensor(grad_output, name="grad_output") + + op_call = gen_correlation_cost_op.correlation_cost_grad + grads = op_call( + input_a, + input_b, + grad_output_tensor, + kernel_size=kernel_size, + max_displacement=max_displacement, + stride_1=stride_1, + stride_2=stride_2, + pad=pad, + data_format=data_format) + + grad_input_a = ops.convert_to_tensor(grads[0], name="grad_input_a") + grad_input_b = ops.convert_to_tensor(grads[1], name="grad_input_b") + return [grad_input_a, grad_input_b] diff --git a/tensorflow_addons/opticalflow/correlation_cost_test.py b/tensorflow_addons/opticalflow/correlation_cost_test.py index e8b84b7000..dce2d4a9fc 100644 --- a/tensorflow_addons/opticalflow/correlation_cost_test.py +++ b/tensorflow_addons/opticalflow/correlation_cost_test.py @@ -19,7 +19,6 @@ import numpy as np - from tensorflow.contrib.correlation_cost.python.ops import correlation_cost_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -29,200 +28,195 @@ class CorrelationCostTest(test.TestCase): - - def _forward(self, input_a, input_b, - kernel_size, - max_displacement, - stride_1, - stride_2, - pad, - data_format, - use_gpu=False): - with self.test_session(use_gpu=use_gpu, force_gpu=use_gpu) as sess: - - input_a_op = ops.convert_to_tensor(input_a, dtype=dtypes.float32) - input_b_op = ops.convert_to_tensor(input_b, dtype=dtypes.float32) - - kernel_size = 1 - max_displacement = 2 - stride_1 = 1 - stride_2 = 2 - pad = 4 - - call_op = correlation_cost_op.correlation_cost - actual_op = call_op(input_a_op, input_b_op, - kernel_size=kernel_size, - max_displacement=max_displacement, - stride_1=stride_1, - stride_2=stride_2, - pad=pad, - data_format=data_format) - - return sess.run(actual_op) - - def _forward_both(self, data_format='NCHW'): - val = [[[[0, -6, 9, 5], - [1, -5, 10, 3], - [2, -4, 11, 1]], - [[3, -3, 12, -1], - [4, -2, 13, -3], - [5, -1, 14, -5]]], - [[[6, 0, 15, -7], - [7, 1, 16, -9], - [8, 2, 17, -11]], - [[9, 3, 18, -13], - [10, 4, 19, -15], - [11, 5, 20, -17]]]] - - input_a = constant_op.constant(np.array(val), dtype=dtypes.float32) - valb = np.array(val).transpose(2, 3, 0, 1).reshape(2, 2, 3, 4) - input_b = constant_op.constant(valb, dtype=dtypes.float32) - - kernel_size = 1 - max_displacement = 2 - stride_1 = 1 - stride_2 = 2 - pad = 4 - - if data_format == 'NHWC': - input_a = array_ops.transpose(input_a, [0, 2, 3, 1]) - input_b = array_ops.transpose(input_b, [0, 2, 3, 1]) - - actual_cpu = self._forward(input_a, input_b, - kernel_size=kernel_size, - max_displacement=max_displacement, - stride_1=stride_1, - stride_2=stride_2, - pad=pad, - data_format=data_format, - use_gpu=False) - - actual_gpu = self._forward(input_a, input_b, - kernel_size=kernel_size, - max_displacement=max_displacement, - stride_1=stride_1, - stride_2=stride_2, - pad=pad, - data_format=data_format, - use_gpu=True) - - self.assertEqual(actual_cpu.shape, actual_gpu.shape) - self.assertAllClose(actual_cpu, actual_gpu) - - def _forward_simple(self, data_format='NCHW', use_gpu=False): - # cumbersome calculation by hand for a fixed input - # we just test where zeros occurs and a few entries - val = [[[[0, -6, 9, 5], - [1, -5, 10, 3], - [2, -4, 11, 1]], - [[3, -3, 12, -1], - [4, -2, 13, -3], - [5, -1, 14, -5]]], - [[[6, 0, 15, -7], - [7, 1, 16, -9], - [8, 2, 17, -11]], - [[9, 3, 18, -13], - [10, 4, 19, -15], - [11, 5, 20, -17]]]] - - input_a = constant_op.constant(np.array(val), dtype=dtypes.float32) - valb = np.array(val).transpose(2, 3, 0, 1).reshape(2, 2, 3, 4) - input_b = constant_op.constant(valb, dtype=dtypes.float32) - - if data_format == 'NHWC': - input_a = array_ops.transpose(input_a, [0, 2, 3, 1]) - input_b = array_ops.transpose(input_b, [0, 2, 3, 1]) - - input_a_op = ops.convert_to_tensor(input_a, dtype=dtypes.float32) - input_b_op = ops.convert_to_tensor(input_b, dtype=dtypes.float32) - - kernel_size = 1 - max_displacement = 2 - stride_1 = 1 - stride_2 = 2 - pad = 4 - - actual = self._forward(input_a_op, input_b_op, - kernel_size=kernel_size, - max_displacement=max_displacement, - stride_1=stride_1, - stride_2=stride_2, - pad=pad, - data_format=data_format, - use_gpu=use_gpu) - - if data_format == 'NHWC': - # NHWC -> NCHW - actual = actual.transpose(0, 3, 1, 2) - - # we just need to test fixed ids, as output is NCHW independently from data_format - expected_ids = np.concatenate([np.zeros(464,), np.ones(464,)]) - self.assertAllClose(np.where(actual == 0)[0], expected_ids) - - counts = [54, 52, 54, 50, 44, 50, 54, 52, 54] - expected_ids = np.concatenate([k * np.ones(v,) - for k, v in enumerate(counts)]) - expected_ids = np.concatenate([expected_ids, expected_ids]) - self.assertAllClose(np.where(actual == 0)[1], expected_ids) - self.assertEqual(actual.shape, (2, 9, 7, 8)) - - def _gradients(self, data_format='NCHW', use_gpu=False): - - batch, channels, height, width = 2, 3, 5, 6 - input_a = np.random.randn(batch, channels, height, width) - input_b = np.random.randn(batch, channels, height, width) - - kernel_size = 1 - max_displacement = 2 - stride_1 = 1 - stride_2 = 2 - pad = 4 - - if data_format == 'NHWC': - input_a = input_a.transpose(0, 2, 3, 1) - input_b = input_b.transpose(0, 2, 3, 1) - - with self.test_session(use_gpu=use_gpu, force_gpu=use_gpu): - - input_a_op = ops.convert_to_tensor(input_a, dtype=dtypes.float32) - input_b_op = ops.convert_to_tensor(input_b, dtype=dtypes.float32) - - call_op = correlation_cost_op.correlation_cost - actual_op = call_op(input_a_op, input_b_op, - kernel_size=kernel_size, - max_displacement=max_displacement, - stride_1=stride_1, - stride_2=stride_2, - pad=pad, - data_format=data_format) - - err_a = test.compute_gradient_error( - [input_a_op, input_b_op], - [input_a.shape, input_b.shape], - actual_op, actual_op.shape.as_list()) - - self.assertLess(err_a, 1e-4) - - def testForwardNCHW(self): - self._forward_simple(data_format='NCHW', use_gpu=False) - self._forward_simple(data_format='NCHW', use_gpu=True) - - def testForwardNHWC(self): - self._forward_simple(data_format='NHWC', use_gpu=False) - self._forward_simple(data_format='NHWC', use_gpu=True) - - def testForwardSame(self): - self._forward_both(data_format='NCHW') - self._forward_both(data_format='NHWC') - - def testBackwardNCHW(self): - self._gradients(data_format='NCHW', use_gpu=False) - self._gradients(data_format='NCHW', use_gpu=True) - - def testBackwardNHWC(self): - self._gradients(data_format='NHWC', use_gpu=False) - self._gradients(data_format='NHWC', use_gpu=True) + def _forward(self, + input_a, + input_b, + kernel_size, + max_displacement, + stride_1, + stride_2, + pad, + data_format, + use_gpu=False): + with self.test_session(use_gpu=use_gpu, force_gpu=use_gpu) as sess: + + input_a_op = ops.convert_to_tensor(input_a, dtype=dtypes.float32) + input_b_op = ops.convert_to_tensor(input_b, dtype=dtypes.float32) + + kernel_size = 1 + max_displacement = 2 + stride_1 = 1 + stride_2 = 2 + pad = 4 + + call_op = correlation_cost_op.correlation_cost + actual_op = call_op( + input_a_op, + input_b_op, + kernel_size=kernel_size, + max_displacement=max_displacement, + stride_1=stride_1, + stride_2=stride_2, + pad=pad, + data_format=data_format) + + return sess.run(actual_op) + + def _forward_both(self, data_format='NCHW'): + val = [[[[0, -6, 9, 5], [1, -5, 10, 3], [2, -4, 11, 1]], + [[3, -3, 12, -1], [4, -2, 13, -3], [5, -1, 14, -5]]], + [[[6, 0, 15, -7], [7, 1, 16, -9], [8, 2, 17, -11]], + [[9, 3, 18, -13], [10, 4, 19, -15], [11, 5, 20, -17]]]] + + input_a = constant_op.constant(np.array(val), dtype=dtypes.float32) + valb = np.array(val).transpose(2, 3, 0, 1).reshape(2, 2, 3, 4) + input_b = constant_op.constant(valb, dtype=dtypes.float32) + + kernel_size = 1 + max_displacement = 2 + stride_1 = 1 + stride_2 = 2 + pad = 4 + + if data_format == 'NHWC': + input_a = array_ops.transpose(input_a, [0, 2, 3, 1]) + input_b = array_ops.transpose(input_b, [0, 2, 3, 1]) + + actual_cpu = self._forward( + input_a, + input_b, + kernel_size=kernel_size, + max_displacement=max_displacement, + stride_1=stride_1, + stride_2=stride_2, + pad=pad, + data_format=data_format, + use_gpu=False) + + actual_gpu = self._forward( + input_a, + input_b, + kernel_size=kernel_size, + max_displacement=max_displacement, + stride_1=stride_1, + stride_2=stride_2, + pad=pad, + data_format=data_format, + use_gpu=True) + + self.assertEqual(actual_cpu.shape, actual_gpu.shape) + self.assertAllClose(actual_cpu, actual_gpu) + + def _forward_simple(self, data_format='NCHW', use_gpu=False): + # cumbersome calculation by hand for a fixed input + # we just test where zeros occurs and a few entries + val = [[[[0, -6, 9, 5], [1, -5, 10, 3], [2, -4, 11, 1]], + [[3, -3, 12, -1], [4, -2, 13, -3], [5, -1, 14, -5]]], + [[[6, 0, 15, -7], [7, 1, 16, -9], [8, 2, 17, -11]], + [[9, 3, 18, -13], [10, 4, 19, -15], [11, 5, 20, -17]]]] + + input_a = constant_op.constant(np.array(val), dtype=dtypes.float32) + valb = np.array(val).transpose(2, 3, 0, 1).reshape(2, 2, 3, 4) + input_b = constant_op.constant(valb, dtype=dtypes.float32) + + if data_format == 'NHWC': + input_a = array_ops.transpose(input_a, [0, 2, 3, 1]) + input_b = array_ops.transpose(input_b, [0, 2, 3, 1]) + + input_a_op = ops.convert_to_tensor(input_a, dtype=dtypes.float32) + input_b_op = ops.convert_to_tensor(input_b, dtype=dtypes.float32) + + kernel_size = 1 + max_displacement = 2 + stride_1 = 1 + stride_2 = 2 + pad = 4 + + actual = self._forward( + input_a_op, + input_b_op, + kernel_size=kernel_size, + max_displacement=max_displacement, + stride_1=stride_1, + stride_2=stride_2, + pad=pad, + data_format=data_format, + use_gpu=use_gpu) + + if data_format == 'NHWC': + # NHWC -> NCHW + actual = actual.transpose(0, 3, 1, 2) + + # we just need to test fixed ids, as output is NCHW independently from data_format + expected_ids = np.concatenate([np.zeros(464,), np.ones(464,)]) + self.assertAllClose(np.where(actual == 0)[0], expected_ids) + + counts = [54, 52, 54, 50, 44, 50, 54, 52, 54] + expected_ids = np.concatenate( + [k * np.ones(v,) for k, v in enumerate(counts)]) + expected_ids = np.concatenate([expected_ids, expected_ids]) + self.assertAllClose(np.where(actual == 0)[1], expected_ids) + self.assertEqual(actual.shape, (2, 9, 7, 8)) + + def _gradients(self, data_format='NCHW', use_gpu=False): + + batch, channels, height, width = 2, 3, 5, 6 + input_a = np.random.randn(batch, channels, height, width) + input_b = np.random.randn(batch, channels, height, width) + + kernel_size = 1 + max_displacement = 2 + stride_1 = 1 + stride_2 = 2 + pad = 4 + + if data_format == 'NHWC': + input_a = input_a.transpose(0, 2, 3, 1) + input_b = input_b.transpose(0, 2, 3, 1) + + with self.test_session(use_gpu=use_gpu, force_gpu=use_gpu): + + input_a_op = ops.convert_to_tensor(input_a, dtype=dtypes.float32) + input_b_op = ops.convert_to_tensor(input_b, dtype=dtypes.float32) + + call_op = correlation_cost_op.correlation_cost + actual_op = call_op( + input_a_op, + input_b_op, + kernel_size=kernel_size, + max_displacement=max_displacement, + stride_1=stride_1, + stride_2=stride_2, + pad=pad, + data_format=data_format) + + err_a = test.compute_gradient_error([input_a_op, input_b_op], + [input_a.shape, input_b.shape], + actual_op, + actual_op.shape.as_list()) + + self.assertLess(err_a, 1e-4) + + def testForwardNCHW(self): + self._forward_simple(data_format='NCHW', use_gpu=False) + self._forward_simple(data_format='NCHW', use_gpu=True) + + def testForwardNHWC(self): + self._forward_simple(data_format='NHWC', use_gpu=False) + self._forward_simple(data_format='NHWC', use_gpu=True) + + def testForwardSame(self): + self._forward_both(data_format='NCHW') + self._forward_both(data_format='NHWC') + + def testBackwardNCHW(self): + self._gradients(data_format='NCHW', use_gpu=False) + self._gradients(data_format='NCHW', use_gpu=True) + + def testBackwardNHWC(self): + self._gradients(data_format='NHWC', use_gpu=False) + self._gradients(data_format='NHWC', use_gpu=True) if __name__ == "__main__": - test.main() + test.main() From 3f1958340ad986626ba82c4191f1157d1f51f28c Mon Sep 17 00:00:00 2001 From: PatWie Date: Thu, 2 May 2019 22:17:07 +0200 Subject: [PATCH 04/18] make sanity-check pass --- tensorflow_addons/custom_ops/opticalflow/BUILD | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tensorflow_addons/custom_ops/opticalflow/BUILD b/tensorflow_addons/custom_ops/opticalflow/BUILD index 37f46b7be3..6cb1fcebf1 100644 --- a/tensorflow_addons/custom_ops/opticalflow/BUILD +++ b/tensorflow_addons/custom_ops/opticalflow/BUILD @@ -3,7 +3,7 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//visibility:public"]) cc_binary( - name = "_distort_image_ops.so", + name = "_correlation_cost_ops.so", srcs = [ "cc/kernels/correlation_cost_op.cc", "cc/kernels/correlation_cost_op.h", @@ -19,7 +19,8 @@ cc_binary( deps = [ "@local_config_tf//:libtensorflow_framework", "@local_config_tf//:tf_header_lib", - ] + if_cuda([ - "@cub_archive//:cub", - ]), + ], + # + if_cuda([ + # "@cub_archive//:cub", + # ]), ) From 5c044469033fc265257ad7f149a113c7f0817a84 Mon Sep 17 00:00:00 2001 From: PatWie Date: Thu, 2 May 2019 22:39:05 +0200 Subject: [PATCH 05/18] correct imports --- .../opticalflow/cc/kernels/correlation_cost_op.cc | 7 ++++++- .../cc/kernels/correlation_cost_op_gpu.cu.cc | 3 +-- .../opticalflow/cc/ops/correlation_cost_op.cc | 2 +- tensorflow_addons/opticalflow/correlation_cost.py | 15 +++++++-------- .../opticalflow/correlation_cost_test.py | 6 +++--- 5 files changed, 18 insertions(+), 15 deletions(-) diff --git a/tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op.cc b/tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op.cc index e4da834c8a..73ddc92aa1 100644 --- a/tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op.cc +++ b/tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op.cc @@ -13,8 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/correlation_cost/kernels/correlation_cost_op.h" +#define EIGEN_USE_THREADS +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA + +#include "tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op_gpu.cu.cc b/tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op_gpu.cu.cc index effd205ff5..d9087046e8 100644 --- a/tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op_gpu.cu.cc +++ b/tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op_gpu.cu.cc @@ -17,8 +17,7 @@ limitations under the License. #define EIGEN_USE_GPU -#include "tensorflow/contrib/correlation_cost/kernels/correlation_cost_op.h" - +#include "tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op.h" #include "external/cub_archive/cub/device/device_reduce.cuh" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow_addons/custom_ops/opticalflow/cc/ops/correlation_cost_op.cc b/tensorflow_addons/custom_ops/opticalflow/cc/ops/correlation_cost_op.cc index 9fb152354d..61fd9d7f0f 100644 --- a/tensorflow_addons/custom_ops/opticalflow/cc/ops/correlation_cost_op.cc +++ b/tensorflow_addons/custom_ops/opticalflow/cc/ops/correlation_cost_op.cc @@ -90,7 +90,7 @@ REGISTER_OP("CorrelationCost") Compute Correlation costs. This layer implements the correlation operation from -FlowNet: Learning Optical Flow with Convolutional Networks (Fischer et al.) +FlowNet Learning Optical Flow with Convolutional Networks (Fischer et al.) input_a: A `Tensor` of the format specified by `data_format`. input_b: A `Tensor` of the format specified by `data_format`. diff --git a/tensorflow_addons/opticalflow/correlation_cost.py b/tensorflow_addons/opticalflow/correlation_cost.py index 4b2e790165..398023db5d 100644 --- a/tensorflow_addons/opticalflow/correlation_cost.py +++ b/tensorflow_addons/opticalflow/correlation_cost.py @@ -18,14 +18,13 @@ from __future__ import division from __future__ import print_function -from tensorflow.contrib.correlation_cost.ops import gen_correlation_cost_op -from tensorflow.contrib.util import loader +import tensorflow as tf from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.platform import resource_loader +from tensorflow_addons.utils.resource_loader import get_path_to_datafile -_correlation_cost_op_so = loader.load_op_library( - resource_loader.get_path_to_datafile("_correlation_cost_op.so")) +_correlation_cost_op_so = tf.load_op_library( + get_path_to_datafile("custom_ops/opticalflow/_correlation_cost_ops.so")) # pylint: disable=redefined-builtin @@ -81,7 +80,7 @@ def correlation_cost(input_a, """ with ops.name_scope(name, "correlation_cost"): - op_call = gen_correlation_cost_op.correlation_cost + op_call = _correlation_cost_op_so.correlation_cost ret = op_call( input_a, input_b, @@ -98,7 +97,7 @@ def correlation_cost(input_a, return ret -correlation_cost_grad = gen_correlation_cost_op.correlation_cost_grad +correlation_cost_grad = _correlation_cost_op_so.correlation_cost_grad @ops.RegisterGradient("CorrelationCost") @@ -114,7 +113,7 @@ def _correlation_cost_grad(op, grad_output): input_b = ops.convert_to_tensor(op.inputs[1], name="input_b") grad_output_tensor = ops.convert_to_tensor(grad_output, name="grad_output") - op_call = gen_correlation_cost_op.correlation_cost_grad + op_call = _correlation_cost_op_so.correlation_cost_grad grads = op_call( input_a, input_b, diff --git a/tensorflow_addons/opticalflow/correlation_cost_test.py b/tensorflow_addons/opticalflow/correlation_cost_test.py index dce2d4a9fc..5b075cd135 100644 --- a/tensorflow_addons/opticalflow/correlation_cost_test.py +++ b/tensorflow_addons/opticalflow/correlation_cost_test.py @@ -19,7 +19,7 @@ import numpy as np -from tensorflow.contrib.correlation_cost.python.ops import correlation_cost_op +from tensorflow_addons.opticalflow import correlation_cost from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.framework import ops @@ -49,7 +49,7 @@ def _forward(self, stride_2 = 2 pad = 4 - call_op = correlation_cost_op.correlation_cost + call_op = correlation_cost actual_op = call_op( input_a_op, input_b_op, @@ -179,7 +179,7 @@ def _gradients(self, data_format='NCHW', use_gpu=False): input_a_op = ops.convert_to_tensor(input_a, dtype=dtypes.float32) input_b_op = ops.convert_to_tensor(input_b, dtype=dtypes.float32) - call_op = correlation_cost_op.correlation_cost + call_op = correlation_cost actual_op = call_op( input_a_op, input_b_op, From 1a21354d0f57dacddf9f3272cebe26ecedc27e79 Mon Sep 17 00:00:00 2001 From: PatWie Date: Thu, 2 May 2019 22:49:10 +0200 Subject: [PATCH 06/18] blast the eager-exec-crap away --- tensorflow_addons/opticalflow/correlation_cost_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow_addons/opticalflow/correlation_cost_test.py b/tensorflow_addons/opticalflow/correlation_cost_test.py index 5b075cd135..24ac372292 100644 --- a/tensorflow_addons/opticalflow/correlation_cost_test.py +++ b/tensorflow_addons/opticalflow/correlation_cost_test.py @@ -25,6 +25,8 @@ from tensorflow.python.framework import ops from tensorflow.python.platform import test from tensorflow.python.framework import constant_op +import tensorflow as tf +tf.compat.v1.disable_eager_execution() class CorrelationCostTest(test.TestCase): From a8da252da1cf117caba7a82200b6cb4bbc1840ab Mon Sep 17 00:00:00 2001 From: Sean Morgan Date: Thu, 25 Jul 2019 12:21:13 -0400 Subject: [PATCH 07/18] Remove private API calls --- .../opticalflow/correlation_cost.py | 18 +++++----- .../opticalflow/correlation_cost_test.py | 36 +++++++++---------- 2 files changed, 24 insertions(+), 30 deletions(-) diff --git a/tensorflow_addons/opticalflow/correlation_cost.py b/tensorflow_addons/opticalflow/correlation_cost.py index 398023db5d..ef801590bc 100644 --- a/tensorflow_addons/opticalflow/correlation_cost.py +++ b/tensorflow_addons/opticalflow/correlation_cost.py @@ -19,8 +19,6 @@ from __future__ import print_function import tensorflow as tf -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops from tensorflow_addons.utils.resource_loader import get_path_to_datafile _correlation_cost_op_so = tf.load_op_library( @@ -79,7 +77,7 @@ def correlation_cost(input_a, A `Tensor` of the format specified by `data_format`. """ - with ops.name_scope(name, "correlation_cost"): + with tf.name_scope(name or "correlation_cost"): op_call = _correlation_cost_op_so.correlation_cost ret = op_call( input_a, @@ -93,14 +91,14 @@ def correlation_cost(input_a, if data_format == 'NHWC': # this is easier to maintain without # specializing an additional cuda kernel - return array_ops.transpose(ret, [0, 2, 3, 1]) + return tf.transpose(ret, [0, 2, 3, 1]) return ret correlation_cost_grad = _correlation_cost_op_so.correlation_cost_grad -@ops.RegisterGradient("CorrelationCost") +@tf.RegisterGradient("CorrelationCost") def _correlation_cost_grad(op, grad_output): kernel_size = op.get_attr("kernel_size") max_displacement = op.get_attr("max_displacement") @@ -109,9 +107,9 @@ def _correlation_cost_grad(op, grad_output): pad = op.get_attr("pad") data_format = op.get_attr("data_format") - input_a = ops.convert_to_tensor(op.inputs[0], name="input_a") - input_b = ops.convert_to_tensor(op.inputs[1], name="input_b") - grad_output_tensor = ops.convert_to_tensor(grad_output, name="grad_output") + input_a = tf.convert_to_tensor(op.inputs[0], name="input_a") + input_b = tf.convert_to_tensor(op.inputs[1], name="input_b") + grad_output_tensor = tf.convert_to_tensor(grad_output, name="grad_output") op_call = _correlation_cost_op_so.correlation_cost_grad grads = op_call( @@ -125,6 +123,6 @@ def _correlation_cost_grad(op, grad_output): pad=pad, data_format=data_format) - grad_input_a = ops.convert_to_tensor(grads[0], name="grad_input_a") - grad_input_b = ops.convert_to_tensor(grads[1], name="grad_input_b") + grad_input_a = tf.convert_to_tensor(grads[0], name="grad_input_a") + grad_input_b = tf.convert_to_tensor(grads[1], name="grad_input_b") return [grad_input_a, grad_input_b] diff --git a/tensorflow_addons/opticalflow/correlation_cost_test.py b/tensorflow_addons/opticalflow/correlation_cost_test.py index 24ac372292..83b609843c 100644 --- a/tensorflow_addons/opticalflow/correlation_cost_test.py +++ b/tensorflow_addons/opticalflow/correlation_cost_test.py @@ -20,16 +20,12 @@ import numpy as np from tensorflow_addons.opticalflow import correlation_cost -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.framework import ops from tensorflow.python.platform import test -from tensorflow.python.framework import constant_op import tensorflow as tf tf.compat.v1.disable_eager_execution() -class CorrelationCostTest(test.TestCase): +class CorrelationCostTest(tf.test.TestCase): def _forward(self, input_a, input_b, @@ -42,8 +38,8 @@ def _forward(self, use_gpu=False): with self.test_session(use_gpu=use_gpu, force_gpu=use_gpu) as sess: - input_a_op = ops.convert_to_tensor(input_a, dtype=dtypes.float32) - input_b_op = ops.convert_to_tensor(input_b, dtype=dtypes.float32) + input_a_op = tf.convert_to_tensor(input_a, dtype=tf.float32) + input_b_op = tf.convert_to_tensor(input_b, dtype=tf.float32) kernel_size = 1 max_displacement = 2 @@ -70,9 +66,9 @@ def _forward_both(self, data_format='NCHW'): [[[6, 0, 15, -7], [7, 1, 16, -9], [8, 2, 17, -11]], [[9, 3, 18, -13], [10, 4, 19, -15], [11, 5, 20, -17]]]] - input_a = constant_op.constant(np.array(val), dtype=dtypes.float32) + input_a = tf.constant(np.array(val), dtype=tf.float32) valb = np.array(val).transpose(2, 3, 0, 1).reshape(2, 2, 3, 4) - input_b = constant_op.constant(valb, dtype=dtypes.float32) + input_b = tf.constant(valb, dtype=tf.float32) kernel_size = 1 max_displacement = 2 @@ -81,8 +77,8 @@ def _forward_both(self, data_format='NCHW'): pad = 4 if data_format == 'NHWC': - input_a = array_ops.transpose(input_a, [0, 2, 3, 1]) - input_b = array_ops.transpose(input_b, [0, 2, 3, 1]) + input_a = tf.transpose(input_a, [0, 2, 3, 1]) + input_b = tf.transpose(input_b, [0, 2, 3, 1]) actual_cpu = self._forward( input_a, @@ -117,16 +113,16 @@ def _forward_simple(self, data_format='NCHW', use_gpu=False): [[[6, 0, 15, -7], [7, 1, 16, -9], [8, 2, 17, -11]], [[9, 3, 18, -13], [10, 4, 19, -15], [11, 5, 20, -17]]]] - input_a = constant_op.constant(np.array(val), dtype=dtypes.float32) + input_a = tf.constant(np.array(val), dtype=tf.float32) valb = np.array(val).transpose(2, 3, 0, 1).reshape(2, 2, 3, 4) - input_b = constant_op.constant(valb, dtype=dtypes.float32) + input_b = tf.constant(valb, dtype=tf.float32) if data_format == 'NHWC': - input_a = array_ops.transpose(input_a, [0, 2, 3, 1]) - input_b = array_ops.transpose(input_b, [0, 2, 3, 1]) + input_a = tf.transpose(input_a, [0, 2, 3, 1]) + input_b = tf.transpose(input_b, [0, 2, 3, 1]) - input_a_op = ops.convert_to_tensor(input_a, dtype=dtypes.float32) - input_b_op = ops.convert_to_tensor(input_b, dtype=dtypes.float32) + input_a_op = tf.convert_to_tensor(input_a, dtype=tf.float32) + input_b_op = tf.convert_to_tensor(input_b, dtype=tf.float32) kernel_size = 1 max_displacement = 2 @@ -178,8 +174,8 @@ def _gradients(self, data_format='NCHW', use_gpu=False): with self.test_session(use_gpu=use_gpu, force_gpu=use_gpu): - input_a_op = ops.convert_to_tensor(input_a, dtype=dtypes.float32) - input_b_op = ops.convert_to_tensor(input_b, dtype=dtypes.float32) + input_a_op = tf.convert_to_tensor(input_a, dtype=tf.float32) + input_b_op = tf.convert_to_tensor(input_b, dtype=tf.float32) call_op = correlation_cost actual_op = call_op( @@ -221,4 +217,4 @@ def testBackwardNHWC(self): if __name__ == "__main__": - test.main() + tf.test.main() From 433725685eeaa06dcbca763b65a5c77abc29bbda Mon Sep 17 00:00:00 2001 From: Sean Morgan Date: Thu, 25 Jul 2019 12:25:03 -0400 Subject: [PATCH 08/18] Remove hardcoded test --- tensorflow_addons/opticalflow/correlation_cost_test.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tensorflow_addons/opticalflow/correlation_cost_test.py b/tensorflow_addons/opticalflow/correlation_cost_test.py index 83b609843c..c87b19d5cd 100644 --- a/tensorflow_addons/opticalflow/correlation_cost_test.py +++ b/tensorflow_addons/opticalflow/correlation_cost_test.py @@ -41,12 +41,6 @@ def _forward(self, input_a_op = tf.convert_to_tensor(input_a, dtype=tf.float32) input_b_op = tf.convert_to_tensor(input_b, dtype=tf.float32) - kernel_size = 1 - max_displacement = 2 - stride_1 = 1 - stride_2 = 2 - pad = 4 - call_op = correlation_cost actual_op = call_op( input_a_op, From 5df628334fd4b86d3762668e773e5ead6e482bcf Mon Sep 17 00:00:00 2001 From: Sean Morgan Date: Thu, 25 Jul 2019 12:44:58 -0400 Subject: [PATCH 09/18] Move to TFA layers --- tensorflow_addons/custom_ops/README.md | 2 +- .../custom_ops/{opticalflow => layers}/BUILD | 0 .../cc/kernels/correlation_cost_op.cc | 2 +- .../cc/kernels/correlation_cost_op.h | 0 .../cc/kernels/correlation_cost_op_gpu.cu.cc | 2 +- .../cc/ops/correlation_cost_op.cc | 0 tensorflow_addons/layers/BUILD | 17 +++++++++++ .../correlation_cost.py | 2 +- .../correlation_cost_test.py | 21 ++++++++++++-- tensorflow_addons/opticalflow/BUILD | 29 ------------------- tensorflow_addons/opticalflow/__init__.py | 21 -------------- 11 files changed, 40 insertions(+), 56 deletions(-) rename tensorflow_addons/custom_ops/{opticalflow => layers}/BUILD (100%) rename tensorflow_addons/custom_ops/{opticalflow => layers}/cc/kernels/correlation_cost_op.cc (99%) rename tensorflow_addons/custom_ops/{opticalflow => layers}/cc/kernels/correlation_cost_op.h (100%) rename tensorflow_addons/custom_ops/{opticalflow => layers}/cc/kernels/correlation_cost_op_gpu.cu.cc (99%) rename tensorflow_addons/custom_ops/{opticalflow => layers}/cc/ops/correlation_cost_op.cc (100%) rename tensorflow_addons/{opticalflow => layers}/correlation_cost.py (98%) rename tensorflow_addons/{opticalflow => layers}/correlation_cost_test.py (91%) delete mode 100644 tensorflow_addons/opticalflow/BUILD delete mode 100644 tensorflow_addons/opticalflow/__init__.py diff --git a/tensorflow_addons/custom_ops/README.md b/tensorflow_addons/custom_ops/README.md index 0d657954e6..0c91c9e02c 100644 --- a/tensorflow_addons/custom_ops/README.md +++ b/tensorflow_addons/custom_ops/README.md @@ -6,4 +6,4 @@ | Image | Ops for image manipulation | | Seq2seq | Ops for seq2seq encoder-decoder framework | | Text | Ops for text processing | -| OpticalFlow | Ops for optical flow processing | +| Layers | Ops for model layers | diff --git a/tensorflow_addons/custom_ops/opticalflow/BUILD b/tensorflow_addons/custom_ops/layers/BUILD similarity index 100% rename from tensorflow_addons/custom_ops/opticalflow/BUILD rename to tensorflow_addons/custom_ops/layers/BUILD diff --git a/tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op.cc b/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op.cc similarity index 99% rename from tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op.cc rename to tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op.cc index 73ddc92aa1..e1f4b1cdbc 100644 --- a/tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op.cc +++ b/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op.cc @@ -19,7 +19,7 @@ limitations under the License. #define EIGEN_USE_GPU #endif // GOOGLE_CUDA -#include "tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op.h" +#include "tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op.h b/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op.h similarity index 100% rename from tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op.h rename to tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op.h diff --git a/tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op_gpu.cu.cc b/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc similarity index 99% rename from tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op_gpu.cu.cc rename to tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc index d9087046e8..db0691136c 100644 --- a/tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op_gpu.cu.cc +++ b/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc @@ -17,7 +17,7 @@ limitations under the License. #define EIGEN_USE_GPU -#include "tensorflow_addons/custom_ops/opticalflow/cc/kernels/correlation_cost_op.h" +#include "tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op.h" #include "external/cub_archive/cub/device/device_reduce.cuh" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow_addons/custom_ops/opticalflow/cc/ops/correlation_cost_op.cc b/tensorflow_addons/custom_ops/layers/cc/ops/correlation_cost_op.cc similarity index 100% rename from tensorflow_addons/custom_ops/opticalflow/cc/ops/correlation_cost_op.cc rename to tensorflow_addons/custom_ops/layers/cc/ops/correlation_cost_op.cc diff --git a/tensorflow_addons/layers/BUILD b/tensorflow_addons/layers/BUILD index e05719a245..33bdbc6746 100644 --- a/tensorflow_addons/layers/BUILD +++ b/tensorflow_addons/layers/BUILD @@ -11,8 +11,12 @@ py_library( "poincare.py", "sparsemax.py", "wrappers.py", + "correlation_cost.py", ], srcs_version = "PY2AND3", + data = [ + "//tensorflow_addons/custom_ops/layers:_correlation_cost_ops.so", + ], deps = [ "//tensorflow_addons/activations", "//tensorflow_addons/utils", @@ -70,3 +74,16 @@ py_test( ":layers", ], ) + +py_test( + name = "correlation_cost_test", + size = "small", + srcs = [ + "correlation_cost_test.py", + ], + main = "correlation_cost_test.py", + srcs_version = "PY2AND3", + deps = [ + ":layers", + ], +) diff --git a/tensorflow_addons/opticalflow/correlation_cost.py b/tensorflow_addons/layers/correlation_cost.py similarity index 98% rename from tensorflow_addons/opticalflow/correlation_cost.py rename to tensorflow_addons/layers/correlation_cost.py index ef801590bc..8f0a5aa0a1 100644 --- a/tensorflow_addons/opticalflow/correlation_cost.py +++ b/tensorflow_addons/layers/correlation_cost.py @@ -22,7 +22,7 @@ from tensorflow_addons.utils.resource_loader import get_path_to_datafile _correlation_cost_op_so = tf.load_op_library( - get_path_to_datafile("custom_ops/opticalflow/_correlation_cost_ops.so")) + get_path_to_datafile("custom_ops/layers/_correlation_cost_ops.so")) # pylint: disable=redefined-builtin diff --git a/tensorflow_addons/opticalflow/correlation_cost_test.py b/tensorflow_addons/layers/correlation_cost_test.py similarity index 91% rename from tensorflow_addons/opticalflow/correlation_cost_test.py rename to tensorflow_addons/layers/correlation_cost_test.py index c87b19d5cd..7ef9a6d324 100644 --- a/tensorflow_addons/opticalflow/correlation_cost_test.py +++ b/tensorflow_addons/layers/correlation_cost_test.py @@ -18,10 +18,10 @@ from __future__ import print_function import numpy as np +import tensorflow as tf -from tensorflow_addons.opticalflow import correlation_cost +from tensorflow_addons.layers.correlation_cost import correlation_cost from tensorflow.python.platform import test -import tensorflow as tf tf.compat.v1.disable_eager_execution() @@ -171,6 +171,23 @@ def _gradients(self, data_format='NCHW', use_gpu=False): input_a_op = tf.convert_to_tensor(input_a, dtype=tf.float32) input_b_op = tf.convert_to_tensor(input_b, dtype=tf.float32) + # def correlation_fn(inputs): + # output = correlation_cost( + # inputs[0], + # inputs[1], + # kernel_size=kernel_size, + # max_displacement=max_displacement, + # stride_1=stride_1, + # stride_2=stride_2, + # pad=pad, + # data_format=data_format) + # return output + # + # theoretical, numerical = tf.test.compute_gradient( + # correlation_fn, [[input_a_op, input_b_op]]) + # + # self.assertAllClose(theoretical[0], numerical[0], 1e-4) + call_op = correlation_cost actual_op = call_op( input_a_op, diff --git a/tensorflow_addons/opticalflow/BUILD b/tensorflow_addons/opticalflow/BUILD deleted file mode 100644 index 7709dafb91..0000000000 --- a/tensorflow_addons/opticalflow/BUILD +++ /dev/null @@ -1,29 +0,0 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//visibility:public"]) - -py_library( - name = "opticalflow", - srcs = ([ - "__init__.py", - "correlation_cost.py", - ]), - data = [ - "//tensorflow_addons/custom_ops/opticalflow:_correlation_cost_ops.so", - "//tensorflow_addons/utils", - ], - srcs_version = "PY2AND3", -) - -py_test( - name = "correlation_cost_test", - size = "small", - srcs = [ - "correlation_cost_test.py", - ], - main = "correlation_cost_test.py", - srcs_version = "PY2AND3", - deps = [ - ":opticalflow", - ], -) diff --git a/tensorflow_addons/opticalflow/__init__.py b/tensorflow_addons/opticalflow/__init__.py deleted file mode 100644 index d6565f7a6c..0000000000 --- a/tensorflow_addons/opticalflow/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""ops module.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow_addons.opticalflow.correlation_cost import correlation_cost From 07b633213fc77bd68e0553f378ed89ca096c0ead Mon Sep 17 00:00:00 2001 From: Sean Morgan Date: Thu, 25 Jul 2019 14:57:56 -0400 Subject: [PATCH 10/18] Make tests eager executed --- .../layers/correlation_cost_test.py | 155 +++++------------- 1 file changed, 38 insertions(+), 117 deletions(-) diff --git a/tensorflow_addons/layers/correlation_cost_test.py b/tensorflow_addons/layers/correlation_cost_test.py index 7ef9a6d324..09e577df0a 100644 --- a/tensorflow_addons/layers/correlation_cost_test.py +++ b/tensorflow_addons/layers/correlation_cost_test.py @@ -21,8 +21,6 @@ import tensorflow as tf from tensorflow_addons.layers.correlation_cost import correlation_cost -from tensorflow.python.platform import test -tf.compat.v1.disable_eager_execution() class CorrelationCostTest(tf.test.TestCase): @@ -34,72 +32,24 @@ def _forward(self, stride_1, stride_2, pad, - data_format, - use_gpu=False): - with self.test_session(use_gpu=use_gpu, force_gpu=use_gpu) as sess: + data_format): - input_a_op = tf.convert_to_tensor(input_a, dtype=tf.float32) - input_b_op = tf.convert_to_tensor(input_b, dtype=tf.float32) - - call_op = correlation_cost - actual_op = call_op( - input_a_op, - input_b_op, - kernel_size=kernel_size, - max_displacement=max_displacement, - stride_1=stride_1, - stride_2=stride_2, - pad=pad, - data_format=data_format) - - return sess.run(actual_op) - - def _forward_both(self, data_format='NCHW'): - val = [[[[0, -6, 9, 5], [1, -5, 10, 3], [2, -4, 11, 1]], - [[3, -3, 12, -1], [4, -2, 13, -3], [5, -1, 14, -5]]], - [[[6, 0, 15, -7], [7, 1, 16, -9], [8, 2, 17, -11]], - [[9, 3, 18, -13], [10, 4, 19, -15], [11, 5, 20, -17]]]] - - input_a = tf.constant(np.array(val), dtype=tf.float32) - valb = np.array(val).transpose(2, 3, 0, 1).reshape(2, 2, 3, 4) - input_b = tf.constant(valb, dtype=tf.float32) - - kernel_size = 1 - max_displacement = 2 - stride_1 = 1 - stride_2 = 2 - pad = 4 - - if data_format == 'NHWC': - input_a = tf.transpose(input_a, [0, 2, 3, 1]) - input_b = tf.transpose(input_b, [0, 2, 3, 1]) - - actual_cpu = self._forward( - input_a, - input_b, - kernel_size=kernel_size, - max_displacement=max_displacement, - stride_1=stride_1, - stride_2=stride_2, - pad=pad, - data_format=data_format, - use_gpu=False) + input_a_op = tf.convert_to_tensor(input_a, dtype=tf.float32) + input_b_op = tf.convert_to_tensor(input_b, dtype=tf.float32) - actual_gpu = self._forward( - input_a, - input_b, + output = correlation_cost( + input_a_op, + input_b_op, kernel_size=kernel_size, max_displacement=max_displacement, stride_1=stride_1, stride_2=stride_2, pad=pad, - data_format=data_format, - use_gpu=True) + data_format=data_format) - self.assertEqual(actual_cpu.shape, actual_gpu.shape) - self.assertAllClose(actual_cpu, actual_gpu) + return output - def _forward_simple(self, data_format='NCHW', use_gpu=False): + def _forward_simple(self, data_format='NCHW'): # cumbersome calculation by hand for a fixed input # we just test where zeros occurs and a few entries val = [[[[0, -6, 9, 5], [1, -5, 10, 3], [2, -4, 11, 1]], @@ -115,8 +65,8 @@ def _forward_simple(self, data_format='NCHW', use_gpu=False): input_a = tf.transpose(input_a, [0, 2, 3, 1]) input_b = tf.transpose(input_b, [0, 2, 3, 1]) - input_a_op = tf.convert_to_tensor(input_a, dtype=tf.float32) - input_b_op = tf.convert_to_tensor(input_b, dtype=tf.float32) + input_a_tensor = tf.convert_to_tensor(input_a, dtype=tf.float32) + input_b_tensor = tf.convert_to_tensor(input_b, dtype=tf.float32) kernel_size = 1 max_displacement = 2 @@ -125,36 +75,35 @@ def _forward_simple(self, data_format='NCHW', use_gpu=False): pad = 4 actual = self._forward( - input_a_op, - input_b_op, + input_a_tensor, + input_b_tensor, kernel_size=kernel_size, max_displacement=max_displacement, stride_1=stride_1, stride_2=stride_2, pad=pad, - data_format=data_format, - use_gpu=use_gpu) + data_format=data_format) if data_format == 'NHWC': # NHWC -> NCHW - actual = actual.transpose(0, 3, 1, 2) + actual = tf.transpose(actual, [0, 3, 1, 2]) # we just need to test fixed ids, as output is NCHW independently from data_format expected_ids = np.concatenate([np.zeros(464,), np.ones(464,)]) - self.assertAllClose(np.where(actual == 0)[0], expected_ids) + self.assertAllClose(np.where(actual.numpy() == 0)[0], expected_ids) counts = [54, 52, 54, 50, 44, 50, 54, 52, 54] expected_ids = np.concatenate( [k * np.ones(v,) for k, v in enumerate(counts)]) expected_ids = np.concatenate([expected_ids, expected_ids]) - self.assertAllClose(np.where(actual == 0)[1], expected_ids) + self.assertAllClose(np.where(actual.numpy() == 0)[1], expected_ids) self.assertEqual(actual.shape, (2, 9, 7, 8)) - def _gradients(self, data_format='NCHW', use_gpu=False): + def _gradients(self, data_format='NCHW'): batch, channels, height, width = 2, 3, 5, 6 - input_a = np.random.randn(batch, channels, height, width) - input_b = np.random.randn(batch, channels, height, width) + input_a = tf.random.normal([batch, channels, height, width]) + input_b = tf.random.normal([batch, channels, height, width]) kernel_size = 1 max_displacement = 2 @@ -163,68 +112,40 @@ def _gradients(self, data_format='NCHW', use_gpu=False): pad = 4 if data_format == 'NHWC': - input_a = input_a.transpose(0, 2, 3, 1) - input_b = input_b.transpose(0, 2, 3, 1) - - with self.test_session(use_gpu=use_gpu, force_gpu=use_gpu): - - input_a_op = tf.convert_to_tensor(input_a, dtype=tf.float32) - input_b_op = tf.convert_to_tensor(input_b, dtype=tf.float32) - - # def correlation_fn(inputs): - # output = correlation_cost( - # inputs[0], - # inputs[1], - # kernel_size=kernel_size, - # max_displacement=max_displacement, - # stride_1=stride_1, - # stride_2=stride_2, - # pad=pad, - # data_format=data_format) - # return output - # - # theoretical, numerical = tf.test.compute_gradient( - # correlation_fn, [[input_a_op, input_b_op]]) - # - # self.assertAllClose(theoretical[0], numerical[0], 1e-4) - - call_op = correlation_cost - actual_op = call_op( - input_a_op, - input_b_op, + input_a = tf.transpose(input_a, [0, 2, 3, 1]) + input_b = tf.transpose(input_b, [0, 2, 3, 1]) + + input_a_op = tf.convert_to_tensor(input_a, dtype=tf.float32) + input_b_op = tf.convert_to_tensor(input_b, dtype=tf.float32) + + def correlation_fn(inputs): + output = correlation_cost( + inputs[0], + inputs[1], kernel_size=kernel_size, max_displacement=max_displacement, stride_1=stride_1, stride_2=stride_2, pad=pad, data_format=data_format) + return output - err_a = test.compute_gradient_error([input_a_op, input_b_op], - [input_a.shape, input_b.shape], - actual_op, - actual_op.shape.as_list()) + theoretical, numerical = tf.test.compute_gradient( + correlation_fn, [[input_a_op, input_b_op]]) - self.assertLess(err_a, 1e-4) + self.assertAllClose(theoretical[0], numerical[0], atol=1e-3) def testForwardNCHW(self): - self._forward_simple(data_format='NCHW', use_gpu=False) - self._forward_simple(data_format='NCHW', use_gpu=True) + self._forward_simple(data_format='NCHW') def testForwardNHWC(self): - self._forward_simple(data_format='NHWC', use_gpu=False) - self._forward_simple(data_format='NHWC', use_gpu=True) - - def testForwardSame(self): - self._forward_both(data_format='NCHW') - self._forward_both(data_format='NHWC') + self._forward_simple(data_format='NHWC') def testBackwardNCHW(self): - self._gradients(data_format='NCHW', use_gpu=False) - self._gradients(data_format='NCHW', use_gpu=True) + self._gradients(data_format='NCHW') def testBackwardNHWC(self): - self._gradients(data_format='NHWC', use_gpu=False) - self._gradients(data_format='NHWC', use_gpu=True) + self._gradients(data_format='NHWC') if __name__ == "__main__": From b4fa9cff4844229e9d891ca973b89d1fa3e27206 Mon Sep 17 00:00:00 2001 From: Sean Morgan Date: Thu, 25 Jul 2019 17:32:26 -0400 Subject: [PATCH 11/18] Run tests in eager and graph mode --- .../layers/correlation_cost_test.py | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/tensorflow_addons/layers/correlation_cost_test.py b/tensorflow_addons/layers/correlation_cost_test.py index 09e577df0a..59dcf21079 100644 --- a/tensorflow_addons/layers/correlation_cost_test.py +++ b/tensorflow_addons/layers/correlation_cost_test.py @@ -19,10 +19,11 @@ import numpy as np import tensorflow as tf - from tensorflow_addons.layers.correlation_cost import correlation_cost +from tensorflow_addons.utils import test_utils +@test_utils.run_all_in_graph_and_eager_modes class CorrelationCostTest(tf.test.TestCase): def _forward(self, input_a, @@ -88,22 +89,24 @@ def _forward_simple(self, data_format='NCHW'): # NHWC -> NCHW actual = tf.transpose(actual, [0, 3, 1, 2]) - # we just need to test fixed ids, as output is NCHW independently from data_format + # We can test fixed ids, as output is independent from data_format expected_ids = np.concatenate([np.zeros(464,), np.ones(464,)]) - self.assertAllClose(np.where(actual.numpy() == 0)[0], expected_ids) + self.assertAllClose(tf.where(tf.equal(actual, 0))[:, 0], expected_ids) counts = [54, 52, 54, 50, 44, 50, 54, 52, 54] expected_ids = np.concatenate( [k * np.ones(v,) for k, v in enumerate(counts)]) expected_ids = np.concatenate([expected_ids, expected_ids]) - self.assertAllClose(np.where(actual.numpy() == 0)[1], expected_ids) + self.assertAllClose(tf.where(tf.equal(actual, 0))[:, 1], expected_ids) self.assertEqual(actual.shape, (2, 9, 7, 8)) def _gradients(self, data_format='NCHW'): batch, channels, height, width = 2, 3, 5, 6 - input_a = tf.random.normal([batch, channels, height, width]) - input_b = tf.random.normal([batch, channels, height, width]) + input_a = np.random.randn( + batch, channels, height, width).astype(np.float32) + input_b = np.random.randn( + batch, channels, height, width).astype(np.float32) kernel_size = 1 max_displacement = 2 @@ -118,20 +121,19 @@ def _gradients(self, data_format='NCHW'): input_a_op = tf.convert_to_tensor(input_a, dtype=tf.float32) input_b_op = tf.convert_to_tensor(input_b, dtype=tf.float32) - def correlation_fn(inputs): - output = correlation_cost( - inputs[0], - inputs[1], + def correlation_fn(input_a, input_b): + return correlation_cost( + input_a, + input_b, kernel_size=kernel_size, max_displacement=max_displacement, stride_1=stride_1, stride_2=stride_2, pad=pad, data_format=data_format) - return output theoretical, numerical = tf.test.compute_gradient( - correlation_fn, [[input_a_op, input_b_op]]) + correlation_fn, [input_a_op, input_b_op]) self.assertAllClose(theoretical[0], numerical[0], atol=1e-3) From c009b329281c581118b6438c4b3782eeda063b6e Mon Sep 17 00:00:00 2001 From: Sean Morgan Date: Thu, 25 Jul 2019 23:10:25 -0400 Subject: [PATCH 12/18] Set as a Keras Layer --- tensorflow_addons/layers/BUILD | 10 +- tensorflow_addons/layers/README.md | 2 + tensorflow_addons/layers/__init__.py | 1 + .../{correlation_cost.py => optical_flow.py} | 100 +++++++++++++++--- ...tion_cost_test.py => optical_flow_test.py} | 83 ++++++++++----- 5 files changed, 154 insertions(+), 42 deletions(-) rename tensorflow_addons/layers/{correlation_cost.py => optical_flow.py} (54%) rename tensorflow_addons/layers/{correlation_cost_test.py => optical_flow_test.py} (62%) diff --git a/tensorflow_addons/layers/BUILD b/tensorflow_addons/layers/BUILD index 33bdbc6746..59aeb562b5 100644 --- a/tensorflow_addons/layers/BUILD +++ b/tensorflow_addons/layers/BUILD @@ -8,15 +8,15 @@ py_library( "__init__.py", "maxout.py", "normalizations.py", + "optical_flow.py", "poincare.py", "sparsemax.py", "wrappers.py", - "correlation_cost.py", ], - srcs_version = "PY2AND3", data = [ "//tensorflow_addons/custom_ops/layers:_correlation_cost_ops.so", ], + srcs_version = "PY2AND3", deps = [ "//tensorflow_addons/activations", "//tensorflow_addons/utils", @@ -76,12 +76,12 @@ py_test( ) py_test( - name = "correlation_cost_test", + name = "optical_flow_test", size = "small", srcs = [ - "correlation_cost_test.py", + "optical_flow_test.py", ], - main = "correlation_cost_test.py", + main = "optical_flow_test.py", srcs_version = "PY2AND3", deps = [ ":layers", diff --git a/tensorflow_addons/layers/README.md b/tensorflow_addons/layers/README.md index 94cf0b55d8..4e4e4b48dc 100644 --- a/tensorflow_addons/layers/README.md +++ b/tensorflow_addons/layers/README.md @@ -5,6 +5,7 @@ |:---------- |:----------- |:------------- | | maxout | | | | normalizations | @smokrow | moritz.kroeger@tu-dortmund.de | +| opticalflow | | | | poincare | | | | sparsemax | @AndreasMadsen | amwwebdk+github@gmail.com | | wrappers | @seanpmorgan | seanmorgan@outlook.com | @@ -15,6 +16,7 @@ | maxout | Maxout | https://arxiv.org/abs/1302.4389 | | normalizations | GroupNormalization | https://arxiv.org/abs/1803.08494 | | normalizations | InstanceNormalization | https://arxiv.org/abs/1607.08022 | +| opticalflow | CorrelationCost | https://arxiv.org/abs/1504.06852 | | poincare | PoincareNormalize | https://arxiv.org/abs/1705.08039 | | sparsemax| Sparsemax | https://arxiv.org/abs/1602.02068 | | wrappers | WeightNormalization | https://arxiv.org/abs/1602.07868 | diff --git a/tensorflow_addons/layers/__init__.py b/tensorflow_addons/layers/__init__.py index e5acf9666a..382f2aa80e 100644 --- a/tensorflow_addons/layers/__init__.py +++ b/tensorflow_addons/layers/__init__.py @@ -21,6 +21,7 @@ from tensorflow_addons.layers.maxout import Maxout from tensorflow_addons.layers.normalizations import GroupNormalization from tensorflow_addons.layers.normalizations import InstanceNormalization +from tensorflow_addons.layers.optical_flow import CorrelationCost from tensorflow_addons.layers.poincare import PoincareNormalize from tensorflow_addons.layers.sparsemax import Sparsemax from tensorflow_addons.layers.wrappers import WeightNormalization diff --git a/tensorflow_addons/layers/correlation_cost.py b/tensorflow_addons/layers/optical_flow.py similarity index 54% rename from tensorflow_addons/layers/correlation_cost.py rename to tensorflow_addons/layers/optical_flow.py index 8f0a5aa0a1..3496b37dda 100644 --- a/tensorflow_addons/layers/correlation_cost.py +++ b/tensorflow_addons/layers/optical_flow.py @@ -19,13 +19,12 @@ from __future__ import print_function import tensorflow as tf +from tensorflow_addons.utils import keras_utils from tensorflow_addons.utils.resource_loader import get_path_to_datafile _correlation_cost_op_so = tf.load_op_library( get_path_to_datafile("custom_ops/layers/_correlation_cost_ops.so")) -# pylint: disable=redefined-builtin - def correlation_cost(input_a, input_b, @@ -34,10 +33,15 @@ def correlation_cost(input_a, stride_1, stride_2, pad, - data_format='NHWC', + data_format='channels_last', name=None): """Correlation Cost Volume computation. + "FlowNet: Learning Optical Flow with Convolutional Networks" + Philipp Fischer, Alexey Dosovitskiy, Eddy Ilg, Philip Hausser, + Caner Hazirbas, Vladimir Golkov, Patrick van der Smagt, + Daniel Cremers, Thomas Brox. https://arxiv.org/abs/1504.06852 + Computes a cost volume using correlation for two inputs. For feature maps A, B with spatial dimensions w, h, c it computes @@ -53,7 +57,7 @@ def correlation_cost(input_a, H' = H + 2 * (pad - bd) / stride_1 W' = W + 2 * (pad - bd) / stride_1 - Note: When the data_format requests "NHWC", an additional explicit + Note: When the data_format requests "channels_last", an additional explicit transpose operation is executed. Args: @@ -68,9 +72,9 @@ def correlation_cost(input_a, pad: An integer specifying the paddings in height and width. data_format: Specifies the data format. Possible values are: - "NHWC" float [batch, height, width, channels] - "NCHW" float [batch, channels, height, width] - Defaults to `"NHWC"`. + "channels_last" float [batch, height, width, channels] + "channels_first" float [batch, channels, height, width] + Defaults to `"channels_last"`. name: A name for the operation (optional). Returns: @@ -79,6 +83,15 @@ def correlation_cost(input_a, with tf.name_scope(name or "correlation_cost"): op_call = _correlation_cost_op_so.correlation_cost + + if data_format == "channels_last": + op_data_format = "NHWC" + elif data_format == "channels_first": + op_data_format = "NCHW" + else: + raise ValueError("`data_format` must be either `channels_last` or" + "`channels_first`") + ret = op_call( input_a, input_b, @@ -87,17 +100,14 @@ def correlation_cost(input_a, stride_1=stride_1, stride_2=stride_2, pad=pad, - data_format=data_format) - if data_format == 'NHWC': + data_format=op_data_format) + if data_format == 'channels_last': # this is easier to maintain without # specializing an additional cuda kernel return tf.transpose(ret, [0, 2, 3, 1]) return ret -correlation_cost_grad = _correlation_cost_op_so.correlation_cost_grad - - @tf.RegisterGradient("CorrelationCost") def _correlation_cost_grad(op, grad_output): kernel_size = op.get_attr("kernel_size") @@ -126,3 +136,69 @@ def _correlation_cost_grad(op, grad_output): grad_input_a = tf.convert_to_tensor(grads[0], name="grad_input_a") grad_input_b = tf.convert_to_tensor(grads[1], name="grad_input_b") return [grad_input_a, grad_input_b] + + +@keras_utils.register_keras_custom_object +class CorrelationCost(tf.python.keras.layers.Layer): + def __init__(self, kernel_size, max_displacement, stride_1, stride_2, pad, + data_format, **kwargs): + self.kernel_size = kernel_size + self.max_displacement = max_displacement + self.stride_1 = stride_1 + self.stride_2 = stride_2 + self.pad = pad + self.data_format = data_format + super(CorrelationCost, self).__init__(**kwargs) + + def build(self, input_shape): + if not isinstance(input_shape, list): + raise ValueError("Input must be a list of two Tensors to process") + super(CorrelationCost, self).build(input_shape) + + def call(self, inputs): + if not isinstance(inputs, list): + raise ValueError("Input must be a list of two Tensors to process") + + input_a = tf.convert_to_tensor(inputs[0]) + input_b = tf.convert_to_tensor(inputs[1]) + + return correlation_cost( + input_a, + input_b, + kernel_size=self.kernel_size, + max_displacement=self.max_displacement, + stride_1=self.stride_1, + stride_2=self.stride_2, + pad=self.pad, + data_format=self.data_format) + + def compute_output_shape(self, input_shape): + assert isinstance(input_shape, list) + n = input_shape[0][0] + r = self.max_displacement / self.stride_2 + bd = self.max_displacement + (self.kernel_size - 1) / 2 + output_c = (2 * r + 1)**2 + + if self.data_format == "channels_first": + output_h = input_shape[0][1] + 2 * (self.pad - bd) / self.stride_1 + output_w = input_shape[0][2] + 2 * (self.pad - bd) / self.stride_1 + elif self.data_format == "channels_last": + output_h = input_shape[0][0] + 2 * (self.pad - bd) / self.stride_1 + output_w = input_shape[0][1] + 2 * (self.pad - bd) / self.stride_1 + else: + raise ValueError("`data_format` must be either `channels_last` or" + "`channels_first`") + return [n, output_c, output_h, output_w] + + def get_config(self): + config = { + 'kernel_size': self.kernel_size, + 'max_displacement': self.max_displacement, + 'stride_1': self.stride_1, + 'stride_2': self.stride_2, + 'pad': self.pad, + 'data_format': self.data_format + } + + base_config = super(CorrelationCost, self).get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/tensorflow_addons/layers/correlation_cost_test.py b/tensorflow_addons/layers/optical_flow_test.py similarity index 62% rename from tensorflow_addons/layers/correlation_cost_test.py rename to tensorflow_addons/layers/optical_flow_test.py index 59dcf21079..bbe76182d5 100644 --- a/tensorflow_addons/layers/correlation_cost_test.py +++ b/tensorflow_addons/layers/optical_flow_test.py @@ -19,21 +19,14 @@ import numpy as np import tensorflow as tf -from tensorflow_addons.layers.correlation_cost import correlation_cost +from tensorflow_addons.layers.optical_flow import correlation_cost, CorrelationCost from tensorflow_addons.utils import test_utils @test_utils.run_all_in_graph_and_eager_modes class CorrelationCostTest(tf.test.TestCase): - def _forward(self, - input_a, - input_b, - kernel_size, - max_displacement, - stride_1, - stride_2, - pad, - data_format): + def _forward(self, input_a, input_b, kernel_size, max_displacement, + stride_1, stride_2, pad, data_format): input_a_op = tf.convert_to_tensor(input_a, dtype=tf.float32) input_b_op = tf.convert_to_tensor(input_b, dtype=tf.float32) @@ -50,7 +43,7 @@ def _forward(self, return output - def _forward_simple(self, data_format='NCHW'): + def _forward_simple(self, data_format='channels_first'): # cumbersome calculation by hand for a fixed input # we just test where zeros occurs and a few entries val = [[[[0, -6, 9, 5], [1, -5, 10, 3], [2, -4, 11, 1]], @@ -62,7 +55,7 @@ def _forward_simple(self, data_format='NCHW'): valb = np.array(val).transpose(2, 3, 0, 1).reshape(2, 2, 3, 4) input_b = tf.constant(valb, dtype=tf.float32) - if data_format == 'NHWC': + if data_format == 'channels_last': input_a = tf.transpose(input_a, [0, 2, 3, 1]) input_b = tf.transpose(input_b, [0, 2, 3, 1]) @@ -85,7 +78,7 @@ def _forward_simple(self, data_format='NCHW'): pad=pad, data_format=data_format) - if data_format == 'NHWC': + if data_format == 'channels_last': # NHWC -> NCHW actual = tf.transpose(actual, [0, 3, 1, 2]) @@ -100,13 +93,13 @@ def _forward_simple(self, data_format='NCHW'): self.assertAllClose(tf.where(tf.equal(actual, 0))[:, 1], expected_ids) self.assertEqual(actual.shape, (2, 9, 7, 8)) - def _gradients(self, data_format='NCHW'): + def _gradients(self, data_format='channels_first'): batch, channels, height, width = 2, 3, 5, 6 - input_a = np.random.randn( - batch, channels, height, width).astype(np.float32) - input_b = np.random.randn( - batch, channels, height, width).astype(np.float32) + input_a = np.random.randn(batch, channels, height, + width).astype(np.float32) + input_b = np.random.randn(batch, channels, height, + width).astype(np.float32) kernel_size = 1 max_displacement = 2 @@ -114,12 +107,12 @@ def _gradients(self, data_format='NCHW'): stride_2 = 2 pad = 4 - if data_format == 'NHWC': + if data_format == 'channels_last': input_a = tf.transpose(input_a, [0, 2, 3, 1]) input_b = tf.transpose(input_b, [0, 2, 3, 1]) - input_a_op = tf.convert_to_tensor(input_a, dtype=tf.float32) - input_b_op = tf.convert_to_tensor(input_b, dtype=tf.float32) + input_a_op = tf.convert_to_tensor(input_a) + input_b_op = tf.convert_to_tensor(input_b) def correlation_fn(input_a, input_b): return correlation_cost( @@ -138,16 +131,56 @@ def correlation_fn(input_a, input_b): self.assertAllClose(theoretical[0], numerical[0], atol=1e-3) def testForwardNCHW(self): - self._forward_simple(data_format='NCHW') + self._forward_simple(data_format='channels_first') def testForwardNHWC(self): - self._forward_simple(data_format='NHWC') + self._forward_simple(data_format='channels_last') def testBackwardNCHW(self): - self._gradients(data_format='NCHW') + self._gradients(data_format='channels_first') def testBackwardNHWC(self): - self._gradients(data_format='NHWC') + self._gradients(data_format='channels_last') + + def testKerasLayer(self): + val_a = [[[[0, -6, 9, 5], [1, -5, 10, 3], [2, -4, 11, 1]], + [[3, -3, 12, -1], [4, -2, 13, -3], [5, -1, 14, -5]]], + [[[6, 0, 15, -7], [7, 1, 16, -9], [8, 2, 17, -11]], + [[9, 3, 18, -13], [10, 4, 19, -15], [11, 5, 20, -17]]]] + val_b = np.array(val_a).transpose(2, 3, 0, 1).reshape(2, 2, 3, 4) + + # yapf: disable + input_a = tf.keras.Input(shape=(2, 3, 4,)) + input_b = tf.keras.Input(shape=(2, 3, 4,)) + + layer = CorrelationCost( + kernel_size=1, + max_displacement=2, + stride_1=1, + stride_2=2, + pad=4, + data_format="channels_first") + + expected_output_shape = tuple( + layer.compute_output_shape([(2, 3, 4,), (2, 3, 4,)]))[1:] + # yapf: enable + + x = [input_a, input_b] + y = layer(x) + model = tf.python.keras.models.Model(x, y) + actual_output = model.predict([val_a, val_b]) + + expected_output_type = 'float32' + if tf.keras.backend.dtype(y[0]) != expected_output_type: + raise AssertionError( + "Inferred output type %s does not equal " + "expected output type %s" % (tf.keras.backend.dtype(y[0]), + expected_output_type)) + + if actual_output[0].shape != expected_output_shape: + raise AssertionError( + "Expected shape %s does not equal output shape" + "%s" % (actual_output[0].shape, expected_output_shape)) if __name__ == "__main__": From c112b1a9f8e88a23ff7f5b11d0c2548a6a81a2c7 Mon Sep 17 00:00:00 2001 From: Sean Morgan Date: Fri, 26 Jul 2019 09:10:44 -0400 Subject: [PATCH 13/18] Fix testing --- tensorflow_addons/custom_ops/layers/BUILD | 30 +++++++++-- .../cc/kernels/correlation_cost_op_gpu.cu.cc | 6 ++- tensorflow_addons/layers/optical_flow.py | 4 +- tensorflow_addons/layers/optical_flow_test.py | 52 +++++++++++++------ 4 files changed, 70 insertions(+), 22 deletions(-) diff --git a/tensorflow_addons/custom_ops/layers/BUILD b/tensorflow_addons/custom_ops/layers/BUILD index 6cb1fcebf1..ce44e51bef 100644 --- a/tensorflow_addons/custom_ops/layers/BUILD +++ b/tensorflow_addons/custom_ops/layers/BUILD @@ -2,6 +2,9 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//visibility:public"]) +load("@local_config_tf//:build_defs.bzl", "D_GLIBCXX_USE_CXX11_ABI") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured", "if_cuda") + cc_binary( name = "_correlation_cost_ops.so", srcs = [ @@ -13,14 +16,33 @@ cc_binary( copts = [ "-pthread", "-std=c++11", - "-D_GLIBCXX_USE_CXX11_ABI=0", + D_GLIBCXX_USE_CXX11_ABI, ], linkshared = 1, deps = [ "@local_config_tf//:libtensorflow_framework", "@local_config_tf//:tf_header_lib", + ] + if_cuda_is_configured([":correlation_cost_ops_gpu"]), +) + +cc_library( + name = "correlation_cost_ops_gpu", + srcs = [ + "cc/kernels/correlation_cost_op.h", + "cc/kernels/correlation_cost_op_gpu.cu.cc", ], - # + if_cuda([ - # "@cub_archive//:cub", - # ]), + copts = if_cuda_is_configured([ + "-DGOOGLE_CUDA=1", + "-x cuda", + "-nvcc_options=relaxed-constexpr", + "-nvcc_options=ftz=true", + ]), + deps = [ + "@local_config_tf//:libtensorflow_framework", + "@local_config_tf//:tf_header_lib", + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_libs", + "@local_config_cuda//cuda:cuda_headers", + ]), + alwayslink = 1, ) diff --git a/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc b/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc index db0691136c..ebacc2134e 100644 --- a/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc +++ b/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc @@ -17,11 +17,13 @@ limitations under the License. #define EIGEN_USE_GPU +// TODO: FIX CUDA Build +//#include "third_party/cub/device/device_reduce.cuh" +//#include "tensorflow/core/util/gpu_kernel_helper.h" + #include "tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op.h" -#include "external/cub_archive/cub/device/device_reduce.cuh" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { diff --git a/tensorflow_addons/layers/optical_flow.py b/tensorflow_addons/layers/optical_flow.py index 3496b37dda..1cea06b931 100644 --- a/tensorflow_addons/layers/optical_flow.py +++ b/tensorflow_addons/layers/optical_flow.py @@ -182,13 +182,15 @@ def compute_output_shape(self, input_shape): if self.data_format == "channels_first": output_h = input_shape[0][1] + 2 * (self.pad - bd) / self.stride_1 output_w = input_shape[0][2] + 2 * (self.pad - bd) / self.stride_1 + return [n, output_c, output_h, output_w] + elif self.data_format == "channels_last": output_h = input_shape[0][0] + 2 * (self.pad - bd) / self.stride_1 output_w = input_shape[0][1] + 2 * (self.pad - bd) / self.stride_1 + return [n, output_h, output_w, output_c] else: raise ValueError("`data_format` must be either `channels_last` or" "`channels_first`") - return [n, output_c, output_h, output_w] def get_config(self): config = { diff --git a/tensorflow_addons/layers/optical_flow_test.py b/tensorflow_addons/layers/optical_flow_test.py index bbe76182d5..c39186ab11 100644 --- a/tensorflow_addons/layers/optical_flow_test.py +++ b/tensorflow_addons/layers/optical_flow_test.py @@ -43,9 +43,13 @@ def _forward(self, input_a, input_b, kernel_size, max_displacement, return output - def _forward_simple(self, data_format='channels_first'): + def _forward_simple(self, data_format, gpu): # cumbersome calculation by hand for a fixed input # we just test where zeros occurs and a few entries + + if gpu and tf.test.is_gpu_available(): + self.skipTest('FIX GPU BUILD') + val = [[[[0, -6, 9, 5], [1, -5, 10, 3], [2, -4, 11, 1]], [[3, -3, 12, -1], [4, -2, 13, -3], [5, -1, 14, -5]]], [[[6, 0, 15, -7], [7, 1, 16, -9], [8, 2, 17, -11]], @@ -93,7 +97,10 @@ def _forward_simple(self, data_format='channels_first'): self.assertAllClose(tf.where(tf.equal(actual, 0))[:, 1], expected_ids) self.assertEqual(actual.shape, (2, 9, 7, 8)) - def _gradients(self, data_format='channels_first'): + def _gradients(self, data_format, gpu): + + if gpu and tf.test.is_gpu_available(): + self.skipTest('FIX GPU BUILD') batch, channels, height, width = 2, 3, 5, 6 input_a = np.random.randn(batch, channels, height, @@ -130,19 +137,10 @@ def correlation_fn(input_a, input_b): self.assertAllClose(theoretical[0], numerical[0], atol=1e-3) - def testForwardNCHW(self): - self._forward_simple(data_format='channels_first') - - def testForwardNHWC(self): - self._forward_simple(data_format='channels_last') + def _keras(self, data_format, gpu): + if gpu and tf.test.is_gpu_available(): + self.skipTest('FIX GPU BUILD') - def testBackwardNCHW(self): - self._gradients(data_format='channels_first') - - def testBackwardNHWC(self): - self._gradients(data_format='channels_last') - - def testKerasLayer(self): val_a = [[[[0, -6, 9, 5], [1, -5, 10, 3], [2, -4, 11, 1]], [[3, -3, 12, -1], [4, -2, 13, -3], [5, -1, 14, -5]]], [[[6, 0, 15, -7], [7, 1, 16, -9], [8, 2, 17, -11]], @@ -159,7 +157,7 @@ def testKerasLayer(self): stride_1=1, stride_2=2, pad=4, - data_format="channels_first") + data_format=data_format) expected_output_shape = tuple( layer.compute_output_shape([(2, 3, 4,), (2, 3, 4,)]))[1:] @@ -182,6 +180,30 @@ def testKerasLayer(self): "Expected shape %s does not equal output shape" "%s" % (actual_output[0].shape, expected_output_shape)) + def testForwardNCHW(self): + self._forward_simple(data_format='channels_first', gpu=False) + self._forward_simple(data_format='channels_first', gpu=True) + + def testForwardNHWC(self): + self._forward_simple(data_format='channels_last', gpu=False) + self._forward_simple(data_format='channels_last', gpu=True) + + def testBackwardNCHW(self): + self._gradients(data_format='channels_first', gpu=False) + self._gradients(data_format='channels_first', gpu=True) + + def testBackwardNHWC(self): + self._gradients(data_format='channels_last', gpu=False) + self._gradients(data_format='channels_last', gpu=True) + + def testKerasNCHW(self): + self._keras(data_format='channels_first', gpu=False) + self._keras(data_format='channels_first', gpu=True) + + def testKerasNHWC(self): + self._keras(data_format='channels_last', gpu=False) + self._keras(data_format='channels_last', gpu=True) + if __name__ == "__main__": tf.test.main() From 03673d8a30619475c4d730cb71d1d4b5e13af82d Mon Sep 17 00:00:00 2001 From: Sean Morgan Date: Fri, 26 Jul 2019 11:08:03 -0400 Subject: [PATCH 14/18] Add GPU kernel and testing --- WORKSPACE | 13 + build_deps/gpu/cub.BUILD | 26 ++ tensorflow_addons/custom_ops/layers/BUILD | 1 + .../cc/kernels/correlation_cost_op_gpu.cu.cc | 5 +- tensorflow_addons/layers/optical_flow.py | 1 + tensorflow_addons/layers/optical_flow_test.py | 246 +++++++++--------- 6 files changed, 159 insertions(+), 133 deletions(-) create mode 100644 build_deps/gpu/cub.BUILD diff --git a/WORKSPACE b/WORKSPACE index 24062f1570..103360fc5d 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -1,6 +1,19 @@ +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("//build_deps/tf_dependency:tf_configure.bzl", "tf_configure") load("//build_deps/gpu:cuda_configure.bzl", "cuda_configure") + +http_archive( + name = "cub_archive", + build_file = "//build_deps/gpu:cub.BUILD", + sha256 = "6bfa06ab52a650ae7ee6963143a0bbc667d6504822cbd9670369b598f18c58c3", + strip_prefix = "cub-1.8.0", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/NVlabs/cub/archive/1.8.0.zip", + "https://github.com/NVlabs/cub/archive/1.8.0.zip", + ], +) + tf_configure( name = "local_config_tf", ) diff --git a/build_deps/gpu/cub.BUILD b/build_deps/gpu/cub.BUILD new file mode 100644 index 0000000000..5de7218cda --- /dev/null +++ b/build_deps/gpu/cub.BUILD @@ -0,0 +1,26 @@ +# Description: CUB library which is a set of primitives for GPU programming. + +load("@local_config_cuda//cuda:build_defs.bzl", "cuda_default_copts", "if_cuda") + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # BSD + + +filegroup( + name = "cub_header_files", + srcs = glob([ + "cub/**", + ]), +) + +cc_library( + name = "cub", + hdrs = if_cuda([":cub_header_files"]), + include_prefix = "gpu", + deps = [ + "@local_config_cuda//cuda:cuda_headers", + ], +) \ No newline at end of file diff --git a/tensorflow_addons/custom_ops/layers/BUILD b/tensorflow_addons/custom_ops/layers/BUILD index ce44e51bef..ed0c567f59 100644 --- a/tensorflow_addons/custom_ops/layers/BUILD +++ b/tensorflow_addons/custom_ops/layers/BUILD @@ -43,6 +43,7 @@ cc_library( ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_libs", "@local_config_cuda//cuda:cuda_headers", + "@cub_archive//:cub", ]), alwayslink = 1, ) diff --git a/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc b/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc index ebacc2134e..7970c71e9a 100644 --- a/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc +++ b/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc @@ -17,10 +17,7 @@ limitations under the License. #define EIGEN_USE_GPU -// TODO: FIX CUDA Build -//#include "third_party/cub/device/device_reduce.cuh" -//#include "tensorflow/core/util/gpu_kernel_helper.h" - +#include "gpu/cub/device/device_reduce.cuh" #include "tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow_addons/layers/optical_flow.py b/tensorflow_addons/layers/optical_flow.py index 1cea06b931..cc41c2ef17 100644 --- a/tensorflow_addons/layers/optical_flow.py +++ b/tensorflow_addons/layers/optical_flow.py @@ -26,6 +26,7 @@ get_path_to_datafile("custom_ops/layers/_correlation_cost_ops.so")) +@tf.function def correlation_cost(input_a, input_b, kernel_size, diff --git a/tensorflow_addons/layers/optical_flow_test.py b/tensorflow_addons/layers/optical_flow_test.py index c39186ab11..7842608b50 100644 --- a/tensorflow_addons/layers/optical_flow_test.py +++ b/tensorflow_addons/layers/optical_flow_test.py @@ -43,88 +43,36 @@ def _forward(self, input_a, input_b, kernel_size, max_displacement, return output - def _forward_simple(self, data_format, gpu): + def _forward_simple(self, data_format): # cumbersome calculation by hand for a fixed input # we just test where zeros occurs and a few entries - if gpu and tf.test.is_gpu_available(): - self.skipTest('FIX GPU BUILD') + with test_utils.use_gpu(): + val = [[[[0, -6, 9, 5], [1, -5, 10, 3], [2, -4, 11, 1]], + [[3, -3, 12, -1], [4, -2, 13, -3], [5, -1, 14, -5]]], + [[[6, 0, 15, -7], [7, 1, 16, -9], [8, 2, 17, -11]], + [[9, 3, 18, -13], [10, 4, 19, -15], [11, 5, 20, -17]]]] - val = [[[[0, -6, 9, 5], [1, -5, 10, 3], [2, -4, 11, 1]], - [[3, -3, 12, -1], [4, -2, 13, -3], [5, -1, 14, -5]]], - [[[6, 0, 15, -7], [7, 1, 16, -9], [8, 2, 17, -11]], - [[9, 3, 18, -13], [10, 4, 19, -15], [11, 5, 20, -17]]]] + input_a = tf.constant(np.array(val), dtype=tf.float32) + valb = np.array(val).transpose(2, 3, 0, 1).reshape(2, 2, 3, 4) + input_b = tf.constant(valb, dtype=tf.float32) - input_a = tf.constant(np.array(val), dtype=tf.float32) - valb = np.array(val).transpose(2, 3, 0, 1).reshape(2, 2, 3, 4) - input_b = tf.constant(valb, dtype=tf.float32) + if data_format == 'channels_last': + input_a = tf.transpose(input_a, [0, 2, 3, 1]) + input_b = tf.transpose(input_b, [0, 2, 3, 1]) - if data_format == 'channels_last': - input_a = tf.transpose(input_a, [0, 2, 3, 1]) - input_b = tf.transpose(input_b, [0, 2, 3, 1]) + input_a_tensor = tf.convert_to_tensor(input_a, dtype=tf.float32) + input_b_tensor = tf.convert_to_tensor(input_b, dtype=tf.float32) - input_a_tensor = tf.convert_to_tensor(input_a, dtype=tf.float32) - input_b_tensor = tf.convert_to_tensor(input_b, dtype=tf.float32) + kernel_size = 1 + max_displacement = 2 + stride_1 = 1 + stride_2 = 2 + pad = 4 - kernel_size = 1 - max_displacement = 2 - stride_1 = 1 - stride_2 = 2 - pad = 4 - - actual = self._forward( - input_a_tensor, - input_b_tensor, - kernel_size=kernel_size, - max_displacement=max_displacement, - stride_1=stride_1, - stride_2=stride_2, - pad=pad, - data_format=data_format) - - if data_format == 'channels_last': - # NHWC -> NCHW - actual = tf.transpose(actual, [0, 3, 1, 2]) - - # We can test fixed ids, as output is independent from data_format - expected_ids = np.concatenate([np.zeros(464,), np.ones(464,)]) - self.assertAllClose(tf.where(tf.equal(actual, 0))[:, 0], expected_ids) - - counts = [54, 52, 54, 50, 44, 50, 54, 52, 54] - expected_ids = np.concatenate( - [k * np.ones(v,) for k, v in enumerate(counts)]) - expected_ids = np.concatenate([expected_ids, expected_ids]) - self.assertAllClose(tf.where(tf.equal(actual, 0))[:, 1], expected_ids) - self.assertEqual(actual.shape, (2, 9, 7, 8)) - - def _gradients(self, data_format, gpu): - - if gpu and tf.test.is_gpu_available(): - self.skipTest('FIX GPU BUILD') - - batch, channels, height, width = 2, 3, 5, 6 - input_a = np.random.randn(batch, channels, height, - width).astype(np.float32) - input_b = np.random.randn(batch, channels, height, - width).astype(np.float32) - - kernel_size = 1 - max_displacement = 2 - stride_1 = 1 - stride_2 = 2 - pad = 4 - - if data_format == 'channels_last': - input_a = tf.transpose(input_a, [0, 2, 3, 1]) - input_b = tf.transpose(input_b, [0, 2, 3, 1]) - - input_a_op = tf.convert_to_tensor(input_a) - input_b_op = tf.convert_to_tensor(input_b) - - def correlation_fn(input_a, input_b): - return correlation_cost( - input_a, - input_b, + actual = self._forward( + input_a_tensor, + input_b_tensor, kernel_size=kernel_size, max_displacement=max_displacement, stride_1=stride_1, @@ -132,77 +80,117 @@ def correlation_fn(input_a, input_b): pad=pad, data_format=data_format) - theoretical, numerical = tf.test.compute_gradient( - correlation_fn, [input_a_op, input_b_op]) - - self.assertAllClose(theoretical[0], numerical[0], atol=1e-3) - - def _keras(self, data_format, gpu): - if gpu and tf.test.is_gpu_available(): - self.skipTest('FIX GPU BUILD') - - val_a = [[[[0, -6, 9, 5], [1, -5, 10, 3], [2, -4, 11, 1]], - [[3, -3, 12, -1], [4, -2, 13, -3], [5, -1, 14, -5]]], - [[[6, 0, 15, -7], [7, 1, 16, -9], [8, 2, 17, -11]], - [[9, 3, 18, -13], [10, 4, 19, -15], [11, 5, 20, -17]]]] - val_b = np.array(val_a).transpose(2, 3, 0, 1).reshape(2, 2, 3, 4) - - # yapf: disable - input_a = tf.keras.Input(shape=(2, 3, 4,)) - input_b = tf.keras.Input(shape=(2, 3, 4,)) - - layer = CorrelationCost( - kernel_size=1, - max_displacement=2, - stride_1=1, - stride_2=2, - pad=4, - data_format=data_format) + if data_format == 'channels_last': + # NHWC -> NCHW + actual = tf.transpose(actual, [0, 3, 1, 2]) + + # We can test fixed ids, as output is independent from data_format + expected_ids = np.concatenate([np.zeros(464,), np.ones(464,)]) + self.assertAllClose(tf.where(tf.equal(actual, 0))[:, 0], expected_ids) + + counts = [54, 52, 54, 50, 44, 50, 54, 52, 54] + expected_ids = np.concatenate( + [k * np.ones(v,) for k, v in enumerate(counts)]) + expected_ids = np.concatenate([expected_ids, expected_ids]) + self.assertAllClose(tf.where(tf.equal(actual, 0))[:, 1], expected_ids) + self.assertEqual(actual.shape, (2, 9, 7, 8)) + + def _gradients(self, data_format): + with test_utils.use_gpu(): + batch, channels, height, width = 2, 3, 5, 6 + input_a = np.random.randn(batch, channels, height, + width).astype(np.float32) + input_b = np.random.randn(batch, channels, height, + width).astype(np.float32) + + kernel_size = 1 + max_displacement = 2 + stride_1 = 1 + stride_2 = 2 + pad = 4 + + if data_format == 'channels_last': + input_a = tf.transpose(input_a, [0, 2, 3, 1]) + input_b = tf.transpose(input_b, [0, 2, 3, 1]) + + input_a_op = tf.convert_to_tensor(input_a) + input_b_op = tf.convert_to_tensor(input_b) + + def correlation_fn(input_a, input_b): + return correlation_cost( + input_a, + input_b, + kernel_size=kernel_size, + max_displacement=max_displacement, + stride_1=stride_1, + stride_2=stride_2, + pad=pad, + data_format=data_format) + + theoretical, numerical = tf.test.compute_gradient( + correlation_fn, [input_a_op, input_b_op]) + + self.assertAllClose(theoretical[0], numerical[0], atol=1e-3) + + def _keras(self, data_format): + # Unable to use `layer_test` as this layer has multiple inputs + with test_utils.use_gpu(): + val_a = [[[[0, -6, 9, 5], [1, -5, 10, 3], [2, -4, 11, 1]], + [[3, -3, 12, -1], [4, -2, 13, -3], [5, -1, 14, -5]]], + [[[6, 0, 15, -7], [7, 1, 16, -9], [8, 2, 17, -11]], + [[9, 3, 18, -13], [10, 4, 19, -15], [11, 5, 20, -17]]]] + val_b = np.array(val_a).transpose(2, 3, 0, 1).reshape(2, 2, 3, 4) + + # yapf: disable + input_a = tf.keras.Input(shape=(2, 3, 4,)) + input_b = tf.keras.Input(shape=(2, 3, 4,)) + + layer = CorrelationCost( + kernel_size=1, + max_displacement=2, + stride_1=1, + stride_2=2, + pad=4, + data_format=data_format) - expected_output_shape = tuple( - layer.compute_output_shape([(2, 3, 4,), (2, 3, 4,)]))[1:] - # yapf: enable + expected_output_shape = tuple( + layer.compute_output_shape([(2, 3, 4,), (2, 3, 4,)]))[1:] + # yapf: enable - x = [input_a, input_b] - y = layer(x) - model = tf.python.keras.models.Model(x, y) - actual_output = model.predict([val_a, val_b]) + x = [input_a, input_b] + y = layer(x) + model = tf.python.keras.models.Model(x, y) + actual_output = model.predict([val_a, val_b]) - expected_output_type = 'float32' - if tf.keras.backend.dtype(y[0]) != expected_output_type: - raise AssertionError( - "Inferred output type %s does not equal " - "expected output type %s" % (tf.keras.backend.dtype(y[0]), - expected_output_type)) + expected_output_type = 'float32' + if tf.keras.backend.dtype(y[0]) != expected_output_type: + raise AssertionError( + "Inferred output type %s does not equal " + "expected output type %s" % (tf.keras.backend.dtype(y[0]), + expected_output_type)) - if actual_output[0].shape != expected_output_shape: - raise AssertionError( - "Expected shape %s does not equal output shape" - "%s" % (actual_output[0].shape, expected_output_shape)) + if actual_output[0].shape != expected_output_shape: + raise AssertionError( + "Expected shape %s does not equal output shape" + "%s" % (actual_output[0].shape, expected_output_shape)) def testForwardNCHW(self): - self._forward_simple(data_format='channels_first', gpu=False) - self._forward_simple(data_format='channels_first', gpu=True) + self._forward_simple(data_format='channels_first') def testForwardNHWC(self): - self._forward_simple(data_format='channels_last', gpu=False) - self._forward_simple(data_format='channels_last', gpu=True) + self._forward_simple(data_format='channels_last') def testBackwardNCHW(self): - self._gradients(data_format='channels_first', gpu=False) - self._gradients(data_format='channels_first', gpu=True) + self._gradients(data_format='channels_first') def testBackwardNHWC(self): - self._gradients(data_format='channels_last', gpu=False) - self._gradients(data_format='channels_last', gpu=True) + self._gradients(data_format='channels_last') def testKerasNCHW(self): - self._keras(data_format='channels_first', gpu=False) - self._keras(data_format='channels_first', gpu=True) + self._keras(data_format='channels_first') def testKerasNHWC(self): - self._keras(data_format='channels_last', gpu=False) - self._keras(data_format='channels_last', gpu=True) + self._keras(data_format='channels_last') if __name__ == "__main__": From 2b4ee8c771ff672e0f0387293b945b7ba14c2905 Mon Sep 17 00:00:00 2001 From: Sean Morgan Date: Fri, 26 Jul 2019 11:16:34 -0400 Subject: [PATCH 15/18] Lint --- build_deps/gpu/cub.BUILD | 3 +-- .../cc/kernels/correlation_cost_op_gpu.cu.cc | 2 +- tensorflow_addons/layers/optical_flow_test.py | 22 ++++++++++--------- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/build_deps/gpu/cub.BUILD b/build_deps/gpu/cub.BUILD index 5de7218cda..cdc9e4f377 100644 --- a/build_deps/gpu/cub.BUILD +++ b/build_deps/gpu/cub.BUILD @@ -8,7 +8,6 @@ package( licenses(["notice"]) # BSD - filegroup( name = "cub_header_files", srcs = glob([ @@ -23,4 +22,4 @@ cc_library( deps = [ "@local_config_cuda//cuda:cuda_headers", ], -) \ No newline at end of file +) diff --git a/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc b/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc index 7970c71e9a..27d3375043 100644 --- a/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc +++ b/tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc @@ -17,8 +17,8 @@ limitations under the License. #define EIGEN_USE_GPU -#include "gpu/cub/device/device_reduce.cuh" #include "tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op.h" +#include "gpu/cub/device/device_reduce.cuh" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/util/tensor_format.h" diff --git a/tensorflow_addons/layers/optical_flow_test.py b/tensorflow_addons/layers/optical_flow_test.py index 7842608b50..bc67224bf3 100644 --- a/tensorflow_addons/layers/optical_flow_test.py +++ b/tensorflow_addons/layers/optical_flow_test.py @@ -86,13 +86,15 @@ def _forward_simple(self, data_format): # We can test fixed ids, as output is independent from data_format expected_ids = np.concatenate([np.zeros(464,), np.ones(464,)]) - self.assertAllClose(tf.where(tf.equal(actual, 0))[:, 0], expected_ids) + self.assertAllClose( + tf.where(tf.equal(actual, 0))[:, 0], expected_ids) counts = [54, 52, 54, 50, 44, 50, 54, 52, 54] expected_ids = np.concatenate( [k * np.ones(v,) for k, v in enumerate(counts)]) expected_ids = np.concatenate([expected_ids, expected_ids]) - self.assertAllClose(tf.where(tf.equal(actual, 0))[:, 1], expected_ids) + self.assertAllClose( + tf.where(tf.equal(actual, 0))[:, 1], expected_ids) self.assertEqual(actual.shape, (2, 9, 7, 8)) def _gradients(self, data_format): @@ -118,14 +120,14 @@ def _gradients(self, data_format): def correlation_fn(input_a, input_b): return correlation_cost( - input_a, - input_b, - kernel_size=kernel_size, - max_displacement=max_displacement, - stride_1=stride_1, - stride_2=stride_2, - pad=pad, - data_format=data_format) + input_a, + input_b, + kernel_size=kernel_size, + max_displacement=max_displacement, + stride_1=stride_1, + stride_2=stride_2, + pad=pad, + data_format=data_format) theoretical, numerical = tf.test.compute_gradient( correlation_fn, [input_a_op, input_b_op]) From 335dcfe4e28164581a31342831fc848cf198faa8 Mon Sep 17 00:00:00 2001 From: Sean Morgan Date: Mon, 29 Jul 2019 09:23:45 -0400 Subject: [PATCH 16/18] Minor changes --- tensorflow_addons/layers/optical_flow.py | 10 ++++++++-- tensorflow_addons/layers/optical_flow_test.py | 8 +++----- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/tensorflow_addons/layers/optical_flow.py b/tensorflow_addons/layers/optical_flow.py index cc41c2ef17..de65e6bf4a 100644 --- a/tensorflow_addons/layers/optical_flow.py +++ b/tensorflow_addons/layers/optical_flow.py @@ -1,4 +1,4 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -140,7 +140,7 @@ def _correlation_cost_grad(op, grad_output): @keras_utils.register_keras_custom_object -class CorrelationCost(tf.python.keras.layers.Layer): +class CorrelationCost(tf.keras.layers.Layer): def __init__(self, kernel_size, max_displacement, stride_1, stride_2, pad, data_format, **kwargs): self.kernel_size = kernel_size @@ -148,7 +148,13 @@ def __init__(self, kernel_size, max_displacement, stride_1, stride_2, pad, self.stride_1 = stride_1 self.stride_2 = stride_2 self.pad = pad + + if data_format != "channels_last" or data_format == "channels_first": + raise ValueError("`data_format` must be either `channels_last` or" + "`channels_first`") + self.data_format = data_format + super(CorrelationCost, self).__init__(**kwargs) def build(self, input_shape): diff --git a/tensorflow_addons/layers/optical_flow_test.py b/tensorflow_addons/layers/optical_flow_test.py index bc67224bf3..060f572c5f 100644 --- a/tensorflow_addons/layers/optical_flow_test.py +++ b/tensorflow_addons/layers/optical_flow_test.py @@ -1,4 +1,4 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -44,9 +44,7 @@ def _forward(self, input_a, input_b, kernel_size, max_displacement, return output def _forward_simple(self, data_format): - # cumbersome calculation by hand for a fixed input - # we just test where zeros occurs and a few entries - + # We are just testing where the output has vanishing values. with test_utils.use_gpu(): val = [[[[0, -6, 9, 5], [1, -5, 10, 3], [2, -4, 11, 1]], [[3, -3, 12, -1], [4, -2, 13, -3], [5, -1, 14, -5]]], @@ -135,7 +133,7 @@ def correlation_fn(input_a, input_b): self.assertAllClose(theoretical[0], numerical[0], atol=1e-3) def _keras(self, data_format): - # Unable to use `layer_test` as this layer has multiple inputs + # Unable to use `layer_test` as this layer has multiple inputs. with test_utils.use_gpu(): val_a = [[[[0, -6, 9, 5], [1, -5, 10, 3], [2, -4, 11, 1]], [[3, -3, 12, -1], [4, -2, 13, -3], [5, -1, 14, -5]]], From 14c8db10f181e822520e623f4e585d7c21b4bffc Mon Sep 17 00:00:00 2001 From: Sean Morgan Date: Mon, 29 Jul 2019 10:06:12 -0400 Subject: [PATCH 17/18] Update error --- tensorflow_addons/layers/optical_flow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow_addons/layers/optical_flow.py b/tensorflow_addons/layers/optical_flow.py index de65e6bf4a..1c1f92513f 100644 --- a/tensorflow_addons/layers/optical_flow.py +++ b/tensorflow_addons/layers/optical_flow.py @@ -149,9 +149,9 @@ def __init__(self, kernel_size, max_displacement, stride_1, stride_2, pad, self.stride_2 = stride_2 self.pad = pad - if data_format != "channels_last" or data_format == "channels_first": + if data_format != "channels_last" or data_format != "channels_first": raise ValueError("`data_format` must be either `channels_last` or" - "`channels_first`") + "`channels_first`, instead got %s" % data_format) self.data_format = data_format From 205d5c7e0b5ab948e366a48c374144c4f20b4bf0 Mon Sep 17 00:00:00 2001 From: Sean Morgan Date: Mon, 29 Jul 2019 11:44:57 -0400 Subject: [PATCH 18/18] Fix mistake --- tensorflow_addons/layers/optical_flow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/layers/optical_flow.py b/tensorflow_addons/layers/optical_flow.py index 1c1f92513f..c2c843548a 100644 --- a/tensorflow_addons/layers/optical_flow.py +++ b/tensorflow_addons/layers/optical_flow.py @@ -149,7 +149,7 @@ def __init__(self, kernel_size, max_displacement, stride_1, stride_2, pad, self.stride_2 = stride_2 self.pad = pad - if data_format != "channels_last" or data_format != "channels_first": + if data_format != "channels_last" and data_format != "channels_first": raise ValueError("`data_format` must be either `channels_last` or" "`channels_first`, instead got %s" % data_format)