From 1baf9ee65842631b7337358f8dd22ef2bdb1b000 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 20 Jul 2019 17:49:05 +0000 Subject: [PATCH 1/5] Add FileDataset to read whole content of file into tf.data pipeline When we started the project, we thought data read through tf.data pipeline will always be record like. That is, each file will have multiple records. This scenario is the case for many situations like Text, Video, etc. where the natural boundary within the file are new line feed, or inidvidual frames. But for many images formats such as webp, jpeg, etc, each image is already a boundary and there is no need to further partition. Further, people still perfer `decode_xxx` call in many situations. This PR adds a FileDataset which could take files (and potentially compressed files) and feed whole file content as string into tf.data. This FileDataset is not a fit where files are really large (e.g., 100GB of text record). However, it is good enough for image file usage. As an example, this PR also converts WebPDataset to use FileDataset + decode_webp. Note: TIFF file could have multiple images with different shape so it is not a fit for FileDataset as well. Signed-off-by: Yong Tang --- tensorflow_io/core/BUILD | 17 ++++ tensorflow_io/core/kernels/file_input.cc | 92 +++++++++++++++++++ tensorflow_io/core/ops/file_ops.cc | 46 ++++++++++ tensorflow_io/core/python/ops/data_ops.py | 16 ++++ .../image/python/ops/image_dataset_ops.py | 47 ++++------ tests/test_image.py | 41 +++------ 6 files changed, 202 insertions(+), 57 deletions(-) create mode 100644 tensorflow_io/core/kernels/file_input.cc create mode 100644 tensorflow_io/core/ops/file_ops.cc diff --git a/tensorflow_io/core/BUILD b/tensorflow_io/core/BUILD index 2fca41710..a55d73d88 100644 --- a/tensorflow_io/core/BUILD +++ b/tensorflow_io/core/BUILD @@ -57,6 +57,22 @@ cc_library( ], ) +cc_library( + name = "file_ops", + srcs = [ + "kernels/file_input.cc", + "ops/file_ops.cc", + ], + copts = tf_io_copts(), + includes = [ + ".", + ], + linkstatic = True, + deps = [ + ":dataset_ops", + ], +) + cc_library( name = "ffmpeg_3.4", srcs = [ @@ -107,6 +123,7 @@ cc_binary( copts = tf_io_copts(), linkshared = 1, deps = [ + ":file_ops", "//tensorflow_io/audio:audio_ops", "//tensorflow_io/avro:avro_ops", "//tensorflow_io/azure:azfs_ops", diff --git a/tensorflow_io/core/kernels/file_input.cc b/tensorflow_io/core/kernels/file_input.cc new file mode 100644 index 000000000..6155c9de5 --- /dev/null +++ b/tensorflow_io/core/kernels/file_input.cc @@ -0,0 +1,92 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "kernels/dataset_ops.h" +#include "tensorflow/core/lib/io/buffered_inputstream.h" + +namespace tensorflow { +namespace data { + +class FileContentInput: public FileInput { + public: + Status ReadRecord(io::InputStreamInterface* s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { + if (record_to_read != 1) { + return errors::InvalidArgument("FileDataset only accept reading one record at a time"); + } + if (state.get() == nullptr) { + state.reset(new bool(true)); + } else { + // We only read file once. + return Status::OK(); + } + int64 chunk_size = 4096; + SizedRandomAccessInputStreamInterface* sized_stream = dynamic_cast(s); + if (sized_stream != nullptr) { + // First try to find out the size of the file, depending on if size is available, we will set the chunk to read. + uint64 file_size = 0; + if (sized_stream->GetFileSize(&file_size) == Status::OK()) { + chunk_size = file_size; + } + } + std::vector entries; + Status status = Status::OK(); + while (status.ok()) { + string buffer; + status = s->ReadNBytes(chunk_size, &buffer); + if (status.ok() || errors::IsOutOfRange(status)) { + entries.emplace_back(std::move(buffer)); + } + } + if (!errors::IsOutOfRange(status)) { + return status; + } + Tensor value_tensor(ctx->allocator({}), DT_STRING, {1}); + if (entries.size() == 1) { + value_tensor.flat()((*record_read)) = std::move(entries[0]); + } else { + string buffer; + int64 total_size = 0; + for (size_t i = 0; i < entries.size(); i++) { + total_size += entries[i].size(); + } + buffer.reserve(total_size); + for (size_t i = 0; i < entries.size(); i++) { + buffer.append(entries[i]); + } + value_tensor.flat()((*record_read)) = std::move(buffer); + } + (*record_read)++; + out_tensors->emplace_back(std::move(value_tensor)); + return Status::OK(); + } + Status FromStream(io::InputStreamInterface* s) override { + return Status::OK(); + } + void EncodeAttributes(VariantTensorData* data) const override { + } + bool DecodeAttributes(const VariantTensorData& data) override { + return true; + } + protected: +}; + +REGISTER_UNARY_VARIANT_DECODE_FUNCTION(FileContentInput, "tensorflow::data::FileContentInput"); + +REGISTER_KERNEL_BUILDER(Name("FileInput").Device(DEVICE_CPU), + FileInputOp); +REGISTER_KERNEL_BUILDER(Name("FileDataset").Device(DEVICE_CPU), + FileInputDatasetOp); +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/core/ops/file_ops.cc b/tensorflow_io/core/ops/file_ops.cc new file mode 100644 index 000000000..8445121ee --- /dev/null +++ b/tensorflow_io/core/ops/file_ops.cc @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("FileInput") + .Input("source: string") + .Output("handle: variant") + .Attr("filters: list(string) = []") + .Attr("columns: list(string) = []") + .Attr("schema: string = ''") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->MakeShape({c->UnknownDim()})); + return Status::OK(); + }); + +REGISTER_OP("FileDataset") + .Input("input: T") + .Input("batch: int64") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .Attr("T: {string, variant} = DT_VARIANT") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->MakeShape({})); + return Status::OK(); + }); + +} // namespace tensorflow diff --git a/tensorflow_io/core/python/ops/data_ops.py b/tensorflow_io/core/python/ops/data_ops.py index d9d20541b..ec366c4ce 100644 --- a/tensorflow_io/core/python/ops/data_ops.py +++ b/tensorflow_io/core/python/ops/data_ops.py @@ -18,6 +18,7 @@ from __future__ import print_function import tensorflow as tf +from tensorflow_io.core.python.ops import core_ops # Note: BaseDataset could be used by Dataset implementations # that does not utilize DataInput implementation. @@ -60,3 +61,18 @@ def __init__(self, fn, data_input, batch, dtypes, shapes): self._batch, output_types=self._dtypes, output_shapes=self._shapes), self._batch, self._dtypes, self._shapes) + +class FileDataset(BaseDataset): + """A FileDataset that read file content as string""" + + def __init__(self, filename): + """Create a FileDataset.""" + self._data_input = core_ops.file_input(filename) + self._batch = 0 + self._dtypes = [tf.string] + self._shapes = [tf.TensorShape([])] + super(FileDataset, self).__init__(core_ops.file_dataset( + self._data_input, + self._batch, + output_types=self._dtypes, + output_shapes=self._shapes), self._batch, self._dtypes, self._shapes) diff --git a/tensorflow_io/image/python/ops/image_dataset_ops.py b/tensorflow_io/image/python/ops/image_dataset_ops.py index 085c01493..0617621fc 100644 --- a/tensorflow_io/image/python/ops/image_dataset_ops.py +++ b/tensorflow_io/image/python/ops/image_dataset_ops.py @@ -21,39 +21,9 @@ from tensorflow import dtypes from tensorflow.compat.v1 import data from tensorflow_io import _load_library +from tensorflow_io.core.python.ops import data_ops as data_ops image_ops = _load_library('_image_ops.so') - -class WebPDataset(data.Dataset): - """A WebP Image File Dataset that reads the WebP file.""" - - def __init__(self, filenames): - """Create a `WebPDataset`. - - filenames: A `tf.string` tensor containing one or more filenames. - """ - self._filenames = tf.convert_to_tensor( - filenames, dtype=dtypes.string, name="filenames") - super(WebPDataset, self).__init__() - - def _inputs(self): - return [] - - def _as_variant_tensor(self): - return image_ops.web_p_dataset(self._filenames) - - @property - def output_classes(self): - return tf.Tensor - - @property - def output_shapes(self): - return tf.TensorShape([None, None, None]) - - @property - def output_types(self): - return dtypes.uint8 - class TIFFDataset(data.Dataset): """A TIFF Image File Dataset that reads the TIFF file.""" @@ -146,3 +116,18 @@ def draw_bounding_boxes(images, boxes, texts=None, colors=None, name=None): colors = [[]] return image_ops.draw_bounding_boxes_v3( images, boxes, colors, texts, name=name) + +class WebPDataset(data_ops.BaseDataset): + """A WebP Image File Dataset that reads the WebP file.""" + + def __init__(self, filename): + """Create a `WebPDataset`. + + filename: A `tf.string` tensor containing one or more filenames. + """ + self._batch = 0 + self._dtypes = [dtypes.uint8] + self._shapes = [tf.TensorShape([None, None, None])] + self._dataset = data_ops.FileDataset(filename).map(decode_webp) + super(WebPDataset, self).__init__( + self._dataset._variant_tensor, self._batch, self._dtypes, self._shapes) # pylint: disable=protected-access diff --git a/tests/test_image.py b/tests/test_image.py index 424f39f2e..57f7b4a09 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -58,40 +58,29 @@ def test_decode_webp(self): self.assertAllEqual(webp_v, png) - def test_webp_file_dataset(self): """Test case for WebPDataset. """ - width = 400 - height = 301 - channel = 4 - png_file = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "test_image", "sample.png") - with open(png_file, 'rb') as f: - png_contents = f.read() - with self.cached_session(): - image_p = image.decode_png(png_contents, channels=channel) - image_v = image_p.eval() - self.assertEqual(image_v.shape, (height, width, channel)) - filename = os.path.join( os.path.dirname(os.path.abspath(__file__)), "test_image", "sample.webp") num_repeats = 2 - dataset = image_io.WebPDataset([filename]).repeat( - num_repeats) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op) - for _ in range(num_repeats): # Dataset is repeated. - v = sess.run(get_next) - self.assertAllEqual(image_v, v) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + dataset = image_io.WebPDataset([filename, filename]) + # Repeat 2 times (2 * 2 = 4 images) + dataset = dataset.repeat(num_repeats) + # Drop alpha channel + dataset = dataset.map(lambda x: x[:, :, :3]) + # Resize to 224 * 224 + dataset = dataset.map(lambda x: tf.keras.applications.resnet50.preprocess_input(tf.image.resize(x, (224, 224)))) + # Batch to 3, still have 4 images (3 + 1) + dataset = dataset.batch(1) + model = tf.keras.applications.resnet50.ResNet50(weights='imagenet') + y = model.predict(dataset) + p = tf.keras.applications.resnet50.decode_predictions(y, top=1) + for i in p: + assert i[0][1] == 'pineapple' # not truly a pineapple, though + assert len(p) == 4 def test_tiff_file_dataset(self): """Test case for TIFFDataset. From ba6c6ad0509097b8b8a14e3e8ec6d22623881736 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 20 Jul 2019 18:02:50 +0000 Subject: [PATCH 2/5] Remove WebPDataset's C++ implementation, and replace with FileDataset+decode_webp Signed-off-by: Yong Tang --- .../image/kernels/webp_dataset_ops.cc | 145 ------------------ tensorflow_io/image/ops/dataset_ops.cc | 9 -- 2 files changed, 154 deletions(-) diff --git a/tensorflow_io/image/kernels/webp_dataset_ops.cc b/tensorflow_io/image/kernels/webp_dataset_ops.cc index f920883f0..6a5adad1c 100644 --- a/tensorflow_io/image/kernels/webp_dataset_ops.cc +++ b/tensorflow_io/image/kernels/webp_dataset_ops.cc @@ -23,149 +23,6 @@ limitations under the License. namespace tensorflow { namespace data { namespace { -class WebPDatasetOp : public DatasetOpKernel { - public: - using DatasetOpKernel::DatasetOpKernel; - explicit WebPDatasetOp(OpKernelConstruction* ctx) - : DatasetOpKernel(ctx) { - } - void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { - const Tensor* filenames_tensor; - OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor)); - OP_REQUIRES( - ctx, filenames_tensor->dims() <= 1, - errors::InvalidArgument("`filenames` must be a scalar or a vector.")); - - std::vector filenames; - filenames.reserve(filenames_tensor->NumElements()); - for (int i = 0; i < filenames_tensor->NumElements(); ++i) { - filenames.push_back(filenames_tensor->flat()(i)); - } - - *output = new Dataset(ctx, filenames); - } - - private: - class Dataset : public DatasetBase { - public: - Dataset(OpKernelContext* ctx, const std::vector& filenames) - : DatasetBase(DatasetContext(ctx)), - filenames_(filenames) {} - - std::unique_ptr MakeIteratorInternal( - const string& prefix) const override { - return std::unique_ptr( - new Iterator({this, strings::StrCat(prefix, "::WebP")})); - } - - const DataTypeVector& output_dtypes() const override { - static DataTypeVector* dtypes = new DataTypeVector({DT_UINT8}); - return *dtypes; - } - - const std::vector& output_shapes() const override { - static std::vector* shapes = - new std::vector({{-1, -1, -1}}); - return *shapes; - } - - string DebugString() const override { - return "WebPDatasetOp::Dataset"; - } - - protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { - Node* filenames = nullptr; - TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames)); - TF_RETURN_IF_ERROR(b->AddDataset(this, {filenames}, output)); - return Status::OK(); - } - - private: - class Iterator : public DatasetIterator { - public: - explicit Iterator(const Params& params) - : DatasetIterator(params) {} - - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { - mutex_lock l(mu_); - if (current_file_index_ == 0) { - TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); - } - if (current_file_index_ < dataset()->filenames_.size()) { - const string& filename = dataset()->filenames_[current_file_index_]; - uint64 size = 0; - TF_RETURN_IF_ERROR(ctx->env()->GetFileSize(filename, &size)); - std::unique_ptr file; - TF_RETURN_IF_ERROR(ctx->env()->NewRandomAccessFile(filename, &file)); - std::unique_ptr stream(new io::RandomAccessInputStream(file.get())); - string data; - TF_RETURN_IF_ERROR(stream->ReadNBytes(size, &data)); - WebPDecoderConfig config; - WebPInitDecoderConfig(&config); - int returned = WebPGetFeatures(reinterpret_cast(data.c_str()), - size, &config.input); - if (returned != VP8_STATUS_OK) { - return errors::InvalidArgument("File could not be decoded as WebP: ", filename); - } - // TODO (yongtang): Set channel = 4 for now. - static const int channel = 4; - Tensor value_tensor(ctx->allocator({}), DT_UINT8, {config.input.height, config.input.width, channel}); - - config.output.colorspace = MODE_RGBA; - config.output.u.RGBA.rgba = value_tensor.flat().data(); - config.output.u.RGBA.stride = config.input.width * channel; - config.output.u.RGBA.size = config.input.height * config.input.width * channel; - config.output.is_external_memory = 1; - returned = WebPDecode(reinterpret_cast(data.c_str()), size, &config); - if (returned != VP8_STATUS_OK) { - return errors::InvalidArgument("File could not be decoded as WebP: ", filename); - } - out_tensors->emplace_back(std::move(value_tensor)); - *end_of_sequence = false; - ++current_file_index_; - return Status::OK(); - } - *end_of_sequence = true; - return Status::OK(); - } - - protected: - Status SaveInternal(IteratorStateWriter* writer) override { - return errors::Unimplemented("SaveInternal is currently not supported"); - } - - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - return errors::Unimplemented( - "RestoreInternal is currently not supported"); - } - - private: - // Sets up WebP streams to read from the topic at - // `current_file_index_`. - Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (current_file_index_ >= dataset()->filenames_.size()) { - return errors::InvalidArgument( - "current_file_index_:", current_file_index_, - " >= filenames_.size():", dataset()->filenames_.size()); - } - return Status::OK(); - } - - mutex mu_; - size_t current_file_index_ GUARDED_BY(mu_) = 0; - }; - - const std::vector filenames_; - const DataTypeVector output_types_; - }; - DataTypeVector output_types_; -}; class DecodeWebPOp : public OpKernel { public: @@ -205,8 +62,6 @@ class DecodeWebPOp : public OpKernel { // TODO (yongtang): Set channels_ = 4 for now. static const int channels_ = 4; }; -REGISTER_KERNEL_BUILDER(Name("WebPDataset").Device(DEVICE_CPU), - WebPDatasetOp); REGISTER_KERNEL_BUILDER(Name("DecodeWebP").Device(DEVICE_CPU), DecodeWebPOp); diff --git a/tensorflow_io/image/ops/dataset_ops.cc b/tensorflow_io/image/ops/dataset_ops.cc index e2ded05d3..fce6a31e7 100644 --- a/tensorflow_io/image/ops/dataset_ops.cc +++ b/tensorflow_io/image/ops/dataset_ops.cc @@ -19,15 +19,6 @@ limitations under the License. namespace tensorflow { -REGISTER_OP("WebPDataset") - .Input("filenames: string") - .Output("handle: variant") - .SetIsStateful() - .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->MakeShape({c->UnknownDim(), c->UnknownDim(), c->UnknownDim()})); - return Status::OK(); - }); - REGISTER_OP("TIFFDataset") .Input("filenames: string") .Output("handle: variant") From c2636c2612248b4a0bd47e9799d9ffc03a26461f Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 21 Jul 2019 18:03:40 +0000 Subject: [PATCH 3/5] Address review comments Signed-off-by: Yong Tang --- tensorflow_io/core/kernels/file_input.cc | 11 ++++------- tensorflow_io/core/ops/file_ops.cc | 2 +- tensorflow_io/core/python/ops/data_ops.py | 2 +- tensorflow_io/image/python/ops/image_dataset_ops.py | 2 +- 4 files changed, 7 insertions(+), 10 deletions(-) diff --git a/tensorflow_io/core/kernels/file_input.cc b/tensorflow_io/core/kernels/file_input.cc index 6155c9de5..5d8e06780 100644 --- a/tensorflow_io/core/kernels/file_input.cc +++ b/tensorflow_io/core/kernels/file_input.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -42,10 +42,12 @@ class FileContentInput: public FileInput { } std::vector entries; Status status = Status::OK(); + int64 total_size = 0; while (status.ok()) { string buffer; status = s->ReadNBytes(chunk_size, &buffer); if (status.ok() || errors::IsOutOfRange(status)) { + total_size += buffer.size(); entries.emplace_back(std::move(buffer)); } } @@ -57,12 +59,8 @@ class FileContentInput: public FileInput { value_tensor.flat()((*record_read)) = std::move(entries[0]); } else { string buffer; - int64 total_size = 0; - for (size_t i = 0; i < entries.size(); i++) { - total_size += entries[i].size(); - } buffer.reserve(total_size); - for (size_t i = 0; i < entries.size(); i++) { + for (size_t i = 0; i < entries.size(); ++i) { buffer.append(entries[i]); } value_tensor.flat()((*record_read)) = std::move(buffer); @@ -79,7 +77,6 @@ class FileContentInput: public FileInput { bool DecodeAttributes(const VariantTensorData& data) override { return true; } - protected: }; REGISTER_UNARY_VARIANT_DECODE_FUNCTION(FileContentInput, "tensorflow::data::FileContentInput"); diff --git a/tensorflow_io/core/ops/file_ops.cc b/tensorflow_io/core/ops/file_ops.cc index 8445121ee..1f36c6de4 100644 --- a/tensorflow_io/core/ops/file_ops.cc +++ b/tensorflow_io/core/ops/file_ops.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/tensorflow_io/core/python/ops/data_ops.py b/tensorflow_io/core/python/ops/data_ops.py index ec366c4ce..bd08f5883 100644 --- a/tensorflow_io/core/python/ops/data_ops.py +++ b/tensorflow_io/core/python/ops/data_ops.py @@ -63,7 +63,7 @@ def __init__(self, fn, data_input, batch, dtypes, shapes): output_shapes=self._shapes), self._batch, self._dtypes, self._shapes) class FileDataset(BaseDataset): - """A FileDataset that read file content as string""" + """A FileDataset that reads file content as string""" def __init__(self, filename): """Create a FileDataset.""" diff --git a/tensorflow_io/image/python/ops/image_dataset_ops.py b/tensorflow_io/image/python/ops/image_dataset_ops.py index 0617621fc..cb154ddce 100644 --- a/tensorflow_io/image/python/ops/image_dataset_ops.py +++ b/tensorflow_io/image/python/ops/image_dataset_ops.py @@ -21,7 +21,7 @@ from tensorflow import dtypes from tensorflow.compat.v1 import data from tensorflow_io import _load_library -from tensorflow_io.core.python.ops import data_ops as data_ops +from tensorflow_io.core.python.ops import data_ops image_ops = _load_library('_image_ops.so') class TIFFDataset(data.Dataset): From 0aa22a443d7a17673ed6b9fd3b0d4c2e90eb2e54 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Mon, 22 Jul 2019 19:38:37 +0000 Subject: [PATCH 4/5] Move tf.keras to separate function in test Signed-off-by: Yong Tang --- tests/test_image.py | 48 ++++++++++++++++++++++----------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/tests/test_image.py b/tests/test_image.py index 57f7b4a09..42f8c5a18 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -58,30 +58,6 @@ def test_decode_webp(self): self.assertAllEqual(webp_v, png) - def test_webp_file_dataset(self): - """Test case for WebPDataset. - """ - filename = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "test_image", "sample.webp") - - num_repeats = 2 - - dataset = image_io.WebPDataset([filename, filename]) - # Repeat 2 times (2 * 2 = 4 images) - dataset = dataset.repeat(num_repeats) - # Drop alpha channel - dataset = dataset.map(lambda x: x[:, :, :3]) - # Resize to 224 * 224 - dataset = dataset.map(lambda x: tf.keras.applications.resnet50.preprocess_input(tf.image.resize(x, (224, 224)))) - # Batch to 3, still have 4 images (3 + 1) - dataset = dataset.batch(1) - model = tf.keras.applications.resnet50.ResNet50(weights='imagenet') - y = model.predict(dataset) - p = tf.keras.applications.resnet50.decode_predictions(y, top=1) - for i in p: - assert i[0][1] == 'pineapple' # not truly a pineapple, though - assert len(p) == 4 - def test_tiff_file_dataset(self): """Test case for TIFFDataset. """ @@ -198,5 +174,29 @@ def test_draw_bounding_box(self): # self.assertAllEqual(bb_image_v, ex_image_v) _ = bb_image_p.eval() +def test_webp_file_dataset(): + """Test case for WebPDataset. + """ + filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "test_image", "sample.webp") + + num_repeats = 2 + + dataset = image_io.WebPDataset([filename, filename]) + # Repeat 2 times (2 * 2 = 4 images) + dataset = dataset.repeat(num_repeats) + # Drop alpha channel + dataset = dataset.map(lambda x: x[:, :, :3]) + # Resize to 224 * 224 + dataset = dataset.map(lambda x: tf.keras.applications.resnet50.preprocess_input(tf.image.resize(x, (224, 224)))) + # Batch to 3, still have 4 images (3 + 1) + dataset = dataset.batch(1) + model = tf.keras.applications.resnet50.ResNet50(weights='imagenet') + y = model.predict(dataset) + p = tf.keras.applications.resnet50.decode_predictions(y, top=1) + for i in p: + assert i[0][1] == 'pineapple' # not truly a pineapple, though + assert len(p) == 4 + if __name__ == "__main__": test.main() From 00ee7961e12e357733c8ee416988a36c6f05f2c4 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Mon, 22 Jul 2019 21:18:05 +0000 Subject: [PATCH 5/5] Split tests in eager and non-eager mode Signed-off-by: Yong Tang --- tests/test_image.py | 24 ----------------- tests/test_image_eager.py | 54 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 24 deletions(-) create mode 100644 tests/test_image_eager.py diff --git a/tests/test_image.py b/tests/test_image.py index 42f8c5a18..9991c9742 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -174,29 +174,5 @@ def test_draw_bounding_box(self): # self.assertAllEqual(bb_image_v, ex_image_v) _ = bb_image_p.eval() -def test_webp_file_dataset(): - """Test case for WebPDataset. - """ - filename = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "test_image", "sample.webp") - - num_repeats = 2 - - dataset = image_io.WebPDataset([filename, filename]) - # Repeat 2 times (2 * 2 = 4 images) - dataset = dataset.repeat(num_repeats) - # Drop alpha channel - dataset = dataset.map(lambda x: x[:, :, :3]) - # Resize to 224 * 224 - dataset = dataset.map(lambda x: tf.keras.applications.resnet50.preprocess_input(tf.image.resize(x, (224, 224)))) - # Batch to 3, still have 4 images (3 + 1) - dataset = dataset.batch(1) - model = tf.keras.applications.resnet50.ResNet50(weights='imagenet') - y = model.predict(dataset) - p = tf.keras.applications.resnet50.decode_predictions(y, top=1) - for i in p: - assert i[0][1] == 'pineapple' # not truly a pineapple, though - assert len(p) == 4 - if __name__ == "__main__": test.main() diff --git a/tests/test_image_eager.py b/tests/test_image_eager.py new file mode 100644 index 000000000..ca197f71a --- /dev/null +++ b/tests/test_image_eager.py @@ -0,0 +1,54 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for Image Dataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import tensorflow as tf +if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")): + tf.compat.v1.enable_eager_execution() +import tensorflow_io.image as image_io # pylint: disable=wrong-import-position + + +def test_webp_file_dataset(): + """Test case for WebPDataset. + """ + filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "test_image", "sample.webp") + + num_repeats = 2 + + dataset = image_io.WebPDataset([filename, filename]) + # Repeat 2 times (2 * 2 = 4 images) + dataset = dataset.repeat(num_repeats) + # Drop alpha channel + dataset = dataset.map(lambda x: x[:, :, :3]) + # Resize to 224 * 224 + dataset = dataset.map(lambda x: tf.keras.applications.resnet50.preprocess_input(tf.image.resize(x, (224, 224)))) + # Batch to 3, still have 4 images (3 + 1) + dataset = dataset.batch(1) + model = tf.keras.applications.resnet50.ResNet50(weights='imagenet') + y = model.predict(dataset) + p = tf.keras.applications.resnet50.decode_predictions(y, top=1) + for i in p: + assert i[0][1] == 'pineapple' # not truly a pineapple, though + assert len(p) == 4 + +if __name__ == "__main__": + test.main()