From 1b54fd7f1181342ba754287ef98908a7cefa14c2 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Sat, 15 Jun 2019 18:29:28 +0800 Subject: [PATCH] fix shape checking --- tensorflow_addons/image/dense_image_warp.py | 71 ++++++++++++------- .../image/dense_image_warp_test.py | 6 +- 2 files changed, 50 insertions(+), 27 deletions(-) diff --git a/tensorflow_addons/image/dense_image_warp.py b/tensorflow_addons/image/dense_image_warp.py index 74570feb72..cfc0b7a1e1 100644 --- a/tensorflow_addons/image/dense_image_warp.py +++ b/tensorflow_addons/image/dense_image_warp.py @@ -55,6 +55,15 @@ def interpolate_bilinear(grid, query_points, indexing="ij", name=None): if len(query_points.shape) != 3: raise ValueError("Query points must be 3 dimensional.") + if query_points.shape[2] is not None and query_points.shape[2] != 2: + raise ValueError("Query points must be size 2 in dim 2.") + + if grid.shape[1] is not None and grid.shape[1] < 2: + raise ValueError("Grid height must be at least 2.") + + if grid.shape[2] is not None and grid.shape[2] < 2: + raise ValueError("Grid width must be at least 2.") + grid_shape = tf.shape(grid) query_shape = tf.shape(query_points) @@ -62,24 +71,33 @@ def interpolate_bilinear(grid, query_points, indexing="ij", name=None): grid_shape[2], grid_shape[3]) shape = [batch_size, height, width, channels] - num_queries = query_shape[1] + + # pylint: disable=bad-continuation + with tf.control_dependencies([ + tf.debugging.assert_equal( + query_shape[2], + 2, + message="Query points must be size 2 in dim 2.") + ]): + num_queries = query_shape[1] + # pylint: enable=bad-continuation query_type = query_points.dtype grid_type = grid.dtype - tf.debugging.assert_equal( - query_shape[2], 2, message="Query points must be size 2 in dim 2.") - - tf.debugging.assert_greater_equal( - height, 2, message="Grid height must be at least 2."), - tf.debugging.assert_greater_equal( - width, 2, message="Grid width must be at least 2.") - - alphas = [] - floors = [] - ceils = [] - index_order = [0, 1] if indexing == "ij" else [1, 0] - unstacked_query_points = tf.unstack(query_points, axis=2) + # pylint: disable=bad-continuation + with tf.control_dependencies([ + tf.debugging.assert_greater_equal( + height, 2, message="Grid height must be at least 2."), + tf.debugging.assert_greater_equal( + width, 2, message="Grid width must be at least 2."), + ]): + alphas = [] + floors = [] + ceils = [] + index_order = [0, 1] if indexing == "ij" else [1, 0] + unstacked_query_points = tf.unstack(query_points, axis=2) + # pylint: enable=bad-continuation for dim in index_order: with tf.name_scope("dim-" + str(dim)): @@ -112,16 +130,21 @@ def interpolate_bilinear(grid, query_points, indexing="ij", name=None): alpha = tf.expand_dims(alpha, 2) alphas.append(alpha) - tf.debugging.assert_less_equal( - tf.cast(batch_size * height * width, dtype=tf.dtypes.float32), - np.iinfo(np.int32).max / 8.0, - message="The image size or batch size is sufficiently large " - "that the linearized addresses used by tf.gather " - "may exceed the int32 limit.") - flattened_grid = tf.reshape(grid, - [batch_size * height * width, channels]) - batch_offsets = tf.reshape( - tf.range(batch_size) * height * width, [batch_size, 1]) + # pylint: disable=bad-continuation + with tf.control_dependencies([ + tf.debugging.assert_less_equal( + tf.cast( + batch_size * height * width, dtype=tf.dtypes.float32), + np.iinfo(np.int32).max / 8.0, + message="The image size or batch size is sufficiently " + "large that the linearized addresses used by tf.gather " + "may exceed the int32 limit.") + ]): + flattened_grid = tf.reshape( + grid, [batch_size * height * width, channels]) + batch_offsets = tf.reshape( + tf.range(batch_size) * height * width, [batch_size, 1]) + # pylint: enable=bad-continuation # This wraps tf.gather. We reshape the image data such that the # batch, y, and x coordinates are pulled into the first dimension. diff --git a/tensorflow_addons/image/dense_image_warp_test.py b/tensorflow_addons/image/dense_image_warp_test.py index 872dd3f6b5..8cc858a91d 100644 --- a/tensorflow_addons/image/dense_image_warp_test.py +++ b/tensorflow_addons/image/dense_image_warp_test.py @@ -214,12 +214,12 @@ def test_gradients_exist(self): for _ in range(10): sess.run(opt_func) - # TODO: run in both graph and eager modes + @test_utils.run_in_graph_and_eager_modes def test_size_exception(self): """Make sure it throws an exception for images that are too small.""" shape = [1, 2, 1, 1] - with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, - "Grid width must be at least 2."): + errors = (ValueError, tf.errors.InvalidArgumentError) + with self.assertRaisesRegexp(errors, "Grid width must be at least 2."): self._check_interpolation_correctness(shape, "float32", "float32")