@@ -58,30 +58,6 @@ def test_decode_webp(self):
5858
5959 self .assertAllEqual (webp_v , png )
6060
61- def test_webp_file_dataset (self ):
62- """Test case for WebPDataset.
63- """
64- filename = os .path .join (
65- os .path .dirname (os .path .abspath (__file__ )), "test_image" , "sample.webp" )
66-
67- num_repeats = 2
68-
69- dataset = image_io .WebPDataset ([filename , filename ])
70- # Repeat 2 times (2 * 2 = 4 images)
71- dataset = dataset .repeat (num_repeats )
72- # Drop alpha channel
73- dataset = dataset .map (lambda x : x [:, :, :3 ])
74- # Resize to 224 * 224
75- dataset = dataset .map (lambda x : tf .keras .applications .resnet50 .preprocess_input (tf .image .resize (x , (224 , 224 ))))
76- # Batch to 3, still have 4 images (3 + 1)
77- dataset = dataset .batch (1 )
78- model = tf .keras .applications .resnet50 .ResNet50 (weights = 'imagenet' )
79- y = model .predict (dataset )
80- p = tf .keras .applications .resnet50 .decode_predictions (y , top = 1 )
81- for i in p :
82- assert i [0 ][1 ] == 'pineapple' # not truly a pineapple, though
83- assert len (p ) == 4
84-
8561 def test_tiff_file_dataset (self ):
8662 """Test case for TIFFDataset.
8763 """
@@ -198,5 +174,29 @@ def test_draw_bounding_box(self):
198174 # self.assertAllEqual(bb_image_v, ex_image_v)
199175 _ = bb_image_p .eval ()
200176
177+ def test_webp_file_dataset ():
178+ """Test case for WebPDataset.
179+ """
180+ filename = os .path .join (
181+ os .path .dirname (os .path .abspath (__file__ )), "test_image" , "sample.webp" )
182+
183+ num_repeats = 2
184+
185+ dataset = image_io .WebPDataset ([filename , filename ])
186+ # Repeat 2 times (2 * 2 = 4 images)
187+ dataset = dataset .repeat (num_repeats )
188+ # Drop alpha channel
189+ dataset = dataset .map (lambda x : x [:, :, :3 ])
190+ # Resize to 224 * 224
191+ dataset = dataset .map (lambda x : tf .keras .applications .resnet50 .preprocess_input (tf .image .resize (x , (224 , 224 ))))
192+ # Batch to 3, still have 4 images (3 + 1)
193+ dataset = dataset .batch (1 )
194+ model = tf .keras .applications .resnet50 .ResNet50 (weights = 'imagenet' )
195+ y = model .predict (dataset )
196+ p = tf .keras .applications .resnet50 .decode_predictions (y , top = 1 )
197+ for i in p :
198+ assert i [0 ][1 ] == 'pineapple' # not truly a pineapple, though
199+ assert len (p ) == 4
200+
201201if __name__ == "__main__" :
202202 test .main ()
0 commit comments