Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions tensorflow_addons/custom_ops/image/cc/ops/distort_image_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ limitations under the License.

namespace tensorflow {

using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;

// --------------------------------------------------------------------------
REGISTER_OP("AdjustHsvInYiq")
Expand All @@ -30,6 +32,15 @@ REGISTER_OP("AdjustHsvInYiq")
.Output("output: T")
.Attr("T: {uint8, int8, int16, int32, int64, half, float, double}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle images, delta_h, scale_s, scale_v;

TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &images));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &delta_h));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &scale_s));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &scale_v));

DimensionHandle channels;
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), -1), 3, &channels));
return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
})
.Doc(R"Doc(
Expand Down
44 changes: 25 additions & 19 deletions tensorflow_addons/image/distort_image_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@
import tensorflow as tf
from tensorflow_addons.image import distort_image_ops

# from tensorflow_addons.utils import test_utils
from tensorflow_addons.utils import test_utils


# TODO: #373 Get this to run in graph mode as well
# @test_utils.run_all_in_graph_and_eager_modes
@test_utils.run_all_in_graph_and_eager_modes
class AdjustHueInYiqTest(tf.test.TestCase):
def _adjust_hue_in_yiq_np(self, x_np, delta_h):
"""Rotate hue in YIQ space.
Expand Down Expand Up @@ -101,15 +100,18 @@ def test_adjust_random_hue_in_yiq(self):
y_tf = self._adjust_hue_in_yiq_tf(x_np, delta_h)
self.assertAllClose(y_tf, y_np, rtol=2e-4, atol=1e-4)

def test_invalid_shapes(self):
def test_invalid_rank(self):
msg = "Shape must be at least rank 3 but is rank 2"
x_np = np.random.rand(2, 3) * 255.
delta_h = np.random.rand() * 2.0 - 1.0
with self.assertRaises((tf.errors.InvalidArgumentError, ValueError)):
with self.assertRaisesRegex(ValueError, msg):
self.evaluate(self._adjust_hue_in_yiq_tf(x_np, delta_h))

def test_invalid_channels(self):
msg = "Dimension must be 3 but is 4"
x_np = np.random.rand(4, 2, 4) * 255.
delta_h = np.random.rand() * 2.0 - 1.0
with self.assertRaisesOpError("input must have 3 channels "
"but instead has 4 channels"):
with self.assertRaisesRegex(ValueError, msg):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not figure out why the check failure here is not raised in the graph mode in the GPU version of tensorflow-2.0-preview, but it is raised in the CPU version.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also have no idea about it... I'd suppose it should either fail or pass regardless of the device.

self.evaluate(self._adjust_hue_in_yiq_tf(x_np, delta_h))

def test_adjust_hsv_in_yiq_unknown_shape(self):
Expand All @@ -132,8 +134,7 @@ def test_random_hsv_in_yiq_unknown_shape(self):
self.assertAllEqual(fn(image_tf), fn(image_tf))


# TODO: #373 Get this to run in graph mode as well
# @test_utils.run_all_in_graph_and_eager_modes
@test_utils.run_all_in_graph_and_eager_modes
class AdjustValueInYiqTest(tf.test.TestCase):
def _adjust_value_in_yiq_np(self, x_np, scale):
return x_np * scale
Expand Down Expand Up @@ -180,20 +181,22 @@ def test_adjust_random_value_in_yiq(self):
y_tf = self._adjust_value_in_yiq_tf(x_np, scale)
self.assertAllClose(y_tf, y_np, rtol=2e-4, atol=1e-4)

def test_invalid_shapes(self):
def test_invalid_rank(self):
msg = "Shape must be at least rank 3 but is rank 2"
x_np = np.random.rand(2, 3) * 255.
scale = np.random.rand() * 2.0 - 1.0
with self.assertRaises((tf.errors.InvalidArgumentError, ValueError)):
with self.assertRaisesRegex(ValueError, msg):
self.evaluate(self._adjust_value_in_yiq_tf(x_np, scale))

def test_invalid_channels(self):
msg = "Dimension must be 3 but is 4"
x_np = np.random.rand(4, 2, 4) * 255.
scale = np.random.rand() * 2.0 - 1.0
with self.assertRaisesOpError("input must have 3 channels "
"but instead has 4 channels"):
with self.assertRaisesRegex(ValueError, msg):
self.evaluate(self._adjust_value_in_yiq_tf(x_np, scale))


# TODO: #373 Get this to run in graph mode as well
# @test_utils.run_all_in_graph_and_eager_modes
@test_utils.run_all_in_graph_and_eager_modes
class AdjustSaturationInYiqTest(tf.test.TestCase):
def _adjust_saturation_in_yiq_tf(self, x_np, scale):
x = tf.constant(x_np)
Expand Down Expand Up @@ -244,15 +247,18 @@ def test_adjust_random_saturation_in_yiq(self):
y_tf = self._adjust_saturation_in_yiq_tf(x_np, scale)
self.assertAllClose(y_tf, y_baseline, rtol=2e-4, atol=1e-4)

def test_invalid_shapes(self):
def test_invalid_rank(self):
msg = "Shape must be at least rank 3 but is rank 2"
x_np = np.random.rand(2, 3) * 255.
scale = np.random.rand() * 2.0 - 1.0
with self.assertRaises((tf.errors.InvalidArgumentError, ValueError)):
with self.assertRaisesRegex(ValueError, msg):
self.evaluate(self._adjust_saturation_in_yiq_tf(x_np, scale))

def test_invalid_channels(self):
msg = "Dimension must be 3 but is 4"
x_np = np.random.rand(4, 2, 4) * 255.
scale = np.random.rand() * 2.0 - 1.0
with self.assertRaisesOpError("input must have 3 channels "
"but instead has 4 channels"):
with self.assertRaisesRegex(ValueError, msg):
self.evaluate(self._adjust_saturation_in_yiq_tf(x_np, scale))


Expand Down