From c6f86d18bdfb99451391c4c149f284129da67cb0 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 8 Jun 2019 21:05:48 +0000 Subject: [PATCH] Support 2.0 Dataset Signed-off-by: Yong Tang --- tensorflow_io/core/python/ops/data_ops.py | 50 ++++++++ tensorflow_io/mnist/python/ops/mnist_ops.py | 132 +++++--------------- tests/test_mnist.py | 118 ++++++++--------- tests/test_mnist_eager.py | 89 ++++++------- 4 files changed, 167 insertions(+), 222 deletions(-) create mode 100644 tensorflow_io/core/python/ops/data_ops.py diff --git a/tensorflow_io/core/python/ops/data_ops.py b/tensorflow_io/core/python/ops/data_ops.py new file mode 100644 index 000000000..e0247c3a8 --- /dev/null +++ b/tensorflow_io/core/python/ops/data_ops.py @@ -0,0 +1,50 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Dataset.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +class Dataset(tf.compat.v2.data.Dataset): + """A base Dataset""" + + def __init__(self, fn, data_input, batch, dtypes, shapes): + """Create a base Dataset.""" + self._fn = fn + self._data_input = data_input + self._batch = 0 if batch is None else batch + self._dtypes = dtypes + self._shapes = shapes + super(Dataset, self).__init__(fn( + self._data_input, + self._batch, + output_types=self._dtypes, + output_shapes=self._shapes)) + + def _inputs(self): + return [] + + @property + def _element_structure(self): + e = [ + tf.data.experimental.TensorStructure( + p, q.as_list()) for (p, q) in zip( + self._dtypes, self._shapes) + ] + if len(e) == 1: + return e[0] + return tf.data.experimental.NestedStructure(e) diff --git a/tensorflow_io/mnist/python/ops/mnist_ops.py b/tensorflow_io/mnist/python/ops/mnist_ops.py index 280a4dea4..b1ab22917 100644 --- a/tensorflow_io/mnist/python/ops/mnist_ops.py +++ b/tensorflow_io/mnist/python/ops/mnist_ops.py @@ -18,124 +18,48 @@ from __future__ import print_function import tensorflow as tf -from tensorflow import dtypes -from tensorflow.compat.v1 import data +from tensorflow_io.core.python.ops import data_ops as data_ops from tensorflow_io.core.python.ops import core_ops as mnist_ops -class _MNISTBaseDataset(data.Dataset): - """A MNIST Dataset - """ - - def __init__(self, batch, mnist_op_class): - """Create a MNISTReader. - - Args: - mnist_op_class: The op of the dataset, either - mnist_ops.mnist_image_dataset or mnist_ops.mnist_label_dataset. - filenames: A `tf.string` tensor containing one or more filenames. - """ - self._batch = batch - self._func = mnist_op_class - super(_MNISTBaseDataset, self).__init__() - - def _inputs(self): - return [] - - def _as_variant_tensor(self): - return self._func( - self._data_input, - self._batch, - output_types=self.output_types, - output_shapes=self.output_shapes) - - @property - def output_classes(self): - return tf.Tensor - - @property - def output_types(self): - return tuple([dtypes.uint8]) - -class MNISTImageDataset(_MNISTBaseDataset): - """A MNIST Image Dataset +class MNISTLabelDataset(data_ops.Dataset): + """A MNISTLabelDataset """ def __init__(self, filename, batch=None): - """Create a MNISTReader. - + """Create a MNISTLabelDataset. Args: filenames: A `tf.string` tensor containing one or more filenames. """ batch = 0 if batch is None else batch - self._data_input = mnist_ops.mnist_image_input(filename, ["none", "gz"]) - super(MNISTImageDataset, self).__init__( - batch, mnist_ops.mnist_image_dataset) - - @property - def output_shapes(self): - return tuple([ - tf.TensorShape([None, None])]) if self._batch == 0 else tuple([ - tf.TensorShape([None, None, None])]) - + dtypes = [tf.uint8] + shapes = [ + tf.TensorShape([])] if batch == 0 else [ + tf.TensorShape([batch])] + super(MNISTLabelDataset, self).__init__( + mnist_ops.mnist_label_dataset, + mnist_ops.mnist_label_input(filename, ["none", "gz"]), + batch, dtypes, shapes) -class MNISTLabelDataset(_MNISTBaseDataset): - """A MNIST Label Dataset +class MNISTImageDataset(data_ops.Dataset): + """A MNISTImageDataset """ def __init__(self, filename, batch=None): - """Create a MNISTReader. - + """Create a MNISTImageDataset. Args: filenames: A `tf.string` tensor containing one or more filenames. """ batch = 0 if batch is None else batch - self._data_input = mnist_ops.mnist_label_input(filename, ["none", "gz"]) - super(MNISTLabelDataset, self).__init__( - batch, mnist_ops.mnist_label_dataset) - - @property - def output_shapes(self): - return tuple([ - tf.TensorShape([])]) if self._batch == 0 else tuple([ - tf.TensorShape([None])]) - -class MNISTDataset(data.Dataset): - """A MNIST Dataset - """ - - def __init__(self, image, label, batch=None): - """Create a MNISTReader. - - Args: - image: A `tf.string` tensor containing image filename. - label: A `tf.string` tensor containing label filename. - """ - self._image = image - self._label = label - self._batch = 0 if batch is None else batch - super(MNISTDataset, self).__init__() - - def _inputs(self): - return [] - - def _as_variant_tensor(self): - return data.Dataset.zip( # pylint: disable=protected-access - (MNISTImageDataset(self._image, self._batch), - MNISTLabelDataset(self._label, self._batch)) - )._as_variant_tensor() - - @property - def output_shapes(self): - return ( - tf.TensorShape([None, None]), - tf.TensorShape([])) if self._batch == 0 else ( - tf.TensorShape([None, None, None]), - tf.TensorShape([None])) - - @property - def output_classes(self): - return tf.Tensor, tf.Tensor - - @property - def output_types(self): - return dtypes.uint8, dtypes.uint8 + dtypes = [tf.uint8] + shapes = [ + tf.TensorShape([None, None])] if batch == 0 else [ + tf.TensorShape([batch, None, None])] + super(MNISTImageDataset, self).__init__( + mnist_ops.mnist_image_dataset, + mnist_ops.mnist_image_input(filename, ["none", "gz"]), + batch, dtypes, shapes) + +def MNISTDataset(image_filename, label_filename, batch=None): + return data_ops.Dataset.zip(( + MNISTImageDataset(image_filename, batch), + MNISTLabelDataset(label_filename, batch))) diff --git a/tests/test_mnist.py b/tests/test_mnist.py index c1ec00087..afca27eb7 100644 --- a/tests/test_mnist.py +++ b/tests/test_mnist.py @@ -19,82 +19,66 @@ from __future__ import print_function import os -import numpy as np - import tensorflow as tf -tf.compat.v1.disable_eager_execution() - -from tensorflow import errors # pylint: disable=wrong-import-position -from tensorflow import test # pylint: disable=wrong-import-position -from tensorflow.compat.v1 import data # pylint: disable=wrong-import-position - -from tensorflow_io import mnist as mnist_io # pylint: disable=wrong-import-position - +import tensorflow_io.mnist as mnist_io -class MNISTDatasetTest(test.TestCase): - """MNISTDatasetTest""" - def test_mnist_dataset(self): - """Test case for MNIST Dataset. - """ - mnist_filename = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "test_mnist", - "mnist.npz") - with np.load(mnist_filename) as f: - (x_test, y_test) = f['x_test'], f['y_test'] +def test_mnist_tutorial(): + """test_mnist_tutorial""" + image_filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_mnist", + "t10k-images-idx3-ubyte.gz") + label_filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_mnist", + "t10k-labels-idx1-ubyte.gz") + d_train = mnist_io.MNISTDataset( + image_filename, + label_filename, + batch=1000) - image_filename = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "test_mnist", - "t10k-images-idx3-ubyte.gz") - label_filename = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "test_mnist", - "t10k-labels-idx1-ubyte.gz") + d_train = d_train.map(lambda x, y: (tf.image.convert_image_dtype(x, tf.float32), y)) - image_dataset = mnist_io.MNISTImageDataset(image_filename, batch=3) - label_dataset = mnist_io.MNISTLabelDataset(label_filename, batch=3) + model = tf.keras.models.Sequential([ + tf.keras.layers.Flatten(input_shape=(28, 28)), + tf.keras.layers.Dense(512, activation=tf.nn.relu), + tf.keras.layers.Dropout(0.2), + tf.keras.layers.Dense(10, activation=tf.nn.softmax) + ]) + model.compile(optimizer='adam', + loss='sparse_categorical_crossentropy', + metrics=['accuracy']) - dataset = mnist_io.MNISTDataset( - image_filename, label_filename) + model.fit(d_train, epochs=5) - iterator = data.Dataset.zip( - (image_dataset, label_dataset)).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() +def test_mnist_tutorial_uncompressed(): + """test_mnist_tutorial_uncompressed""" + image_filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_mnist", + "t10k-images-idx3-ubyte") + label_filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_mnist", + "t10k-labels-idx1-ubyte") + d_train = mnist_io.MNISTDataset( + image_filename, + label_filename, + batch=1) - with self.cached_session() as sess: - sess.run(init_op) - l = len(y_test) - for i in range(0, l-1, 3): - v_x = x_test[i:i+3] - v_y = y_test[i:i+3] - m_x, m_y = sess.run(get_next) - self.assertAllEqual(v_y, m_y) - self.assertAllEqual(v_x, m_x) - v_x = x_test[l-1:l] - v_y = y_test[l-1:l] - m_x, m_y = sess.run(get_next) - self.assertAllEqual(v_y, m_y) - self.assertAllEqual(v_x, m_x) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + d_train = d_train.map(lambda x, y: (tf.image.convert_image_dtype(x, tf.float32), y)) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() + model = tf.keras.models.Sequential([ + tf.keras.layers.Flatten(input_shape=(28, 28)), + tf.keras.layers.Dense(512, activation=tf.nn.relu), + tf.keras.layers.Dropout(0.2), + tf.keras.layers.Dense(10, activation=tf.nn.softmax) + ]) + model.compile(optimizer='adam', + loss='sparse_categorical_crossentropy', + metrics=['accuracy']) - with self.cached_session() as sess: - sess.run(init_op) - l = len(y_test) - for i in range(l): - v_x = x_test[i] - v_y = y_test[i] - m_x, m_y = sess.run(get_next) - self.assertAllEqual(v_y, m_y) - self.assertAllEqual(v_x, m_x) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + model.fit(d_train, epochs=5) if __name__ == "__main__": test.main() diff --git a/tests/test_mnist_eager.py b/tests/test_mnist_eager.py index 963a66a68..88be998d7 100644 --- a/tests/test_mnist_eager.py +++ b/tests/test_mnist_eager.py @@ -19,16 +19,22 @@ from __future__ import print_function import os -import pytest - +import numpy as np import tensorflow as tf -import tensorflow_io.mnist as mnist_io +if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")): + tf.compat.v1.enable_eager_execution() +import tensorflow_io.mnist as mnist_io # pylint: disable=wrong-import-position + +def test_mnist_dataset(): + """Test case for MNIST Dataset. + """ + mnist_filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_mnist", + "mnist.npz") + with np.load(mnist_filename) as f: + (x_test, y_test) = f['x_test'], f['y_test'] -@pytest.mark.skipif( - not (hasattr(tf, "version") and - tf.version.VERSION.startswith("2.0.")), reason=None) -def test_mnist_tutorial(): - """test_mnist_tutorial""" image_filename = os.path.join( os.path.dirname(os.path.abspath(__file__)), "test_mnist", @@ -37,54 +43,35 @@ def test_mnist_tutorial(): os.path.dirname(os.path.abspath(__file__)), "test_mnist", "t10k-labels-idx1-ubyte.gz") - d_train = mnist_io.MNISTDataset( - image_filename, - label_filename) - - d_train = d_train.map(lambda x, y: (tf.image.convert_image_dtype(x, tf.float32), y)).batch(1000) - - model = tf.keras.models.Sequential([ - tf.keras.layers.Flatten(input_shape=(28, 28)), - tf.keras.layers.Dense(512, activation=tf.nn.relu), - tf.keras.layers.Dropout(0.2), - tf.keras.layers.Dense(10, activation=tf.nn.softmax) - ]) - model.compile(optimizer='adam', - loss='sparse_categorical_crossentropy', - metrics=['accuracy']) - model.fit(d_train, epochs=5) + image_dataset = mnist_io.MNISTImageDataset(image_filename) + label_dataset = mnist_io.MNISTLabelDataset(label_filename) -@pytest.mark.skipif( - not (hasattr(tf, "version") and - tf.version.VERSION.startswith("2.0.")), reason=None) -def test_mnist_tutorial_uncompressed(): - """test_mnist_tutorial_uncompressed""" - image_filename = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "test_mnist", - "t10k-images-idx3-ubyte") - label_filename = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "test_mnist", - "t10k-labels-idx1-ubyte") - d_train = mnist_io.MNISTDataset( - image_filename, - label_filename) + i = 0 + for m_x in image_dataset: + v_x = x_test[i] + assert np.alltrue(v_x == m_x.numpy()) + i += 1 + assert i == len(y_test) - d_train = d_train.map(lambda x, y: (tf.image.convert_image_dtype(x, tf.float32), y)).batch(1) + i = 0 + for m_y in label_dataset: + v_y = y_test[i] + assert np.alltrue(v_y == m_y.numpy()) + i += 1 + assert i == len(y_test) - model = tf.keras.models.Sequential([ - tf.keras.layers.Flatten(input_shape=(28, 28)), - tf.keras.layers.Dense(512, activation=tf.nn.relu), - tf.keras.layers.Dropout(0.2), - tf.keras.layers.Dense(10, activation=tf.nn.softmax) - ]) - model.compile(optimizer='adam', - loss='sparse_categorical_crossentropy', - metrics=['accuracy']) + dataset = mnist_io.MNISTDataset( + image_filename, label_filename) - model.fit(d_train, epochs=5) + i = 0 + for (m_x, m_y) in dataset: + v_x = x_test[i] + v_y = y_test[i] + assert np.alltrue(v_y == m_y.numpy()) + assert np.alltrue(v_x == m_x.numpy()) + i += 1 + assert i == len(y_test) if __name__ == "__main__": test.main()