Skip to content

Commit 320ad67

Browse files
WindQAQseanpmorgan
authored andcommitted
fix shape checking (#290)
1 parent 829cd88 commit 320ad67

File tree

2 files changed

+50
-27
lines changed

2 files changed

+50
-27
lines changed

tensorflow_addons/image/dense_image_warp.py

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -55,31 +55,49 @@ def interpolate_bilinear(grid, query_points, indexing="ij", name=None):
5555
if len(query_points.shape) != 3:
5656
raise ValueError("Query points must be 3 dimensional.")
5757

58+
if query_points.shape[2] is not None and query_points.shape[2] != 2:
59+
raise ValueError("Query points must be size 2 in dim 2.")
60+
61+
if grid.shape[1] is not None and grid.shape[1] < 2:
62+
raise ValueError("Grid height must be at least 2.")
63+
64+
if grid.shape[2] is not None and grid.shape[2] < 2:
65+
raise ValueError("Grid width must be at least 2.")
66+
5867
grid_shape = tf.shape(grid)
5968
query_shape = tf.shape(query_points)
6069

6170
batch_size, height, width, channels = (grid_shape[0], grid_shape[1],
6271
grid_shape[2], grid_shape[3])
6372

6473
shape = [batch_size, height, width, channels]
65-
num_queries = query_shape[1]
74+
75+
# pylint: disable=bad-continuation
76+
with tf.control_dependencies([
77+
tf.debugging.assert_equal(
78+
query_shape[2],
79+
2,
80+
message="Query points must be size 2 in dim 2.")
81+
]):
82+
num_queries = query_shape[1]
83+
# pylint: enable=bad-continuation
6684

6785
query_type = query_points.dtype
6886
grid_type = grid.dtype
6987

70-
tf.debugging.assert_equal(
71-
query_shape[2], 2, message="Query points must be size 2 in dim 2.")
72-
73-
tf.debugging.assert_greater_equal(
74-
height, 2, message="Grid height must be at least 2."),
75-
tf.debugging.assert_greater_equal(
76-
width, 2, message="Grid width must be at least 2.")
77-
78-
alphas = []
79-
floors = []
80-
ceils = []
81-
index_order = [0, 1] if indexing == "ij" else [1, 0]
82-
unstacked_query_points = tf.unstack(query_points, axis=2)
88+
# pylint: disable=bad-continuation
89+
with tf.control_dependencies([
90+
tf.debugging.assert_greater_equal(
91+
height, 2, message="Grid height must be at least 2."),
92+
tf.debugging.assert_greater_equal(
93+
width, 2, message="Grid width must be at least 2."),
94+
]):
95+
alphas = []
96+
floors = []
97+
ceils = []
98+
index_order = [0, 1] if indexing == "ij" else [1, 0]
99+
unstacked_query_points = tf.unstack(query_points, axis=2)
100+
# pylint: enable=bad-continuation
83101

84102
for dim in index_order:
85103
with tf.name_scope("dim-" + str(dim)):
@@ -112,16 +130,21 @@ def interpolate_bilinear(grid, query_points, indexing="ij", name=None):
112130
alpha = tf.expand_dims(alpha, 2)
113131
alphas.append(alpha)
114132

115-
tf.debugging.assert_less_equal(
116-
tf.cast(batch_size * height * width, dtype=tf.dtypes.float32),
117-
np.iinfo(np.int32).max / 8.0,
118-
message="The image size or batch size is sufficiently large "
119-
"that the linearized addresses used by tf.gather "
120-
"may exceed the int32 limit.")
121-
flattened_grid = tf.reshape(grid,
122-
[batch_size * height * width, channels])
123-
batch_offsets = tf.reshape(
124-
tf.range(batch_size) * height * width, [batch_size, 1])
133+
# pylint: disable=bad-continuation
134+
with tf.control_dependencies([
135+
tf.debugging.assert_less_equal(
136+
tf.cast(
137+
batch_size * height * width, dtype=tf.dtypes.float32),
138+
np.iinfo(np.int32).max / 8.0,
139+
message="The image size or batch size is sufficiently "
140+
"large that the linearized addresses used by tf.gather "
141+
"may exceed the int32 limit.")
142+
]):
143+
flattened_grid = tf.reshape(
144+
grid, [batch_size * height * width, channels])
145+
batch_offsets = tf.reshape(
146+
tf.range(batch_size) * height * width, [batch_size, 1])
147+
# pylint: enable=bad-continuation
125148

126149
# This wraps tf.gather. We reshape the image data such that the
127150
# batch, y, and x coordinates are pulled into the first dimension.

tensorflow_addons/image/dense_image_warp_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,12 +214,12 @@ def test_gradients_exist(self):
214214
for _ in range(10):
215215
sess.run(opt_func)
216216

217-
# TODO: run in both graph and eager modes
217+
@test_utils.run_in_graph_and_eager_modes
218218
def test_size_exception(self):
219219
"""Make sure it throws an exception for images that are too small."""
220220
shape = [1, 2, 1, 1]
221-
with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
222-
"Grid width must be at least 2."):
221+
errors = (ValueError, tf.errors.InvalidArgumentError)
222+
with self.assertRaisesRegexp(errors, "Grid width must be at least 2."):
223223
self._check_interpolation_correctness(shape, "float32", "float32")
224224

225225

0 commit comments

Comments
 (0)