diff --git a/tensorflow_addons/image/BUILD b/tensorflow_addons/image/BUILD index 36b8e1e022..852279257f 100644 --- a/tensorflow_addons/image/BUILD +++ b/tensorflow_addons/image/BUILD @@ -6,6 +6,7 @@ py_library( name = "image", srcs = ([ "__init__.py", + "dense_image_warp.py", "distort_image_ops.py", "transform_ops.py", ]), @@ -17,6 +18,19 @@ py_library( srcs_version = "PY2AND3", ) +py_test( + name = "dense_image_warp_test", + size = "small", + srcs = [ + "dense_image_warp_test.py", + ], + main = "dense_image_warp_test.py", + srcs_version = "PY2AND3", + deps = [ + ":image", + ], +) + py_test( name = "distort_image_ops_test", size = "small", diff --git a/tensorflow_addons/image/README.md b/tensorflow_addons/image/README.md index a07d608dee..ae1049dbb2 100644 --- a/tensorflow_addons/image/README.md +++ b/tensorflow_addons/image/README.md @@ -3,12 +3,15 @@ ## Maintainers | Submodule | Maintainers | Contact Info | |:---------- |:----------- |:--------------| +| dense_image_warp | | | | distort_image_ops | | | | transform_ops | | | ## Components | Submodule | Image Processing Function | Reference | |:---------- |:----------- |:----------- | +| dense_image_warp | dense_image_warp | | +| dense_image_warp | interpolate_bilinear | | | distort_image_ops | adjust_hsv_in_yiq | | | distort_image_ops | random_hsv_in_yiq | | | transform_ops | angles_to_projective_transforms | | diff --git a/tensorflow_addons/image/__init__.py b/tensorflow_addons/image/__init__.py index eab041cffe..280981ddd3 100644 --- a/tensorflow_addons/image/__init__.py +++ b/tensorflow_addons/image/__init__.py @@ -17,6 +17,8 @@ from __future__ import division from __future__ import print_function +from tensorflow_addons.image.dense_image_warp import dense_image_warp +from tensorflow_addons.image.dense_image_warp import interpolate_bilinear from tensorflow_addons.image.distort_image_ops import adjust_hsv_in_yiq from tensorflow_addons.image.distort_image_ops import random_hsv_in_yiq from tensorflow_addons.image.transform_ops import rotate diff --git a/tensorflow_addons/image/dense_image_warp.py b/tensorflow_addons/image/dense_image_warp.py new file mode 100644 index 0000000000..dcffc72401 --- /dev/null +++ b/tensorflow_addons/image/dense_image_warp.py @@ -0,0 +1,211 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Image warping using per-pixel flow vectors.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + + +@tf.function +def interpolate_bilinear(grid, + query_points, + name="interpolate_bilinear", + indexing="ij"): + """Similar to Matlab's interp2 function. + + Finds values for query points on a grid using bilinear interpolation. + + Args: + grid: a 4-D float `Tensor` of shape `[batch, height, width, channels]`. + query_points: a 3-D float `Tensor` of N points with shape + `[batch, N, 2]`. + name: a name for the operation (optional). + indexing: whether the query points are specified as row and column (ij), + or Cartesian coordinates (xy). + + Returns: + values: a 3-D `Tensor` with shape `[batch, N, channels]` + + Raises: + ValueError: if the indexing mode is invalid, or if the shape of the + inputs invalid. + """ + if indexing != "ij" and indexing != "xy": + raise ValueError("Indexing mode must be \'ij\' or \'xy\'") + + with tf.name_scope(name): + grid = tf.convert_to_tensor(grid) + query_points = tf.convert_to_tensor(query_points) + shape = grid.get_shape().as_list() + if len(shape) != 4: + msg = "Grid must be 4 dimensional. Received size: " + raise ValueError(msg + str(grid.get_shape())) + + batch_size, height, width, channels = (tf.shape(grid)[0], + tf.shape(grid)[1], + tf.shape(grid)[2], + tf.shape(grid)[3]) + + shape = [batch_size, height, width, channels] + query_type = query_points.dtype + grid_type = grid.dtype + + tf.debugging.assert_equal( + len(query_points.get_shape()), + 3, + message="Query points must be 3 dimensional.") + tf.debugging.assert_equal( + tf.shape(query_points)[2], + 2, + message="Query points must be size 2 in dim 2.") + + num_queries = tf.shape(query_points)[1] + + tf.debugging.assert_greater_equal( + height, 2, message="Grid height must be at least 2."), + tf.debugging.assert_greater_equal( + width, 2, message="Grid width must be at least 2.") + + alphas = [] + floors = [] + ceils = [] + index_order = [0, 1] if indexing == "ij" else [1, 0] + unstacked_query_points = tf.unstack(query_points, axis=2) + + for dim in index_order: + with tf.name_scope("dim-" + str(dim)): + queries = unstacked_query_points[dim] + + size_in_indexing_dimension = shape[dim + 1] + + # max_floor is size_in_indexing_dimension - 2 so that max_floor + 1 + # is still a valid index into the grid. + max_floor = tf.cast(size_in_indexing_dimension - 2, query_type) + min_floor = tf.constant(0.0, dtype=query_type) + floor = tf.math.minimum( + tf.math.maximum(min_floor, tf.math.floor(queries)), + max_floor) + int_floor = tf.cast(floor, tf.dtypes.int32) + floors.append(int_floor) + ceil = int_floor + 1 + ceils.append(ceil) + + # alpha has the same type as the grid, as we will directly use alpha + # when taking linear combinations of pixel values from the image. + alpha = tf.cast(queries - floor, grid_type) + min_alpha = tf.constant(0.0, dtype=grid_type) + max_alpha = tf.constant(1.0, dtype=grid_type) + alpha = tf.math.minimum( + tf.math.maximum(min_alpha, alpha), max_alpha) + + # Expand alpha to [b, n, 1] so we can use broadcasting + # (since the alpha values don't depend on the channel). + alpha = tf.expand_dims(alpha, 2) + alphas.append(alpha) + + tf.debugging.assert_less_equal( + tf.cast(batch_size * height * width, dtype=tf.dtypes.float32), + np.iinfo(np.int32).max / 8.0, + message="The image size or batch size is sufficiently large " + "that the linearized addresses used by tf.gather " + "may exceed the int32 limit.") + flattened_grid = tf.reshape(grid, + [batch_size * height * width, channels]) + batch_offsets = tf.reshape( + tf.range(batch_size) * height * width, [batch_size, 1]) + + # This wraps tf.gather. We reshape the image data such that the + # batch, y, and x coordinates are pulled into the first dimension. + # Then we gather. Finally, we reshape the output back. It's possible this + # code would be made simpler by using tf.gather_nd. + def gather(y_coords, x_coords, name): + with tf.name_scope("gather-" + name): + linear_coordinates = ( + batch_offsets + y_coords * width + x_coords) + gathered_values = tf.gather(flattened_grid, linear_coordinates) + return tf.reshape(gathered_values, + [batch_size, num_queries, channels]) + + # grab the pixel values in the 4 corners around each query point + top_left = gather(floors[0], floors[1], "top_left") + top_right = gather(floors[0], ceils[1], "top_right") + bottom_left = gather(ceils[0], floors[1], "bottom_left") + bottom_right = gather(ceils[0], ceils[1], "bottom_right") + + # now, do the actual interpolation + with tf.name_scope("interpolate"): + interp_top = alphas[1] * (top_right - top_left) + top_left + interp_bottom = alphas[1] * ( + bottom_right - bottom_left) + bottom_left + interp = alphas[0] * (interp_bottom - interp_top) + interp_top + + return interp + + +@tf.function +def dense_image_warp(image, flow, name="dense_image_warp"): + """Image warping using per-pixel flow vectors. + + Apply a non-linear warp to the image, where the warp is specified by a + dense flow field of offset vectors that define the correspondences of + pixel values in the output image back to locations in the source image. + Specifically, the pixel value at output[b, j, i, c] is + images[b, j - flow[b, j, i, 0], i - flow[b, j, i, 1], c]. + + The locations specified by this formula do not necessarily map to an int + index. Therefore, the pixel value is obtained by bilinear + interpolation of the 4 nearest pixels around + (b, j - flow[b, j, i, 0], i - flow[b, j, i, 1]). For locations outside + of the image, we use the nearest pixel values at the image boundary. + + Args: + image: 4-D float `Tensor` with shape `[batch, height, width, channels]`. + flow: A 4-D float `Tensor` with shape `[batch, height, width, 2]`. + name: A name for the operation (optional). + + Note that image and flow can be of type tf.half, tf.float32, or + tf.float64, and do not necessarily have to be the same type. + + Returns: + A 4-D float `Tensor` with shape`[batch, height, width, channels]` + and same type as input image. + + Raises: + ValueError: if height < 2 or width < 2 or the inputs have the wrong + number of dimensions. + """ + with tf.name_scope(name): + batch_size, height, width, channels = (tf.shape(image)[0], + tf.shape(image)[1], + tf.shape(image)[2], + tf.shape(image)[3]) + + # The flow is defined on the image grid. Turn the flow into a list of query + # points in the grid space. + grid_x, grid_y = tf.meshgrid(tf.range(width), tf.range(height)) + stacked_grid = tf.cast(tf.stack([grid_y, grid_x], axis=2), flow.dtype) + batched_grid = tf.expand_dims(stacked_grid, axis=0) + query_points_on_grid = batched_grid - flow + query_points_flattened = tf.reshape(query_points_on_grid, + [batch_size, height * width, 2]) + # Compute values at the query points, then reshape the result back to the + # image grid. + interpolated = interpolate_bilinear(image, query_points_flattened) + interpolated = tf.reshape(interpolated, + [batch_size, height, width, channels]) + return interpolated diff --git a/tensorflow_addons/image/dense_image_warp_test.py b/tensorflow_addons/image/dense_image_warp_test.py new file mode 100644 index 0000000000..6810d6a936 --- /dev/null +++ b/tensorflow_addons/image/dense_image_warp_test.py @@ -0,0 +1,227 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for dense_image_warp.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import numpy as np +import tensorflow as tf + +from tensorflow_addons.image import dense_image_warp +from tensorflow_addons.image import interpolate_bilinear +from tensorflow_addons.utils import test_utils + + +class DenseImageWarpTest(tf.test.TestCase): + def setUp(self): + np.random.seed(0) + + @test_utils.run_in_graph_and_eager_modes + def test_interpolate_small_grid_ij(self): + grid = tf.constant([[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], + shape=[1, 3, 3, 1]) + query_points = tf.constant([[0., 0.], [1., 0.], [2., 0.5], [1.5, 1.5]], + shape=[1, 4, 2]) + expected_results = np.reshape(np.array([0., 3., 6.5, 6.]), [1, 4, 1]) + + interp = interpolate_bilinear(grid, query_points) + + self.assertAllClose(expected_results, interp) + + @test_utils.run_in_graph_and_eager_modes + def test_interpolate_small_grid_xy(self): + grid = tf.constant([[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], + shape=[1, 3, 3, 1]) + query_points = tf.constant( + [[0., 0.], [0., 1.], [0.5, 2.0], [1.5, 1.5]], shape=[1, 4, 2]) + expected_results = np.reshape(np.array([0., 3., 6.5, 6.]), [1, 4, 1]) + + interp = interpolate_bilinear(grid, query_points, indexing="xy") + + self.assertAllClose(expected_results, interp) + + @test_utils.run_in_graph_and_eager_modes + def test_interpolate_small_grid_batched(self): + grid = tf.constant([[[0., 1.], [3., 4.]], [[5., 6.], [7., 8.]]], + shape=[2, 2, 2, 1]) + query_points = tf.constant([[[0., 0.], [1., 0.], [0.5, 0.5]], + [[0.5, 0.], [1., 0.], [1., 1.]]]) + expected_results = np.reshape( + np.array([[0., 3., 2.], [6., 7., 8.]]), [2, 3, 1]) + + interp = interpolate_bilinear(grid, query_points) + + self.assertAllClose(expected_results, interp) + + def _get_random_image_and_flows(self, shape, image_type, flow_type): + batch_size, height, width, num_channels = shape + image_shape = [batch_size, height, width, num_channels] + image = np.random.normal(size=image_shape) + flow_shape = [batch_size, height, width, 2] + flows = np.random.normal(size=flow_shape) * 3 + return image.astype(image_type), flows.astype(flow_type) + + def _assert_correct_interpolation_value(self, + image, + flows, + pred_interpolation, + batch_index, + y_index, + x_index, + low_precision=False): + """Assert that the tf interpolation matches hand-computed value.""" + height = image.shape[1] + width = image.shape[2] + displacement = flows[batch_index, y_index, x_index, :] + float_y = y_index - displacement[0] + float_x = x_index - displacement[1] + floor_y = max(min(height - 2, math.floor(float_y)), 0) + floor_x = max(min(width - 2, math.floor(float_x)), 0) + ceil_y = floor_y + 1 + ceil_x = floor_x + 1 + + alpha_y = min(max(0.0, float_y - floor_y), 1.0) + alpha_x = min(max(0.0, float_x - floor_x), 1.0) + + floor_y = int(floor_y) + floor_x = int(floor_x) + ceil_y = int(ceil_y) + ceil_x = int(ceil_x) + + top_left = image[batch_index, floor_y, floor_x, :] + top_right = image[batch_index, floor_y, ceil_x, :] + bottom_left = image[batch_index, ceil_y, floor_x, :] + bottom_right = image[batch_index, ceil_y, ceil_x, :] + + interp_top = alpha_x * (top_right - top_left) + top_left + interp_bottom = alpha_x * (bottom_right - bottom_left) + bottom_left + interp = alpha_y * (interp_bottom - interp_top) + interp_top + atol = 1e-6 + rtol = 1e-6 + if low_precision: + atol = 1e-2 + rtol = 1e-3 + self.assertAllClose( + interp, + pred_interpolation[batch_index, y_index, x_index, :], + atol=atol, + rtol=rtol) + + def _check_zero_flow_correctness(self, shape, image_type, flow_type): + """Assert using zero flows doesn't change the input image.""" + rand_image, rand_flows = self._get_random_image_and_flows( + shape, image_type, flow_type) + rand_flows *= 0 + + interp = dense_image_warp( + image=tf.convert_to_tensor(rand_image), + flow=tf.convert_to_tensor(rand_flows)) + + self.assertAllClose(rand_image, interp) + + @test_utils.run_in_graph_and_eager_modes + def test_zero_flows(self): + """Apply _check_zero_flow_correctness() for a few sizes and types.""" + shapes_to_try = [[3, 4, 5, 6], [1, 2, 2, 1]] + for shape in shapes_to_try: + self._check_zero_flow_correctness( + shape, image_type="float32", flow_type="float32") + + def _check_interpolation_correctness(self, + shape, + image_type, + flow_type, + num_probes=5): + """Interpolate, and then assert correctness for a few query + locations.""" + low_precision = image_type == "float16" or flow_type == "float16" + rand_image, rand_flows = self._get_random_image_and_flows( + shape, image_type, flow_type) + + interp = dense_image_warp( + image=tf.convert_to_tensor(rand_image), + flow=tf.convert_to_tensor(rand_flows)) + + for _ in range(num_probes): + batch_index = np.random.randint(0, shape[0]) + y_index = np.random.randint(0, shape[1]) + x_index = np.random.randint(0, shape[2]) + + self._assert_correct_interpolation_value( + rand_image, + rand_flows, + interp, + batch_index, + y_index, + x_index, + low_precision=low_precision) + + @test_utils.run_in_graph_and_eager_modes + def test_interpolation(self): + """Apply _check_interpolation_correctness() for a few sizes and + types.""" + shapes_to_try = [[3, 4, 5, 6], [1, 5, 5, 3], [1, 2, 2, 1]] + for im_type in ["float32", "float64", "float16"]: + for flow_type in ["float32", "float64", "float16"]: + for shape in shapes_to_try: + self._check_interpolation_correctness( + shape, im_type, flow_type) + + # TODO: switch to TF2 later. + @test_utils.run_deprecated_v1 + def test_gradients_exist(self): + """Check that backprop can run. + + The correctness of the gradients is assumed, since the forward + propagation is tested to be correct and we only use built-in tf + ops. However, we perform a simple test to make sure that + backprop can actually run. We treat the flows as a tf.Variable + and optimize them to minimize the difference between the + interpolated image and the input image. + """ + batch_size, height, width, num_channels = [4, 5, 6, 7] + image_shape = [batch_size, height, width, num_channels] + image = tf.random.normal(image_shape) + flow_shape = [batch_size, height, width, 2] + init_flows = np.float32(np.random.normal(size=flow_shape) * 0.25) + flows = tf.Variable(init_flows) + + interp = dense_image_warp(image, flows) + loss = tf.math.reduce_mean(tf.math.square(interp - image)) + + optimizer = tf.optimizers.Adam(1.0) + grad = tf.gradients(loss, [flows]) + opt_func = optimizer.apply_gradients(zip(grad, [flows])) + init_op = tf.compat.v1.global_variables_initializer() + + with self.cached_session() as sess: + sess.run(init_op) + for _ in range(10): + sess.run(opt_func) + + # TODO: run in both graph and eager modes + def test_size_exception(self): + """Make sure it throws an exception for images that are too small.""" + shape = [1, 2, 1, 1] + with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, + "Grid width must be at least 2."): + self._check_interpolation_correctness(shape, "float32", "float32") + + +if __name__ == "__main__": + tf.test.main()