diff --git a/tensorflow_addons/image/dense_image_warp.py b/tensorflow_addons/image/dense_image_warp.py index fe6af1d753..74570feb72 100644 --- a/tensorflow_addons/image/dense_image_warp.py +++ b/tensorflow_addons/image/dense_image_warp.py @@ -21,7 +21,6 @@ import tensorflow as tf -@tf.function def interpolate_bilinear(grid, query_points, indexing="ij", name=None): """Similar to Matlab's interp2 function. @@ -48,30 +47,28 @@ def interpolate_bilinear(grid, query_points, indexing="ij", name=None): with tf.name_scope(name or "interpolate_bilinear"): grid = tf.convert_to_tensor(grid) query_points = tf.convert_to_tensor(query_points) - shape = grid.get_shape().as_list() - if len(shape) != 4: + + if len(grid.shape) != 4: msg = "Grid must be 4 dimensional. Received size: " - raise ValueError(msg + str(grid.get_shape())) + raise ValueError(msg + str(grid.shape)) + + if len(query_points.shape) != 3: + raise ValueError("Query points must be 3 dimensional.") + + grid_shape = tf.shape(grid) + query_shape = tf.shape(query_points) - batch_size, height, width, channels = (tf.shape(grid)[0], - tf.shape(grid)[1], - tf.shape(grid)[2], - tf.shape(grid)[3]) + batch_size, height, width, channels = (grid_shape[0], grid_shape[1], + grid_shape[2], grid_shape[3]) shape = [batch_size, height, width, channels] + num_queries = query_shape[1] + query_type = query_points.dtype grid_type = grid.dtype tf.debugging.assert_equal( - len(query_points.get_shape()), - 3, - message="Query points must be 3 dimensional.") - tf.debugging.assert_equal( - tf.shape(query_points)[2], - 2, - message="Query points must be size 2 in dim 2.") - - num_queries = tf.shape(query_points)[1] + query_shape[2], 2, message="Query points must be size 2 in dim 2.") tf.debugging.assert_greater_equal( height, 2, message="Grid height must be at least 2."), diff --git a/tensorflow_addons/image/dense_image_warp_test.py b/tensorflow_addons/image/dense_image_warp_test.py index 8570aed2a1..6810d6a936 100644 --- a/tensorflow_addons/image/dense_image_warp_test.py +++ b/tensorflow_addons/image/dense_image_warp_test.py @@ -134,7 +134,7 @@ def _check_zero_flow_correctness(self, shape, image_type, flow_type): self.assertAllClose(rand_image, interp) - # TODO: run in both graph and eager modes + @test_utils.run_in_graph_and_eager_modes def test_zero_flows(self): """Apply _check_zero_flow_correctness() for a few sizes and types.""" shapes_to_try = [[3, 4, 5, 6], [1, 2, 2, 1]]