diff --git a/tensorflow_addons/image/BUILD b/tensorflow_addons/image/BUILD index 52b68cf8f6..79f5b1fcd5 100644 --- a/tensorflow_addons/image/BUILD +++ b/tensorflow_addons/image/BUILD @@ -13,8 +13,11 @@ py_library( "transform_ops.py", "translate_ops.py", "utils.py", + "sparse_image_warp.py", + "interpolate_spline.py", ]), data = [ + ":sparse_image_warp_test_data", "//tensorflow_addons/custom_ops/image:_distort_image_ops.so", "//tensorflow_addons/custom_ops/image:_image_ops.so", "//tensorflow_addons/utils", @@ -22,6 +25,11 @@ py_library( srcs_version = "PY2AND3", ) +filegroup( + name = "sparse_image_warp_test_data", + srcs = glob(["test_data/*.png"]), +) + py_test( name = "dense_image_warp_test", size = "small", @@ -113,3 +121,29 @@ py_test( ":image", ], ) + +py_test( + name = "sparse_image_warp_test", + size = "medium", + srcs = [ + "sparse_image_warp_test.py", + ], + main = "sparse_image_warp_test.py", + srcs_version = "PY2AND3", + deps = [ + ":image", + ], +) + +py_test( + name = "interpolate_spline_test", + size = "medium", + srcs = [ + "interpolate_spline_test.py", + ], + main = "interpolate_spline_test.py", + srcs_version = "PY2AND3", + deps = [ + ":image", + ], +) diff --git a/tensorflow_addons/image/__init__.py b/tensorflow_addons/image/__init__.py index c2ec059e1a..d0d886735d 100644 --- a/tensorflow_addons/image/__init__.py +++ b/tensorflow_addons/image/__init__.py @@ -26,4 +26,6 @@ from tensorflow_addons.image.filters import median_filter2d from tensorflow_addons.image.transform_ops import rotate from tensorflow_addons.image.transform_ops import transform +from tensorflow_addons.image.sparse_image_warp import sparse_image_warp +from tensorflow_addons.image.interpolate_spline import interpolate_spline from tensorflow_addons.image.translate_ops import translate diff --git a/tensorflow_addons/image/interpolate_spline.py b/tensorflow_addons/image/interpolate_spline.py new file mode 100644 index 0000000000..84abf98209 --- /dev/null +++ b/tensorflow_addons/image/interpolate_spline.py @@ -0,0 +1,303 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Polyharmonic spline interpolation.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +EPSILON = 0.0000000001 + + +def _cross_squared_distance_matrix(x, y): + """Pairwise squared distance between two (batch) matrices' rows (2nd dim). + + Computes the pairwise distances between rows of x and rows of y + Args: + x: [batch_size, n, d] float `Tensor` + y: [batch_size, m, d] float `Tensor` + + Returns: + squared_dists: [batch_size, n, m] float `Tensor`, where + squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2 + """ + x_norm_squared = tf.reduce_sum(tf.square(x), 2) + y_norm_squared = tf.reduce_sum(tf.square(y), 2) + + # Expand so that we can broadcast. + x_norm_squared_tile = tf.expand_dims(x_norm_squared, 2) + y_norm_squared_tile = tf.expand_dims(y_norm_squared, 1) + + x_y_transpose = tf.matmul(x, y, adjoint_b=True) + + # squared_dists[b,i,j] = ||x_bi - y_bj||^2 = + # x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj + squared_dists = ( + x_norm_squared_tile - 2 * x_y_transpose + y_norm_squared_tile) + + return squared_dists + + +def _pairwise_squared_distance_matrix(x): + """Pairwise squared distance among a (batch) matrix's rows (2nd dim). + + This saves a bit of computation vs. using + _cross_squared_distance_matrix(x,x) + + Args: + x: `[batch_size, n, d]` float `Tensor` + + Returns: + squared_dists: `[batch_size, n, n]` float `Tensor`, where + squared_dists[b,i,j] = ||x[b,i,:] - x[b,j,:]||^2 + """ + + x_x_transpose = tf.matmul(x, x, adjoint_b=True) + x_norm_squared = tf.linalg.diag_part(x_x_transpose) + x_norm_squared_tile = tf.expand_dims(x_norm_squared, 2) + + # squared_dists[b,i,j] = ||x_bi - x_bj||^2 = + # = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj + squared_dists = x_norm_squared_tile - 2 * x_x_transpose + tf.transpose( + x_norm_squared_tile, [0, 2, 1]) + + return squared_dists + + +def _solve_interpolation(train_points, train_values, order, + regularization_weight): + """Solve for interpolation coefficients. + + Computes the coefficients of the polyharmonic interpolant for the + 'training' data defined by (train_points, train_values) using the kernel + phi. + + Args: + train_points: `[b, n, d]` interpolation centers + train_values: `[b, n, k]` function values + order: order of the interpolation + regularization_weight: weight to place on smoothness regularization term + + Returns: + w: `[b, n, k]` weights on each interpolation center + v: `[b, d, k]` weights on each input dimension + Raises: + ValueError: if d or k is not fully specified. + """ + + # These dimensions are set dynamically at runtime. + b, n, _ = tf.unstack(tf.shape(train_points), num=3) + + d = train_points.shape[-1] + if d is None: + raise ValueError('The dimensionality of the input points (d) must be ' + 'statically-inferrable.') + + k = train_values.shape[-1] + if k is None: + raise ValueError('The dimensionality of the output values (k) must be ' + 'statically-inferrable.') + + # First, rename variables so that the notation (c, f, w, v, A, B, etc.) + # follows https://en.wikipedia.org/wiki/Polyharmonic_spline. + # To account for python style guidelines we use + # matrix_a for A and matrix_b for B. + + c = train_points + f = train_values + + # Next, construct the linear system. + with tf.name_scope('construct_linear_system'): + + matrix_a = _phi(_pairwise_squared_distance_matrix(c), + order) # [b, n, n] + if regularization_weight > 0: + batch_identity_matrix = tf.expand_dims(tf.eye(n, dtype=c.dtype), 0) + matrix_a += regularization_weight * batch_identity_matrix + + # Append ones to the feature values for the bias term + # in the linear model. + ones = tf.ones_like(c[..., :1], dtype=c.dtype) + matrix_b = tf.concat([c, ones], 2) # [b, n, d + 1] + + # [b, n + d + 1, n] + left_block = tf.concat( + [matrix_a, tf.transpose(matrix_b, [0, 2, 1])], 1) + + num_b_cols = matrix_b.get_shape()[2] # d + 1 + lhs_zeros = tf.zeros([b, num_b_cols, num_b_cols], train_points.dtype) + right_block = tf.concat([matrix_b, lhs_zeros], + 1) # [b, n + d + 1, d + 1] + lhs = tf.concat([left_block, right_block], + 2) # [b, n + d + 1, n + d + 1] + + rhs_zeros = tf.zeros([b, d + 1, k], train_points.dtype) + rhs = tf.concat([f, rhs_zeros], 1) # [b, n + d + 1, k] + + # Then, solve the linear system and unpack the results. + with tf.name_scope('solve_linear_system'): + w_v = tf.linalg.solve(lhs, rhs) + w = w_v[:, :n, :] + v = w_v[:, n:, :] + + return w, v + + +def _apply_interpolation(query_points, train_points, w, v, order): + """Apply polyharmonic interpolation model to data. + + Given coefficients w and v for the interpolation model, we evaluate + interpolated function values at query_points. + + Args: + query_points: `[b, m, d]` x values to evaluate the interpolation at + train_points: `[b, n, d]` x values that act as the interpolation centers + ( the c variables in the wikipedia article) + w: `[b, n, k]` weights on each interpolation center + v: `[b, d, k]` weights on each input dimension + order: order of the interpolation + + Returns: + Polyharmonic interpolation evaluated at points defined in query_points. + """ + + # First, compute the contribution from the rbf term. + pairwise_dists = _cross_squared_distance_matrix(query_points, train_points) + phi_pairwise_dists = _phi(pairwise_dists, order) + + rbf_term = tf.matmul(phi_pairwise_dists, w) + + # Then, compute the contribution from the linear term. + # Pad query_points with ones, for the bias term in the linear model. + query_points_pad = tf.concat([ + query_points, + tf.ones_like(query_points[..., :1], train_points.dtype) + ], 2) + linear_term = tf.matmul(query_points_pad, v) + + return rbf_term + linear_term + + +def _phi(r, order): + """Coordinate-wise nonlinearity used to define the order of the + interpolation. + + See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition. + + Args: + r: input op + order: interpolation order + + Returns: + phi_k evaluated coordinate-wise on r, for k = r + """ + + # using EPSILON prevents log(0), sqrt0), etc. + # sqrt(0) is well-defined, but its gradient is not + with tf.name_scope('phi'): + if order == 1: + r = tf.maximum(r, EPSILON) + r = tf.sqrt(r) + return r + elif order == 2: + return 0.5 * r * tf.math.log(tf.maximum(r, EPSILON)) + elif order == 4: + return 0.5 * tf.square(r) * tf.math.log(tf.maximum(r, EPSILON)) + elif order % 2 == 0: + r = tf.maximum(r, EPSILON) + return 0.5 * tf.pow(r, 0.5 * order) * tf.math.log(r) + else: + r = tf.maximum(r, EPSILON) + return tf.pow(r, 0.5 * order) + + +def interpolate_spline(train_points, + train_values, + query_points, + order, + regularization_weight=0.0, + name='interpolate_spline'): + r"""Interpolate signal using polyharmonic interpolation. + + The interpolant has the form + $$f(x) = \sum_{i = 1}^n w_i \phi(||x - c_i||) + v^T x + b.$$ + + This is a sum of two terms: (1) a weighted sum of radial basis function + (RBF) terms, with the centers \\(c_1, ... c_n\\), and (2) a linear term + with a bias. The \\(c_i\\) vectors are 'training' points. + In the code, b is absorbed into v + by appending 1 as a final dimension to x. The coefficients w and v are + estimated such that the interpolant exactly fits the value of the function + at the \\(c_i\\) points, the vector w is orthogonal to each \\(c_i\\), + and the vector w sums to 0. With these constraints, the coefficients + can be obtained by solving a linear system. + + \\(\phi\\) is an RBF, parametrized by an interpolation + order. Using order=2 produces the well-known thin-plate spline. + + We also provide the option to perform regularized interpolation. Here, the + interpolant is selected to trade off between the squared loss on the + training data and a certain measure of its curvature + ([details](https://en.wikipedia.org/wiki/Polyharmonic_spline)). + Using a regularization weight greater than zero has the effect that the + interpolant will no longer exactly fit the training data. However, it may + be less vulnerable to overfitting, particularly for high-order + interpolation. + + Note the interpolation procedure is differentiable with respect to all + inputs besides the order parameter. + + We support dynamically-shaped inputs, where batch_size, n, and m are None + at graph construction time. However, d and k must be known. + + Args: + train_points: `[batch_size, n, d]` float `Tensor` of n d-dimensional + locations. These do not need to be regularly-spaced. + train_values: `[batch_size, n, k]` float `Tensor` of n c-dimensional + values evaluated at train_points. + query_points: `[batch_size, m, d]` `Tensor` of m d-dimensional locations + where we will output the interpolant's values. + order: order of the interpolation. Common values are 1 for + \\(\phi(r) = r\\), 2 for \\(\phi(r) = r^2 * log(r)\\) + (thin-plate spline), or 3 for \\(\phi(r) = r^3\\). + regularization_weight: weight placed on the regularization term. + This will depend substantially on the problem, and it should always be + tuned. For many problems, it is reasonable to use no regularization. + If using a non-zero value, we recommend a small value like 0.001. + name: name prefix for ops created by this function + + Returns: + `[b, m, k]` float `Tensor` of query values. We use train_points and + train_values to perform polyharmonic interpolation. The query values are + the values of the interpolant evaluated at the locations specified in + query_points. + """ + with tf.name_scope(name or "interpolate_spline"): + train_points = tf.convert_to_tensor(train_points) + train_values = tf.convert_to_tensor(train_values) + query_points = tf.convert_to_tensor(query_points) + + # First, fit the spline to the observed data. + with tf.name_scope('solve'): + w, v = _solve_interpolation(train_points, train_values, order, + regularization_weight) + + # Then, evaluate the spline at the query locations. + with tf.name_scope('predict'): + query_values = _apply_interpolation(query_points, train_points, w, + v, order) + + return query_values diff --git a/tensorflow_addons/image/interpolate_spline_test.py b/tensorflow_addons/image/interpolate_spline_test.py new file mode 100644 index 0000000000..106edf8beb --- /dev/null +++ b/tensorflow_addons/image/interpolate_spline_test.py @@ -0,0 +1,364 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for interpolate_spline.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from scipy import interpolate as sc_interpolate + +import tensorflow as tf +import tensorflow.compat.v1 as tf1 # TODO: locate placeholder +from tensorflow_addons.utils import test_utils +from tensorflow_addons.image import interpolate_spline + + +class _InterpolationProblem(object): + """Abstract class for interpolation problem descriptions.""" + + def get_problem(self, optimizable=False, extrapolate=True, + dtype='float32'): + """Make data for an interpolation problem where all x vectors are n-d. + + Args: + optimizable: If True, then make train_points a tf.Variable. + extrapolate: If False, then clamp the query_points values to be within + the max and min of train_points. + dtype: The data type to use. + + Returns: + query_points, query_values, train_points, train_values: training and + test tensors for interpolation problem + """ + + # The values generated here depend on a seed of 0. + np.random.seed(0) + + batch_size = 1 + num_training_points = 10 + num_query_points = 4 + + init_points = np.random.uniform( + size=[batch_size, num_training_points, self.DATA_DIM]) + + init_points = init_points.astype(dtype) + train_points = (tf.Variable(init_points) + if optimizable else tf.constant(init_points)) + train_values = self.tf_function(train_points) + + query_points_np = np.random.uniform( + size=[batch_size, num_query_points, self.DATA_DIM]) + query_points_np = query_points_np.astype(dtype) + if not extrapolate: + query_points_np = np.clip(query_points_np, np.min(init_points), + np.max(init_points)) + + query_points = tf.constant(query_points_np) + query_values = self.np_function(query_points_np) + + return query_points, query_values, train_points, train_values + + +class _QuadraticPlusSinProblem1D(_InterpolationProblem): + """1D interpolation problem used for regression testing.""" + DATA_DIM = 1 + HARDCODED_QUERY_VALUES = { + (1.0, 0.0): + [6.2647187603, -7.84362604077, -5.63690142322, 1.42928896387], + (1.0, 0.01): + [6.77688289946, -8.02163669853, -5.79491157027, 1.4063285693], + (2.0, 0.0): + [8.67110264937, -8.41281390883, -5.80190044693, 1.50155606059], + (2.0, 0.01): [ + 6.70797816797, -7.49709587663, -5.28965776238, 1.52284731741 + ], + (3.0, 0.0): [ + 9.37691802935, -8.50390141515, -5.80786417426, 1.63467762122 + ], + (3.0, 0.01): [ + 4.47106304758, -5.71266128361, -3.92529303296, 1.86755293857 + ], + (4.0, 0.0): [ + 9.58172461111, -8.51432104771, -5.80967675388, 1.63361164256 + ], + (4.0, 0.01): [ + -3.87902711352, -0.0253462273846, 1.79857618022, -0.769339675725 + ] + } + + def np_function(self, x): + """Takes np array, evaluates the test function, and returns np + array.""" + return np.sum( + np.power((x - 0.5), 3) - 0.25 * x + 10 * np.sin(x * 10), + axis=2, + keepdims=True) + + def tf_function(self, x): + """Takes tf tensor, evaluates the test function, and returns tf + tensor.""" + return tf.reduce_mean( + tf.pow((x - 0.5), 3) - 0.25 * x + 10 * tf.sin(x * 10), + 2, + keepdims=True) + + +class _QuadraticPlusSinProblemND(_InterpolationProblem): + """3D interpolation problem used for regression testing.""" + + DATA_DIM = 3 + HARDCODED_QUERY_VALUES = { + (1.0, 0.0): + [1.06609663962, 1.28894849357, 1.10882405595, 1.63966936885], + (1.0, 0.01): + [1.03123780748, 1.2952930985, 1.10366822954, 1.65265118569], + (2.0, 0.0): + [0.627787735064, 1.43802857251, 1.00194632358, 1.91667538215], + (2.0, 0.01): [ + 0.730159985046, 1.41702471595, 1.0065827217, 1.85758519312 + ], + (3.0, 0.0): [ + 0.350460417862, 1.67223539464, 1.00475331246, 2.31580322491 + ], + (3.0, 0.01): [ + 0.624557250556, 1.63138876667, 0.976588193162, 2.12511237866 + ], + (4.0, 0.0): [ + 0.898129669986, 1.24434133638, -0.938056116931, 1.59910338833 + ], + (4.0, 0.01): [ + 0.0930360338179, -3.38791305538, -1.00969032567, 0.745535080382 + ], + } + + def np_function(self, x): + """Takes np array, evaluates the test function, and returns np + array.""" + return np.sum( + np.square(x - 0.5) + 0.25 * x + 1 * np.sin(x * 15), + axis=2, + keepdims=True) + + def tf_function(self, x): + """Takes tf tensor, evaluates the test function, and returns tf + tensor.""" + return tf.reduce_sum( + tf.square(x - 0.5) + 0.25 * x + 1 * tf.sin(x * 15), + 2, + keepdims=True) + + +class InterpolateSplineTest(tf.test.TestCase): + def test_1d_linear_interpolation(self): + """For 1d linear interpolation, we can compare directly to scipy.""" + + tp = _QuadraticPlusSinProblem1D() + (query_points, _, train_points, train_values) = tp.get_problem( + extrapolate=False, dtype='float64') + interpolation_order = 1 + + with tf.name_scope('interpolator'): + interpolator = interpolate_spline( + train_points, train_values, query_points, interpolation_order) + with self.cached_session() as sess: + fetches = [ + query_points, train_points, train_values, interpolator + ] + query_points_, train_points_, train_values_, interp_ = sess.run( # pylint: disable=C0301 + fetches) + + # Just look at the first element of the minibatch. + # Also, trim the final singleton dimension. + interp_ = interp_[0, :, 0] + query_points_ = query_points_[0, :, 0] + train_points_ = train_points_[0, :, 0] + train_values_ = train_values_[0, :, 0] + + # Compute scipy interpolation. + scipy_interp_function = sc_interpolate.interp1d( + train_points_, train_values_, kind='linear') + + scipy_interpolation = scipy_interp_function(query_points_) + scipy_interpolation_on_train = scipy_interp_function( + train_points_) + + # Even with float64 precision, the interpolants disagree with scipy a + # bit due to the fact that we add the EPSILON to prevent sqrt(0), etc. + tol = 1e-3 + + self.assertAllClose( + train_values_, + scipy_interpolation_on_train, + atol=tol, + rtol=tol) + self.assertAllClose( + interp_, scipy_interpolation, atol=tol, rtol=tol) + + def test_1d_interpolation(self): + """Regression test for interpolation with 1-D points.""" + + tp = _QuadraticPlusSinProblem1D() + (query_points, _, train_points, + train_values) = tp.get_problem(dtype='float64') + + for order in (1, 2, 3): + for reg_weight in (0, 0.01): + interpolator = interpolate_spline(train_points, train_values, + query_points, order, + reg_weight) + + target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, + reg_weight)] + target_interpolation = np.array(target_interpolation) + with self.cached_session() as sess: + interp_val = sess.run(interpolator) + self.assertAllClose(interp_val[0, :, 0], + target_interpolation) + + def test_nd_linear_interpolation(self): + """Regression test for interpolation with N-D points.""" + + tp = _QuadraticPlusSinProblemND() + (query_points, _, train_points, + train_values) = tp.get_problem(dtype='float64') + + for order in (1, 2, 3): + for reg_weight in (0, 0.01): + interpolator = interpolate_spline(train_points, train_values, + query_points, order, + reg_weight) + + target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, + reg_weight)] + target_interpolation = np.array(target_interpolation) + with self.cached_session() as sess: + interp_val = sess.run(interpolator) + self.assertAllClose(interp_val[0, :, 0], + target_interpolation) + + @test_utils.run_deprecated_v1 + def test_nd_linear_interpolation_unspecified_shape(self): + """Ensure that interpolation supports dynamic batch_size and + num_points.""" + tp = _QuadraticPlusSinProblemND() + (query_points, _, train_points, + train_values) = tp.get_problem(dtype='float64') + + # Construct placeholders such that the batch size, number of train points, + # and number of query points are not known at graph construction time. + feature_dim = query_points.shape[-1] + value_dim = train_values.shape[-1] + train_points_ph = tf1.placeholder( + dtype=train_points.dtype, shape=[None, None, feature_dim]) + train_values_ph = tf1.placeholder( + dtype=train_values.dtype, shape=[None, None, value_dim]) + query_points_ph = tf1.placeholder( + dtype=query_points.dtype, shape=[None, None, feature_dim]) + + order = 1 + reg_weight = 0.01 + + interpolator = interpolate_spline(train_points_ph, train_values_ph, + query_points_ph, order, reg_weight) + + target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)] + target_interpolation = np.array(target_interpolation) + with self.cached_session() as sess: + + (train_points_value, train_values_value, + query_points_value) = sess.run( + [train_points, train_values, query_points]) + + interp_val = sess.run( + interpolator, + feed_dict={ + train_points_ph: train_points_value, + train_values_ph: train_values_value, + query_points_ph: query_points_value + }) + self.assertAllClose(interp_val[0, :, 0], target_interpolation) + + def test_fully_unspecified_shape(self): + """Ensure that erreor is thrown when input/output dim unspecified.""" + self.skipTest("TODO: port to tf2.0 / eager") + tp = _QuadraticPlusSinProblemND() + (query_points, _, train_points, + train_values) = tp.get_problem(dtype='float64') + + # Construct placeholders such that the batch size, number of train points, + # and number of query points are not known at graph construction time. + feature_dim = query_points.shape[-1] + value_dim = train_values.shape[-1] + train_points_ph = tf1.placeholder( + dtype=train_points.dtype, shape=[None, None, feature_dim]) + train_points_ph_invalid = tf1.placeholder( + dtype=train_points.dtype, shape=[None, None, None]) + train_values_ph = tf1.placeholder( + dtype=train_values.dtype, shape=[None, None, value_dim]) + train_values_ph_invalid = tf1.placeholder( + dtype=train_values.dtype, shape=[None, None, None]) + query_points_ph = tf1.placeholder( + dtype=query_points.dtype, shape=[None, None, feature_dim]) + + order = 1 + reg_weight = 0.01 + + with self.assertRaises(ValueError): + _ = interpolate_spline(train_points_ph_invalid, train_values_ph, + query_points_ph, order, reg_weight) + + with self.assertRaises(ValueError): + _ = interpolate_spline(train_points_ph, train_values_ph_invalid, + query_points_ph, order, reg_weight) + + def test_interpolation_gradient(self): + """Make sure that backprop can run. Correctness of gradients is + assumed. + + Here, we create a use a small 'training' set and a more densely- + sampled set of query points, for which we know the true value in + advance. The goal is to choose x locations for the training data + such that interpolating using this training data yields the best + reconstruction for the function values at the query points. The + training data locations are optimized iteratively using gradient + descent. + """ + tp = _QuadraticPlusSinProblemND() + (query_points, query_values, train_points, + train_values) = tp.get_problem(optimizable=True) + + regularization = 0.001 + for interpolation_order in (1, 2, 3, 4): + optimizer = tf1.train.MomentumOptimizer(0.001, 0.9) + + @tf.function + def train_step(): + with tf.GradientTape() as gt: + interpolator = interpolate_spline( + train_points, train_values, query_points, + interpolation_order, regularization) + loss = tf.reduce_mean( + tf.square(query_values - interpolator)) + grad = gt.gradient(loss, [train_points]) + grad, _ = tf.clip_by_global_norm(grad, 1.0) + opt_func = optimizer.apply_gradients(zip(grad, [train_points])) + + for epoch in range(100): + train_step() + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_addons/image/sparse_image_warp.py b/tensorflow_addons/image/sparse_image_warp.py new file mode 100644 index 0000000000..b1697eab0f --- /dev/null +++ b/tensorflow_addons/image/sparse_image_warp.py @@ -0,0 +1,200 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Image warping using sparse flow defined at control points.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf +from tensorflow_addons.image import dense_image_warp +from tensorflow_addons.image import interpolate_spline + + +def _get_grid_locations(image_height, image_width): + """Wrapper for np.meshgrid.""" + + y_range = np.linspace(0, image_height - 1, image_height) + x_range = np.linspace(0, image_width - 1, image_width) + y_grid, x_grid = np.meshgrid(y_range, x_range, indexing='ij') + return np.stack((y_grid, x_grid), -1) + + +def _expand_to_minibatch(np_array, batch_size): + """Tile arbitrarily-sized np_array to include new batch dimension.""" + tiles = [batch_size] + [1] * np_array.ndim + return np.tile(np.expand_dims(np_array, 0), tiles) + + +def _get_boundary_locations(image_height, image_width, num_points_per_edge): + """Compute evenly-spaced indices along edge of image.""" + y_range = np.linspace(0, image_height - 1, num_points_per_edge + 2) + x_range = np.linspace(0, image_width - 1, num_points_per_edge + 2) + ys, xs = np.meshgrid(y_range, x_range, indexing='ij') + is_boundary = np.logical_or( + np.logical_or(xs == 0, xs == image_width - 1), + np.logical_or(ys == 0, ys == image_height - 1)) + return np.stack([ys[is_boundary], xs[is_boundary]], axis=-1) + + +def _add_zero_flow_controls_at_boundary(control_point_locations, + control_point_flows, image_height, + image_width, boundary_points_per_edge): + """Add control points for zero-flow boundary conditions. + + Augment the set of control points with extra points on the + boundary of the image that have zero flow. + + Args: + control_point_locations: input control points + control_point_flows: their flows + image_height: image height + image_width: image width + boundary_points_per_edge: number of points to add in the middle of each + edge (not including the corners). + The total number of points added is + 4 + 4*(boundary_points_per_edge). + + Returns: + merged_control_point_locations: augmented set of control point locations + merged_control_point_flows: augmented set of control point flows + """ + + batch_size = tf.compat.dimension_value(control_point_locations.shape[0]) + + boundary_point_locations = _get_boundary_locations( + image_height, image_width, boundary_points_per_edge) + + boundary_point_flows = np.zeros([boundary_point_locations.shape[0], 2]) + + type_to_use = control_point_locations.dtype + boundary_point_locations = tf.constant( + _expand_to_minibatch(boundary_point_locations, batch_size), + dtype=type_to_use) + + boundary_point_flows = tf.constant( + _expand_to_minibatch(boundary_point_flows, batch_size), + dtype=type_to_use) + + merged_control_point_locations = tf.concat( + [control_point_locations, boundary_point_locations], 1) + + merged_control_point_flows = tf.concat( + [control_point_flows, boundary_point_flows], 1) + + return merged_control_point_locations, merged_control_point_flows + + +def sparse_image_warp(image, + source_control_point_locations, + dest_control_point_locations, + interpolation_order=2, + regularization_weight=0.0, + num_boundary_points=0, + name='sparse_image_warp'): + """Image warping using correspondences between sparse control points. + + Apply a non-linear warp to the image, where the warp is specified by + the source and destination locations of a (potentially small) number of + control points. First, we use a polyharmonic spline + (`tf.contrib.image.interpolate_spline`) to interpolate the displacements + between the corresponding control points to a dense flow field. + Then, we warp the image using this dense flow field + (`tf.contrib.image.dense_image_warp`). + + Let t index our control points. For regularization_weight=0, we have: + warped_image[b, dest_control_point_locations[b, t, 0], + dest_control_point_locations[b, t, 1], :] = + image[b, source_control_point_locations[b, t, 0], + source_control_point_locations[b, t, 1], :]. + + For regularization_weight > 0, this condition is met approximately, since + regularized interpolation trades off smoothness of the interpolant vs. + reconstruction of the interpolant at the control points. + See `tf.contrib.image.interpolate_spline` for further documentation of the + interpolation_order and regularization_weight arguments. + + + Args: + image: `[batch, height, width, channels]` float `Tensor` + source_control_point_locations: `[batch, num_control_points, 2]` float + `Tensor` + dest_control_point_locations: `[batch, num_control_points, 2]` float + `Tensor` + interpolation_order: polynomial order used by the spline interpolation + regularization_weight: weight on smoothness regularizer in interpolation + num_boundary_points: How many zero-flow boundary points to include at + each image edge.Usage: + num_boundary_points=0: don't add zero-flow points + num_boundary_points=1: 4 corners of the image + num_boundary_points=2: 4 corners and one in the middle of each edge + (8 points total) + num_boundary_points=n: 4 corners and n-1 along each edge + name: A name for the operation (optional). + + Note that image and offsets can be of type tf.half, tf.float32, or + tf.float64, and do not necessarily have to be the same type. + + Returns: + warped_image: `[batch, height, width, channels]` float `Tensor` with same + type as input image. + flow_field: `[batch, height, width, 2]` float `Tensor` containing the + dense flow field produced by the interpolation. + """ + + image = tf.convert_to_tensor(image) + source_control_point_locations = tf.convert_to_tensor( + source_control_point_locations) + dest_control_point_locations = tf.convert_to_tensor( + dest_control_point_locations) + + control_point_flows = ( + dest_control_point_locations - source_control_point_locations) + + clamp_boundaries = num_boundary_points > 0 + boundary_points_per_edge = num_boundary_points - 1 + + with tf.name_scope(name or "sparse_image_warp"): + + batch_size, image_height, image_width, _ = image.get_shape().as_list() + + # This generates the dense locations where the interpolant + # will be evaluated. + grid_locations = _get_grid_locations(image_height, image_width) + + flattened_grid_locations = np.reshape(grid_locations, + [image_height * image_width, 2]) + + flattened_grid_locations = tf.constant( + _expand_to_minibatch(flattened_grid_locations, batch_size), + image.dtype) + + if clamp_boundaries: + (dest_control_point_locations, + control_point_flows) = _add_zero_flow_controls_at_boundary( + dest_control_point_locations, control_point_flows, + image_height, image_width, boundary_points_per_edge) + + flattened_flows = interpolate_spline.interpolate_spline( + dest_control_point_locations, control_point_flows, + flattened_grid_locations, interpolation_order, + regularization_weight) + + dense_flows = tf.reshape(flattened_flows, + [batch_size, image_height, image_width, 2]) + + warped_image = dense_image_warp(image, dense_flows) + + return warped_image, dense_flows diff --git a/tensorflow_addons/image/sparse_image_warp_test.py b/tensorflow_addons/image/sparse_image_warp_test.py new file mode 100644 index 0000000000..4c2659c7d5 --- /dev/null +++ b/tensorflow_addons/image/sparse_image_warp_test.py @@ -0,0 +1,254 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for sparse_image_warp.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf +import tensorflow.compat.v1 as tf1 # TODO: port TF1 test files? +from tensorflow_addons.image.sparse_image_warp import _get_boundary_locations +from tensorflow_addons.image.sparse_image_warp import _get_grid_locations +from tensorflow_addons.image import sparse_image_warp +from tensorflow_addons.utils.resource_loader import get_path_to_datafile + + +class SparseImageWarpTest(tf.test.TestCase): + def setUp(self): + np.random.seed(0) + + def testGetBoundaryLocations(self): + image_height = 11 + image_width = 11 + num_points_per_edge = 4 + locs = _get_boundary_locations(image_height, image_width, + num_points_per_edge) + num_points = locs.shape[0] + self.assertEqual(num_points, 4 + 4 * num_points_per_edge) + locs = [(locs[i, 0], locs[i, 1]) for i in range(num_points)] + for i in (0, image_height - 1): + for j in (0, image_width - 1): + self.assertIn((i, j), locs, + '{},{} not in the locations'.format(i, j)) + + for i in (2, 4, 6, 8): + for j in (0, image_width - 1): + self.assertIn((i, j), locs, + '{},{} not in the locations'.format(i, j)) + + for i in (0, image_height - 1): + for j in (2, 4, 6, 8): + self.assertIn((i, j), locs, + '{},{} not in the locations'.format(i, j)) + + def testGetGridLocations(self): + image_height = 5 + image_width = 3 + grid = _get_grid_locations(image_height, image_width) + for i in range(image_height): + for j in range(image_width): + self.assertEqual(grid[i, j, 0], i) + self.assertEqual(grid[i, j, 1], j) + + def testZeroShift(self): + """Run assertZeroShift for various hyperparameters.""" + for order in (1, 2): + for regularization in (0, 0.01): + for num_boundary_points in (0, 1): + self.assertZeroShift(order, regularization, + num_boundary_points) + + def assertZeroShift(self, order, regularization, num_boundary_points): + """Check that warping with zero displacements doesn't change the + image.""" + batch_size = 1 + image_height = 4 + image_width = 4 + channels = 3 + + image = np.random.uniform( + size=[batch_size, image_height, image_width, channels]) + + input_image_op = tf.constant(np.float32(image)) + + control_point_locations = [[1., 1.], [2., 2.], [2., 1.]] + control_point_locations = tf.constant( + np.float32(np.expand_dims(control_point_locations, 0))) + + control_point_displacements = np.zeros( + control_point_locations.shape.as_list()) + control_point_displacements = tf.constant( + np.float32(control_point_displacements)) + + (warped_image_op, flow_field) = sparse_image_warp( + input_image_op, + control_point_locations, + control_point_locations + control_point_displacements, + interpolation_order=order, + regularization_weight=regularization, + num_boundary_points=num_boundary_points) + + with self.cached_session() as sess: + warped_image, input_image, _ = sess.run( + [warped_image_op, input_image_op, flow_field]) + + self.assertAllClose(warped_image, input_image) + + def testMoveSinglePixel(self): + """Run assertMoveSinglePixel for various hyperparameters and data + types.""" + for order in (1, 2): + for num_boundary_points in (1, 2): + for type_to_use in (tf.dtypes.float32, tf.dtypes.float64): + self.assertMoveSinglePixel(order, num_boundary_points, + type_to_use) + + def assertMoveSinglePixel(self, order, num_boundary_points, type_to_use): + """Move a single block in a small grid using warping.""" + batch_size = 1 + image_height = 7 + image_width = 7 + channels = 3 + + image = np.zeros([batch_size, image_height, image_width, channels]) + image[:, 3, 3, :] = 1.0 + input_image_op = tf.constant(image, dtype=type_to_use) + + # Place a control point at the one white pixel. + control_point_locations = [[3., 3.]] + control_point_locations = tf.constant( + np.float32(np.expand_dims(control_point_locations, 0)), + dtype=type_to_use) + # Shift it one pixel to the right. + control_point_displacements = [[0., 1.0]] + control_point_displacements = tf.constant( + np.float32(np.expand_dims(control_point_displacements, 0)), + dtype=type_to_use) + + (warped_image_op, flow_field) = sparse_image_warp( + input_image_op, + control_point_locations, + control_point_locations + control_point_displacements, + interpolation_order=order, + num_boundary_points=num_boundary_points) + + with self.cached_session() as sess: + warped_image, input_image, flow = sess.run( + [warped_image_op, input_image_op, flow_field]) + # Check that it moved the pixel correctly. + self.assertAllClose( + warped_image[0, 4, 5, :], + input_image[0, 4, 4, :], + atol=1e-5, + rtol=1e-5) + + # Test that there is no flow at the corners. + for i in (0, image_height - 1): + for j in (0, image_width - 1): + self.assertAllClose( + flow[0, i, j, :], np.zeros([2]), atol=1e-5, rtol=1e-5) + + def load_image(self, image_file, sess): + image_op = tf.image.decode_png( + tf.io.read_file(image_file), dtype=tf.dtypes.uint8, + channels=4)[:, :, 0:3] + return sess.run(image_op) + + def testSmileyFace(self): + """Check warping accuracy by comparing to hardcoded warped images.""" + + input_file = get_path_to_datafile( + "image/test_data/Yellow_Smiley_Face.png") + with self.cached_session() as sess: + input_image = self.load_image(input_file, sess) + control_points = np.asarray([[64, 59], [180 - 64, 59], [39, 111], + [180 - 39, 111], [90, 143], [58, 134], + [180 - 58, 134]]) # pyformat: disable + control_point_displacements = np.asarray([[-10.5, 10.5], [10.5, 10.5], + [0, 0], [0, 0], [0, -10], + [-20, 10.25], [10, 10.75]]) + control_points_op = tf.constant( + np.expand_dims(np.float32(control_points[:, [1, 0]]), 0)) + control_point_displacements_op = tf.constant( + np.expand_dims( + np.float32(control_point_displacements[:, [1, 0]]), 0)) + float_image = np.expand_dims(np.float32(input_image) / 255, 0) + input_image_op = tf.constant(float_image) + + for interpolation_order in (1, 2, 3): + for num_boundary_points in (0, 1, 4): + warp_op, _ = sparse_image_warp( + input_image_op, + control_points_op, + control_points_op + control_point_displacements_op, + interpolation_order=interpolation_order, + num_boundary_points=num_boundary_points) + with self.cached_session() as sess: + warped_image = sess.run(warp_op) + out_image = np.uint8(warped_image[0, :, :, :] * 255) + target_file = get_path_to_datafile( + "image/test_data/Yellow_Smiley_Face_Warp-interp" + + "-{}-clamp-{}.png".format(interpolation_order, + num_boundary_points)) + + target_image = self.load_image(target_file, sess) + + # Check that the target_image and out_image difference is no + # bigger than 2 (on a scale of 0-255). Due to differences in + # floating point computation on different devices, the float + # output in warped_image may get rounded to a different int + # than that in the saved png file loaded into target_image. + self.assertAllClose( + target_image, out_image, atol=2, rtol=1e-3) + + def testThatBackpropRuns(self): + """Run optimization to ensure that gradients can be computed.""" + self.skipTest("TODO: port to tf2.0 / eager") + batch_size = 1 + image_height = 9 + image_width = 12 + image = tf.Variable( + np.float32( + np.random.uniform( + size=[batch_size, image_height, image_width, 3]))) + control_point_locations = [[3., 3.]] + control_point_locations = tf.constant( + np.float32(np.expand_dims(control_point_locations, 0))) + control_point_displacements = [[0.25, -0.5]] + control_point_displacements = tf.constant( + np.float32(np.expand_dims(control_point_displacements, 0))) + warped_image, _ = sparse_image_warp( + image, + control_point_locations, + control_point_locations + control_point_displacements, + num_boundary_points=3) + + loss = tf.reduce_mean(tf.abs(warped_image - image)) + optimizer = tf1.train.MomentumOptimizer(0.001, 0.9) + grad = tf.gradients(loss, [image]) + grad, _ = tf.clip_by_global_norm(grad, 1.0) + opt_func = optimizer.apply_gradients(zip(grad, [image])) + init_op = tf1.variables.global_variables_initializer( + ) # TODO: fix TF1 ref. + + with self.cached_session() as sess: + sess.run(init_op) + for _ in range(5): + sess.run([loss, opt_func]) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_addons/image/test_data/Yellow_Smiley_Face.png b/tensorflow_addons/image/test_data/Yellow_Smiley_Face.png new file mode 100644 index 0000000000..7e303881e2 Binary files /dev/null and b/tensorflow_addons/image/test_data/Yellow_Smiley_Face.png differ diff --git a/tensorflow_addons/image/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-0.png b/tensorflow_addons/image/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-0.png new file mode 100644 index 0000000000..7fd9e4e6d6 Binary files /dev/null and b/tensorflow_addons/image/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-0.png differ diff --git a/tensorflow_addons/image/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-1.png b/tensorflow_addons/image/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-1.png new file mode 100644 index 0000000000..86d225e5d2 Binary files /dev/null and b/tensorflow_addons/image/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-1.png differ diff --git a/tensorflow_addons/image/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-4.png b/tensorflow_addons/image/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-4.png new file mode 100644 index 0000000000..37e8ffae11 Binary files /dev/null and b/tensorflow_addons/image/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-4.png differ diff --git a/tensorflow_addons/image/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-0.png b/tensorflow_addons/image/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-0.png new file mode 100644 index 0000000000..e49b581612 Binary files /dev/null and b/tensorflow_addons/image/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-0.png differ diff --git a/tensorflow_addons/image/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-1.png b/tensorflow_addons/image/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-1.png new file mode 100644 index 0000000000..df3cf20043 Binary files /dev/null and b/tensorflow_addons/image/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-1.png differ diff --git a/tensorflow_addons/image/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-4.png b/tensorflow_addons/image/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-4.png new file mode 100644 index 0000000000..e1799a87c8 Binary files /dev/null and b/tensorflow_addons/image/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-4.png differ diff --git a/tensorflow_addons/image/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-0.png b/tensorflow_addons/image/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-0.png new file mode 100644 index 0000000000..2c346e0ce5 Binary files /dev/null and b/tensorflow_addons/image/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-0.png differ diff --git a/tensorflow_addons/image/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-1.png b/tensorflow_addons/image/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-1.png new file mode 100644 index 0000000000..6f8b65451c Binary files /dev/null and b/tensorflow_addons/image/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-1.png differ diff --git a/tensorflow_addons/image/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-4.png b/tensorflow_addons/image/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-4.png new file mode 100644 index 0000000000..8e78146d95 Binary files /dev/null and b/tensorflow_addons/image/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-4.png differ diff --git a/tensorflow_addons/seq2seq/attention_wrapper.py b/tensorflow_addons/seq2seq/attention_wrapper.py index 3abe62bd8c..54b1bf1be8 100644 --- a/tensorflow_addons/seq2seq/attention_wrapper.py +++ b/tensorflow_addons/seq2seq/attention_wrapper.py @@ -1899,8 +1899,8 @@ def call(self, inputs, state, **kwargs): # previous attention value. cell_inputs = self._cell_input_fn(inputs, state.attention) cell_state = state.cell_state - cell_output, next_cell_state = self._cell( - cell_inputs, cell_state, **kwargs) + cell_output, next_cell_state = self._cell(cell_inputs, cell_state, + **kwargs) cell_batch_size = (tf.compat.dimension_value(cell_output.shape[0]) or tf.shape(cell_output)[0])