diff --git a/tensorflow_io/core/BUILD b/tensorflow_io/core/BUILD index 07bdc0b36..9e38c53f5 100644 --- a/tensorflow_io/core/BUILD +++ b/tensorflow_io/core/BUILD @@ -61,6 +61,7 @@ cc_library( cc_library( name = "core_ops", srcs = [ + "kernels/archive_kernels.cc", "kernels/rebatch_dataset_op.cc", "ops/core_ops.cc", ], diff --git a/tensorflow_io/core/kernels/archive_kernels.cc b/tensorflow_io/core/kernels/archive_kernels.cc new file mode 100644 index 000000000..2149687b9 --- /dev/null +++ b/tensorflow_io/core/kernels/archive_kernels.cc @@ -0,0 +1,243 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/io/inputstream_interface.h" +#include "tensorflow/core/lib/io/random_inputstream.h" +#include "tensorflow/core/lib/io/zlib_compression_options.h" +#include "tensorflow/core/lib/io/zlib_inputstream.h" +#include "tensorflow_io/core/kernels/stream.h" + +namespace tensorflow { +namespace data { +namespace { + +class ArchiveRandomAccessFile : public SizedRandomAccessFile { +public: + ArchiveRandomAccessFile(Env* env, const string& filename, const void* optional_memory_buff, const size_t optional_memory_size) : SizedRandomAccessFile(env, filename, optional_memory_buff, optional_memory_size) {} + ~ArchiveRandomAccessFile() {} + static ssize_t CallbackRead(struct archive *a, void *client_data, const void **buff) { + class ArchiveRandomAccessFile *p = (class ArchiveRandomAccessFile *)client_data; + StringPiece data(p->callback_read_buffer_, sizeof(p->callback_read_buffer_)); + Status s = p->Read(p->callback_read_offset_, sizeof(p->callback_read_buffer_), &data, p->callback_read_buffer_); + if (!s.ok()) { + if (!errors::IsOutOfRange(s)) { + return -1; + } + } + p->callback_read_offset_ += data.size(); + *buff = p->callback_read_buffer_; + return data.size(); + } + // CallbackRead + char callback_read_buffer_[4096]; + int64 callback_read_offset_ = 0; +}; + + +class ListArchiveEntriesOp : public OpKernel { + public: + explicit ListArchiveEntriesOp(OpKernelConstruction* context) : OpKernel(context) { + env_ = context->env(); + OP_REQUIRES_OK(context, context->GetAttr("filters", &filters_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& filename_tensor = context->input(0); + const string filename = filename_tensor.scalar()(); + + const Tensor& memory_tensor = context->input(1); + const string memory = memory_tensor.scalar()(); + + std::unique_ptr file(new ArchiveRandomAccessFile(env_, filename, memory.data(), memory.size())); + uint64 size; + OP_REQUIRES_OK(context, file->GetFileSize(&size)); + + std::unique_ptr archive(archive_read_new(), [](struct archive *a){ archive_read_free(a);}); + for (const string& filter: filters_) { + if (filter == "none") { + archive_read_support_filter_none(archive.get()); + archive_read_support_format_raw(archive.get()); + } + if (filter == "gz") { + archive_read_support_filter_gzip(archive.get()); + archive_read_support_format_raw(archive.get()); + } + if (filter == "tar.gz") { + archive_read_support_filter_gzip(archive.get()); + archive_read_support_format_tar(archive.get()); + } + } + + OP_REQUIRES( + context, (archive_read_open(archive.get(), file.get(), NULL, ArchiveRandomAccessFile::CallbackRead, NULL) == ARCHIVE_OK), + errors::InvalidArgument("unable to open datainput for ", filename, ": ", archive_error_string(archive.get()))); + + string format; + std::vector entries; + struct archive_entry *entry; + while (archive_read_next_header(archive.get(), &entry) == ARCHIVE_OK) { + string entryname = archive_entry_pathname(entry); + entries.emplace_back(entryname); + + string archive_format(archive_format_name(archive.get())); + string archive_filter = (archive_filter_count(archive.get()) > 0) ? archive_filter_name(archive.get(), 0) : ""; + // Find out format + if (format == "") { + for (const string& filter : filters_) { + if (filter == "none") { + if (archive_format == "raw" && archive_filter == "none") { + format = "none"; + break; + } + } + if (filter == "gz") { + if (archive_format == "raw" && archive_filter == "gzip") { + format = "gz"; + break; + } + } + if (filter == "tar.gz") { + if (archive_format == "GNU tar format" && archive_filter == "gzip") { + format = "tar.gz"; + break; + } + } + } + // We are not able to find out the supported + OP_REQUIRES(context, format != "", errors::InvalidArgument("unsupported archive: ", archive_format, "|", archive_filter)); + } + } + + Tensor* format_tensor; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), &format_tensor)); + format_tensor->scalar()() = format; + + Tensor* entries_tensor; + OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({static_cast(entries.size())}), &entries_tensor)); + + for (size_t i = 0; i < entries.size(); i++) { + entries_tensor->flat()(i) = entries[i]; + } + } + private: + mutex mu_; + Env* env_ GUARDED_BY(mu_); + std::vector filters_ GUARDED_BY(mu_); +}; + +class ReadArchiveOp : public OpKernel { + public: + explicit ReadArchiveOp(OpKernelConstruction* context) : OpKernel(context) { + env_ = context->env(); + } + + void Compute(OpKernelContext* context) override { + const Tensor& filename_tensor = context->input(0); + const string filename = filename_tensor.scalar()(); + + const Tensor& format_tensor = context->input(1); + const string format = format_tensor.scalar()(); + + const Tensor& entries_tensor = context->input(2); + std::unordered_map entries; + for (int64 i = 0; i < entries_tensor.NumElements(); i++) { + OP_REQUIRES(context, entries.find(entries_tensor.flat()(i)) == entries.end(), errors::InvalidArgument("duplicate entries: ", entries_tensor.flat()(i))); + entries[entries_tensor.flat()(i)] = i; + } + + const Tensor& memory_tensor = context->input(3); + const string memory = memory_tensor.scalar()(); + + Tensor* output_tensor; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({static_cast(entries.size())}), &output_tensor)); + + std::unique_ptr file(new ArchiveRandomAccessFile(env_, filename, memory.data(), memory.size())); + uint64 size; + OP_REQUIRES_OK(context, file->GetFileSize(&size)); + + if (format == "none") { + // Treat none as normal file. + string output_string; + output_string.resize(size); + StringPiece result; + OP_REQUIRES_OK(context, file->Read(0, size, &result, &output_string[0])); + output_tensor->flat()(0) = std::move(output_string); + return; + } + + + std::unique_ptr archive(archive_read_new(), [](struct archive *a){ archive_read_free(a);}); + if (format == "gz") { + // Treat gz file specially. Looks like libarchive always have issue + // with text file so use ZlibInputStream. Now libarchive + // is mostly used for archive (not compressio). + io::RandomAccessInputStream file_stream(file.get()); + io::ZlibCompressionOptions zlib_compression_options = zlib_compression_options = io::ZlibCompressionOptions::GZIP(); + io::ZlibInputStream compression_stream(&file_stream, 65536, 65536, zlib_compression_options); + string output_string; + Status status = compression_stream.ReadNBytes(INT_MAX, &output_string); + output_tensor->flat()(0) = std::move(output_string); + return; + } + + if (format == "tar.gz") { + archive_read_support_filter_gzip(archive.get()); + archive_read_support_format_tar(archive.get()); + } else { + OP_REQUIRES(context, false, errors::InvalidArgument("unsupported format: ", format)); + } + + OP_REQUIRES( + context, (archive_read_open(archive.get(), file.get(), NULL, ArchiveRandomAccessFile::CallbackRead, NULL) == ARCHIVE_OK), + errors::InvalidArgument("unable to open datainput for ", filename, ": ", archive_error_string(archive.get()))); + + struct archive_entry *entry; + while (archive_read_next_header(archive.get(), &entry) == ARCHIVE_OK) { + string entryname = archive_entry_pathname(entry); + if (entries.find(entryname) != entries.end()) { + size_t bytes_to_read = archive_entry_size(entry); + string output_string; + output_string.resize(bytes_to_read); + size_t bytes_read = 0; + while (bytes_read < bytes_to_read) { + ssize_t size = archive_read_data(archive.get(), &output_string[bytes_read], bytes_to_read - bytes_read); + if (size == 0) { + break; + } + bytes_read += size; + } + output_tensor->flat()(entries[entryname]) = std::move(output_string); + } + } + + } + private: + mutex mu_; + Env* env_ GUARDED_BY(mu_); +}; + +REGISTER_KERNEL_BUILDER(Name("ListArchiveEntries").Device(DEVICE_CPU), + ListArchiveEntriesOp); + +REGISTER_KERNEL_BUILDER(Name("ReadArchive").Device(DEVICE_CPU), + ReadArchiveOp); + + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/core/ops/core_ops.cc b/tensorflow_io/core/ops/core_ops.cc index 051b18a33..61760e457 100644 --- a/tensorflow_io/core/ops/core_ops.cc +++ b/tensorflow_io/core/ops/core_ops.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -34,4 +34,27 @@ REGISTER_OP("AdjustBatchDataset") TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); return shape_inference::ScalarShape(c); }); +REGISTER_OP("ListArchiveEntries") + .Input("filename: string") + .Input("memory: string") + .Output("format: string") + .Output("entries: string") + .Attr("filters: list(string) = []") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + c->set_output(1, c->MakeShape({c->UnknownDim()})); + return Status::OK(); + }); + +REGISTER_OP("ReadArchive") + .Input("filename: string") + .Input("format: string") + .Input("entries: string") + .Input("memory: string") + .Output("output: string") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->MakeShape({c->UnknownDim()})); + return Status::OK(); + }); + } // namespace tensorflow diff --git a/tensorflow_io/core/python/ops/archive_ops.py b/tensorflow_io/core/python/ops/archive_ops.py new file mode 100644 index 000000000..918d50187 --- /dev/null +++ b/tensorflow_io/core/python/ops/archive_ops.py @@ -0,0 +1,34 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Archive.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow_io.core.python.ops import core_ops + +def list_archive_entries(filename, filters, **kwargs): + """list_archive_entries""" + memory = kwargs.get("memory", "") + if not isinstance(filters, list): + filters = [filters] + return core_ops.list_archive_entries( + filename, filters=filters, memory=memory) + +def read_archive(filename, format, entries, **kwargs): # pylint: disable=redefined-builtin + """read_archive""" + memory = kwargs.get("memory", "") + return core_ops.read_archive( + filename, format, entries, memory=memory) diff --git a/tests/test_archive_eager.py b/tests/test_archive_eager.py new file mode 100644 index 000000000..a68b12b6d --- /dev/null +++ b/tests/test_archive_eager.py @@ -0,0 +1,121 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for read_archive.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tensorflow as tf +if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")): + tf.compat.v1.enable_eager_execution() +import tensorflow_io.core.python.ops.archive_ops as archive_io # pylint: disable=wrong-import-position + +def test_gz(): + """test_archive""" + filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_parquet", + "parquet_cpp_example.parquet.gz") + filename = "file://" + filename + + format, entries = archive_io.list_archive_entries(filename, ["gz", "tar.gz"]) # pylint: disable=redefined-builtin + assert format.numpy().decode() == "gz" + assert entries.shape == [1] + + elements = archive_io.read_archive(filename, format, entries) + assert elements.shape == [1] + + expected_filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_parquet", + "parquet_cpp_example.parquet") + expected_filename = "file://" + expected_filename + + assert elements[0].numpy() == tf.io.read_file(expected_filename).numpy() + +def test_none(): + """test_none""" + filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_parquet", + "parquet_cpp_example.parquet") + filename = "file://" + filename + + format, entries = archive_io.list_archive_entries(filename, ["none", "gz"]) # pylint: disable=redefined-builtin + assert format.numpy().decode() == "none" + assert entries.shape == [1] + + elements = archive_io.read_archive(filename, format, entries) + assert elements.shape == [1] + + assert elements[0].numpy() == tf.io.read_file(filename).numpy() + +def test_tar_gz(): + """test_tar_gz""" + filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_parquet", + "parquet_cpp_example.parquet.tar.gz") + filename = "file://" + filename + + format, entries = archive_io.list_archive_entries(filename, ["gz", "tar.gz"]) # pylint: disable=redefined-builtin + assert format.numpy().decode() == "tar.gz" + assert entries.shape == [2] + assert entries[0].numpy().decode() == "parquet_cpp_example.parquet.1" + assert entries[1].numpy().decode() == "parquet_cpp_example.parquet.2" + + elements = archive_io.read_archive(filename, format, entries) + assert elements.shape == [2] + + expected_filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_parquet", + "parquet_cpp_example.parquet") + expected_filename = "file://" + expected_filename + + assert elements[0].numpy() == tf.io.read_file(expected_filename).numpy() + assert elements[1].numpy() == tf.io.read_file(expected_filename).numpy() + +def test_dataset(): + """test_dataset""" + filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_parquet", + "parquet_cpp_example.parquet.tar.gz") + filename = "file://" + filename + + # This is a demo implementation of ArchiveDataset + dataset = tf.compat.v2.data.Dataset.from_tensor_slices(([filename])).map( + lambda f: () + (f,) + archive_io.list_archive_entries(f, "tar.gz") + ).map( + lambda f, format, e: (tf.broadcast_to(f, tf.shape(e)), tf.broadcast_to(format, tf.shape(e)), e) + ).apply(tf.data.experimental.unbatch()).map(archive_io.read_archive) + + expected_filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_parquet", + "parquet_cpp_example.parquet") + expected_filename = "file://" + expected_filename + + i = 0 + for entry in dataset: + assert entry.numpy() == tf.io.read_file(expected_filename).numpy() + i += 1 + assert i == 2 + +if __name__ == "__main__": + test.main() diff --git a/tests/test_parquet/parquet_cpp_example.parquet.tar.gz b/tests/test_parquet/parquet_cpp_example.parquet.tar.gz new file mode 100644 index 000000000..a3cc77405 Binary files /dev/null and b/tests/test_parquet/parquet_cpp_example.parquet.tar.gz differ