2323import tensorflow as tf
2424from tensorflow_addons .image import distort_image_ops
2525
26- # from tensorflow_addons.utils import test_utils
26+ from tensorflow_addons .utils import test_utils
2727
2828
29- # TODO: #373 Get this to run in graph mode as well
30- # @test_utils.run_all_in_graph_and_eager_modes
29+ @test_utils .run_all_in_graph_and_eager_modes
3130class AdjustHueInYiqTest (tf .test .TestCase ):
3231 def _adjust_hue_in_yiq_np (self , x_np , delta_h ):
3332 """Rotate hue in YIQ space.
@@ -101,15 +100,18 @@ def test_adjust_random_hue_in_yiq(self):
101100 y_tf = self ._adjust_hue_in_yiq_tf (x_np , delta_h )
102101 self .assertAllClose (y_tf , y_np , rtol = 2e-4 , atol = 1e-4 )
103102
104- def test_invalid_shapes (self ):
103+ def test_invalid_rank (self ):
104+ msg = "Shape must be at least rank 3 but is rank 2"
105105 x_np = np .random .rand (2 , 3 ) * 255.
106106 delta_h = np .random .rand () * 2.0 - 1.0
107- with self .assertRaises (( tf . errors . InvalidArgumentError , ValueError ) ):
107+ with self .assertRaisesRegex ( ValueError , msg ):
108108 self .evaluate (self ._adjust_hue_in_yiq_tf (x_np , delta_h ))
109+
110+ def test_invalid_channels (self ):
111+ msg = "Dimension must be 3 but is 4"
109112 x_np = np .random .rand (4 , 2 , 4 ) * 255.
110113 delta_h = np .random .rand () * 2.0 - 1.0
111- with self .assertRaisesOpError ("input must have 3 channels "
112- "but instead has 4 channels" ):
114+ with self .assertRaisesRegex (ValueError , msg ):
113115 self .evaluate (self ._adjust_hue_in_yiq_tf (x_np , delta_h ))
114116
115117 def test_adjust_hsv_in_yiq_unknown_shape (self ):
@@ -132,8 +134,7 @@ def test_random_hsv_in_yiq_unknown_shape(self):
132134 self .assertAllEqual (fn (image_tf ), fn (image_tf ))
133135
134136
135- # TODO: #373 Get this to run in graph mode as well
136- # @test_utils.run_all_in_graph_and_eager_modes
137+ @test_utils .run_all_in_graph_and_eager_modes
137138class AdjustValueInYiqTest (tf .test .TestCase ):
138139 def _adjust_value_in_yiq_np (self , x_np , scale ):
139140 return x_np * scale
@@ -180,20 +181,22 @@ def test_adjust_random_value_in_yiq(self):
180181 y_tf = self ._adjust_value_in_yiq_tf (x_np , scale )
181182 self .assertAllClose (y_tf , y_np , rtol = 2e-4 , atol = 1e-4 )
182183
183- def test_invalid_shapes (self ):
184+ def test_invalid_rank (self ):
185+ msg = "Shape must be at least rank 3 but is rank 2"
184186 x_np = np .random .rand (2 , 3 ) * 255.
185187 scale = np .random .rand () * 2.0 - 1.0
186- with self .assertRaises (( tf . errors . InvalidArgumentError , ValueError ) ):
188+ with self .assertRaisesRegex ( ValueError , msg ):
187189 self .evaluate (self ._adjust_value_in_yiq_tf (x_np , scale ))
190+
191+ def test_invalid_channels (self ):
192+ msg = "Dimension must be 3 but is 4"
188193 x_np = np .random .rand (4 , 2 , 4 ) * 255.
189194 scale = np .random .rand () * 2.0 - 1.0
190- with self .assertRaisesOpError ("input must have 3 channels "
191- "but instead has 4 channels" ):
195+ with self .assertRaisesRegex (ValueError , msg ):
192196 self .evaluate (self ._adjust_value_in_yiq_tf (x_np , scale ))
193197
194198
195- # TODO: #373 Get this to run in graph mode as well
196- # @test_utils.run_all_in_graph_and_eager_modes
199+ @test_utils .run_all_in_graph_and_eager_modes
197200class AdjustSaturationInYiqTest (tf .test .TestCase ):
198201 def _adjust_saturation_in_yiq_tf (self , x_np , scale ):
199202 x = tf .constant (x_np )
@@ -244,15 +247,18 @@ def test_adjust_random_saturation_in_yiq(self):
244247 y_tf = self ._adjust_saturation_in_yiq_tf (x_np , scale )
245248 self .assertAllClose (y_tf , y_baseline , rtol = 2e-4 , atol = 1e-4 )
246249
247- def test_invalid_shapes (self ):
250+ def test_invalid_rank (self ):
251+ msg = "Shape must be at least rank 3 but is rank 2"
248252 x_np = np .random .rand (2 , 3 ) * 255.
249253 scale = np .random .rand () * 2.0 - 1.0
250- with self .assertRaises (( tf . errors . InvalidArgumentError , ValueError ) ):
254+ with self .assertRaisesRegex ( ValueError , msg ):
251255 self .evaluate (self ._adjust_saturation_in_yiq_tf (x_np , scale ))
256+
257+ def test_invalid_channels (self ):
258+ msg = "Dimension must be 3 but is 4"
252259 x_np = np .random .rand (4 , 2 , 4 ) * 255.
253260 scale = np .random .rand () * 2.0 - 1.0
254- with self .assertRaisesOpError ("input must have 3 channels "
255- "but instead has 4 channels" ):
261+ with self .assertRaisesRegex (ValueError , msg ):
256262 self .evaluate (self ._adjust_saturation_in_yiq_tf (x_np , scale ))
257263
258264
0 commit comments