diff --git a/tensorflow_addons/custom_ops/layers/BUILD b/tensorflow_addons/custom_ops/layers/BUILD index 44d877105d..539e9d9146 100644 --- a/tensorflow_addons/custom_ops/layers/BUILD +++ b/tensorflow_addons/custom_ops/layers/BUILD @@ -33,3 +33,19 @@ custom_op_library( "cc/kernels/embedding_bag_backward_kernels.cu.cc", ], ) + +custom_op_library( + name = "_deformable_conv2d_ops.so", + srcs = [ + "cc/kernels/deformable_conv2d_op.cc", + "cc/kernels/deformable_conv2d_op.h", + "cc/ops/deformable_conv2d_op.cc", + ], + cuda_deps = [ + "@cub_archive//:cub", + ], + cuda_srcs = [ + "cc/kernels/deformable_conv2d_op.h", + "cc/kernels/deformable_conv2d_op_gpu.cu.cc", + ], +) diff --git a/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc b/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc new file mode 100644 index 0000000000..06f8cdd629 --- /dev/null +++ b/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.cc @@ -0,0 +1,635 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#define EIGEN_USE_THREADS + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA + +#include "tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.h" + +#include + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/kernel_shape_util.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { +namespace addons { + +using CPUDevice = Eigen::ThreadPoolDevice; +using GPUDevice = Eigen::GpuDevice; + +#if GOOGLE_CUDA +#define EXTERN_TEMPLATE(T) \ + extern template Status Transpose( \ + OpKernelContext * ctx, const Tensor &in, \ + const gtl::ArraySlice perm, Tensor *out); +TF_CALL_float(EXTERN_TEMPLATE); +TF_CALL_double(EXTERN_TEMPLATE); +#undef EXTERN_TEMPLATE +#endif // GOOGLE_CUDA + +namespace functor { + +#if GOOGLE_CUDA +#define EXTERN_TEMPLATE(T) \ + extern template struct DeformableConv2DForwardFunctor; \ + extern template struct DeformableConv2DGradFunctor; +TF_CALL_float(EXTERN_TEMPLATE); +TF_CALL_double(EXTERN_TEMPLATE); +#undef EXTERN_TEMPLATE +#endif // GOOGLE_CUDA + +#define IM2COL(T) \ + template <> \ + void DeformableConv2DFunctorBase::DeformableIm2Col( \ + OpKernelContext *context, int32 b) { \ + auto num_kernels = \ + p.input_channels * p.output_rows * p.output_cols * p.parallel_imgs; \ + \ + const auto offset_eigen_tensor = offset_tensor.tensor(); \ + \ + const auto mask_eigen_tensor = \ + p.use_mask ? mask_tensor.tensor() \ + : mask_tensor.shaped({0, 0, 0, 0, 0, 0, 0}); \ + \ + const auto input_eigen_tensor = input_tensor.tensor(); \ + \ + auto column_buffer_eigen_tensor = column_buffer_tensor.tensor(); \ + \ + const auto cost = p.filter_rows * p.filter_cols; \ + const auto work = [&](Eigen::Index start, Eigen::Index end) -> void { \ + for (Eigen::Index k = start; k < end; ++k) { \ + const auto current_output_col = k % p.output_cols; \ + const auto current_output_row = (k / p.output_cols) % p.output_rows; \ + const auto current_batch = \ + (k / (p.output_rows * p.output_cols)) % p.parallel_imgs; \ + const auto current_input_channel = \ + k / (p.output_rows * p.output_cols * p.parallel_imgs); \ + const auto current_output_channel = \ + current_input_channel * p.filter_rows * p.filter_cols; \ + \ + const auto current_actual_batch = b * p.parallel_imgs + current_batch; \ + \ + const auto group_index = \ + current_input_channel / (p.input_channels / p.offset_groups); \ + \ + auto column_buffer_tensor_channel = current_output_channel; \ + for (auto current_filter_row = 0; current_filter_row < p.filter_rows; \ + current_filter_row++) { \ + for (auto current_filter_col = 0; \ + current_filter_col < p.filter_cols; current_filter_col++) { \ + auto offset_h = \ + offset_eigen_tensor(b, current_batch, group_index, \ + current_filter_row, current_filter_col, 0, \ + current_output_row, current_output_col); \ + auto offset_w = \ + offset_eigen_tensor(b, current_batch, group_index, \ + current_filter_row, current_filter_col, 1, \ + current_output_row, current_output_col); \ + \ + auto mask = p.use_mask \ + ? mask_eigen_tensor( \ + b, current_batch, group_index, \ + current_filter_row, current_filter_col, \ + current_output_row, current_output_col) \ + : T(1); \ + \ + auto y = (current_output_row * p.stride_rows - p.padding_rows) + \ + current_filter_row * p.dilation_rows + offset_h; \ + auto x = (current_output_col * p.stride_cols - p.padding_cols) + \ + current_filter_col * p.dilation_cols + offset_w; \ + \ + column_buffer_eigen_tensor(column_buffer_tensor_channel, \ + current_batch, current_output_row, \ + current_output_col) = \ + mask * BilinearInterpolate(input_eigen_tensor, b, \ + current_actual_batch, \ + current_input_channel, y, x); \ + column_buffer_tensor_channel++; \ + } \ + } \ + } \ + }; \ + auto thread_pool = \ + context->device()->tensorflow_cpu_worker_threads()->workers; \ + thread_pool->ParallelFor(num_kernels, cost, work); \ + } +TF_CALL_float(IM2COL); +TF_CALL_double(IM2COL); +#undef IM2COL + +#define COL2IM_OFFSET_AND_MASK(T) \ + template <> \ + void \ + DeformableConv2DGradFunctor::DeformableCol2ImForOffsetAndMask( \ + OpKernelContext *context, int32 b) { \ + const auto num_kernels = p.output_rows * p.output_cols * 2 * \ + p.filter_rows * p.filter_cols * p.offset_groups * \ + p.parallel_imgs; \ + \ + const auto offset_eigen_tensor = offset_tensor.template tensor(); \ + \ + const auto mask_eigen_tensor = \ + p.use_mask ? mask_tensor.template tensor() \ + : mask_tensor.template shaped({0, 0, 0, 0, 0, 0, 0}); \ + \ + const auto column_buffer_eigen_tensor = \ + column_buffer_tensor.template shaped( \ + {p.input_channels, p.filter_rows, p.filter_cols, p.parallel_imgs, \ + p.output_rows, p.output_cols}); \ + \ + auto offset_grad_eigen_tensor = offset_grad_tensor.tensor(); \ + auto mask_grad_eigen_tensor = mask_grad_tensor.tensor(); \ + \ + const auto input_eigen_tensor = input_tensor.tensor(); \ + \ + const auto cost = p.input_channels / p.offset_groups; \ + const auto work = [&](Eigen::Index start, Eigen::Index end) -> void { \ + for (Eigen::Index k = start; k < end; ++k) { \ + auto offset_grad_value = T(0); \ + auto mask_grad_value = T(0); \ + \ + const auto offset_channels = \ + 2 * p.filter_rows * p.filter_cols * p.offset_groups; \ + \ + const auto offset_channel_step = p.filter_rows * p.filter_cols; \ + \ + const auto current_output_col = k % p.output_cols; \ + const auto current_output_row = (k / p.output_cols) % p.output_rows; \ + const auto current_filter_col = \ + (k / (2 * p.output_rows * p.output_cols)) % p.filter_cols; \ + const auto current_filter_row = \ + (k / (2 * p.output_rows * p.output_cols * p.filter_cols)) % \ + p.filter_rows; \ + const auto current_offset_channel = \ + (k / (p.output_rows * p.output_cols)) % offset_channels; \ + const auto current_batch = \ + k / (p.output_rows * p.output_cols * offset_channels); \ + \ + const auto current_actual_batch = b * p.parallel_imgs + current_batch; \ + \ + const auto current_offset_group = \ + current_offset_channel / (2 * offset_channel_step); \ + \ + const auto channels_per_offset_group = \ + p.input_channels / p.offset_groups; \ + const auto offset_channel_diff = \ + current_offset_channel - \ + current_offset_group * 2 * offset_channel_step; \ + const auto is_y_direction = offset_channel_diff % 2 == 0; \ + \ + for (auto selected_offset_channel = (offset_channel_diff / 2); \ + selected_offset_channel < \ + channels_per_offset_group * offset_channel_step; \ + selected_offset_channel += offset_channel_step) { \ + const auto selected_filter_col = \ + selected_offset_channel % p.filter_cols; \ + const auto selected_filter_row = \ + (selected_offset_channel / p.filter_cols) % p.filter_rows; \ + const auto input_channel_diff = \ + (selected_offset_channel / (p.filter_cols * p.filter_rows)); \ + \ + const auto offset_h = offset_eigen_tensor( \ + b, current_batch, current_offset_group, selected_filter_row, \ + selected_filter_col, 0, current_output_row, current_output_col); \ + const auto offset_w = offset_eigen_tensor( \ + b, current_batch, current_offset_group, selected_filter_row, \ + selected_filter_col, 1, current_output_row, current_output_col); \ + const auto mask = \ + p.use_mask ? mask_eigen_tensor( \ + b, current_batch, current_offset_group, \ + selected_filter_row, selected_filter_col, \ + current_output_row, current_output_col) \ + : T(1); \ + \ + const auto y = \ + (current_output_row * p.stride_rows - p.padding_rows) + \ + selected_filter_row * p.dilation_rows + offset_h; \ + const auto x = \ + (current_output_col * p.stride_cols - p.padding_cols) + \ + selected_filter_col * p.dilation_cols + offset_w; \ + \ + const auto selected_input_channel = \ + input_channel_diff + \ + current_offset_group * channels_per_offset_group; \ + \ + const auto filter_data = column_buffer_eigen_tensor( \ + selected_input_channel, selected_filter_row, \ + selected_filter_col, current_batch, current_output_row, \ + current_output_col); \ + \ + const auto weight = GetCoordinateWeight( \ + input_eigen_tensor, b, current_actual_batch, \ + selected_input_channel, y, x, is_y_direction); \ + \ + offset_grad_value += mask * weight * filter_data; \ + \ + if (is_y_direction) { \ + mask_grad_value += \ + filter_data * BilinearInterpolate( \ + input_eigen_tensor, b, current_actual_batch, \ + selected_input_channel, y, x); \ + } \ + } \ + \ + offset_grad_eigen_tensor(current_actual_batch, current_offset_channel, \ + current_output_row, current_output_col) = \ + offset_grad_value; \ + \ + if (p.use_mask && is_y_direction) { \ + const auto current_mask_channel = \ + (current_offset_group * p.filter_rows + current_filter_row) * \ + p.filter_cols + \ + current_filter_col; \ + \ + mask_grad_eigen_tensor(current_actual_batch, current_mask_channel, \ + current_output_row, current_output_col) = \ + mask_grad_value; \ + } \ + } \ + }; \ + auto thread_pool = \ + context->device()->tensorflow_cpu_worker_threads()->workers; \ + thread_pool->ParallelFor(num_kernels, cost, work); \ + } +TF_CALL_float(COL2IM_OFFSET_AND_MASK); +TF_CALL_double(COL2IM_OFFSET_AND_MASK); +#undef COL2IM_OFFSET_AND_MASK + +#define COL2IM_INPUT(T) \ + template <> \ + void DeformableConv2DGradFunctor::DeformableCol2ImForInput( \ + OpKernelContext *context, int32 b) { \ + const auto num_kernels = p.input_channels * p.filter_rows * \ + p.filter_cols * p.output_rows * p.output_cols * \ + p.parallel_imgs; \ + \ + const auto offset_eigen_tensor = offset_tensor.template tensor(); \ + \ + const auto mask_eigen_tensor = \ + p.use_mask ? mask_tensor.template tensor() \ + : mask_tensor.template shaped({0, 0, 0, 0, 0, 0, 0}); \ + \ + const auto column_buffer_tensor_flattened = \ + column_buffer_tensor.template shaped({num_kernels}); \ + \ + auto input_grad_eigen_tensor = input_grad_tensor.tensor(); \ + \ + const auto cost = 3 * 3; \ + std::array mutex_array; \ + \ + const auto work = [&](Eigen::Index start, Eigen::Index end) -> void { \ + for (Eigen::Index k = start; k < end; ++k) { \ + const auto current_output_col = k % p.output_cols; \ + const auto current_output_row = (k / p.output_cols) % p.output_rows; \ + const auto current_batch = \ + (k / (p.output_rows * p.output_cols)) % p.parallel_imgs; \ + \ + const auto current_filter_col = \ + (k / (p.output_rows * p.output_cols * p.parallel_imgs)) % \ + p.filter_cols; \ + const auto current_filter_row = \ + (k / (p.output_rows * p.output_cols * p.parallel_imgs * \ + p.filter_cols)) % \ + p.filter_rows; \ + const auto current_channel = \ + k / (p.output_rows * p.output_cols * p.parallel_imgs * \ + p.filter_rows * p.filter_cols); \ + \ + const auto current_offset_group = \ + current_channel / (p.input_channels / p.offset_groups); \ + \ + const auto mask = \ + p.use_mask \ + ? mask_eigen_tensor(b, current_batch, current_offset_group, \ + current_filter_row, current_filter_col, \ + current_output_row, current_output_col) \ + : T(1); \ + \ + const auto offset_h = offset_eigen_tensor( \ + b, current_batch, current_offset_group, current_filter_row, \ + current_filter_col, 0, current_output_row, current_output_col); \ + const auto offset_w = offset_eigen_tensor( \ + b, current_batch, current_offset_group, current_filter_row, \ + current_filter_col, 1, current_output_row, current_output_col); \ + \ + const auto y = (current_output_row * p.stride_rows - p.padding_rows) + \ + current_filter_row * p.dilation_rows + offset_h; \ + const auto x = (current_output_col * p.stride_cols - p.padding_cols) + \ + current_filter_col * p.dilation_cols + offset_w; \ + \ + for (auto dy = -1; dy <= 1; dy++) { \ + for (auto dx = -1; dx <= 1; dx++) { \ + const auto current_input_row = int(y) + dy; \ + const auto current_input_col = int(x) + dx; \ + \ + if (p.input_rows > current_input_row && current_input_row >= 0 && \ + p.input_cols > current_input_col && current_input_col >= 0 && \ + std::abs(y - current_input_row) < 1 && \ + std::abs(x - current_input_col) < 1) { \ + const auto weight = (1.0 - std::abs(y - current_input_row)) * \ + (1.0 - std::abs(x - current_input_col)); \ + \ + const auto current_actual_batch = \ + b * p.parallel_imgs + current_batch; \ + \ + std::lock_guard lock( \ + mutex_array[(current_input_row * p.input_cols + \ + current_input_col) % \ + 100]); \ + input_grad_eigen_tensor(current_actual_batch, current_channel, \ + current_input_row, current_input_col) += \ + mask * weight * column_buffer_tensor_flattened(k); \ + } \ + } \ + } \ + } \ + }; \ + auto thread_pool = \ + context->device()->tensorflow_cpu_worker_threads()->workers; \ + thread_pool->ParallelFor(num_kernels, cost, work); \ + } +TF_CALL_float(COL2IM_INPUT); +TF_CALL_double(COL2IM_INPUT); +#undef COL2IM_INPUT + +} // end namespace functor + +template +class DeformableConv2DOpBase : public OpKernel { + public: + explicit DeformableConv2DOpBase(OpKernelConstruction *context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("strides", &strides)); + OP_REQUIRES_OK(context, context->GetAttr("weight_groups", &weight_groups)); + OP_REQUIRES_OK(context, context->GetAttr("offset_groups", &offset_groups)); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding)); + OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations)); + string data_format_str; + OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str)); + FormatFromString(data_format_str, &data_format); + + p = DeformableConv2DParams{}; + } + + void Compute(OpKernelContext *context) override { + const Tensor &input_tensor = context->input(0); + const Tensor &filter_tensor = context->input(1); + + const Tensor &bias_tensor = context->input(2); + const Tensor &mask_tensor = context->input(4); + + const TensorShape &input_shape = input_tensor.shape(); + const TensorShape &filter_shape = filter_tensor.shape(); + + const auto input_batches = input_shape.dim_size(0); + const auto input_channels = input_shape.dim_size(1); + const auto input_rows = input_shape.dim_size(2); + const auto input_cols = input_shape.dim_size(3); + + const auto output_channels = filter_shape.dim_size(0); + const auto filter_channels = filter_shape.dim_size(1); + const auto filter_rows = filter_shape.dim_size(2); + const auto filter_cols = filter_shape.dim_size(3); + + const auto dilation_rows = dilations[0]; + const auto dilation_cols = dilations[1]; + + const auto stride_rows = strides[0]; + const auto stride_cols = strides[1]; + + const auto parallel_imgs = GetParallelImgs(input_batches); + + int64 output_rows, output_cols; + int64 padding_rows, padding_cols; + OP_REQUIRES_OK( + context, GetWindowedOutputSizeV2(input_rows, filter_rows, dilation_rows, + stride_rows, padding, &output_rows, + &padding_rows)); + OP_REQUIRES_OK( + context, GetWindowedOutputSizeV2(input_cols, filter_cols, dilation_cols, + stride_cols, padding, &output_cols, + &padding_cols)); + + p.input_batches = input_batches; + p.input_channels = input_channels; + p.input_rows = input_rows; + p.input_cols = input_cols; + p.filter_channels = filter_channels; + p.filter_rows = filter_rows; + p.filter_cols = filter_cols; + p.padding_rows = padding_rows; + p.padding_cols = padding_cols; + p.stride_rows = stride_rows; + p.stride_cols = stride_cols; + p.dilation_rows = dilation_rows; + p.dilation_cols = dilation_cols; + p.output_channels = output_channels; + p.output_rows = output_rows; + p.output_cols = output_cols; + p.parallel_imgs = parallel_imgs; + p.weight_groups = weight_groups; + p.offset_groups = offset_groups; + p.batches = p.input_batches / p.parallel_imgs; + p.use_mask = mask_tensor.NumElements() > 0; + p.use_bias = bias_tensor.NumElements() > 0; + } + + int GetParallelImgs(int n) { + for (auto k = kMaxParallelImgs; k > 1; --k) { + if (n % k == 0) { + return k; + } + } + return 1; + } + + protected: + TensorFormat data_format; + DeformableConv2DParams p; + + private: + std::vector strides; + int32 weight_groups; + int32 offset_groups; + Padding padding; + std::vector dilations; +}; + +template +class DeformableConv2DForwardOp : public DeformableConv2DOpBase { + using DeformableConv2DOpBase::data_format; + using DeformableConv2DOpBase::p; + + public: + explicit DeformableConv2DForwardOp(OpKernelConstruction *context) + : DeformableConv2DOpBase(context) {} + + void Compute(OpKernelContext *context) override { + DeformableConv2DOpBase::Compute(context); + + const Tensor &input_tensor = context->input(0); + const Tensor &filter_tensor = context->input(1); + const Tensor &bias_tensor = context->input(2); + const Tensor &offset_tensor = context->input(3); + const Tensor &mask_tensor = context->input(4); + + TensorShape column_buffer_shape( + {p.input_channels * p.filter_rows * p.filter_cols, p.parallel_imgs, + p.output_rows, p.output_cols}); + Tensor column_buffer_tensor; + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, + column_buffer_shape, + &column_buffer_tensor)); + + TensorShape output_shape = + ShapeFromFormat(data_format, p.input_batches, p.output_rows, + p.output_cols, p.output_channels); + Tensor *output_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, output_shape, &output_tensor)); + + functor::DeformableConv2DForwardFunctor deformableConv2DFunc( + &input_tensor, &filter_tensor, &bias_tensor, &offset_tensor, + &mask_tensor, &column_buffer_tensor, output_tensor, &p); + Status s = deformableConv2DFunc(context); + + OP_REQUIRES_OK(context, s); + } +}; + +template +class DeformableConv2DGradOp : public DeformableConv2DOpBase { + using DeformableConv2DOpBase::data_format; + using DeformableConv2DOpBase::p; + + public: + explicit DeformableConv2DGradOp(OpKernelConstruction *context) + : DeformableConv2DOpBase(context) {} + + void Compute(OpKernelContext *context) override { + DeformableConv2DOpBase::Compute(context); + + const Tensor &input_tensor = context->input(0); + const Tensor &filter_tensor = context->input(1); + const Tensor &bias_tensor = context->input(2); + const Tensor &offset_tensor = context->input(3); + const Tensor &mask_tensor = context->input(4); + const Tensor &output_grad_tensor = context->input(5); + + const TensorShape &input_shape = input_tensor.shape(); + const TensorShape &filter_shape = filter_tensor.shape(); + const TensorShape &bias_shape = bias_tensor.shape(); + const TensorShape &offset_shape = offset_tensor.shape(); + const TensorShape &mask_shape = mask_tensor.shape(); + + TensorShape column_buffer_shape( + {p.input_channels * p.filter_rows * p.filter_cols, p.parallel_imgs, + p.output_rows, p.output_cols}); + Tensor column_buffer_tensor; + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, + column_buffer_shape, + &column_buffer_tensor)); + + Tensor output_grad_tensor_reshaped; + CHECK(output_grad_tensor_reshaped.CopyFrom( + output_grad_tensor, + TensorShape({p.batches, p.parallel_imgs, p.output_channels, + p.output_rows, p.output_cols}))); + + TensorShape output_grad_tensor_transposed_shape( + {p.batches, p.output_channels, p.parallel_imgs, p.output_rows, + p.output_cols}); + Tensor output_grad_tensor_transposed; + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum::value, + output_grad_tensor_transposed_shape, + &output_grad_tensor_transposed)); + OP_REQUIRES_OK(context, + Transpose(context, output_grad_tensor_reshaped, + {0, 2, 1, 3, 4}, + &output_grad_tensor_transposed)); + + TensorShape output_shape = + ShapeFromFormat(data_format, p.input_batches, p.output_rows, + p.output_cols, p.output_channels); + + Tensor *input_grad_tensor = nullptr; + OP_REQUIRES_OK( + context, context->allocate_output(0, input_shape, &input_grad_tensor)); + Tensor *filter_grad_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(1, filter_shape, + &filter_grad_tensor)); + Tensor *bias_grad_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(2, bias_shape, &bias_grad_tensor)); + Tensor *offset_grad_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(3, offset_shape, + &offset_grad_tensor)); + Tensor *mask_grad_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(4, mask_shape, &mask_grad_tensor)); + + functor::DeformableConv2DGradFunctor deformableConv2DGradFunc( + &input_tensor, &filter_tensor, &bias_tensor, &offset_tensor, + &mask_tensor, &output_grad_tensor_transposed, input_grad_tensor, + filter_grad_tensor, bias_grad_tensor, offset_grad_tensor, + mask_grad_tensor, &column_buffer_tensor, &p); + Status s = deformableConv2DGradFunc(context); + + OP_REQUIRES_OK(context, s); + } +}; + +// Register the CPU kernels. +#define REGISTER_DEFORMABLECONV2D_OP_CPU(T) \ + REGISTER_KERNEL_BUILDER(Name("Addons>DeformableConv2D") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + DeformableConv2DForwardOp) \ + REGISTER_KERNEL_BUILDER(Name("Addons>DeformableConv2DGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + DeformableConv2DGradOp) + +TF_CALL_float(REGISTER_DEFORMABLECONV2D_OP_CPU); +TF_CALL_double(REGISTER_DEFORMABLECONV2D_OP_CPU); +#undef REGISTER_DEFORMABLECONV2D_OP_CPU + +// Register the GPU kernels. +#if GOOGLE_CUDA + +#define REGISTER_DEFORMABLECONV2D_OP_GPU(T) \ + REGISTER_KERNEL_BUILDER(Name("Addons>DeformableConv2D") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T"), \ + DeformableConv2DForwardOp) \ + REGISTER_KERNEL_BUILDER(Name("Addons>DeformableConv2DGrad") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T"), \ + DeformableConv2DGradOp) + +TF_CALL_float(REGISTER_DEFORMABLECONV2D_OP_GPU); +TF_CALL_double(REGISTER_DEFORMABLECONV2D_OP_GPU); +#undef REGISTER_DEFORMABLECONV2D_OP_GPU + +#endif // GOOGLE_CUDA + +} // namespace addons +} // namespace tensorflow diff --git a/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.h b/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.h new file mode 100644 index 0000000000..44cf11e304 --- /dev/null +++ b/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.h @@ -0,0 +1,598 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef TENSORFLOW_ADDONS_LAYERS_KERNELS_DEFORMABLECONV2D_OP_H_ +#define TENSORFLOW_ADDONS_LAYERS_KERNELS_DEFORMABLECONV2D_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/matmul_op_impl.h" +#include "tensorflow/core/util/matmul_bcast.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { +namespace addons { +static const int kMaxParallelImgs = 32; + +struct DeformableConv2DParams { + int32 input_batches; + int32 input_channels; + int32 input_rows; + int32 input_cols; + int32 filter_channels; + int32 filter_rows; + int32 filter_cols; + int32 padding_rows; + int32 padding_cols; + int32 stride_rows; + int32 stride_cols; + int32 dilation_rows; + int32 dilation_cols; + int32 output_channels; + int32 output_rows; + int32 output_cols; + int32 parallel_imgs; + int32 weight_groups; + int32 offset_groups; + int32 batches; + bool use_mask; + bool use_bias; +}; + +template +Status TensorSetZero(OpKernelContext *ctx, Tensor *value) { + const auto d = ctx->template eigen_device(); + auto out = value->flat(); + + const bool use_64bit = out.size() > Eigen::NumTraits::highest(); + + if (!use_64bit && Eigen::internal::is_same::value) { + To32Bit(out).device(d) = To32Bit(out).constant(T(0)); + } else { + out.device(d) = out.constant(T(0)); + } + + return Status::OK(); +} + +template +Status AddToTensor(OpKernelContext *ctx, Tensor *sum, const Tensor *current, + const Tensor *add) { + const auto d = ctx->template eigen_device(); + + auto out = sum->flat(); + auto a = current->flat(); + auto b = add->flat(); + + const bool use_64bit = out.size() > Eigen::NumTraits::highest(); + + if (!use_64bit && Eigen::internal::is_same::value) { + To32Bit(out).device(d) = To32Bit(a) + To32Bit(b); + } else { + out.device(d) = a + b; + } + + return Status::OK(); +} + +template +Status Transpose(OpKernelContext *ctx, const Tensor &in, + const gtl::ArraySlice perm, Tensor *out) { + const auto d = ctx->template eigen_device(); + + Eigen::array p; + for (int i = 0; i < NDIMS; ++i) { + p[i] = perm[i]; + } + + auto x = typename TTypes::ConstTensor( + reinterpret_cast(in.tensor_data().data()), + in.shape().AsEigenDSizes()); + auto y = typename TTypes::Tensor( + reinterpret_cast(const_cast(out->tensor_data().data())), + out->shape().AsEigenDSizes()); + + const bool use_64bit = x.size() > Eigen::NumTraits::highest(); + + if (!use_64bit && Eigen::internal::is_same::value) { + To32Bit(y).device(d) = To32Bit(x).shuffle(p); + } else { + y.device(d) = x.shuffle(p); + } + + return Status::OK(); +} + +template +Status CopySliceToElement(OpKernelContext *ctx, const Tensor &parent, + Tensor *element, int64 index) { + const auto d = ctx->template eigen_device(); + + auto out = element->flat(); + auto in = parent.flat_outer_dims(); + + const bool use_64bit = in.size() > Eigen::NumTraits::highest(); + + if (!use_64bit && Eigen::internal::is_same::value) { + To32Bit(out).device(d) = To32Bit(in).chip(index, 0); + } else { + out.device(d) = in.chip(index, 0); + } + + return Status::OK(); +} + +template +Status CopyElementToSlice(OpKernelContext *ctx, Tensor element, Tensor *parent, + int64 index) { + const auto d = ctx->template eigen_device(); + + auto out = parent->flat_outer_dims(); + auto in = element.flat(); + + const bool use_64bit = out.size() > Eigen::NumTraits::highest(); + + if (!use_64bit && Eigen::internal::is_same::value) { + To32Bit(out).chip(index, 0).device(d) = To32Bit(in); + } else { + out.chip(index, 0).device(d) = in; + } + + return Status::OK(); +} + +namespace functor { + +template +EIGEN_DEVICE_FUNC T BilinearInterpolate(typename TTypes::Tensor img, + int32 b, int32 batch, int32 channel, + T y, T x) { + const auto max_height = img.dimension(3); + const auto max_width = img.dimension(4); + + if (y <= -1 || max_height <= y || x <= -1 || max_width <= x) { + return T(0); + } + + int y_low = floor(y); + int x_low = floor(x); + int y_high = y_low + 1; + int w_high = x_low + 1; + + auto v1 = T(0); + if (y_low >= 0 && x_low >= 0) { + v1 = img(b, batch, channel, y_low, x_low); + } + + auto v2 = T(0); + if (y_low >= 0 && w_high <= max_width - 1) { + v2 = img(b, batch, channel, y_low, w_high); + } + + auto v3 = T(0); + if (y_high <= max_height - 1 && x_low >= 0) { + v3 = img(b, batch, channel, y_high, x_low); + } + + auto v4 = T(0); + if (y_high <= max_height - 1 && w_high <= max_width - 1) { + v4 = img(b, batch, channel, y_high, w_high); + } + + auto lh = y - y_low; + auto lw = x - x_low; + auto hh = 1 - lh; + auto hw = 1 - lw; + + auto w1 = hh * hw; + auto w2 = hh * lw; + auto w3 = lh * hw; + auto w4 = lh * lw; + + return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; +} + +template +EIGEN_DEVICE_FUNC T GetCoordinateWeight(typename TTypes::Tensor img, + int32 b, int32 batch, int32 channel, + T y, T x, bool is_y_direction) { + const auto max_height = img.dimension(3); + const auto max_width = img.dimension(4); + + const int y_low = floor(y); + const int x_low = floor(x); + const int y_high = y_low + 1; + const int x_high = x_low + 1; + + const bool valid_y_low = max_height > y_low && y_low >= 0; + const bool valid_y_high = max_height > y_high && y_high >= 0; + const bool valid_x_low = max_width > x_low && x_low >= 0; + const bool valid_x_high = max_width > x_high && x_high >= 0; + + auto v_yx = T(0); + if (valid_y_low && valid_x_low) { + v_yx = img(b, batch, channel, y_low, x_low); + } + + auto v_yX = T(0); + if (valid_y_low && valid_x_high) { + v_yX = img(b, batch, channel, y_low, x_high); + } + + auto v_Yx = T(0); + if (valid_y_high && valid_x_low) { + v_Yx = img(b, batch, channel, y_high, x_low); + } + + auto v_YX = T(0); + if (valid_y_high && valid_x_high) { + v_YX = img(b, batch, channel, y_high, x_high); + } + + if (is_y_direction) { + const auto dx = x - x_low; + return (v_YX - v_yX) * dx + (v_Yx - v_yx) * (1 - dx); + } else { + const auto dy = y - y_low; + return (v_YX - v_Yx) * dy + (v_yX - v_yx) * (1 - dy); + } +} + +template +struct DeformableConv2DFunctorBase { + DeformableConv2DFunctorBase(const Tensor *_input_tensor, + const Tensor *_filter_tensor, + const Tensor *_bias_tensor, + const Tensor *_offset_tensor, + const Tensor *_mask_tensor, + Tensor *_column_buffer_tensor, + DeformableConv2DParams *_p) + : input_tensor(_input_tensor->dtype()), + filter_tensor(_filter_tensor->dtype()), + bias_tensor(_bias_tensor->dtype()), + offset_tensor(_offset_tensor->dtype()), + mask_tensor(_mask_tensor->dtype()), + column_buffer_tensor(_column_buffer_tensor->dtype()), + p(*_p) { + CHECK(input_tensor.CopyFrom( + *_input_tensor, + TensorShape({p.batches, p.parallel_imgs, p.input_channels, p.input_rows, + p.input_cols}))); + CHECK(filter_tensor.CopyFrom( + *_filter_tensor, + TensorShape({p.weight_groups, p.output_channels / p.weight_groups, + p.filter_channels * p.filter_rows * p.filter_cols}))); + CHECK(bias_tensor.CopyFrom(*_bias_tensor, _bias_tensor->shape())); + + CHECK(offset_tensor.CopyFrom( + *_offset_tensor, + TensorShape({p.batches, p.parallel_imgs, p.offset_groups, p.filter_rows, + p.filter_cols, 2, p.output_rows, p.output_cols}))); + + if (p.use_mask) { + CHECK(mask_tensor.CopyFrom( + *_mask_tensor, + TensorShape({p.batches, p.parallel_imgs, p.offset_groups, + p.filter_rows, p.filter_cols, p.output_rows, + p.output_cols}))); + } else { + CHECK(mask_tensor.CopyFrom(*_mask_tensor, + TensorShape({0, 0, 0, 0, 0, 0, 0}))); + } + + CHECK(column_buffer_tensor.CopyFrom( + *_column_buffer_tensor, + TensorShape({p.input_channels * p.filter_rows * p.filter_cols, + p.parallel_imgs, p.output_rows, p.output_cols}))); + } + + virtual Status operator()(OpKernelContext *context) = 0; + + void DeformableIm2Col(OpKernelContext *context, int32 b); + + Tensor input_tensor; + Tensor filter_tensor; + Tensor bias_tensor; + Tensor offset_tensor; + Tensor mask_tensor; + Tensor column_buffer_tensor; + DeformableConv2DParams p; +}; + +template +struct DeformableConv2DForwardFunctor + : public DeformableConv2DFunctorBase { + using DeformableConv2DFunctorBase::input_tensor; + using DeformableConv2DFunctorBase::filter_tensor; + using DeformableConv2DFunctorBase::bias_tensor; + using DeformableConv2DFunctorBase::offset_tensor; + using DeformableConv2DFunctorBase::mask_tensor; + using DeformableConv2DFunctorBase::column_buffer_tensor; + using DeformableConv2DFunctorBase::p; + + DeformableConv2DForwardFunctor( + const Tensor *_input_tensor, const Tensor *_filter_tensor, + const Tensor *_bias_tensor, const Tensor *_offset_tensor, + const Tensor *_mask_tensor, Tensor *_column_buffer_tensor, + Tensor *_output_tensor, DeformableConv2DParams *_p) + : DeformableConv2DFunctorBase( + _input_tensor, _filter_tensor, _bias_tensor, _offset_tensor, + _mask_tensor, _column_buffer_tensor, _p), + output_tensor(_output_tensor->dtype()) { + CHECK(output_tensor.CopyFrom(*_output_tensor, _output_tensor->shape())); + } + + Status operator()(OpKernelContext *context) { + // input_channels * filter_rows * filter_cols / weight_groups == + // filter_channels * filter_rows * filter_cols + const auto elems = p.filter_channels * p.filter_rows * p.filter_cols; + const auto rows = p.output_channels / p.weight_groups; + const auto cols = p.parallel_imgs * p.output_rows * p.output_cols; + + Tensor output_tmp_tensor; + TF_RETURN_IF_ERROR(context->allocate_temp( + DataTypeToEnum::value, + TensorShape({p.batches, p.output_channels, p.parallel_imgs, + p.output_rows, p.output_cols}), + &output_tmp_tensor)); + + Tensor output_tmp_mtx_tensor; + TF_RETURN_IF_ERROR(context->allocate_temp( + DataTypeToEnum::value, TensorShape({p.weight_groups, rows, cols}), + &output_tmp_mtx_tensor)); + + Tensor output_tmp_tensor_reshaped(output_tmp_tensor.dtype()); + CHECK(output_tmp_tensor_reshaped.CopyFrom( + output_tmp_tensor, + TensorShape({p.batches, p.weight_groups, rows, cols}))); + + Tensor column_buffer_tensor_reshaped(column_buffer_tensor.dtype()); + CHECK(column_buffer_tensor_reshaped.CopyFrom( + column_buffer_tensor, TensorShape({p.weight_groups, elems, cols}))); + + for (auto b = 0; b < p.batches; b++) { + this->DeformableIm2Col(context, b); + + auto lhs = filter_tensor; + auto rhs = column_buffer_tensor_reshaped; + auto out = output_tmp_mtx_tensor; + + MatMulBCast bcast(lhs.shape().dim_sizes(), rhs.shape().dim_sizes()); + + LaunchBatchMatMul::Launch(context, lhs, rhs, false, false, + false, false, bcast, &out); + + TF_RETURN_IF_ERROR(CopyElementToSlice( + context, out, &output_tmp_tensor_reshaped, b)); + } + + Tensor output_tensor_reshaped(output_tensor.dtype()); + CHECK(output_tensor_reshaped.CopyFrom( + output_tensor, + TensorShape({p.batches, p.parallel_imgs, p.output_channels, + p.output_rows, p.output_cols}))); + TF_RETURN_IF_ERROR(Transpose( + context, output_tmp_tensor, {0, 2, 1, 3, 4}, &output_tensor_reshaped)); + + if (p.use_bias) { + Eigen::DSizes four_dims(1, p.output_channels, 1, 1); + Eigen::DSizes broadcast_dims(p.input_batches, 1, p.output_rows, + p.output_cols); + + const auto d = context->eigen_device(); + + auto out = output_tensor.tensor(); + auto add = bias_tensor.template tensor(); + + const bool use_64bit = out.size() > Eigen::NumTraits::highest(); + + if (!use_64bit && + Eigen::internal::is_same::value) { + To32Bit(out).device(d) += + To32Bit(add).reshape(four_dims).broadcast(broadcast_dims); + } else { + out.device(d) += add.reshape(four_dims).broadcast(broadcast_dims); + } + } + + return Status::OK(); + } + + Tensor output_tensor; +}; + +template +struct DeformableConv2DGradFunctor + : public DeformableConv2DFunctorBase { + using DeformableConv2DFunctorBase::input_tensor; + using DeformableConv2DFunctorBase::filter_tensor; + using DeformableConv2DFunctorBase::bias_tensor; + using DeformableConv2DFunctorBase::offset_tensor; + using DeformableConv2DFunctorBase::mask_tensor; + using DeformableConv2DFunctorBase::column_buffer_tensor; + using DeformableConv2DFunctorBase::p; + + DeformableConv2DGradFunctor( + const Tensor *_input_tensor, const Tensor *_filter_tensor, + const Tensor *_bias_tensor, const Tensor *_offset_tensor, + const Tensor *_mask_tensor, Tensor *_output_grad_tensor, + Tensor *_input_grad_tensor, Tensor *_filter_grad_tensor, + Tensor *_bias_grad_tensor, Tensor *_offset_grad_tensor, + Tensor *_mask_grad_tensor, Tensor *_column_buffer_tensor, + DeformableConv2DParams *_p) + : DeformableConv2DFunctorBase( + _input_tensor, _filter_tensor, _bias_tensor, _offset_tensor, + _mask_tensor, _column_buffer_tensor, _p), + output_grad_tensor(_output_grad_tensor->dtype()), + input_grad_tensor(_input_grad_tensor->dtype()), + filter_grad_tensor(_filter_grad_tensor->dtype()), + bias_grad_tensor(_bias_grad_tensor->dtype()), + offset_grad_tensor(_offset_grad_tensor->dtype()), + mask_grad_tensor(_mask_grad_tensor->dtype()) { + CHECK(output_grad_tensor.CopyFrom(*_output_grad_tensor, + _output_grad_tensor->shape())); + CHECK(input_grad_tensor.CopyFrom(*_input_grad_tensor, + _input_grad_tensor->shape())); + CHECK(filter_grad_tensor.CopyFrom( + *_filter_grad_tensor, + TensorShape({p.weight_groups, p.output_channels / p.weight_groups, + p.filter_channels * p.filter_rows * p.filter_cols}))); + CHECK(bias_grad_tensor.CopyFrom(*_bias_grad_tensor, + _bias_grad_tensor->shape())); + CHECK(offset_grad_tensor.CopyFrom(*_offset_grad_tensor, + _offset_grad_tensor->shape())); + CHECK(mask_grad_tensor.CopyFrom(*_mask_grad_tensor, + _mask_grad_tensor->shape())); + } + + Status operator()(OpKernelContext *context) { + TF_RETURN_IF_ERROR(TensorSetZero(context, &input_grad_tensor)); + TF_RETURN_IF_ERROR(TensorSetZero(context, &filter_grad_tensor)); + TF_RETURN_IF_ERROR( + TensorSetZero(context, &column_buffer_tensor)); + + ComputeInputOffsetMaskGrad(context); + + ComputeFilterGrad(context); + + if (p.use_bias) { + const auto d = context->eigen_device(); + + auto out = bias_grad_tensor.tensor(); + auto in = output_grad_tensor.tensor(); + + const bool use_64bit = in.size() > Eigen::NumTraits::highest(); + + const Eigen::array axes({0, 2, 3, 4}); + + if (!use_64bit && + Eigen::internal::is_same::value) { + To32Bit(out).device(d) = To32Bit(in).sum(axes); + } else { + out.device(d) = in.sum(axes); + } + } + + return Status::OK(); + } + + void ComputeFilterGrad(OpKernelContext *context) { + // input_channels * filter_rows * filter_cols / weight_groups == + // filter_channels * filter_rows * filter_cols + const auto cols = p.filter_channels * p.filter_rows * p.filter_cols; + const auto rows = p.output_channels / p.weight_groups; + const auto elems = p.parallel_imgs * p.output_rows * p.output_cols; + + Tensor output_grad_tensor_reshaped(output_grad_tensor.dtype()); + CHECK(output_grad_tensor_reshaped.CopyFrom( + output_grad_tensor, + TensorShape({p.batches, p.weight_groups, rows, elems}))); + + Tensor column_buffer_tensor_reshaped(column_buffer_tensor.dtype()); + CHECK(column_buffer_tensor_reshaped.CopyFrom( + column_buffer_tensor, TensorShape({p.weight_groups, cols, elems}))); + + Tensor matmul_lhs_tmp_tensor; + OP_REQUIRES_OK(context, context->allocate_temp( + DataTypeToEnum::value, + TensorShape({p.weight_groups, rows, elems}), + &matmul_lhs_tmp_tensor)); + + Tensor matmul_out_tmp_tensor; + OP_REQUIRES_OK(context, context->allocate_temp( + DataTypeToEnum::value, + TensorShape({p.weight_groups, rows, cols}), + &matmul_out_tmp_tensor)); + + for (auto b = 0; b < p.batches; b++) { + this->DeformableIm2Col(context, b); + + auto lhs = matmul_lhs_tmp_tensor; + auto rhs = column_buffer_tensor_reshaped; + auto out = matmul_out_tmp_tensor; + + OP_REQUIRES_OK(context, + CopySliceToElement( + context, output_grad_tensor_reshaped, &lhs, b)); + + MatMulBCast bcast(lhs.shape().dim_sizes(), rhs.shape().dim_sizes()); + + LaunchBatchMatMul::Launch(context, lhs, rhs, false, false, + false, true, bcast, &out); + + OP_REQUIRES_OK(context, + AddToTensor(context, &filter_grad_tensor, + &filter_grad_tensor, &out)); + } + } + + void ComputeInputOffsetMaskGrad(OpKernelContext *context) { + // input_channels * filter_rows * filter_cols / weight_groups == + // filter_channels * filter_rows * filter_cols + const auto rows = p.filter_channels * p.filter_rows * p.filter_cols; + const auto elems = p.output_channels / p.weight_groups; + const auto cols = p.parallel_imgs * p.output_rows * p.output_cols; + + Tensor output_grad_tensor_reshaped(output_grad_tensor.dtype()); + CHECK(output_grad_tensor_reshaped.CopyFrom( + output_grad_tensor, + TensorShape({p.batches, p.weight_groups, elems, cols}))); + + Tensor column_buffer_tensor_reshaped(column_buffer_tensor.dtype()); + CHECK(column_buffer_tensor_reshaped.CopyFrom( + column_buffer_tensor, TensorShape({p.weight_groups, rows, cols}))); + + Tensor matmul_rhs_tmp_tensor; + OP_REQUIRES_OK(context, context->allocate_temp( + DataTypeToEnum::value, + TensorShape({p.weight_groups, elems, cols}), + &matmul_rhs_tmp_tensor)); + + for (auto b = 0; b < p.batches; b++) { + auto lhs = filter_tensor; + auto rhs = matmul_rhs_tmp_tensor; + auto out = column_buffer_tensor_reshaped; + + OP_REQUIRES_OK(context, + CopySliceToElement( + context, output_grad_tensor_reshaped, &rhs, b)); + + MatMulBCast bcast(lhs.shape().dim_sizes(), rhs.shape().dim_sizes()); + + LaunchBatchMatMul::Launch(context, lhs, rhs, false, false, + true, false, bcast, &out); + + DeformableCol2ImForOffsetAndMask(context, b); + + DeformableCol2ImForInput(context, b); + } + } + + void DeformableCol2ImForOffsetAndMask(OpKernelContext *context, int32 b); + + void DeformableCol2ImForInput(OpKernelContext *context, int32 b); + + Tensor output_grad_tensor; + Tensor input_grad_tensor; + Tensor filter_grad_tensor; + Tensor bias_grad_tensor; + Tensor offset_grad_tensor; + Tensor mask_grad_tensor; +}; + +} // namespace functor +} // namespace addons +} // namespace tensorflow + +#endif // TENSORFLOW_ADDONS_LAYERS_KERNELS_DEFORMABLECONV2D_OP_H_ diff --git a/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op_gpu.cu.cc b/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op_gpu.cu.cc new file mode 100644 index 0000000000..e4f02c6268 --- /dev/null +++ b/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op_gpu.cu.cc @@ -0,0 +1,385 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/kernel_shape_util.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" +#include "tensorflow_addons/custom_ops/layers/cc/kernels/deformable_conv2d_op.h" + +namespace tensorflow { +namespace addons { + +using GPUDevice = Eigen::GpuDevice; + +namespace functor { + +template +__global__ void DeformableIm2ColKernel( + int32 b, int32 num_kernels, DeformableConv2DParams p, + typename TTypes::Tensor input_eigen_tensor, + typename TTypes::Tensor offset_eigen_tensor, + typename TTypes::Tensor mask_eigen_tensor, + typename TTypes::Tensor column_buffer_eigen_tensor) { + CUDA_1D_KERNEL_LOOP(k, num_kernels) { + const auto current_output_col = k % p.output_cols; + const auto current_output_row = (k / p.output_cols) % p.output_rows; + const auto current_batch = + (k / (p.output_rows * p.output_cols)) % p.parallel_imgs; + const auto current_input_channel = + k / (p.output_rows * p.output_cols * p.parallel_imgs); + const auto current_output_channel = + current_input_channel * p.filter_rows * p.filter_cols; + + const auto current_actual_batch = b * p.parallel_imgs + current_batch; + + const auto group_index = + current_input_channel / (p.input_channels / p.offset_groups); + + auto column_buffer_tensor_channel = current_output_channel; + for (auto current_filter_row = 0; current_filter_row < p.filter_rows; + current_filter_row++) { + for (auto current_filter_col = 0; current_filter_col < p.filter_cols; + current_filter_col++) { + auto offset_h = offset_eigen_tensor( + b, current_batch, group_index, current_filter_row, + current_filter_col, 0, current_output_row, current_output_col); + auto offset_w = offset_eigen_tensor( + b, current_batch, group_index, current_filter_row, + current_filter_col, 1, current_output_row, current_output_col); + + auto mask = p.use_mask ? mask_eigen_tensor( + b, current_batch, group_index, + current_filter_row, current_filter_col, + current_output_row, current_output_col) + : T(1); + + auto y = (current_output_row * p.stride_rows - p.padding_rows) + + current_filter_row * p.dilation_rows + offset_h; + auto x = (current_output_col * p.stride_cols - p.padding_cols) + + current_filter_col * p.dilation_cols + offset_w; + + column_buffer_eigen_tensor(column_buffer_tensor_channel, current_batch, + current_output_row, current_output_col) = + mask * BilinearInterpolate(input_eigen_tensor, b, + current_actual_batch, + current_input_channel, y, x); + column_buffer_tensor_channel++; + } + } + } +} + +template +__global__ void DeformableCol2ImForOffsetAndMaskKernel( + int32 b, int32 num_kernels, DeformableConv2DParams p, + typename TTypes::Tensor input_eigen_tensor, + typename TTypes::Tensor offset_eigen_tensor, + typename TTypes::Tensor mask_eigen_tensor, + typename TTypes::Tensor offset_grad_eigen_tensor, + typename TTypes::Tensor mask_grad_eigen_tensor, + typename TTypes::Tensor column_buffer_eigen_tensor) { + CUDA_1D_KERNEL_LOOP(k, num_kernels) { + auto offset_grad_value = T(0); + auto mask_grad_value = T(0); + + const auto offset_channels = + 2 * p.filter_rows * p.filter_cols * p.offset_groups; + + const auto offset_channel_step = p.filter_rows * p.filter_cols; + + const auto current_output_col = k % p.output_cols; + const auto current_output_row = (k / p.output_cols) % p.output_rows; + const auto current_filter_col = + (k / (2 * p.output_rows * p.output_cols)) % p.filter_cols; + const auto current_filter_row = + (k / (2 * p.output_rows * p.output_cols * p.filter_cols)) % + p.filter_rows; + const auto current_offset_channel = + (k / (p.output_rows * p.output_cols)) % offset_channels; + const auto current_batch = + k / (p.output_rows * p.output_cols * offset_channels); + + const auto current_actual_batch = b * p.parallel_imgs + current_batch; + + const auto current_offset_group = + current_offset_channel / (2 * offset_channel_step); + + const auto channels_per_offset_group = p.input_channels / p.offset_groups; + const auto offset_channel_diff = + current_offset_channel - current_offset_group * 2 * offset_channel_step; + const auto is_y_direction = offset_channel_diff % 2 == 0; + + for (auto selected_offset_channel = (offset_channel_diff / 2); + selected_offset_channel < + channels_per_offset_group * offset_channel_step; + selected_offset_channel += offset_channel_step) { + const auto selected_filter_col = selected_offset_channel % p.filter_cols; + const auto selected_filter_row = + (selected_offset_channel / p.filter_cols) % p.filter_rows; + const auto input_channel_diff = + (selected_offset_channel / (p.filter_cols * p.filter_rows)); + + const auto offset_h = offset_eigen_tensor( + b, current_batch, current_offset_group, selected_filter_row, + selected_filter_col, 0, current_output_row, current_output_col); + const auto offset_w = offset_eigen_tensor( + b, current_batch, current_offset_group, selected_filter_row, + selected_filter_col, 1, current_output_row, current_output_col); + const auto mask = + p.use_mask + ? mask_eigen_tensor(b, current_batch, current_offset_group, + selected_filter_row, selected_filter_col, + current_output_row, current_output_col) + : T(1); + + const auto y = (current_output_row * p.stride_rows - p.padding_rows) + + selected_filter_row * p.dilation_rows + offset_h; + const auto x = (current_output_col * p.stride_cols - p.padding_cols) + + selected_filter_col * p.dilation_cols + offset_w; + + const auto selected_input_channel = + input_channel_diff + current_offset_group * channels_per_offset_group; + + const auto filter_data = column_buffer_eigen_tensor( + selected_input_channel, selected_filter_row, selected_filter_col, + current_batch, current_output_row, current_output_col); + + const auto weight = + GetCoordinateWeight(input_eigen_tensor, b, current_actual_batch, + selected_input_channel, y, x, is_y_direction); + + offset_grad_value += mask * weight * filter_data; + + if (is_y_direction) { + mask_grad_value += + filter_data * BilinearInterpolate(input_eigen_tensor, b, + current_actual_batch, + selected_input_channel, y, x); + } + } + + offset_grad_eigen_tensor(current_actual_batch, current_offset_channel, + current_output_row, current_output_col) = + offset_grad_value; + + if (p.use_mask && is_y_direction) { + const auto current_mask_channel = + (current_offset_group * p.filter_rows + current_filter_row) * + p.filter_cols + + current_filter_col; + + mask_grad_eigen_tensor(current_actual_batch, current_mask_channel, + current_output_row, current_output_col) = + mask_grad_value; + } + } +} + +template +__global__ void DeformableCol2ImForInputKernel( + int32 b, int32 num_kernels, DeformableConv2DParams p, + typename TTypes::Tensor offset_eigen_tensor, + typename TTypes::Tensor mask_eigen_tensor, + typename TTypes::Tensor input_grad_eigen_tensor, + typename TTypes::Tensor column_buffer_tensor_flattened) { + CUDA_1D_KERNEL_LOOP(k, num_kernels) { + const auto current_output_col = k % p.output_cols; + const auto current_output_row = (k / p.output_cols) % p.output_rows; + const auto current_batch = + (k / (p.output_rows * p.output_cols)) % p.parallel_imgs; + + const auto current_filter_col = + (k / (p.output_rows * p.output_cols * p.parallel_imgs)) % p.filter_cols; + const auto current_filter_row = (k / (p.output_rows * p.output_cols * + p.parallel_imgs * p.filter_cols)) % + p.filter_rows; + const auto current_channel = + k / (p.output_rows * p.output_cols * p.parallel_imgs * p.filter_rows * + p.filter_cols); + + const auto current_offset_group = + current_channel / (p.input_channels / p.offset_groups); + + const auto mask = + p.use_mask ? mask_eigen_tensor(b, current_batch, current_offset_group, + current_filter_row, current_filter_col, + current_output_row, current_output_col) + : T(1); + + const auto offset_h = offset_eigen_tensor( + b, current_batch, current_offset_group, current_filter_row, + current_filter_col, 0, current_output_row, current_output_col); + const auto offset_w = offset_eigen_tensor( + b, current_batch, current_offset_group, current_filter_row, + current_filter_col, 1, current_output_row, current_output_col); + + const auto y = (current_output_row * p.stride_rows - p.padding_rows) + + current_filter_row * p.dilation_rows + offset_h; + const auto x = (current_output_col * p.stride_cols - p.padding_cols) + + current_filter_col * p.dilation_cols + offset_w; + + for (auto dy = -1; dy <= 1; dy++) { + for (auto dx = -1; dx <= 1; dx++) { + const auto current_input_row = int(y) + dy; + const auto current_input_col = int(x) + dx; + + if (p.input_rows > current_input_row && current_input_row >= 0 && + p.input_cols > current_input_col && current_input_col >= 0 && + std::abs(y - current_input_row) < 1 && + std::abs(x - current_input_col) < 1) { + const auto weight = (1.0 - std::abs(y - current_input_row)) * + (1.0 - std::abs(x - current_input_col)); + + const auto current_actual_batch = b * p.parallel_imgs + current_batch; + + auto *ptr = input_grad_eigen_tensor.data(); + + const auto ptr_pos = + ((current_actual_batch * p.input_channels + current_channel) * + p.input_rows + + current_input_row) * + p.input_cols + + current_input_col; + + GpuAtomicAdd(ptr + ptr_pos, + mask * weight * column_buffer_tensor_flattened(k)); + } + } + } + } +} + +#define IM2COL(T) \ + template <> \ + void DeformableConv2DFunctorBase::DeformableIm2Col( \ + OpKernelContext *context, int32 b) { \ + auto num_kernels = \ + p.input_channels * p.output_rows * p.output_cols * p.parallel_imgs; \ + \ + const auto offset_eigen_tensor = offset_tensor.tensor(); \ + \ + const auto mask_eigen_tensor = \ + p.use_mask ? mask_tensor.tensor() \ + : mask_tensor.shaped({0, 0, 0, 0, 0, 0, 0}); \ + \ + const auto input_eigen_tensor = input_tensor.tensor(); \ + \ + auto column_buffer_eigen_tensor = column_buffer_tensor.tensor(); \ + \ + auto device = context->template eigen_device(); \ + GpuLaunchConfig config = GetGpuLaunchConfig(num_kernels, device); \ + TF_CHECK_OK(GpuLaunchKernel(DeformableIm2ColKernel, config.block_count, \ + config.thread_per_block, 0, device.stream(), \ + b, num_kernels, p, input_eigen_tensor, \ + offset_eigen_tensor, mask_eigen_tensor, \ + column_buffer_eigen_tensor)); \ + } +TF_CALL_float(IM2COL); +TF_CALL_double(IM2COL); +#undef IM2COL + +#define COL2IM_OFFSET_AND_MASK(T) \ + template <> \ + void \ + DeformableConv2DGradFunctor::DeformableCol2ImForOffsetAndMask( \ + OpKernelContext *context, int32 b) { \ + const auto num_kernels = p.output_rows * p.output_cols * 2 * \ + p.filter_rows * p.filter_cols * p.offset_groups * \ + p.parallel_imgs; \ + \ + const auto offset_eigen_tensor = offset_tensor.template tensor(); \ + \ + const auto mask_eigen_tensor = \ + p.use_mask ? mask_tensor.template tensor() \ + : mask_tensor.template shaped({0, 0, 0, 0, 0, 0, 0}); \ + \ + const auto column_buffer_eigen_tensor = \ + column_buffer_tensor.template shaped( \ + {p.input_channels, p.filter_rows, p.filter_cols, p.parallel_imgs, \ + p.output_rows, p.output_cols}); \ + \ + auto offset_grad_eigen_tensor = offset_grad_tensor.tensor(); \ + auto mask_grad_eigen_tensor = mask_grad_tensor.tensor(); \ + \ + const auto input_eigen_tensor = input_tensor.tensor(); \ + \ + auto device = context->template eigen_device(); \ + GpuLaunchConfig config = GetGpuLaunchConfig(num_kernels, device); \ + TF_CHECK_OK(GpuLaunchKernel( \ + DeformableCol2ImForOffsetAndMaskKernel, config.block_count, \ + config.thread_per_block, 0, device.stream(), b, num_kernels, p, \ + input_eigen_tensor, offset_eigen_tensor, mask_eigen_tensor, \ + offset_grad_eigen_tensor, mask_grad_eigen_tensor, \ + column_buffer_eigen_tensor)); \ + } +TF_CALL_float(COL2IM_OFFSET_AND_MASK); +TF_CALL_double(COL2IM_OFFSET_AND_MASK); +#undef COL2IM_OFFSET_AND_MASK + +#define COL2IM_INPUT(T) \ + template <> \ + void DeformableConv2DGradFunctor::DeformableCol2ImForInput( \ + OpKernelContext *context, int32 b) { \ + const auto num_kernels = p.input_channels * p.filter_rows * \ + p.filter_cols * p.output_rows * p.output_cols * \ + p.parallel_imgs; \ + \ + const auto offset_eigen_tensor = offset_tensor.template tensor(); \ + \ + const auto mask_eigen_tensor = \ + p.use_mask ? mask_tensor.template tensor() \ + : mask_tensor.template shaped({0, 0, 0, 0, 0, 0, 0}); \ + \ + const auto column_buffer_tensor_flattened = \ + column_buffer_tensor.template shaped({num_kernels}); \ + \ + auto input_grad_eigen_tensor = input_grad_tensor.tensor(); \ + \ + auto device = context->template eigen_device(); \ + GpuLaunchConfig config = GetGpuLaunchConfig(num_kernels, device); \ + TF_CHECK_OK(GpuLaunchKernel( \ + DeformableCol2ImForInputKernel, config.block_count, \ + config.thread_per_block, 0, device.stream(), b, num_kernels, p, \ + offset_eigen_tensor, mask_eigen_tensor, input_grad_eigen_tensor, \ + column_buffer_tensor_flattened)); \ + } +TF_CALL_float(COL2IM_INPUT); +TF_CALL_double(COL2IM_INPUT); +#undef COL2IM_INPUT + +#define EXPLICIT_TEMPLATE(T) \ + template struct DeformableConv2DForwardFunctor; \ + template struct DeformableConv2DGradFunctor; +TF_CALL_float(EXPLICIT_TEMPLATE); +TF_CALL_double(EXPLICIT_TEMPLATE); +#undef EXPLICIT_TEMPLATE + +} // end namespace functor + +#define EXPLICIT_TEMPLATE(T) \ + template Status Transpose( \ + OpKernelContext * ctx, const Tensor &in, \ + const gtl::ArraySlice perm, Tensor *out); +TF_CALL_float(EXPLICIT_TEMPLATE); +TF_CALL_double(EXPLICIT_TEMPLATE); +#undef EXPLICIT_TEMPLATE + +} // namespace addons +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow_addons/custom_ops/layers/cc/ops/deformable_conv2d_op.cc b/tensorflow_addons/custom_ops/layers/cc/ops/deformable_conv2d_op.cc new file mode 100644 index 0000000000..0d862f3de5 --- /dev/null +++ b/tensorflow_addons/custom_ops/layers/cc/ops/deformable_conv2d_op.cc @@ -0,0 +1,247 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/kernel_shape_util.h" +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { +namespace addons { + +using ::tensorflow::shape_inference::DimensionHandle; +using ::tensorflow::shape_inference::InferenceContext; +using ::tensorflow::shape_inference::ShapeHandle; + +REGISTER_OP("Addons>DeformableConv2D") + .Input("input: T") + .Input("filter: T") + .Input("bias: T") + .Input("offset: T") + .Input("mask: T") + .Output("output: T") + .Attr("strides: list(int)") + .Attr("weight_groups: int") + .Attr("offset_groups: int") + .Attr(GetPaddingAttrString()) + .Attr("dilations: list(int)") + .Attr("data_format: { 'NCHW' }") + .Attr("T: {float, double}") + .SetShapeFn([](InferenceContext *c) { + ShapeHandle input_shape; + ShapeHandle filter_shape; + ShapeHandle bias_shape; + ShapeHandle offset_shape; + ShapeHandle mask_shape; + std::vector strides; + std::vector dilations; + std::string data_format_str; + TensorFormat data_format; + int32 weight_groups; + int32 offset_groups; + Padding padding; + + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &bias_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 4, &offset_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 4, &mask_shape)); + TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); + TF_RETURN_IF_ERROR(c->GetAttr("weight_groups", &weight_groups)); + TF_RETURN_IF_ERROR(c->GetAttr("offset_groups", &offset_groups)); + TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); + TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations)); + TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str)); + FormatFromString(data_format_str, &data_format); + if (strides.size() != 2 || dilations.size() != 2) { + return errors::InvalidArgument("strides/dilations size must be 2."); + } + + DimensionHandle input_batches_dim = c->Dim(input_shape, 0); + DimensionHandle input_channels_dim = c->Dim(input_shape, 1); + DimensionHandle input_rows_dim = c->Dim(input_shape, 2); + DimensionHandle input_cols_dim = c->Dim(input_shape, 3); + + DimensionHandle bias_dim = c->Dim(bias_shape, 0); + + DimensionHandle output_channels_dim = c->Dim(filter_shape, 0); + DimensionHandle filter_channels_dim = c->Dim(filter_shape, 1); + DimensionHandle filter_rows_dim = c->Dim(filter_shape, 2); + DimensionHandle filter_cols_dim = c->Dim(filter_shape, 3); + + DimensionHandle offset_batches_dim = c->Dim(offset_shape, 0); + DimensionHandle offset_channels_dim = c->Dim(offset_shape, 1); + DimensionHandle offset_heights_dim = c->Dim(offset_shape, 2); + DimensionHandle offset_weights_dim = c->Dim(offset_shape, 3); + + DimensionHandle mask_batches_dim = c->Dim(mask_shape, 0); + DimensionHandle mask_channels_dim = c->Dim(mask_shape, 1); + DimensionHandle mask_heights_dim = c->Dim(mask_shape, 2); + DimensionHandle mask_weights_dim = c->Dim(mask_shape, 3); + + bool use_mask = InferenceContext::Value(mask_batches_dim) != 0; + bool use_bias = InferenceContext::Value(bias_dim) != 0; + + auto input_batches = InferenceContext::Value(input_batches_dim); + auto input_rows = InferenceContext::Value(input_rows_dim); + auto input_cols = InferenceContext::Value(input_cols_dim); + + auto output_channels = InferenceContext::Value(output_channels_dim); + + auto filter_rows = InferenceContext::Value(filter_rows_dim); + auto filter_cols = InferenceContext::Value(filter_cols_dim); + + auto stride_rows = strides[0]; + auto stride_cols = strides[1]; + auto diration_rows = dilations[0]; + auto diration_cols = dilations[1]; + + DimensionHandle tmp; + + if (use_bias) { + TF_RETURN_IF_ERROR(c->WithValue(bias_dim, output_channels, &tmp)); + } + + TF_RETURN_IF_ERROR( + c->Divide(output_channels_dim, weight_groups, true, &tmp)); + TF_RETURN_IF_ERROR( + c->Divide(input_channels_dim, offset_groups, true, &tmp)); + + TF_RETURN_IF_ERROR(c->Multiply(filter_channels_dim, weight_groups, &tmp)); + TF_RETURN_IF_ERROR(c->Merge(input_channels_dim, tmp, &tmp)); + + TF_RETURN_IF_ERROR(c->WithValue(offset_batches_dim, input_batches, &tmp)); + + if (use_mask) { + TF_RETURN_IF_ERROR(c->WithValue(mask_batches_dim, input_batches, &tmp)); + } + + if (InferenceContext::ValueKnown(filter_rows_dim) && + InferenceContext::ValueKnown(filter_cols_dim)) { + auto filter_area = filter_rows * filter_cols * offset_groups; + TF_RETURN_IF_ERROR( + c->WithValue(offset_channels_dim, 2 * filter_area, &tmp)); + if (use_mask) { + TF_RETURN_IF_ERROR( + c->WithValue(mask_channels_dim, filter_area, &tmp)); + } + } + + if (!InferenceContext::ValueKnown(input_rows_dim) || + !InferenceContext::ValueKnown(input_cols_dim) || + !InferenceContext::ValueKnown(filter_rows_dim) || + !InferenceContext::ValueKnown(filter_cols_dim)) { + c->set_output(0, c->MakeShape({input_batches_dim, output_channels_dim, + InferenceContext::kUnknownDim, + InferenceContext::kUnknownDim})); + return Status::OK(); + } + + auto effective_filter_rows = + filter_rows + (filter_rows - 1) * (diration_rows - 1); + auto effective_filter_cols = + filter_cols + (filter_cols - 1) * (diration_cols - 1); + + int64 output_rows, output_cols; + int64 padding_before, padding_after; + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( + input_rows, effective_filter_rows, stride_rows, padding, &output_rows, + &padding_before, &padding_after)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( + input_cols, effective_filter_cols, stride_cols, padding, &output_cols, + &padding_before, &padding_after)); + + TF_RETURN_IF_ERROR(c->WithValue(offset_heights_dim, output_rows, &tmp)); + TF_RETURN_IF_ERROR(c->WithValue(offset_weights_dim, output_cols, &tmp)); + if (use_mask) { + TF_RETURN_IF_ERROR(c->WithValue(mask_heights_dim, output_rows, &tmp)); + TF_RETURN_IF_ERROR(c->WithValue(mask_weights_dim, output_cols, &tmp)); + } + + c->set_output(0, c->MakeShape({input_batches_dim, output_channels_dim, + output_rows, output_cols})); + + return Status::OK(); + }) + .Doc(R"doc(Compute Modulated Deformable Convolution. + +This layer implements the operation from +Deformable ConvNets v2: More Deformable, Better Results (Zhu et al.) + +input: A `Tensor` of the format specified by `data_format`. +filter: A `Tensor` of the convolution kernel weights. Its shape is + `(output_channel, input_channel // weight_groups, kernel_height, kernel_width)`. +bias: A `Tensor` of the convolution bias. + `(0,)`-shape `Tensor` is passed when bias is disabled on Python side. +offset: A `Tensor` of the offsets which are applied for each position + in the convolution kernel. The channel size must be + `2 * kernel_height * kernel_width * offset_groups`. +mask: A `Tensor` of the modulation which are applied for each position + in the convolution kernel. The channel size must be + `kernel_height * kernel_width * offset_groups` if the modulation mode is + enabled on Python side. `(0,)`-shape `Tensor` is passed when the modulation + mode is disabled on Python side. +strides: A list of 2 integers, specifying the strides of the convolution + along the height and width. +weight_groups: An integer specifying the number of groups in which the input is + split along the channel axis. Each group is convolved separately with + `filters / weight_groups` filters. The output is the concatenation of all + the groups results along the channel axis. Input channels and output + channels must both be divisible by groups. +offset_groups: An integer specifying the number of groups in which the input is + split along the channel axis. Each group is convolved separately with + its group offset. +padding: A string specifying the padding type. + Possible values are: + "VALID" + "SAME" +dilations: A list of 2 integers, specifying the dilation rate to use + for dilated convolution. +data_format: Specifies the data format. + Possible values is: + "NCHW" float [batch, channels, height, width] + Defaults to `"NCHW"`. +)doc"); + +REGISTER_OP("Addons>DeformableConv2DGrad") + .Input("input: T") + .Input("filter: T") + .Input("bias: T") + .Input("offset: T") + .Input("mask: T") + .Input("output_grad: T") + .Output("input_grad: T") + .Output("filter_grad: T") + .Output("bias_grad: T") + .Output("offset_grad: T") + .Output("mask_grad: T") + .Attr("strides: list(int)") + .Attr("weight_groups: int") + .Attr("offset_groups: int") + .Attr(GetPaddingAttrString()) + .Attr("dilations: list(int)") + .Attr("data_format: { 'NCHW' }") + .Attr("T: {float, double}") + .SetShapeFn([](InferenceContext *c) { + c->set_output(0, c->input(0)); + c->set_output(1, c->input(1)); + c->set_output(2, c->input(2)); + c->set_output(3, c->input(3)); + c->set_output(4, c->input(4)); + return Status::OK(); + }) + .Doc(R"doc(DeformableConv2DGrad op.)doc"); + +} // namespace addons +} // namespace tensorflow diff --git a/tensorflow_addons/layers/BUILD b/tensorflow_addons/layers/BUILD index 3bbd16779e..56f18b9eca 100644 --- a/tensorflow_addons/layers/BUILD +++ b/tensorflow_addons/layers/BUILD @@ -9,6 +9,7 @@ py_library( data = [ "//tensorflow_addons/custom_ops/layers:_correlation_cost_ops.so", "//tensorflow_addons/custom_ops/layers:_embedding_bag_ops.so", + "//tensorflow_addons/custom_ops/layers:_deformable_conv2d_ops.so", ], deps = [ "//tensorflow_addons/activations", diff --git a/tensorflow_addons/layers/__init__.py b/tensorflow_addons/layers/__init__.py index 32072f826a..27ba5d88bd 100644 --- a/tensorflow_addons/layers/__init__.py +++ b/tensorflow_addons/layers/__init__.py @@ -44,3 +44,4 @@ from tensorflow_addons.layers.stochastic_depth import StochasticDepth from tensorflow_addons.layers.noisy_dense import NoisyDense from tensorflow_addons.layers.crf import CRF +from tensorflow_addons.layers.deformable_conv2d import DeformableConv2D diff --git a/tensorflow_addons/layers/deformable_conv2d.py b/tensorflow_addons/layers/deformable_conv2d.py new file mode 100644 index 0000000000..7d248d5aac --- /dev/null +++ b/tensorflow_addons/layers/deformable_conv2d.py @@ -0,0 +1,340 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import typing + +import tensorflow as tf +from typeguard import typechecked +from tensorflow_addons.utils import types +from tensorflow_addons.utils.resource_loader import LazySO +import tensorflow_addons.utils.keras_utils as conv_utils + +_deformable_conv2d_ops_so = LazySO("custom_ops/layers/_deformable_conv2d_ops.so") + + +@typechecked +def _deformable_conv2d( + input_tensor: tf.Tensor, + filter_tensor: tf.Tensor, + bias_tensor: tf.Tensor, + offset_tensor: tf.Tensor, + mask_tensor: tf.Tensor, + strides: typing.Union[tuple, list], + dilations: typing.Union[tuple, list], + weight_groups: int, + offset_groups: int, + padding: str, +): + with tf.name_scope("deformable_conv2d"): + return _deformable_conv2d_ops_so.ops.addons_deformable_conv2d( + input=input_tensor, + filter=filter_tensor, + bias=bias_tensor, + offset=offset_tensor, + mask=mask_tensor, + strides=strides, + weight_groups=weight_groups, + offset_groups=offset_groups, + padding=padding, + data_format="NCHW", + dilations=dilations, + ) + + +@tf.RegisterGradient("Addons>DeformableConv2D") +def _deformable_conv2d_grad(op, grad): + input = op.inputs[0] + filter = op.inputs[1] + bias = op.inputs[2] + offset = op.inputs[3] + mask = op.inputs[4] + strides = op.get_attr("strides") + weight_groups = op.get_attr("weight_groups") + offset_groups = op.get_attr("offset_groups") + padding = op.get_attr("padding") + dilations = op.get_attr("dilations") + data_format = op.get_attr("data_format") + + data_grad = _deformable_conv2d_ops_so.ops.addons_deformable_conv2d_grad( + input, + filter, + bias, + offset, + mask, + grad, + strides=strides, + weight_groups=weight_groups, + offset_groups=offset_groups, + padding=padding, + dilations=dilations, + data_format=data_format, + ) + return data_grad + + +@tf.keras.utils.register_keras_serializable(package="Addons") +class DeformableConv2D(tf.keras.layers.Layer): + @typechecked + def __init__( + self, + filters: int, + kernel_size: typing.Union[int, tuple, list] = (3, 3), + strides: typing.Union[int, tuple, list] = (1, 1), + padding: str = "valid", + data_format: str = "channels_first", + dilation_rate: typing.Union[int, tuple, list] = (1, 1), + weight_groups: int = 1, + offset_groups: int = 1, + use_mask: bool = False, + use_bias: bool = False, + kernel_initializer: types.Initializer = None, + bias_initializer: types.Initializer = None, + kernel_regularizer: types.Regularizer = None, + bias_regularizer: types.Regularizer = None, + kernel_constraint: types.Constraint = None, + bias_constraint: types.Constraint = None, + **kwargs + ): + """Modulated Deformable Convolution Layer. + + This layer implements from [Deformable ConvNets v2: More Deformable, Better Results] + (https://arxiv.org/abs/1811.11168)(Zhu et al.). + + Arguments: + filters: Integer, the dimensionality of the output space (i.e. the number of + output filters in the convolution). + kernel_size: An integer or tuple/list of 2 integers, specifying the height + and width of the 2D convolution window. Can be a single integer to specify + the same value for all spatial dimensions. + strides: An integer or tuple/list of 2 integers, specifying the strides of + the convolution along the height and width. Can be a single integer to + specify the same value for all spatial dimensions. Specifying any stride + value != 1 is incompatible with specifying any `dilation_rate` value != 1. + padding: one of `"valid"` or `"same"` (case-insensitive). + data_format: Specifies the data format. + Possible values is: + "channels_first" float [batch, channels, height, width] + Defaults to `"channels_first"`. + dilation_rate: an integer or tuple/list of 2 integers, specifying the + dilation rate to use for dilated convolution. Can be a single integer to + specify the same value for all spatial dimensions. + weight_groups: A positive integer specifying the number of groups in which the + input is split along the channel axis. Each group is convolved separately + with `filters / weight_groups` filters. The output is the concatenation of all + the `weight_groups` results along the channel axis. Input channels and `filters` + must both be divisible by `groups`. + offset_groups: An integer specifying the number of groups in which the input is + split along the channel axis. Each group is convolved separately with + its group offset. + use_mask: Boolean, whether the layer uses a modulation input. + use_bias: Boolean, whether the layer uses a bias vector. + kernel_initializer: Initializer for the `kernel` weights matrix (see + `keras.initializers`). + bias_initializer: Initializer for the bias vector (see + `keras.initializers`). + kernel_regularizer: Regularizer function applied to the `kernel` weights + matrix (see `keras.regularizers`). + bias_regularizer: Regularizer function applied to the bias vector (see + `keras.regularizers`). + activity_regularizer: Regularizer function applied to the output of the + layer (its "activation") (see `keras.regularizers`). + kernel_constraint: Constraint function applied to the kernel matrix (see + `keras.constraints`). + bias_constraint: Constraint function applied to the bias vector (see + `keras.constraints`). + """ + super(DeformableConv2D, self).__init__(**kwargs) + + self.filters = filters + self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, "kernel_size") + self.strides = conv_utils.normalize_tuple(strides, 2, "strides") + self.padding = conv_utils.normalize_padding(padding) + self.data_format = conv_utils.normalize_data_format(data_format) + self.dilation_rate = conv_utils.normalize_tuple( + dilation_rate, 2, "dilation_rate" + ) + self.weight_groups = weight_groups + self.offset_groups = offset_groups + self.use_mask = use_mask + self.use_bias = use_bias + self.kernel_initializer = tf.keras.initializers.get(kernel_initializer) + self.bias_initializer = tf.keras.initializers.get(bias_initializer) + self.kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer) + self.bias_regularizer = tf.keras.regularizers.get(bias_regularizer) + self.kernel_constraint = tf.keras.constraints.get(kernel_constraint) + self.bias_constraint = tf.keras.constraints.get(bias_constraint) + + if self.padding == "causal": + raise ValueError("Causal padding is not supported.") + + if self.data_format != "channels_first": + raise ValueError("`channels_last` data format is not supported.") + + if self.filters % self.weight_groups != 0: + raise ValueError("filters must be divisible by weight_group.") + + self.filter_weights = None + self.filter_bias = None + self.null_mask = None + + def _validate_shapes(self, shapes): + if type(shapes) is not list: + raise ValueError("DeformableConv2D input must be list of Tensor.") + elif self.use_mask and len(shapes) != 3: + raise ValueError("DeformableConv2D input must be 3-length list of Tensor.") + elif not self.use_mask and len(shapes) != 2: + raise ValueError("DeformableConv2D input must be 2-length list of Tensor.") + + def build(self, shapes): + self._validate_shapes(shapes) + + input_shape = shapes[0] + offset_shape = shapes[1] + mask_shape = shapes[2] if self.use_mask else None + + exp_off_c = self.offset_groups * 2 * self.kernel_size[0] * self.kernel_size[1] + + off_b, off_c, off_h, off_w = offset_shape + in_b, in_c, in_h, in_w = input_shape + + out_h = conv_utils.conv_output_length( + in_h, + self.kernel_size[0], + padding=self.padding, + stride=self.strides[0], + dilation=self.dilation_rate[0], + ) + out_w = conv_utils.conv_output_length( + in_w, + self.kernel_size[1], + padding=self.padding, + stride=self.strides[1], + dilation=self.dilation_rate[1], + ) + + if off_b != in_b or off_c != exp_off_c or off_h != out_h or off_w != out_w: + raise ValueError( + "DeformableConv2D Offset shape must be [{}, {}, {}, {}].".format( + in_b, exp_off_c, out_h, out_w + ) + ) + + if mask_shape is not None: + exp_mask_c = exp_off_c // 2 + + mask_b, mask_c, mask_h, mask_w = mask_shape + + if ( + mask_b != in_b + or mask_c != exp_mask_c + or mask_h != out_h + or mask_w != out_w + ): + raise ValueError( + "DeformableConv2D Mask shape must be [{}, {}, {}, {}].".format( + in_b, exp_mask_c, out_h, out_w + ) + ) + + # Channel first + shape = ( + self.filters, + input_shape[1] // self.weight_groups, + self.kernel_size[0], + self.kernel_size[1], + ) + + self.filter_weights = self.add_weight( + name="filter", + shape=shape, + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + trainable=True, + ) + + if self.use_bias: + self.filter_bias = self.add_weight( + name="bias", + shape=(self.filters,), + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint, + trainable=True, + ) + else: + self.filter_bias = tf.zeros((0,)) + + if not self.use_mask: + self.null_mask = tf.zeros((0, 0, 0, 0)) + + self.built = True + + def compute_output_shape(self, shapes): + self._validate_shapes(shapes) + + input_shape = shapes[0] + in_b, _, in_h, in_w = input_shape + + out_h = conv_utils.conv_output_length( + in_h, + self.kernel_size[0], + padding=self.padding, + stride=self.strides[0], + dilation=self.dilation_rate[0], + ) + out_w = conv_utils.conv_output_length( + in_w, + self.kernel_size[1], + padding=self.padding, + stride=self.strides[1], + dilation=self.dilation_rate[1], + ) + + return tf.TensorShape([in_b, self.filters, out_h, out_w]) + + def call(self, inputs, **kwargs): + input_tensor = inputs[0] + offset_tensor = inputs[1] + mask_tensor = inputs[2] if self.use_mask else self.null_mask + + return _deformable_conv2d( + input_tensor=tf.convert_to_tensor(input_tensor), + filter_tensor=tf.convert_to_tensor(self.filter_weights), + bias_tensor=tf.convert_to_tensor(self.filter_bias), + offset_tensor=tf.convert_to_tensor(offset_tensor), + mask_tensor=tf.convert_to_tensor(mask_tensor), + strides=self.strides, + weight_groups=self.weight_groups, + offset_groups=self.offset_groups, + padding="SAME" if self.padding == "same" else "VALID", + dilations=self.dilation_rate, + ) + + def get_config(self): + config = { + "kernel_size": self.kernel_size, + "filters": self.filters, + "strides": self.strides, + "padding": self.padding, + "data_format": self.data_format, + "dilation_rate": self.dilation_rate, + "weight_groups": self.weight_groups, + "offset_groups": self.offset_groups, + "use_mask": self.use_mask, + "use_bias": self.use_bias, + } + base_config = super().get_config() + return {**base_config, **config} diff --git a/tensorflow_addons/layers/tests/deformable_conv2d_test.py b/tensorflow_addons/layers/tests/deformable_conv2d_test.py new file mode 100644 index 0000000000..c7dbb5d922 --- /dev/null +++ b/tensorflow_addons/layers/tests/deformable_conv2d_test.py @@ -0,0 +1,609 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +import pytest +import numpy as np +import tensorflow as tf +import tensorflow_addons.utils.keras_utils as conv_utils +from tensorflow_addons.layers.deformable_conv2d import ( + DeformableConv2D, + _deformable_conv2d, +) + + +def _get_padding_length( + padding, filter_size, dilation_rate, stride, input_size, output_size +): + effective_filter_size = (filter_size - 1) * dilation_rate + 1 + + pad = 0 + if padding == "same": + pad = ((output_size - 1) * stride + effective_filter_size - input_size) // 2 + + return pad + + +def _bilinear_interpolate(img, y, x): + max_height, max_width = img.shape + + if y <= -1 or max_height <= y or x <= -1 or max_width <= x: + return 0.0 + + y_low = int(np.floor(y)) + x_low = int(np.floor(x)) + y_high = y_low + 1 + w_high = x_low + 1 + + v1 = 0.0 + if y_low >= 0 and x_low >= 0: + v1 = img[y_low, x_low] + + v2 = 0.0 + if y_low >= 0 and w_high <= max_width - 1: + v2 = img[y_low, w_high] + + v3 = 0.0 + if y_high <= max_height - 1 and x_low >= 0: + v3 = img[y_high, x_low] + + v4 = 0.0 + if y_high <= max_height - 1 and w_high <= max_width - 1: + v4 = img[y_high, w_high] + + lh = y - y_low + lw = x - x_low + hh = 1 - lh + hw = 1 - lw + + w1 = hh * hw + w2 = hh * lw + w3 = lh * hw + w4 = lh * lw + + return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4 + + +def _expected( + input_tensor, + filter_tensor, + offset_tensor, + mask_tensor, + bias, + strides, + weight_groups, + offset_groups, + padding, + dilation_rate, +): + input_tensor = input_tensor.numpy() + filter_tensor = filter_tensor.numpy() + offset_tensor = offset_tensor.numpy() + mask_tensor = mask_tensor.numpy() + bias = bias.numpy() + + padding = conv_utils.normalize_padding(padding) + + stride_rows, stride_cols = conv_utils.normalize_tuple(strides, 2, "strides") + dilation_rows, dilation_cols = conv_utils.normalize_tuple( + dilation_rate, 2, "dilation_rate" + ) + filter_rows, filter_cols = filter_tensor.shape[-2:] + + batches, input_channels, input_rows, input_cols = input_tensor.shape + output_channels = filter_tensor.shape[0] + + output_rows = conv_utils.conv_output_length( + input_rows, + filter_rows, + padding=padding, + stride=stride_rows, + dilation=dilation_rows, + ) + output_cols = conv_utils.conv_output_length( + input_cols, + filter_cols, + padding=padding, + stride=stride_cols, + dilation=dilation_cols, + ) + + padding_rows = _get_padding_length( + padding, filter_rows, dilation_rows, stride_rows, input_rows, output_rows + ) + padding_cols = _get_padding_length( + padding, filter_cols, dilation_cols, stride_cols, input_cols, output_cols + ) + + input_channels_per_offset_group = input_channels // offset_groups + + input_channels_per_weight_groups = filter_tensor.shape[1] + output_channels_per_weight_groups = output_channels // weight_groups + + output = np.zeros((batches, output_channels, output_rows, output_cols)) + + if output.size == 0: + return output + + offset_tensor = offset_tensor.reshape((batches, -1, 2, output_rows, output_cols)) + + for batch in range(batches): + for output_channel in range(output_channels): + for output_row in range(output_rows): + for output_col in range(output_cols): + for filter_row in range(filter_rows): + for filter_col in range(filter_cols): + for input_channel in range( + input_channels_per_weight_groups + ): + weight_group = ( + output_channel // output_channels_per_weight_groups + ) + current_input_channel = ( + weight_group * input_channels_per_weight_groups + + input_channel + ) + + offset_group = ( + current_input_channel + // input_channels_per_offset_group + ) + offset_idx = ( + offset_group * (filter_rows * filter_cols) + + filter_row * filter_cols + + filter_col + ) + + dy = offset_tensor[ + batch, offset_idx, 0, output_row, output_col + ] + dx = offset_tensor[ + batch, offset_idx, 1, output_row, output_col + ] + + mask = ( + mask_tensor[ + batch, offset_idx, output_row, output_col + ] + if mask_tensor.size > 0 + else 1 + ) + + y = ( + stride_rows * output_row + - padding_rows + + dilation_rows * filter_row + + dy + ) + x = ( + stride_cols * output_col + - padding_cols + + dilation_cols * filter_col + + dx + ) + + output[ + batch, output_channel, output_row, output_col + ] += ( + mask + * filter_tensor[ + output_channel, + input_channel, + filter_row, + filter_col, + ] + * _bilinear_interpolate( + input_tensor[ + batch, current_input_channel, :, : + ], + y, + x, + ) + ) + + if bias.size > 0: + output += bias.reshape((1, output_channels, 1, 1)) + + return output + + +@pytest.mark.with_device(["cpu", "gpu"]) +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +@pytest.mark.parametrize("padding", ["same", "valid"]) +@pytest.mark.parametrize("batches", [0, 1, 2]) +def test_forward(data_format, padding, batches): + if data_format == "channels_last": + return + + input_channels = 6 + filters = 2 + weight_groups = 2 + offset_groups = 3 + + strides = (2, 1) + dilation_rate = (2, 1) + kernel_size = (3, 2) + + input_rows, input_cols = 5, 4 + filter_rows, filter_cols = kernel_size + stride_rows, stride_cols = strides + dilation_rows, dilation_cols = dilation_rate + + output_rows = conv_utils.conv_output_length( + input_rows, + filter_rows, + padding=padding, + stride=stride_rows, + dilation=dilation_rows, + ) + output_cols = conv_utils.conv_output_length( + input_cols, + filter_cols, + padding=padding, + stride=stride_cols, + dilation=dilation_cols, + ) + + offsets = offset_groups * filter_rows * filter_cols + + input_tensor = tf.random.uniform([batches, input_channels, input_rows, input_cols]) + offset_tensor = tf.random.uniform([batches, 2 * offsets, output_rows, output_cols]) + mask_tensor = tf.random.uniform([batches, offsets, output_rows, output_cols]) + + conv = DeformableConv2D( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + dilation_rate=dilation_rate, + weight_groups=weight_groups, + offset_groups=offset_groups, + use_mask=True, + use_bias=True, + ) + + actual = conv([input_tensor, offset_tensor, mask_tensor]) + + filter_tensor = conv.filter_weights + bias = conv.filter_bias + + expected = _expected( + input_tensor, + filter_tensor, + offset_tensor, + mask_tensor, + bias, + strides, + weight_groups, + offset_groups, + padding, + dilation_rate, + ) + + np.testing.assert_allclose(actual.numpy(), expected, rtol=1e-5) + + +@pytest.mark.with_device(["cpu", "gpu"]) +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +@pytest.mark.parametrize("padding", ["same", "valid"]) +@pytest.mark.parametrize("batches", [0, 1, 2]) +def test_forward_no_mask(data_format, padding, batches): + if data_format == "channels_last": + return + + input_channels = 6 + filters = 2 + weight_groups = 2 + offset_groups = 3 + + strides = (2, 1) + dilation_rate = (2, 1) + kernel_size = (3, 2) + + input_rows, input_cols = 5, 4 + filter_rows, filter_cols = kernel_size + stride_rows, stride_cols = strides + dilation_rows, dilation_cols = dilation_rate + + output_rows = conv_utils.conv_output_length( + input_rows, + filter_rows, + padding=padding, + stride=stride_rows, + dilation=dilation_rows, + ) + output_cols = conv_utils.conv_output_length( + input_cols, + filter_cols, + padding=padding, + stride=stride_cols, + dilation=dilation_cols, + ) + + offsets = offset_groups * filter_rows * filter_cols + + input_tensor = tf.random.uniform([batches, input_channels, input_rows, input_cols]) + offset_tensor = tf.random.uniform([batches, 2 * offsets, output_rows, output_cols]) + + conv = DeformableConv2D( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + dilation_rate=dilation_rate, + weight_groups=weight_groups, + offset_groups=offset_groups, + use_mask=False, + use_bias=True, + ) + + actual = conv([input_tensor, offset_tensor]) + + filter_tensor = conv.filter_weights + mask_tensor = conv.null_mask + bias = conv.filter_bias + + expected = _expected( + input_tensor, + filter_tensor, + offset_tensor, + mask_tensor, + bias, + strides, + weight_groups, + offset_groups, + padding, + dilation_rate, + ) + + np.testing.assert_allclose(actual.numpy(), expected, rtol=1e-5) + + +@pytest.mark.with_device(["cpu", "gpu"]) +@pytest.mark.parametrize("padding", ["same", "valid"]) +@pytest.mark.parametrize("batches", [0, 1, 2]) +def test_gradients(data_format, padding, batches): + if data_format == "channels_last": + return + + input_channels = 6 + filters = 2 + weight_groups = 2 + offset_groups = 3 + + strides = (2, 1) + dilation_rate = (2, 1) + kernel_size = (3, 2) + + input_rows, input_cols = 5, 4 + filter_rows, filter_cols = kernel_size + stride_rows, stride_cols = strides + dilation_rows, dilation_cols = dilation_rate + + output_rows = conv_utils.conv_output_length( + input_rows, + filter_rows, + padding=padding, + stride=stride_rows, + dilation=dilation_rows, + ) + output_cols = conv_utils.conv_output_length( + input_cols, + filter_cols, + padding=padding, + stride=stride_cols, + dilation=dilation_cols, + ) + + offsets = offset_groups * filter_rows * filter_cols + + input_tensor = tf.random.uniform([batches, input_channels, input_rows, input_cols]) + offset_tensor = tf.random.uniform([batches, 2 * offsets, output_rows, output_cols]) + mask_tensor = tf.random.uniform([batches, offsets, output_rows, output_cols]) + + conv = DeformableConv2D( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + dilation_rate=dilation_rate, + weight_groups=weight_groups, + offset_groups=offset_groups, + use_mask=True, + use_bias=True, + ) + + conv.build([tf.shape(input_tensor), tf.shape(offset_tensor), tf.shape(mask_tensor)]) + + def conv_fn(input_tensor, filter_weights, filter_bias, offset_tensor, mask_tensor): + return _deformable_conv2d( + input_tensor=tf.convert_to_tensor(input_tensor), + filter_tensor=tf.convert_to_tensor(filter_weights), + bias_tensor=tf.convert_to_tensor(filter_bias), + offset_tensor=tf.convert_to_tensor(offset_tensor), + mask_tensor=tf.convert_to_tensor(mask_tensor), + strides=conv.strides, + weight_groups=conv.weight_groups, + offset_groups=conv.offset_groups, + padding="SAME" if conv.padding == "same" else "VALID", + dilations=conv.dilation_rate, + ) + + theoretical, numerical = tf.test.compute_gradient( + conv_fn, + [ + input_tensor, + conv.filter_weights, + conv.filter_bias, + offset_tensor, + mask_tensor, + ], + ) + + np.testing.assert_allclose(theoretical[0], numerical[0], atol=1e-3) + np.testing.assert_allclose(theoretical[1], numerical[1], atol=1e-3) + np.testing.assert_allclose(theoretical[2], numerical[2], atol=1e-3) + np.testing.assert_allclose(theoretical[3], numerical[3], atol=1e-3) + np.testing.assert_allclose(theoretical[4], numerical[4], atol=1e-3) + + +@pytest.mark.with_device(["cpu", "gpu"]) +@pytest.mark.parametrize("padding", ["same", "valid"]) +@pytest.mark.parametrize("batches", [0, 1, 2]) +def test_gradients_no_mask(data_format, padding, batches): + if data_format == "channels_last": + return + + input_channels = 6 + filters = 2 + weight_groups = 2 + offset_groups = 3 + + strides = (2, 1) + dilation_rate = (2, 1) + kernel_size = (3, 2) + + input_rows, input_cols = 5, 4 + filter_rows, filter_cols = kernel_size + stride_rows, stride_cols = strides + dilation_rows, dilation_cols = dilation_rate + + output_rows = conv_utils.conv_output_length( + input_rows, + filter_rows, + padding=padding, + stride=stride_rows, + dilation=dilation_rows, + ) + output_cols = conv_utils.conv_output_length( + input_cols, + filter_cols, + padding=padding, + stride=stride_cols, + dilation=dilation_cols, + ) + + offsets = offset_groups * filter_rows * filter_cols + + input_tensor = tf.random.uniform([batches, input_channels, input_rows, input_cols]) + offset_tensor = tf.random.uniform([batches, 2 * offsets, output_rows, output_cols]) + + conv = DeformableConv2D( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + dilation_rate=dilation_rate, + weight_groups=weight_groups, + offset_groups=offset_groups, + use_mask=False, + use_bias=True, + ) + + conv.build([tf.shape(input_tensor), tf.shape(offset_tensor)]) + + def conv_fn(input_tensor, filter_weights, filter_bias, offset_tensor): + return _deformable_conv2d( + input_tensor=tf.convert_to_tensor(input_tensor), + filter_tensor=tf.convert_to_tensor(filter_weights), + bias_tensor=tf.convert_to_tensor(filter_bias), + offset_tensor=tf.convert_to_tensor(offset_tensor), + mask_tensor=tf.convert_to_tensor(conv.null_mask), + strides=conv.strides, + weight_groups=conv.weight_groups, + offset_groups=conv.offset_groups, + padding="SAME" if conv.padding == "same" else "VALID", + dilations=conv.dilation_rate, + ) + + theoretical, numerical = tf.test.compute_gradient( + conv_fn, [input_tensor, conv.filter_weights, conv.filter_bias, offset_tensor] + ) + + np.testing.assert_allclose(theoretical[0], numerical[0], atol=1e-3) + np.testing.assert_allclose(theoretical[1], numerical[1], atol=1e-3) + np.testing.assert_allclose(theoretical[2], numerical[2], atol=1e-3) + np.testing.assert_allclose(theoretical[3], numerical[3], atol=1e-3) + + +@pytest.mark.with_device(["cpu", "gpu"]) +def test_keras(data_format): + if data_format == "channels_last": + return + + batches = 1 + input_channels = 6 + filters = 2 + weight_groups = 2 + offset_groups = 3 + + strides = (2, 1) + padding = "same" + dilation_rate = (2, 1) + kernel_size = (3, 2) + + input_rows, input_cols = 5, 4 + filter_rows, filter_cols = kernel_size + stride_rows, stride_cols = strides + dilation_rows, dilation_cols = dilation_rate + + output_rows = conv_utils.conv_output_length( + input_rows, + filter_rows, + padding=padding, + stride=stride_rows, + dilation=dilation_rows, + ) + output_cols = conv_utils.conv_output_length( + input_cols, + filter_cols, + padding=padding, + stride=stride_cols, + dilation=dilation_cols, + ) + + offsets = offset_groups * filter_rows * filter_cols + + input_tensor = tf.random.uniform([batches, input_channels, input_rows, input_cols]) + offset_tensor = tf.random.uniform([batches, 2 * offsets, output_rows, output_cols]) + mask_tensor = tf.random.uniform([batches, offsets, output_rows, output_cols]) + + input_a = tf.keras.Input([input_channels, input_rows, input_cols]) + input_b = tf.keras.Input([2 * offsets, output_rows, output_cols]) + input_c = tf.keras.Input([offsets, output_rows, output_cols]) + + conv = DeformableConv2D( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + dilation_rate=dilation_rate, + weight_groups=weight_groups, + offset_groups=offset_groups, + use_mask=True, + use_bias=True, + ) + + expected_output_shape = tuple( + conv.compute_output_shape([input_a.shape, input_b.shape, input_c.shape]) + ) + + x = [input_a, input_b, input_c] + y = conv(x) + model = tf.keras.models.Model(x, y) + actual_output = model([input_tensor, offset_tensor, mask_tensor]) + + assert tf.keras.backend.dtype(y[0]) == "float32" + assert actual_output.shape[1:] == expected_output_shape[1:] diff --git a/tensorflow_addons/utils/keras_utils.py b/tensorflow_addons/utils/keras_utils.py index acb5f925ba..0e4a58fa01 100644 --- a/tensorflow_addons/utils/keras_utils.py +++ b/tensorflow_addons/utils/keras_utils.py @@ -68,6 +68,20 @@ def get_config(self): return {**base_config, **config} +def normalize_padding(value): + """A copy of tensorflow.python.keras.util.""" + if isinstance(value, (list, tuple)): + return value + padding = value.lower() + if padding not in {"valid", "same", "causal"}: + raise ValueError( + "The `padding` argument must be a list/tuple or one of " + '"valid", "same" (or "causal", only for `Conv1D). ' + "Received: " + str(padding) + ) + return padding + + def normalize_data_format(value): if value is None: value = tf.keras.backend.image_data_format() @@ -143,6 +157,34 @@ def normalize_tuple(value, n, name): return value_tuple +def conv_output_length(input_length, filter_size, padding, stride, dilation=1): + """Determines output length of a convolution given input length. + + A copy of tensorflow.python.keras.util. + + Arguments: + input_length: integer. + filter_size: integer. + padding: one of "same", "valid", "full", "causal" + stride: integer. + dilation: dilation rate, integer. + + Returns: + The output length (integer). + """ + if input_length is None: + return None + assert padding in {"same", "valid", "full", "causal"} + dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1) + if padding in ["same", "causal"]: + output_length = input_length + elif padding == "valid": + output_length = input_length - dilated_filter_size + 1 + elif padding == "full": + output_length = input_length + dilated_filter_size - 1 + return (output_length + stride - 1) // stride + + def _hasattr(obj, attr_name): # If possible, avoid retrieving the attribute as the object might run some # lazy computation in it.