|
18 | 18 | from __future__ import print_function |
19 | 19 |
|
20 | 20 | import tensorflow as tf |
21 | | -from tensorflow import dtypes |
22 | | -from tensorflow.compat.v1 import data |
| 21 | +from tensorflow_io.core.python.ops import data_ops as data_ops |
23 | 22 | from tensorflow_io.core.python.ops import core_ops as mnist_ops |
24 | 23 |
|
25 | | -class _MNISTBaseDataset(data.Dataset): |
26 | | - """A MNIST Dataset |
27 | | - """ |
28 | | - |
29 | | - def __init__(self, batch, mnist_op_class): |
30 | | - """Create a MNISTReader. |
31 | | -
|
32 | | - Args: |
33 | | - mnist_op_class: The op of the dataset, either |
34 | | - mnist_ops.mnist_image_dataset or mnist_ops.mnist_label_dataset. |
35 | | - filenames: A `tf.string` tensor containing one or more filenames. |
36 | | - """ |
37 | | - self._batch = batch |
38 | | - self._func = mnist_op_class |
39 | | - super(_MNISTBaseDataset, self).__init__() |
40 | | - |
41 | | - def _inputs(self): |
42 | | - return [] |
43 | | - |
44 | | - def _as_variant_tensor(self): |
45 | | - return self._func( |
46 | | - self._data_input, |
47 | | - self._batch, |
48 | | - output_types=self.output_types, |
49 | | - output_shapes=self.output_shapes) |
50 | | - |
51 | | - @property |
52 | | - def output_classes(self): |
53 | | - return tf.Tensor |
54 | | - |
55 | | - @property |
56 | | - def output_types(self): |
57 | | - return tuple([dtypes.uint8]) |
58 | | - |
59 | | -class MNISTImageDataset(_MNISTBaseDataset): |
60 | | - """A MNIST Image Dataset |
| 24 | +class MNISTLabelDataset(data_ops.Dataset): |
| 25 | + """A MNISTLabelDataset |
61 | 26 | """ |
62 | 27 |
|
63 | 28 | def __init__(self, filename, batch=None): |
64 | | - """Create a MNISTReader. |
65 | | -
|
| 29 | + """Create a MNISTLabelDataset. |
66 | 30 | Args: |
67 | 31 | filenames: A `tf.string` tensor containing one or more filenames. |
68 | 32 | """ |
69 | 33 | batch = 0 if batch is None else batch |
70 | | - self._data_input = mnist_ops.mnist_image_input(filename, ["none", "gz"]) |
71 | | - super(MNISTImageDataset, self).__init__( |
72 | | - batch, mnist_ops.mnist_image_dataset) |
73 | | - |
74 | | - @property |
75 | | - def output_shapes(self): |
76 | | - return tuple([ |
77 | | - tf.TensorShape([None, None])]) if self._batch == 0 else tuple([ |
78 | | - tf.TensorShape([None, None, None])]) |
79 | | - |
| 34 | + dtypes = [tf.uint8] |
| 35 | + shapes = [ |
| 36 | + tf.TensorShape([])] if batch == 0 else [ |
| 37 | + tf.TensorShape([batch])] |
| 38 | + super(MNISTLabelDataset, self).__init__( |
| 39 | + mnist_ops.mnist_label_dataset, |
| 40 | + mnist_ops.mnist_label_input(filename, ["none", "gz"]), |
| 41 | + batch, dtypes, shapes) |
80 | 42 |
|
81 | | -class MNISTLabelDataset(_MNISTBaseDataset): |
82 | | - """A MNIST Label Dataset |
| 43 | +class MNISTImageDataset(data_ops.Dataset): |
| 44 | + """A MNISTImageDataset |
83 | 45 | """ |
84 | 46 |
|
85 | 47 | def __init__(self, filename, batch=None): |
86 | | - """Create a MNISTReader. |
87 | | -
|
| 48 | + """Create a MNISTImageDataset. |
88 | 49 | Args: |
89 | 50 | filenames: A `tf.string` tensor containing one or more filenames. |
90 | 51 | """ |
91 | 52 | batch = 0 if batch is None else batch |
92 | | - self._data_input = mnist_ops.mnist_label_input(filename, ["none", "gz"]) |
93 | | - super(MNISTLabelDataset, self).__init__( |
94 | | - batch, mnist_ops.mnist_label_dataset) |
95 | | - |
96 | | - @property |
97 | | - def output_shapes(self): |
98 | | - return tuple([ |
99 | | - tf.TensorShape([])]) if self._batch == 0 else tuple([ |
100 | | - tf.TensorShape([None])]) |
101 | | - |
102 | | -class MNISTDataset(data.Dataset): |
103 | | - """A MNIST Dataset |
104 | | - """ |
105 | | - |
106 | | - def __init__(self, image, label, batch=None): |
107 | | - """Create a MNISTReader. |
108 | | -
|
109 | | - Args: |
110 | | - image: A `tf.string` tensor containing image filename. |
111 | | - label: A `tf.string` tensor containing label filename. |
112 | | - """ |
113 | | - self._image = image |
114 | | - self._label = label |
115 | | - self._batch = 0 if batch is None else batch |
116 | | - super(MNISTDataset, self).__init__() |
117 | | - |
118 | | - def _inputs(self): |
119 | | - return [] |
120 | | - |
121 | | - def _as_variant_tensor(self): |
122 | | - return data.Dataset.zip( # pylint: disable=protected-access |
123 | | - (MNISTImageDataset(self._image, self._batch), |
124 | | - MNISTLabelDataset(self._label, self._batch)) |
125 | | - )._as_variant_tensor() |
126 | | - |
127 | | - @property |
128 | | - def output_shapes(self): |
129 | | - return ( |
130 | | - tf.TensorShape([None, None]), |
131 | | - tf.TensorShape([])) if self._batch == 0 else ( |
132 | | - tf.TensorShape([None, None, None]), |
133 | | - tf.TensorShape([None])) |
134 | | - |
135 | | - @property |
136 | | - def output_classes(self): |
137 | | - return tf.Tensor, tf.Tensor |
138 | | - |
139 | | - @property |
140 | | - def output_types(self): |
141 | | - return dtypes.uint8, dtypes.uint8 |
| 53 | + dtypes = [tf.uint8] |
| 54 | + shapes = [ |
| 55 | + tf.TensorShape([None, None])] if batch == 0 else [ |
| 56 | + tf.TensorShape([batch, None, None])] |
| 57 | + super(MNISTImageDataset, self).__init__( |
| 58 | + mnist_ops.mnist_image_dataset, |
| 59 | + mnist_ops.mnist_image_input(filename, ["none", "gz"]), |
| 60 | + batch, dtypes, shapes) |
| 61 | + |
| 62 | +def MNISTDataset(image_filename, label_filename, batch=None): |
| 63 | + return data_ops.Dataset.zip(( |
| 64 | + MNISTImageDataset(image_filename, batch), |
| 65 | + MNISTLabelDataset(label_filename, batch))) |
0 commit comments