Skip to content

Commit cb09478

Browse files
WindQAQSquadrick
authored andcommitted
Make tests compatible with graph mode (#291)
1 parent 894a074 commit cb09478

File tree

2 files changed

+12
-15
lines changed

2 files changed

+12
-15
lines changed

tensorflow_addons/image/distort_image_ops.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,6 @@ def adjust_hsv_in_yiq(image,
139139
with tf.name_scope(name or "adjust_hsv_in_yiq"):
140140
image = tf.convert_to_tensor(image, name="image")
141141

142-
if image.shape.rank < 3:
143-
raise ValueError("input must be at least rank 3.")
144-
145142
# Remember original dtype to so we can convert back if needed
146143
orig_dtype = image.dtype
147144
flt_image = tf.image.convert_image_dtype(image, tf.dtypes.float32)

tensorflow_addons/image/distort_image_ops_test.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -102,18 +102,18 @@ def test_adjust_random_hue_in_yiq(self):
102102
y_tf = self._adjust_hue_in_yiq_tf(x_np, delta_h)
103103
self.assertAllClose(y_tf, y_np, rtol=2e-4, atol=1e-4)
104104

105-
# TODO: run in both graph and eager modes
105+
@test_utils.run_in_graph_and_eager_modes
106106
def test_invalid_shapes(self):
107107
x_np = np.random.rand(2, 3) * 255.
108108
delta_h = np.random.rand() * 2.0 - 1.0
109109
with self.assertRaisesRegexp(ValueError,
110-
"input must be at least rank 3."):
111-
self._adjust_hue_in_yiq_tf(x_np, delta_h)
110+
"Shape must be at least rank 3"):
111+
self.evaluate(self._adjust_hue_in_yiq_tf(x_np, delta_h))
112112
x_np = np.random.rand(4, 2, 4) * 255.
113113
delta_h = np.random.rand() * 2.0 - 1.0
114114
with self.assertRaisesOpError("input must have 3 channels "
115115
"but instead has 4 channels"):
116-
self._adjust_hue_in_yiq_tf(x_np, delta_h)
116+
self.evaluate(self._adjust_hue_in_yiq_tf(x_np, delta_h))
117117

118118

119119
class AdjustValueInYiqTest(tf.test.TestCase):
@@ -163,18 +163,18 @@ def test_adjust_random_value_in_yiq(self):
163163
y_tf = self._adjust_value_in_yiq_tf(x_np, scale)
164164
self.assertAllClose(y_tf, y_np, rtol=2e-4, atol=1e-4)
165165

166-
# TODO: run in both graph and eager modes
166+
@test_utils.run_in_graph_and_eager_modes
167167
def test_invalid_shapes(self):
168168
x_np = np.random.rand(2, 3) * 255.
169169
scale = np.random.rand() * 2.0 - 1.0
170170
with self.assertRaisesRegexp(ValueError,
171-
"input must be at least rank 3."):
172-
self._adjust_value_in_yiq_tf(x_np, scale)
171+
"Shape must be at least rank 3"):
172+
self.evaluate(self._adjust_value_in_yiq_tf(x_np, scale))
173173
x_np = np.random.rand(4, 2, 4) * 255.
174174
scale = np.random.rand() * 2.0 - 1.0
175175
with self.assertRaisesOpError("input must have 3 channels "
176176
"but instead has 4 channels"):
177-
self._adjust_value_in_yiq_tf(x_np, scale)
177+
self.evaluate(self._adjust_value_in_yiq_tf(x_np, scale))
178178

179179

180180
class AdjustSaturationInYiqTest(tf.test.TestCase):
@@ -228,18 +228,18 @@ def test_adjust_random_saturation_in_yiq(self):
228228
y_tf = self._adjust_saturation_in_yiq_tf(x_np, scale)
229229
self.assertAllClose(y_tf, y_baseline, rtol=2e-4, atol=1e-4)
230230

231-
# TODO: run in both graph and eager modes
231+
@test_utils.run_in_graph_and_eager_modes
232232
def test_invalid_shapes(self):
233233
x_np = np.random.rand(2, 3) * 255.
234234
scale = np.random.rand() * 2.0 - 1.0
235235
with self.assertRaisesRegexp(ValueError,
236-
"input must be at least rank 3."):
237-
self._adjust_saturation_in_yiq_tf(x_np, scale)
236+
"Shape must be at least rank 3"):
237+
self.evaluate(self._adjust_saturation_in_yiq_tf(x_np, scale))
238238
x_np = np.random.rand(4, 2, 4) * 255.
239239
scale = np.random.rand() * 2.0 - 1.0
240240
with self.assertRaisesOpError("input must have 3 channels "
241241
"but instead has 4 channels"):
242-
self._adjust_saturation_in_yiq_tf(x_np, scale)
242+
self.evaluate(self._adjust_saturation_in_yiq_tf(x_np, scale))
243243

244244

245245
# TODO: get rid of sessions

0 commit comments

Comments
 (0)