diff --git a/tensorflow_addons/custom_ops/image/BUILD b/tensorflow_addons/custom_ops/image/BUILD index 7002d86e4a..80c055a661 100644 --- a/tensorflow_addons/custom_ops/image/BUILD +++ b/tensorflow_addons/custom_ops/image/BUILD @@ -35,3 +35,16 @@ custom_op_library( "cc/kernels/image_projective_transform_op_gpu.cu.cc", ], ) + +custom_op_library( + name = "_resampler_ops.so", + srcs = [ + "cc/kernels/resampler_ops.cc", + "cc/kernels/resampler_ops.h", + "cc/ops/resampler_ops.cc", + ], + cuda_srcs = [ + "cc/kernels/resampler_ops.h", + "cc/kernels/resampler_ops_gpu.cu.cc", + ], +) diff --git a/tensorflow_addons/custom_ops/image/cc/kernels/resampler_ops.cc b/tensorflow_addons/custom_ops/image/cc/kernels/resampler_ops.cc new file mode 100644 index 0000000000..0a892330ea --- /dev/null +++ b/tensorflow_addons/custom_ops/image/cc/kernels/resampler_ops.cc @@ -0,0 +1,417 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#define EIGEN_USE_THREADS + +#include +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/util/work_sharder.h" +#include "tensorflow_addons/custom_ops/image/cc/kernels/resampler_ops.h" + +namespace tensorflow { + +namespace addons { + +using CPUDevice = Eigen::ThreadPoolDevice; +using GPUDevice = Eigen::GpuDevice; + +namespace functor { + +template +struct Resampler2DFunctor { + void operator()(OpKernelContext* ctx, const CPUDevice& d, + const T* __restrict__ data, const T* __restrict__ warp, + T* __restrict__ output, const int batch_size, + const int data_height, const int data_width, + const int data_channels, const int num_sampling_points) { + const int warp_batch_stride = num_sampling_points * 2; + const int data_batch_stride = data_height * data_width * data_channels; + const int output_batch_stride = num_sampling_points * data_channels; + const T zero = static_cast(0.0); + const T one = static_cast(1.0); + + auto resample_batches = [&](const int start, const int limit) { + for (int batch_id = start; batch_id < limit; ++batch_id) { + // Utility lambda to access data point and set output values. + // The functions take care of performing the relevant pointer + // arithmetics abstracting away the low level details in the + // main loop over samples. Note that data is stored in NHWC format. + auto set_output = [&](const int sample_id, const int channel, + const T value) { + output[batch_id * output_batch_stride + sample_id * data_channels + + channel] = value; + }; + + auto get_data_point = [&](const int x, const int y, const int chan) { + const bool point_is_in_range = + (x >= 0 && y >= 0 && x <= data_width - 1 && y <= data_height - 1); + return point_is_in_range + ? data[batch_id * data_batch_stride + + data_channels * (y * data_width + x) + chan] + : zero; + }; + + for (int sample_id = 0; sample_id < num_sampling_points; ++sample_id) { + const T x = warp[batch_id * warp_batch_stride + sample_id * 2]; + const T y = warp[batch_id * warp_batch_stride + sample_id * 2 + 1]; + // The interpolation function: + // a) implicitly pads the input data with 0s (hence the unusual checks + // with {x,y} > -1) + // b) returns 0 when sampling outside the (padded) image. + // The effect is that the sampled signal smoothly goes to 0 outside + // the original input domain, rather than presenting a jump + // discontinuity at the image boundaries. + if (x > static_cast(-1.0) && y > static_cast(-1.0) && + x < static_cast(data_width) && + y < static_cast(data_height)) { + // Precompute floor (f) and ceil (c) values for x and y. + const int fx = std::floor(static_cast(x)); + const int fy = std::floor(static_cast(y)); + const int cx = fx + 1; + const int cy = fy + 1; + const T dx = static_cast(cx) - x; + const T dy = static_cast(cy) - y; + + for (int chan = 0; chan < data_channels; ++chan) { + const T img_fxfy = dx * dy * get_data_point(fx, fy, chan); + const T img_cxcy = + (one - dx) * (one - dy) * get_data_point(cx, cy, chan); + const T img_fxcy = dx * (one - dy) * get_data_point(fx, cy, chan); + const T img_cxfy = (one - dx) * dy * get_data_point(cx, fy, chan); + set_output(sample_id, chan, + img_fxfy + img_cxcy + img_fxcy + img_cxfy); + } + } else { + for (int chan = 0; chan < data_channels; ++chan) { + set_output(sample_id, chan, zero); + } + } + } + } + }; + // Rough estimate of work for each batch entry. + // From third_party/tensorflow/core/util/work_sharder.cc we gather that an + // estimate of the cost of each work unit is needed to correctly shard the + // workload. Shard assumes each cost unit is 1ns, minimum cost per shard + // being 10us. + const int64 cost = + static_cast(num_sampling_points) * data_channels * 1000; + auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); + Shard(worker_threads.num_threads, worker_threads.workers, batch_size, cost, + resample_batches); + } +}; + +} // namespace functor + +template +class ResamplerOp : public OpKernel { + public: + explicit ResamplerOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& data = ctx->input(0); + const Tensor& warp = ctx->input(1); + + const TensorShape& data_shape = data.shape(); + OP_REQUIRES(ctx, data_shape.dims() == 4, + errors::Unimplemented( + "Only bilinear interpolation is currently supported. The " + "input data shape must be [batch_size, data_height, " + "data_width, data_channels], but is: ", + data_shape.DebugString())); + const TensorShape& warp_shape = warp.shape(); + OP_REQUIRES( + ctx, TensorShapeUtils::IsMatrixOrHigher(warp_shape), + errors::InvalidArgument("warp should be at least a matrix, got shape ", + "warp should be at least a matrix, got shape ", + warp_shape.DebugString())); + OP_REQUIRES(ctx, warp_shape.dim_size(warp_shape.dims() - 1) == 2, + errors::Unimplemented( + "Only bilinear interpolation is supported, warping " + "coordinates must be 2D; warp shape last entry should be " + "2, but shape vector is: ", + warp_shape.DebugString())); + OP_REQUIRES(ctx, data_shape.dim_size(0) == warp_shape.dim_size(0), + errors::InvalidArgument( + "Batch size of data and warp tensor must be the same, but " + "input shapes are: ", + data_shape.DebugString(), ", ", warp_shape.DebugString())); + const int batch_size = data_shape.dim_size(0); + const int data_height = data_shape.dim_size(1); + const int data_width = data_shape.dim_size(2); + const int data_channels = data_shape.dim_size(3); + TensorShape output_shape = warp.shape(); + output_shape.set_dim(output_shape.dims() - 1, data_channels); + const int num_sampling_points = warp.NumElements() / batch_size / 2; + Tensor* output = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output)); + + // Execute kernel only for nonempty output; otherwise Eigen crashes on GPU. + if (num_sampling_points > 0) { + functor::Resampler2DFunctor()( + ctx, ctx->eigen_device(), data.flat().data(), + warp.flat().data(), output->flat().data(), batch_size, + data_height, data_width, data_channels, num_sampling_points); + } + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(ResamplerOp); +}; + +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("Addons>Resampler").Device(DEVICE_CPU).TypeConstraint("T"), \ + ResamplerOp); + +TF_CALL_half(REGISTER); +TF_CALL_float(REGISTER); +TF_CALL_double(REGISTER); +#undef REGISTER + +#if GOOGLE_CUDA +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("Addons>Resampler").Device(DEVICE_GPU).TypeConstraint("T"), \ + ResamplerOp) +TF_CALL_float(REGISTER); +TF_CALL_double(REGISTER); +#undef REGISTER +#endif // GOOGLE_CUDA + +namespace functor { + +template +struct ResamplerGrad2DFunctor { + void operator()(OpKernelContext* ctx, const CPUDevice& d, + const T* __restrict__ data, const T* __restrict__ warp, + const T* __restrict__ grad_output, T* __restrict__ grad_data, + T* __restrict__ grad_warp, const int batch_size, + const int data_height, const int data_width, + const int data_channels, const int num_sampling_points) { + // Set gradients to 0, because the kernel incrementally updates the + // tensor entries by adding partial contributions. + const int resampler_output_size = + batch_size * num_sampling_points * data_channels; + const int grad_warp_size = resampler_output_size / data_channels * 2; + const int grad_data_size = + data_height * data_width * data_channels * batch_size; + memset(grad_data, 0, sizeof(T) * grad_data_size); + memset(grad_warp, 0, sizeof(T) * grad_warp_size); + + const auto&& data_batch_stride = data_height * data_width * data_channels; + const auto&& warp_batch_stride = num_sampling_points * 2; + const int output_batch_stride = num_sampling_points * data_channels; + const T zero = static_cast(0.0); + const T one = static_cast(1.0); + + auto update_grads_for_batches = [&](const int start, const int limit) { + for (int batch_id = start; batch_id < limit; ++batch_id) { + // Utility lambdas to access data and update gradient tensors. + // The functions take care of performing the relevant pointer + // arithmetics abstracting away the low level details in the + // main loop over samples. Note that data is stored in NHWC format. + auto get_data_point = [&](const int x, const int y, const int chan) { + const bool point_is_in_range = + (x >= 0 && y >= 0 && x <= data_width - 1 && y <= data_height - 1); + return point_is_in_range + ? data[batch_id * data_batch_stride + + data_channels * (y * data_width + x) + chan] + : zero; + }; + + auto update_grad_data = [&](const int x, const int y, const int chan, + const T value) { + const bool point_is_in_range = + (x >= 0 && y >= 0 && x <= data_width - 1 && y <= data_height - 1); + if (point_is_in_range) { + grad_data[batch_id * data_batch_stride + + data_channels * (y * data_width + x) + chan] += value; + } + }; + + auto update_grad_warp = [&](const int sample_id, const int channel, + const T value) { + grad_warp[batch_id * warp_batch_stride + sample_id * 2 + channel] += + value; + }; + + for (int sample_id = 0; sample_id < num_sampling_points; ++sample_id) { + const T x = warp[batch_id * warp_batch_stride + sample_id * 2]; + const T y = warp[batch_id * warp_batch_stride + sample_id * 2 + 1]; + // The interpolation function whose gradient this function implements: + // a) implicitly pads the input data with 0s (hence the unusual checks + // with {x,y} > -1) + // b) returns 0 when sampling outside the (padded) image. + // The effect is that the sampled signal smoothly goes to 0 outside + // the original input domain, rather than presenting a jump + // discontinuity at the image boundaries. + if (x > static_cast(-1.0) && y > static_cast(-1.0) && + x < static_cast(data_width) && + y < static_cast(data_height)) { + // Precompute floor (f) and ceil (c) values for x and y. + const int fx = std::floor(static_cast(x)); + const int fy = std::floor(static_cast(y)); + const int cx = fx + 1; + const int cy = fy + 1; + const T dx = static_cast(cx) - x; + const T dy = static_cast(cy) - y; + + for (int chan = 0; chan < data_channels; ++chan) { + const T grad_output_value = + grad_output[batch_id * output_batch_stride + + sample_id * data_channels + chan]; + const T img_fxfy = get_data_point(fx, fy, chan); + const T img_cxcy = get_data_point(cx, cy, chan); + const T img_fxcy = get_data_point(fx, cy, chan); + const T img_cxfy = get_data_point(cx, fy, chan); + + // Update partial gradients wrt relevant warp field entries + update_grad_warp( + sample_id, 0, + grad_output_value * ((one - dy) * (img_cxcy - img_fxcy) + + dy * (img_cxfy - img_fxfy))); + + update_grad_warp( + sample_id, 1, + grad_output_value * ((one - dx) * (img_cxcy - img_cxfy) + + dx * (img_fxcy - img_fxfy))); + + // Update partial gradients wrt sampled data + update_grad_data(fx, fy, chan, grad_output_value * dx * dy); + update_grad_data(cx, cy, chan, + grad_output_value * (one - dx) * (one - dy)); + update_grad_data(fx, cy, chan, + grad_output_value * dx * (one - dy)); + update_grad_data(cx, fy, chan, + grad_output_value * (one - dx) * dy); + } + } + } + } + }; + // Rough estimate of work for each batch entry. + // From third_party/tensorflow/core/util/work_sharder.cc we gather that an + // estimate of the cost of each work unit is needed to correctly shard the + // workload. Shard assumes each cost unit is 1ns, minimum cost per shard + // being 10us. + // TODO(fviola): Check out if there is a better way of doing this. + auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); + const int64 cost = + static_cast(num_sampling_points) * data_channels * 1000; + Shard(worker_threads.num_threads, worker_threads.workers, batch_size, cost, + update_grads_for_batches); + } +}; + +} // namespace functor + +template +class ResamplerGradOp : public OpKernel { + public: + explicit ResamplerGradOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& data = ctx->input(0); + const Tensor& warp = ctx->input(1); + const Tensor& grad_output = ctx->input(2); + + const TensorShape& data_shape = data.shape(); + OP_REQUIRES(ctx, data_shape.dims() == 4, + errors::Unimplemented( + "Only bilinear interpolation is supported, the input data " + "tensor must be a batch of 2d data; data shape should have " + "4 entries corresponding to [batch_size, data_height, " + "data_width, data_channels], but is: ", + data_shape.DebugString())); + const int batch_size = data_shape.dim_size(0); + const int data_height = data_shape.dim_size(1); + const int data_width = data_shape.dim_size(2); + const int data_channels = data_shape.dim_size(3); + const TensorShape& warp_shape = warp.shape(); + OP_REQUIRES( + ctx, TensorShapeUtils::IsMatrixOrHigher(warp_shape), + errors::InvalidArgument("warp should be at least a matrix, got shape ", + warp_shape.DebugString())); + OP_REQUIRES(ctx, warp_shape.dim_size(warp_shape.dims() - 1) == 2, + errors::Unimplemented( + "Only bilinear interpolation is supported, warping " + "coordinates must be 2D; warp shape last entry should be " + "2, but shape vector is: ", + warp_shape.DebugString())); + const TensorShape& grad_output_shape = grad_output.shape(); + TensorShape resampler_output_shape = warp.shape(); + resampler_output_shape.set_dim(resampler_output_shape.dims() - 1, + data_channels); + OP_REQUIRES(ctx, grad_output_shape == resampler_output_shape, + errors::InvalidArgument( + "grad_output shape is not consistent with data and warp " + "shapes; it should be ", + resampler_output_shape.DebugString(), " but is ", + grad_output_shape.DebugString())); + const int num_sampling_points = warp.NumElements() / batch_size / 2; + Tensor* grad_data = nullptr; + Tensor* grad_warp = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, data.shape(), &grad_data)); + OP_REQUIRES_OK(ctx, ctx->allocate_output(1, warp.shape(), &grad_warp)); + // Execute kernel only for nonempty output; otherwise Eigen crashes on GPU. + if (num_sampling_points > 0) { + functor::ResamplerGrad2DFunctor()( + ctx, ctx->eigen_device(), data.flat().data(), + warp.flat().data(), grad_output.flat().data(), + grad_data->flat().data(), grad_warp->flat().data(), batch_size, + data_height, data_width, data_channels, num_sampling_points); + } + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(ResamplerGradOp); +}; + +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER(Name("Addons>ResamplerGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + ResamplerGradOp); + +TF_CALL_half(REGISTER); +TF_CALL_float(REGISTER); +TF_CALL_double(REGISTER); +#undef REGISTER + +#if GOOGLE_CUDA +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER(Name("Addons>ResamplerGrad") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T"), \ + ResamplerGradOp) +// Disable half and double precision since atomicAdds are not supported +// TF_CALL_half(REGISTER); +// TF_CALL_double(REGISTER); +TF_CALL_float(REGISTER); + +#undef REGISTER +#endif // GOOGLE_CUDA + +} // end namespace addons +} // namespace tensorflow \ No newline at end of file diff --git a/tensorflow_addons/custom_ops/image/cc/kernels/resampler_ops.h b/tensorflow_addons/custom_ops/image/cc/kernels/resampler_ops.h new file mode 100644 index 0000000000..0aa53f51aa --- /dev/null +++ b/tensorflow_addons/custom_ops/image/cc/kernels/resampler_ops.h @@ -0,0 +1,53 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef TENSORFLOW_ADDONS_IMAGE_KERNELS_RESAMPLER_OPS_H_ +#define TENSORFLOW_ADDONS_IMAGE_KERNELS_RESAMPLER_OPS_H_ + +#if PLATFORM_WINDOWS +#define __restrict__ __restrict +#endif + +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace addons { +namespace functor { + +// Helper functor for the Resampler Op in 2D +template +struct Resampler2DFunctor { + void operator()(OpKernelContext* ctx, const Device& d, + const T* __restrict__ data, const T* __restrict__ warp, + T* __restrict__ output, const int batch_size, + const int data_height, const int data_width, + const int data_channels, const int num_sampling_points); +}; + +// Helper functor for the Resampler Gradient Op in 2D +template +struct ResamplerGrad2DFunctor { + void operator()(OpKernelContext* ctx, const Device& d, + const T* __restrict__ data, const T* __restrict__ warp, + const T* __restrict__ grad_output, T* __restrict__ grad_data, + T* __restrict__ grad_warp, const int batch_size, + const int data_height, const int data_width, + const int data_channels, const int num_sampling_points); +}; + +} // namespace functor +} // namespace addons +} // namespace tensorflow +#endif // TENSORFLOW_ADDONS_IMAGE_KERNELS_RESAMPLER_OPS_H_ \ No newline at end of file diff --git a/tensorflow_addons/custom_ops/image/cc/kernels/resampler_ops_gpu.cu.cc b/tensorflow_addons/custom_ops/image/cc/kernels/resampler_ops_gpu.cu.cc new file mode 100644 index 0000000000..cd8d49733e --- /dev/null +++ b/tensorflow_addons/custom_ops/image/cc/kernels/resampler_ops_gpu.cu.cc @@ -0,0 +1,281 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include + +#include + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" +#include "tensorflow_addons/custom_ops/image/cc/kernels/resampler_ops.h" + +namespace tensorflow { +namespace addons { + +using GPUDevice = Eigen::GpuDevice; + +namespace { + +#define GET_DATA_POINT(x, y) \ + data[batch_id * data_batch_stride + data_channels * (y * data_width + x) + \ + chan] + +template +__global__ void Resampler2DKernel(const T* __restrict__ data, + const T* __restrict__ warp, + T* __restrict__ output, const int batch_size, + const int data_height, const int data_width, + const int data_channels, + const int num_sampling_points) { + const int output_data_size = batch_size * num_sampling_points * data_channels; + CUDA_1D_KERNEL_LOOP(index, output_data_size) { + const int out_index = index; + + // Get (idxSample, channel, point) from the index. + // Use this formula + // index = batch_id * num_sampling_points * num_chans + + // sample_id * num_chans + chan_id, + // with sample_id = [0, ... ,num_sampling_points) + const int data_batch_stride = data_height * data_width * data_channels; + const int warp_batch_stride = num_sampling_points * 2; + const int output_batch_stride = num_sampling_points * data_channels; + + const int batch_id = index / output_batch_stride; + const int index_in_batch = index % output_batch_stride; + const int chan = index_in_batch % data_channels; + const int sample_id = index_in_batch / data_channels; + + // Get coords of 2D point where data will be resampled + const T x = warp[batch_id * warp_batch_stride + sample_id * 2]; + const T y = warp[batch_id * warp_batch_stride + sample_id * 2 + 1]; + const T zero = static_cast(0.0); + const T one = static_cast(1.0); + // The interpolation function: + // a) implicitly pads the input data with 0s (hence the unusual checks + // with {x,y} > -1) + // b) returns 0 when sampling outside the (padded) image. + // The effect is that the sampled signal smoothly goes to 0 outside + // the original input domain, rather than presenting a jump + // discontinuity at the image boundaries. + if (x > static_cast(-1.0) && y > static_cast(-1.0) && + x < static_cast(data_width) && y < static_cast(data_height)) { + // Precompute floor (f) and ceil (c) values for x and y. + const int fx = std::floor(static_cast(x)); + const int fy = std::floor(static_cast(y)); + const int cx = fx + 1; + const int cy = fy + 1; + const T dx = static_cast(cx) - x; + const T dy = static_cast(cy) - y; + + const T img_fxfy = + (fx >= 0 && fy >= 0) ? dx * dy * GET_DATA_POINT(fx, fy) : zero; + + const T img_cxcy = (cx <= data_width - 1 && cy <= data_height - 1) + ? (one - dx) * (one - dy) * GET_DATA_POINT(cx, cy) + : zero; + + const T img_fxcy = (fx >= 0 && cy <= data_height - 1) + ? dx * (one - dy) * GET_DATA_POINT(fx, cy) + : zero; + + const T img_cxfy = (cx <= data_width - 1 && fy >= 0) + ? (one - dx) * dy * GET_DATA_POINT(cx, fy) + : zero; + + output[out_index] = img_fxfy + img_cxcy + img_fxcy + img_cxfy; + } else { + output[out_index] = zero; + } + } +} + +} // namespace + +namespace functor { + +template +struct Resampler2DFunctor { + void operator()(OpKernelContext* ctx, const GPUDevice& d, + const T* __restrict__ data, const T* __restrict__ warp, + T* __restrict__ output, const int batch_size, + const int data_height, const int data_width, + const int data_channels, const int num_sampling_points) { + const int output_data_size = + batch_size * num_sampling_points * data_channels; + GpuLaunchConfig config = GetGpuLaunchConfig(output_data_size, d); + TF_CHECK_OK(GpuLaunchKernel( + Resampler2DKernel, config.block_count, config.thread_per_block, 0, + d.stream(), data, warp, output, batch_size, data_height, data_width, + data_channels, num_sampling_points)); + } +}; + +// TODO(fviola): gcudacc fails at compile time with Eigen::half. +// template struct Resampler2DFunctor; +template struct Resampler2DFunctor; +template struct Resampler2DFunctor; + +} // namespace functor + +namespace { + +#define UPDATE_GRAD_DATA_POINT(x, y, v) \ + atomicAdd(grad_data + (batch_id * data_batch_stride + \ + data_channels * (y * data_width + x) + chan), \ + v) + +template +__global__ void ResamplerGrad2DKernel( + const T* __restrict__ data, const T* __restrict__ warp, + const T* __restrict__ grad_output, T* __restrict__ grad_data, + T* __restrict__ grad_warp, const int batch_size, const int data_height, + const int data_width, const int data_channels, + const int num_sampling_points) { + const int resampler_output_size = + batch_size * num_sampling_points * data_channels; + CUDA_1D_KERNEL_LOOP(index, resampler_output_size) { + const int out_index = index; + + // Get (idxSample, channel, point) from the index. + // Use this formula + // index = batch_id * num_sampling_points * num_chans + + // sample_id * num_chans + chan_id, + // with sample_id = [0, ... ,num_sampling_points) + const int data_batch_stride = data_height * data_width * data_channels; + const int warp_batch_stride = num_sampling_points * 2; + const int output_batch_stride = num_sampling_points * data_channels; + + const int batch_id = index / output_batch_stride; + const int index_in_batch = index % output_batch_stride; + const int chan = index_in_batch % data_channels; + const int sample_id = index_in_batch / data_channels; + + // Get coords of 2D point where data will be resampled + const int warp_id_x = batch_id * warp_batch_stride + sample_id * 2; + const int warp_id_y = warp_id_x + 1; + const T x = warp[warp_id_x]; + const T y = warp[warp_id_y]; + const T zero = static_cast(0.0); + const T one = static_cast(1.0); + + // Get grad output + const T grad_output_value = grad_output[out_index]; + // The interpolation function whose gradient this kernel implements: + // a) implicitly pads the input data with 0s (hence the unusual checks + // with {x,y} > -1) + // b) returns 0 when sampling outside the (padded) image. + // The effect is that the sampled signal smoothly goes to 0 outside + // the original input domain, rather than presenting a jump + // discontinuity at the image boundaries. + if (x > static_cast(-1.0) && y > static_cast(-1.0) && + x < static_cast(data_width) && y < static_cast(data_height)) { + // Precompute floor (f) and ceil (c) values for x and y. + const int fx = std::floor(static_cast(x)); + const int fy = std::floor(static_cast(y)); + const int cx = fx + 1; + const int cy = fy + 1; + const T dx = static_cast(cx) - x; + const T dy = static_cast(cy) - y; + + const T img_fxfy = (fx >= 0 && fy >= 0) ? GET_DATA_POINT(fx, fy) : zero; + + const T img_cxcy = (cx <= data_width - 1 && cy <= data_height - 1) + ? GET_DATA_POINT(cx, cy) + : zero; + + const T img_fxcy = + (fx >= 0 && cy <= data_height - 1) ? GET_DATA_POINT(fx, cy) : zero; + + const T img_cxfy = + (cx <= data_width - 1 && fy >= 0) ? GET_DATA_POINT(cx, fy) : zero; + + // Update partial gradients wrt relevant warp field entries + atomicAdd(grad_warp + warp_id_x, + grad_output_value * ((one - dy) * (img_cxcy - img_fxcy) + + dy * (img_cxfy - img_fxfy))); + atomicAdd(grad_warp + warp_id_y, + grad_output_value * ((one - dx) * (img_cxcy - img_cxfy) + + dx * (img_fxcy - img_fxfy))); + + // Update partial gradients wrt sampled data + if (fx >= 0 && fy >= 0) { + UPDATE_GRAD_DATA_POINT(fx, fy, grad_output_value * dx * dy); + } + if (cx <= data_width - 1 && cy <= data_height - 1) { + UPDATE_GRAD_DATA_POINT(cx, cy, + grad_output_value * (one - dx) * (one - dy)); + } + if (fx >= 0 && cy <= data_height - 1) { + UPDATE_GRAD_DATA_POINT(fx, cy, grad_output_value * dx * (one - dy)); + } + if (cx <= data_width - 1 && fy >= 0) { + UPDATE_GRAD_DATA_POINT(cx, fy, grad_output_value * (one - dx) * dy); + } + } + } +} + +#undef GET_DATA_POINT +#undef UPDATE_GRAD_DATA_POINT + +} // namespace + +namespace functor { + +template +struct ResamplerGrad2DFunctor { + void operator()(OpKernelContext* ctx, const GPUDevice& d, + const T* __restrict__ data, const T* __restrict__ warp, + const T* __restrict__ grad_output, T* __restrict__ grad_data, + T* __restrict__ grad_warp, const int batch_size, + const int data_height, const int data_width, + const int data_channels, const int num_sampling_points) { + // Set gradients to 0, because the kernel incrementally updates the + // tensor entries by adding partial contributions. + const int grad_warp_size = batch_size * num_sampling_points * 2; + const int grad_data_size = + batch_size * data_height * data_width * data_channels; + + GpuLaunchConfig config = GetGpuLaunchConfig(grad_warp_size, d); + TF_CHECK_OK(GpuLaunchKernel(SetZero, config.block_count, + config.thread_per_block, 0, d.stream(), + grad_warp_size, grad_warp)); + + config = GetGpuLaunchConfig(grad_data_size, d); + TF_CHECK_OK(GpuLaunchKernel(SetZero, config.block_count, + config.thread_per_block, 0, d.stream(), + grad_data_size, grad_data)); + + const int resampler_output_size = + batch_size * num_sampling_points * data_channels; + config = GetGpuLaunchConfig(resampler_output_size, d); + TF_CHECK_OK(GpuLaunchKernel(ResamplerGrad2DKernel, config.block_count, + config.thread_per_block, 0, d.stream(), data, + warp, grad_output, grad_data, grad_warp, + batch_size, data_height, data_width, + data_channels, num_sampling_points)); + } +}; + +template struct ResamplerGrad2DFunctor; + +} // namespace functor +} // namespace addons +} // namespace tensorflow +#endif // GOOGLE_CUDA \ No newline at end of file diff --git a/tensorflow_addons/custom_ops/image/cc/ops/resampler_ops.cc b/tensorflow_addons/custom_ops/image/cc/ops/resampler_ops.cc new file mode 100644 index 0000000000..e2837eac55 --- /dev/null +++ b/tensorflow_addons/custom_ops/image/cc/ops/resampler_ops.cc @@ -0,0 +1,65 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { +namespace addons { + +using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +// -------------------------------------------------------------------------- +REGISTER_OP("Addons>Resampler") + .Input("data: T") + .Input("warp: T") + .Output("output: T") + .Attr("T: {half, float, double}") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle data; + ShapeHandle warp; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data)); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &warp)); + + ShapeHandle output; // will be warp[:-1] + [data[-1]] + TF_RETURN_IF_ERROR(c->Subshape(warp, 0, -1, &output)); + TF_RETURN_IF_ERROR( + c->Concatenate(output, c->Vector(c->Dim(data, -1)), &output)); + + c->set_output(0, output); + return Status::OK(); + }) + .Doc(R"doc(Resampler op.)doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("Addons>ResamplerGrad") + .Input("data: T") + .Input("warp: T") + .Input("grad_output: T") + .Output("grad_data: T") + .Output("grad_warp: T") + .Attr("T: {half, float, double}") + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->input(0)); + c->set_output(1, c->input(1)); + return Status::OK(); + }) + .Doc(R"doc(Resampler Grad op.)doc"); + +} // namespace addon +} // namespace tensorflow \ No newline at end of file diff --git a/tensorflow_addons/image/BUILD b/tensorflow_addons/image/BUILD index 27b42df54d..c09ae77e41 100644 --- a/tensorflow_addons/image/BUILD +++ b/tensorflow_addons/image/BUILD @@ -16,11 +16,13 @@ py_library( "sparse_image_warp.py", "interpolate_spline.py", "connected_components.py", + "resampler_ops.py", ]), data = [ ":sparse_image_warp_test_data", "//tensorflow_addons/custom_ops/image:_distort_image_ops.so", "//tensorflow_addons/custom_ops/image:_image_ops.so", + "//tensorflow_addons/custom_ops/image:_resampler_ops.so", "//tensorflow_addons/utils", ], srcs_version = "PY2AND3", @@ -160,3 +162,16 @@ py_test( ":image", ], ) + +py_test( + name = "resampler_ops_test", + size = "medium", + srcs = [ + "resampler_ops_test.py", + ], + main = "resampler_ops_test.py", + srcs_version = "PY2AND3", + deps = [ + ":image", + ], +) diff --git a/tensorflow_addons/image/README.md b/tensorflow_addons/image/README.md index 6742c14792..ec75671cb0 100644 --- a/tensorflow_addons/image/README.md +++ b/tensorflow_addons/image/README.md @@ -8,6 +8,7 @@ | distance_transform_ops | @mels630 | mels630@gmail.com | | distort_image_ops | @WindQAQ | windqaq@gmail.com | | filters | @Mainak431 | mainakdutta76@gmail.com | +| resampler_ops | @autoih | ihjhuo@gmail.com | | transform_ops | @mels630 | mels630@gmail.com | | translate_ops | @sayoojbk | sayoojbk@gmail.com | @@ -22,6 +23,7 @@ | distort_image_ops | random_hsv_in_yiq | | | filters | mean_filter2d | | | filters | median_filter2d | | +| resampler_ops | resampler | | | transform_ops | angles_to_projective_transforms | | | transform_ops | compose_transforms | | | transform_ops | matrices_to_flat_transforms | | diff --git a/tensorflow_addons/image/__init__.py b/tensorflow_addons/image/__init__.py index ced3acc14f..fa804dbdd2 100644 --- a/tensorflow_addons/image/__init__.py +++ b/tensorflow_addons/image/__init__.py @@ -17,6 +17,7 @@ from __future__ import division from __future__ import print_function +from tensorflow_addons.image.connected_components import connected_components from tensorflow_addons.image.dense_image_warp import dense_image_warp from tensorflow_addons.image.dense_image_warp import interpolate_bilinear from tensorflow_addons.image.distance_transform import euclidean_dist_transform @@ -24,9 +25,9 @@ from tensorflow_addons.image.distort_image_ops import random_hsv_in_yiq from tensorflow_addons.image.filters import mean_filter2d from tensorflow_addons.image.filters import median_filter2d +from tensorflow_addons.image.interpolate_spline import interpolate_spline +from tensorflow_addons.image.resampler_ops import resampler +from tensorflow_addons.image.sparse_image_warp import sparse_image_warp from tensorflow_addons.image.transform_ops import rotate from tensorflow_addons.image.transform_ops import transform -from tensorflow_addons.image.sparse_image_warp import sparse_image_warp -from tensorflow_addons.image.interpolate_spline import interpolate_spline from tensorflow_addons.image.translate_ops import translate -from tensorflow_addons.image.connected_components import connected_components diff --git a/tensorflow_addons/image/resampler_ops.py b/tensorflow_addons/image/resampler_ops.py new file mode 100644 index 0000000000..304bed5ff3 --- /dev/null +++ b/tensorflow_addons/image/resampler_ops.py @@ -0,0 +1,65 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Python layer for Resampler.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow_addons.utils.resource_loader import get_path_to_datafile + +_resampler_ops = tf.load_op_library( + get_path_to_datafile("custom_ops/image/_resampler_ops.so")) + + +@tf.function +def resampler(data, warp, name=None): + """Resamples input data at user defined coordinates. + + The resampler currently only supports bilinear interpolation of 2D data. + Args: + data: Tensor of shape `[batch_size, data_height, data_width, + data_num_channels]` containing 2D data that will be resampled. + warp: Tensor of minimum rank 2 containing the coordinates at + which resampling will be performed. Since only bilinear + interpolation is currently supported, the last dimension of the + `warp` tensor must be 2, representing the (x, y) coordinate where + x is the index for width and y is the index for height. + name: Optional name of the op. + Returns: + Tensor of resampled values from `data`. The output tensor shape + is determined by the shape of the warp tensor. For example, if `data` + is of shape `[batch_size, data_height, data_width, data_num_channels]` + and warp of shape `[batch_size, dim_0, ... , dim_n, 2]` the output will + be of shape `[batch_size, dim_0, ... , dim_n, data_num_channels]`. + Raises: + ImportError: if the wrapper generated during compilation is not + present when the function is called. + """ + with tf.name_scope(name or "resampler"): + data_tensor = tf.convert_to_tensor(data, name="data") + warp_tensor = tf.convert_to_tensor(warp, name="warp") + return _resampler_ops.addons_resampler(data_tensor, warp_tensor) + + +@tf.RegisterGradient("Addons>Resampler") +def _resampler_grad(op, grad_output): + data, warp = op.inputs + grad_output_tensor = tf.convert_to_tensor(grad_output, name="grad_output") + return _resampler_ops.addons_resampler_grad(data, warp, grad_output_tensor) + + +tf.no_gradient("Addons>ResamplerGrad") diff --git a/tensorflow_addons/image/resampler_ops_test.py b/tensorflow_addons/image/resampler_ops_test.py new file mode 100644 index 0000000000..156cbb29e1 --- /dev/null +++ b/tensorflow_addons/image/resampler_ops_test.py @@ -0,0 +1,248 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for resampler.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +import numpy as np +import tensorflow as tf +from tensorflow_addons.image import resampler_ops +from tensorflow_addons.utils import test_utils + + +def _bilinearly_interpolate(data, x, y): + """Performs bilinenar interpolation of grid data at user defined + coordinates. + + This interpolation function: + a) implicitly pads the input data with 0s. + b) returns 0 when sampling outside the (padded) image. + The effect is that the sampled signal smoothly goes to 0 outside the + original input domain, rather than producing a jump discontinuity at + the image boundaries. + Args: + data: numpy array of shape `[data_height, data_width]` containing data + samples assumed to be defined at the corresponding pixel coordinates. + x: numpy array of shape `[warp_height, warp_width]` containing + x coordinates at which interpolation will be performed. + y: numpy array of shape `[warp_height, warp_width]` containing + y coordinates at which interpolation will be performed. + Returns: + Numpy array of shape `[warp_height, warp_width]` containing interpolated + values. + """ + shape = x.shape + x = np.asarray(x) + 1 + y = np.asarray(y) + 1 + data = np.pad(data, 1, "constant", constant_values=0) + + x_0 = np.floor(x).astype(int) + x_1 = x_0 + 1 + y_0 = np.floor(y).astype(int) + y_1 = y_0 + 1 + + x_0 = np.clip(x_0, 0, data.shape[1] - 1) + x_1 = np.clip(x_1, 0, data.shape[1] - 1) + y_0 = np.clip(y_0, 0, data.shape[0] - 1) + y_1 = np.clip(y_1, 0, data.shape[0] - 1) + + i_a = data[y_0, x_0] + i_b = data[y_1, x_0] + i_c = data[y_0, x_1] + i_d = data[y_1, x_1] + + w_a = (x_1 - x) * (y_1 - y) + w_b = (x_1 - x) * (y - y_0) + w_c = (x - x_0) * (y_1 - y) + w_d = (x - x_0) * (y - y_0) + + samples = (w_a * i_a + w_b * i_b + w_c * i_c + w_d * i_d) + samples = samples.reshape(shape) + + return samples + + +def _make_warp(batch_size, warp_height, warp_width, dtype): + """Creates batch of warping coordinates.""" + x, y = np.meshgrid( + np.linspace(0, warp_width - 1, warp_width), + np.linspace(0, warp_height - 1, warp_height)) + warp = np.concatenate((x.reshape([warp_height, warp_width, 1]), + y.reshape([warp_height, warp_width, 1])), 2) + warp = np.tile( + warp.reshape([1, warp_height, warp_width, 2]), [batch_size, 1, 1, 1]) + warp += np.random.randn(*warp.shape) + return warp.astype(dtype) + + +@test_utils.run_all_in_graph_and_eager_modes +class ResamplerTest(tf.test.TestCase, parameterized.TestCase): + @parameterized.named_parameters(("float32", np.float32), + ("float64", np.float64)) + def test_op_forward_pass_gpu(self, dtype): + if not tf.test.is_gpu_available(): + self.skipTest("gpu is not available.") + self._test_op_forward_pass(True, dtype) + + @parameterized.named_parameters(("float16", np.float16), + ("float32", np.float32), + ("float64", np.float64)) + def test_op_forward_pass_cpu(self, dtype): + self._test_op_forward_pass(False, dtype) + + @parameterized.named_parameters(("float32", np.float32), + ("float64", np.float64)) + def test_op_backward_pass_gpu(self, dtype): + if not tf.test.is_gpu_available(): + self.skipTest("gpu is not available.") + self._test_op_backward_pass(True, dtype) + + @parameterized.named_parameters(("float16", np.float16), + ("float32", np.float32), + ("float64", np.float64)) + def test_op_backward_pass_cpu(self, dtype): + self._test_op_backward_pass(False, dtype) + + def _test_op_forward_pass(self, on_gpu, dtype): + np.random.seed(0) + data_width = 7 + data_height = 9 + data_channels = 5 + warp_width = 4 + warp_height = 8 + batch_size = 10 + + warp = _make_warp(batch_size, warp_height, warp_width, dtype) + data_shape = (batch_size, data_height, data_width, data_channels) + data = np.random.rand(*data_shape).astype(dtype) + use_gpu = on_gpu and tf.test.is_gpu_available() + with test_utils.device(use_gpu): + data_ph = tf.constant(data) + warp_ph = tf.constant(warp) + outputs = self.evaluate( + resampler_ops.resampler(data=data_ph, warp=warp_ph)) + self.assertEqual(outputs.shape, + (10, warp_height, warp_width, data_channels)) + + # Generate reference output via bilinear interpolation in numpy + reference_output = np.zeros_like(outputs) + for batch in range(batch_size): + for c in range(data_channels): + reference_output[batch, :, :, c] = _bilinearly_interpolate( + data[batch, :, :, c], warp[batch, :, :, 0], + warp[batch, :, :, 1]) + + self.assertAllCloseAccordingToType( + outputs, reference_output, half_rtol=5e-3, half_atol=5e-3) + + def _test_op_backward_pass(self, on_gpu, dtype): + np.random.seed(13) + data_width = 5 + data_height = 4 + data_channels = 3 + warp_width = 2 + warp_height = 6 + batch_size = 3 + + warp = _make_warp(batch_size, warp_height, warp_width, dtype) + data_shape = (batch_size, data_height, data_width, data_channels) + data = np.random.rand(*data_shape).astype(dtype) + use_gpu = on_gpu and tf.test.is_gpu_available() + with test_utils.device(use_gpu): + data_tensor = tf.constant(data) + warp_tensor = tf.constant(warp) + theoretical, numerical = tf.test.compute_gradient( + resampler_ops.resampler, [data_tensor, warp_tensor]) + if not use_gpu: + # On CPU we perform numerical differentiation at the best available + # precision, and compare against that. This is necessary for test to + # pass for float16. + data_tensor_64 = tf.constant(data, dtype=tf.float64) + warp_tensor_64 = tf.constant(warp, dtype=tf.float64) + theoretical_64, numerical_64 = tf.test.compute_gradient( + resampler_ops.resampler, [data_tensor_64, warp_tensor_64]) + + for t, n in zip(theoretical, numerical_64): + self.assertAllCloseAccordingToType( + t, n, float_rtol=5e-5, float_atol=5e-5) + else: + for t, n in zip(theoretical, numerical): + self.assertAllCloseAccordingToType( + t, n, float_rtol=5e-5, float_atol=5e-5) + + def test_op_errors(self): + batch_size = 10 + data_height = 9 + data_width = 7 + data_depth = 3 + data_channels = 5 + warp_width = 4 + warp_height = 8 + + # Input data shape is not defined over a 2D grid, i.e. its shape is not like + # (batch_size, data_height, data_width, data_channels). + data_shape = (batch_size, data_height, data_width, data_depth, + data_channels) + data = np.zeros(data_shape) + warp_shape = (batch_size, warp_height, warp_width, 2) + warp = np.zeros(warp_shape) + + # pylint: disable=bad-continuation + with self.assertRaisesRegexp( + tf.errors.UnimplementedError, + "Only bilinear interpolation is currently supported."): + # pylint: enable=bad-continuation + self.evaluate(resampler_ops.resampler(data, warp)) + + # Warp tensor must be at least a matrix, with shape [batch_size, 2]. + data_shape = (batch_size, data_height, data_width, data_channels) + data = np.zeros(data_shape) + warp_shape = (batch_size,) + warp = np.zeros(warp_shape) + + with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, + "warp should be at least a matrix"): + self.evaluate(resampler_ops.resampler(data, warp)) + + # The batch size of the data and warp tensors must be the same. + data_shape = (batch_size, data_height, data_width, data_channels) + data = np.zeros(data_shape) + warp_shape = (batch_size + 1, warp_height, warp_width, 2) + warp = np.zeros(warp_shape) + + with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, + "Batch size of data and warp tensor"): + self.evaluate(resampler_ops.resampler(data, warp)) + + # The warp tensor must contain 2D coordinates, i.e. its shape last dimension + # must be 2. + data_shape = (batch_size, data_height, data_width, data_channels) + data = np.zeros(data_shape) + warp_shape = (batch_size, warp_height, warp_width, 3) + warp = np.zeros(warp_shape) + + # pylint: disable=bad-continuation + with self.assertRaisesRegexp( + tf.errors.UnimplementedError, + "Only bilinear interpolation is supported, warping"): + # pylint: enable=bad-continuation + self.evaluate(resampler_ops.resampler(data, warp)) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_addons/image/sparse_image_warp.py b/tensorflow_addons/image/sparse_image_warp.py index b1697eab0f..5a0cbe42bc 100644 --- a/tensorflow_addons/image/sparse_image_warp.py +++ b/tensorflow_addons/image/sparse_image_warp.py @@ -187,7 +187,7 @@ def sparse_image_warp(image, dest_control_point_locations, control_point_flows, image_height, image_width, boundary_points_per_edge) - flattened_flows = interpolate_spline.interpolate_spline( + flattened_flows = interpolate_spline( dest_control_point_locations, control_point_flows, flattened_grid_locations, interpolation_order, regularization_weight)