Skip to content

Commit 00ee796

Browse files
committed
Split tests in eager and non-eager mode
Signed-off-by: Yong Tang <[email protected]>
1 parent 0aa22a4 commit 00ee796

File tree

2 files changed

+54
-24
lines changed

2 files changed

+54
-24
lines changed

tests/test_image.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -174,29 +174,5 @@ def test_draw_bounding_box(self):
174174
# self.assertAllEqual(bb_image_v, ex_image_v)
175175
_ = bb_image_p.eval()
176176

177-
def test_webp_file_dataset():
178-
"""Test case for WebPDataset.
179-
"""
180-
filename = os.path.join(
181-
os.path.dirname(os.path.abspath(__file__)), "test_image", "sample.webp")
182-
183-
num_repeats = 2
184-
185-
dataset = image_io.WebPDataset([filename, filename])
186-
# Repeat 2 times (2 * 2 = 4 images)
187-
dataset = dataset.repeat(num_repeats)
188-
# Drop alpha channel
189-
dataset = dataset.map(lambda x: x[:, :, :3])
190-
# Resize to 224 * 224
191-
dataset = dataset.map(lambda x: tf.keras.applications.resnet50.preprocess_input(tf.image.resize(x, (224, 224))))
192-
# Batch to 3, still have 4 images (3 + 1)
193-
dataset = dataset.batch(1)
194-
model = tf.keras.applications.resnet50.ResNet50(weights='imagenet')
195-
y = model.predict(dataset)
196-
p = tf.keras.applications.resnet50.decode_predictions(y, top=1)
197-
for i in p:
198-
assert i[0][1] == 'pineapple' # not truly a pineapple, though
199-
assert len(p) == 4
200-
201177
if __name__ == "__main__":
202178
test.main()

tests/test_image_eager.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4+
# use this file except in compliance with the License. You may obtain a copy of
5+
# the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
# License for the specific language governing permissions and limitations under
13+
# the License.
14+
# ==============================================================================
15+
"""Tests for Image Dataset."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import os
22+
23+
import tensorflow as tf
24+
if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")):
25+
tf.compat.v1.enable_eager_execution()
26+
import tensorflow_io.image as image_io # pylint: disable=wrong-import-position
27+
28+
29+
def test_webp_file_dataset():
30+
"""Test case for WebPDataset.
31+
"""
32+
filename = os.path.join(
33+
os.path.dirname(os.path.abspath(__file__)), "test_image", "sample.webp")
34+
35+
num_repeats = 2
36+
37+
dataset = image_io.WebPDataset([filename, filename])
38+
# Repeat 2 times (2 * 2 = 4 images)
39+
dataset = dataset.repeat(num_repeats)
40+
# Drop alpha channel
41+
dataset = dataset.map(lambda x: x[:, :, :3])
42+
# Resize to 224 * 224
43+
dataset = dataset.map(lambda x: tf.keras.applications.resnet50.preprocess_input(tf.image.resize(x, (224, 224))))
44+
# Batch to 3, still have 4 images (3 + 1)
45+
dataset = dataset.batch(1)
46+
model = tf.keras.applications.resnet50.ResNet50(weights='imagenet')
47+
y = model.predict(dataset)
48+
p = tf.keras.applications.resnet50.decode_predictions(y, top=1)
49+
for i in p:
50+
assert i[0][1] == 'pineapple' # not truly a pineapple, though
51+
assert len(p) == 4
52+
53+
if __name__ == "__main__":
54+
test.main()

0 commit comments

Comments
 (0)