From 30143476a3ff04495a0c3e11e4f23a553f4611b3 Mon Sep 17 00:00:00 2001 From: Licht Takeuchi Date: Thu, 8 Oct 2020 04:06:54 +0900 Subject: [PATCH 01/13] Add DeformableConv2D layer --- tensorflow_addons/custom_ops/layers/BUILD | 14 + .../layers/cc/kernels/deformable_conv2d_op.cc | 777 ++++++++++++++++++ .../layers/cc/kernels/deformable_conv2d_op.h | 271 ++++++ .../layers/cc/ops/deformable_conv2d_op.cc | 246 ++++++ tensorflow_addons/layers/BUILD | 1 + tensorflow_addons/layers/__init__.py | 1 + tensorflow_addons/layers/deformable_conv2d.py | 332 ++++++++ .../layers/tests/deformable_conv2d_test.py | 420 ++++++++++ 8 files changed, 2062 insertions(+) create mode 100644 tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc create mode 100644 tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.h create mode 100644 tensorflow_addons/custom_ops/layers/cc/ops/deformable_conv2d_op.cc create mode 100644 tensorflow_addons/layers/deformable_conv2d.py create mode 100644 tensorflow_addons/layers/tests/deformable_conv2d_test.py diff --git a/tensorflow_addons/custom_ops/layers/BUILD b/tensorflow_addons/custom_ops/layers/BUILD index 5ff49efe99..e9755770f8 100644 --- a/tensorflow_addons/custom_ops/layers/BUILD +++ b/tensorflow_addons/custom_ops/layers/BUILD @@ -19,3 +19,17 @@ custom_op_library( "cc/kernels/correlation_cost_op_gpu.cu.cc", ], ) + +custom_op_library( + name = "_deformable_conv2d_ops.so", + srcs = [ + "cc/kernels/deformable_conv2d_op.cc", + "cc/kernels/deformable_conv2d_op.h", + "cc/ops/deformable_conv2d_op.cc", + ], + cuda_deps = [ + "@cub_archive//:cub", + ], + cuda_srcs = [ + ], +) diff --git a/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc b/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc new file mode 100644 index 0000000000..21d136ace6 --- /dev/null +++ b/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc @@ -0,0 +1,777 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#define EIGEN_USE_THREADS + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA + +#include "tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.h" + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { +namespace addons { + +using CPUDevice = Eigen::ThreadPoolDevice; +using GPUDevice = Eigen::GpuDevice; + +namespace functor { + +template +struct DeformableConv2DFunctor + : public DeformableConv2DFunctorBase { + using DeformableConv2DFunctorBase::_input_tensor; + using DeformableConv2DFunctorBase::_filter_tensor; + using DeformableConv2DFunctorBase::_bias_tensor; + using DeformableConv2DFunctorBase::_offset_tensor; + using DeformableConv2DFunctorBase::_mask_tensor; + using DeformableConv2DFunctorBase::_column_buffer_tensor; + using DeformableConv2DFunctorBase::p; + + DeformableConv2DFunctor( + typename TTypes::ConstTensor input_tensor, + typename TTypes::ConstTensor filter_tensor, + typename TTypes::ConstTensor bias_tensor, + typename TTypes::ConstTensor offset_tensor, + typename TTypes::ConstTensor mask_tensor, + typename TTypes::Tensor column_buffer_tensor, + typename TTypes::Tensor output_tensor, + DeformableConv2DParams _p) + : DeformableConv2DFunctorBase( + input_tensor, filter_tensor, bias_tensor, offset_tensor, + mask_tensor, column_buffer_tensor, _p), + _output_tensor(output_tensor) { + _output_tensor.setZero(); + } + + Status operator()(OpKernelContext *context) { + const auto use_bias = _bias_tensor.dimension(0) > 0; + const auto batches = p.input_batches / p.parallel_imgs; + + auto filter_tensor = _filter_tensor.reshape( + Shape5D({p.weight_groups, p.output_channels / p.weight_groups, + p.filter_channels, p.filter_rows, p.filter_cols})); + + auto output_tensor = _output_tensor.reshape( + Shape5D({batches, p.weight_groups, p.output_channels / p.weight_groups, + p.parallel_imgs * p.output_rows, p.output_cols})); + + // input_channels * filter_rows * filter_cols / weight_groups == + // filter_channels * filter_rows * filter_cols + const auto elems = p.filter_channels * p.filter_rows * p.filter_cols; + const auto rows = p.output_channels / p.weight_groups; + const auto cols = p.parallel_imgs * p.output_rows * p.output_cols; + + auto column_buffer_tensor = + _column_buffer_tensor.reshape(Shape3D({p.weight_groups, elems, cols})); + + for (auto b = 0; b < batches; b++) { + auto output_tensor_batch = output_tensor.chip(b, 0); + + this->DeformableIm2Col(b); + + for (auto g = 0; g < p.weight_groups; g++) { + EigenTensor filter_mtx = + filter_tensor.chip(g, 0).reshape(Shape2D({rows, elems})); + EigenTensor column_buffer_mtx = + column_buffer_tensor.chip(g, 0); + + auto mtx_shape = Shape2D({rows, cols}); + Eigen::array, 1> product_dims = { + Eigen::IndexPair(1, 0)}; + + EigenTensor mul = + filter_mtx.contract(column_buffer_mtx, product_dims); + + output_tensor_batch.chip(g, 0).reshape(mtx_shape) += mul; + } + } + + auto output_tensor_transposed = + output_tensor + .reshape(Shape5D({batches, p.output_channels, p.parallel_imgs, + p.output_rows, p.output_cols})) + .shuffle(Shape5D({0, 2, 1, 3, 4})) + .reshape(Shape4D({p.input_batches, p.output_channels, p.output_rows, + p.output_cols})); + + _output_tensor = output_tensor_transposed.eval(); + + if (use_bias) { + auto bias_tensor_broadcasted = + _bias_tensor.reshape(Shape4D({1, p.output_channels, 1, 1})) + .broadcast( + Shape4D({p.input_batches, 1, p.output_rows, p.output_cols})); + + _output_tensor += bias_tensor_broadcasted; + } + + return Status::OK(); + } + + typename TTypes::Tensor _output_tensor; +}; + +template +struct DeformableConv2DGradFunctor + : public DeformableConv2DFunctorBase { + using DeformableConv2DFunctorBase::_input_tensor; + using DeformableConv2DFunctorBase::_filter_tensor; + using DeformableConv2DFunctorBase::_bias_tensor; + using DeformableConv2DFunctorBase::_offset_tensor; + using DeformableConv2DFunctorBase::_mask_tensor; + using DeformableConv2DFunctorBase::_column_buffer_tensor; + using DeformableConv2DFunctorBase::p; + + DeformableConv2DGradFunctor( + typename TTypes::ConstTensor input_tensor, + typename TTypes::ConstTensor filter_tensor, + typename TTypes::ConstTensor bias_tensor, + typename TTypes::ConstTensor offset_tensor, + typename TTypes::ConstTensor mask_tensor, + typename TTypes::ConstTensor output_grad_tensor, + typename TTypes::Tensor input_grad_tensor, + typename TTypes::Tensor filter_grad_tensor, + typename TTypes::Tensor bias_grad_tensor, + typename TTypes::Tensor offset_grad_tensor, + typename TTypes::Tensor mask_grad_tensor, + typename TTypes::Tensor column_buffer_tensor, + DeformableConv2DParams _p) + : DeformableConv2DFunctorBase( + input_tensor, filter_tensor, bias_tensor, offset_tensor, + mask_tensor, column_buffer_tensor, _p), + _output_grad_tensor(output_grad_tensor), + _input_grad_tensor(input_grad_tensor), + _filter_grad_tensor(filter_grad_tensor), + _bias_grad_tensor(bias_grad_tensor), + _offset_grad_tensor(offset_grad_tensor), + _mask_grad_tensor(mask_grad_tensor) { + _input_grad_tensor.setZero(); + _filter_grad_tensor.setZero(); + _column_buffer_tensor.setZero(); + } + + Status operator()(OpKernelContext *context) { + const auto use_bias = _bias_tensor.dimension(0) > 0; + + ComputeInputOffsetMaskGrad(); + + ComputeFilterGrad(); + + if (use_bias) { + _bias_grad_tensor.setConstant(Dtype(1)); + _bias_grad_tensor *= + _output_grad_tensor.sum(Eigen::array({0, 2, 3})); + } + + return Status::OK(); + } + + void ComputeFilterGrad() { + const auto batches = p.input_batches / p.parallel_imgs; + + auto filter_grad_tensor = _filter_grad_tensor.reshape( + Shape5D({p.weight_groups, p.output_channels / p.weight_groups, + p.filter_channels, p.filter_rows, p.filter_cols})); + + EigenTensor output_grad_tensor = + _output_grad_tensor + .reshape(Shape5D({batches, p.parallel_imgs, p.output_channels, + p.output_rows, p.output_cols})) + .shuffle(Shape5D({0, 2, 1, 3, 4})) + .reshape(Shape5D({batches, p.weight_groups, + p.output_channels / p.weight_groups, + p.parallel_imgs * p.output_rows, p.output_cols})); + + // input_channels * filter_rows * filter_cols / weight_groups == + // filter_channels * filter_rows * filter_cols + const auto elems = p.filter_channels * p.filter_rows * p.filter_cols; + const auto rows = p.output_channels / p.weight_groups; + const auto cols = p.parallel_imgs * p.output_rows * p.output_cols; + + auto column_buffer_tensor = + _column_buffer_tensor.reshape(Shape3D({p.weight_groups, elems, cols})); + + for (auto b = 0; b < batches; b++) { + auto output_grad_tensor_batch = output_grad_tensor.chip(b, 0); + + this->DeformableIm2Col(b); + + for (auto g = 0; g < p.weight_groups; g++) { + EigenTensor column_buffer_mtx = + column_buffer_tensor.chip(g, 0).shuffle(Shape2D({1, 0})); + + EigenTensor output_grad_mtx = + output_grad_tensor_batch.chip(g, 0).reshape(Shape2D({rows, cols})); + + Eigen::array, 1> product_dims = { + Eigen::IndexPair(1, 0)}; + + EigenTensor mul = + output_grad_mtx.contract(column_buffer_mtx, product_dims); + + filter_grad_tensor.chip(g, 0).reshape(Shape2D({rows, elems})) += mul; + } + } + } + + void ComputeInputOffsetMaskGrad() { + auto batches = p.input_batches / p.parallel_imgs; + + EigenTensor filter_tensor = _filter_tensor.reshape( + Shape5D({p.weight_groups, p.output_channels / p.weight_groups, + p.filter_channels, p.filter_rows, p.filter_cols})); + + EigenTensor output_grad_tensor = + _output_grad_tensor + .reshape(Shape5D({batches, p.parallel_imgs, p.output_channels, + p.output_rows, p.output_cols})) + .shuffle(Shape5D({0, 2, 1, 3, 4})) + .reshape(Shape5D({batches, p.weight_groups, + p.output_channels / p.weight_groups, + p.parallel_imgs * p.output_rows, p.output_cols})); + + // input_channels * filter_rows * filter_cols / weight_groups == + // filter_channels * filter_rows * filter_cols + const auto rows = p.filter_channels * p.filter_rows * p.filter_cols; + const auto elems = p.output_channels / p.weight_groups; + const auto cols = p.parallel_imgs * p.output_rows * p.output_cols; + + auto column_buffer_tensor = + _column_buffer_tensor.reshape(Shape3D({p.weight_groups, rows, cols})); + + for (auto b = 0; b < batches; b++) { + _column_buffer_tensor.setZero(); + + auto output_grad_tensor_chipped = output_grad_tensor.chip(b, 0); + for (int g = 0; g < p.weight_groups; g++) { + EigenTensor filter_mtx = filter_tensor.chip(g, 0) + .reshape(Shape2D({elems, rows})) + .shuffle(Shape2D({1, 0})); + EigenTensor output_grad_mtx = + output_grad_tensor_chipped.chip(g, 0).reshape( + Shape2D({elems, cols})); + + Eigen::array, 1> product_dims = { + Eigen::IndexPair(1, 0)}; + + EigenTensor mul = + filter_mtx.contract(output_grad_mtx, product_dims); + + column_buffer_tensor.chip(g, 0) = mul; + } + + DeformableCol2ImForOffsetAndMask(b); + + DeformableCol2ImForInput(b); + } + } + + void DeformableCol2ImForOffsetAndMask(int32 b) { + auto use_mask = _mask_tensor.dimension(0) > 0; + auto batches = p.input_batches / p.parallel_imgs; + auto num_kernels = p.output_rows * p.output_cols * 2 * p.filter_rows * + p.filter_cols * p.offset_groups * p.parallel_imgs; + auto offset_channels = 2 * p.filter_rows * p.filter_cols * p.offset_groups; + + EigenTensor input_tensor = + _input_tensor + .reshape(Shape5D({batches, p.parallel_imgs, p.input_channels, + p.input_rows, p.input_cols})) + .chip(b, 0); + + EigenTensor offset_tensor = + _offset_tensor + .reshape(Shape8D({batches, p.parallel_imgs, p.offset_groups, + p.filter_rows, p.filter_cols, 2, p.output_rows, + p.output_cols})) + .chip(b, 0); + + EigenTensor mask_tensor = + use_mask ? static_cast>( + _mask_tensor + .reshape(Shape7D({batches, p.parallel_imgs, + p.offset_groups, p.filter_rows, + p.filter_cols, p.output_rows, + p.output_cols})) + .chip(b, 0)) + : _mask_tensor.reshape(Shape6D({0, 0, 0, 0, 0, 0})); + + EigenTensor column_buffer_tensor = _column_buffer_tensor.reshape( + Shape6D({p.input_channels, p.filter_rows, p.filter_cols, + p.parallel_imgs, p.output_rows, p.output_cols})); + + for (auto k = 0; k < num_kernels; k++) { + auto offset_grad_value = Dtype(0); + auto mask_grad_value = Dtype(0); + + const auto offset_channel_step = p.filter_rows * p.filter_cols; + + const auto current_output_col = k % p.output_cols; + const auto current_output_row = (k / p.output_cols) % p.output_rows; + const auto current_filter_col = + (k / (2 * p.output_rows * p.output_cols)) % p.filter_cols; + const auto current_filter_row = + (k / (2 * p.output_rows * p.output_cols * p.filter_cols)) % + p.filter_rows; + const auto current_offset_channel = + (k / (p.output_rows * p.output_cols)) % offset_channels; + const auto current_batch = + k / (p.output_rows * p.output_cols * offset_channels); + + auto current_actual_batch = b * p.parallel_imgs + current_batch; + + const auto current_offset_group = + current_offset_channel / (2 * offset_channel_step); + + const auto channels_per_offset_group = p.input_channels / p.offset_groups; + const auto offset_channel_diff = + current_offset_channel - + current_offset_group * 2 * offset_channel_step; + const auto is_y_direction = offset_channel_diff % 2 == 0; + + for (auto selected_offset_channel = (offset_channel_diff / 2); + selected_offset_channel < + channels_per_offset_group * offset_channel_step; + selected_offset_channel += offset_channel_step) { + const auto selected_filter_col = + selected_offset_channel % p.filter_cols; + const auto selected_filter_row = + (selected_offset_channel / p.filter_cols) % p.filter_rows; + const auto input_channel_diff = + (selected_offset_channel / (p.filter_cols * p.filter_rows)); + + const auto offset_h = offset_tensor( + current_batch, current_offset_group, selected_filter_row, + selected_filter_col, 0, current_output_row, current_output_col); + const auto offset_w = offset_tensor( + current_batch, current_offset_group, selected_filter_row, + selected_filter_col, 1, current_output_row, current_output_col); + const auto mask = use_mask + ? static_cast(mask_tensor( + current_batch, current_offset_group, + selected_filter_row, selected_filter_col, + current_output_row, current_output_col)) + : Dtype(1); + + const auto y = (current_output_row * p.stride_rows - p.padding_rows) + + selected_filter_row * p.dilation_rows + offset_h; + const auto x = (current_output_col * p.stride_cols - p.padding_cols) + + selected_filter_col * p.dilation_cols + offset_w; + + const auto selected_input_channel = + input_channel_diff + + current_offset_group * channels_per_offset_group; + + auto filter_data = column_buffer_tensor( + selected_input_channel, selected_filter_row, selected_filter_col, + current_batch, current_output_row, current_output_col); + + const auto weight = GetCoordinateWeight( + current_actual_batch, selected_input_channel, y, x, is_y_direction); + + offset_grad_value += mask * weight * filter_data; + + if (is_y_direction) { + mask_grad_value += filter_data * this->BilinearInterpolate( + current_actual_batch, + selected_input_channel, y, x); + } + } + + _offset_grad_tensor(current_actual_batch, current_offset_channel, + current_output_row, current_output_col) = + offset_grad_value; + + if (use_mask && is_y_direction) { + auto current_mask_channel = + (current_offset_group * p.filter_rows + current_filter_row) * + p.filter_cols + + current_filter_col; + + _mask_grad_tensor(current_actual_batch, current_mask_channel, + current_output_row, current_output_col) = + mask_grad_value; + } + } + } + + void DeformableCol2ImForInput(int32 b) { + auto use_mask = _mask_tensor.dimension(0) > 0; + auto batches = p.input_batches / p.parallel_imgs; + auto num_kernels = p.input_channels * p.filter_rows * p.filter_cols * + p.output_rows * p.output_cols * p.parallel_imgs; + + EigenTensor offset_tensor = + _offset_tensor + .reshape(Shape8D({batches, p.parallel_imgs, p.offset_groups, + p.filter_rows, p.filter_cols, 2, p.output_rows, + p.output_cols})) + .chip(b, 0); + + EigenTensor mask_tensor = + use_mask ? static_cast>( + _mask_tensor + .reshape(Shape7D({batches, p.parallel_imgs, + p.offset_groups, p.filter_rows, + p.filter_cols, p.output_rows, + p.output_cols})) + .chip(b, 0)) + : _mask_tensor.reshape(Shape6D({0, 0, 0, 0, 0, 0})); + + EigenTensor column_buffer_tensor_flattened = + _column_buffer_tensor.reshape(Shape1D({num_kernels})); + + for (auto k = 0; k < num_kernels; k++) { + const auto current_output_col = k % p.output_cols; + const auto current_output_row = (k / p.output_cols) % p.output_rows; + const auto current_batch = + (k / (p.output_rows * p.output_cols)) % p.parallel_imgs; + + const auto current_filter_col = + (k / (p.output_rows * p.output_cols * p.parallel_imgs)) % + p.filter_cols; + const auto current_filter_row = (k / (p.output_rows * p.output_cols * + p.parallel_imgs * p.filter_cols)) % + p.filter_rows; + const auto current_channel = + k / (p.output_rows * p.output_cols * p.parallel_imgs * p.filter_rows * + p.filter_cols); + + const auto current_offset_group = + current_channel / (p.input_channels / p.offset_groups); + + auto mask = use_mask ? mask_tensor(current_batch, current_offset_group, + current_filter_row, current_filter_col, + current_output_row, current_output_col) + : Dtype(1); + + auto offset_h = offset_tensor(current_batch, current_offset_group, + current_filter_row, current_filter_col, 0, + current_output_row, current_output_col); + auto offset_w = offset_tensor(current_batch, current_offset_group, + current_filter_row, current_filter_col, 1, + current_output_row, current_output_col); + + const auto y = (current_output_row * p.stride_rows - p.padding_rows) + + current_filter_row * p.dilation_rows + offset_h; + const auto x = (current_output_col * p.stride_cols - p.padding_cols) + + current_filter_col * p.dilation_cols + offset_w; + + for (auto dy = -1; dy <= 1; dy++) { + for (auto dx = -1; dx <= 1; dx++) { + auto current_input_row = int(y) + dy; + auto current_input_col = int(x) + dx; + if (p.input_rows > current_input_row && current_input_row >= 0 && + p.input_cols > current_input_col && current_input_col >= 0 && + std::abs(y - current_input_row) < 1 && + std::abs(x - current_input_col) < 1) { + auto weight = (1.0 - std::abs(y - current_input_row)) * + (1.0 - std::abs(x - current_input_col)); + + auto current_actual_batch = b * p.parallel_imgs + current_batch; + + _input_grad_tensor(current_actual_batch, current_channel, + current_input_row, current_input_col) += + mask * weight * column_buffer_tensor_flattened(k); + } + } + } + } + } + + Dtype GetCoordinateWeight(int32 batch, int32 channel, Dtype y, Dtype x, + bool is_y_direction) { + EigenTensor img = _input_tensor.chip(batch, 0).chip(channel, 0); + + auto max_height = img.dimension(0); + auto max_width = img.dimension(1); + + int y_low = floor(y); + int x_low = floor(x); + int y_high = y_low + 1; + int x_high = x_low + 1; + + bool valid_y_low = max_height > y_low && y_low >= 0; + bool valid_y_high = max_height > y_high && y_high >= 0; + bool valid_x_low = max_width > x_low && x_low >= 0; + bool valid_x_high = max_width > x_high && x_high >= 0; + + auto v_yx = Dtype(0); + if (valid_y_low && valid_x_low) { + v_yx = img(y_low, x_low); + } + + auto v_yX = Dtype(0); + if (valid_y_low && valid_x_high) { + v_yX = img(y_low, x_high); + } + + auto v_Yx = Dtype(0); + if (valid_y_high && valid_x_low) { + v_Yx = img(y_high, x_low); + } + + auto v_YX = Dtype(0); + if (valid_y_high && valid_x_high) { + v_YX = img(y_high, x_high); + } + + if (is_y_direction) { + auto dx = x - x_low; + return (v_YX - v_yX) * dx + (v_Yx - v_yx) * (1 - dx); + } else { + auto dy = y - y_low; + return (v_YX - v_Yx) * dy + (v_yX - v_yx) * (1 - dy); + } + } + + typename TTypes::ConstTensor _output_grad_tensor; + typename TTypes::Tensor _input_grad_tensor; + typename TTypes::Tensor _filter_grad_tensor; + typename TTypes::Tensor _bias_grad_tensor; + typename TTypes::Tensor _offset_grad_tensor; + typename TTypes::Tensor _mask_grad_tensor; +}; + +} // end namespace functor + +template +class DeformableConv2DOpBase : public OpKernel { + public: + explicit DeformableConv2DOpBase(OpKernelConstruction *context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("strides", &strides)); + OP_REQUIRES_OK(context, context->GetAttr("weight_groups", &weight_groups)); + OP_REQUIRES_OK(context, context->GetAttr("offset_groups", &offset_groups)); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding)); + OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations)); + string data_format_str; + OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str)); + FormatFromString(data_format_str, &data_format); + + p = DeformableConv2DParams{}; + } + + void Compute(OpKernelContext *context) override { + const Tensor &input_tensor = context->input(0); + const Tensor &filter_tensor = context->input(1); + + const TensorShape &input_shape = input_tensor.shape(); + const TensorShape &filter_shape = filter_tensor.shape(); + + auto input_batches = input_shape.dim_size(0); + auto input_channels = input_shape.dim_size(1); + auto input_rows = input_shape.dim_size(2); + auto input_cols = input_shape.dim_size(3); + + auto output_channels = filter_shape.dim_size(0); + auto filter_channels = filter_shape.dim_size(1); + auto filter_rows = filter_shape.dim_size(2); + auto filter_cols = filter_shape.dim_size(3); + + auto dilation_rows = dilations[0]; + auto dilation_cols = dilations[1]; + + auto stride_rows = strides[0]; + auto stride_cols = strides[1]; + + auto parallel_imgs = GetParallelImgs(input_batches); + + int64 output_rows, output_cols; + int64 padding_rows, padding_cols; + OP_REQUIRES_OK( + context, GetWindowedOutputSizeV2(input_rows, filter_rows, dilation_rows, + stride_rows, padding, &output_rows, + &padding_rows)); + OP_REQUIRES_OK( + context, GetWindowedOutputSizeV2(input_cols, filter_cols, dilation_cols, + stride_cols, padding, &output_cols, + &padding_cols)); + + p.input_batches = input_batches; + p.input_channels = input_channels; + p.input_rows = input_rows; + p.input_cols = input_cols; + p.filter_channels = filter_channels; + p.filter_rows = filter_rows; + p.filter_cols = filter_cols; + p.padding_rows = padding_rows; + p.padding_cols = padding_cols; + p.stride_rows = stride_rows; + p.stride_cols = stride_cols; + p.dilation_rows = dilation_rows; + p.dilation_cols = dilation_cols; + p.output_channels = output_channels; + p.output_rows = output_rows; + p.output_cols = output_cols; + p.parallel_imgs = parallel_imgs; + p.weight_groups = weight_groups; + p.offset_groups = offset_groups; + } + + int GetParallelImgs(int n) { + for (auto k = kMaxParallelImgs; k > 1; --k) { + if (n % k == 0) { + return k; + } + } + return 1; + } + + protected: + TensorFormat data_format; + DeformableConv2DParams p; + + private: + std::vector strides; + int32 weight_groups; + int32 offset_groups; + Padding padding; + std::vector dilations; +}; + +template +class DeformableConv2DOp : public DeformableConv2DOpBase { + using DeformableConv2DOpBase::data_format; + using DeformableConv2DOpBase::p; + + public: + explicit DeformableConv2DOp(OpKernelConstruction *context) + : DeformableConv2DOpBase(context) {} + + void Compute(OpKernelContext *context) override { + DeformableConv2DOpBase::Compute(context); + + const Tensor &input_tensor = context->input(0); + const Tensor &filter_tensor = context->input(1); + const Tensor &bias_tensor = context->input(2); + const Tensor &offset_tensor = context->input(3); + const Tensor &mask_tensor = context->input(4); + + TensorShape column_buffer_shape( + {p.input_channels * p.filter_rows * p.filter_cols, p.parallel_imgs, + p.output_rows, p.output_cols}); + Tensor column_buffer_tensor; + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, + column_buffer_shape, + &column_buffer_tensor)); + + TensorShape output_shape = + ShapeFromFormat(data_format, p.input_batches, p.output_rows, + p.output_cols, p.output_channels); + Tensor *output_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, output_shape, &output_tensor)); + + functor::DeformableConv2DFunctor deformableConv2DFunc( + input_tensor.tensor(), filter_tensor.tensor(), + bias_tensor.tensor(), offset_tensor.tensor(), + mask_tensor.tensor(), column_buffer_tensor.tensor(), + output_tensor->tensor(), p); + Status s = deformableConv2DFunc(context); + + OP_REQUIRES_OK(context, s); + } +}; + +template +class DeformableConv2DGradOp : public DeformableConv2DOpBase { + using DeformableConv2DOpBase::data_format; + using DeformableConv2DOpBase::p; + + public: + explicit DeformableConv2DGradOp(OpKernelConstruction *context) + : DeformableConv2DOpBase(context) {} + + void Compute(OpKernelContext *context) override { + DeformableConv2DOpBase::Compute(context); + + const Tensor &input_tensor = context->input(0); + const Tensor &filter_tensor = context->input(1); + const Tensor &bias_tensor = context->input(2); + const Tensor &offset_tensor = context->input(3); + const Tensor &mask_tensor = context->input(4); + const Tensor &output_grad_tensor = context->input(5); + + const TensorShape &input_shape = input_tensor.shape(); + const TensorShape &filter_shape = filter_tensor.shape(); + const TensorShape &bias_shape = bias_tensor.shape(); + const TensorShape &offset_shape = offset_tensor.shape(); + const TensorShape &mask_shape = mask_tensor.shape(); + + TensorShape column_buffer_shape( + {p.input_channels * p.filter_rows * p.filter_cols, p.parallel_imgs, + p.output_rows, p.output_cols}); + Tensor column_buffer_tensor; + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, + column_buffer_shape, + &column_buffer_tensor)); + + TensorShape output_shape = + ShapeFromFormat(data_format, p.input_batches, p.output_rows, + p.output_cols, p.output_channels); + + Tensor *input_grad_tensor = nullptr; + OP_REQUIRES_OK( + context, context->allocate_output(0, input_shape, &input_grad_tensor)); + Tensor *filter_grad_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(1, filter_shape, + &filter_grad_tensor)); + Tensor *bias_grad_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(2, bias_shape, &bias_grad_tensor)); + Tensor *offset_grad_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(3, offset_shape, + &offset_grad_tensor)); + Tensor *mask_grad_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(4, mask_shape, &mask_grad_tensor)); + + functor::DeformableConv2DGradFunctor deformableConv2DGradFunc( + input_tensor.tensor(), filter_tensor.tensor(), + bias_tensor.tensor(), offset_tensor.tensor(), + mask_tensor.tensor(), output_grad_tensor.tensor(), + input_grad_tensor->tensor(), filter_grad_tensor->tensor(), + bias_grad_tensor->tensor(), offset_grad_tensor->tensor(), + mask_grad_tensor->tensor(), column_buffer_tensor.tensor(), + p); + Status s = deformableConv2DGradFunc(context); + + OP_REQUIRES_OK(context, s); + } +}; + +// Register the CPU kernels. +#define REGISTER_DEFORMABLECONV2D_OP_CPU(T) \ + REGISTER_KERNEL_BUILDER(Name("Addons>DeformableConv2D") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + DeformableConv2DOp) \ + REGISTER_KERNEL_BUILDER(Name("Addons>DeformableConv2DGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + DeformableConv2DGradOp) + +TF_CALL_float(REGISTER_DEFORMABLECONV2D_OP_CPU); +TF_CALL_double(REGISTER_DEFORMABLECONV2D_OP_CPU); +#undef REGISTER_DEFORMABLECONV2D_OP_CPU + +} // namespace addons +} // namespace tensorflow diff --git a/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.h b/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.h new file mode 100644 index 0000000000..fe9bba1f32 --- /dev/null +++ b/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.h @@ -0,0 +1,271 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef TENSORFLOW_ADDONS_LAYERS_KERNELS_DEFORMABLECONV2D_OP_H_ +#define TENSORFLOW_ADDONS_LAYERS_KERNELS_DEFORMABLECONV2D_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { +namespace addons { +using Shape8D = Eigen::array; +using Shape7D = Eigen::array; +using Shape6D = Eigen::array; +using Shape5D = Eigen::array; +using Shape4D = Eigen::array; +using Shape3D = Eigen::array; +using Shape2D = Eigen::array; +using Shape1D = Eigen::array; + +template +using EigenTensor = Eigen::Tensor; +template +using EigenTensorRef = + Eigen::TensorRef>; + +static const int kMaxParallelImgs = 32; + +struct DeformableConv2DParams { + int32 input_batches; + int32 input_channels; + int32 input_rows; + int32 input_cols; + int32 filter_channels; + int32 filter_rows; + int32 filter_cols; + int32 padding_rows; + int32 padding_cols; + int32 stride_rows; + int32 stride_cols; + int32 dilation_rows; + int32 dilation_cols; + int32 output_channels; + int32 output_rows; + int32 output_cols; + int32 parallel_imgs; + int32 weight_groups; + int32 offset_groups; +}; + +namespace functor { + +template +struct DeformableConv2DFunctorBase { + DeformableConv2DFunctorBase( + typename TTypes::ConstTensor input_tensor, + typename TTypes::ConstTensor filter_tensor, + typename TTypes::ConstTensor bias_tensor, + typename TTypes::ConstTensor offset_tensor, + typename TTypes::ConstTensor mask_tensor, + typename TTypes::Tensor column_buffer_tensor, + DeformableConv2DParams _p) + : _input_tensor(input_tensor), + _filter_tensor(filter_tensor), + _bias_tensor(bias_tensor), + _offset_tensor(offset_tensor), + _mask_tensor(mask_tensor), + _column_buffer_tensor(column_buffer_tensor), + p(_p) {} + + virtual Status operator()(OpKernelContext* context) = 0; + + Dtype BilinearInterpolate(int32 batch, int32 channel, Dtype y, Dtype x) { + EigenTensor img = _input_tensor.chip(batch, 0).chip(channel, 0); + + auto max_height = img.dimension(0); + auto max_width = img.dimension(1); + + if (y <= -1 || max_height <= y || x <= -1 || max_width <= x) { + return Dtype(0); + } + + int y_low = floor(y); + int x_low = floor(x); + int y_high = y_low + 1; + int w_high = x_low + 1; + + auto v1 = Dtype(0); + if (y_low >= 0 && x_low >= 0) { + v1 = img(y_low, x_low); + } + + auto v2 = Dtype(0); + if (y_low >= 0 && w_high <= max_width - 1) { + v2 = img(y_low, w_high); + } + + auto v3 = Dtype(0); + if (y_high <= max_height - 1 && x_low >= 0) { + v3 = img(y_high, x_low); + } + + auto v4 = Dtype(0); + if (y_high <= max_height - 1 && w_high <= max_width - 1) { + v4 = img(y_high, w_high); + } + + auto lh = y - y_low; + auto lw = x - x_low; + auto hh = 1 - lh; + auto hw = 1 - lw; + + auto w1 = hh * hw; + auto w2 = hh * lw; + auto w3 = lh * hw; + auto w4 = lh * lw; + + return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; + } + + void DeformableIm2Col(int32 b) { + auto use_mask = _mask_tensor.dimension(0) > 0; + auto num_kernels = + p.input_channels * p.output_rows * p.output_cols * p.parallel_imgs; + auto batches = p.input_batches / p.parallel_imgs; + + EigenTensor offset_tensor = + _offset_tensor + .reshape(Shape8D({batches, p.parallel_imgs, p.offset_groups, + p.filter_rows, p.filter_cols, 2, p.output_rows, + p.output_cols})) + .chip(b, 0); + + EigenTensor mask_tensor = + use_mask ? static_cast>( + _mask_tensor + .reshape(Shape7D({batches, p.parallel_imgs, + p.offset_groups, p.filter_rows, + p.filter_cols, p.output_rows, + p.output_cols})) + .chip(b, 0)) + : _mask_tensor.reshape(Shape6D({0, 0, 0, 0, 0, 0})); + + for (auto k = 0; k < num_kernels; k++) { + const auto current_output_col = k % p.output_cols; + const auto current_output_row = (k / p.output_cols) % p.output_rows; + const auto current_batch = + (k / (p.output_rows * p.output_cols)) % p.parallel_imgs; + const auto current_input_channel = + k / (p.output_rows * p.output_cols * p.parallel_imgs); + const auto current_output_channel = + current_input_channel * p.filter_rows * p.filter_cols; + + auto current_actual_batch = b * p.parallel_imgs + current_batch; + + const auto group_index = + current_input_channel / (p.input_channels / p.offset_groups); + + EigenTensor offset_tensor_chipped = + offset_tensor.chip(current_batch, 0).chip(group_index, 0); + + EigenTensor mask_tensor_chipped = + use_mask + ? static_cast>( + mask_tensor.chip(current_batch, 0).chip(group_index, 0)) + : mask_tensor.reshape(Shape4D({0, 0, 0, 0})); + + auto column_buffer_tensor_channel = current_output_channel; + for (auto current_filter_row = 0; current_filter_row < p.filter_rows; + current_filter_row++) { + for (auto current_filter_col = 0; current_filter_col < p.filter_cols; + current_filter_col++) { + auto offset_h = + offset_tensor_chipped(current_filter_row, current_filter_col, 0, + current_output_row, current_output_col); + auto offset_w = + offset_tensor_chipped(current_filter_row, current_filter_col, 1, + current_output_row, current_output_col); + + auto mask = use_mask ? mask_tensor_chipped( + current_filter_row, current_filter_col, + current_output_row, current_output_col) + : Dtype(1); + + auto y = (current_output_row * p.stride_rows - p.padding_rows) + + current_filter_row * p.dilation_rows + offset_h; + auto x = (current_output_col * p.stride_cols - p.padding_cols) + + current_filter_col * p.dilation_cols + offset_w; + + _column_buffer_tensor(column_buffer_tensor_channel, current_batch, + current_output_row, current_output_col) = + mask * BilinearInterpolate(current_actual_batch, + current_input_channel, y, x); + column_buffer_tensor_channel++; + } + } + } + } + + typename TTypes::ConstTensor _input_tensor; + typename TTypes::ConstTensor _filter_tensor; + typename TTypes::ConstTensor _bias_tensor; + typename TTypes::ConstTensor _offset_tensor; + typename TTypes::ConstTensor _mask_tensor; + typename TTypes::Tensor _column_buffer_tensor; + DeformableConv2DParams p; +}; + +template +struct DeformableConv2DFunctor + : public DeformableConv2DFunctorBase { + DeformableConv2DFunctor( + typename TTypes::ConstTensor input_tensor, + typename TTypes::ConstTensor filter_tensor, + typename TTypes::ConstTensor bias_tensor, + typename TTypes::ConstTensor offset_tensor, + typename TTypes::ConstTensor mask_tensor, + typename TTypes::Tensor column_buffer_tensor, + typename TTypes::Tensor output_tensor, + DeformableConv2DParams _p); + + Status operator()(OpKernelContext* context); + + typename TTypes::Tensor _output_tensor; +}; + +template +struct DeformableConv2DGradFunctor + : public DeformableConv2DFunctorBase { + DeformableConv2DGradFunctor( + typename TTypes::ConstTensor input_tensor, + typename TTypes::ConstTensor filter_tensor, + typename TTypes::ConstTensor bias_tensor, + typename TTypes::ConstTensor offset_tensor, + typename TTypes::ConstTensor mask_tensor, + typename TTypes::ConstTensor output_grad_tensor, + typename TTypes::Tensor input_grad_tensor, + typename TTypes::Tensor filter_grad_tensor, + typename TTypes::Tensor bias_grad_tensor, + typename TTypes::Tensor offset_grad_tensor, + typename TTypes::Tensor mask_grad_tensor, + typename TTypes::Tensor column_buffer_tensor, + DeformableConv2DParams p); + + Status operator()(OpKernelContext* context); + + typename TTypes::ConstTensor _output_grad_tensor; + typename TTypes::Tensor _input_grad_tensor; + typename TTypes::Tensor _filter_grad_tensor; + typename TTypes::Tensor _bias_grad_tensor; + typename TTypes::Tensor _offset_grad_tensor; + typename TTypes::Tensor _mask_grad_tensor; +}; + +} // namespace functor +} // namespace addons +} // namespace tensorflow + +#endif // TENSORFLOW_ADDONS_LAYERS_KERNELS_DEFORMABLECONV2D_OP_H_ diff --git a/tensorflow_addons/custom_ops/layers/cc/ops/deformable_conv2d_op.cc b/tensorflow_addons/custom_ops/layers/cc/ops/deformable_conv2d_op.cc new file mode 100644 index 0000000000..5cb9561f32 --- /dev/null +++ b/tensorflow_addons/custom_ops/layers/cc/ops/deformable_conv2d_op.cc @@ -0,0 +1,246 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { +namespace addons { + +using ::tensorflow::shape_inference::DimensionHandle; +using ::tensorflow::shape_inference::InferenceContext; +using ::tensorflow::shape_inference::ShapeHandle; + +REGISTER_OP("Addons>DeformableConv2D") + .Input("input: T") + .Input("filter: T") + .Input("bias: T") + .Input("offset: T") + .Input("mask: T") + .Output("output: T") + .Attr("strides: list(int)") + .Attr("weight_groups: int") + .Attr("offset_groups: int") + .Attr(GetPaddingAttrString()) + .Attr("dilations: list(int)") + .Attr("data_format: { 'NCHW' }") + .Attr("T: {float, double}") + .SetShapeFn([](InferenceContext *c) { + ShapeHandle input_shape; + ShapeHandle filter_shape; + ShapeHandle bias_shape; + ShapeHandle offset_shape; + ShapeHandle mask_shape; + std::vector strides; + std::vector dilations; + std::string data_format_str; + TensorFormat data_format; + int32 weight_groups; + int32 offset_groups; + Padding padding; + + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &bias_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 4, &offset_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 4, &mask_shape)); + TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); + TF_RETURN_IF_ERROR(c->GetAttr("weight_groups", &weight_groups)); + TF_RETURN_IF_ERROR(c->GetAttr("offset_groups", &offset_groups)); + TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); + TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations)); + TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str)); + FormatFromString(data_format_str, &data_format); + if (strides.size() != 2 || dilations.size() != 2) { + return errors::InvalidArgument("strides/dilations size must be 2."); + } + + DimensionHandle input_batches_dim = c->Dim(input_shape, 0); + DimensionHandle input_channels_dim = c->Dim(input_shape, 1); + DimensionHandle input_rows_dim = c->Dim(input_shape, 2); + DimensionHandle input_cols_dim = c->Dim(input_shape, 3); + + DimensionHandle bias_dim = c->Dim(bias_shape, 0); + + DimensionHandle output_channels_dim = c->Dim(filter_shape, 0); + DimensionHandle filter_channels_dim = c->Dim(filter_shape, 1); + DimensionHandle filter_rows_dim = c->Dim(filter_shape, 2); + DimensionHandle filter_cols_dim = c->Dim(filter_shape, 3); + + DimensionHandle offset_batches_dim = c->Dim(offset_shape, 0); + DimensionHandle offset_channels_dim = c->Dim(offset_shape, 1); + DimensionHandle offset_heights_dim = c->Dim(offset_shape, 2); + DimensionHandle offset_weights_dim = c->Dim(offset_shape, 3); + + DimensionHandle mask_batches_dim = c->Dim(mask_shape, 0); + DimensionHandle mask_channels_dim = c->Dim(mask_shape, 1); + DimensionHandle mask_heights_dim = c->Dim(mask_shape, 2); + DimensionHandle mask_weights_dim = c->Dim(mask_shape, 3); + + bool use_mask = InferenceContext::Value(mask_batches_dim) != 0; + bool use_bias = InferenceContext::Value(bias_dim) != 0; + + auto input_batches = InferenceContext::Value(input_batches_dim); + auto input_rows = InferenceContext::Value(input_rows_dim); + auto input_cols = InferenceContext::Value(input_cols_dim); + + auto output_channels = InferenceContext::Value(output_channels_dim); + + auto filter_rows = InferenceContext::Value(filter_rows_dim); + auto filter_cols = InferenceContext::Value(filter_cols_dim); + + auto stride_rows = strides[0]; + auto stride_cols = strides[1]; + auto diration_rows = dilations[0]; + auto diration_cols = dilations[1]; + + DimensionHandle tmp; + + if (use_bias) { + TF_RETURN_IF_ERROR(c->WithValue(bias_dim, output_channels, &tmp)); + } + + TF_RETURN_IF_ERROR( + c->Divide(output_channels_dim, weight_groups, true, &tmp)); + TF_RETURN_IF_ERROR( + c->Divide(input_channels_dim, offset_groups, true, &tmp)); + + TF_RETURN_IF_ERROR(c->Multiply(filter_channels_dim, weight_groups, &tmp)); + TF_RETURN_IF_ERROR(c->Merge(input_channels_dim, tmp, &tmp)); + + TF_RETURN_IF_ERROR(c->WithValue(offset_batches_dim, input_batches, &tmp)); + + if (use_mask) { + TF_RETURN_IF_ERROR(c->WithValue(mask_batches_dim, input_batches, &tmp)); + } + + if (InferenceContext::ValueKnown(filter_rows_dim) && + InferenceContext::ValueKnown(filter_cols_dim)) { + auto filter_area = filter_rows * filter_cols * offset_groups; + TF_RETURN_IF_ERROR( + c->WithValue(offset_channels_dim, 2 * filter_area, &tmp)); + if (use_mask) { + TF_RETURN_IF_ERROR( + c->WithValue(mask_channels_dim, filter_area, &tmp)); + } + } + + if (!InferenceContext::ValueKnown(input_rows_dim) || + !InferenceContext::ValueKnown(input_cols_dim) || + !InferenceContext::ValueKnown(filter_rows_dim) || + !InferenceContext::ValueKnown(filter_cols_dim)) { + c->set_output(0, c->MakeShape({input_batches_dim, output_channels_dim, + InferenceContext::kUnknownDim, + InferenceContext::kUnknownDim})); + return Status::OK(); + } + + auto effective_filter_rows = + filter_rows + (filter_rows - 1) * (diration_rows - 1); + auto effective_filter_cols = + filter_cols + (filter_cols - 1) * (diration_cols - 1); + + int64 output_rows, output_cols; + int64 padding_before, padding_after; + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( + input_rows, effective_filter_rows, stride_rows, padding, &output_rows, + &padding_before, &padding_after)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( + input_cols, effective_filter_cols, stride_cols, padding, &output_cols, + &padding_before, &padding_after)); + + TF_RETURN_IF_ERROR(c->WithValue(offset_heights_dim, output_rows, &tmp)); + TF_RETURN_IF_ERROR(c->WithValue(offset_weights_dim, output_cols, &tmp)); + if (use_mask) { + TF_RETURN_IF_ERROR(c->WithValue(mask_heights_dim, output_rows, &tmp)); + TF_RETURN_IF_ERROR(c->WithValue(mask_weights_dim, output_cols, &tmp)); + } + + c->set_output(0, c->MakeShape({input_batches_dim, output_channels_dim, + output_rows, output_cols})); + + return Status::OK(); + }) + .Doc(R"doc(Compute Modulated Deformable Convolution. + +This layer implements the operation from +Deformable ConvNets v2: More Deformable, Better Results (Zhu et al.) + +input: A `Tensor` of the format specified by `data_format`. +filter: A `Tensor` of the convolution kernel weights. Its shape is + `(output_channel, input_channel // weight_groups, kernel_height, kernel_width)`. +bias: A `Tensor` of the convolution bias. + `(0,)`-shape `Tensor` is passed when bias is disabled on Python side. +offset: A `Tensor` of the offsets which are applied for each position + in the convolution kernel. The channel size must be + `2 * kernel_height * kernel_width * offset_groups`. +mask: A `Tensor` of the modulation which are applied for each position + in the convolution kernel. The channel size must be + `kernel_height * kernel_width * offset_groups` if the modulation mode is + enabled on Python side. `(0,)`-shape `Tensor` is passed when the modulation + mode is disabled on Python side. +strides: A list of 2 integers, specifying the strides of the convolution + along the height and width. +weight_groups: An integer specifying the number of groups in which the input is + split along the channel axis. Each group is convolved separately with + `filters / weight_groups` filters. The output is the concatenation of all + the groups results along the channel axis. Input channels and output + channels must both be divisible by groups. +offset_groups: An integer specifying the number of groups in which the input is + split along the channel axis. Each group is convolved separately with + its group offset. +padding: A string specifying the padding type. + Possible values are: + "VALID" + "SAME" +dilations: A list of 2 integers, specifying the dilation rate to use + for dilated convolution. +data_format: Specifies the data format. + Possible values is: + "NCHW" float [batch, channels, height, width] + Defaults to `"NCHW"`. +)doc"); + +REGISTER_OP("Addons>DeformableConv2DGrad") + .Input("input: T") + .Input("filter: T") + .Input("bias: T") + .Input("offset: T") + .Input("mask: T") + .Input("output_grad: T") + .Output("input_grad: T") + .Output("filter_grad: T") + .Output("bias_grad: T") + .Output("offset_grad: T") + .Output("mask_grad: T") + .Attr("strides: list(int)") + .Attr("weight_groups: int") + .Attr("offset_groups: int") + .Attr(GetPaddingAttrString()) + .Attr("dilations: list(int)") + .Attr("data_format: { 'NCHW' }") + .Attr("T: {float, double}") + .SetShapeFn([](InferenceContext *c) { + c->set_output(0, c->input(0)); + c->set_output(1, c->input(1)); + c->set_output(2, c->input(2)); + c->set_output(3, c->input(3)); + c->set_output(4, c->input(4)); + return Status::OK(); + }) + .Doc(R"doc(DeformableConv2DGrad op.)doc"); + +} // namespace addons +} // namespace tensorflow diff --git a/tensorflow_addons/layers/BUILD b/tensorflow_addons/layers/BUILD index 958906cea8..9f420029bc 100644 --- a/tensorflow_addons/layers/BUILD +++ b/tensorflow_addons/layers/BUILD @@ -7,6 +7,7 @@ py_library( srcs = glob(["*.py"]), data = [ "//tensorflow_addons/custom_ops/layers:_correlation_cost_ops.so", + "//tensorflow_addons/custom_ops/layers:_deformable_conv2d_ops.so", ], deps = [ "//tensorflow_addons/activations", diff --git a/tensorflow_addons/layers/__init__.py b/tensorflow_addons/layers/__init__.py index 78c972e294..b06fbe2d06 100644 --- a/tensorflow_addons/layers/__init__.py +++ b/tensorflow_addons/layers/__init__.py @@ -41,3 +41,4 @@ from tensorflow_addons.layers.stochastic_depth import StochasticDepth from tensorflow_addons.layers.noisy_dense import NoisyDense from tensorflow_addons.layers.crf import CRF +from tensorflow_addons.layers.deformable_conv2d import DeformableConv2D diff --git a/tensorflow_addons/layers/deformable_conv2d.py b/tensorflow_addons/layers/deformable_conv2d.py new file mode 100644 index 0000000000..3155fce91c --- /dev/null +++ b/tensorflow_addons/layers/deformable_conv2d.py @@ -0,0 +1,332 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import typing + +import tensorflow as tf +from typeguard import typechecked +from tensorflow_addons.utils import types +from tensorflow_addons.utils.resource_loader import LazySO +from tensorflow.python.keras.utils import conv_utils + +_deformable_conv2d_ops_so = LazySO("custom_ops/layers/_deformable_conv2d_ops.so") + + +@typechecked +def _deformable_conv2d( + input_tensor: tf.Tensor, + filter_tensor: tf.Tensor, + bias_tensor: tf.Tensor, + offset_tensor: tf.Tensor, + mask_tensor: tf.Tensor, + strides: typing.Union[tuple, list], + dilations: typing.Union[tuple, list], + weight_groups: int, + offset_groups: int, + padding: str, +): + with tf.name_scope("deformable_conv2d"): + return _deformable_conv2d_ops_so.ops.addons_deformable_conv2d( + input=input_tensor, + filter=filter_tensor, + bias=bias_tensor, + offset=offset_tensor, + mask=mask_tensor, + strides=strides, + weight_groups=weight_groups, + offset_groups=offset_groups, + padding=padding, + data_format="NCHW", + dilations=dilations, + ) + + +@tf.RegisterGradient("Addons>DeformableConv2D") +def _deformable_conv2d_grad(op, grad): + input = op.inputs[0] + filter = op.inputs[1] + bias = op.inputs[2] + offset = op.inputs[3] + mask = op.inputs[4] + strides = op.get_attr("strides") + weight_groups = op.get_attr("weight_groups") + offset_groups = op.get_attr("offset_groups") + padding = op.get_attr("padding") + dilations = op.get_attr("dilations") + data_format = op.get_attr("data_format") + + data_grad = _deformable_conv2d_ops_so.ops.addons_deformable_conv2d_grad( + input, + filter, + bias, + offset, + mask, + grad, + strides=strides, + weight_groups=weight_groups, + offset_groups=offset_groups, + padding=padding, + dilations=dilations, + data_format=data_format, + ) + return data_grad + + +@tf.keras.utils.register_keras_serializable(package="Addons") +class DeformableConv2D(tf.keras.layers.Layer): + @typechecked + def __init__( + self, + filters: int, + kernel_size: typing.Union[int, tuple, list] = (3, 3), + strides: typing.Union[int, tuple, list] = (1, 1), + padding: str = "valid", + data_format: str = "channels_first", + dilation_rate: typing.Union[int, tuple, list] = (1, 1), + weight_groups: int = 1, + offset_groups: int = 1, + use_mask: bool = False, + use_bias: bool = False, + kernel_initializer: types.Initializer = None, + bias_initializer: types.Initializer = None, + kernel_regularizer: types.Regularizer = None, + bias_regularizer: types.Regularizer = None, + kernel_constraint: types.Constraint = None, + bias_constraint: types.Constraint = None, + **kwargs + ): + """Modulated Deformable Convolution Layer. + + This layer implements from [Deformable ConvNets v2: More Deformable, Better Results] + (https://arxiv.org/abs/1811.11168)(Zhu et al.). + + Arguments: + filters: Integer, the dimensionality of the output space (i.e. the number of + output filters in the convolution). + kernel_size: An integer or tuple/list of 2 integers, specifying the height + and width of the 2D convolution window. Can be a single integer to specify + the same value for all spatial dimensions. + strides: An integer or tuple/list of 2 integers, specifying the strides of + the convolution along the height and width. Can be a single integer to + specify the same value for all spatial dimensions. Specifying any stride + value != 1 is incompatible with specifying any `dilation_rate` value != 1. + padding: one of `"valid"` or `"same"` (case-insensitive). + data_format: Specifies the data format. + Possible values is: + "channels_first" float [batch, channels, height, width] + Defaults to `"channels_first"`. + dilation_rate: an integer or tuple/list of 2 integers, specifying the + dilation rate to use for dilated convolution. Can be a single integer to + specify the same value for all spatial dimensions. + weight_groups: A positive integer specifying the number of groups in which the + input is split along the channel axis. Each group is convolved separately + with `filters / weight_groups` filters. The output is the concatenation of all + the `weight_groups` results along the channel axis. Input channels and `filters` + must both be divisible by `groups`. + offset_groups: An integer specifying the number of groups in which the input is + split along the channel axis. Each group is convolved separately with + its group offset. + use_mask: Boolean, whether the layer uses a modulation input. + use_bias: Boolean, whether the layer uses a bias vector. + kernel_initializer: Initializer for the `kernel` weights matrix (see + `keras.initializers`). + bias_initializer: Initializer for the bias vector (see + `keras.initializers`). + kernel_regularizer: Regularizer function applied to the `kernel` weights + matrix (see `keras.regularizers`). + bias_regularizer: Regularizer function applied to the bias vector (see + `keras.regularizers`). + activity_regularizer: Regularizer function applied to the output of the + layer (its "activation") (see `keras.regularizers`). + kernel_constraint: Constraint function applied to the kernel matrix (see + `keras.constraints`). + bias_constraint: Constraint function applied to the bias vector (see + `keras.constraints`). + """ + super(DeformableConv2D, self).__init__(**kwargs) + + self.filters = filters + self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, "kernel_size") + self.strides = conv_utils.normalize_tuple(strides, 2, "strides") + self.padding = conv_utils.normalize_padding(padding) + self.data_format = conv_utils.normalize_data_format(data_format) + self.dilation_rate = conv_utils.normalize_tuple( + dilation_rate, 2, "dilation_rate" + ) + self.weight_groups = weight_groups + self.offset_groups = offset_groups + self.use_mask = use_mask + self.use_bias = use_bias + self.kernel_initializer = tf.keras.initializers.get(kernel_initializer) + self.bias_initializer = tf.keras.initializers.get(bias_initializer) + self.kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer) + self.bias_regularizer = tf.keras.regularizers.get(bias_regularizer) + self.kernel_constraint = tf.keras.constraints.get(kernel_constraint) + self.bias_constraint = tf.keras.constraints.get(bias_constraint) + + if self.padding == "causal": + raise ValueError("Causal padding is not supported.") + + if self.data_format != "channels_first": + raise ValueError("`channels_last` data format is not supported.") + + if self.filters % self.weight_groups != 0: + raise ValueError("filters must be divisible by weight_group.") + + self.filter_weights = None + self.filter_bias = None + + def _validate_shapes(self, shapes): + if type(shapes) is not list: + raise ValueError("DeformableConv2D input must be list of Tensor.") + elif self.use_mask and len(shapes) != 3: + raise ValueError("DeformableConv2D input must be 3-length list of Tensor.") + elif not self.use_mask and len(shapes) != 2: + raise ValueError("DeformableConv2D input must be 2-length list of Tensor.") + + def build(self, shapes): + self._validate_shapes(shapes) + + input_shape = shapes[0] + offset_shape = shapes[1] + mask_shape = shapes[2] if self.use_mask else None + + exp_off_c = self.offset_groups * 2 * self.kernel_size[0] * self.kernel_size[1] + + off_b, off_c, off_h, off_w = offset_shape + in_b, in_c, in_h, in_w = input_shape + + out_h = conv_utils.conv_output_length( + in_h, + self.kernel_size[0], + padding=self.padding, + stride=self.strides[0], + dilation=self.dilation_rate[0], + ) + out_w = conv_utils.conv_output_length( + in_w, + self.kernel_size[1], + padding=self.padding, + stride=self.strides[1], + dilation=self.dilation_rate[1], + ) + + if off_b != in_b or off_c != exp_off_c or off_h != out_h or off_w != out_w: + raise ValueError( + f"DeformableConv2D Offset shape must be [{in_b}, {exp_off_c}, {out_h}, {out_w}]." + ) + + if mask_shape is not None: + exp_mask_c = exp_off_c // 2 + + mask_b, mask_c, mask_h, mask_w = mask_shape + + if ( + mask_b != in_b + or mask_c != exp_mask_c + or mask_h != out_h + or mask_w != out_w + ): + raise ValueError( + f"DeformableConv2D Mask shape must be [{in_b}, {exp_mask_c}, {out_h}, {out_w}]." + ) + + # Channel first + shape = ( + self.filters, + input_shape[1] // self.weight_groups, + self.kernel_size[0], + self.kernel_size[1], + ) + + self.filter_weights = self.add_weight( + name="filter", + shape=shape, + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + trainable=True, + ) + + if self.use_bias: + self.filter_bias = self.add_weight( + name="bias", + shape=(self.filters,), + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint, + trainable=True, + ) + else: + self.filter_bias = tf.zeros((0,)) + + self.built = True + + def compute_output_shape(self, shapes): + self._validate_shapes(shapes) + + input_shape = shapes[0] + in_b, _, in_h, in_w = input_shape + + out_h = conv_utils.conv_output_length( + in_h, + self.kernel_size[0], + padding=self.padding, + stride=self.strides[0], + dilation=self.dilation_rate[0], + ) + out_w = conv_utils.conv_output_length( + in_w, + self.kernel_size[1], + padding=self.padding, + stride=self.strides[1], + dilation=self.dilation_rate[1], + ) + + return tf.TensorShape([in_b, self.filters, out_h, out_w]) + + def call(self, inputs, **kwargs): + input_tensor = inputs[0] + offset_tensor = inputs[1] + mask_tensor = inputs[2] if self.use_mask else tf.zeros((0, 0, 0, 0)) + + return _deformable_conv2d( + input_tensor=tf.convert_to_tensor(input_tensor), + filter_tensor=tf.convert_to_tensor(self.filter_weights), + bias_tensor=tf.convert_to_tensor(self.filter_bias), + offset_tensor=tf.convert_to_tensor(offset_tensor), + mask_tensor=tf.convert_to_tensor(mask_tensor), + strides=self.strides, + weight_groups=self.weight_groups, + offset_groups=self.offset_groups, + padding="SAME" if self.padding == "same" else "VALID", + dilations=self.dilation_rate, + ) + + def get_config(self): + config = { + "kernel_size": self.kernel_size, + "filters": self.filters, + "strides": self.strides, + "padding": self.padding, + "data_format": self.data_format, + "dilation_rate": self.dilation_rate, + "weight_groups": self.weight_groups, + "offset_groups": self.offset_groups, + "use_mask": self.use_mask, + "use_bias": self.use_bias, + } + base_config = super().get_config() + return {**base_config, **config} diff --git a/tensorflow_addons/layers/tests/deformable_conv2d_test.py b/tensorflow_addons/layers/tests/deformable_conv2d_test.py new file mode 100644 index 0000000000..ba3dceec61 --- /dev/null +++ b/tensorflow_addons/layers/tests/deformable_conv2d_test.py @@ -0,0 +1,420 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +import pytest +import numpy as np +import tensorflow as tf +from tensorflow.python.keras.utils import conv_utils +from tensorflow_addons.layers.deformable_conv2d import DeformableConv2D + + +def _get_padding_length( + padding, filter_size, dilation_rate, stride, input_size, output_size +): + effective_filter_size = (filter_size - 1) * dilation_rate + 1 + + pad = 0 + if padding == "same": + pad = ((output_size - 1) * stride + effective_filter_size - input_size) // 2 + + return pad + + +def _bilinear_interpolate(img, y, x): + max_height, max_width = img.shape + + if y <= -1 or max_height <= y or x <= -1 or max_width <= x: + return 0.0 + + y_low = int(np.floor(y)) + x_low = int(np.floor(x)) + y_high = y_low + 1 + w_high = x_low + 1 + + v1 = 0.0 + if y_low >= 0 and x_low >= 0: + v1 = img[y_low, x_low] + + v2 = 0.0 + if y_low >= 0 and w_high <= max_width - 1: + v2 = img[y_low, w_high] + + v3 = 0.0 + if y_high <= max_height - 1 and x_low >= 0: + v3 = img[y_high, x_low] + + v4 = 0.0 + if y_high <= max_height - 1 and w_high <= max_width - 1: + v4 = img[y_high, w_high] + + lh = y - y_low + lw = x - x_low + hh = 1 - lh + hw = 1 - lw + + w1 = hh * hw + w2 = hh * lw + w3 = lh * hw + w4 = lh * lw + + return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4 + + +def _expected( + input_tensor, + filter_tensor, + offset_tensor, + mask_tensor, + bias, + strides, + weight_groups, + offset_groups, + padding, + dilation_rate, +): + input_tensor = input_tensor.numpy() + filter_tensor = filter_tensor.numpy() + offset_tensor = offset_tensor.numpy() + mask_tensor = mask_tensor.numpy() + bias = bias.numpy() + + padding = conv_utils.normalize_padding(padding) + + stride_rows, stride_cols = conv_utils.normalize_tuple(strides, 2, "strides") + dilation_rows, dilation_cols = conv_utils.normalize_tuple( + dilation_rate, 2, "dilation_rate" + ) + filter_rows, filter_cols = filter_tensor.shape[-2:] + + batches, input_channels, input_rows, input_cols = input_tensor.shape + output_channels = filter_tensor.shape[0] + + output_rows = conv_utils.conv_output_length( + input_rows, + filter_rows, + padding=padding, + stride=stride_rows, + dilation=dilation_rows, + ) + output_cols = conv_utils.conv_output_length( + input_cols, + filter_cols, + padding=padding, + stride=stride_cols, + dilation=dilation_cols, + ) + + padding_rows = _get_padding_length( + padding, filter_rows, dilation_rows, stride_rows, input_rows, output_rows + ) + padding_cols = _get_padding_length( + padding, filter_cols, dilation_cols, stride_cols, input_cols, output_cols + ) + + input_channels_per_offset_group = input_channels // offset_groups + + input_channels_per_weight_groups = filter_tensor.shape[1] + output_channels_per_weight_groups = output_channels // weight_groups + + offset_tensor = offset_tensor.reshape((batches, -1, 2, output_rows, output_cols)) + + output = np.zeros((batches, output_channels, output_rows, output_cols)) + + for batch in range(batches): + for output_channel in range(output_channels): + for output_row in range(output_rows): + for output_col in range(output_cols): + for filter_row in range(filter_rows): + for filter_col in range(filter_cols): + for input_channel in range( + input_channels_per_weight_groups + ): + weight_group = ( + output_channel // output_channels_per_weight_groups + ) + current_input_channel = ( + weight_group * input_channels_per_weight_groups + + input_channel + ) + + offset_group = ( + current_input_channel + // input_channels_per_offset_group + ) + offset_idx = ( + offset_group * (filter_rows * filter_cols) + + filter_row * filter_cols + + filter_col + ) + + dy = offset_tensor[ + batch, offset_idx, 0, output_row, output_col + ] + dx = offset_tensor[ + batch, offset_idx, 1, output_row, output_col + ] + + mask = mask_tensor[ + batch, offset_idx, output_row, output_col + ] + + y = ( + stride_rows * output_row + - padding_rows + + dilation_rows * filter_row + + dy + ) + x = ( + stride_cols * output_col + - padding_cols + + dilation_cols * filter_col + + dx + ) + + output[ + batch, output_channel, output_row, output_col + ] += ( + mask + * filter_tensor[ + output_channel, + input_channel, + filter_row, + filter_col, + ] + * _bilinear_interpolate( + input_tensor[ + batch, current_input_channel, :, : + ], + y, + x, + ) + ) + + output += bias.reshape((1, output_channels, 1, 1)) + return output + + +@pytest.mark.with_device(["cpu"]) +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_forward_simple(data_format): + if data_format == "channels_last": + return + + batches = 1 + input_channels = 6 + filters = 2 + weight_groups = 2 + offset_groups = 3 + + strides = (2, 1) + padding = "same" + dilation_rate = (2, 1) + kernel_size = (3, 2) + + input_rows, input_cols = 5, 4 + filter_rows, filter_cols = kernel_size + stride_rows, stride_cols = strides + dilation_rows, dilation_cols = dilation_rate + + output_rows = conv_utils.conv_output_length( + input_rows, + filter_rows, + padding=padding, + stride=stride_rows, + dilation=dilation_rows, + ) + output_cols = conv_utils.conv_output_length( + input_cols, + filter_cols, + padding=padding, + stride=stride_cols, + dilation=dilation_cols, + ) + + offsets = offset_groups * filter_rows * filter_cols + + input_tensor = tf.random.uniform([batches, input_channels, input_rows, input_cols]) + offset_tensor = tf.random.uniform([batches, 2 * offsets, output_rows, output_cols]) + mask_tensor = tf.random.uniform([batches, offsets, output_rows, output_cols]) + + conv = DeformableConv2D( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + dilation_rate=dilation_rate, + weight_groups=weight_groups, + offset_groups=offset_groups, + use_mask=True, + use_bias=True, + ) + + actual = conv([input_tensor, offset_tensor, mask_tensor]) + + filter_tensor = conv.filter_weights + bias = conv.filter_bias + + expected = _expected( + input_tensor, + filter_tensor, + offset_tensor, + mask_tensor, + bias, + strides, + weight_groups, + offset_groups, + padding, + dilation_rate, + ) + + np.testing.assert_allclose(actual.numpy(), expected) + + +@pytest.mark.with_device(["cpu"]) +def test_gradients(data_format): + if data_format == "channels_last": + return + + batches = 1 + input_channels = 6 + filters = 2 + weight_groups = 2 + offset_groups = 3 + + strides = (2, 1) + padding = "same" + dilation_rate = (2, 1) + kernel_size = (3, 2) + + input_rows, input_cols = 5, 4 + filter_rows, filter_cols = kernel_size + stride_rows, stride_cols = strides + dilation_rows, dilation_cols = dilation_rate + + output_rows = conv_utils.conv_output_length( + input_rows, + filter_rows, + padding=padding, + stride=stride_rows, + dilation=dilation_rows, + ) + output_cols = conv_utils.conv_output_length( + input_cols, + filter_cols, + padding=padding, + stride=stride_cols, + dilation=dilation_cols, + ) + + offsets = offset_groups * filter_rows * filter_cols + + input_tensor = tf.random.uniform([batches, input_channels, input_rows, input_cols]) + offset_tensor = tf.random.uniform([batches, 2 * offsets, output_rows, output_cols]) + mask_tensor = tf.random.uniform([batches, offsets, output_rows, output_cols]) + + conv = DeformableConv2D( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + dilation_rate=dilation_rate, + weight_groups=weight_groups, + offset_groups=offset_groups, + use_mask=True, + use_bias=True, + ) + + def conv_fn(input_tensor, offset_tensor, mask_tensor): + return conv([input_tensor, offset_tensor, mask_tensor]) + + theoretical, numerical = tf.test.compute_gradient( + conv_fn, [input_tensor, offset_tensor, mask_tensor] + ) + + np.testing.assert_allclose(theoretical[0], numerical[0], atol=1e-3) + np.testing.assert_allclose(theoretical[1], numerical[1], atol=1e-3) + np.testing.assert_allclose(theoretical[2], numerical[2], atol=1e-3) + + +@pytest.mark.with_device(["cpu"]) +def test_keras(data_format): + if data_format == "channels_last": + return + + batches = 1 + input_channels = 6 + filters = 2 + weight_groups = 2 + offset_groups = 3 + + strides = (2, 1) + padding = "same" + dilation_rate = (2, 1) + kernel_size = (3, 2) + + input_rows, input_cols = 5, 4 + filter_rows, filter_cols = kernel_size + stride_rows, stride_cols = strides + dilation_rows, dilation_cols = dilation_rate + + output_rows = conv_utils.conv_output_length( + input_rows, + filter_rows, + padding=padding, + stride=stride_rows, + dilation=dilation_rows, + ) + output_cols = conv_utils.conv_output_length( + input_cols, + filter_cols, + padding=padding, + stride=stride_cols, + dilation=dilation_cols, + ) + + offsets = offset_groups * filter_rows * filter_cols + + input_tensor = tf.random.uniform([batches, input_channels, input_rows, input_cols]) + offset_tensor = tf.random.uniform([batches, 2 * offsets, output_rows, output_cols]) + mask_tensor = tf.random.uniform([batches, offsets, output_rows, output_cols]) + + input_a = tf.keras.Input([input_channels, input_rows, input_cols]) + input_b = tf.keras.Input([2 * offsets, output_rows, output_cols]) + input_c = tf.keras.Input([offsets, output_rows, output_cols]) + + conv = DeformableConv2D( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + dilation_rate=dilation_rate, + weight_groups=weight_groups, + offset_groups=offset_groups, + use_mask=True, + use_bias=True, + ) + + expected_output_shape = tuple( + conv.compute_output_shape([input_a.shape, input_b.shape, input_c.shape]) + ) + + x = [input_a, input_b, input_c] + y = conv(x) + model = tf.keras.models.Model(x, y) + actual_output = model([input_tensor, offset_tensor, mask_tensor]) + + assert tf.keras.backend.dtype(y[0]) == "float32" + assert actual_output.shape[1:] == expected_output_shape[1:] From 5066b2a37e226a1b1929c53a7684e7b4bf0ebccc Mon Sep 17 00:00:00 2001 From: Licht Takeuchi Date: Thu, 8 Oct 2020 12:22:54 +0900 Subject: [PATCH 02/13] Fix headers --- .../custom_ops/layers/cc/kernels/deformable_conv2d_op.cc | 1 + .../custom_ops/layers/cc/ops/deformable_conv2d_op.cc | 1 + 2 files changed, 2 insertions(+) diff --git a/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc b/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc index 21d136ace6..8f182c7561 100644 --- a/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc +++ b/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc @@ -22,6 +22,7 @@ #include "tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.h" #include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/kernel_shape_util.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/util/work_sharder.h" diff --git a/tensorflow_addons/custom_ops/layers/cc/ops/deformable_conv2d_op.cc b/tensorflow_addons/custom_ops/layers/cc/ops/deformable_conv2d_op.cc index 5cb9561f32..0d862f3de5 100644 --- a/tensorflow_addons/custom_ops/layers/cc/ops/deformable_conv2d_op.cc +++ b/tensorflow_addons/custom_ops/layers/cc/ops/deformable_conv2d_op.cc @@ -14,6 +14,7 @@ // ============================================================================= #include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/kernel_shape_util.h" #include "tensorflow/core/framework/op.h" namespace tensorflow { From 04b283da3e6ee88441b69e72d46d3c0295295509 Mon Sep 17 00:00:00 2001 From: Licht Takeuchi Date: Thu, 8 Oct 2020 12:34:14 +0900 Subject: [PATCH 03/13] Remove string format expression on Python codes --- tensorflow_addons/layers/deformable_conv2d.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tensorflow_addons/layers/deformable_conv2d.py b/tensorflow_addons/layers/deformable_conv2d.py index 3155fce91c..b4d97248bf 100644 --- a/tensorflow_addons/layers/deformable_conv2d.py +++ b/tensorflow_addons/layers/deformable_conv2d.py @@ -225,7 +225,9 @@ def build(self, shapes): if off_b != in_b or off_c != exp_off_c or off_h != out_h or off_w != out_w: raise ValueError( - f"DeformableConv2D Offset shape must be [{in_b}, {exp_off_c}, {out_h}, {out_w}]." + "DeformableConv2D Offset shape must be [{}, {}, {}, {}].".format( + in_b, exp_off_c, out_h, out_w + ) ) if mask_shape is not None: @@ -240,7 +242,9 @@ def build(self, shapes): or mask_w != out_w ): raise ValueError( - f"DeformableConv2D Mask shape must be [{in_b}, {exp_mask_c}, {out_h}, {out_w}]." + "DeformableConv2D Mask shape must be [{}, {}, {}, {}].".format( + in_b, exp_mask_c, out_h, out_w + ) ) # Channel first From 44bf7756a6fbe9f8fe05160fe6f4701fc03573f3 Mon Sep 17 00:00:00 2001 From: Licht Takeuchi Date: Thu, 8 Oct 2020 12:45:06 +0900 Subject: [PATCH 04/13] Remove Tensorflow Python internal reference --- tensorflow_addons/layers/deformable_conv2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/layers/deformable_conv2d.py b/tensorflow_addons/layers/deformable_conv2d.py index b4d97248bf..da47b11a18 100644 --- a/tensorflow_addons/layers/deformable_conv2d.py +++ b/tensorflow_addons/layers/deformable_conv2d.py @@ -19,7 +19,7 @@ from typeguard import typechecked from tensorflow_addons.utils import types from tensorflow_addons.utils.resource_loader import LazySO -from tensorflow.python.keras.utils import conv_utils +from tensorflow.keras.utils import conv_utils _deformable_conv2d_ops_so = LazySO("custom_ops/layers/_deformable_conv2d_ops.so") From 536fc16c613926208c5bc4354eaa72e09ef5a73e Mon Sep 17 00:00:00 2001 From: Licht Takeuchi Date: Thu, 8 Oct 2020 12:47:43 +0900 Subject: [PATCH 05/13] Revert "Remove Tensorflow Python internal reference" This reverts commit 44bf7756a6fbe9f8fe05160fe6f4701fc03573f3. --- tensorflow_addons/layers/deformable_conv2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/layers/deformable_conv2d.py b/tensorflow_addons/layers/deformable_conv2d.py index da47b11a18..b4d97248bf 100644 --- a/tensorflow_addons/layers/deformable_conv2d.py +++ b/tensorflow_addons/layers/deformable_conv2d.py @@ -19,7 +19,7 @@ from typeguard import typechecked from tensorflow_addons.utils import types from tensorflow_addons.utils.resource_loader import LazySO -from tensorflow.keras.utils import conv_utils +from tensorflow.python.keras.utils import conv_utils _deformable_conv2d_ops_so = LazySO("custom_ops/layers/_deformable_conv2d_ops.so") From a64e05bc55b525c243b708cba4fb94fed74eef05 Mon Sep 17 00:00:00 2001 From: Licht Takeuchi Date: Thu, 8 Oct 2020 22:32:32 +0900 Subject: [PATCH 06/13] Use keras_utils instead of TensorFlow private API --- tensorflow_addons/layers/deformable_conv2d.py | 2 +- tensorflow_addons/utils/keras_utils.py | 42 +++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/tensorflow_addons/layers/deformable_conv2d.py b/tensorflow_addons/layers/deformable_conv2d.py index b4d97248bf..68e16bc6ea 100644 --- a/tensorflow_addons/layers/deformable_conv2d.py +++ b/tensorflow_addons/layers/deformable_conv2d.py @@ -19,7 +19,7 @@ from typeguard import typechecked from tensorflow_addons.utils import types from tensorflow_addons.utils.resource_loader import LazySO -from tensorflow.python.keras.utils import conv_utils +import tensorflow_addons.utils.keras_utils as conv_utils _deformable_conv2d_ops_so = LazySO("custom_ops/layers/_deformable_conv2d_ops.so") diff --git a/tensorflow_addons/utils/keras_utils.py b/tensorflow_addons/utils/keras_utils.py index e480527c21..fc0cc32759 100644 --- a/tensorflow_addons/utils/keras_utils.py +++ b/tensorflow_addons/utils/keras_utils.py @@ -68,6 +68,20 @@ def get_config(self): return {**base_config, **config} +def normalize_padding(value): + """A copy of tensorflow.python.keras.util.""" + if isinstance(value, (list, tuple)): + return value + padding = value.lower() + if padding not in {"valid", "same", "causal"}: + raise ValueError( + "The `padding` argument must be a list/tuple or one of " + '"valid", "same" (or "causal", only for `Conv1D). ' + "Received: " + str(padding) + ) + return padding + + def normalize_data_format(value): if value is None: value = tf.keras.backend.image_data_format() @@ -143,6 +157,34 @@ def normalize_tuple(value, n, name): return value_tuple +def conv_output_length(input_length, filter_size, padding, stride, dilation=1): + """Determines output length of a convolution given input length. + + A copy of tensorflow.python.keras.util. + + Arguments: + input_length: integer. + filter_size: integer. + padding: one of "same", "valid", "full", "causal" + stride: integer. + dilation: dilation rate, integer. + + Returns: + The output length (integer). + """ + if input_length is None: + return None + assert padding in {"same", "valid", "full", "causal"} + dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1) + if padding in ["same", "causal"]: + output_length = input_length + elif padding == "valid": + output_length = input_length - dilated_filter_size + 1 + elif padding == "full": + output_length = input_length + dilated_filter_size - 1 + return (output_length + stride - 1) // stride + + def _hasattr(obj, attr_name): # If possible, avoid retrieving the attribute as the object might run some # lazy computation in it. From 480bbb2c50598aa022d88d91cc25ed511b4ede14 Mon Sep 17 00:00:00 2001 From: Licht Takeuchi Date: Thu, 8 Oct 2020 22:36:36 +0900 Subject: [PATCH 07/13] Use keras_utils when testing instead of TensorFlow private API --- tensorflow_addons/layers/tests/deformable_conv2d_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/layers/tests/deformable_conv2d_test.py b/tensorflow_addons/layers/tests/deformable_conv2d_test.py index ba3dceec61..5a18adaa88 100644 --- a/tensorflow_addons/layers/tests/deformable_conv2d_test.py +++ b/tensorflow_addons/layers/tests/deformable_conv2d_test.py @@ -17,7 +17,7 @@ import pytest import numpy as np import tensorflow as tf -from tensorflow.python.keras.utils import conv_utils +import tensorflow_addons.utils.keras_utils as conv_utils from tensorflow_addons.layers.deformable_conv2d import DeformableConv2D From 7e9654f2884acb63fcfb737299b51ab329ec4d02 Mon Sep 17 00:00:00 2001 From: Licht Takeuchi Date: Thu, 22 Oct 2020 11:27:09 +0900 Subject: [PATCH 08/13] Refactor and add GPU kernel --- tensorflow_addons/custom_ops/layers/BUILD | 2 + .../layers/cc/kernels/deformable_conv2d_op.cc | 915 +++++++----------- .../layers/cc/kernels/deformable_conv2d_op.h | 716 ++++++++++---- .../cc/kernels/deformable_conv2d_op_gpu.cu.cc | 385 ++++++++ .../layers/tests/deformable_conv2d_test.py | 6 +- 5 files changed, 1276 insertions(+), 748 deletions(-) create mode 100644 tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op_gpu.cu.cc diff --git a/tensorflow_addons/custom_ops/layers/BUILD b/tensorflow_addons/custom_ops/layers/BUILD index e9755770f8..f45c0a2133 100644 --- a/tensorflow_addons/custom_ops/layers/BUILD +++ b/tensorflow_addons/custom_ops/layers/BUILD @@ -31,5 +31,7 @@ custom_op_library( "@cub_archive//:cub", ], cuda_srcs = [ + "cc/kernels/deformable_conv2d_op.h", + "cc/kernels/deformable_conv2d_op_gpu.cu.cc", ], ) diff --git a/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc b/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc index 8f182c7561..ddfceb5ca5 100644 --- a/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc +++ b/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc @@ -23,8 +23,6 @@ #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/kernel_shape_util.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/util/work_sharder.h" namespace tensorflow { namespace addons { @@ -32,524 +30,306 @@ namespace addons { using CPUDevice = Eigen::ThreadPoolDevice; using GPUDevice = Eigen::GpuDevice; -namespace functor { - -template -struct DeformableConv2DFunctor - : public DeformableConv2DFunctorBase { - using DeformableConv2DFunctorBase::_input_tensor; - using DeformableConv2DFunctorBase::_filter_tensor; - using DeformableConv2DFunctorBase::_bias_tensor; - using DeformableConv2DFunctorBase::_offset_tensor; - using DeformableConv2DFunctorBase::_mask_tensor; - using DeformableConv2DFunctorBase::_column_buffer_tensor; - using DeformableConv2DFunctorBase::p; - - DeformableConv2DFunctor( - typename TTypes::ConstTensor input_tensor, - typename TTypes::ConstTensor filter_tensor, - typename TTypes::ConstTensor bias_tensor, - typename TTypes::ConstTensor offset_tensor, - typename TTypes::ConstTensor mask_tensor, - typename TTypes::Tensor column_buffer_tensor, - typename TTypes::Tensor output_tensor, - DeformableConv2DParams _p) - : DeformableConv2DFunctorBase( - input_tensor, filter_tensor, bias_tensor, offset_tensor, - mask_tensor, column_buffer_tensor, _p), - _output_tensor(output_tensor) { - _output_tensor.setZero(); - } - - Status operator()(OpKernelContext *context) { - const auto use_bias = _bias_tensor.dimension(0) > 0; - const auto batches = p.input_batches / p.parallel_imgs; - - auto filter_tensor = _filter_tensor.reshape( - Shape5D({p.weight_groups, p.output_channels / p.weight_groups, - p.filter_channels, p.filter_rows, p.filter_cols})); - - auto output_tensor = _output_tensor.reshape( - Shape5D({batches, p.weight_groups, p.output_channels / p.weight_groups, - p.parallel_imgs * p.output_rows, p.output_cols})); - - // input_channels * filter_rows * filter_cols / weight_groups == - // filter_channels * filter_rows * filter_cols - const auto elems = p.filter_channels * p.filter_rows * p.filter_cols; - const auto rows = p.output_channels / p.weight_groups; - const auto cols = p.parallel_imgs * p.output_rows * p.output_cols; - - auto column_buffer_tensor = - _column_buffer_tensor.reshape(Shape3D({p.weight_groups, elems, cols})); - - for (auto b = 0; b < batches; b++) { - auto output_tensor_batch = output_tensor.chip(b, 0); - - this->DeformableIm2Col(b); - - for (auto g = 0; g < p.weight_groups; g++) { - EigenTensor filter_mtx = - filter_tensor.chip(g, 0).reshape(Shape2D({rows, elems})); - EigenTensor column_buffer_mtx = - column_buffer_tensor.chip(g, 0); - - auto mtx_shape = Shape2D({rows, cols}); - Eigen::array, 1> product_dims = { - Eigen::IndexPair(1, 0)}; - - EigenTensor mul = - filter_mtx.contract(column_buffer_mtx, product_dims); - - output_tensor_batch.chip(g, 0).reshape(mtx_shape) += mul; - } - } - - auto output_tensor_transposed = - output_tensor - .reshape(Shape5D({batches, p.output_channels, p.parallel_imgs, - p.output_rows, p.output_cols})) - .shuffle(Shape5D({0, 2, 1, 3, 4})) - .reshape(Shape4D({p.input_batches, p.output_channels, p.output_rows, - p.output_cols})); - - _output_tensor = output_tensor_transposed.eval(); - - if (use_bias) { - auto bias_tensor_broadcasted = - _bias_tensor.reshape(Shape4D({1, p.output_channels, 1, 1})) - .broadcast( - Shape4D({p.input_batches, 1, p.output_rows, p.output_cols})); - - _output_tensor += bias_tensor_broadcasted; - } - - return Status::OK(); - } - - typename TTypes::Tensor _output_tensor; -}; - -template -struct DeformableConv2DGradFunctor - : public DeformableConv2DFunctorBase { - using DeformableConv2DFunctorBase::_input_tensor; - using DeformableConv2DFunctorBase::_filter_tensor; - using DeformableConv2DFunctorBase::_bias_tensor; - using DeformableConv2DFunctorBase::_offset_tensor; - using DeformableConv2DFunctorBase::_mask_tensor; - using DeformableConv2DFunctorBase::_column_buffer_tensor; - using DeformableConv2DFunctorBase::p; - - DeformableConv2DGradFunctor( - typename TTypes::ConstTensor input_tensor, - typename TTypes::ConstTensor filter_tensor, - typename TTypes::ConstTensor bias_tensor, - typename TTypes::ConstTensor offset_tensor, - typename TTypes::ConstTensor mask_tensor, - typename TTypes::ConstTensor output_grad_tensor, - typename TTypes::Tensor input_grad_tensor, - typename TTypes::Tensor filter_grad_tensor, - typename TTypes::Tensor bias_grad_tensor, - typename TTypes::Tensor offset_grad_tensor, - typename TTypes::Tensor mask_grad_tensor, - typename TTypes::Tensor column_buffer_tensor, - DeformableConv2DParams _p) - : DeformableConv2DFunctorBase( - input_tensor, filter_tensor, bias_tensor, offset_tensor, - mask_tensor, column_buffer_tensor, _p), - _output_grad_tensor(output_grad_tensor), - _input_grad_tensor(input_grad_tensor), - _filter_grad_tensor(filter_grad_tensor), - _bias_grad_tensor(bias_grad_tensor), - _offset_grad_tensor(offset_grad_tensor), - _mask_grad_tensor(mask_grad_tensor) { - _input_grad_tensor.setZero(); - _filter_grad_tensor.setZero(); - _column_buffer_tensor.setZero(); - } - - Status operator()(OpKernelContext *context) { - const auto use_bias = _bias_tensor.dimension(0) > 0; - - ComputeInputOffsetMaskGrad(); - - ComputeFilterGrad(); - - if (use_bias) { - _bias_grad_tensor.setConstant(Dtype(1)); - _bias_grad_tensor *= - _output_grad_tensor.sum(Eigen::array({0, 2, 3})); - } - - return Status::OK(); - } - - void ComputeFilterGrad() { - const auto batches = p.input_batches / p.parallel_imgs; - - auto filter_grad_tensor = _filter_grad_tensor.reshape( - Shape5D({p.weight_groups, p.output_channels / p.weight_groups, - p.filter_channels, p.filter_rows, p.filter_cols})); - - EigenTensor output_grad_tensor = - _output_grad_tensor - .reshape(Shape5D({batches, p.parallel_imgs, p.output_channels, - p.output_rows, p.output_cols})) - .shuffle(Shape5D({0, 2, 1, 3, 4})) - .reshape(Shape5D({batches, p.weight_groups, - p.output_channels / p.weight_groups, - p.parallel_imgs * p.output_rows, p.output_cols})); - - // input_channels * filter_rows * filter_cols / weight_groups == - // filter_channels * filter_rows * filter_cols - const auto elems = p.filter_channels * p.filter_rows * p.filter_cols; - const auto rows = p.output_channels / p.weight_groups; - const auto cols = p.parallel_imgs * p.output_rows * p.output_cols; - - auto column_buffer_tensor = - _column_buffer_tensor.reshape(Shape3D({p.weight_groups, elems, cols})); - - for (auto b = 0; b < batches; b++) { - auto output_grad_tensor_batch = output_grad_tensor.chip(b, 0); - - this->DeformableIm2Col(b); - - for (auto g = 0; g < p.weight_groups; g++) { - EigenTensor column_buffer_mtx = - column_buffer_tensor.chip(g, 0).shuffle(Shape2D({1, 0})); - - EigenTensor output_grad_mtx = - output_grad_tensor_batch.chip(g, 0).reshape(Shape2D({rows, cols})); - - Eigen::array, 1> product_dims = { - Eigen::IndexPair(1, 0)}; +#define EXTERN_TEMPLATE(T) \ + extern template Status Transpose( \ + OpKernelContext * ctx, const Tensor &in, \ + const gtl::ArraySlice perm, Tensor *out); +TF_CALL_float(EXTERN_TEMPLATE); +TF_CALL_double(EXTERN_TEMPLATE); +#undef EXTERN_TEMPLATE - EigenTensor mul = - output_grad_mtx.contract(column_buffer_mtx, product_dims); - - filter_grad_tensor.chip(g, 0).reshape(Shape2D({rows, elems})) += mul; - } - } - } - - void ComputeInputOffsetMaskGrad() { - auto batches = p.input_batches / p.parallel_imgs; - - EigenTensor filter_tensor = _filter_tensor.reshape( - Shape5D({p.weight_groups, p.output_channels / p.weight_groups, - p.filter_channels, p.filter_rows, p.filter_cols})); - - EigenTensor output_grad_tensor = - _output_grad_tensor - .reshape(Shape5D({batches, p.parallel_imgs, p.output_channels, - p.output_rows, p.output_cols})) - .shuffle(Shape5D({0, 2, 1, 3, 4})) - .reshape(Shape5D({batches, p.weight_groups, - p.output_channels / p.weight_groups, - p.parallel_imgs * p.output_rows, p.output_cols})); - - // input_channels * filter_rows * filter_cols / weight_groups == - // filter_channels * filter_rows * filter_cols - const auto rows = p.filter_channels * p.filter_rows * p.filter_cols; - const auto elems = p.output_channels / p.weight_groups; - const auto cols = p.parallel_imgs * p.output_rows * p.output_cols; - - auto column_buffer_tensor = - _column_buffer_tensor.reshape(Shape3D({p.weight_groups, rows, cols})); - - for (auto b = 0; b < batches; b++) { - _column_buffer_tensor.setZero(); - - auto output_grad_tensor_chipped = output_grad_tensor.chip(b, 0); - for (int g = 0; g < p.weight_groups; g++) { - EigenTensor filter_mtx = filter_tensor.chip(g, 0) - .reshape(Shape2D({elems, rows})) - .shuffle(Shape2D({1, 0})); - EigenTensor output_grad_mtx = - output_grad_tensor_chipped.chip(g, 0).reshape( - Shape2D({elems, cols})); - - Eigen::array, 1> product_dims = { - Eigen::IndexPair(1, 0)}; - - EigenTensor mul = - filter_mtx.contract(output_grad_mtx, product_dims); - - column_buffer_tensor.chip(g, 0) = mul; - } - - DeformableCol2ImForOffsetAndMask(b); - - DeformableCol2ImForInput(b); - } - } - - void DeformableCol2ImForOffsetAndMask(int32 b) { - auto use_mask = _mask_tensor.dimension(0) > 0; - auto batches = p.input_batches / p.parallel_imgs; - auto num_kernels = p.output_rows * p.output_cols * 2 * p.filter_rows * - p.filter_cols * p.offset_groups * p.parallel_imgs; - auto offset_channels = 2 * p.filter_rows * p.filter_cols * p.offset_groups; - - EigenTensor input_tensor = - _input_tensor - .reshape(Shape5D({batches, p.parallel_imgs, p.input_channels, - p.input_rows, p.input_cols})) - .chip(b, 0); - - EigenTensor offset_tensor = - _offset_tensor - .reshape(Shape8D({batches, p.parallel_imgs, p.offset_groups, - p.filter_rows, p.filter_cols, 2, p.output_rows, - p.output_cols})) - .chip(b, 0); - - EigenTensor mask_tensor = - use_mask ? static_cast>( - _mask_tensor - .reshape(Shape7D({batches, p.parallel_imgs, - p.offset_groups, p.filter_rows, - p.filter_cols, p.output_rows, - p.output_cols})) - .chip(b, 0)) - : _mask_tensor.reshape(Shape6D({0, 0, 0, 0, 0, 0})); - - EigenTensor column_buffer_tensor = _column_buffer_tensor.reshape( - Shape6D({p.input_channels, p.filter_rows, p.filter_cols, - p.parallel_imgs, p.output_rows, p.output_cols})); - - for (auto k = 0; k < num_kernels; k++) { - auto offset_grad_value = Dtype(0); - auto mask_grad_value = Dtype(0); - - const auto offset_channel_step = p.filter_rows * p.filter_cols; - - const auto current_output_col = k % p.output_cols; - const auto current_output_row = (k / p.output_cols) % p.output_rows; - const auto current_filter_col = - (k / (2 * p.output_rows * p.output_cols)) % p.filter_cols; - const auto current_filter_row = - (k / (2 * p.output_rows * p.output_cols * p.filter_cols)) % - p.filter_rows; - const auto current_offset_channel = - (k / (p.output_rows * p.output_cols)) % offset_channels; - const auto current_batch = - k / (p.output_rows * p.output_cols * offset_channels); - - auto current_actual_batch = b * p.parallel_imgs + current_batch; - - const auto current_offset_group = - current_offset_channel / (2 * offset_channel_step); - - const auto channels_per_offset_group = p.input_channels / p.offset_groups; - const auto offset_channel_diff = - current_offset_channel - - current_offset_group * 2 * offset_channel_step; - const auto is_y_direction = offset_channel_diff % 2 == 0; - - for (auto selected_offset_channel = (offset_channel_diff / 2); - selected_offset_channel < - channels_per_offset_group * offset_channel_step; - selected_offset_channel += offset_channel_step) { - const auto selected_filter_col = - selected_offset_channel % p.filter_cols; - const auto selected_filter_row = - (selected_offset_channel / p.filter_cols) % p.filter_rows; - const auto input_channel_diff = - (selected_offset_channel / (p.filter_cols * p.filter_rows)); - - const auto offset_h = offset_tensor( - current_batch, current_offset_group, selected_filter_row, - selected_filter_col, 0, current_output_row, current_output_col); - const auto offset_w = offset_tensor( - current_batch, current_offset_group, selected_filter_row, - selected_filter_col, 1, current_output_row, current_output_col); - const auto mask = use_mask - ? static_cast(mask_tensor( - current_batch, current_offset_group, - selected_filter_row, selected_filter_col, - current_output_row, current_output_col)) - : Dtype(1); - - const auto y = (current_output_row * p.stride_rows - p.padding_rows) + - selected_filter_row * p.dilation_rows + offset_h; - const auto x = (current_output_col * p.stride_cols - p.padding_cols) + - selected_filter_col * p.dilation_cols + offset_w; - - const auto selected_input_channel = - input_channel_diff + - current_offset_group * channels_per_offset_group; - - auto filter_data = column_buffer_tensor( - selected_input_channel, selected_filter_row, selected_filter_col, - current_batch, current_output_row, current_output_col); - - const auto weight = GetCoordinateWeight( - current_actual_batch, selected_input_channel, y, x, is_y_direction); - - offset_grad_value += mask * weight * filter_data; - - if (is_y_direction) { - mask_grad_value += filter_data * this->BilinearInterpolate( - current_actual_batch, - selected_input_channel, y, x); - } - } - - _offset_grad_tensor(current_actual_batch, current_offset_channel, - current_output_row, current_output_col) = - offset_grad_value; - - if (use_mask && is_y_direction) { - auto current_mask_channel = - (current_offset_group * p.filter_rows + current_filter_row) * - p.filter_cols + - current_filter_col; +namespace functor { - _mask_grad_tensor(current_actual_batch, current_mask_channel, - current_output_row, current_output_col) = - mask_grad_value; - } - } +#define EXTERN_TEMPLATE(T) \ + extern template struct DeformableConv2DForwardFunctor; \ + extern template struct DeformableConv2DGradFunctor; +TF_CALL_float(EXTERN_TEMPLATE); +TF_CALL_double(EXTERN_TEMPLATE); +#undef EXTERN_TEMPLATE + +#define IM2COL(T) \ + template <> \ + void DeformableConv2DFunctorBase::DeformableIm2Col( \ + OpKernelContext *context, int32 b) { \ + auto num_kernels = \ + p.input_channels * p.output_rows * p.output_cols * p.parallel_imgs; \ + \ + const auto offset_eigen_tensor = offset_tensor.tensor(); \ + \ + const auto mask_eigen_tensor = \ + p.use_mask ? mask_tensor.tensor() \ + : mask_tensor.shaped({0, 0, 0, 0, 0, 0, 0}); \ + \ + const auto input_eigen_tensor = input_tensor.tensor(); \ + \ + auto column_buffer_eigen_tensor = column_buffer_tensor.tensor(); \ + \ + for (auto k = 0; k < num_kernels; k++) { \ + const auto current_output_col = k % p.output_cols; \ + const auto current_output_row = (k / p.output_cols) % p.output_rows; \ + const auto current_batch = \ + (k / (p.output_rows * p.output_cols)) % p.parallel_imgs; \ + const auto current_input_channel = \ + k / (p.output_rows * p.output_cols * p.parallel_imgs); \ + const auto current_output_channel = \ + current_input_channel * p.filter_rows * p.filter_cols; \ + \ + const auto current_actual_batch = b * p.parallel_imgs + current_batch; \ + \ + const auto group_index = \ + current_input_channel / (p.input_channels / p.offset_groups); \ + \ + auto column_buffer_tensor_channel = current_output_channel; \ + for (auto current_filter_row = 0; current_filter_row < p.filter_rows; \ + current_filter_row++) { \ + for (auto current_filter_col = 0; current_filter_col < p.filter_cols; \ + current_filter_col++) { \ + auto offset_h = offset_eigen_tensor( \ + b, current_batch, group_index, current_filter_row, \ + current_filter_col, 0, current_output_row, current_output_col); \ + auto offset_w = offset_eigen_tensor( \ + b, current_batch, group_index, current_filter_row, \ + current_filter_col, 1, current_output_row, current_output_col); \ + \ + auto mask = p.use_mask ? mask_eigen_tensor( \ + b, current_batch, group_index, \ + current_filter_row, current_filter_col, \ + current_output_row, current_output_col) \ + : T(1); \ + \ + auto y = (current_output_row * p.stride_rows - p.padding_rows) + \ + current_filter_row * p.dilation_rows + offset_h; \ + auto x = (current_output_col * p.stride_cols - p.padding_cols) + \ + current_filter_col * p.dilation_cols + offset_w; \ + \ + column_buffer_eigen_tensor(column_buffer_tensor_channel, \ + current_batch, current_output_row, \ + current_output_col) = \ + mask * BilinearInterpolate(input_eigen_tensor, b, \ + current_actual_batch, \ + current_input_channel, y, x); \ + column_buffer_tensor_channel++; \ + } \ + } \ + } \ } - - void DeformableCol2ImForInput(int32 b) { - auto use_mask = _mask_tensor.dimension(0) > 0; - auto batches = p.input_batches / p.parallel_imgs; - auto num_kernels = p.input_channels * p.filter_rows * p.filter_cols * - p.output_rows * p.output_cols * p.parallel_imgs; - - EigenTensor offset_tensor = - _offset_tensor - .reshape(Shape8D({batches, p.parallel_imgs, p.offset_groups, - p.filter_rows, p.filter_cols, 2, p.output_rows, - p.output_cols})) - .chip(b, 0); - - EigenTensor mask_tensor = - use_mask ? static_cast>( - _mask_tensor - .reshape(Shape7D({batches, p.parallel_imgs, - p.offset_groups, p.filter_rows, - p.filter_cols, p.output_rows, - p.output_cols})) - .chip(b, 0)) - : _mask_tensor.reshape(Shape6D({0, 0, 0, 0, 0, 0})); - - EigenTensor column_buffer_tensor_flattened = - _column_buffer_tensor.reshape(Shape1D({num_kernels})); - - for (auto k = 0; k < num_kernels; k++) { - const auto current_output_col = k % p.output_cols; - const auto current_output_row = (k / p.output_cols) % p.output_rows; - const auto current_batch = - (k / (p.output_rows * p.output_cols)) % p.parallel_imgs; - - const auto current_filter_col = - (k / (p.output_rows * p.output_cols * p.parallel_imgs)) % - p.filter_cols; - const auto current_filter_row = (k / (p.output_rows * p.output_cols * - p.parallel_imgs * p.filter_cols)) % - p.filter_rows; - const auto current_channel = - k / (p.output_rows * p.output_cols * p.parallel_imgs * p.filter_rows * - p.filter_cols); - - const auto current_offset_group = - current_channel / (p.input_channels / p.offset_groups); - - auto mask = use_mask ? mask_tensor(current_batch, current_offset_group, - current_filter_row, current_filter_col, - current_output_row, current_output_col) - : Dtype(1); - - auto offset_h = offset_tensor(current_batch, current_offset_group, - current_filter_row, current_filter_col, 0, - current_output_row, current_output_col); - auto offset_w = offset_tensor(current_batch, current_offset_group, - current_filter_row, current_filter_col, 1, - current_output_row, current_output_col); - - const auto y = (current_output_row * p.stride_rows - p.padding_rows) + - current_filter_row * p.dilation_rows + offset_h; - const auto x = (current_output_col * p.stride_cols - p.padding_cols) + - current_filter_col * p.dilation_cols + offset_w; - - for (auto dy = -1; dy <= 1; dy++) { - for (auto dx = -1; dx <= 1; dx++) { - auto current_input_row = int(y) + dy; - auto current_input_col = int(x) + dx; - if (p.input_rows > current_input_row && current_input_row >= 0 && - p.input_cols > current_input_col && current_input_col >= 0 && - std::abs(y - current_input_row) < 1 && - std::abs(x - current_input_col) < 1) { - auto weight = (1.0 - std::abs(y - current_input_row)) * - (1.0 - std::abs(x - current_input_col)); - - auto current_actual_batch = b * p.parallel_imgs + current_batch; - - _input_grad_tensor(current_actual_batch, current_channel, - current_input_row, current_input_col) += - mask * weight * column_buffer_tensor_flattened(k); - } - } - } - } +TF_CALL_float(IM2COL); +TF_CALL_double(IM2COL); +#undef IM2COL + +#define COL2IM_OFFSET_AND_MASK(T) \ + template <> \ + void \ + DeformableConv2DGradFunctor::DeformableCol2ImForOffsetAndMask( \ + OpKernelContext *context, int32 b) { \ + const auto num_kernels = p.output_rows * p.output_cols * 2 * \ + p.filter_rows * p.filter_cols * p.offset_groups * \ + p.parallel_imgs; \ + \ + const auto offset_eigen_tensor = offset_tensor.template tensor(); \ + \ + const auto mask_eigen_tensor = \ + p.use_mask ? mask_tensor.template tensor() \ + : mask_tensor.template shaped({0, 0, 0, 0, 0, 0, 0}); \ + \ + const auto column_buffer_eigen_tensor = \ + column_buffer_tensor.template shaped( \ + {p.input_channels, p.filter_rows, p.filter_cols, p.parallel_imgs, \ + p.output_rows, p.output_cols}); \ + \ + auto offset_grad_eigen_tensor = offset_grad_tensor.tensor(); \ + auto mask_grad_eigen_tensor = mask_grad_tensor.tensor(); \ + \ + const auto input_eigen_tensor = input_tensor.tensor(); \ + \ + for (auto k = 0; k < num_kernels; k++) { \ + auto offset_grad_value = T(0); \ + auto mask_grad_value = T(0); \ + \ + const auto offset_channels = \ + 2 * p.filter_rows * p.filter_cols * p.offset_groups; \ + \ + const auto offset_channel_step = p.filter_rows * p.filter_cols; \ + \ + const auto current_output_col = k % p.output_cols; \ + const auto current_output_row = (k / p.output_cols) % p.output_rows; \ + const auto current_filter_col = \ + (k / (2 * p.output_rows * p.output_cols)) % p.filter_cols; \ + const auto current_filter_row = \ + (k / (2 * p.output_rows * p.output_cols * p.filter_cols)) % \ + p.filter_rows; \ + const auto current_offset_channel = \ + (k / (p.output_rows * p.output_cols)) % offset_channels; \ + const auto current_batch = \ + k / (p.output_rows * p.output_cols * offset_channels); \ + \ + const auto current_actual_batch = b * p.parallel_imgs + current_batch; \ + \ + const auto current_offset_group = \ + current_offset_channel / (2 * offset_channel_step); \ + \ + const auto channels_per_offset_group = \ + p.input_channels / p.offset_groups; \ + const auto offset_channel_diff = \ + current_offset_channel - \ + current_offset_group * 2 * offset_channel_step; \ + const auto is_y_direction = offset_channel_diff % 2 == 0; \ + \ + for (auto selected_offset_channel = (offset_channel_diff / 2); \ + selected_offset_channel < \ + channels_per_offset_group * offset_channel_step; \ + selected_offset_channel += offset_channel_step) { \ + const auto selected_filter_col = \ + selected_offset_channel % p.filter_cols; \ + const auto selected_filter_row = \ + (selected_offset_channel / p.filter_cols) % p.filter_rows; \ + const auto input_channel_diff = \ + (selected_offset_channel / (p.filter_cols * p.filter_rows)); \ + \ + const auto offset_h = offset_eigen_tensor( \ + b, current_batch, current_offset_group, selected_filter_row, \ + selected_filter_col, 0, current_output_row, current_output_col); \ + const auto offset_w = offset_eigen_tensor( \ + b, current_batch, current_offset_group, selected_filter_row, \ + selected_filter_col, 1, current_output_row, current_output_col); \ + const auto mask = \ + p.use_mask \ + ? mask_eigen_tensor(b, current_batch, current_offset_group, \ + selected_filter_row, selected_filter_col, \ + current_output_row, current_output_col) \ + : T(1); \ + \ + const auto y = (current_output_row * p.stride_rows - p.padding_rows) + \ + selected_filter_row * p.dilation_rows + offset_h; \ + const auto x = (current_output_col * p.stride_cols - p.padding_cols) + \ + selected_filter_col * p.dilation_cols + offset_w; \ + \ + const auto selected_input_channel = \ + input_channel_diff + \ + current_offset_group * channels_per_offset_group; \ + \ + const auto filter_data = column_buffer_eigen_tensor( \ + selected_input_channel, selected_filter_row, selected_filter_col, \ + current_batch, current_output_row, current_output_col); \ + \ + const auto weight = GetCoordinateWeight( \ + input_eigen_tensor, b, current_actual_batch, \ + selected_input_channel, y, x, is_y_direction); \ + \ + offset_grad_value += mask * weight * filter_data; \ + \ + if (is_y_direction) { \ + mask_grad_value += \ + filter_data * BilinearInterpolate( \ + input_eigen_tensor, b, current_actual_batch, \ + selected_input_channel, y, x); \ + } \ + } \ + \ + offset_grad_eigen_tensor(current_actual_batch, current_offset_channel, \ + current_output_row, current_output_col) = \ + offset_grad_value; \ + \ + if (p.use_mask && is_y_direction) { \ + const auto current_mask_channel = \ + (current_offset_group * p.filter_rows + current_filter_row) * \ + p.filter_cols + \ + current_filter_col; \ + \ + mask_grad_eigen_tensor(current_actual_batch, current_mask_channel, \ + current_output_row, current_output_col) = \ + mask_grad_value; \ + } \ + } \ } - - Dtype GetCoordinateWeight(int32 batch, int32 channel, Dtype y, Dtype x, - bool is_y_direction) { - EigenTensor img = _input_tensor.chip(batch, 0).chip(channel, 0); - - auto max_height = img.dimension(0); - auto max_width = img.dimension(1); - - int y_low = floor(y); - int x_low = floor(x); - int y_high = y_low + 1; - int x_high = x_low + 1; - - bool valid_y_low = max_height > y_low && y_low >= 0; - bool valid_y_high = max_height > y_high && y_high >= 0; - bool valid_x_low = max_width > x_low && x_low >= 0; - bool valid_x_high = max_width > x_high && x_high >= 0; - - auto v_yx = Dtype(0); - if (valid_y_low && valid_x_low) { - v_yx = img(y_low, x_low); - } - - auto v_yX = Dtype(0); - if (valid_y_low && valid_x_high) { - v_yX = img(y_low, x_high); - } - - auto v_Yx = Dtype(0); - if (valid_y_high && valid_x_low) { - v_Yx = img(y_high, x_low); - } - - auto v_YX = Dtype(0); - if (valid_y_high && valid_x_high) { - v_YX = img(y_high, x_high); - } - - if (is_y_direction) { - auto dx = x - x_low; - return (v_YX - v_yX) * dx + (v_Yx - v_yx) * (1 - dx); - } else { - auto dy = y - y_low; - return (v_YX - v_Yx) * dy + (v_yX - v_yx) * (1 - dy); - } +TF_CALL_float(COL2IM_OFFSET_AND_MASK); +TF_CALL_double(COL2IM_OFFSET_AND_MASK); +#undef COL2IM_OFFSET_AND_MASK + +#define COL2IM_INPUT(T) \ + template <> \ + void DeformableConv2DGradFunctor::DeformableCol2ImForInput( \ + OpKernelContext *context, int32 b) { \ + const auto num_kernels = p.input_channels * p.filter_rows * \ + p.filter_cols * p.output_rows * p.output_cols * \ + p.parallel_imgs; \ + \ + const auto offset_eigen_tensor = offset_tensor.template tensor(); \ + \ + const auto mask_eigen_tensor = \ + p.use_mask ? mask_tensor.template tensor() \ + : mask_tensor.template shaped({0, 0, 0, 0, 0, 0, 0}); \ + \ + const auto column_buffer_tensor_flattened = \ + column_buffer_tensor.template shaped({num_kernels}); \ + \ + auto input_grad_eigen_tensor = input_grad_tensor.tensor(); \ + \ + for (auto k = 0; k < num_kernels; k++) { \ + const auto current_output_col = k % p.output_cols; \ + const auto current_output_row = (k / p.output_cols) % p.output_rows; \ + const auto current_batch = \ + (k / (p.output_rows * p.output_cols)) % p.parallel_imgs; \ + \ + const auto current_filter_col = \ + (k / (p.output_rows * p.output_cols * p.parallel_imgs)) % \ + p.filter_cols; \ + const auto current_filter_row = \ + (k / (p.output_rows * p.output_cols * p.parallel_imgs * \ + p.filter_cols)) % \ + p.filter_rows; \ + const auto current_channel = \ + k / (p.output_rows * p.output_cols * p.parallel_imgs * \ + p.filter_rows * p.filter_cols); \ + \ + const auto current_offset_group = \ + current_channel / (p.input_channels / p.offset_groups); \ + \ + const auto mask = \ + p.use_mask \ + ? mask_eigen_tensor(b, current_batch, current_offset_group, \ + current_filter_row, current_filter_col, \ + current_output_row, current_output_col) \ + : T(1); \ + \ + const auto offset_h = offset_eigen_tensor( \ + b, current_batch, current_offset_group, current_filter_row, \ + current_filter_col, 0, current_output_row, current_output_col); \ + const auto offset_w = offset_eigen_tensor( \ + b, current_batch, current_offset_group, current_filter_row, \ + current_filter_col, 1, current_output_row, current_output_col); \ + \ + const auto y = (current_output_row * p.stride_rows - p.padding_rows) + \ + current_filter_row * p.dilation_rows + offset_h; \ + const auto x = (current_output_col * p.stride_cols - p.padding_cols) + \ + current_filter_col * p.dilation_cols + offset_w; \ + \ + for (auto dy = -1; dy <= 1; dy++) { \ + for (auto dx = -1; dx <= 1; dx++) { \ + const auto current_input_row = int(y) + dy; \ + const auto current_input_col = int(x) + dx; \ + \ + if (p.input_rows > current_input_row && current_input_row >= 0 && \ + p.input_cols > current_input_col && current_input_col >= 0 && \ + std::abs(y - current_input_row) < 1 && \ + std::abs(x - current_input_col) < 1) { \ + const auto weight = (1.0 - std::abs(y - current_input_row)) * \ + (1.0 - std::abs(x - current_input_col)); \ + \ + const auto current_actual_batch = \ + b * p.parallel_imgs + current_batch; \ + \ + input_grad_eigen_tensor(current_actual_batch, current_channel, \ + current_input_row, current_input_col) += \ + mask * weight * column_buffer_tensor_flattened(k); \ + } \ + } \ + } \ + } \ } - - typename TTypes::ConstTensor _output_grad_tensor; - typename TTypes::Tensor _input_grad_tensor; - typename TTypes::Tensor _filter_grad_tensor; - typename TTypes::Tensor _bias_grad_tensor; - typename TTypes::Tensor _offset_grad_tensor; - typename TTypes::Tensor _mask_grad_tensor; -}; +TF_CALL_float(COL2IM_INPUT); +TF_CALL_double(COL2IM_INPUT); +#undef COL2IM_INPUT } // end namespace functor @@ -574,26 +354,29 @@ class DeformableConv2DOpBase : public OpKernel { const Tensor &input_tensor = context->input(0); const Tensor &filter_tensor = context->input(1); + const Tensor &bias_tensor = context->input(2); + const Tensor &mask_tensor = context->input(4); + const TensorShape &input_shape = input_tensor.shape(); const TensorShape &filter_shape = filter_tensor.shape(); - auto input_batches = input_shape.dim_size(0); - auto input_channels = input_shape.dim_size(1); - auto input_rows = input_shape.dim_size(2); - auto input_cols = input_shape.dim_size(3); + const auto input_batches = input_shape.dim_size(0); + const auto input_channels = input_shape.dim_size(1); + const auto input_rows = input_shape.dim_size(2); + const auto input_cols = input_shape.dim_size(3); - auto output_channels = filter_shape.dim_size(0); - auto filter_channels = filter_shape.dim_size(1); - auto filter_rows = filter_shape.dim_size(2); - auto filter_cols = filter_shape.dim_size(3); + const auto output_channels = filter_shape.dim_size(0); + const auto filter_channels = filter_shape.dim_size(1); + const auto filter_rows = filter_shape.dim_size(2); + const auto filter_cols = filter_shape.dim_size(3); - auto dilation_rows = dilations[0]; - auto dilation_cols = dilations[1]; + const auto dilation_rows = dilations[0]; + const auto dilation_cols = dilations[1]; - auto stride_rows = strides[0]; - auto stride_cols = strides[1]; + const auto stride_rows = strides[0]; + const auto stride_cols = strides[1]; - auto parallel_imgs = GetParallelImgs(input_batches); + const auto parallel_imgs = GetParallelImgs(input_batches); int64 output_rows, output_cols; int64 padding_rows, padding_cols; @@ -625,6 +408,9 @@ class DeformableConv2DOpBase : public OpKernel { p.parallel_imgs = parallel_imgs; p.weight_groups = weight_groups; p.offset_groups = offset_groups; + p.batches = p.input_batches / p.parallel_imgs; + p.use_mask = mask_tensor.NumElements() > 0; + p.use_bias = bias_tensor.NumElements() > 0; } int GetParallelImgs(int n) { @@ -649,12 +435,12 @@ class DeformableConv2DOpBase : public OpKernel { }; template -class DeformableConv2DOp : public DeformableConv2DOpBase { +class DeformableConv2DForwardOp : public DeformableConv2DOpBase { using DeformableConv2DOpBase::data_format; using DeformableConv2DOpBase::p; public: - explicit DeformableConv2DOp(OpKernelConstruction *context) + explicit DeformableConv2DForwardOp(OpKernelConstruction *context) : DeformableConv2DOpBase(context) {} void Compute(OpKernelContext *context) override { @@ -681,11 +467,9 @@ class DeformableConv2DOp : public DeformableConv2DOpBase { OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output_tensor)); - functor::DeformableConv2DFunctor deformableConv2DFunc( - input_tensor.tensor(), filter_tensor.tensor(), - bias_tensor.tensor(), offset_tensor.tensor(), - mask_tensor.tensor(), column_buffer_tensor.tensor(), - output_tensor->tensor(), p); + functor::DeformableConv2DForwardFunctor deformableConv2DFunc( + &input_tensor, &filter_tensor, &bias_tensor, &offset_tensor, + &mask_tensor, &column_buffer_tensor, output_tensor, &p); Status s = deformableConv2DFunc(context); OP_REQUIRES_OK(context, s); @@ -725,6 +509,25 @@ class DeformableConv2DGradOp : public DeformableConv2DOpBase { column_buffer_shape, &column_buffer_tensor)); + Tensor output_grad_tensor_reshaped; + CHECK(output_grad_tensor_reshaped.CopyFrom( + output_grad_tensor, + TensorShape({p.batches, p.parallel_imgs, p.output_channels, + p.output_rows, p.output_cols}))); + + TensorShape output_grad_tensor_transposed_shape( + {p.batches, p.output_channels, p.parallel_imgs, p.output_rows, + p.output_cols}); + Tensor output_grad_tensor_transposed; + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum::value, + output_grad_tensor_transposed_shape, + &output_grad_tensor_transposed)); + OP_REQUIRES_OK(context, + Transpose(context, output_grad_tensor_reshaped, + {0, 2, 1, 3, 4}, + &output_grad_tensor_transposed)); + TensorShape output_shape = ShapeFromFormat(data_format, p.input_batches, p.output_rows, p.output_cols, p.output_channels); @@ -746,13 +549,10 @@ class DeformableConv2DGradOp : public DeformableConv2DOpBase { context->allocate_output(4, mask_shape, &mask_grad_tensor)); functor::DeformableConv2DGradFunctor deformableConv2DGradFunc( - input_tensor.tensor(), filter_tensor.tensor(), - bias_tensor.tensor(), offset_tensor.tensor(), - mask_tensor.tensor(), output_grad_tensor.tensor(), - input_grad_tensor->tensor(), filter_grad_tensor->tensor(), - bias_grad_tensor->tensor(), offset_grad_tensor->tensor(), - mask_grad_tensor->tensor(), column_buffer_tensor.tensor(), - p); + &input_tensor, &filter_tensor, &bias_tensor, &offset_tensor, + &mask_tensor, &output_grad_tensor_transposed, input_grad_tensor, + filter_grad_tensor, bias_grad_tensor, offset_grad_tensor, + mask_grad_tensor, &column_buffer_tensor, &p); Status s = deformableConv2DGradFunc(context); OP_REQUIRES_OK(context, s); @@ -760,19 +560,34 @@ class DeformableConv2DGradOp : public DeformableConv2DOpBase { }; // Register the CPU kernels. -#define REGISTER_DEFORMABLECONV2D_OP_CPU(T) \ - REGISTER_KERNEL_BUILDER(Name("Addons>DeformableConv2D") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T"), \ - DeformableConv2DOp) \ - REGISTER_KERNEL_BUILDER(Name("Addons>DeformableConv2DGrad") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T"), \ +#define REGISTER_DEFORMABLECONV2D_OP_CPU(T) \ + REGISTER_KERNEL_BUILDER(Name("Addons>DeformableConv2D") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + DeformableConv2DForwardOp) \ + REGISTER_KERNEL_BUILDER(Name("Addons>DeformableConv2DGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ DeformableConv2DGradOp) TF_CALL_float(REGISTER_DEFORMABLECONV2D_OP_CPU); TF_CALL_double(REGISTER_DEFORMABLECONV2D_OP_CPU); #undef REGISTER_DEFORMABLECONV2D_OP_CPU +// Register the GPU kernels. +#define REGISTER_DEFORMABLECONV2D_OP_GPU(T) \ + REGISTER_KERNEL_BUILDER(Name("Addons>DeformableConv2D") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T"), \ + DeformableConv2DForwardOp) \ + REGISTER_KERNEL_BUILDER(Name("Addons>DeformableConv2DGrad") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T"), \ + DeformableConv2DGradOp) + +TF_CALL_float(REGISTER_DEFORMABLECONV2D_OP_GPU); +TF_CALL_double(REGISTER_DEFORMABLECONV2D_OP_GPU); +#undef REGISTER_DEFORMABLECONV2D_OP_GPU + } // namespace addons } // namespace tensorflow diff --git a/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.h b/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.h index fe9bba1f32..205a6812b6 100644 --- a/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.h +++ b/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.h @@ -17,25 +17,11 @@ #define TENSORFLOW_ADDONS_LAYERS_KERNELS_DEFORMABLECONV2D_OP_H_ #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/batch_matmul_op_impl.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { namespace addons { -using Shape8D = Eigen::array; -using Shape7D = Eigen::array; -using Shape6D = Eigen::array; -using Shape5D = Eigen::array; -using Shape4D = Eigen::array; -using Shape3D = Eigen::array; -using Shape2D = Eigen::array; -using Shape1D = Eigen::array; - -template -using EigenTensor = Eigen::Tensor; -template -using EigenTensorRef = - Eigen::TensorRef>; - static const int kMaxParallelImgs = 32; struct DeformableConv2DParams { @@ -58,210 +44,550 @@ struct DeformableConv2DParams { int32 parallel_imgs; int32 weight_groups; int32 offset_groups; + int32 batches; + bool use_mask; + bool use_bias; }; +template +Status TensorSetZero(OpKernelContext *ctx, Tensor *value) { + const auto d = ctx->template eigen_device(); + auto out = value->flat(); + + const bool use_64bit = out.size() > Eigen::NumTraits::highest(); + + if (!use_64bit && Eigen::internal::is_same::value) { + To32Bit(out).device(d) = To32Bit(out).constant(T(0)); + } else { + out.device(d) = out.constant(T(0)); + } + + return Status::OK(); +} + +template +Status AddToTensor(OpKernelContext *ctx, Tensor *sum, const Tensor *current, + const Tensor *add) { + const auto d = ctx->template eigen_device(); + + auto out = sum->flat(); + auto a = current->flat(); + auto b = add->flat(); + + const bool use_64bit = out.size() > Eigen::NumTraits::highest(); + + if (!use_64bit && Eigen::internal::is_same::value) { + To32Bit(out).device(d) = To32Bit(a) + To32Bit(b); + } else { + out.device(d) = a + b; + } + + return Status::OK(); +} + +template +Status Transpose(OpKernelContext *ctx, const Tensor &in, + const gtl::ArraySlice perm, Tensor *out) { + const auto d = ctx->template eigen_device(); + + Eigen::array p; + for (int i = 0; i < NDIMS; ++i) { + p[i] = perm[i]; + } + + auto x = typename TTypes::ConstTensor( + reinterpret_cast(in.tensor_data().data()), + in.shape().AsEigenDSizes()); + auto y = typename TTypes::Tensor( + reinterpret_cast(const_cast(out->tensor_data().data())), + out->shape().AsEigenDSizes()); + + const bool use_64bit = x.size() > Eigen::NumTraits::highest(); + + if (!use_64bit && Eigen::internal::is_same::value) { + To32Bit(y).device(d) = To32Bit(x).shuffle(p); + } else { + y.device(d) = x.shuffle(p); + } + + return Status::OK(); +} + +template +Status CopySliceToElement(OpKernelContext *ctx, const Tensor &parent, + Tensor *element, int64 index) { + const auto d = ctx->template eigen_device(); + + auto out = element->flat(); + auto in = parent.flat_outer_dims(); + + const bool use_64bit = in.size() > Eigen::NumTraits::highest(); + + if (!use_64bit && Eigen::internal::is_same::value) { + To32Bit(out).device(d) = To32Bit(in).chip(index, 0); + } else { + out.device(d) = in.chip(index, 0); + } + + return Status::OK(); +} + +template +Status CopyElementToSlice(OpKernelContext *ctx, Tensor element, Tensor *parent, + int64 index) { + const auto d = ctx->template eigen_device(); + + auto out = parent->flat_outer_dims(); + auto in = element.flat(); + + const bool use_64bit = out.size() > Eigen::NumTraits::highest(); + + if (!use_64bit && Eigen::internal::is_same::value) { + To32Bit(out).chip(index, 0).device(d) = To32Bit(in); + } else { + out.chip(index, 0).device(d) = in; + } + + return Status::OK(); +} + namespace functor { -template -struct DeformableConv2DFunctorBase { - DeformableConv2DFunctorBase( - typename TTypes::ConstTensor input_tensor, - typename TTypes::ConstTensor filter_tensor, - typename TTypes::ConstTensor bias_tensor, - typename TTypes::ConstTensor offset_tensor, - typename TTypes::ConstTensor mask_tensor, - typename TTypes::Tensor column_buffer_tensor, - DeformableConv2DParams _p) - : _input_tensor(input_tensor), - _filter_tensor(filter_tensor), - _bias_tensor(bias_tensor), - _offset_tensor(offset_tensor), - _mask_tensor(mask_tensor), - _column_buffer_tensor(column_buffer_tensor), - p(_p) {} - - virtual Status operator()(OpKernelContext* context) = 0; - - Dtype BilinearInterpolate(int32 batch, int32 channel, Dtype y, Dtype x) { - EigenTensor img = _input_tensor.chip(batch, 0).chip(channel, 0); - - auto max_height = img.dimension(0); - auto max_width = img.dimension(1); - - if (y <= -1 || max_height <= y || x <= -1 || max_width <= x) { - return Dtype(0); - } +template +EIGEN_DEVICE_FUNC T BilinearInterpolate(typename TTypes::Tensor img, + int32 b, int32 batch, int32 channel, + T y, T x) { + const auto max_height = img.dimension(3); + const auto max_width = img.dimension(4); - int y_low = floor(y); - int x_low = floor(x); - int y_high = y_low + 1; - int w_high = x_low + 1; + if (y <= -1 || max_height <= y || x <= -1 || max_width <= x) { + return T(0); + } - auto v1 = Dtype(0); - if (y_low >= 0 && x_low >= 0) { - v1 = img(y_low, x_low); - } + int y_low = floor(y); + int x_low = floor(x); + int y_high = y_low + 1; + int w_high = x_low + 1; - auto v2 = Dtype(0); - if (y_low >= 0 && w_high <= max_width - 1) { - v2 = img(y_low, w_high); - } + auto v1 = T(0); + if (y_low >= 0 && x_low >= 0) { + v1 = img(b, batch, channel, y_low, x_low); + } - auto v3 = Dtype(0); - if (y_high <= max_height - 1 && x_low >= 0) { - v3 = img(y_high, x_low); - } + auto v2 = T(0); + if (y_low >= 0 && w_high <= max_width - 1) { + v2 = img(b, batch, channel, y_low, w_high); + } - auto v4 = Dtype(0); - if (y_high <= max_height - 1 && w_high <= max_width - 1) { - v4 = img(y_high, w_high); - } + auto v3 = T(0); + if (y_high <= max_height - 1 && x_low >= 0) { + v3 = img(b, batch, channel, y_high, x_low); + } - auto lh = y - y_low; - auto lw = x - x_low; - auto hh = 1 - lh; - auto hw = 1 - lw; + auto v4 = T(0); + if (y_high <= max_height - 1 && w_high <= max_width - 1) { + v4 = img(b, batch, channel, y_high, w_high); + } - auto w1 = hh * hw; - auto w2 = hh * lw; - auto w3 = lh * hw; - auto w4 = lh * lw; + auto lh = y - y_low; + auto lw = x - x_low; + auto hh = 1 - lh; + auto hw = 1 - lw; + + auto w1 = hh * hw; + auto w2 = hh * lw; + auto w3 = lh * hw; + auto w4 = lh * lw; + + return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; +} + +template +EIGEN_DEVICE_FUNC T GetCoordinateWeight(typename TTypes::Tensor img, + int32 b, int32 batch, int32 channel, + T y, T x, bool is_y_direction) { + const auto max_height = img.dimension(3); + const auto max_width = img.dimension(4); + + const int y_low = floor(y); + const int x_low = floor(x); + const int y_high = y_low + 1; + const int x_high = x_low + 1; + + const bool valid_y_low = max_height > y_low && y_low >= 0; + const bool valid_y_high = max_height > y_high && y_high >= 0; + const bool valid_x_low = max_width > x_low && x_low >= 0; + const bool valid_x_high = max_width > x_high && x_high >= 0; + + auto v_yx = T(0); + if (valid_y_low && valid_x_low) { + v_yx = img(b, batch, channel, y_low, x_low); + } - return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; + auto v_yX = T(0); + if (valid_y_low && valid_x_high) { + v_yX = img(b, batch, channel, y_low, x_high); } - void DeformableIm2Col(int32 b) { - auto use_mask = _mask_tensor.dimension(0) > 0; - auto num_kernels = - p.input_channels * p.output_rows * p.output_cols * p.parallel_imgs; - auto batches = p.input_batches / p.parallel_imgs; - - EigenTensor offset_tensor = - _offset_tensor - .reshape(Shape8D({batches, p.parallel_imgs, p.offset_groups, - p.filter_rows, p.filter_cols, 2, p.output_rows, - p.output_cols})) - .chip(b, 0); - - EigenTensor mask_tensor = - use_mask ? static_cast>( - _mask_tensor - .reshape(Shape7D({batches, p.parallel_imgs, - p.offset_groups, p.filter_rows, - p.filter_cols, p.output_rows, - p.output_cols})) - .chip(b, 0)) - : _mask_tensor.reshape(Shape6D({0, 0, 0, 0, 0, 0})); - - for (auto k = 0; k < num_kernels; k++) { - const auto current_output_col = k % p.output_cols; - const auto current_output_row = (k / p.output_cols) % p.output_rows; - const auto current_batch = - (k / (p.output_rows * p.output_cols)) % p.parallel_imgs; - const auto current_input_channel = - k / (p.output_rows * p.output_cols * p.parallel_imgs); - const auto current_output_channel = - current_input_channel * p.filter_rows * p.filter_cols; - - auto current_actual_batch = b * p.parallel_imgs + current_batch; - - const auto group_index = - current_input_channel / (p.input_channels / p.offset_groups); - - EigenTensor offset_tensor_chipped = - offset_tensor.chip(current_batch, 0).chip(group_index, 0); - - EigenTensor mask_tensor_chipped = - use_mask - ? static_cast>( - mask_tensor.chip(current_batch, 0).chip(group_index, 0)) - : mask_tensor.reshape(Shape4D({0, 0, 0, 0})); - - auto column_buffer_tensor_channel = current_output_channel; - for (auto current_filter_row = 0; current_filter_row < p.filter_rows; - current_filter_row++) { - for (auto current_filter_col = 0; current_filter_col < p.filter_cols; - current_filter_col++) { - auto offset_h = - offset_tensor_chipped(current_filter_row, current_filter_col, 0, - current_output_row, current_output_col); - auto offset_w = - offset_tensor_chipped(current_filter_row, current_filter_col, 1, - current_output_row, current_output_col); - - auto mask = use_mask ? mask_tensor_chipped( - current_filter_row, current_filter_col, - current_output_row, current_output_col) - : Dtype(1); - - auto y = (current_output_row * p.stride_rows - p.padding_rows) + - current_filter_row * p.dilation_rows + offset_h; - auto x = (current_output_col * p.stride_cols - p.padding_cols) + - current_filter_col * p.dilation_cols + offset_w; - - _column_buffer_tensor(column_buffer_tensor_channel, current_batch, - current_output_row, current_output_col) = - mask * BilinearInterpolate(current_actual_batch, - current_input_channel, y, x); - column_buffer_tensor_channel++; - } - } + auto v_Yx = T(0); + if (valid_y_high && valid_x_low) { + v_Yx = img(b, batch, channel, y_high, x_low); + } + + auto v_YX = T(0); + if (valid_y_high && valid_x_high) { + v_YX = img(b, batch, channel, y_high, x_high); + } + + if (is_y_direction) { + const auto dx = x - x_low; + return (v_YX - v_yX) * dx + (v_Yx - v_yx) * (1 - dx); + } else { + const auto dy = y - y_low; + return (v_YX - v_Yx) * dy + (v_yX - v_yx) * (1 - dy); + } +} + +template +struct DeformableConv2DFunctorBase { + DeformableConv2DFunctorBase(const Tensor *_input_tensor, + const Tensor *_filter_tensor, + const Tensor *_bias_tensor, + const Tensor *_offset_tensor, + const Tensor *_mask_tensor, + Tensor *_column_buffer_tensor, + DeformableConv2DParams *_p) + : input_tensor(_input_tensor->dtype()), + filter_tensor(_filter_tensor->dtype()), + bias_tensor(_bias_tensor->dtype()), + offset_tensor(_offset_tensor->dtype()), + mask_tensor(_mask_tensor->dtype()), + column_buffer_tensor(_column_buffer_tensor->dtype()), + p(*_p) { + CHECK(input_tensor.CopyFrom( + *_input_tensor, + TensorShape({p.batches, p.parallel_imgs, p.input_channels, p.input_rows, + p.input_cols}))); + CHECK(filter_tensor.CopyFrom( + *_filter_tensor, + TensorShape({p.weight_groups, p.output_channels / p.weight_groups, + p.filter_channels * p.filter_rows * p.filter_cols}))); + CHECK(bias_tensor.CopyFrom(*_bias_tensor, _bias_tensor->shape())); + + CHECK(offset_tensor.CopyFrom( + *_offset_tensor, + TensorShape({p.batches, p.parallel_imgs, p.offset_groups, p.filter_rows, + p.filter_cols, 2, p.output_rows, p.output_cols}))); + + if (p.use_mask) { + CHECK(mask_tensor.CopyFrom( + *_mask_tensor, + TensorShape({p.batches, p.parallel_imgs, p.offset_groups, + p.filter_rows, p.filter_cols, p.output_rows, + p.output_cols}))); + } else { + CHECK(mask_tensor.CopyFrom(*_mask_tensor, + TensorShape({0, 0, 0, 0, 0, 0, 0}))); } + + CHECK(column_buffer_tensor.CopyFrom( + *_column_buffer_tensor, + TensorShape({p.input_channels * p.filter_rows * p.filter_cols, + p.parallel_imgs, p.output_rows, p.output_cols}))); } - typename TTypes::ConstTensor _input_tensor; - typename TTypes::ConstTensor _filter_tensor; - typename TTypes::ConstTensor _bias_tensor; - typename TTypes::ConstTensor _offset_tensor; - typename TTypes::ConstTensor _mask_tensor; - typename TTypes::Tensor _column_buffer_tensor; + virtual Status operator()(OpKernelContext *context) = 0; + + void DeformableIm2Col(OpKernelContext *context, int32 b); + + Tensor input_tensor; + Tensor filter_tensor; + Tensor bias_tensor; + Tensor offset_tensor; + Tensor mask_tensor; + Tensor column_buffer_tensor; DeformableConv2DParams p; }; -template -struct DeformableConv2DFunctor - : public DeformableConv2DFunctorBase { - DeformableConv2DFunctor( - typename TTypes::ConstTensor input_tensor, - typename TTypes::ConstTensor filter_tensor, - typename TTypes::ConstTensor bias_tensor, - typename TTypes::ConstTensor offset_tensor, - typename TTypes::ConstTensor mask_tensor, - typename TTypes::Tensor column_buffer_tensor, - typename TTypes::Tensor output_tensor, - DeformableConv2DParams _p); - - Status operator()(OpKernelContext* context); - - typename TTypes::Tensor _output_tensor; +template +struct DeformableConv2DForwardFunctor + : public DeformableConv2DFunctorBase { + using DeformableConv2DFunctorBase::input_tensor; + using DeformableConv2DFunctorBase::filter_tensor; + using DeformableConv2DFunctorBase::bias_tensor; + using DeformableConv2DFunctorBase::offset_tensor; + using DeformableConv2DFunctorBase::mask_tensor; + using DeformableConv2DFunctorBase::column_buffer_tensor; + using DeformableConv2DFunctorBase::p; + + DeformableConv2DForwardFunctor( + const Tensor *_input_tensor, const Tensor *_filter_tensor, + const Tensor *_bias_tensor, const Tensor *_offset_tensor, + const Tensor *_mask_tensor, Tensor *_column_buffer_tensor, + Tensor *_output_tensor, DeformableConv2DParams *_p) + : DeformableConv2DFunctorBase( + _input_tensor, _filter_tensor, _bias_tensor, _offset_tensor, + _mask_tensor, _column_buffer_tensor, _p), + output_tensor(_output_tensor->dtype()) { + CHECK(output_tensor.CopyFrom(*_output_tensor, _output_tensor->shape())); + } + + Status operator()(OpKernelContext *context) { + // input_channels * filter_rows * filter_cols / weight_groups == + // filter_channels * filter_rows * filter_cols + const auto elems = p.filter_channels * p.filter_rows * p.filter_cols; + const auto rows = p.output_channels / p.weight_groups; + const auto cols = p.parallel_imgs * p.output_rows * p.output_cols; + + Tensor output_tmp_tensor; + TF_RETURN_IF_ERROR(context->allocate_temp( + DataTypeToEnum::value, + TensorShape({p.batches, p.output_channels, p.parallel_imgs, + p.output_rows, p.output_cols}), + &output_tmp_tensor)); + + Tensor output_tmp_mtx_tensor; + TF_RETURN_IF_ERROR(context->allocate_temp( + DataTypeToEnum::value, TensorShape({p.weight_groups, rows, cols}), + &output_tmp_mtx_tensor)); + + Tensor output_tmp_tensor_reshaped(output_tmp_tensor.dtype()); + CHECK(output_tmp_tensor_reshaped.CopyFrom( + output_tmp_tensor, + TensorShape({p.batches, p.weight_groups, rows, cols}))); + + Tensor column_buffer_tensor_reshaped(column_buffer_tensor.dtype()); + CHECK(column_buffer_tensor_reshaped.CopyFrom( + column_buffer_tensor, TensorShape({p.weight_groups, elems, cols}))); + + for (auto b = 0; b < p.batches; b++) { + this->DeformableIm2Col(context, b); + + auto lhs = filter_tensor; + auto rhs = column_buffer_tensor_reshaped; + auto out = output_tmp_mtx_tensor; + + MatMulBCast bcast(lhs.shape().dim_sizes(), rhs.shape().dim_sizes()); + + LaunchBatchMatMul::Launch(context, lhs, rhs, false, false, + false, false, bcast, &out); + + TF_RETURN_IF_ERROR(CopyElementToSlice( + context, out, &output_tmp_tensor_reshaped, b)); + } + + Tensor output_tensor_reshaped(output_tensor.dtype()); + CHECK(output_tensor_reshaped.CopyFrom( + output_tensor, + TensorShape({p.batches, p.parallel_imgs, p.output_channels, + p.output_rows, p.output_cols}))); + TF_RETURN_IF_ERROR(Transpose( + context, output_tmp_tensor, {0, 2, 1, 3, 4}, &output_tensor_reshaped)); + + if (p.use_bias) { + Eigen::DSizes four_dims(1, p.output_channels, 1, 1); + Eigen::DSizes broadcast_dims(p.input_batches, 1, p.output_rows, + p.output_cols); + + const auto d = context->eigen_device(); + + auto out = output_tensor.tensor(); + auto add = bias_tensor.template tensor(); + + const bool use_64bit = out.size() > Eigen::NumTraits::highest(); + + if (!use_64bit && + Eigen::internal::is_same::value) { + To32Bit(out).device(d) += + To32Bit(add).reshape(four_dims).broadcast(broadcast_dims); + } else { + out.device(d) += add.reshape(four_dims).broadcast(broadcast_dims); + } + } + + return Status::OK(); + } + + Tensor output_tensor; }; -template +template struct DeformableConv2DGradFunctor - : public DeformableConv2DFunctorBase { + : public DeformableConv2DFunctorBase { + using DeformableConv2DFunctorBase::input_tensor; + using DeformableConv2DFunctorBase::filter_tensor; + using DeformableConv2DFunctorBase::bias_tensor; + using DeformableConv2DFunctorBase::offset_tensor; + using DeformableConv2DFunctorBase::mask_tensor; + using DeformableConv2DFunctorBase::column_buffer_tensor; + using DeformableConv2DFunctorBase::p; + DeformableConv2DGradFunctor( - typename TTypes::ConstTensor input_tensor, - typename TTypes::ConstTensor filter_tensor, - typename TTypes::ConstTensor bias_tensor, - typename TTypes::ConstTensor offset_tensor, - typename TTypes::ConstTensor mask_tensor, - typename TTypes::ConstTensor output_grad_tensor, - typename TTypes::Tensor input_grad_tensor, - typename TTypes::Tensor filter_grad_tensor, - typename TTypes::Tensor bias_grad_tensor, - typename TTypes::Tensor offset_grad_tensor, - typename TTypes::Tensor mask_grad_tensor, - typename TTypes::Tensor column_buffer_tensor, - DeformableConv2DParams p); - - Status operator()(OpKernelContext* context); - - typename TTypes::ConstTensor _output_grad_tensor; - typename TTypes::Tensor _input_grad_tensor; - typename TTypes::Tensor _filter_grad_tensor; - typename TTypes::Tensor _bias_grad_tensor; - typename TTypes::Tensor _offset_grad_tensor; - typename TTypes::Tensor _mask_grad_tensor; + const Tensor *_input_tensor, const Tensor *_filter_tensor, + const Tensor *_bias_tensor, const Tensor *_offset_tensor, + const Tensor *_mask_tensor, Tensor *_output_grad_tensor, + Tensor *_input_grad_tensor, Tensor *_filter_grad_tensor, + Tensor *_bias_grad_tensor, Tensor *_offset_grad_tensor, + Tensor *_mask_grad_tensor, Tensor *_column_buffer_tensor, + DeformableConv2DParams *_p) + : DeformableConv2DFunctorBase( + _input_tensor, _filter_tensor, _bias_tensor, _offset_tensor, + _mask_tensor, _column_buffer_tensor, _p), + output_grad_tensor(_output_grad_tensor->dtype()), + input_grad_tensor(_input_grad_tensor->dtype()), + filter_grad_tensor(_filter_grad_tensor->dtype()), + bias_grad_tensor(_bias_grad_tensor->dtype()), + offset_grad_tensor(_offset_grad_tensor->dtype()), + mask_grad_tensor(_mask_grad_tensor->dtype()) { + CHECK(output_grad_tensor.CopyFrom(*_output_grad_tensor, + _output_grad_tensor->shape())); + CHECK(input_grad_tensor.CopyFrom(*_input_grad_tensor, + _input_grad_tensor->shape())); + CHECK(filter_grad_tensor.CopyFrom( + *_filter_grad_tensor, + TensorShape({p.weight_groups, p.output_channels / p.weight_groups, + p.filter_channels * p.filter_rows * p.filter_cols}))); + CHECK(bias_grad_tensor.CopyFrom(*_bias_grad_tensor, + _bias_grad_tensor->shape())); + CHECK(offset_grad_tensor.CopyFrom(*_offset_grad_tensor, + _offset_grad_tensor->shape())); + CHECK(mask_grad_tensor.CopyFrom(*_mask_grad_tensor, + _mask_grad_tensor->shape())); + } + + Status operator()(OpKernelContext *context) { + TF_RETURN_IF_ERROR(TensorSetZero(context, &input_grad_tensor)); + TF_RETURN_IF_ERROR(TensorSetZero(context, &filter_grad_tensor)); + TF_RETURN_IF_ERROR( + TensorSetZero(context, &column_buffer_tensor)); + + ComputeInputOffsetMaskGrad(context); + + ComputeFilterGrad(context); + + if (p.use_bias) { + const auto d = context->eigen_device(); + + auto out = bias_grad_tensor.tensor(); + auto in = output_grad_tensor.tensor(); + + const bool use_64bit = in.size() > Eigen::NumTraits::highest(); + + const Eigen::array axes({0, 2, 3, 4}); + + if (!use_64bit && + Eigen::internal::is_same::value) { + To32Bit(out).device(d) = To32Bit(in).sum(axes); + } else { + out.device(d) = in.sum(axes); + } + } + + return Status::OK(); + } + + void ComputeFilterGrad(OpKernelContext *context) { + // input_channels * filter_rows * filter_cols / weight_groups == + // filter_channels * filter_rows * filter_cols + const auto cols = p.filter_channels * p.filter_rows * p.filter_cols; + const auto rows = p.output_channels / p.weight_groups; + const auto elems = p.parallel_imgs * p.output_rows * p.output_cols; + + Tensor output_grad_tensor_reshaped(output_grad_tensor.dtype()); + CHECK(output_grad_tensor_reshaped.CopyFrom( + output_grad_tensor, + TensorShape({p.batches, p.weight_groups, rows, elems}))); + + Tensor column_buffer_tensor_reshaped(column_buffer_tensor.dtype()); + CHECK(column_buffer_tensor_reshaped.CopyFrom( + column_buffer_tensor, TensorShape({p.weight_groups, elems, cols}))); + + Tensor matmul_lhs_tmp_tensor; + OP_REQUIRES_OK(context, context->allocate_temp( + DataTypeToEnum::value, + TensorShape({p.weight_groups, rows, elems}), + &matmul_lhs_tmp_tensor)); + + Tensor matmul_out_tmp_tensor; + OP_REQUIRES_OK(context, context->allocate_temp( + DataTypeToEnum::value, + TensorShape({p.weight_groups, rows, cols}), + &matmul_out_tmp_tensor)); + + for (auto b = 0; b < p.batches; b++) { + this->DeformableIm2Col(context, b); + + auto lhs = matmul_lhs_tmp_tensor; + auto rhs = column_buffer_tensor_reshaped; + auto out = matmul_out_tmp_tensor; + + OP_REQUIRES_OK(context, + CopySliceToElement( + context, output_grad_tensor_reshaped, &lhs, b)); + + MatMulBCast bcast(lhs.shape().dim_sizes(), rhs.shape().dim_sizes()); + + LaunchBatchMatMul::Launch(context, lhs, rhs, false, false, + false, true, bcast, &out); + + OP_REQUIRES_OK(context, + AddToTensor(context, &filter_grad_tensor, + &filter_grad_tensor, &out)); + } + } + + void ComputeInputOffsetMaskGrad(OpKernelContext *context) { + // input_channels * filter_rows * filter_cols / weight_groups == + // filter_channels * filter_rows * filter_cols + const auto rows = p.filter_channels * p.filter_rows * p.filter_cols; + const auto elems = p.output_channels / p.weight_groups; + const auto cols = p.parallel_imgs * p.output_rows * p.output_cols; + + Tensor output_grad_tensor_reshaped(output_grad_tensor.dtype()); + CHECK(output_grad_tensor_reshaped.CopyFrom( + output_grad_tensor, + TensorShape({p.batches, p.weight_groups, elems, cols}))); + + Tensor column_buffer_tensor_reshaped(column_buffer_tensor.dtype()); + CHECK(column_buffer_tensor_reshaped.CopyFrom( + column_buffer_tensor, TensorShape({p.weight_groups, rows, cols}))); + + Tensor matmul_rhs_tmp_tensor; + OP_REQUIRES_OK(context, context->allocate_temp( + DataTypeToEnum::value, + TensorShape({p.weight_groups, elems, cols}), + &matmul_rhs_tmp_tensor)); + + for (auto b = 0; b < p.batches; b++) { + auto lhs = filter_tensor; + auto rhs = matmul_rhs_tmp_tensor; + auto out = column_buffer_tensor_reshaped; + + OP_REQUIRES_OK(context, + CopySliceToElement( + context, output_grad_tensor_reshaped, &rhs, b)); + + MatMulBCast bcast(lhs.shape().dim_sizes(), rhs.shape().dim_sizes()); + + LaunchBatchMatMul::Launch(context, lhs, rhs, false, false, + true, false, bcast, &out); + + DeformableCol2ImForOffsetAndMask(context, b); + + DeformableCol2ImForInput(context, b); + } + } + + void DeformableCol2ImForOffsetAndMask(OpKernelContext *context, int32 b); + + void DeformableCol2ImForInput(OpKernelContext *context, int32 b); + + Tensor output_grad_tensor; + Tensor input_grad_tensor; + Tensor filter_grad_tensor; + Tensor bias_grad_tensor; + Tensor offset_grad_tensor; + Tensor mask_grad_tensor; }; } // namespace functor diff --git a/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op_gpu.cu.cc b/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op_gpu.cu.cc new file mode 100644 index 0000000000..e4f02c6268 --- /dev/null +++ b/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op_gpu.cu.cc @@ -0,0 +1,385 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/kernel_shape_util.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" +#include "tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.h" + +namespace tensorflow { +namespace addons { + +using GPUDevice = Eigen::GpuDevice; + +namespace functor { + +template +__global__ void DeformableIm2ColKernel( + int32 b, int32 num_kernels, DeformableConv2DParams p, + typename TTypes::Tensor input_eigen_tensor, + typename TTypes::Tensor offset_eigen_tensor, + typename TTypes::Tensor mask_eigen_tensor, + typename TTypes::Tensor column_buffer_eigen_tensor) { + CUDA_1D_KERNEL_LOOP(k, num_kernels) { + const auto current_output_col = k % p.output_cols; + const auto current_output_row = (k / p.output_cols) % p.output_rows; + const auto current_batch = + (k / (p.output_rows * p.output_cols)) % p.parallel_imgs; + const auto current_input_channel = + k / (p.output_rows * p.output_cols * p.parallel_imgs); + const auto current_output_channel = + current_input_channel * p.filter_rows * p.filter_cols; + + const auto current_actual_batch = b * p.parallel_imgs + current_batch; + + const auto group_index = + current_input_channel / (p.input_channels / p.offset_groups); + + auto column_buffer_tensor_channel = current_output_channel; + for (auto current_filter_row = 0; current_filter_row < p.filter_rows; + current_filter_row++) { + for (auto current_filter_col = 0; current_filter_col < p.filter_cols; + current_filter_col++) { + auto offset_h = offset_eigen_tensor( + b, current_batch, group_index, current_filter_row, + current_filter_col, 0, current_output_row, current_output_col); + auto offset_w = offset_eigen_tensor( + b, current_batch, group_index, current_filter_row, + current_filter_col, 1, current_output_row, current_output_col); + + auto mask = p.use_mask ? mask_eigen_tensor( + b, current_batch, group_index, + current_filter_row, current_filter_col, + current_output_row, current_output_col) + : T(1); + + auto y = (current_output_row * p.stride_rows - p.padding_rows) + + current_filter_row * p.dilation_rows + offset_h; + auto x = (current_output_col * p.stride_cols - p.padding_cols) + + current_filter_col * p.dilation_cols + offset_w; + + column_buffer_eigen_tensor(column_buffer_tensor_channel, current_batch, + current_output_row, current_output_col) = + mask * BilinearInterpolate(input_eigen_tensor, b, + current_actual_batch, + current_input_channel, y, x); + column_buffer_tensor_channel++; + } + } + } +} + +template +__global__ void DeformableCol2ImForOffsetAndMaskKernel( + int32 b, int32 num_kernels, DeformableConv2DParams p, + typename TTypes::Tensor input_eigen_tensor, + typename TTypes::Tensor offset_eigen_tensor, + typename TTypes::Tensor mask_eigen_tensor, + typename TTypes::Tensor offset_grad_eigen_tensor, + typename TTypes::Tensor mask_grad_eigen_tensor, + typename TTypes::Tensor column_buffer_eigen_tensor) { + CUDA_1D_KERNEL_LOOP(k, num_kernels) { + auto offset_grad_value = T(0); + auto mask_grad_value = T(0); + + const auto offset_channels = + 2 * p.filter_rows * p.filter_cols * p.offset_groups; + + const auto offset_channel_step = p.filter_rows * p.filter_cols; + + const auto current_output_col = k % p.output_cols; + const auto current_output_row = (k / p.output_cols) % p.output_rows; + const auto current_filter_col = + (k / (2 * p.output_rows * p.output_cols)) % p.filter_cols; + const auto current_filter_row = + (k / (2 * p.output_rows * p.output_cols * p.filter_cols)) % + p.filter_rows; + const auto current_offset_channel = + (k / (p.output_rows * p.output_cols)) % offset_channels; + const auto current_batch = + k / (p.output_rows * p.output_cols * offset_channels); + + const auto current_actual_batch = b * p.parallel_imgs + current_batch; + + const auto current_offset_group = + current_offset_channel / (2 * offset_channel_step); + + const auto channels_per_offset_group = p.input_channels / p.offset_groups; + const auto offset_channel_diff = + current_offset_channel - current_offset_group * 2 * offset_channel_step; + const auto is_y_direction = offset_channel_diff % 2 == 0; + + for (auto selected_offset_channel = (offset_channel_diff / 2); + selected_offset_channel < + channels_per_offset_group * offset_channel_step; + selected_offset_channel += offset_channel_step) { + const auto selected_filter_col = selected_offset_channel % p.filter_cols; + const auto selected_filter_row = + (selected_offset_channel / p.filter_cols) % p.filter_rows; + const auto input_channel_diff = + (selected_offset_channel / (p.filter_cols * p.filter_rows)); + + const auto offset_h = offset_eigen_tensor( + b, current_batch, current_offset_group, selected_filter_row, + selected_filter_col, 0, current_output_row, current_output_col); + const auto offset_w = offset_eigen_tensor( + b, current_batch, current_offset_group, selected_filter_row, + selected_filter_col, 1, current_output_row, current_output_col); + const auto mask = + p.use_mask + ? mask_eigen_tensor(b, current_batch, current_offset_group, + selected_filter_row, selected_filter_col, + current_output_row, current_output_col) + : T(1); + + const auto y = (current_output_row * p.stride_rows - p.padding_rows) + + selected_filter_row * p.dilation_rows + offset_h; + const auto x = (current_output_col * p.stride_cols - p.padding_cols) + + selected_filter_col * p.dilation_cols + offset_w; + + const auto selected_input_channel = + input_channel_diff + current_offset_group * channels_per_offset_group; + + const auto filter_data = column_buffer_eigen_tensor( + selected_input_channel, selected_filter_row, selected_filter_col, + current_batch, current_output_row, current_output_col); + + const auto weight = + GetCoordinateWeight(input_eigen_tensor, b, current_actual_batch, + selected_input_channel, y, x, is_y_direction); + + offset_grad_value += mask * weight * filter_data; + + if (is_y_direction) { + mask_grad_value += + filter_data * BilinearInterpolate(input_eigen_tensor, b, + current_actual_batch, + selected_input_channel, y, x); + } + } + + offset_grad_eigen_tensor(current_actual_batch, current_offset_channel, + current_output_row, current_output_col) = + offset_grad_value; + + if (p.use_mask && is_y_direction) { + const auto current_mask_channel = + (current_offset_group * p.filter_rows + current_filter_row) * + p.filter_cols + + current_filter_col; + + mask_grad_eigen_tensor(current_actual_batch, current_mask_channel, + current_output_row, current_output_col) = + mask_grad_value; + } + } +} + +template +__global__ void DeformableCol2ImForInputKernel( + int32 b, int32 num_kernels, DeformableConv2DParams p, + typename TTypes::Tensor offset_eigen_tensor, + typename TTypes::Tensor mask_eigen_tensor, + typename TTypes::Tensor input_grad_eigen_tensor, + typename TTypes::Tensor column_buffer_tensor_flattened) { + CUDA_1D_KERNEL_LOOP(k, num_kernels) { + const auto current_output_col = k % p.output_cols; + const auto current_output_row = (k / p.output_cols) % p.output_rows; + const auto current_batch = + (k / (p.output_rows * p.output_cols)) % p.parallel_imgs; + + const auto current_filter_col = + (k / (p.output_rows * p.output_cols * p.parallel_imgs)) % p.filter_cols; + const auto current_filter_row = (k / (p.output_rows * p.output_cols * + p.parallel_imgs * p.filter_cols)) % + p.filter_rows; + const auto current_channel = + k / (p.output_rows * p.output_cols * p.parallel_imgs * p.filter_rows * + p.filter_cols); + + const auto current_offset_group = + current_channel / (p.input_channels / p.offset_groups); + + const auto mask = + p.use_mask ? mask_eigen_tensor(b, current_batch, current_offset_group, + current_filter_row, current_filter_col, + current_output_row, current_output_col) + : T(1); + + const auto offset_h = offset_eigen_tensor( + b, current_batch, current_offset_group, current_filter_row, + current_filter_col, 0, current_output_row, current_output_col); + const auto offset_w = offset_eigen_tensor( + b, current_batch, current_offset_group, current_filter_row, + current_filter_col, 1, current_output_row, current_output_col); + + const auto y = (current_output_row * p.stride_rows - p.padding_rows) + + current_filter_row * p.dilation_rows + offset_h; + const auto x = (current_output_col * p.stride_cols - p.padding_cols) + + current_filter_col * p.dilation_cols + offset_w; + + for (auto dy = -1; dy <= 1; dy++) { + for (auto dx = -1; dx <= 1; dx++) { + const auto current_input_row = int(y) + dy; + const auto current_input_col = int(x) + dx; + + if (p.input_rows > current_input_row && current_input_row >= 0 && + p.input_cols > current_input_col && current_input_col >= 0 && + std::abs(y - current_input_row) < 1 && + std::abs(x - current_input_col) < 1) { + const auto weight = (1.0 - std::abs(y - current_input_row)) * + (1.0 - std::abs(x - current_input_col)); + + const auto current_actual_batch = b * p.parallel_imgs + current_batch; + + auto *ptr = input_grad_eigen_tensor.data(); + + const auto ptr_pos = + ((current_actual_batch * p.input_channels + current_channel) * + p.input_rows + + current_input_row) * + p.input_cols + + current_input_col; + + GpuAtomicAdd(ptr + ptr_pos, + mask * weight * column_buffer_tensor_flattened(k)); + } + } + } + } +} + +#define IM2COL(T) \ + template <> \ + void DeformableConv2DFunctorBase::DeformableIm2Col( \ + OpKernelContext *context, int32 b) { \ + auto num_kernels = \ + p.input_channels * p.output_rows * p.output_cols * p.parallel_imgs; \ + \ + const auto offset_eigen_tensor = offset_tensor.tensor(); \ + \ + const auto mask_eigen_tensor = \ + p.use_mask ? mask_tensor.tensor() \ + : mask_tensor.shaped({0, 0, 0, 0, 0, 0, 0}); \ + \ + const auto input_eigen_tensor = input_tensor.tensor(); \ + \ + auto column_buffer_eigen_tensor = column_buffer_tensor.tensor(); \ + \ + auto device = context->template eigen_device(); \ + GpuLaunchConfig config = GetGpuLaunchConfig(num_kernels, device); \ + TF_CHECK_OK(GpuLaunchKernel(DeformableIm2ColKernel, config.block_count, \ + config.thread_per_block, 0, device.stream(), \ + b, num_kernels, p, input_eigen_tensor, \ + offset_eigen_tensor, mask_eigen_tensor, \ + column_buffer_eigen_tensor)); \ + } +TF_CALL_float(IM2COL); +TF_CALL_double(IM2COL); +#undef IM2COL + +#define COL2IM_OFFSET_AND_MASK(T) \ + template <> \ + void \ + DeformableConv2DGradFunctor::DeformableCol2ImForOffsetAndMask( \ + OpKernelContext *context, int32 b) { \ + const auto num_kernels = p.output_rows * p.output_cols * 2 * \ + p.filter_rows * p.filter_cols * p.offset_groups * \ + p.parallel_imgs; \ + \ + const auto offset_eigen_tensor = offset_tensor.template tensor(); \ + \ + const auto mask_eigen_tensor = \ + p.use_mask ? mask_tensor.template tensor() \ + : mask_tensor.template shaped({0, 0, 0, 0, 0, 0, 0}); \ + \ + const auto column_buffer_eigen_tensor = \ + column_buffer_tensor.template shaped( \ + {p.input_channels, p.filter_rows, p.filter_cols, p.parallel_imgs, \ + p.output_rows, p.output_cols}); \ + \ + auto offset_grad_eigen_tensor = offset_grad_tensor.tensor(); \ + auto mask_grad_eigen_tensor = mask_grad_tensor.tensor(); \ + \ + const auto input_eigen_tensor = input_tensor.tensor(); \ + \ + auto device = context->template eigen_device(); \ + GpuLaunchConfig config = GetGpuLaunchConfig(num_kernels, device); \ + TF_CHECK_OK(GpuLaunchKernel( \ + DeformableCol2ImForOffsetAndMaskKernel, config.block_count, \ + config.thread_per_block, 0, device.stream(), b, num_kernels, p, \ + input_eigen_tensor, offset_eigen_tensor, mask_eigen_tensor, \ + offset_grad_eigen_tensor, mask_grad_eigen_tensor, \ + column_buffer_eigen_tensor)); \ + } +TF_CALL_float(COL2IM_OFFSET_AND_MASK); +TF_CALL_double(COL2IM_OFFSET_AND_MASK); +#undef COL2IM_OFFSET_AND_MASK + +#define COL2IM_INPUT(T) \ + template <> \ + void DeformableConv2DGradFunctor::DeformableCol2ImForInput( \ + OpKernelContext *context, int32 b) { \ + const auto num_kernels = p.input_channels * p.filter_rows * \ + p.filter_cols * p.output_rows * p.output_cols * \ + p.parallel_imgs; \ + \ + const auto offset_eigen_tensor = offset_tensor.template tensor(); \ + \ + const auto mask_eigen_tensor = \ + p.use_mask ? mask_tensor.template tensor() \ + : mask_tensor.template shaped({0, 0, 0, 0, 0, 0, 0}); \ + \ + const auto column_buffer_tensor_flattened = \ + column_buffer_tensor.template shaped({num_kernels}); \ + \ + auto input_grad_eigen_tensor = input_grad_tensor.tensor(); \ + \ + auto device = context->template eigen_device(); \ + GpuLaunchConfig config = GetGpuLaunchConfig(num_kernels, device); \ + TF_CHECK_OK(GpuLaunchKernel( \ + DeformableCol2ImForInputKernel, config.block_count, \ + config.thread_per_block, 0, device.stream(), b, num_kernels, p, \ + offset_eigen_tensor, mask_eigen_tensor, input_grad_eigen_tensor, \ + column_buffer_tensor_flattened)); \ + } +TF_CALL_float(COL2IM_INPUT); +TF_CALL_double(COL2IM_INPUT); +#undef COL2IM_INPUT + +#define EXPLICIT_TEMPLATE(T) \ + template struct DeformableConv2DForwardFunctor; \ + template struct DeformableConv2DGradFunctor; +TF_CALL_float(EXPLICIT_TEMPLATE); +TF_CALL_double(EXPLICIT_TEMPLATE); +#undef EXPLICIT_TEMPLATE + +} // end namespace functor + +#define EXPLICIT_TEMPLATE(T) \ + template Status Transpose( \ + OpKernelContext * ctx, const Tensor &in, \ + const gtl::ArraySlice perm, Tensor *out); +TF_CALL_float(EXPLICIT_TEMPLATE); +TF_CALL_double(EXPLICIT_TEMPLATE); +#undef EXPLICIT_TEMPLATE + +} // namespace addons +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow_addons/layers/tests/deformable_conv2d_test.py b/tensorflow_addons/layers/tests/deformable_conv2d_test.py index 5a18adaa88..6899468624 100644 --- a/tensorflow_addons/layers/tests/deformable_conv2d_test.py +++ b/tensorflow_addons/layers/tests/deformable_conv2d_test.py @@ -207,7 +207,7 @@ def _expected( return output -@pytest.mark.with_device(["cpu"]) +@pytest.mark.with_device(["cpu", "gpu"]) @pytest.mark.usefixtures("maybe_run_functions_eagerly") def test_forward_simple(data_format): if data_format == "channels_last": @@ -283,7 +283,7 @@ def test_forward_simple(data_format): np.testing.assert_allclose(actual.numpy(), expected) -@pytest.mark.with_device(["cpu"]) +@pytest.mark.with_device(["cpu", "gpu"]) def test_gradients(data_format): if data_format == "channels_last": return @@ -349,7 +349,7 @@ def conv_fn(input_tensor, offset_tensor, mask_tensor): np.testing.assert_allclose(theoretical[2], numerical[2], atol=1e-3) -@pytest.mark.with_device(["cpu"]) +@pytest.mark.with_device(["cpu", "gpu"]) def test_keras(data_format): if data_format == "channels_last": return From 6e02ca7ef7e1d01063946b6a42aa54ab6f9970cc Mon Sep 17 00:00:00 2001 From: Licht Takeuchi Date: Fri, 30 Oct 2020 00:33:08 +0900 Subject: [PATCH 09/13] Register GPU kernel iff GOOGLE_CUDA is defined --- .../custom_ops/layers/cc/kernels/deformable_conv2d_op.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc b/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc index ddfceb5ca5..1f13180a3b 100644 --- a/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc +++ b/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc @@ -575,6 +575,8 @@ TF_CALL_double(REGISTER_DEFORMABLECONV2D_OP_CPU); #undef REGISTER_DEFORMABLECONV2D_OP_CPU // Register the GPU kernels. +#if GOOGLE_CUDA + #define REGISTER_DEFORMABLECONV2D_OP_GPU(T) \ REGISTER_KERNEL_BUILDER(Name("Addons>DeformableConv2D") \ .Device(DEVICE_GPU) \ @@ -589,5 +591,7 @@ TF_CALL_float(REGISTER_DEFORMABLECONV2D_OP_GPU); TF_CALL_double(REGISTER_DEFORMABLECONV2D_OP_GPU); #undef REGISTER_DEFORMABLECONV2D_OP_GPU +#endif // GOOGLE_CUDA + } // namespace addons } // namespace tensorflow From 99034d6ecdbe6d88b54e10a18f2d48e95f2ae058 Mon Sep 17 00:00:00 2001 From: Licht Takeuchi Date: Fri, 30 Oct 2020 01:24:30 +0900 Subject: [PATCH 10/13] Declare extern template iff GOOGLE_CUDA is defined --- .../custom_ops/layers/cc/kernels/deformable_conv2d_op.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc b/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc index 1f13180a3b..197116d6bc 100644 --- a/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc +++ b/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc @@ -30,6 +30,7 @@ namespace addons { using CPUDevice = Eigen::ThreadPoolDevice; using GPUDevice = Eigen::GpuDevice; +#if GOOGLE_CUDA #define EXTERN_TEMPLATE(T) \ extern template Status Transpose( \ OpKernelContext * ctx, const Tensor &in, \ @@ -37,15 +38,18 @@ using GPUDevice = Eigen::GpuDevice; TF_CALL_float(EXTERN_TEMPLATE); TF_CALL_double(EXTERN_TEMPLATE); #undef EXTERN_TEMPLATE +#endif // GOOGLE_CUDA namespace functor { +#if GOOGLE_CUDA #define EXTERN_TEMPLATE(T) \ extern template struct DeformableConv2DForwardFunctor; \ extern template struct DeformableConv2DGradFunctor; TF_CALL_float(EXTERN_TEMPLATE); TF_CALL_double(EXTERN_TEMPLATE); #undef EXTERN_TEMPLATE +#endif // GOOGLE_CUDA #define IM2COL(T) \ template <> \ From e751f57a99990606652c68bdd45400c6484f6aec Mon Sep 17 00:00:00 2001 From: Licht Takeuchi Date: Fri, 30 Oct 2020 18:03:54 +0900 Subject: [PATCH 11/13] Bug fix and add more tests --- .../layers/cc/kernels/deformable_conv2d_op.h | 2 +- tensorflow_addons/layers/deformable_conv2d.py | 6 +- .../layers/tests/deformable_conv2d_test.py | 223 ++++++++++++++++-- 3 files changed, 212 insertions(+), 19 deletions(-) diff --git a/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.h b/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.h index 205a6812b6..542e3edb25 100644 --- a/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.h +++ b/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.h @@ -500,7 +500,7 @@ struct DeformableConv2DGradFunctor Tensor column_buffer_tensor_reshaped(column_buffer_tensor.dtype()); CHECK(column_buffer_tensor_reshaped.CopyFrom( - column_buffer_tensor, TensorShape({p.weight_groups, elems, cols}))); + column_buffer_tensor, TensorShape({p.weight_groups, cols, elems}))); Tensor matmul_lhs_tmp_tensor; OP_REQUIRES_OK(context, context->allocate_temp( diff --git a/tensorflow_addons/layers/deformable_conv2d.py b/tensorflow_addons/layers/deformable_conv2d.py index 68e16bc6ea..7d248d5aac 100644 --- a/tensorflow_addons/layers/deformable_conv2d.py +++ b/tensorflow_addons/layers/deformable_conv2d.py @@ -187,6 +187,7 @@ def __init__( self.filter_weights = None self.filter_bias = None + self.null_mask = None def _validate_shapes(self, shapes): if type(shapes) is not list: @@ -276,6 +277,9 @@ def build(self, shapes): else: self.filter_bias = tf.zeros((0,)) + if not self.use_mask: + self.null_mask = tf.zeros((0, 0, 0, 0)) + self.built = True def compute_output_shape(self, shapes): @@ -304,7 +308,7 @@ def compute_output_shape(self, shapes): def call(self, inputs, **kwargs): input_tensor = inputs[0] offset_tensor = inputs[1] - mask_tensor = inputs[2] if self.use_mask else tf.zeros((0, 0, 0, 0)) + mask_tensor = inputs[2] if self.use_mask else self.null_mask return _deformable_conv2d( input_tensor=tf.convert_to_tensor(input_tensor), diff --git a/tensorflow_addons/layers/tests/deformable_conv2d_test.py b/tensorflow_addons/layers/tests/deformable_conv2d_test.py index 6899468624..c7dbb5d922 100644 --- a/tensorflow_addons/layers/tests/deformable_conv2d_test.py +++ b/tensorflow_addons/layers/tests/deformable_conv2d_test.py @@ -18,7 +18,10 @@ import numpy as np import tensorflow as tf import tensorflow_addons.utils.keras_utils as conv_utils -from tensorflow_addons.layers.deformable_conv2d import DeformableConv2D +from tensorflow_addons.layers.deformable_conv2d import ( + DeformableConv2D, + _deformable_conv2d, +) def _get_padding_length( @@ -129,10 +132,13 @@ def _expected( input_channels_per_weight_groups = filter_tensor.shape[1] output_channels_per_weight_groups = output_channels // weight_groups - offset_tensor = offset_tensor.reshape((batches, -1, 2, output_rows, output_cols)) - output = np.zeros((batches, output_channels, output_rows, output_cols)) + if output.size == 0: + return output + + offset_tensor = offset_tensor.reshape((batches, -1, 2, output_rows, output_cols)) + for batch in range(batches): for output_channel in range(output_channels): for output_row in range(output_rows): @@ -167,9 +173,13 @@ def _expected( batch, offset_idx, 1, output_row, output_col ] - mask = mask_tensor[ - batch, offset_idx, output_row, output_col - ] + mask = ( + mask_tensor[ + batch, offset_idx, output_row, output_col + ] + if mask_tensor.size > 0 + else 1 + ) y = ( stride_rows * output_row @@ -203,24 +213,26 @@ def _expected( ) ) - output += bias.reshape((1, output_channels, 1, 1)) + if bias.size > 0: + output += bias.reshape((1, output_channels, 1, 1)) + return output @pytest.mark.with_device(["cpu", "gpu"]) @pytest.mark.usefixtures("maybe_run_functions_eagerly") -def test_forward_simple(data_format): +@pytest.mark.parametrize("padding", ["same", "valid"]) +@pytest.mark.parametrize("batches", [0, 1, 2]) +def test_forward(data_format, padding, batches): if data_format == "channels_last": return - batches = 1 input_channels = 6 filters = 2 weight_groups = 2 offset_groups = 3 strides = (2, 1) - padding = "same" dilation_rate = (2, 1) kernel_size = (3, 2) @@ -280,22 +292,98 @@ def test_forward_simple(data_format): dilation_rate, ) - np.testing.assert_allclose(actual.numpy(), expected) + np.testing.assert_allclose(actual.numpy(), expected, rtol=1e-5) @pytest.mark.with_device(["cpu", "gpu"]) -def test_gradients(data_format): +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +@pytest.mark.parametrize("padding", ["same", "valid"]) +@pytest.mark.parametrize("batches", [0, 1, 2]) +def test_forward_no_mask(data_format, padding, batches): + if data_format == "channels_last": + return + + input_channels = 6 + filters = 2 + weight_groups = 2 + offset_groups = 3 + + strides = (2, 1) + dilation_rate = (2, 1) + kernel_size = (3, 2) + + input_rows, input_cols = 5, 4 + filter_rows, filter_cols = kernel_size + stride_rows, stride_cols = strides + dilation_rows, dilation_cols = dilation_rate + + output_rows = conv_utils.conv_output_length( + input_rows, + filter_rows, + padding=padding, + stride=stride_rows, + dilation=dilation_rows, + ) + output_cols = conv_utils.conv_output_length( + input_cols, + filter_cols, + padding=padding, + stride=stride_cols, + dilation=dilation_cols, + ) + + offsets = offset_groups * filter_rows * filter_cols + + input_tensor = tf.random.uniform([batches, input_channels, input_rows, input_cols]) + offset_tensor = tf.random.uniform([batches, 2 * offsets, output_rows, output_cols]) + + conv = DeformableConv2D( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + dilation_rate=dilation_rate, + weight_groups=weight_groups, + offset_groups=offset_groups, + use_mask=False, + use_bias=True, + ) + + actual = conv([input_tensor, offset_tensor]) + + filter_tensor = conv.filter_weights + mask_tensor = conv.null_mask + bias = conv.filter_bias + + expected = _expected( + input_tensor, + filter_tensor, + offset_tensor, + mask_tensor, + bias, + strides, + weight_groups, + offset_groups, + padding, + dilation_rate, + ) + + np.testing.assert_allclose(actual.numpy(), expected, rtol=1e-5) + + +@pytest.mark.with_device(["cpu", "gpu"]) +@pytest.mark.parametrize("padding", ["same", "valid"]) +@pytest.mark.parametrize("batches", [0, 1, 2]) +def test_gradients(data_format, padding, batches): if data_format == "channels_last": return - batches = 1 input_channels = 6 filters = 2 weight_groups = 2 offset_groups = 3 strides = (2, 1) - padding = "same" dilation_rate = (2, 1) kernel_size = (3, 2) @@ -337,16 +425,117 @@ def test_gradients(data_format): use_bias=True, ) - def conv_fn(input_tensor, offset_tensor, mask_tensor): - return conv([input_tensor, offset_tensor, mask_tensor]) + conv.build([tf.shape(input_tensor), tf.shape(offset_tensor), tf.shape(mask_tensor)]) + + def conv_fn(input_tensor, filter_weights, filter_bias, offset_tensor, mask_tensor): + return _deformable_conv2d( + input_tensor=tf.convert_to_tensor(input_tensor), + filter_tensor=tf.convert_to_tensor(filter_weights), + bias_tensor=tf.convert_to_tensor(filter_bias), + offset_tensor=tf.convert_to_tensor(offset_tensor), + mask_tensor=tf.convert_to_tensor(mask_tensor), + strides=conv.strides, + weight_groups=conv.weight_groups, + offset_groups=conv.offset_groups, + padding="SAME" if conv.padding == "same" else "VALID", + dilations=conv.dilation_rate, + ) + + theoretical, numerical = tf.test.compute_gradient( + conv_fn, + [ + input_tensor, + conv.filter_weights, + conv.filter_bias, + offset_tensor, + mask_tensor, + ], + ) + + np.testing.assert_allclose(theoretical[0], numerical[0], atol=1e-3) + np.testing.assert_allclose(theoretical[1], numerical[1], atol=1e-3) + np.testing.assert_allclose(theoretical[2], numerical[2], atol=1e-3) + np.testing.assert_allclose(theoretical[3], numerical[3], atol=1e-3) + np.testing.assert_allclose(theoretical[4], numerical[4], atol=1e-3) + + +@pytest.mark.with_device(["cpu", "gpu"]) +@pytest.mark.parametrize("padding", ["same", "valid"]) +@pytest.mark.parametrize("batches", [0, 1, 2]) +def test_gradients_no_mask(data_format, padding, batches): + if data_format == "channels_last": + return + + input_channels = 6 + filters = 2 + weight_groups = 2 + offset_groups = 3 + + strides = (2, 1) + dilation_rate = (2, 1) + kernel_size = (3, 2) + + input_rows, input_cols = 5, 4 + filter_rows, filter_cols = kernel_size + stride_rows, stride_cols = strides + dilation_rows, dilation_cols = dilation_rate + + output_rows = conv_utils.conv_output_length( + input_rows, + filter_rows, + padding=padding, + stride=stride_rows, + dilation=dilation_rows, + ) + output_cols = conv_utils.conv_output_length( + input_cols, + filter_cols, + padding=padding, + stride=stride_cols, + dilation=dilation_cols, + ) + + offsets = offset_groups * filter_rows * filter_cols + + input_tensor = tf.random.uniform([batches, input_channels, input_rows, input_cols]) + offset_tensor = tf.random.uniform([batches, 2 * offsets, output_rows, output_cols]) + + conv = DeformableConv2D( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + dilation_rate=dilation_rate, + weight_groups=weight_groups, + offset_groups=offset_groups, + use_mask=False, + use_bias=True, + ) + + conv.build([tf.shape(input_tensor), tf.shape(offset_tensor)]) + + def conv_fn(input_tensor, filter_weights, filter_bias, offset_tensor): + return _deformable_conv2d( + input_tensor=tf.convert_to_tensor(input_tensor), + filter_tensor=tf.convert_to_tensor(filter_weights), + bias_tensor=tf.convert_to_tensor(filter_bias), + offset_tensor=tf.convert_to_tensor(offset_tensor), + mask_tensor=tf.convert_to_tensor(conv.null_mask), + strides=conv.strides, + weight_groups=conv.weight_groups, + offset_groups=conv.offset_groups, + padding="SAME" if conv.padding == "same" else "VALID", + dilations=conv.dilation_rate, + ) theoretical, numerical = tf.test.compute_gradient( - conv_fn, [input_tensor, offset_tensor, mask_tensor] + conv_fn, [input_tensor, conv.filter_weights, conv.filter_bias, offset_tensor] ) np.testing.assert_allclose(theoretical[0], numerical[0], atol=1e-3) np.testing.assert_allclose(theoretical[1], numerical[1], atol=1e-3) np.testing.assert_allclose(theoretical[2], numerical[2], atol=1e-3) + np.testing.assert_allclose(theoretical[3], numerical[3], atol=1e-3) @pytest.mark.with_device(["cpu", "gpu"]) From da64fb4760adcd4f42b642668dd02846257230c5 Mon Sep 17 00:00:00 2001 From: Licht Takeuchi Date: Mon, 2 Nov 2020 15:24:37 +0900 Subject: [PATCH 12/13] Parallelize CPU kernel --- .../layers/cc/kernels/deformable_conv2d_op.cc | 423 ++++++++++-------- 1 file changed, 228 insertions(+), 195 deletions(-) diff --git a/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc b/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc index 197116d6bc..fb5b250ac6 100644 --- a/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc +++ b/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc @@ -21,6 +21,9 @@ #include "tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.h" +#include +#include + #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/kernel_shape_util.h" @@ -68,54 +71,63 @@ TF_CALL_double(EXTERN_TEMPLATE); \ auto column_buffer_eigen_tensor = column_buffer_tensor.tensor(); \ \ - for (auto k = 0; k < num_kernels; k++) { \ - const auto current_output_col = k % p.output_cols; \ - const auto current_output_row = (k / p.output_cols) % p.output_rows; \ - const auto current_batch = \ - (k / (p.output_rows * p.output_cols)) % p.parallel_imgs; \ - const auto current_input_channel = \ - k / (p.output_rows * p.output_cols * p.parallel_imgs); \ - const auto current_output_channel = \ - current_input_channel * p.filter_rows * p.filter_cols; \ - \ - const auto current_actual_batch = b * p.parallel_imgs + current_batch; \ - \ - const auto group_index = \ - current_input_channel / (p.input_channels / p.offset_groups); \ - \ - auto column_buffer_tensor_channel = current_output_channel; \ - for (auto current_filter_row = 0; current_filter_row < p.filter_rows; \ - current_filter_row++) { \ - for (auto current_filter_col = 0; current_filter_col < p.filter_cols; \ - current_filter_col++) { \ - auto offset_h = offset_eigen_tensor( \ - b, current_batch, group_index, current_filter_row, \ - current_filter_col, 0, current_output_row, current_output_col); \ - auto offset_w = offset_eigen_tensor( \ - b, current_batch, group_index, current_filter_row, \ - current_filter_col, 1, current_output_row, current_output_col); \ - \ - auto mask = p.use_mask ? mask_eigen_tensor( \ - b, current_batch, group_index, \ - current_filter_row, current_filter_col, \ - current_output_row, current_output_col) \ - : T(1); \ - \ - auto y = (current_output_row * p.stride_rows - p.padding_rows) + \ - current_filter_row * p.dilation_rows + offset_h; \ - auto x = (current_output_col * p.stride_cols - p.padding_cols) + \ - current_filter_col * p.dilation_cols + offset_w; \ - \ - column_buffer_eigen_tensor(column_buffer_tensor_channel, \ - current_batch, current_output_row, \ - current_output_col) = \ - mask * BilinearInterpolate(input_eigen_tensor, b, \ - current_actual_batch, \ - current_input_channel, y, x); \ - column_buffer_tensor_channel++; \ + const auto cost = p.filter_rows * p.filter_cols; \ + const auto work = [&](Eigen::Index start, Eigen::Index end) -> void { \ + for (Eigen::Index k = start; k < end; ++k) { \ + const auto current_output_col = k % p.output_cols; \ + const auto current_output_row = (k / p.output_cols) % p.output_rows; \ + const auto current_batch = \ + (k / (p.output_rows * p.output_cols)) % p.parallel_imgs; \ + const auto current_input_channel = \ + k / (p.output_rows * p.output_cols * p.parallel_imgs); \ + const auto current_output_channel = \ + current_input_channel * p.filter_rows * p.filter_cols; \ + \ + const auto current_actual_batch = b * p.parallel_imgs + current_batch; \ + \ + const auto group_index = \ + current_input_channel / (p.input_channels / p.offset_groups); \ + \ + auto column_buffer_tensor_channel = current_output_channel; \ + for (auto current_filter_row = 0; current_filter_row < p.filter_rows; \ + current_filter_row++) { \ + for (auto current_filter_col = 0; \ + current_filter_col < p.filter_cols; current_filter_col++) { \ + auto offset_h = \ + offset_eigen_tensor(b, current_batch, group_index, \ + current_filter_row, current_filter_col, 0, \ + current_output_row, current_output_col); \ + auto offset_w = \ + offset_eigen_tensor(b, current_batch, group_index, \ + current_filter_row, current_filter_col, 1, \ + current_output_row, current_output_col); \ + \ + auto mask = p.use_mask \ + ? mask_eigen_tensor( \ + b, current_batch, group_index, \ + current_filter_row, current_filter_col, \ + current_output_row, current_output_col) \ + : T(1); \ + \ + auto y = (current_output_row * p.stride_rows - p.padding_rows) + \ + current_filter_row * p.dilation_rows + offset_h; \ + auto x = (current_output_col * p.stride_cols - p.padding_cols) + \ + current_filter_col * p.dilation_cols + offset_w; \ + \ + column_buffer_eigen_tensor(column_buffer_tensor_channel, \ + current_batch, current_output_row, \ + current_output_col) = \ + mask * BilinearInterpolate(input_eigen_tensor, b, \ + current_actual_batch, \ + current_input_channel, y, x); \ + column_buffer_tensor_channel++; \ + } \ } \ } \ - } \ + }; \ + auto thread_pool = \ + context->device()->tensorflow_cpu_worker_threads()->workers; \ + thread_pool->ParallelFor(num_kernels, cost, work); \ } TF_CALL_float(IM2COL); TF_CALL_double(IM2COL); @@ -146,105 +158,114 @@ TF_CALL_double(IM2COL); \ const auto input_eigen_tensor = input_tensor.tensor(); \ \ - for (auto k = 0; k < num_kernels; k++) { \ - auto offset_grad_value = T(0); \ - auto mask_grad_value = T(0); \ - \ - const auto offset_channels = \ - 2 * p.filter_rows * p.filter_cols * p.offset_groups; \ - \ - const auto offset_channel_step = p.filter_rows * p.filter_cols; \ - \ - const auto current_output_col = k % p.output_cols; \ - const auto current_output_row = (k / p.output_cols) % p.output_rows; \ - const auto current_filter_col = \ - (k / (2 * p.output_rows * p.output_cols)) % p.filter_cols; \ - const auto current_filter_row = \ - (k / (2 * p.output_rows * p.output_cols * p.filter_cols)) % \ - p.filter_rows; \ - const auto current_offset_channel = \ - (k / (p.output_rows * p.output_cols)) % offset_channels; \ - const auto current_batch = \ - k / (p.output_rows * p.output_cols * offset_channels); \ - \ - const auto current_actual_batch = b * p.parallel_imgs + current_batch; \ - \ - const auto current_offset_group = \ - current_offset_channel / (2 * offset_channel_step); \ - \ - const auto channels_per_offset_group = \ - p.input_channels / p.offset_groups; \ - const auto offset_channel_diff = \ - current_offset_channel - \ - current_offset_group * 2 * offset_channel_step; \ - const auto is_y_direction = offset_channel_diff % 2 == 0; \ - \ - for (auto selected_offset_channel = (offset_channel_diff / 2); \ - selected_offset_channel < \ - channels_per_offset_group * offset_channel_step; \ - selected_offset_channel += offset_channel_step) { \ - const auto selected_filter_col = \ - selected_offset_channel % p.filter_cols; \ - const auto selected_filter_row = \ - (selected_offset_channel / p.filter_cols) % p.filter_rows; \ - const auto input_channel_diff = \ - (selected_offset_channel / (p.filter_cols * p.filter_rows)); \ - \ - const auto offset_h = offset_eigen_tensor( \ - b, current_batch, current_offset_group, selected_filter_row, \ - selected_filter_col, 0, current_output_row, current_output_col); \ - const auto offset_w = offset_eigen_tensor( \ - b, current_batch, current_offset_group, selected_filter_row, \ - selected_filter_col, 1, current_output_row, current_output_col); \ - const auto mask = \ - p.use_mask \ - ? mask_eigen_tensor(b, current_batch, current_offset_group, \ - selected_filter_row, selected_filter_col, \ - current_output_row, current_output_col) \ - : T(1); \ - \ - const auto y = (current_output_row * p.stride_rows - p.padding_rows) + \ - selected_filter_row * p.dilation_rows + offset_h; \ - const auto x = (current_output_col * p.stride_cols - p.padding_cols) + \ - selected_filter_col * p.dilation_cols + offset_w; \ - \ - const auto selected_input_channel = \ - input_channel_diff + \ - current_offset_group * channels_per_offset_group; \ - \ - const auto filter_data = column_buffer_eigen_tensor( \ - selected_input_channel, selected_filter_row, selected_filter_col, \ - current_batch, current_output_row, current_output_col); \ - \ - const auto weight = GetCoordinateWeight( \ - input_eigen_tensor, b, current_actual_batch, \ - selected_input_channel, y, x, is_y_direction); \ - \ - offset_grad_value += mask * weight * filter_data; \ - \ - if (is_y_direction) { \ - mask_grad_value += \ - filter_data * BilinearInterpolate( \ - input_eigen_tensor, b, current_actual_batch, \ - selected_input_channel, y, x); \ + const auto cost = p.input_channels / p.offset_groups; \ + const auto work = [&](Eigen::Index start, Eigen::Index end) -> void { \ + for (Eigen::Index k = start; k < end; ++k) { \ + auto offset_grad_value = T(0); \ + auto mask_grad_value = T(0); \ + \ + const auto offset_channels = \ + 2 * p.filter_rows * p.filter_cols * p.offset_groups; \ + \ + const auto offset_channel_step = p.filter_rows * p.filter_cols; \ + \ + const auto current_output_col = k % p.output_cols; \ + const auto current_output_row = (k / p.output_cols) % p.output_rows; \ + const auto current_filter_col = \ + (k / (2 * p.output_rows * p.output_cols)) % p.filter_cols; \ + const auto current_filter_row = \ + (k / (2 * p.output_rows * p.output_cols * p.filter_cols)) % \ + p.filter_rows; \ + const auto current_offset_channel = \ + (k / (p.output_rows * p.output_cols)) % offset_channels; \ + const auto current_batch = \ + k / (p.output_rows * p.output_cols * offset_channels); \ + \ + const auto current_actual_batch = b * p.parallel_imgs + current_batch; \ + \ + const auto current_offset_group = \ + current_offset_channel / (2 * offset_channel_step); \ + \ + const auto channels_per_offset_group = \ + p.input_channels / p.offset_groups; \ + const auto offset_channel_diff = \ + current_offset_channel - \ + current_offset_group * 2 * offset_channel_step; \ + const auto is_y_direction = offset_channel_diff % 2 == 0; \ + \ + for (auto selected_offset_channel = (offset_channel_diff / 2); \ + selected_offset_channel < \ + channels_per_offset_group * offset_channel_step; \ + selected_offset_channel += offset_channel_step) { \ + const auto selected_filter_col = \ + selected_offset_channel % p.filter_cols; \ + const auto selected_filter_row = \ + (selected_offset_channel / p.filter_cols) % p.filter_rows; \ + const auto input_channel_diff = \ + (selected_offset_channel / (p.filter_cols * p.filter_rows)); \ + \ + const auto offset_h = offset_eigen_tensor( \ + b, current_batch, current_offset_group, selected_filter_row, \ + selected_filter_col, 0, current_output_row, current_output_col); \ + const auto offset_w = offset_eigen_tensor( \ + b, current_batch, current_offset_group, selected_filter_row, \ + selected_filter_col, 1, current_output_row, current_output_col); \ + const auto mask = \ + p.use_mask ? mask_eigen_tensor( \ + b, current_batch, current_offset_group, \ + selected_filter_row, selected_filter_col, \ + current_output_row, current_output_col) \ + : T(1); \ + \ + const auto y = \ + (current_output_row * p.stride_rows - p.padding_rows) + \ + selected_filter_row * p.dilation_rows + offset_h; \ + const auto x = \ + (current_output_col * p.stride_cols - p.padding_cols) + \ + selected_filter_col * p.dilation_cols + offset_w; \ + \ + const auto selected_input_channel = \ + input_channel_diff + \ + current_offset_group * channels_per_offset_group; \ + \ + const auto filter_data = column_buffer_eigen_tensor( \ + selected_input_channel, selected_filter_row, \ + selected_filter_col, current_batch, current_output_row, \ + current_output_col); \ + \ + const auto weight = GetCoordinateWeight( \ + input_eigen_tensor, b, current_actual_batch, \ + selected_input_channel, y, x, is_y_direction); \ + \ + offset_grad_value += mask * weight * filter_data; \ + \ + if (is_y_direction) { \ + mask_grad_value += \ + filter_data * BilinearInterpolate( \ + input_eigen_tensor, b, current_actual_batch, \ + selected_input_channel, y, x); \ + } \ } \ - } \ \ - offset_grad_eigen_tensor(current_actual_batch, current_offset_channel, \ - current_output_row, current_output_col) = \ - offset_grad_value; \ + offset_grad_eigen_tensor(current_actual_batch, current_offset_channel, \ + current_output_row, current_output_col) = \ + offset_grad_value; \ \ - if (p.use_mask && is_y_direction) { \ - const auto current_mask_channel = \ - (current_offset_group * p.filter_rows + current_filter_row) * \ - p.filter_cols + \ - current_filter_col; \ + if (p.use_mask && is_y_direction) { \ + const auto current_mask_channel = \ + (current_offset_group * p.filter_rows + current_filter_row) * \ + p.filter_cols + \ + current_filter_col; \ \ - mask_grad_eigen_tensor(current_actual_batch, current_mask_channel, \ - current_output_row, current_output_col) = \ - mask_grad_value; \ + mask_grad_eigen_tensor(current_actual_batch, current_mask_channel, \ + current_output_row, current_output_col) = \ + mask_grad_value; \ + } \ } \ - } \ + }; \ + auto thread_pool = \ + context->device()->tensorflow_cpu_worker_threads()->workers; \ + thread_pool->ParallelFor(num_kernels, cost, work); \ } TF_CALL_float(COL2IM_OFFSET_AND_MASK); TF_CALL_double(COL2IM_OFFSET_AND_MASK); @@ -269,67 +290,79 @@ TF_CALL_double(COL2IM_OFFSET_AND_MASK); \ auto input_grad_eigen_tensor = input_grad_tensor.tensor(); \ \ - for (auto k = 0; k < num_kernels; k++) { \ - const auto current_output_col = k % p.output_cols; \ - const auto current_output_row = (k / p.output_cols) % p.output_rows; \ - const auto current_batch = \ - (k / (p.output_rows * p.output_cols)) % p.parallel_imgs; \ - \ - const auto current_filter_col = \ - (k / (p.output_rows * p.output_cols * p.parallel_imgs)) % \ - p.filter_cols; \ - const auto current_filter_row = \ - (k / (p.output_rows * p.output_cols * p.parallel_imgs * \ - p.filter_cols)) % \ - p.filter_rows; \ - const auto current_channel = \ - k / (p.output_rows * p.output_cols * p.parallel_imgs * \ - p.filter_rows * p.filter_cols); \ - \ - const auto current_offset_group = \ - current_channel / (p.input_channels / p.offset_groups); \ - \ - const auto mask = \ - p.use_mask \ - ? mask_eigen_tensor(b, current_batch, current_offset_group, \ - current_filter_row, current_filter_col, \ - current_output_row, current_output_col) \ - : T(1); \ - \ - const auto offset_h = offset_eigen_tensor( \ - b, current_batch, current_offset_group, current_filter_row, \ - current_filter_col, 0, current_output_row, current_output_col); \ - const auto offset_w = offset_eigen_tensor( \ - b, current_batch, current_offset_group, current_filter_row, \ - current_filter_col, 1, current_output_row, current_output_col); \ - \ - const auto y = (current_output_row * p.stride_rows - p.padding_rows) + \ - current_filter_row * p.dilation_rows + offset_h; \ - const auto x = (current_output_col * p.stride_cols - p.padding_cols) + \ - current_filter_col * p.dilation_cols + offset_w; \ + const auto cost = 3 * 3; \ + std::array mutex_array; \ + \ + const auto work = [&](Eigen::Index start, Eigen::Index end) -> void { \ + for (Eigen::Index k = start; k < end; ++k) { \ + const auto current_output_col = k % p.output_cols; \ + const auto current_output_row = (k / p.output_cols) % p.output_rows; \ + const auto current_batch = \ + (k / (p.output_rows * p.output_cols)) % p.parallel_imgs; \ + \ + const auto current_filter_col = \ + (k / (p.output_rows * p.output_cols * p.parallel_imgs)) % \ + p.filter_cols; \ + const auto current_filter_row = \ + (k / (p.output_rows * p.output_cols * p.parallel_imgs * \ + p.filter_cols)) % \ + p.filter_rows; \ + const auto current_channel = \ + k / (p.output_rows * p.output_cols * p.parallel_imgs * \ + p.filter_rows * p.filter_cols); \ + \ + const auto current_offset_group = \ + current_channel / (p.input_channels / p.offset_groups); \ \ - for (auto dy = -1; dy <= 1; dy++) { \ - for (auto dx = -1; dx <= 1; dx++) { \ - const auto current_input_row = int(y) + dy; \ - const auto current_input_col = int(x) + dx; \ - \ - if (p.input_rows > current_input_row && current_input_row >= 0 && \ - p.input_cols > current_input_col && current_input_col >= 0 && \ - std::abs(y - current_input_row) < 1 && \ - std::abs(x - current_input_col) < 1) { \ - const auto weight = (1.0 - std::abs(y - current_input_row)) * \ - (1.0 - std::abs(x - current_input_col)); \ + const auto mask = \ + p.use_mask \ + ? mask_eigen_tensor(b, current_batch, current_offset_group, \ + current_filter_row, current_filter_col, \ + current_output_row, current_output_col) \ + : T(1); \ \ - const auto current_actual_batch = \ - b * p.parallel_imgs + current_batch; \ + const auto offset_h = offset_eigen_tensor( \ + b, current_batch, current_offset_group, current_filter_row, \ + current_filter_col, 0, current_output_row, current_output_col); \ + const auto offset_w = offset_eigen_tensor( \ + b, current_batch, current_offset_group, current_filter_row, \ + current_filter_col, 1, current_output_row, current_output_col); \ \ - input_grad_eigen_tensor(current_actual_batch, current_channel, \ - current_input_row, current_input_col) += \ - mask * weight * column_buffer_tensor_flattened(k); \ + const auto y = (current_output_row * p.stride_rows - p.padding_rows) + \ + current_filter_row * p.dilation_rows + offset_h; \ + const auto x = (current_output_col * p.stride_cols - p.padding_cols) + \ + current_filter_col * p.dilation_cols + offset_w; \ + \ + for (auto dy = -1; dy <= 1; dy++) { \ + for (auto dx = -1; dx <= 1; dx++) { \ + const auto current_input_row = int(y) + dy; \ + const auto current_input_col = int(x) + dx; \ + \ + if (p.input_rows > current_input_row && current_input_row >= 0 && \ + p.input_cols > current_input_col && current_input_col >= 0 && \ + std::abs(y - current_input_row) < 1 && \ + std::abs(x - current_input_col) < 1) { \ + const auto weight = (1.0 - std::abs(y - current_input_row)) * \ + (1.0 - std::abs(x - current_input_col)); \ + \ + const auto current_actual_batch = \ + b * p.parallel_imgs + current_batch; \ + \ + std::lock_guard lock( \ + mutex_array[(current_input_row * p.input_cols + \ + current_input_col) % \ + 100]); \ + input_grad_eigen_tensor(current_actual_batch, current_channel, \ + current_input_row, current_input_col) += \ + mask * weight * column_buffer_tensor_flattened(k); \ + } \ } \ } \ } \ - } \ + }; \ + auto thread_pool = \ + context->device()->tensorflow_cpu_worker_threads()->workers; \ + thread_pool->ParallelFor(num_kernels, cost, work); \ } TF_CALL_float(COL2IM_INPUT); TF_CALL_double(COL2IM_INPUT); From 310f04642fed063e47b80f09b87c204eb696dcca Mon Sep 17 00:00:00 2001 From: Licht Takeuchi Date: Sat, 12 Jun 2021 01:22:52 +0900 Subject: [PATCH 13/13] Now works on the latest TensorFlow --- .../custom_ops/layers/cc/kernels/deformable_conv2d_op.cc | 3 ++- .../custom_ops/layers/cc/kernels/deformable_conv2d_op.h | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc b/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc index fb5b250ac6..06f8cdd629 100644 --- a/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc +++ b/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc @@ -22,10 +22,11 @@ #include "tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.h" #include -#include #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/kernel_shape_util.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/util/work_sharder.h" namespace tensorflow { namespace addons { diff --git a/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.h b/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.h index 542e3edb25..44cf11e304 100644 --- a/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.h +++ b/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.h @@ -17,7 +17,8 @@ #define TENSORFLOW_ADDONS_LAYERS_KERNELS_DEFORMABLECONV2D_OP_H_ #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/kernels/batch_matmul_op_impl.h" +#include "tensorflow/core/kernels/matmul_op_impl.h" +#include "tensorflow/core/util/matmul_bcast.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow {