Skip to content

Commit c60fed8

Browse files
WindQAQseanpmorgan
authored andcommitted
shape checking in python code (#257)
1 parent d46dba1 commit c60fed8

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

tensorflow_addons/image/distort_image_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ def adjust_hsv_in_yiq(image,
138138
"""
139139
with tf.name_scope(name or "adjust_hsv_in_yiq"):
140140
image = tf.convert_to_tensor(image, name="image")
141+
142+
if image.shape.rank < 3:
143+
raise ValueError("input must be at least rank 3.")
144+
141145
# Remember original dtype to so we can convert back if needed
142146
orig_dtype = image.dtype
143147
flt_image = tf.image.convert_image_dtype(image, tf.dtypes.float32)

tensorflow_addons/image/distort_image_ops_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def test_invalid_shapes(self):
107107
x_np = np.random.rand(2, 3) * 255.
108108
delta_h = np.random.rand() * 2.0 - 1.0
109109
with self.assertRaisesRegexp(ValueError,
110-
"Shape must be at least rank 3"):
110+
"input must be at least rank 3."):
111111
self._adjust_hue_in_yiq_tf(x_np, delta_h)
112112
x_np = np.random.rand(4, 2, 4) * 255.
113113
delta_h = np.random.rand() * 2.0 - 1.0
@@ -168,7 +168,7 @@ def test_invalid_shapes(self):
168168
x_np = np.random.rand(2, 3) * 255.
169169
scale = np.random.rand() * 2.0 - 1.0
170170
with self.assertRaisesRegexp(ValueError,
171-
"Shape must be at least rank 3"):
171+
"input must be at least rank 3."):
172172
self._adjust_value_in_yiq_tf(x_np, scale)
173173
x_np = np.random.rand(4, 2, 4) * 255.
174174
scale = np.random.rand() * 2.0 - 1.0
@@ -233,7 +233,7 @@ def test_invalid_shapes(self):
233233
x_np = np.random.rand(2, 3) * 255.
234234
scale = np.random.rand() * 2.0 - 1.0
235235
with self.assertRaisesRegexp(ValueError,
236-
"Shape must be at least rank 3"):
236+
"input must be at least rank 3."):
237237
self._adjust_saturation_in_yiq_tf(x_np, scale)
238238
x_np = np.random.rand(4, 2, 4) * 255.
239239
scale = np.random.rand() * 2.0 - 1.0

0 commit comments

Comments
 (0)