diff --git a/tensorflow_io/core/python/ops/io_tensor.py b/tensorflow_io/core/python/ops/io_tensor.py index d25531695..abc18a2a2 100644 --- a/tensorflow_io/core/python/ops/io_tensor.py +++ b/tensorflow_io/core/python/ops/io_tensor.py @@ -29,6 +29,7 @@ from tensorflow_io.core.python.ops import csv_io_tensor_ops from tensorflow_io.core.python.ops import avro_io_tensor_ops from tensorflow_io.core.python.ops import ffmpeg_io_tensor_ops +from tensorflow_io.core.python.ops import parquet_io_tensor_ops class IOTensor(io_tensor_ops._IOTensor): # pylint: disable=protected-access """IOTensor @@ -422,3 +423,20 @@ def from_ffmpeg(cls, with tf.name_scope(kwargs.get("name", "IOFromFFmpeg")): return ffmpeg_io_tensor_ops.FFmpegIOTensor( filename, internal=True) + + @classmethod + def from_parquet(cls, + filename, + **kwargs): + """Creates an `IOTensor` from a parquet file. + + Args: + filename: A string, the filename of a parquet file. + name: A name prefix for the IOTensor (optional). + + Returns: + A `IOTensor`. + + """ + with tf.name_scope(kwargs.get("name", "IOFromParquet")): + return parquet_io_tensor_ops.ParquetIOTensor(filename, internal=True) diff --git a/tensorflow_io/core/python/ops/parquet_io_tensor_ops.py b/tensorflow_io/core/python/ops/parquet_io_tensor_ops.py new file mode 100644 index 000000000..408c329a5 --- /dev/null +++ b/tensorflow_io/core/python/ops/parquet_io_tensor_ops.py @@ -0,0 +1,60 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ParquetIOTensor""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import uuid + +import tensorflow as tf +from tensorflow_io.core.python.ops import io_tensor_ops +from tensorflow_io.core.python.ops import core_ops + +class ParquetIOTensor(io_tensor_ops._TableIOTensor): # pylint: disable=protected-access + """ParquetIOTensor""" + + #============================================================================= + # Constructor (private) + #============================================================================= + def __init__(self, + filename, + capacity=None, + internal=False): + with tf.name_scope("ParquetIOTensor") as scope: + resource, columns = core_ops.parquet_indexable_init( + filename, + container=scope, + shared_name="%s/%s" % (filename, uuid.uuid4().hex)) + partitions = None + if capacity is not None: + partitions = core_ops.parquet_indexable_partitions(resource) + partitions = partitions.numpy().tolist() + if capacity > 0: + partitions = [ + v for e in partitions for v in list( + [capacity] * (e // capacity) + [e % capacity])] + columns = [column.decode() for column in columns.numpy().tolist()] + spec = [] + for column in columns: + shape, dtype = core_ops.parquet_indexable_spec(resource, column) + shape = tf.TensorShape(shape.numpy()) + dtype = tf.as_dtype(dtype.numpy()) + spec.append(tf.TensorSpec(shape, dtype, column)) + spec = tuple(spec) + super(ParquetIOTensor, self).__init__( + spec, columns, + resource, core_ops.parquet_indexable_read, + partitions=partitions, internal=internal) diff --git a/tensorflow_io/parquet/kernels/parquet_kernels.cc b/tensorflow_io/parquet/kernels/parquet_kernels.cc index 5e5da6560..72385c06f 100644 --- a/tensorflow_io/parquet/kernels/parquet_kernels.cc +++ b/tensorflow_io/parquet/kernels/parquet_kernels.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow_io/core/kernels/io_interface.h" #include "tensorflow_io/arrow/kernels/arrow_kernels.h" #include "parquet/api/reader.h" @@ -218,5 +219,186 @@ REGISTER_KERNEL_BUILDER(Name("ReadParquet").Device(DEVICE_CPU), } // namespace + + +class ParquetIndexable : public IOIndexableInterface { + public: + ParquetIndexable(Env* env) + : env_(env) {} + + ~ParquetIndexable() {} + Status Init(const std::vector& input, const std::vector& metadata, const void* memory_data, const int64 memory_size) override { + if (input.size() > 1) { + return errors::InvalidArgument("more than 1 filename is not supported"); + } + const string& filename = input[0]; + file_.reset(new SizedRandomAccessFile(env_, filename, memory_data, memory_size)); + TF_RETURN_IF_ERROR(file_->GetFileSize(&file_size_)); + + parquet_file_.reset(new ArrowRandomAccessFile(file_.get(), file_size_)); + + parquet_file_.reset(new ArrowRandomAccessFile(file_.get(), file_size_)); + parquet_reader_ = parquet::ParquetFileReader::Open(parquet_file_); + parquet_metadata_ = parquet_reader_->metadata(); + + shapes_.clear(); + dtypes_.clear(); + columns_.clear(); + for (size_t i = 0; i < parquet_metadata_->num_columns(); i++) { + ::tensorflow::DataType dtype; + switch(parquet_metadata_->schema()->Column(i)->physical_type()) { + case parquet::Type::BOOLEAN: + dtype = ::tensorflow::DT_BOOL; + break; + case parquet::Type::INT32: + dtype = ::tensorflow::DT_INT32; + break; + case parquet::Type::INT64: + dtype = ::tensorflow::DT_INT64; + break; + case parquet::Type::INT96: // Deprecated, thrown out exception when access with __getitem__ + dtype = ::tensorflow::DT_INT64; + break; + case parquet::Type::FLOAT: + dtype = ::tensorflow::DT_FLOAT; + break; + case parquet::Type::DOUBLE: + dtype = ::tensorflow::DT_DOUBLE; + break; + case parquet::Type::BYTE_ARRAY: + dtype = ::tensorflow::DT_STRING; + break; + case parquet::Type::FIXED_LEN_BYTE_ARRAY: + dtype = ::tensorflow::DT_STRING; + break; + default: + return errors::InvalidArgument("parquet data type is not supported: ", parquet_metadata_->schema()->Column(i)->physical_type()); + break; + } + shapes_.push_back(TensorShape({static_cast(parquet_metadata_->num_rows())})); + dtypes_.push_back(dtype); + columns_.push_back(parquet_metadata_->schema()->Column(i)->path().get()->ToDotString()); + columns_index_[parquet_metadata_->schema()->Column(i)->path().get()->ToDotString()] = i; + } + + return Status::OK(); + } + Status Partitions(std::vector *partitions) override { + partitions->clear(); + for (int row_group = 0; row_group < parquet_metadata_->num_row_groups(); row_group++) { + std::shared_ptr row_group_reader = parquet_reader_->RowGroup(row_group); + partitions->push_back(row_group_reader->metadata()->num_rows()); + } + return Status::OK(); + } + Status Components(std::vector* components) override { + components->clear(); + for (size_t i = 0; i < columns_.size(); i++) { + components->push_back(columns_[i]); + } + return Status::OK(); + } + Status Spec(const string& component, PartialTensorShape* shape, DataType* dtype, bool label) override { + if (columns_index_.find(component) == columns_index_.end()) { + return errors::InvalidArgument("component ", component, " is invalid"); + } + int64 column_index = columns_index_[component]; + *shape = shapes_[column_index]; + *dtype = dtypes_[column_index]; + return Status::OK(); + } + + Status Read(const int64 start, const int64 stop, const string& component, Tensor* value, Tensor* label) override { + if (columns_index_.find(component) == columns_index_.end()) { + return errors::InvalidArgument("component ", component, " is invalid"); + } + int64 column_index = columns_index_[component]; + const string& column = component; + + int64 row_group_offset = 0; + for (int row_group = 0; row_group < parquet_metadata_->num_row_groups(); row_group++) { + std::shared_ptr row_group_reader = parquet_reader_->RowGroup(row_group); + // Skip if row group is not within [start..stop] + if ((row_group_offset + row_group_reader->metadata()->num_rows() < start) || (stop <= row_group_offset)) { + row_group_offset += row_group_reader->metadata()->num_rows(); + continue; + } + // Find row_to_read range + int64 row_to_read_start = row_group_offset > start ? row_group_offset : start; + int64 row_to_read_final = (row_group_offset + row_group_reader->metadata()->num_rows()) < (stop) ? (row_group_offset + row_group_reader->metadata()->num_rows()) : (stop); + int64 row_to_read_count = row_to_read_final - row_to_read_start; + + // TODO: parquet is RowGroup based so ideally the RowGroup should be cached + // with the hope of indexing and slicing happens on each row. For now no caching + // is done yet. + std::shared_ptr column_reader = row_group_reader->Column(column_index); + + // buffer to fill location is value.data()[row_to_read_start - start] + + #define PARQUET_PROCESS_TYPE(ptype, type) { \ + parquet::TypedColumnReader* reader = \ + static_cast*>( \ + column_reader.get()); \ + if (row_to_read_start > row_group_offset) { \ + reader->Skip(row_to_read_start - row_group_offset); \ + } \ + ptype::c_type* value_p = (ptype::c_type *)(void *)(&(value->flat().data()[row_to_read_start - start])); \ + int64_t values_read; \ + int64_t levels_read = reader->ReadBatch(row_to_read_count, nullptr, nullptr, value_p, &values_read); \ + if (!(levels_read == values_read && levels_read == row_to_read_count)) { \ + return errors::InvalidArgument("null value in column: ", column); \ + } \ + } + switch (parquet_metadata_->schema()->Column(column_index)->physical_type()) { + case parquet::Type::BOOLEAN: + PARQUET_PROCESS_TYPE(parquet::BooleanType, bool); + break; + case parquet::Type::INT32: + PARQUET_PROCESS_TYPE(parquet::Int32Type, int32); + break; + case parquet::Type::INT64: + PARQUET_PROCESS_TYPE(parquet::Int64Type, int64); + break; + case parquet::Type::FLOAT: + PARQUET_PROCESS_TYPE(parquet::FloatType, float); + break; + case parquet::Type::DOUBLE: + PARQUET_PROCESS_TYPE(parquet::DoubleType, double); + break; + default: + return errors::InvalidArgument("invalid data type: ", parquet_metadata_->schema()->Column(column_index)->physical_type()); + } + row_group_offset += row_group_reader->metadata()->num_rows(); + } + return Status::OK(); + } + + string DebugString() const override { + mutex_lock l(mu_); + return strings::StrCat("ParquetIndexable"); + } + private: + mutable mutex mu_; + Env* env_ GUARDED_BY(mu_); + std::unique_ptr file_ GUARDED_BY(mu_); + uint64 file_size_ GUARDED_BY(mu_); + std::shared_ptr parquet_file_; + std::unique_ptr<::parquet::ParquetFileReader> parquet_reader_; + std::shared_ptr<::parquet::FileMetaData> parquet_metadata_; + + std::vector dtypes_; + std::vector shapes_; + std::vector columns_; + std::unordered_map columns_index_; +}; + +REGISTER_KERNEL_BUILDER(Name("ParquetIndexableInit").Device(DEVICE_CPU), + IOInterfaceInitOp); +REGISTER_KERNEL_BUILDER(Name("ParquetIndexableSpec").Device(DEVICE_CPU), + IOInterfaceSpecOp); +REGISTER_KERNEL_BUILDER(Name("ParquetIndexablePartitions").Device(DEVICE_CPU), + IOIndexablePartitionsOp); +REGISTER_KERNEL_BUILDER(Name("ParquetIndexableRead").Device(DEVICE_CPU), + IOIndexableReadOp); } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/parquet/ops/parquet_ops.cc b/tensorflow_io/parquet/ops/parquet_ops.cc index 32c318f8c..bc22a94fc 100644 --- a/tensorflow_io/parquet/ops/parquet_ops.cc +++ b/tensorflow_io/parquet/ops/parquet_ops.cc @@ -44,5 +44,52 @@ REGISTER_OP("ReadParquet") c->set_output(0, c->MakeShape({c->UnknownDim()})); return Status::OK(); }); +REGISTER_OP("ParquetIndexableInit") + .Input("input: string") + .Output("resource: resource") + .Output("components: string") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + c->set_output(1, c->MakeShape({})); + return Status::OK(); + }); + +REGISTER_OP("ParquetIndexableSpec") + .Input("input: resource") + .Output("shape: int64") + .Output("dtype: int64") + .Attr("component: string") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->MakeShape({c->UnknownDim()})); + c->set_output(1, c->MakeShape({})); + return Status::OK(); + }); + +REGISTER_OP("ParquetIndexableRead") + .Input("input: resource") + .Input("start: int64") + .Input("stop: int64") + .Output("value: dtype") + .Attr("component: string") + .Attr("shape: shape") + .Attr("dtype: type") + .SetShapeFn([](shape_inference::InferenceContext* c) { + PartialTensorShape shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); + shape_inference::ShapeHandle entry; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &entry)); + c->set_output(0, entry); + return Status::OK(); + }); + +REGISTER_OP("ParquetIndexablePartitions") + .Input("input: resource") + .Output("partitions: int64") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->MakeShape({c->UnknownDim()})); + return Status::OK(); + }); } // namespace tensorflow diff --git a/tensorflow_io/parquet/python/ops/parquet_ops.py b/tensorflow_io/parquet/python/ops/parquet_ops.py index a0cb9407a..a581af000 100644 --- a/tensorflow_io/parquet/python/ops/parquet_ops.py +++ b/tensorflow_io/parquet/python/ops/parquet_ops.py @@ -17,10 +17,18 @@ from __future__ import division from __future__ import print_function +import warnings + import tensorflow as tf from tensorflow_io.core.python.ops import core_ops as parquet_ops from tensorflow_io.core.python.ops import data_ops +warnings.warn( + "The tensorflow_io.parquet.ParquetDataset is " + "deprecated. Please look for tfio.IOTensor.from_parquet " + "for reading parquet files into tensorflow.", + DeprecationWarning) + def list_parquet_columns(filename, **kwargs): """list_parquet_columns""" if not tf.executing_eagerly(): diff --git a/tests/test_avro_eager.py b/tests/test_avro_eager.py index a55e1c2f8..11574f0d1 100644 --- a/tests/test_avro_eager.py +++ b/tests/test_avro_eager.py @@ -119,6 +119,7 @@ def test_avro_dataset_partition(): assert im.numpy() == 100.0 + i assert re.numpy() == 100.0 * i i += 1 + assert i == 100 if __name__ == "__main__": test.main() diff --git a/tests/test_parquet_eager.py b/tests/test_parquet_eager.py index d80440a20..933e53f1e 100644 --- a/tests/test_parquet_eager.py +++ b/tests/test_parquet_eager.py @@ -24,7 +24,13 @@ import tensorflow as tf if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")): tf.compat.v1.enable_eager_execution() -import tensorflow_io.parquet as parquet_io # pylint: disable=wrong-import-position +import tensorflow_io as tfio # pylint: disable=wrong-import-position + +filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_parquet", + "parquet_cpp_example.parquet") +filename = "file://" + filename # Note: The sample file is generated from: # `parquet-cpp/examples/low-level-api/reader_writer` @@ -41,24 +47,27 @@ def test_parquet(): """Test case for read_parquet. """ - filename = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "test_parquet", - "parquet_cpp_example.parquet") - filename = "file://" + filename - - specs = parquet_io.list_parquet_columns(filename) + parquet = tfio.IOTensor.from_parquet(filename) columns = [ 'boolean_field', 'int32_field', 'int64_field', + 'int96_field', 'float_field', - 'double_field'] - p0 = parquet_io.read_parquet(filename, specs['boolean_field']) - p1 = parquet_io.read_parquet(filename, specs['int32_field']) - p2 = parquet_io.read_parquet(filename, specs['int64_field']) - p4 = parquet_io.read_parquet(filename, specs['float_field']) - p5 = parquet_io.read_parquet(filename, specs['double_field']) + 'double_field', + 'ba_field', + 'flba_field'] + assert parquet.columns == columns + p0 = parquet('boolean_field') + p1 = parquet('int32_field') + p2 = parquet('int64_field') + p4 = parquet('float_field') + p5 = parquet('double_field') + assert p0.dtype == tf.bool + assert p1.dtype == tf.int32 + assert p2.dtype == tf.int64 + assert p4.dtype == tf.float32 + assert p5.dtype == tf.float64 for i in range(500): # 500 rows. v0 = ((i % 2) == 0) @@ -72,24 +81,25 @@ def test_parquet(): assert np.isclose(v4, p4[i].numpy()) assert np.isclose(v5, p5[i].numpy()) - dataset = tf.compat.v2.data.Dataset.zip( - tuple( - [parquet_io.ParquetDataset(filename, column) for column in columns]) - ).apply(tf.data.experimental.unbatch()) - i = 0 - for p in dataset: - v0 = ((i % 2) == 0) - v1 = i - v2 = i * 1000 * 1000 * 1000 * 1000 - v4 = 1.1 * i - v5 = 1.1111111 * i - p0, p1, p2, p4, p5 = p - assert v0 == p0.numpy() - assert v1 == p1.numpy() - assert v2 == p2.numpy() - assert np.isclose(v4, p4.numpy()) - assert np.isclose(v5, p5.numpy()) - i += 1 +def test_parquet_partition(): + """test_parquet_partition""" + for capacity in [ + 1, 2, 3, + 11, 12, 13, + 50, 51, 100, 200]: + parquet = tfio.IOTensor.from_parquet( + filename, capacity=capacity) + assert np.all( + parquet("int32_field").to_tensor().numpy() == [i for i in range(500)]) + for step in [ + 1, 2, 3, + 10, 11, 12, 13, + 50, 51, 52, 53]: + indices = list(range(0, 100, step)) + for start, stop in zip(indices, indices[1:] + [100]): + expected = [i for i in range(start, stop)] + items = parquet("int32_field")[start:stop] + assert np.all(items.numpy() == expected) if __name__ == "__main__": test.main()