diff --git a/tensorflow_io/image/BUILD b/tensorflow_io/image/BUILD index d04b7cdfa..635d9f618 100644 --- a/tensorflow_io/image/BUILD +++ b/tensorflow_io/image/BUILD @@ -6,6 +6,7 @@ cc_binary( name = "python/ops/_image_ops.so", srcs = [ "kernels/gif_dataset_ops.cc", + "kernels/image_dataset_ops.cc", "kernels/tiff_dataset_ops.cc", "kernels/webp_dataset_ops.cc", "ops/dataset_ops.cc", @@ -17,7 +18,9 @@ cc_binary( ], linkshared = 1, deps = [ + "//tensorflow_io/core:dataset_ops", "@giflib", + "@libarchive", "@libtiff", "@libwebp", "@local_config_tf//:libtensorflow_framework", diff --git a/tensorflow_io/image/__init__.py b/tensorflow_io/image/__init__.py index 5e175e1ab..32d5371c0 100644 --- a/tensorflow_io/image/__init__.py +++ b/tensorflow_io/image/__init__.py @@ -14,6 +14,7 @@ # ============================================================================== """Image Dataset. +@@ImageDataset @@WebPDataset @@TIFFDataset @@GIFDataset @@ -24,14 +25,17 @@ from __future__ import division from __future__ import print_function +from tensorflow_io.image.python.ops.image_dataset_ops import ImageDataset from tensorflow_io.image.python.ops.image_dataset_ops import WebPDataset from tensorflow_io.image.python.ops.image_dataset_ops import TIFFDataset from tensorflow_io.image.python.ops.image_dataset_ops import GIFDataset from tensorflow_io.image.python.ops.image_dataset_ops import decode_webp + from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ + "ImageDataset", "WebPDataset", "TIFFDataset", "GIFDataset", diff --git a/tensorflow_io/image/kernels/image_dataset_ops.cc b/tensorflow_io/image/kernels/image_dataset_ops.cc new file mode 100644 index 000000000..e016650e0 --- /dev/null +++ b/tensorflow_io/image/kernels/image_dataset_ops.cc @@ -0,0 +1,225 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "kernels/dataset_ops.h" + +#include "webp/encode.h" +#include "imageio/webpdec.h" +#include "imageio/metadata.h" +extern "C" { +#include "tiff.h" +#include "tiffio.h" +} +#include "tiffio.hxx" +#include + +namespace tensorflow { +namespace data { + +class ImageStream { + public: + explicit ImageStream() + : eof_(false) + , tiff_(nullptr, TIFFClose) { + } + explicit ImageStream(io::InputStreamInterface& in, const string& header, const Status &status) + : eof_(false) + , tiff_(nullptr, TIFFClose) + , stream_(header, std::ios_base::ate | std::ios_base::in | std::ios_base::out) { + Status s; + do { + string buffer; + s = in.ReadNBytes(4096, &buffer); + if (s.ok() || errors::IsOutOfRange(s)) { + stream_ << buffer; + } + } while (s.ok()); + tiff_.reset(TIFFStreamOpen("[in memory]", static_cast(&stream_))); + } + int64 eof_ = false; + std::unique_ptr tiff_; + private: + std::stringstream stream_; +}; + +class ImageInput: public DataInput { + public: + Status ReadRecord(io::InputStreamInterface& s, IteratorContext* ctx, std::unique_ptr& state, int64* returned, std::vector* out_tensors) const override { + if (format_ == "webp") { + if (state.get() == nullptr) { + state.reset(new ImageStream()); + } + if (state->eof_) { + *returned = 0; + return Status::OK(); + } + string buffer; + TF_RETURN_IF_ERROR(s.ReadNBytes(filesize_, &buffer)); + + int64 height = shape_[0]; + int64 width = shape_[1]; + int64 channel = shape_[2]; + + WebPDecoderConfig config; + WebPInitDecoderConfig(&config); + int r = WebPGetFeatures(reinterpret_cast(buffer.data()), buffer.size(), &config.input); + if (r != VP8_STATUS_OK) { + return errors::InvalidArgument("file could not be featured as WebP: ", r); + } + + if (height != config.input.height || width != config.input.width) { + return errors::InvalidArgument("height and width (", config.input.height, ", ", config.input.width, ") does not match data before (", height, ", ", width, ")"); + } + + *returned = 1; + state->eof_ = true; + + Tensor value_tensor(ctx->allocator({}), DT_UINT8, {height, width, channel}); + + config.output.colorspace = MODE_RGBA; + config.output.u.RGBA.rgba = value_tensor.flat().data(); + config.output.u.RGBA.stride = width * channel; + config.output.u.RGBA.size = height * width * channel; + config.output.is_external_memory = 1; + r = WebPDecode(reinterpret_cast(buffer.data()), buffer.size(), &config); + if (r != VP8_STATUS_OK) { + return errors::InvalidArgument("file could not be decoded as WebP: ", r); + } + + out_tensors->emplace_back(std::move(value_tensor)); + return Status::OK(); + } + else if (format_ == "tiff") { + if (state.get() == nullptr) { + state.reset(new ImageStream(s, "", Status::OK())); + } + if (state->eof_) { + *returned = 0; + return Status::OK(); + } + + int64 height = shape_[0]; + int64 width = shape_[1]; + int64 channel = shape_[2]; + + Tensor value_tensor(ctx->allocator({}), DT_UINT8, {height, width, channel}); + // Tensor is aligned + uint32* raster = reinterpret_cast(value_tensor.flat().data()); + if (!TIFFReadRGBAImageOriented(state->tiff_.get(), width, height, raster, ORIENTATION_TOPLEFT, 0)) { + return errors::InvalidArgument("unable to process tiff"); + } + out_tensors->emplace_back(std::move(value_tensor)); + + *returned = 1; + if (!TIFFReadDirectory(state->tiff_.get())) { + state->eof_ = true; + } + return Status::OK(); + } + return errors::Unimplemented("format ", format_, "has not been supported yet"); + } + Status FromStream(io::InputStreamInterface& s) override { + string header; + Status status = s.ReadNBytes(4096, &header); + if (!(status.ok() || errors::IsOutOfRange(status))) { + return status; + } + if (header.size() >= 12 && memcmp(&header.data()[0], "RIFF", 4) == 0 && memcmp(&header.data()[8], "WEBP", 4) == 0) { + // 4k should be enough to capture WebP header... If not we will adjust later + WebPDecoderConfig config; + WebPInitDecoderConfig(&config); + int returned = WebPGetFeatures(reinterpret_cast(header.data()), header.size(), &config.input); + if (returned != VP8_STATUS_OK) { + return errors::InvalidArgument("file could not be decoded from stream as WebP: ", returned); + } + // Note: Always decode with channel = 4. + int32 height = config.input.height; + int32 width = config.input.width; + static const int32 channel = 4; + // Skip to the end to find out the size of WebP as we need it in the next run. + Status status = s.SkipNBytes(std::numeric_limits::max()); + if (!(status.ok() || errors::IsOutOfRange(status))) { + return status; + } + + shape_ = absl::InlinedVector({height, width, channel}); + filesize_ = s.Tell(); + format_ = "webp"; + + return Status::OK(); + } else if (header.size() >= 4 && memcmp(&header.data()[0], "II*\0", 4) == 0) { + // Read everything. + ImageStream is(s, header, status); + if (is.tiff_.get() == nullptr) { + return errors::InvalidArgument("unable to open file"); + } + uint32 width, height; + TIFFGetField(is.tiff_.get(), TIFFTAG_IMAGEWIDTH, &width); + TIFFGetField(is.tiff_.get(), TIFFTAG_IMAGELENGTH, &height); + static int64 channel = 4; + shape_ = absl::InlinedVector({height, width, channel}); + filesize_ = s.Tell(); + format_ = "tiff"; + + return Status::OK(); + } + return errors::InvalidArgument("unknown image file format"); + } + void EncodeAttributes(VariantTensorData* data) const override { + data->tensors_.emplace_back(Tensor(DT_INT64, TensorShape({3}))); + data->tensors_.back().flat()(0) = shape_[0]; + data->tensors_.back().flat()(1) = shape_[1]; + data->tensors_.back().flat()(2) = shape_[2]; + + data->tensors_.emplace_back(Tensor(DT_INT64, TensorShape({}))); + data->tensors_.back().scalar()() = filesize_; + + data->tensors_.emplace_back(Tensor(DT_STRING, TensorShape({}))); + data->tensors_.back().scalar()() = format_; + } + bool DecodeAttributes(const VariantTensorData& data) override { + size_t format_index = data.tensors().size() - 1; + format_ = data.tensors(format_index).scalar()(); + + size_t filesize_index = data.tensors().size() - 2; + filesize_ = data.tensors(filesize_index).scalar()(); + + size_t shape_index = data.tensors().size() - 3; + shape_ = absl::InlinedVector({ + data.tensors(shape_index).flat()(0), + data.tensors(shape_index).flat()(1), + data.tensors(shape_index).flat()(2), + }); + + return true; + } + const string& format() const { + return format_; + } + protected: + absl::InlinedVector shape_; + int64 filesize_; + string format_; +}; + +REGISTER_UNARY_VARIANT_DECODE_FUNCTION(ImageInput, "tensorflow::data::ImageInput"); + +REGISTER_KERNEL_BUILDER(Name("ImageInput").Device(DEVICE_CPU), + DataInputOp); +REGISTER_KERNEL_BUILDER(Name("ImageDataset").Device(DEVICE_CPU), + InputDatasetOp); + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/image/ops/dataset_ops.cc b/tensorflow_io/image/ops/dataset_ops.cc index cff41872f..e592c3b6c 100644 --- a/tensorflow_io/image/ops/dataset_ops.cc +++ b/tensorflow_io/image/ops/dataset_ops.cc @@ -58,4 +58,24 @@ REGISTER_OP("DecodeWebP") return Status::OK(); }); +REGISTER_OP("ImageDataset") + .Input("input: T") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .Attr("T: {string, variant} = DT_VARIANT") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->MakeShape({c->UnknownDim(), c->UnknownDim(), c->UnknownDim(), c->UnknownDim()})); + return Status::OK(); + }); + +REGISTER_OP("ImageInput") + .Input("source: string") + .Output("handle: variant") + .Attr("filters: list(string) = []") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->MakeShape({c->UnknownDim()})); + return Status::OK(); + }); } // namespace tensorflow diff --git a/tensorflow_io/image/python/ops/image_dataset_ops.py b/tensorflow_io/image/python/ops/image_dataset_ops.py index c2f6be4ba..4d22c7d1e 100644 --- a/tensorflow_io/image/python/ops/image_dataset_ops.py +++ b/tensorflow_io/image/python/ops/image_dataset_ops.py @@ -23,6 +23,39 @@ from tensorflow_io import _load_library image_ops = _load_library('_image_ops.so') +class ImageDataset(data.Dataset): + """An Image Dataset + """ + + def __init__(self, filename): + """Create an ImageDataset. + + Args: + filename: A `tf.string` tensor containing one or more filenames. + """ + self._data_input = image_ops.image_input(filename) + super(ImageDataset, self).__init__() + + def _inputs(self): + return [] + + def _as_variant_tensor(self): + return image_ops.image_dataset( + self._data_input, + output_types=self.output_types, + output_shapes=self.output_shapes) + + @property + def output_classes(self): + return tensorflow.Tensor + + @property + def output_shapes(self): + return tuple([tensorflow.TensorShape([None, None, None])]) + + @property + def output_types(self): + return tuple([dtypes.uint8]) class WebPDataset(data.Dataset): """A WebP Image File Dataset that reads the WebP file.""" diff --git a/tests/test_image.py b/tests/test_image.py index 4fcd16326..91b4ab7c7 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -172,5 +172,83 @@ def test_gif_file_dataset(self): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def test_webp_file_image_dataset(self): + """Test case for ImageDataset. + """ + width = 400 + height = 301 + channel = 4 + png_file = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "test_image", "sample.png") + with open(png_file, 'rb') as f: + png_contents = f.read() + with self.cached_session(): + image_p = image.decode_png(png_contents, channels=channel) + image_v = image_p.eval() + self.assertEqual(image_v.shape, (height, width, channel)) + + filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "test_image", "sample.webp") + + num_repeats = 2 + + dataset = image_io.ImageDataset([filename]).repeat( + num_repeats) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op) + for _ in range(num_repeats): + v = sess.run(get_next) + self.assertAllEqual(image_v, v) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def test_tiff_file_image_dataset(self): + """Test case for ImageDataset. + """ + width = 560 + height = 320 + channels = 4 + + images = [] + for filename in [ + "small-00.png", + "small-01.png", + "small-02.png", + "small-03.png", + "small-04.png"]: + with open( + os.path.join(os.path.dirname(os.path.abspath(__file__)), + "test_image", + filename), 'rb') as f: + png_contents = f.read() + with self.cached_session(): + image_p = image.decode_png(png_contents, channels=channels) + image_v = image_p.eval() + self.assertEqual(image_v.shape, (height, width, channels)) + images.append(image_v) + + filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "test_image", "small.tiff") + filename = "file://" + filename + + num_repeats = 2 + + dataset = image_io.ImageDataset([filename]).repeat(num_repeats) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + with self.cached_session() as sess: + sess.run(init_op) + for _ in range(num_repeats): + for i in range(5): + v = sess.run(get_next) + self.assertAllEqual(images[i], v) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + if __name__ == "__main__": test.main() diff --git a/third_party/libtiff.BUILD b/third_party/libtiff.BUILD index 6f55df923..b590c84c2 100644 --- a/third_party/libtiff.BUILD +++ b/third_party/libtiff.BUILD @@ -10,6 +10,7 @@ cc_library( srcs = glob( [ "libtiff/tif_*.c", + "libtiff/tif_stream.cxx", ], exclude = [ "libtiff/tif_win32.c", @@ -17,6 +18,7 @@ cc_library( ), hdrs = glob([ "libtiff/*.h", + "libtiff/tiffio.hxx", ]) + [ "libtiff/tif_config.h", "libtiff/tiffconf.h",