Skip to content

Commit 0098fcb

Browse files
committed
Move tf.keras to separate function in test
Signed-off-by: Yong Tang <[email protected]>
1 parent 04070cc commit 0098fcb

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

tests/test_image.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -58,30 +58,6 @@ def test_decode_webp(self):
5858

5959
self.assertAllEqual(webp_v, png)
6060

61-
def test_webp_file_dataset(self):
62-
"""Test case for WebPDataset.
63-
"""
64-
filename = os.path.join(
65-
os.path.dirname(os.path.abspath(__file__)), "test_image", "sample.webp")
66-
67-
num_repeats = 2
68-
69-
dataset = image_io.WebPDataset([filename, filename])
70-
# Repeat 2 times (2 * 2 = 4 images)
71-
dataset = dataset.repeat(num_repeats)
72-
# Drop alpha channel
73-
dataset = dataset.map(lambda x: x[:, :, :3])
74-
# Resize to 224 * 224
75-
dataset = dataset.map(lambda x: tf.keras.applications.resnet50.preprocess_input(tf.image.resize(x, (224, 224))))
76-
# Batch to 3, still have 4 images (3 + 1)
77-
dataset = dataset.batch(1)
78-
model = tf.keras.applications.resnet50.ResNet50(weights='imagenet')
79-
y = model.predict(dataset)
80-
p = tf.keras.applications.resnet50.decode_predictions(y, top=1)
81-
for i in p:
82-
assert i[0][1] == 'pineapple' # not truly a pineapple, though
83-
assert len(p) == 4
84-
8561
def test_tiff_file_dataset(self):
8662
"""Test case for TIFFDataset.
8763
"""
@@ -198,5 +174,29 @@ def test_draw_bounding_box(self):
198174
# self.assertAllEqual(bb_image_v, ex_image_v)
199175
_ = bb_image_p.eval()
200176

177+
def test_webp_file_dataset():
178+
"""Test case for WebPDataset.
179+
"""
180+
filename = os.path.join(
181+
os.path.dirname(os.path.abspath(__file__)), "test_image", "sample.webp")
182+
183+
num_repeats = 2
184+
185+
dataset = image_io.WebPDataset([filename, filename])
186+
# Repeat 2 times (2 * 2 = 4 images)
187+
dataset = dataset.repeat(num_repeats)
188+
# Drop alpha channel
189+
dataset = dataset.map(lambda x: x[:, :, :3])
190+
# Resize to 224 * 224
191+
dataset = dataset.map(lambda x: tf.keras.applications.resnet50.preprocess_input(tf.image.resize(x, (224, 224))))
192+
# Batch to 3, still have 4 images (3 + 1)
193+
dataset = dataset.batch(1)
194+
model = tf.keras.applications.resnet50.ResNet50(weights='imagenet')
195+
y = model.predict(dataset)
196+
p = tf.keras.applications.resnet50.decode_predictions(y, top=1)
197+
for i in p:
198+
assert i[0][1] == 'pineapple' # not truly a pineapple, though
199+
assert len(p) == 4
200+
201201
if __name__ == "__main__":
202202
test.main()

0 commit comments

Comments
 (0)