Skip to content

Commit c6f86d1

Browse files
committed
Support 2.0 Dataset
Signed-off-by: Yong Tang <[email protected]>
1 parent 3593b5e commit c6f86d1

File tree

4 files changed

+167
-222
lines changed

4 files changed

+167
-222
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Dataset."""
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import tensorflow as tf
21+
22+
class Dataset(tf.compat.v2.data.Dataset):
23+
"""A base Dataset"""
24+
25+
def __init__(self, fn, data_input, batch, dtypes, shapes):
26+
"""Create a base Dataset."""
27+
self._fn = fn
28+
self._data_input = data_input
29+
self._batch = 0 if batch is None else batch
30+
self._dtypes = dtypes
31+
self._shapes = shapes
32+
super(Dataset, self).__init__(fn(
33+
self._data_input,
34+
self._batch,
35+
output_types=self._dtypes,
36+
output_shapes=self._shapes))
37+
38+
def _inputs(self):
39+
return []
40+
41+
@property
42+
def _element_structure(self):
43+
e = [
44+
tf.data.experimental.TensorStructure(
45+
p, q.as_list()) for (p, q) in zip(
46+
self._dtypes, self._shapes)
47+
]
48+
if len(e) == 1:
49+
return e[0]
50+
return tf.data.experimental.NestedStructure(e)

tensorflow_io/mnist/python/ops/mnist_ops.py

Lines changed: 28 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -18,124 +18,48 @@
1818
from __future__ import print_function
1919

2020
import tensorflow as tf
21-
from tensorflow import dtypes
22-
from tensorflow.compat.v1 import data
21+
from tensorflow_io.core.python.ops import data_ops as data_ops
2322
from tensorflow_io.core.python.ops import core_ops as mnist_ops
2423

25-
class _MNISTBaseDataset(data.Dataset):
26-
"""A MNIST Dataset
27-
"""
28-
29-
def __init__(self, batch, mnist_op_class):
30-
"""Create a MNISTReader.
31-
32-
Args:
33-
mnist_op_class: The op of the dataset, either
34-
mnist_ops.mnist_image_dataset or mnist_ops.mnist_label_dataset.
35-
filenames: A `tf.string` tensor containing one or more filenames.
36-
"""
37-
self._batch = batch
38-
self._func = mnist_op_class
39-
super(_MNISTBaseDataset, self).__init__()
40-
41-
def _inputs(self):
42-
return []
43-
44-
def _as_variant_tensor(self):
45-
return self._func(
46-
self._data_input,
47-
self._batch,
48-
output_types=self.output_types,
49-
output_shapes=self.output_shapes)
50-
51-
@property
52-
def output_classes(self):
53-
return tf.Tensor
54-
55-
@property
56-
def output_types(self):
57-
return tuple([dtypes.uint8])
58-
59-
class MNISTImageDataset(_MNISTBaseDataset):
60-
"""A MNIST Image Dataset
24+
class MNISTLabelDataset(data_ops.Dataset):
25+
"""A MNISTLabelDataset
6126
"""
6227

6328
def __init__(self, filename, batch=None):
64-
"""Create a MNISTReader.
65-
29+
"""Create a MNISTLabelDataset.
6630
Args:
6731
filenames: A `tf.string` tensor containing one or more filenames.
6832
"""
6933
batch = 0 if batch is None else batch
70-
self._data_input = mnist_ops.mnist_image_input(filename, ["none", "gz"])
71-
super(MNISTImageDataset, self).__init__(
72-
batch, mnist_ops.mnist_image_dataset)
73-
74-
@property
75-
def output_shapes(self):
76-
return tuple([
77-
tf.TensorShape([None, None])]) if self._batch == 0 else tuple([
78-
tf.TensorShape([None, None, None])])
79-
34+
dtypes = [tf.uint8]
35+
shapes = [
36+
tf.TensorShape([])] if batch == 0 else [
37+
tf.TensorShape([batch])]
38+
super(MNISTLabelDataset, self).__init__(
39+
mnist_ops.mnist_label_dataset,
40+
mnist_ops.mnist_label_input(filename, ["none", "gz"]),
41+
batch, dtypes, shapes)
8042

81-
class MNISTLabelDataset(_MNISTBaseDataset):
82-
"""A MNIST Label Dataset
43+
class MNISTImageDataset(data_ops.Dataset):
44+
"""A MNISTImageDataset
8345
"""
8446

8547
def __init__(self, filename, batch=None):
86-
"""Create a MNISTReader.
87-
48+
"""Create a MNISTImageDataset.
8849
Args:
8950
filenames: A `tf.string` tensor containing one or more filenames.
9051
"""
9152
batch = 0 if batch is None else batch
92-
self._data_input = mnist_ops.mnist_label_input(filename, ["none", "gz"])
93-
super(MNISTLabelDataset, self).__init__(
94-
batch, mnist_ops.mnist_label_dataset)
95-
96-
@property
97-
def output_shapes(self):
98-
return tuple([
99-
tf.TensorShape([])]) if self._batch == 0 else tuple([
100-
tf.TensorShape([None])])
101-
102-
class MNISTDataset(data.Dataset):
103-
"""A MNIST Dataset
104-
"""
105-
106-
def __init__(self, image, label, batch=None):
107-
"""Create a MNISTReader.
108-
109-
Args:
110-
image: A `tf.string` tensor containing image filename.
111-
label: A `tf.string` tensor containing label filename.
112-
"""
113-
self._image = image
114-
self._label = label
115-
self._batch = 0 if batch is None else batch
116-
super(MNISTDataset, self).__init__()
117-
118-
def _inputs(self):
119-
return []
120-
121-
def _as_variant_tensor(self):
122-
return data.Dataset.zip( # pylint: disable=protected-access
123-
(MNISTImageDataset(self._image, self._batch),
124-
MNISTLabelDataset(self._label, self._batch))
125-
)._as_variant_tensor()
126-
127-
@property
128-
def output_shapes(self):
129-
return (
130-
tf.TensorShape([None, None]),
131-
tf.TensorShape([])) if self._batch == 0 else (
132-
tf.TensorShape([None, None, None]),
133-
tf.TensorShape([None]))
134-
135-
@property
136-
def output_classes(self):
137-
return tf.Tensor, tf.Tensor
138-
139-
@property
140-
def output_types(self):
141-
return dtypes.uint8, dtypes.uint8
53+
dtypes = [tf.uint8]
54+
shapes = [
55+
tf.TensorShape([None, None])] if batch == 0 else [
56+
tf.TensorShape([batch, None, None])]
57+
super(MNISTImageDataset, self).__init__(
58+
mnist_ops.mnist_image_dataset,
59+
mnist_ops.mnist_image_input(filename, ["none", "gz"]),
60+
batch, dtypes, shapes)
61+
62+
def MNISTDataset(image_filename, label_filename, batch=None):
63+
return data_ops.Dataset.zip((
64+
MNISTImageDataset(image_filename, batch),
65+
MNISTLabelDataset(label_filename, batch)))

tests/test_mnist.py

Lines changed: 51 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -19,82 +19,66 @@
1919
from __future__ import print_function
2020

2121
import os
22-
import numpy as np
23-
2422
import tensorflow as tf
25-
tf.compat.v1.disable_eager_execution()
26-
27-
from tensorflow import errors # pylint: disable=wrong-import-position
28-
from tensorflow import test # pylint: disable=wrong-import-position
29-
from tensorflow.compat.v1 import data # pylint: disable=wrong-import-position
30-
31-
from tensorflow_io import mnist as mnist_io # pylint: disable=wrong-import-position
32-
23+
import tensorflow_io.mnist as mnist_io
3324

34-
class MNISTDatasetTest(test.TestCase):
35-
"""MNISTDatasetTest"""
36-
def test_mnist_dataset(self):
37-
"""Test case for MNIST Dataset.
38-
"""
39-
mnist_filename = os.path.join(
40-
os.path.dirname(os.path.abspath(__file__)),
41-
"test_mnist",
42-
"mnist.npz")
43-
with np.load(mnist_filename) as f:
44-
(x_test, y_test) = f['x_test'], f['y_test']
25+
def test_mnist_tutorial():
26+
"""test_mnist_tutorial"""
27+
image_filename = os.path.join(
28+
os.path.dirname(os.path.abspath(__file__)),
29+
"test_mnist",
30+
"t10k-images-idx3-ubyte.gz")
31+
label_filename = os.path.join(
32+
os.path.dirname(os.path.abspath(__file__)),
33+
"test_mnist",
34+
"t10k-labels-idx1-ubyte.gz")
35+
d_train = mnist_io.MNISTDataset(
36+
image_filename,
37+
label_filename,
38+
batch=1000)
4539

46-
image_filename = os.path.join(
47-
os.path.dirname(os.path.abspath(__file__)),
48-
"test_mnist",
49-
"t10k-images-idx3-ubyte.gz")
50-
label_filename = os.path.join(
51-
os.path.dirname(os.path.abspath(__file__)),
52-
"test_mnist",
53-
"t10k-labels-idx1-ubyte.gz")
40+
d_train = d_train.map(lambda x, y: (tf.image.convert_image_dtype(x, tf.float32), y))
5441

55-
image_dataset = mnist_io.MNISTImageDataset(image_filename, batch=3)
56-
label_dataset = mnist_io.MNISTLabelDataset(label_filename, batch=3)
42+
model = tf.keras.models.Sequential([
43+
tf.keras.layers.Flatten(input_shape=(28, 28)),
44+
tf.keras.layers.Dense(512, activation=tf.nn.relu),
45+
tf.keras.layers.Dropout(0.2),
46+
tf.keras.layers.Dense(10, activation=tf.nn.softmax)
47+
])
48+
model.compile(optimizer='adam',
49+
loss='sparse_categorical_crossentropy',
50+
metrics=['accuracy'])
5751

58-
dataset = mnist_io.MNISTDataset(
59-
image_filename, label_filename)
52+
model.fit(d_train, epochs=5)
6053

61-
iterator = data.Dataset.zip(
62-
(image_dataset, label_dataset)).make_initializable_iterator()
63-
init_op = iterator.initializer
64-
get_next = iterator.get_next()
54+
def test_mnist_tutorial_uncompressed():
55+
"""test_mnist_tutorial_uncompressed"""
56+
image_filename = os.path.join(
57+
os.path.dirname(os.path.abspath(__file__)),
58+
"test_mnist",
59+
"t10k-images-idx3-ubyte")
60+
label_filename = os.path.join(
61+
os.path.dirname(os.path.abspath(__file__)),
62+
"test_mnist",
63+
"t10k-labels-idx1-ubyte")
64+
d_train = mnist_io.MNISTDataset(
65+
image_filename,
66+
label_filename,
67+
batch=1)
6568

66-
with self.cached_session() as sess:
67-
sess.run(init_op)
68-
l = len(y_test)
69-
for i in range(0, l-1, 3):
70-
v_x = x_test[i:i+3]
71-
v_y = y_test[i:i+3]
72-
m_x, m_y = sess.run(get_next)
73-
self.assertAllEqual(v_y, m_y)
74-
self.assertAllEqual(v_x, m_x)
75-
v_x = x_test[l-1:l]
76-
v_y = y_test[l-1:l]
77-
m_x, m_y = sess.run(get_next)
78-
self.assertAllEqual(v_y, m_y)
79-
self.assertAllEqual(v_x, m_x)
80-
with self.assertRaises(errors.OutOfRangeError):
81-
sess.run(get_next)
69+
d_train = d_train.map(lambda x, y: (tf.image.convert_image_dtype(x, tf.float32), y))
8270

83-
iterator = dataset.make_initializable_iterator()
84-
init_op = iterator.initializer
85-
get_next = iterator.get_next()
71+
model = tf.keras.models.Sequential([
72+
tf.keras.layers.Flatten(input_shape=(28, 28)),
73+
tf.keras.layers.Dense(512, activation=tf.nn.relu),
74+
tf.keras.layers.Dropout(0.2),
75+
tf.keras.layers.Dense(10, activation=tf.nn.softmax)
76+
])
77+
model.compile(optimizer='adam',
78+
loss='sparse_categorical_crossentropy',
79+
metrics=['accuracy'])
8680

87-
with self.cached_session() as sess:
88-
sess.run(init_op)
89-
l = len(y_test)
90-
for i in range(l):
91-
v_x = x_test[i]
92-
v_y = y_test[i]
93-
m_x, m_y = sess.run(get_next)
94-
self.assertAllEqual(v_y, m_y)
95-
self.assertAllEqual(v_x, m_x)
96-
with self.assertRaises(errors.OutOfRangeError):
97-
sess.run(get_next)
81+
model.fit(d_train, epochs=5)
9882

9983
if __name__ == "__main__":
10084
test.main()

0 commit comments

Comments
 (0)