From ce430900258fd6bf70e5024e206dbd446ad6bc5f Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 27 Jul 2019 05:17:58 +0000 Subject: [PATCH 1/4] Rework on ParquetDataset for easy access and better cache size in eager mode This fix is part of the effort to improve overall Dataset for easy access and better cache size in eager mode. See 382 and 366 for related discussions. In order to be able to read file either in filename or in mmeory, this PR adds an SizedRandomAccessFile which allows to provide an optional memory buffer as file content. This could be useful in process compression or archives where we could just read the uncompressed file content into memory. The preivous limitation in Dataset was that Dataset was a iterable so sequence length is unknown until graph runtime. In this PR, we provide an helper function to read the specs of parquet file and lenth is know. This also could open other avenues such as map parquet file with __getitem__ and __len__. Further, parquet file could be read into a Tensor and processed easily (such as pandas like API). The read_parquet_specs could be similarly applied to HDF5 which is more important: HDF5 could have dataset with different sizes. Summary: 1) Two basic C++ kernel ops are implemnted: read_parquet_specs and read_parquet 2) One ParquetDataset that is python implementation only (no C++ anymore) 3) ParquetDataset support eager and graph mode, in graph mode, dtype and shape are provided by user explicitly. In eager mode, only column name is needed. 4) read_parquet works in eager and graph mode, can read records either in full, or in slices 5) read_parquet_specs works in eager mode only (limitation). For cache batch vs. batch in tf.keras 1) Added a hidden `capacity` to adjust the cache batch size 2) batch to be passed in tf.keras is unrelated to `capacity`, but we could use `rebatch` to change at the end of the pipeline. 3) `capacity` could be padded to allow `rebatch` to only cut a slice over one chunk. If not padded to `batch_size` in tf.keras, then `rebatch` likely will copy over boundary. Signed-off-by: Yong Tang --- tensorflow_io/core/BUILD | 1 + tensorflow_io/parquet/BUILD | 10 +- tensorflow_io/parquet/__init__.py | 6 + .../parquet/kernels/parquet_kernels.cc | 318 ++++++++++++++++++ tensorflow_io/parquet/ops/parquet_ops.cc | 29 +- .../parquet/python/ops/parquet_ops.py | 106 +++--- tests/test_parquet.py | 155 --------- tests/test_parquet_eager.py | 94 ++++++ 8 files changed, 490 insertions(+), 229 deletions(-) create mode 100644 tensorflow_io/parquet/kernels/parquet_kernels.cc delete mode 100644 tests/test_parquet.py create mode 100644 tests/test_parquet_eager.py diff --git a/tensorflow_io/core/BUILD b/tensorflow_io/core/BUILD index 07bdc0b36..06a8205e2 100644 --- a/tensorflow_io/core/BUILD +++ b/tensorflow_io/core/BUILD @@ -135,6 +135,7 @@ cc_binary( "//tensorflow_io/json:json_ops", "//tensorflow_io/lmdb:lmdb_ops", "//tensorflow_io/mnist:mnist_ops", + "//tensorflow_io/parquet:parquet_ops", "//tensorflow_io/prometheus:prometheus_ops", "//tensorflow_io/text:text_ops", "@libarchive", diff --git a/tensorflow_io/parquet/BUILD b/tensorflow_io/parquet/BUILD index 5f74f1547..f6768043f 100644 --- a/tensorflow_io/parquet/BUILD +++ b/tensorflow_io/parquet/BUILD @@ -7,18 +7,16 @@ load( "tf_io_copts", ) -cc_binary( - name = "python/ops/_parquet_ops.so", +cc_library( + name = "parquet_ops", srcs = [ - "kernels/parquet_input.cc", + "kernels/parquet_kernels.cc", "ops/parquet_ops.cc", ], copts = tf_io_copts(), - linkshared = 1, + linkstatic = True, deps = [ "//tensorflow_io/core:dataset_ops", "@arrow", - "@local_config_tf//:libtensorflow_framework", - "@local_config_tf//:tf_header_lib", ], ) diff --git a/tensorflow_io/parquet/__init__.py b/tensorflow_io/parquet/__init__.py index e3bb3bba5..6e66d7cf4 100644 --- a/tensorflow_io/parquet/__init__.py +++ b/tensorflow_io/parquet/__init__.py @@ -15,6 +15,8 @@ """Parquet Dataset. @@ParquetDataset +@@read_parquet +@@read_parquet_specs """ from __future__ import absolute_import @@ -22,11 +24,15 @@ from __future__ import print_function from tensorflow_io.parquet.python.ops.parquet_ops import ParquetDataset +from tensorflow_io.parquet.python.ops.parquet_ops import read_parquet +from tensorflow_io.parquet.python.ops.parquet_ops import read_parquet_specs from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ "ParquetDataset", + "read_parquet", + "read_parquet_specs", ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow_io/parquet/kernels/parquet_kernels.cc b/tensorflow_io/parquet/kernels/parquet_kernels.cc new file mode 100644 index 000000000..00fc4624d --- /dev/null +++ b/tensorflow_io/parquet/kernels/parquet_kernels.cc @@ -0,0 +1,318 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "kernels/dataset_ops.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "parquet/api/reader.h" + +namespace tensorflow { +namespace data { +namespace { + +class ParquetRandomAccessFile : public ::arrow::io::RandomAccessFile { +public: + explicit ParquetRandomAccessFile(tensorflow::RandomAccessFile *file, int64 size) + : file_(file) + , size_(size) { } + + ~ParquetRandomAccessFile() {} + arrow::Status Close() override { + return arrow::Status::OK(); + } + arrow::Status Tell(int64_t* position) const override { + return arrow::Status::NotImplemented("Tell"); + } + arrow::Status Seek(int64_t position) override { + return arrow::Status::NotImplemented("Seek"); + } + arrow::Status Read(int64_t nbytes, int64_t* bytes_read, void* out) override { + return arrow::Status::NotImplemented("Read (void*)"); + } + arrow::Status Read(int64_t nbytes, std::shared_ptr* out) override { + return arrow::Status::NotImplemented("Read (Buffer*)"); + } + arrow::Status GetSize(int64_t* size) override { + *size = size_; + return arrow::Status::OK(); + } + bool supports_zero_copy() const override { + return false; + } + arrow::Status ReadAt(int64_t position, int64_t nbytes, int64_t* bytes_read, void* out) override { + StringPiece result; + Status status = file_->Read(position, nbytes, &result, (char*)out); + if (!(status.ok() || errors::IsOutOfRange(status))) { + return arrow::Status::IOError(status.error_message()); + } + *bytes_read = result.size(); + return arrow::Status::OK(); + } + arrow::Status ReadAt(int64_t position, int64_t nbytes, std::shared_ptr* out) override { + string buffer; + buffer.resize(nbytes); + StringPiece result; + Status status = file_->Read(position, nbytes, &result, (char*)(&buffer[0])); + if (!(status.ok() || errors::IsOutOfRange(status))) { + return arrow::Status::IOError(status.error_message()); + } + buffer.resize(result.size()); + return arrow::Buffer::FromString(buffer, out); + } +private: + tensorflow::RandomAccessFile* file_; + int64 size_; +}; + +class ReadParquetSpecsOp : public OpKernel { + public: + explicit ReadParquetSpecsOp(OpKernelConstruction* context) : OpKernel(context) { + env_ = context->env(); + } + + void Compute(OpKernelContext* context) override { + const Tensor& filename_tensor = context->input(0); + const string filename = filename_tensor.scalar()(); + + std::unique_ptr file; + OP_REQUIRES_OK(context, env_->NewRandomAccessFile(filename, &file)); + uint64 size = 0; + OP_REQUIRES_OK(context, env_->GetFileSize(filename, &size)); + + std::shared_ptr parquet_file(new ParquetRandomAccessFile(file.get(), size)); + std::shared_ptr<::parquet::FileMetaData> metadata = ::parquet::ReadMetaData(parquet_file); + + std::vector columns; + std::vector dtypes; + std::vector counts; + columns.reserve(metadata->num_columns()); + dtypes.reserve(metadata->num_columns()); + counts.reserve(metadata->num_columns()); + for (int i = 0; i < metadata->num_columns(); i++) { + string dtype = ""; + switch(metadata->schema()->Column(i)->physical_type()) { + case parquet::Type::BOOLEAN: + dtype = "bool"; + break; + case parquet::Type::INT32: + dtype = "int32"; + break; + case parquet::Type::INT64: + dtype = "int64"; + break; + case parquet::Type::FLOAT: + dtype = "float"; + break; + case parquet::Type::DOUBLE: + dtype = "double"; + break; + default: + // Unsupported data type INT96, BYTE_ARRAY, FIXED_LEN_BYTE_ARRAY + break; + } + if (dtype == "") { + continue; + } + columns.push_back(metadata->schema()->Column(i)->path().get()->ToDotString()); + dtypes.push_back(dtype); + counts.push_back(metadata->num_rows()); + } + + TensorShape output_shape = filename_tensor.shape(); + output_shape.AddDim(columns.size()); + + Tensor* columns_tensor; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &columns_tensor)); + Tensor* dtypes_tensor; + OP_REQUIRES_OK(context, context->allocate_output(1, output_shape, &dtypes_tensor)); + + output_shape.AddDim(1); + + Tensor* shapes_tensor; + OP_REQUIRES_OK(context, context->allocate_output(2, output_shape, &shapes_tensor)); + + for (int i = 0; i < columns.size(); i++) { + columns_tensor->flat()(i) = columns[i]; + dtypes_tensor->flat()(i) = dtypes[i]; + shapes_tensor->flat()(i) = counts[i]; + } + } + private: + mutex mu_; + Env* env_ GUARDED_BY(mu_); +}; + +// Note: This SizedRandomAccessFile should only lives within Compute() +// of the kernel as buffer could be released by outside. +class SizedRandomAccessFile : public tensorflow::RandomAccessFile { + public: + SizedRandomAccessFile(Env* env, const string& filename, const string& optional_memory) + : file_(nullptr) + , size_status_(Status::OK()) + , size_(optional_memory.size()) + , buffer_(optional_memory) { + if (size_ == 0) { + size_status_ = env->GetFileSize(filename, &size_); + if (size_status_.ok()) { + size_status_ = env->NewRandomAccessFile(filename, &file_); + } + } + } + + virtual ~SizedRandomAccessFile() {} + Status Read(uint64 offset, size_t n, StringPiece* result, char* scratch) const override { + if (file_.get() != nullptr) { + return file_.get()->Read(offset, n, result, scratch); + } + size_t bytes_to_read = 0; + if (offset < size_) { + bytes_to_read = (offset + n < size_) ? n : (size_ - offset); + } + if (bytes_to_read > 0) { + memcpy(scratch, buffer_.data(), bytes_to_read); + } + *result = StringPiece(scratch, bytes_to_read); + if (bytes_to_read < n) { + return errors::OutOfRange("EOF reached"); + } + return Status::OK(); + } + Status GetFileSize(uint64* size) { + if (size_status_.ok()) { + *size = size_; + } + return size_status_; + } + private: + std::unique_ptr file_; + Status size_status_; + uint64 size_; + const string& buffer_; +}; + +class ReadParquetOp : public OpKernel { + public: + explicit ReadParquetOp(OpKernelConstruction* context) : OpKernel(context) { + env_ = context->env(); + } + + void Compute(OpKernelContext* context) override { + const Tensor& filename_tensor = context->input(0); + const string& filename = filename_tensor.scalar()(); + + const Tensor& column_tensor = context->input(1); + const string& column = column_tensor.scalar()(); + + const Tensor& start_tensor = context->input(2); + const int64 start = start_tensor.scalar()(); + + const Tensor& count_tensor = context->input(3); + int64 count = count_tensor.scalar()(); + + const Tensor& memory_tensor = context->input(4); + const string& memory = memory_tensor.scalar()(); + + std::unique_ptr file(new SizedRandomAccessFile(env_, filename, memory)); + uint64 size; + OP_REQUIRES_OK(context, file->GetFileSize(&size)); + + std::shared_ptr parquet_file(new ParquetRandomAccessFile(file.get(), size)); + std::unique_ptr<::parquet::ParquetFileReader> parquet_reader = parquet::ParquetFileReader::Open(parquet_file); + std::shared_ptr<::parquet::FileMetaData> file_metadata = parquet_reader->metadata(); + int column_index = 0; + while (column_index < file_metadata->num_columns()) { + if (file_metadata->schema()->Column(column_index)->path().get()->ToDotString() == column) { + break; + } + column_index++; + } + OP_REQUIRES(context, (column_index < file_metadata->num_columns()), errors::InvalidArgument("unable to find column: ", column)); + + if (start + count > file_metadata->num_rows()) { + count = file_metadata->num_rows() - start; + } + + TensorShape output_shape({count}); + + Tensor* output_tensor; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output_tensor)); + + int64 row_group_offset = 0; + for (int row_group = 0; row_group < file_metadata->num_row_groups(); row_group++) { + std::shared_ptr row_group_reader = parquet_reader->RowGroup(row_group); + // Skip if row group is not within [start..start+count] + if ((row_group_offset + row_group_reader->metadata()->num_rows() < start) || (start + count <= row_group_offset)) { + row_group_offset += row_group_reader->metadata()->num_rows(); + continue; + } + // Find row_to_read range + int64 row_to_read_start = row_group_offset > start ? row_group_offset : start; + int64 row_to_read_final = (row_group_offset + row_group_reader->metadata()->num_rows()) < (start + count) ? (row_group_offset + row_group_reader->metadata()->num_rows()) : (start + count); + int64 row_to_read_count = row_to_read_final - row_to_read_start; + + std::shared_ptr column_reader = row_group_reader->Column(column_index); + + // buffer to fill location is tensor.data()[row_to_read_start - start] + + #define PROCESS_TYPE(ptype, type) \ + { \ + parquet::TypedColumnReader* reader = \ + static_cast*>( \ + column_reader.get()); \ + if (row_to_read_start > row_group_offset) { \ + reader->Skip(row_to_read_start - row_group_offset); \ + } \ + ptype::c_type* value = (ptype::c_type *)(void *)(&(output_tensor->flat().data()[row_to_read_start - start])); \ + int64_t values_read; \ + int64_t levels_read = reader->ReadBatch(row_to_read_count, nullptr, nullptr, value, &values_read); \ + OP_REQUIRES(context, (levels_read == values_read && levels_read == row_to_read_count), errors::InvalidArgument("null value in column: ", column)); \ + } + switch (file_metadata->schema()->Column(column_index)->physical_type()) { + case parquet::Type::BOOLEAN: + PROCESS_TYPE(parquet::BooleanType, bool); + break; + case parquet::Type::INT32: + PROCESS_TYPE(parquet::Int32Type, int32); + break; + case parquet::Type::INT64: + PROCESS_TYPE(parquet::Int64Type, int64); + break; + case parquet::Type::FLOAT: + PROCESS_TYPE(parquet::FloatType, float); + break; + case parquet::Type::DOUBLE: + PROCESS_TYPE(parquet::DoubleType, double); + break; + default: + OP_REQUIRES(context, false, errors::InvalidArgument("invalid data type: ", file_metadata->schema()->Column(column_index)->physical_type())); + } + row_group_offset += row_group_reader->metadata()->num_rows(); + } + } + private: + mutex mu_; + Env* env_ GUARDED_BY(mu_); +}; + +REGISTER_KERNEL_BUILDER(Name("ReadParquetSpecs").Device(DEVICE_CPU), + ReadParquetSpecsOp); +REGISTER_KERNEL_BUILDER(Name("ReadParquet").Device(DEVICE_CPU), + ReadParquetOp); + + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/parquet/ops/parquet_ops.cc b/tensorflow_io/parquet/ops/parquet_ops.cc index e38adb8f1..8d7a81caf 100644 --- a/tensorflow_io/parquet/ops/parquet_ops.cc +++ b/tensorflow_io/parquet/ops/parquet_ops.cc @@ -19,27 +19,26 @@ limitations under the License. namespace tensorflow { -REGISTER_OP("ParquetInput") - .Input("source: string") - .Output("handle: variant") - .Attr("filters: list(string) = []") - .Attr("columns: list(string) = []") - .Attr("schema: string = ''") +REGISTER_OP("ReadParquetSpecs") + .Input("filename: string") + .Output("columns: string") + .Output("dtypes: string") + .Output("shapes: int64") .SetShapeFn([](shape_inference::InferenceContext* c) { c->set_output(0, c->MakeShape({c->UnknownDim()})); return Status::OK(); }); -REGISTER_OP("ParquetDataset") - .Input("input: T") - .Input("batch: int64") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .Attr("T: {string, variant} = DT_VARIANT") - .SetIsStateful() +REGISTER_OP("ReadParquet") + .Input("filename: string") + .Input("column: string") + .Input("start: int64") + .Input("count: int64") + .Input("memory: string") + .Attr("dtype: type") + .Output("output: dtype") .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->MakeShape({})); + c->set_output(0, c->MakeShape({c->UnknownDim()})); return Status::OK(); }); diff --git a/tensorflow_io/parquet/python/ops/parquet_ops.py b/tensorflow_io/parquet/python/ops/parquet_ops.py index 2f1e575ad..76acd8b69 100644 --- a/tensorflow_io/parquet/python/ops/parquet_ops.py +++ b/tensorflow_io/parquet/python/ops/parquet_ops.py @@ -18,68 +18,68 @@ from __future__ import print_function import tensorflow as tf -from tensorflow.compat.v1 import data -from tensorflow_io import _load_library -parquet_ops = _load_library('_parquet_ops.so') +from tensorflow_io.core.python.ops import core_ops as parquet_ops +from tensorflow_io.core.python.ops import data_ops -class ParquetDataset(data.Dataset): +def read_parquet_specs(filename): + """read_parquet_specs""" + if not tf.executing_eagerly(): + raise NotImplementedError("read_parquet_spect only support eager mode") + columns, dtypes, shapes = parquet_ops.read_parquet_specs(filename) + entries = zip(tf.unstack(columns), tf.unstack(dtypes), tf.unstack(shapes)) + return dict([(column.numpy(), tf.TensorSpec( + shape.numpy(), dtype.numpy(), column.numpy())) for ( + column, dtype, shape) in entries]) + +def read_parquet(filename, spec, start=0, **kwargs): + """read_parquet""" + memory = kwargs.get("memory", "") + return parquet_ops.read_parquet( + filename, spec.name, + start=start, count=spec.shape[0] - start, dtype=spec.dtype, + memory=memory) + +class ParquetDataset(data_ops.BaseDataset): """A Parquet Dataset that reads the parquet file.""" - def __init__(self, filename, columns, dtypes=None, batch=None): + def __init__(self, filename, column, batch=None, **kwargs): """Create a `ParquetDataset`. `ParquetDataset` allows a user to read data from a parquet file. - For example: - - ```python - dataset = tf.contrib.parquet.ParquetDataset( - "/foo/bar.parquet", [0, 1], (tf.bool, tf.int32)) - iterator = dataset.make_one_shot_iterator() - next_element = iterator.get_next() - # Prints the rows of the result set of the column [0, 1]. - while True: - try: - print(sess.run(next_element)) - except tf.errors.OutOfRangeError: - break - ``` Args: - filename: A 0-D or 1-D `tf.string` tensor containing one or more - filenames. - columns: A 0-D or 1-D `tf.int32` tensor containing the columns to extract. - dtypes: A tuple of `tf.DType` objects representing the types of the - columns returned. + filename: filename of the parquet file to read. + column: column name to read. """ - self._data_input = parquet_ops.parquet_input( - filename, ["none", "gz"], columns=columns) - self._columns = columns - self._dtypes = dtypes - self._batch = 0 if batch is None else batch - super(ParquetDataset, self).__init__() - - def _inputs(self): - return [] - - def _as_variant_tensor(self): - return parquet_ops.parquet_dataset( - self._data_input, - self._batch, - output_types=self.output_types, - output_shapes=self.output_shapes) + # Note: count and dtype could be in kwargs if in graph mode. + if not tf.executing_eagerly(): + count = kwargs.get("count") + dtype = kwargs.get("dtype") + else: + specs = read_parquet_specs(filename) + count = specs[column].shape[0] + dtype = specs[column].dtype - @property - def output_classes(self): - return tuple([tf.Tensor for _ in self._columns]) + batch = 0 if batch is None else batch + shape = tf.TensorShape([]) if ( + batch is None or batch == 0) else tf.TensorShape([None]) - @property - def output_shapes(self): - return tuple( - [tf.TensorShape([]) for _ in self._columns] - ) if self._batch is None else tuple( - [tf.TensorShape([None]) for _ in self._columns] - ) + # capacity is the rough count for each chunk in dataset + # not directly related to batch, will be padded to batch though + capacity = kwargs.get("capacity", 65536) + if batch is not None and batch != 0 and capacity > batch: + capacity = (capacity // batch) * batch + entry_start = range(0, count, capacity) + entry_count = [min(capacity, count - start) for start in entry_start] + dataset = data_ops.BaseDataset.from_tensor_slices( + (tf.constant(entry_start, tf.int64), tf.constant(entry_count, tf.int64)) + ).map(lambda start, count: parquet_ops.read_parquet( + filename, column, start, count, dtype=dtype, memory="")) + if batch is None or batch == 0: + self._dataset = dataset.unbatch() + else: + # TODO: convert to rebatch for performance + self._dataset = dataset.unbatch().batch(batch) - @property - def output_types(self): - return self._dtypes + super(ParquetDataset, self).__init__( + self._dataset._variant_tensor, [dtype], [shape]) # pylint: disable=protected-access diff --git a/tests/test_parquet.py b/tests/test_parquet.py deleted file mode 100644 index 20af0ff20..000000000 --- a/tests/test_parquet.py +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may not -# use this file except in compliance with the License. You may obtain a copy of -# the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations under -# the License. -# ============================================================================== -"""Tests for ParquetDataset.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os - -import tensorflow as tf -tf.compat.v1.disable_eager_execution() - -from tensorflow import dtypes # pylint: disable=wrong-import-position -from tensorflow import errors # pylint: disable=wrong-import-position -from tensorflow import test # pylint: disable=wrong-import-position -from tensorflow.compat.v1 import data # pylint: disable=wrong-import-position - -import tensorflow_io.parquet as parquet_io # pylint: disable=wrong-import-position - -class ParquetDatasetTest(test.TestCase): - """ParquetDatasetTest""" - def test_parquet_dataset(self): - """Test case for ParquetDataset. - - Note: The sample file is generated from: - `parquet-cpp/examples/low-level-api/reader_writer` - This test extracts columns of [0, 1, 2, 4, 5] - with column data types of [bool, int32, int64, float, double]. - Please check `parquet-cpp/examples/low-level-api/reader-writer.cc` - to find details of how records are generated: - Column 0 (bool): True for even rows and False otherwise. - Column 1 (int32): Equal to row_index. - Column 2 (int64): Equal to row_index * 1000 * 1000 * 1000 * 1000. - Column 4 (float): Equal to row_index * 1.1. - Column 5 (double): Equal to row_index * 1.1111111. - """ - filename = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "test_parquet", - "parquet_cpp_example.parquet") - filename = "file://" + filename - columns = [ - 'boolean_field', - 'int32_field', - 'int64_field', - 'float_field', - 'double_field'] - output_types = ( - dtypes.bool, dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64) - num_repeats = 2 - - dataset = parquet_io.ParquetDataset( - [filename], columns, output_types).repeat(num_repeats) - iterator = data.make_initializable_iterator(dataset) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for _ in range(num_repeats): # Dataset is repeated. - for i in range(500): # 500 rows. - v0 = ((i % 2) == 0) - v1 = i - v2 = i * 1000 * 1000 * 1000 * 1000 - v4 = 1.1 * i - v5 = 1.1111111 * i - vv = sess.run(get_next) - self.assertAllClose((v0, v1, v2, v4, v5), vv) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - dataset = parquet_io.ParquetDataset( - [filename], columns, output_types, batch=1) - iterator = data.make_initializable_iterator(dataset) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for i in range(500): - v0 = ((i % 2) == 0) - v1 = i - v2 = i * 1000 * 1000 * 1000 * 1000 - v4 = 1.1 * i - v5 = 1.1111111 * i - vv = sess.run(get_next) - self.assertAllClose(([v0], [v1], [v2], [v4], [v5]), vv) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - dataset = parquet_io.ParquetDataset( - [filename, filename], columns, output_types, batch=3) - iterator = data.make_initializable_iterator(dataset) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for ii in range(0, 999, 3): - v0, v1, v2, v4, v5 = [], [], [], [], [] - for i in [ii % 500, (ii + 1) % 500, (ii + 2) % 500]: - v0.append((i % 2) == 0) - v1.append(i) - v2.append(i * 1000 * 1000 * 1000 * 1000) - v4.append(1.1 * i) - v5.append(1.1111111 * i) - vv = sess.run(get_next) - self.assertAllClose((v0, v1, v2, v4, v5), vv) - i = 999 % 500 - v0 = ((i % 2) == 0) - v1 = i - v2 = i * 1000 * 1000 * 1000 * 1000 - v4 = 1.1 * i - v5 = 1.1111111 * i - vv = sess.run(get_next) - self.assertAllClose(([v0], [v1], [v2], [v4], [v5]), vv) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # With compression - filename = filename + '.gz' - dataset = parquet_io.ParquetDataset( - [filename], columns, output_types).repeat(num_repeats) - iterator = data.make_initializable_iterator(dataset) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for _ in range(num_repeats): # Dataset is repeated. - for i in range(500): # 500 rows. - v0 = ((i % 2) == 0) - v1 = i - v2 = i * 1000 * 1000 * 1000 * 1000 - v4 = 1.1 * i - v5 = 1.1111111 * i - vv = sess.run(get_next) - self.assertAllClose((v0, v1, v2, v4, v5), vv) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) -if __name__ == "__main__": - test.main() diff --git a/tests/test_parquet_eager.py b/tests/test_parquet_eager.py new file mode 100644 index 000000000..4dd67ed7f --- /dev/null +++ b/tests/test_parquet_eager.py @@ -0,0 +1,94 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for read_parquet and ParquetDataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import numpy as np + +import tensorflow as tf +if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")): + tf.compat.v1.enable_eager_execution() +import tensorflow_io.parquet as parquet_io # pylint: disable=wrong-import-position + +# Note: The sample file is generated from: +# `parquet-cpp/examples/low-level-api/reader_writer` +# This test extracts columns of [0, 1, 2, 4, 5] +# with column data types of [bool, int32, int64, float, double]. +# Please check `parquet-cpp/examples/low-level-api/reader-writer.cc` +# to find details of how records are generated: +# Column 0 (bool): True for even rows and False otherwise. +# Column 1 (int32): Equal to row_index. +# Column 2 (int64): Equal to row_index * 1000 * 1000 * 1000 * 1000. +# Column 4 (float): Equal to row_index * 1.1. +# Column 5 (double): Equal to row_index * 1.1111111. +def test_parquet(): + """Test case for read_parquet. + + """ + filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_parquet", + "parquet_cpp_example.parquet") + filename = "file://" + filename + + specs = parquet_io.read_parquet_specs(filename) + columns = [ + 'boolean_field', + 'int32_field', + 'int64_field', + 'float_field', + 'double_field'] + p0 = parquet_io.read_parquet(filename, specs['boolean_field']) + p1 = parquet_io.read_parquet(filename, specs['int32_field']) + p2 = parquet_io.read_parquet(filename, specs['int64_field']) + p4 = parquet_io.read_parquet(filename, specs['float_field']) + p5 = parquet_io.read_parquet(filename, specs['double_field']) + + for i in range(500): # 500 rows. + v0 = ((i % 2) == 0) + v1 = i + v2 = i * 1000 * 1000 * 1000 * 1000 + v4 = 1.1 * i + v5 = 1.1111111 * i + assert v0 == p0[i].numpy() + assert v1 == p1[i].numpy() + assert v2 == p2[i].numpy() + assert np.isclose(v4, p4[i].numpy()) + assert np.isclose(v5, p5[i].numpy()) + + dataset = tf.compat.v2.data.Dataset.zip( + tuple( + [parquet_io.ParquetDataset(filename, column) for column in columns])) + i = 0 + for p in dataset: + v0 = ((i % 2) == 0) + v1 = i + v2 = i * 1000 * 1000 * 1000 * 1000 + v4 = 1.1 * i + v5 = 1.1111111 * i + p0, p1, p2, p4, p5 = p + assert v0 == p0.numpy() + assert v1 == p1.numpy() + assert v2 == p2.numpy() + assert np.isclose(v4, p4.numpy()) + assert np.isclose(v5, p5.numpy()) + i += 1 + +if __name__ == "__main__": + test.main() From 2b9980d54ad3ee7b339539caf4aba37a13e52865 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 28 Jul 2019 05:49:06 +0000 Subject: [PATCH 2/4] Fix build failures Signed-off-by: Yong Tang --- tensorflow_io/parquet/__init__.py | 6 +- .../parquet/kernels/parquet_input.cc | 315 ------------------ .../parquet/kernels/parquet_kernels.cc | 114 +++---- tensorflow_io/parquet/ops/parquet_ops.cc | 3 +- .../parquet/python/ops/parquet_ops.py | 28 +- tests/test_parquet_eager.py | 2 +- 6 files changed, 79 insertions(+), 389 deletions(-) delete mode 100644 tensorflow_io/parquet/kernels/parquet_input.cc diff --git a/tensorflow_io/parquet/__init__.py b/tensorflow_io/parquet/__init__.py index 6e66d7cf4..e64ed8b83 100644 --- a/tensorflow_io/parquet/__init__.py +++ b/tensorflow_io/parquet/__init__.py @@ -16,7 +16,7 @@ @@ParquetDataset @@read_parquet -@@read_parquet_specs +@@read_parquet_columns """ from __future__ import absolute_import @@ -25,14 +25,14 @@ from tensorflow_io.parquet.python.ops.parquet_ops import ParquetDataset from tensorflow_io.parquet.python.ops.parquet_ops import read_parquet -from tensorflow_io.parquet.python.ops.parquet_ops import read_parquet_specs +from tensorflow_io.parquet.python.ops.parquet_ops import read_parquet_columns from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ "ParquetDataset", "read_parquet", - "read_parquet_specs", + "read_parquet_columns", ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow_io/parquet/kernels/parquet_input.cc b/tensorflow_io/parquet/kernels/parquet_input.cc deleted file mode 100644 index e1808d559..000000000 --- a/tensorflow_io/parquet/kernels/parquet_input.cc +++ /dev/null @@ -1,315 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "kernels/dataset_ops.h" -#include "tensorflow/core/lib/io/buffered_inputstream.h" -#include "parquet/api/reader.h" - -namespace tensorflow { -namespace data { - -class ParquetRandomAccessFile : public ::arrow::io::RandomAccessFile { -public: - explicit ParquetRandomAccessFile(io::InputStreamInterface* s) - : input_stream_(nullptr) - , buffered_stream_(nullptr) { - input_stream_ = dynamic_cast(s); - if (input_stream_ == nullptr) { - buffered_stream_.reset(new SizedRandomAccessBufferedStream(s)); - input_stream_ = buffered_stream_.get(); - } - } - ~ParquetRandomAccessFile() {} - arrow::Status Close() override { - return arrow::Status::OK(); - } - arrow::Status Tell(int64_t* position) const override { - return arrow::Status::NotImplemented("Tell"); - } - arrow::Status Seek(int64_t position) override { - return arrow::Status::NotImplemented("Seek"); - } - arrow::Status Read(int64_t nbytes, int64_t* bytes_read, void* out) override { - return arrow::Status::NotImplemented("Read (void*)"); - } - arrow::Status Read(int64_t nbytes, std::shared_ptr* out) override { - return arrow::Status::NotImplemented("Read (Buffer*)"); - } - arrow::Status GetSize(int64_t* size) override { - uint64 size_value = 0; - Status status = input_stream_->GetFileSize(&size_value); - if (!status.ok()) { - return arrow::Status::IOError(status.error_message()); - } - *size = size_value; - return arrow::Status::OK(); - } - bool supports_zero_copy() const override { - return false; - } - arrow::Status ReadAt(int64_t position, int64_t nbytes, int64_t* bytes_read, void* out) override { - StringPiece result; - Status status = input_stream_->Read(position, nbytes, &result, (char *)out); - if (!(status.ok() || errors::IsOutOfRange(status))) { - return arrow::Status::IOError(status.error_message()); - } - *bytes_read = result.size(); - return arrow::Status::OK(); - } - arrow::Status ReadAt(int64_t position, int64_t nbytes, std::shared_ptr* out) override { - string buffer; - buffer.resize(nbytes); - StringPiece result; - Status status = input_stream_->Read(position, nbytes, &result, &buffer[0]); - if (!(status.ok() || errors::IsOutOfRange(status))) { - return arrow::Status::IOError(status.error_message()); - } - buffer.resize(result.size()); - return arrow::Buffer::FromString(buffer, out); - } -private: - SizedRandomAccessInputStreamInterface* input_stream_; - std::unique_ptr buffered_stream_; -}; - -class ParquetInputStream{ -public: - explicit ParquetInputStream(io::InputStreamInterface* s, const std::vector& columns) - : input_stream_(new ParquetRandomAccessFile(s)) - , column_names_(columns) { - } - Status ReadHeader() { - parquet_reader_ = parquet::ParquetFileReader::Open(input_stream_); - file_metadata_ = parquet_reader_->metadata(); - columns_ = std::vector(column_names_.size(), -1); - dtypes_ = std::vector(column_names_.size()); - for (size_t i = 0; i < column_names_.size(); i++) { - for (int j = 0; j < file_metadata_->schema()->num_columns(); j++) { - if (column_names_[i] == file_metadata_->schema()->Column(j)->path().get()->ToDotString()) { - columns_[i] = j; - switch(file_metadata_->schema()->Column(j)->physical_type()) { - case parquet::Type::BOOLEAN: - dtypes_[i] = DT_BOOL; - break; - case parquet::Type::INT32: - dtypes_[i] = DT_INT32; - break; - case parquet::Type::INT64: - dtypes_[i] = DT_INT64; - break; - case parquet::Type::FLOAT: - dtypes_[i] = DT_FLOAT; - break; - case parquet::Type::DOUBLE: - dtypes_[i] = DT_DOUBLE; - break; - default: - return errors::InvalidArgument("data type is not supported for column ", column_names_[i]); - } - break; - } - } - if (columns_[i] < 0) { - return errors::InvalidArgument("unable to find column ", column_names_[i]); - } - } - current_row_group_ = 0; - TF_RETURN_IF_ERROR(ReadRowGroup()); - return Status::OK(); - } - DataType DType(int64 i) { - return dtypes_[i]; - } - int64 Columns() { - return (int64)columns_.size(); - } - Status ReadRowGroup() { - if (current_row_group_ < file_metadata_->num_row_groups()) { - row_group_reader_ = parquet_reader_->RowGroup(current_row_group_); - column_readers_.clear(); - for (size_t i = 0; i < columns_.size(); i++) { - int64 column = columns_[i]; - std::shared_ptr column_reader = - row_group_reader_->Column(column); - column_readers_.emplace_back(column_reader); - } - } - current_row_ = 0; - return Status::OK(); - } - ~ParquetInputStream() { - current_row_ = 0; - column_readers_.clear(); - row_group_reader_.reset(); - current_row_group_ = 0; - file_metadata_.reset(); - parquet_reader_.reset(); - } - Status ReadRecord(int64 index, int64 record_to_read, std::vector* out_tensors, int64* record_read) { - while (current_row_group_ < file_metadata_->num_row_groups()) { - if (current_row_ < row_group_reader_->metadata()->num_rows()) { - // Read columns to outputs. - // TODO: Read more than one value at a time. - for (size_t i = 0; i < columns_.size(); i++) { - DataType dtype = dtypes_[i]; - std::shared_ptr column_reader = column_readers_[i]; - TF_RETURN_IF_ERROR(GetTensorValue(current_row_, dtype, column_reader.get(), &(*out_tensors)[i], index)); - } - ++current_row_; - *record_read = 1; - return Status::OK(); - } - // We have reached the end of the current row group, so maybe - // move on to next row group. - current_row_ = 0; - row_group_reader_.reset(); - ++current_row_group_; - TF_RETURN_IF_ERROR(ReadRowGroup()); - } - return Status::OK(); - } -private: - template - Status FillTensorValue(parquet::ColumnReader* column_reader, - typename DType::c_type* value) { - parquet::TypedColumnReader* reader = - static_cast*>(column_reader); - // Read one value at a time. The number of rows read is returned. - // values_read contains the number of non-null rows - int64_t values_read = 0; - int64_t rows_read = reader->ReadBatch(1, nullptr, nullptr, value, &values_read); - // Ensure only one value is read and there are no NULL values in the - // rows read - if (rows_read != 1) { - return errors::Internal("rows_read (", rows_read, ") != 1 or values_read (", values_read, ") != 1"); - } - return Status::OK(); - } - Status GetTensorValue(int64 row, const DataType& data_type, parquet::ColumnReader* column_reader, Tensor* tensor, int64 index) { - switch (data_type) { - case DT_INT32: { - parquet::TypedColumnReader* reader = - static_cast*>( - column_reader); - int32_t value; - TF_RETURN_IF_ERROR( - FillTensorValue(reader, &value)); - tensor->flat()(index) = value; - } break; - case DT_INT64: { - parquet::TypedColumnReader* reader = - static_cast*>( - column_reader); - int64_t value; - TF_RETURN_IF_ERROR( - FillTensorValue(reader, &value)); - tensor->flat()(index) = value; - } break; - case DT_FLOAT: { - parquet::TypedColumnReader* reader = - static_cast*>( - column_reader); - float value; - TF_RETURN_IF_ERROR( - FillTensorValue(reader, &value)); - tensor->flat()(index) = value; - } break; - case DT_DOUBLE: { - parquet::TypedColumnReader* reader = - static_cast*>( - column_reader); - double value; - TF_RETURN_IF_ERROR( - FillTensorValue(reader, &value)); - tensor->flat()(index) = value; - } break; - case DT_BOOL: { - parquet::TypedColumnReader* reader = - static_cast*>( - column_reader); - bool value; - TF_RETURN_IF_ERROR( - FillTensorValue(reader, &value)); - tensor->flat()(index) = value; - } break; - default: - return errors::Unimplemented( - DataTypeString(data_type), - " is currently not supported in ParquetDataset"); - } - return Status::OK(); - } - std::shared_ptr<::arrow::io::RandomAccessFile> input_stream_; - std::vector column_names_; - std::vector columns_; - std::vector dtypes_; - std::unique_ptr parquet_reader_; - std::shared_ptr file_metadata_; - int64 current_row_group_ = 0; - std::shared_ptr row_group_reader_; - std::vector> column_readers_; - int64 current_row_ = 0; -}; - -class ParquetInput: public FileInput { - public: - Status ReadRecord(io::InputStreamInterface* s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { - if (state.get() == nullptr) { - state.reset(new ParquetInputStream(s, columns())); - TF_RETURN_IF_ERROR(state.get()->ReadHeader()); - } - // Let's allocate enough space for Tensor, if more than read, replace. - for (int64 i = 0; i < state.get()->Columns(); i++) { - Tensor tensor(ctx->allocator({}), state.get()->DType(i), {record_to_read}); - out_tensors->emplace_back(std::move(tensor)); - } - while ((*record_read) < record_to_read) { - int64 count = 0; - TF_RETURN_IF_ERROR(state.get()->ReadRecord((*record_read), record_to_read - (*record_read), out_tensors, &count)); - (*record_read) += count; - if (count == 0) { - break; - } - } - if (*record_read < record_to_read) { - if (*record_read == 0) { - out_tensors->clear(); - } - for (size_t i = 0; i < out_tensors->size(); i++) { - Tensor tensor = (*out_tensors)[i].Slice(0, *record_read); - (*out_tensors)[i] = std::move(tensor); - } - } - return Status::OK(); - } - Status FromStream(io::InputStreamInterface* s) override { - return Status::OK(); - } - void EncodeAttributes(VariantTensorData* data) const override { - } - bool DecodeAttributes(const VariantTensorData& data) override { - return true; - } - protected: -}; - -REGISTER_UNARY_VARIANT_DECODE_FUNCTION(ParquetInput, "tensorflow::data::ParquetInput"); - -REGISTER_KERNEL_BUILDER(Name("ParquetInput").Device(DEVICE_CPU), - FileInputOp); -REGISTER_KERNEL_BUILDER(Name("ParquetDataset").Device(DEVICE_CPU), - FileInputDatasetOp); -} // namespace data -} // namespace tensorflow diff --git a/tensorflow_io/parquet/kernels/parquet_kernels.cc b/tensorflow_io/parquet/kernels/parquet_kernels.cc index 00fc4624d..13c85ac86 100644 --- a/tensorflow_io/parquet/kernels/parquet_kernels.cc +++ b/tensorflow_io/parquet/kernels/parquet_kernels.cc @@ -23,6 +23,54 @@ namespace tensorflow { namespace data { namespace { +// Note: This SizedRandomAccessFile should only lives within Compute() +// of the kernel as buffer could be released by outside. +class SizedRandomAccessFile : public tensorflow::RandomAccessFile { + public: + SizedRandomAccessFile(Env* env, const string& filename, const string& optional_memory) + : file_(nullptr) + , size_status_(Status::OK()) + , size_(optional_memory.size()) + , buffer_(optional_memory) { + if (size_ == 0) { + size_status_ = env->GetFileSize(filename, &size_); + if (size_status_.ok()) { + size_status_ = env->NewRandomAccessFile(filename, &file_); + } + } + } + + virtual ~SizedRandomAccessFile() {} + Status Read(uint64 offset, size_t n, StringPiece* result, char* scratch) const override { + if (file_.get() != nullptr) { + return file_.get()->Read(offset, n, result, scratch); + } + size_t bytes_to_read = 0; + if (offset < size_) { + bytes_to_read = (offset + n < size_) ? n : (size_ - offset); + } + if (bytes_to_read > 0) { + memcpy(scratch, buffer_.data(), bytes_to_read); + } + *result = StringPiece(scratch, bytes_to_read); + if (bytes_to_read < n) { + return errors::OutOfRange("EOF reached"); + } + return Status::OK(); + } + Status GetFileSize(uint64* size) { + if (size_status_.ok()) { + *size = size_; + } + return size_status_; + } + private: + std::unique_ptr file_; + Status size_status_; + uint64 size_; + const string& buffer_; +}; + class ParquetRandomAccessFile : public ::arrow::io::RandomAccessFile { public: explicit ParquetRandomAccessFile(tensorflow::RandomAccessFile *file, int64 size) @@ -77,9 +125,9 @@ class ParquetRandomAccessFile : public ::arrow::io::RandomAccessFile { int64 size_; }; -class ReadParquetSpecsOp : public OpKernel { +class ReadParquetColumnsOp : public OpKernel { public: - explicit ReadParquetSpecsOp(OpKernelConstruction* context) : OpKernel(context) { + explicit ReadParquetColumnsOp(OpKernelConstruction* context) : OpKernel(context) { env_ = context->env(); } @@ -87,10 +135,12 @@ class ReadParquetSpecsOp : public OpKernel { const Tensor& filename_tensor = context->input(0); const string filename = filename_tensor.scalar()(); - std::unique_ptr file; - OP_REQUIRES_OK(context, env_->NewRandomAccessFile(filename, &file)); - uint64 size = 0; - OP_REQUIRES_OK(context, env_->GetFileSize(filename, &size)); + const Tensor& memory_tensor = context->input(1); + const string& memory = memory_tensor.scalar()(); + + std::unique_ptr file(new SizedRandomAccessFile(env_, filename, memory)); + uint64 size; + OP_REQUIRES_OK(context, file->GetFileSize(&size)); std::shared_ptr parquet_file(new ParquetRandomAccessFile(file.get(), size)); std::shared_ptr<::parquet::FileMetaData> metadata = ::parquet::ReadMetaData(parquet_file); @@ -155,54 +205,6 @@ class ReadParquetSpecsOp : public OpKernel { Env* env_ GUARDED_BY(mu_); }; -// Note: This SizedRandomAccessFile should only lives within Compute() -// of the kernel as buffer could be released by outside. -class SizedRandomAccessFile : public tensorflow::RandomAccessFile { - public: - SizedRandomAccessFile(Env* env, const string& filename, const string& optional_memory) - : file_(nullptr) - , size_status_(Status::OK()) - , size_(optional_memory.size()) - , buffer_(optional_memory) { - if (size_ == 0) { - size_status_ = env->GetFileSize(filename, &size_); - if (size_status_.ok()) { - size_status_ = env->NewRandomAccessFile(filename, &file_); - } - } - } - - virtual ~SizedRandomAccessFile() {} - Status Read(uint64 offset, size_t n, StringPiece* result, char* scratch) const override { - if (file_.get() != nullptr) { - return file_.get()->Read(offset, n, result, scratch); - } - size_t bytes_to_read = 0; - if (offset < size_) { - bytes_to_read = (offset + n < size_) ? n : (size_ - offset); - } - if (bytes_to_read > 0) { - memcpy(scratch, buffer_.data(), bytes_to_read); - } - *result = StringPiece(scratch, bytes_to_read); - if (bytes_to_read < n) { - return errors::OutOfRange("EOF reached"); - } - return Status::OK(); - } - Status GetFileSize(uint64* size) { - if (size_status_.ok()) { - *size = size_; - } - return size_status_; - } - private: - std::unique_ptr file_; - Status size_status_; - uint64 size_; - const string& buffer_; -}; - class ReadParquetOp : public OpKernel { public: explicit ReadParquetOp(OpKernelConstruction* context) : OpKernel(context) { @@ -307,8 +309,8 @@ class ReadParquetOp : public OpKernel { Env* env_ GUARDED_BY(mu_); }; -REGISTER_KERNEL_BUILDER(Name("ReadParquetSpecs").Device(DEVICE_CPU), - ReadParquetSpecsOp); +REGISTER_KERNEL_BUILDER(Name("ReadParquetColumns").Device(DEVICE_CPU), + ReadParquetColumnsOp); REGISTER_KERNEL_BUILDER(Name("ReadParquet").Device(DEVICE_CPU), ReadParquetOp); diff --git a/tensorflow_io/parquet/ops/parquet_ops.cc b/tensorflow_io/parquet/ops/parquet_ops.cc index 8d7a81caf..dd5e58c89 100644 --- a/tensorflow_io/parquet/ops/parquet_ops.cc +++ b/tensorflow_io/parquet/ops/parquet_ops.cc @@ -19,8 +19,9 @@ limitations under the License. namespace tensorflow { -REGISTER_OP("ReadParquetSpecs") +REGISTER_OP("ReadParquetColumns") .Input("filename: string") + .Input("memory: string") .Output("columns: string") .Output("dtypes: string") .Output("shapes: int64") diff --git a/tensorflow_io/parquet/python/ops/parquet_ops.py b/tensorflow_io/parquet/python/ops/parquet_ops.py index 76acd8b69..7d4b07781 100644 --- a/tensorflow_io/parquet/python/ops/parquet_ops.py +++ b/tensorflow_io/parquet/python/ops/parquet_ops.py @@ -21,22 +21,24 @@ from tensorflow_io.core.python.ops import core_ops as parquet_ops from tensorflow_io.core.python.ops import data_ops -def read_parquet_specs(filename): - """read_parquet_specs""" +def read_parquet_columns(filename, **kwargs): + """read_parquet_columns""" if not tf.executing_eagerly(): raise NotImplementedError("read_parquet_spect only support eager mode") - columns, dtypes, shapes = parquet_ops.read_parquet_specs(filename) + memory = kwargs.get("memory", "") + columns, dtypes, shapes = parquet_ops.read_parquet_columns( + filename, memory=memory) entries = zip(tf.unstack(columns), tf.unstack(dtypes), tf.unstack(shapes)) - return dict([(column.numpy(), tf.TensorSpec( - shape.numpy(), dtype.numpy(), column.numpy())) for ( + return dict([(column.numpy().decode(), tf.TensorSpec( + shape.numpy(), dtype.numpy().decode(), column.numpy().decode())) for ( column, dtype, shape) in entries]) -def read_parquet(filename, spec, start=0, **kwargs): +def read_parquet(filename, column, start=0, **kwargs): """read_parquet""" memory = kwargs.get("memory", "") return parquet_ops.read_parquet( - filename, spec.name, - start=start, count=spec.shape[0] - start, dtype=spec.dtype, + filename, column.name, + start=start, count=column.shape[0] - start, dtype=column.dtype, memory=memory) class ParquetDataset(data_ops.BaseDataset): @@ -56,9 +58,9 @@ def __init__(self, filename, column, batch=None, **kwargs): count = kwargs.get("count") dtype = kwargs.get("dtype") else: - specs = read_parquet_specs(filename) - count = specs[column].shape[0] - dtype = specs[column].dtype + columns = read_parquet_columns(filename) + count = columns[column].shape[0] + dtype = columns[column].dtype batch = 0 if batch is None else batch shape = tf.TensorShape([]) if ( @@ -76,10 +78,10 @@ def __init__(self, filename, column, batch=None, **kwargs): ).map(lambda start, count: parquet_ops.read_parquet( filename, column, start, count, dtype=dtype, memory="")) if batch is None or batch == 0: - self._dataset = dataset.unbatch() + self._dataset = dataset.apply(tf.data.experimental.unbatch()) else: # TODO: convert to rebatch for performance - self._dataset = dataset.unbatch().batch(batch) + self._dataset = dataset.apply(tf.data.experimental.unbatch()).batch(batch) super(ParquetDataset, self).__init__( self._dataset._variant_tensor, [dtype], [shape]) # pylint: disable=protected-access diff --git a/tests/test_parquet_eager.py b/tests/test_parquet_eager.py index 4dd67ed7f..d49c5c236 100644 --- a/tests/test_parquet_eager.py +++ b/tests/test_parquet_eager.py @@ -47,7 +47,7 @@ def test_parquet(): "parquet_cpp_example.parquet") filename = "file://" + filename - specs = parquet_io.read_parquet_specs(filename) + specs = parquet_io.read_parquet_columns(filename) columns = [ 'boolean_field', 'int32_field', From 59019969a55d36cdf58e18b64ac66a6b3a3e4a7c Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 28 Jul 2019 16:35:30 +0000 Subject: [PATCH 3/4] Rename read_parquet_columns => list_parquet_columns Signed-off-by: Yong Tang --- tensorflow_io/parquet/__init__.py | 6 +++--- tensorflow_io/parquet/kernels/parquet_kernels.cc | 8 ++++---- tensorflow_io/parquet/ops/parquet_ops.cc | 2 +- tensorflow_io/parquet/python/ops/parquet_ops.py | 8 ++++---- tests/test_parquet_eager.py | 2 +- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/tensorflow_io/parquet/__init__.py b/tensorflow_io/parquet/__init__.py index e64ed8b83..4f7ca258d 100644 --- a/tensorflow_io/parquet/__init__.py +++ b/tensorflow_io/parquet/__init__.py @@ -16,7 +16,7 @@ @@ParquetDataset @@read_parquet -@@read_parquet_columns +@@list_parquet_columns """ from __future__ import absolute_import @@ -25,14 +25,14 @@ from tensorflow_io.parquet.python.ops.parquet_ops import ParquetDataset from tensorflow_io.parquet.python.ops.parquet_ops import read_parquet -from tensorflow_io.parquet.python.ops.parquet_ops import read_parquet_columns +from tensorflow_io.parquet.python.ops.parquet_ops import list_parquet_columns from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ "ParquetDataset", "read_parquet", - "read_parquet_columns", + "list_parquet_columns", ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow_io/parquet/kernels/parquet_kernels.cc b/tensorflow_io/parquet/kernels/parquet_kernels.cc index 13c85ac86..df40937fb 100644 --- a/tensorflow_io/parquet/kernels/parquet_kernels.cc +++ b/tensorflow_io/parquet/kernels/parquet_kernels.cc @@ -125,9 +125,9 @@ class ParquetRandomAccessFile : public ::arrow::io::RandomAccessFile { int64 size_; }; -class ReadParquetColumnsOp : public OpKernel { +class ListParquetColumnsOp : public OpKernel { public: - explicit ReadParquetColumnsOp(OpKernelConstruction* context) : OpKernel(context) { + explicit ListParquetColumnsOp(OpKernelConstruction* context) : OpKernel(context) { env_ = context->env(); } @@ -309,8 +309,8 @@ class ReadParquetOp : public OpKernel { Env* env_ GUARDED_BY(mu_); }; -REGISTER_KERNEL_BUILDER(Name("ReadParquetColumns").Device(DEVICE_CPU), - ReadParquetColumnsOp); +REGISTER_KERNEL_BUILDER(Name("ListParquetColumns").Device(DEVICE_CPU), + ListParquetColumnsOp); REGISTER_KERNEL_BUILDER(Name("ReadParquet").Device(DEVICE_CPU), ReadParquetOp); diff --git a/tensorflow_io/parquet/ops/parquet_ops.cc b/tensorflow_io/parquet/ops/parquet_ops.cc index dd5e58c89..d12b9b7b2 100644 --- a/tensorflow_io/parquet/ops/parquet_ops.cc +++ b/tensorflow_io/parquet/ops/parquet_ops.cc @@ -19,7 +19,7 @@ limitations under the License. namespace tensorflow { -REGISTER_OP("ReadParquetColumns") +REGISTER_OP("ListParquetColumns") .Input("filename: string") .Input("memory: string") .Output("columns: string") diff --git a/tensorflow_io/parquet/python/ops/parquet_ops.py b/tensorflow_io/parquet/python/ops/parquet_ops.py index 7d4b07781..1d97f3a21 100644 --- a/tensorflow_io/parquet/python/ops/parquet_ops.py +++ b/tensorflow_io/parquet/python/ops/parquet_ops.py @@ -21,12 +21,12 @@ from tensorflow_io.core.python.ops import core_ops as parquet_ops from tensorflow_io.core.python.ops import data_ops -def read_parquet_columns(filename, **kwargs): - """read_parquet_columns""" +def list_parquet_columns(filename, **kwargs): + """list_parquet_columns""" if not tf.executing_eagerly(): raise NotImplementedError("read_parquet_spect only support eager mode") memory = kwargs.get("memory", "") - columns, dtypes, shapes = parquet_ops.read_parquet_columns( + columns, dtypes, shapes = parquet_ops.list_parquet_columns( filename, memory=memory) entries = zip(tf.unstack(columns), tf.unstack(dtypes), tf.unstack(shapes)) return dict([(column.numpy().decode(), tf.TensorSpec( @@ -58,7 +58,7 @@ def __init__(self, filename, column, batch=None, **kwargs): count = kwargs.get("count") dtype = kwargs.get("dtype") else: - columns = read_parquet_columns(filename) + columns = list_parquet_columns(filename) count = columns[column].shape[0] dtype = columns[column].dtype diff --git a/tests/test_parquet_eager.py b/tests/test_parquet_eager.py index d49c5c236..ff3e6363c 100644 --- a/tests/test_parquet_eager.py +++ b/tests/test_parquet_eager.py @@ -47,7 +47,7 @@ def test_parquet(): "parquet_cpp_example.parquet") filename = "file://" + filename - specs = parquet_io.read_parquet_columns(filename) + specs = parquet_io.list_parquet_columns(filename) columns = [ 'boolean_field', 'int32_field', From fe9eb2acd1244dd05bf0caa9fd77b06ecc9da787 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Wed, 31 Jul 2019 04:51:57 +0000 Subject: [PATCH 4/4] Remove batch args, and add test in graph mode Signed-off-by: Yong Tang --- .../parquet/kernels/parquet_kernels.cc | 86 +++++------------- tensorflow_io/parquet/ops/parquet_ops.cc | 6 +- .../parquet/python/ops/parquet_ops.py | 48 +++++----- tests/test_parquet.py | 89 +++++++++++++++++++ tests/test_parquet_eager.py | 3 +- 5 files changed, 139 insertions(+), 93 deletions(-) create mode 100644 tests/test_parquet.py diff --git a/tensorflow_io/parquet/kernels/parquet_kernels.cc b/tensorflow_io/parquet/kernels/parquet_kernels.cc index df40937fb..e6caf99a8 100644 --- a/tensorflow_io/parquet/kernels/parquet_kernels.cc +++ b/tensorflow_io/parquet/kernels/parquet_kernels.cc @@ -13,64 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#define EIGEN_USE_THREADS - -#include "kernels/dataset_ops.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow_io/core/kernels/stream.h" #include "parquet/api/reader.h" namespace tensorflow { namespace data { namespace { -// Note: This SizedRandomAccessFile should only lives within Compute() -// of the kernel as buffer could be released by outside. -class SizedRandomAccessFile : public tensorflow::RandomAccessFile { - public: - SizedRandomAccessFile(Env* env, const string& filename, const string& optional_memory) - : file_(nullptr) - , size_status_(Status::OK()) - , size_(optional_memory.size()) - , buffer_(optional_memory) { - if (size_ == 0) { - size_status_ = env->GetFileSize(filename, &size_); - if (size_status_.ok()) { - size_status_ = env->NewRandomAccessFile(filename, &file_); - } - } - } - - virtual ~SizedRandomAccessFile() {} - Status Read(uint64 offset, size_t n, StringPiece* result, char* scratch) const override { - if (file_.get() != nullptr) { - return file_.get()->Read(offset, n, result, scratch); - } - size_t bytes_to_read = 0; - if (offset < size_) { - bytes_to_read = (offset + n < size_) ? n : (size_ - offset); - } - if (bytes_to_read > 0) { - memcpy(scratch, buffer_.data(), bytes_to_read); - } - *result = StringPiece(scratch, bytes_to_read); - if (bytes_to_read < n) { - return errors::OutOfRange("EOF reached"); - } - return Status::OK(); - } - Status GetFileSize(uint64* size) { - if (size_status_.ok()) { - *size = size_; - } - return size_status_; - } - private: - std::unique_ptr file_; - Status size_status_; - uint64 size_; - const string& buffer_; -}; - class ParquetRandomAccessFile : public ::arrow::io::RandomAccessFile { public: explicit ParquetRandomAccessFile(tensorflow::RandomAccessFile *file, int64 size) @@ -138,7 +88,7 @@ class ListParquetColumnsOp : public OpKernel { const Tensor& memory_tensor = context->input(1); const string& memory = memory_tensor.scalar()(); - std::unique_ptr file(new SizedRandomAccessFile(env_, filename, memory)); + std::unique_ptr file(new SizedRandomAccessFile(env_, filename, memory.data(), memory.size())); uint64 size; OP_REQUIRES_OK(context, file->GetFileSize(&size)); @@ -218,16 +168,16 @@ class ReadParquetOp : public OpKernel { const Tensor& column_tensor = context->input(1); const string& column = column_tensor.scalar()(); - const Tensor& start_tensor = context->input(2); - const int64 start = start_tensor.scalar()(); + const Tensor& memory_tensor = context->input(2); + const string& memory = memory_tensor.scalar()(); - const Tensor& count_tensor = context->input(3); - int64 count = count_tensor.scalar()(); + const Tensor& start_tensor = context->input(3); + int64 start = start_tensor.scalar()(); - const Tensor& memory_tensor = context->input(4); - const string& memory = memory_tensor.scalar()(); + const Tensor& stop_tensor = context->input(4); + int64 stop = stop_tensor.scalar()(); - std::unique_ptr file(new SizedRandomAccessFile(env_, filename, memory)); + std::unique_ptr file(new SizedRandomAccessFile(env_, filename, memory.data(), memory.size())); uint64 size; OP_REQUIRES_OK(context, file->GetFileSize(&size)); @@ -243,11 +193,17 @@ class ReadParquetOp : public OpKernel { } OP_REQUIRES(context, (column_index < file_metadata->num_columns()), errors::InvalidArgument("unable to find column: ", column)); - if (start + count > file_metadata->num_rows()) { - count = file_metadata->num_rows() - start; + if (start > file_metadata->num_rows()) { + start = file_metadata->num_rows(); + } + if (stop < 0) { + stop = file_metadata->num_rows(); + } + if (stop > file_metadata->num_rows()) { + stop = file_metadata->num_rows(); } - TensorShape output_shape({count}); + TensorShape output_shape({stop - start}); Tensor* output_tensor; OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output_tensor)); @@ -255,14 +211,14 @@ class ReadParquetOp : public OpKernel { int64 row_group_offset = 0; for (int row_group = 0; row_group < file_metadata->num_row_groups(); row_group++) { std::shared_ptr row_group_reader = parquet_reader->RowGroup(row_group); - // Skip if row group is not within [start..start+count] - if ((row_group_offset + row_group_reader->metadata()->num_rows() < start) || (start + count <= row_group_offset)) { + // Skip if row group is not within [start..stop] + if ((row_group_offset + row_group_reader->metadata()->num_rows() < start) || (stop <= row_group_offset)) { row_group_offset += row_group_reader->metadata()->num_rows(); continue; } // Find row_to_read range int64 row_to_read_start = row_group_offset > start ? row_group_offset : start; - int64 row_to_read_final = (row_group_offset + row_group_reader->metadata()->num_rows()) < (start + count) ? (row_group_offset + row_group_reader->metadata()->num_rows()) : (start + count); + int64 row_to_read_final = (row_group_offset + row_group_reader->metadata()->num_rows()) < (stop) ? (row_group_offset + row_group_reader->metadata()->num_rows()) : (stop); int64 row_to_read_count = row_to_read_final - row_to_read_start; std::shared_ptr column_reader = row_group_reader->Column(column_index); diff --git a/tensorflow_io/parquet/ops/parquet_ops.cc b/tensorflow_io/parquet/ops/parquet_ops.cc index d12b9b7b2..32c318f8c 100644 --- a/tensorflow_io/parquet/ops/parquet_ops.cc +++ b/tensorflow_io/parquet/ops/parquet_ops.cc @@ -27,15 +27,17 @@ REGISTER_OP("ListParquetColumns") .Output("shapes: int64") .SetShapeFn([](shape_inference::InferenceContext* c) { c->set_output(0, c->MakeShape({c->UnknownDim()})); + c->set_output(1, c->MakeShape({c->UnknownDim()})); + c->set_output(2, c->MakeShape({c->UnknownDim(), c->UnknownDim()})); return Status::OK(); }); REGISTER_OP("ReadParquet") .Input("filename: string") .Input("column: string") - .Input("start: int64") - .Input("count: int64") .Input("memory: string") + .Input("start: int64") + .Input("stop: int64") .Attr("dtype: type") .Output("output: dtype") .SetShapeFn([](shape_inference::InferenceContext* c) { diff --git a/tensorflow_io/parquet/python/ops/parquet_ops.py b/tensorflow_io/parquet/python/ops/parquet_ops.py index 1d97f3a21..a0cb9407a 100644 --- a/tensorflow_io/parquet/python/ops/parquet_ops.py +++ b/tensorflow_io/parquet/python/ops/parquet_ops.py @@ -24,7 +24,7 @@ def list_parquet_columns(filename, **kwargs): """list_parquet_columns""" if not tf.executing_eagerly(): - raise NotImplementedError("read_parquet_spect only support eager mode") + raise NotImplementedError("list_parquet_columns only support eager mode") memory = kwargs.get("memory", "") columns, dtypes, shapes = parquet_ops.list_parquet_columns( filename, memory=memory) @@ -33,18 +33,23 @@ def list_parquet_columns(filename, **kwargs): shape.numpy(), dtype.numpy().decode(), column.numpy().decode())) for ( column, dtype, shape) in entries]) -def read_parquet(filename, column, start=0, **kwargs): +def read_parquet(filename, column, **kwargs): """read_parquet""" memory = kwargs.get("memory", "") + start = kwargs.get("start", 0) + stop = kwargs.get("stop", None) + if stop is None and column.shape[0] is not None: + stop = column.shape[0] - start + if stop is None: + stop = -1 return parquet_ops.read_parquet( - filename, column.name, - start=start, count=column.shape[0] - start, dtype=column.dtype, - memory=memory) + filename, column.name, memory=memory, + start=start, stop=-1, dtype=column.dtype) class ParquetDataset(data_ops.BaseDataset): """A Parquet Dataset that reads the parquet file.""" - def __init__(self, filename, column, batch=None, **kwargs): + def __init__(self, filename, column, **kwargs): """Create a `ParquetDataset`. `ParquetDataset` allows a user to read data from a parquet file. @@ -53,35 +58,28 @@ def __init__(self, filename, column, batch=None, **kwargs): filename: filename of the parquet file to read. column: column name to read. """ - # Note: count and dtype could be in kwargs if in graph mode. + # Note: start, stop and dtype could be in kwargs if in graph mode. if not tf.executing_eagerly(): - count = kwargs.get("count") + start = kwargs.get("start") + stop = kwargs.get("stop") dtype = kwargs.get("dtype") else: columns = list_parquet_columns(filename) - count = columns[column].shape[0] + start = 0 + stop = columns[column].shape[0] dtype = columns[column].dtype - batch = 0 if batch is None else batch - shape = tf.TensorShape([]) if ( - batch is None or batch == 0) else tf.TensorShape([None]) + shape = tf.TensorShape([None]) # capacity is the rough count for each chunk in dataset - # not directly related to batch, will be padded to batch though capacity = kwargs.get("capacity", 65536) - if batch is not None and batch != 0 and capacity > batch: - capacity = (capacity // batch) * batch - entry_start = range(0, count, capacity) - entry_count = [min(capacity, count - start) for start in entry_start] + entry_start = list(range(start, stop, capacity)) + entry_stop = entry_start[1:] + [stop] dataset = data_ops.BaseDataset.from_tensor_slices( - (tf.constant(entry_start, tf.int64), tf.constant(entry_count, tf.int64)) - ).map(lambda start, count: parquet_ops.read_parquet( - filename, column, start, count, dtype=dtype, memory="")) - if batch is None or batch == 0: - self._dataset = dataset.apply(tf.data.experimental.unbatch()) - else: - # TODO: convert to rebatch for performance - self._dataset = dataset.apply(tf.data.experimental.unbatch()).batch(batch) + (tf.constant(entry_start, tf.int64), tf.constant(entry_stop, tf.int64)) + ).map(lambda start, stop: parquet_ops.read_parquet( + filename, column, memory="", start=start, stop=stop, dtype=dtype)) + self._dataset = dataset super(ParquetDataset, self).__init__( self._dataset._variant_tensor, [dtype], [shape]) # pylint: disable=protected-access diff --git a/tests/test_parquet.py b/tests/test_parquet.py new file mode 100644 index 000000000..e66d67461 --- /dev/null +++ b/tests/test_parquet.py @@ -0,0 +1,89 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for ParquetDataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import pytest +import numpy as np + +import tensorflow as tf +tf.compat.v1.disable_eager_execution() +import tensorflow_io.parquet as parquet_io # pylint: disable=wrong-import-position + +# Note: The sample file is generated from: +# `parquet-cpp/examples/low-level-api/reader_writer` +# This test extracts columns of [0, 1, 2, 4, 5] +# with column data types of [bool, int32, int64, float, double]. +# Please check `parquet-cpp/examples/low-level-api/reader-writer.cc` +# to find details of how records are generated: +# Column 0 (bool): True for even rows and False otherwise. +# Column 1 (int32): Equal to row_index. +# Column 2 (int64): Equal to row_index * 1000 * 1000 * 1000 * 1000. +# Column 4 (float): Equal to row_index * 1.1. +# Column 5 (double): Equal to row_index * 1.1111111. +def test_parquet(): + """Test case for ParquetDataset.""" + filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_parquet", + "parquet_cpp_example.parquet") + filename = "file://" + filename + + columns = [ + 'boolean_field', + 'int32_field', + 'int64_field', + 'float_field', + 'double_field'] + dtypes = [ + tf.bool, + tf.int32, + tf.int64, + tf.float32, + tf.double] + + dataset = tf.compat.v2.data.Dataset.zip( + tuple([parquet_io.ParquetDataset( + filename, column, dtype=dtype, + start=0, stop=500) for ( + column, dtype) in zip(columns, dtypes)])).apply( + tf.data.experimental.unbatch()) + + iterator = tf.compat.v1.data.make_initializable_iterator(dataset) + init_op = iterator.initializer + get_next = iterator.get_next() + with tf.compat.v1.Session() as sess: + sess.run(init_op) + for i in range(500): + v0 = ((i % 2) == 0) + v1 = i + v2 = i * 1000 * 1000 * 1000 * 1000 + v4 = 1.1 * i + v5 = 1.1111111 * i + p0, p1, p2, p4, p5 = sess.run(get_next) + assert v0 == p0 + assert v1 == p1 + assert v2 == p2 + assert np.isclose(v4, p4) + assert np.isclose(v5, p5) + with pytest.raises(tf.errors.OutOfRangeError): + sess.run(get_next) + +if __name__ == "__main__": + test.main() diff --git a/tests/test_parquet_eager.py b/tests/test_parquet_eager.py index ff3e6363c..d80440a20 100644 --- a/tests/test_parquet_eager.py +++ b/tests/test_parquet_eager.py @@ -74,7 +74,8 @@ def test_parquet(): dataset = tf.compat.v2.data.Dataset.zip( tuple( - [parquet_io.ParquetDataset(filename, column) for column in columns])) + [parquet_io.ParquetDataset(filename, column) for column in columns]) + ).apply(tf.data.experimental.unbatch()) i = 0 for p in dataset: v0 = ((i % 2) == 0)