Skip to content

Commit 695dc19

Browse files
WindQAQseanpmorgan
authored andcommitted
fix graph mode tests (#395)
1 parent a0b3beb commit 695dc19

File tree

2 files changed

+36
-19
lines changed

2 files changed

+36
-19
lines changed

tensorflow_addons/custom_ops/image/cc/ops/distort_image_ops.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ limitations under the License.
1919

2020
namespace tensorflow {
2121

22+
using shape_inference::DimensionHandle;
2223
using shape_inference::InferenceContext;
24+
using shape_inference::ShapeHandle;
2325

2426
// --------------------------------------------------------------------------
2527
REGISTER_OP("AdjustHsvInYiq")
@@ -30,6 +32,15 @@ REGISTER_OP("AdjustHsvInYiq")
3032
.Output("output: T")
3133
.Attr("T: {uint8, int8, int16, int32, int64, half, float, double}")
3234
.SetShapeFn([](InferenceContext* c) {
35+
ShapeHandle images, delta_h, scale_s, scale_v;
36+
37+
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &images));
38+
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &delta_h));
39+
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &scale_s));
40+
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &scale_v));
41+
42+
DimensionHandle channels;
43+
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), -1), 3, &channels));
3344
return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
3445
})
3546
.Doc(R"Doc(

tensorflow_addons/image/distort_image_ops_test.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,10 @@
2323
import tensorflow as tf
2424
from tensorflow_addons.image import distort_image_ops
2525

26-
# from tensorflow_addons.utils import test_utils
26+
from tensorflow_addons.utils import test_utils
2727

2828

29-
# TODO: #373 Get this to run in graph mode as well
30-
# @test_utils.run_all_in_graph_and_eager_modes
29+
@test_utils.run_all_in_graph_and_eager_modes
3130
class AdjustHueInYiqTest(tf.test.TestCase):
3231
def _adjust_hue_in_yiq_np(self, x_np, delta_h):
3332
"""Rotate hue in YIQ space.
@@ -101,15 +100,18 @@ def test_adjust_random_hue_in_yiq(self):
101100
y_tf = self._adjust_hue_in_yiq_tf(x_np, delta_h)
102101
self.assertAllClose(y_tf, y_np, rtol=2e-4, atol=1e-4)
103102

104-
def test_invalid_shapes(self):
103+
def test_invalid_rank(self):
104+
msg = "Shape must be at least rank 3 but is rank 2"
105105
x_np = np.random.rand(2, 3) * 255.
106106
delta_h = np.random.rand() * 2.0 - 1.0
107-
with self.assertRaises((tf.errors.InvalidArgumentError, ValueError)):
107+
with self.assertRaisesRegex(ValueError, msg):
108108
self.evaluate(self._adjust_hue_in_yiq_tf(x_np, delta_h))
109+
110+
def test_invalid_channels(self):
111+
msg = "Dimension must be 3 but is 4"
109112
x_np = np.random.rand(4, 2, 4) * 255.
110113
delta_h = np.random.rand() * 2.0 - 1.0
111-
with self.assertRaisesOpError("input must have 3 channels "
112-
"but instead has 4 channels"):
114+
with self.assertRaisesRegex(ValueError, msg):
113115
self.evaluate(self._adjust_hue_in_yiq_tf(x_np, delta_h))
114116

115117
def test_adjust_hsv_in_yiq_unknown_shape(self):
@@ -132,8 +134,7 @@ def test_random_hsv_in_yiq_unknown_shape(self):
132134
self.assertAllEqual(fn(image_tf), fn(image_tf))
133135

134136

135-
# TODO: #373 Get this to run in graph mode as well
136-
# @test_utils.run_all_in_graph_and_eager_modes
137+
@test_utils.run_all_in_graph_and_eager_modes
137138
class AdjustValueInYiqTest(tf.test.TestCase):
138139
def _adjust_value_in_yiq_np(self, x_np, scale):
139140
return x_np * scale
@@ -180,20 +181,22 @@ def test_adjust_random_value_in_yiq(self):
180181
y_tf = self._adjust_value_in_yiq_tf(x_np, scale)
181182
self.assertAllClose(y_tf, y_np, rtol=2e-4, atol=1e-4)
182183

183-
def test_invalid_shapes(self):
184+
def test_invalid_rank(self):
185+
msg = "Shape must be at least rank 3 but is rank 2"
184186
x_np = np.random.rand(2, 3) * 255.
185187
scale = np.random.rand() * 2.0 - 1.0
186-
with self.assertRaises((tf.errors.InvalidArgumentError, ValueError)):
188+
with self.assertRaisesRegex(ValueError, msg):
187189
self.evaluate(self._adjust_value_in_yiq_tf(x_np, scale))
190+
191+
def test_invalid_channels(self):
192+
msg = "Dimension must be 3 but is 4"
188193
x_np = np.random.rand(4, 2, 4) * 255.
189194
scale = np.random.rand() * 2.0 - 1.0
190-
with self.assertRaisesOpError("input must have 3 channels "
191-
"but instead has 4 channels"):
195+
with self.assertRaisesRegex(ValueError, msg):
192196
self.evaluate(self._adjust_value_in_yiq_tf(x_np, scale))
193197

194198

195-
# TODO: #373 Get this to run in graph mode as well
196-
# @test_utils.run_all_in_graph_and_eager_modes
199+
@test_utils.run_all_in_graph_and_eager_modes
197200
class AdjustSaturationInYiqTest(tf.test.TestCase):
198201
def _adjust_saturation_in_yiq_tf(self, x_np, scale):
199202
x = tf.constant(x_np)
@@ -244,15 +247,18 @@ def test_adjust_random_saturation_in_yiq(self):
244247
y_tf = self._adjust_saturation_in_yiq_tf(x_np, scale)
245248
self.assertAllClose(y_tf, y_baseline, rtol=2e-4, atol=1e-4)
246249

247-
def test_invalid_shapes(self):
250+
def test_invalid_rank(self):
251+
msg = "Shape must be at least rank 3 but is rank 2"
248252
x_np = np.random.rand(2, 3) * 255.
249253
scale = np.random.rand() * 2.0 - 1.0
250-
with self.assertRaises((tf.errors.InvalidArgumentError, ValueError)):
254+
with self.assertRaisesRegex(ValueError, msg):
251255
self.evaluate(self._adjust_saturation_in_yiq_tf(x_np, scale))
256+
257+
def test_invalid_channels(self):
258+
msg = "Dimension must be 3 but is 4"
252259
x_np = np.random.rand(4, 2, 4) * 255.
253260
scale = np.random.rand() * 2.0 - 1.0
254-
with self.assertRaisesOpError("input must have 3 channels "
255-
"but instead has 4 channels"):
261+
with self.assertRaisesRegex(ValueError, msg):
256262
self.evaluate(self._adjust_saturation_in_yiq_tf(x_np, scale))
257263

258264

0 commit comments

Comments
 (0)