From 75bdc8f626d57bfc2f2efe5f81c1744962fdb1d3 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Mon, 5 Aug 2019 16:01:26 +0800 Subject: [PATCH] fix graph mode tests --- .../image/cc/ops/distort_image_ops.cc | 11 +++++ .../image/distort_image_ops_test.py | 44 +++++++++++-------- 2 files changed, 36 insertions(+), 19 deletions(-) diff --git a/tensorflow_addons/custom_ops/image/cc/ops/distort_image_ops.cc b/tensorflow_addons/custom_ops/image/cc/ops/distort_image_ops.cc index 82357ea606..bbaa506ff3 100644 --- a/tensorflow_addons/custom_ops/image/cc/ops/distort_image_ops.cc +++ b/tensorflow_addons/custom_ops/image/cc/ops/distort_image_ops.cc @@ -19,7 +19,9 @@ limitations under the License. namespace tensorflow { +using shape_inference::DimensionHandle; using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; // -------------------------------------------------------------------------- REGISTER_OP("AdjustHsvInYiq") @@ -30,6 +32,15 @@ REGISTER_OP("AdjustHsvInYiq") .Output("output: T") .Attr("T: {uint8, int8, int16, int32, int64, half, float, double}") .SetShapeFn([](InferenceContext* c) { + ShapeHandle images, delta_h, scale_s, scale_v; + + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &images)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &delta_h)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &scale_s)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &scale_v)); + + DimensionHandle channels; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), -1), 3, &channels)); return shape_inference::UnchangedShapeWithRankAtLeast(c, 3); }) .Doc(R"Doc( diff --git a/tensorflow_addons/image/distort_image_ops_test.py b/tensorflow_addons/image/distort_image_ops_test.py index 0a03f35495..a06521fcf4 100644 --- a/tensorflow_addons/image/distort_image_ops_test.py +++ b/tensorflow_addons/image/distort_image_ops_test.py @@ -23,11 +23,10 @@ import tensorflow as tf from tensorflow_addons.image import distort_image_ops -# from tensorflow_addons.utils import test_utils +from tensorflow_addons.utils import test_utils -# TODO: #373 Get this to run in graph mode as well -# @test_utils.run_all_in_graph_and_eager_modes +@test_utils.run_all_in_graph_and_eager_modes class AdjustHueInYiqTest(tf.test.TestCase): def _adjust_hue_in_yiq_np(self, x_np, delta_h): """Rotate hue in YIQ space. @@ -101,15 +100,18 @@ def test_adjust_random_hue_in_yiq(self): y_tf = self._adjust_hue_in_yiq_tf(x_np, delta_h) self.assertAllClose(y_tf, y_np, rtol=2e-4, atol=1e-4) - def test_invalid_shapes(self): + def test_invalid_rank(self): + msg = "Shape must be at least rank 3 but is rank 2" x_np = np.random.rand(2, 3) * 255. delta_h = np.random.rand() * 2.0 - 1.0 - with self.assertRaises((tf.errors.InvalidArgumentError, ValueError)): + with self.assertRaisesRegex(ValueError, msg): self.evaluate(self._adjust_hue_in_yiq_tf(x_np, delta_h)) + + def test_invalid_channels(self): + msg = "Dimension must be 3 but is 4" x_np = np.random.rand(4, 2, 4) * 255. delta_h = np.random.rand() * 2.0 - 1.0 - with self.assertRaisesOpError("input must have 3 channels " - "but instead has 4 channels"): + with self.assertRaisesRegex(ValueError, msg): self.evaluate(self._adjust_hue_in_yiq_tf(x_np, delta_h)) def test_adjust_hsv_in_yiq_unknown_shape(self): @@ -132,8 +134,7 @@ def test_random_hsv_in_yiq_unknown_shape(self): self.assertAllEqual(fn(image_tf), fn(image_tf)) -# TODO: #373 Get this to run in graph mode as well -# @test_utils.run_all_in_graph_and_eager_modes +@test_utils.run_all_in_graph_and_eager_modes class AdjustValueInYiqTest(tf.test.TestCase): def _adjust_value_in_yiq_np(self, x_np, scale): return x_np * scale @@ -180,20 +181,22 @@ def test_adjust_random_value_in_yiq(self): y_tf = self._adjust_value_in_yiq_tf(x_np, scale) self.assertAllClose(y_tf, y_np, rtol=2e-4, atol=1e-4) - def test_invalid_shapes(self): + def test_invalid_rank(self): + msg = "Shape must be at least rank 3 but is rank 2" x_np = np.random.rand(2, 3) * 255. scale = np.random.rand() * 2.0 - 1.0 - with self.assertRaises((tf.errors.InvalidArgumentError, ValueError)): + with self.assertRaisesRegex(ValueError, msg): self.evaluate(self._adjust_value_in_yiq_tf(x_np, scale)) + + def test_invalid_channels(self): + msg = "Dimension must be 3 but is 4" x_np = np.random.rand(4, 2, 4) * 255. scale = np.random.rand() * 2.0 - 1.0 - with self.assertRaisesOpError("input must have 3 channels " - "but instead has 4 channels"): + with self.assertRaisesRegex(ValueError, msg): self.evaluate(self._adjust_value_in_yiq_tf(x_np, scale)) -# TODO: #373 Get this to run in graph mode as well -# @test_utils.run_all_in_graph_and_eager_modes +@test_utils.run_all_in_graph_and_eager_modes class AdjustSaturationInYiqTest(tf.test.TestCase): def _adjust_saturation_in_yiq_tf(self, x_np, scale): x = tf.constant(x_np) @@ -244,15 +247,18 @@ def test_adjust_random_saturation_in_yiq(self): y_tf = self._adjust_saturation_in_yiq_tf(x_np, scale) self.assertAllClose(y_tf, y_baseline, rtol=2e-4, atol=1e-4) - def test_invalid_shapes(self): + def test_invalid_rank(self): + msg = "Shape must be at least rank 3 but is rank 2" x_np = np.random.rand(2, 3) * 255. scale = np.random.rand() * 2.0 - 1.0 - with self.assertRaises((tf.errors.InvalidArgumentError, ValueError)): + with self.assertRaisesRegex(ValueError, msg): self.evaluate(self._adjust_saturation_in_yiq_tf(x_np, scale)) + + def test_invalid_channels(self): + msg = "Dimension must be 3 but is 4" x_np = np.random.rand(4, 2, 4) * 255. scale = np.random.rand() * 2.0 - 1.0 - with self.assertRaisesOpError("input must have 3 channels " - "but instead has 4 channels"): + with self.assertRaisesRegex(ValueError, msg): self.evaluate(self._adjust_saturation_in_yiq_tf(x_np, scale))