Skip to content

Commit 200cd33

Browse files
committed
Add tfio.IOTensor.from_parquet support
Note: this PR depends on PR 438. Parquet columnar file format that naturally fits into a table/column data. Since Parquet file itself is indexable, degenerating parquet into an iterable dataset is not desirable as it loses convenience and flexibility. This PR adds tfio.IOTensor.from_parquet support so that it is possible to acess parquet data through natual `__getitem__` operations. Signed-off-by: Yong Tang <[email protected]>
1 parent 6844b4f commit 200cd33

File tree

6 files changed

+298
-27
lines changed

6 files changed

+298
-27
lines changed

tensorflow_io/core/python/ops/io_tensor.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from tensorflow_io.core.python.ops import json_io_tensor_ops
2424
from tensorflow_io.core.python.ops import kafka_io_tensor_ops
2525
from tensorflow_io.core.python.ops import prometheus_io_tensor_ops
26+
from tensorflow_io.core.python.ops import parquet_io_tensor_ops
2627

2728
class IOTensor(io_tensor_ops._IOTensor): # pylint: disable=protected-access
2829
"""IOTensor
@@ -287,3 +288,20 @@ def from_prometheus(cls,
287288
with tf.name_scope(kwargs.get("name", "IOFromPrometheus")):
288289
return prometheus_io_tensor_ops.PrometheusIOTensor(
289290
query, endpoint=kwargs.get("endpoint", None), internal=True)
291+
292+
@classmethod
293+
def from_parquet(cls,
294+
filename,
295+
**kwargs):
296+
"""Creates an `IOTensor` from a parquet file.
297+
298+
Args:
299+
filename: A string, the filename of a parquet file.
300+
name: A name prefix for the IOTensor (optional).
301+
302+
Returns:
303+
A `IOTensor`.
304+
305+
"""
306+
with tf.name_scope(kwargs.get("name", "IOFromParquet")):
307+
return parquet_io_tensor_ops.ParquetIOTensor(filename, internal=True)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""ParquetIOTensor"""
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import uuid
21+
22+
import tensorflow as tf
23+
from tensorflow_io.core.python.ops import io_tensor_ops
24+
from tensorflow_io.core.python.ops import core_ops
25+
26+
class ParquetIOTensor(io_tensor_ops._TableIOTensor): # pylint: disable=protected-access
27+
"""ParquetIOTensor"""
28+
29+
#=============================================================================
30+
# Constructor (private)
31+
#=============================================================================
32+
def __init__(self,
33+
filename,
34+
internal=False):
35+
with tf.name_scope("ParquetIOTensor") as scope:
36+
resource, shapes, dtypes, columns = core_ops.parquet_indexable_init(
37+
filename,
38+
container=scope,
39+
shared_name="%s/%s" % (filename, uuid.uuid4().hex))
40+
shapes = [
41+
tf.TensorShape(
42+
[None if dim < 0 else dim for dim in e.numpy() if dim != 0]
43+
) for e in tf.unstack(shapes)]
44+
dtypes = [tf.as_dtype(e.numpy()) for e in tf.unstack(dtypes)]
45+
columns = [e.numpy().decode() for e in tf.unstack(columns)]
46+
spec = tuple([tf.TensorSpec(shape, dtype, column) for (
47+
shape, dtype, column) in zip(shapes, dtypes, columns)])
48+
super(ParquetIOTensor, self).__init__(
49+
spec, columns,
50+
resource, core_ops.parquet_indexable_get_item,
51+
internal=internal)

tensorflow_io/parquet/kernels/parquet_kernels.cc

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515

1616
#include "tensorflow/core/framework/op_kernel.h"
1717
#include "tensorflow_io/arrow/kernels/arrow_kernels.h"
18+
#include "tensorflow_io/core/kernels/io_interface.h"
1819
#include "parquet/api/reader.h"
1920

2021
namespace tensorflow {
@@ -218,5 +219,173 @@ REGISTER_KERNEL_BUILDER(Name("ReadParquet").Device(DEVICE_CPU),
218219

219220

220221
} // namespace
222+
223+
224+
class ParquetIndexable : public IOIndexableInterface {
225+
public:
226+
ParquetIndexable(Env* env)
227+
: env_(env) {}
228+
229+
~ParquetIndexable() {}
230+
Status Init(const std::vector<string>& input, const std::vector<string>& metadata, const void* memory_data, const int64 memory_size) override {
231+
if (input.size() > 1) {
232+
return errors::InvalidArgument("more than 1 filename is not supported");
233+
}
234+
const string& filename = input[0];
235+
file_.reset(new SizedRandomAccessFile(env_, filename, memory_data, memory_size));
236+
TF_RETURN_IF_ERROR(file_->GetFileSize(&file_size_));
237+
238+
parquet_file_.reset(new ArrowRandomAccessFile(file_.get(), file_size_));
239+
parquet_reader_ = parquet::ParquetFileReader::Open(parquet_file_);
240+
parquet_metadata_ = parquet_reader_->metadata();
241+
242+
shapes_.clear();
243+
dtypes_.clear();
244+
columns_.clear();
245+
for (size_t i = 0; i < parquet_metadata_->num_columns(); i++) {
246+
::tensorflow::DataType dtype;
247+
switch(parquet_metadata_->schema()->Column(i)->physical_type()) {
248+
case parquet::Type::BOOLEAN:
249+
dtype = ::tensorflow::DT_BOOL;
250+
break;
251+
case parquet::Type::INT32:
252+
dtype = ::tensorflow::DT_INT32;
253+
break;
254+
case parquet::Type::INT64:
255+
dtype = ::tensorflow::DT_INT64;
256+
break;
257+
case parquet::Type::INT96: // Deprecated, thrown out exception when access with __getitem__
258+
dtype = ::tensorflow::DT_INT64;
259+
break;
260+
case parquet::Type::FLOAT:
261+
dtype = ::tensorflow::DT_FLOAT;
262+
break;
263+
case parquet::Type::DOUBLE:
264+
dtype = ::tensorflow::DT_DOUBLE;
265+
break;
266+
case parquet::Type::BYTE_ARRAY:
267+
dtype = ::tensorflow::DT_STRING;
268+
break;
269+
case parquet::Type::FIXED_LEN_BYTE_ARRAY:
270+
dtype = ::tensorflow::DT_STRING;
271+
break;
272+
default:
273+
return errors::InvalidArgument("parquet data type is not supported: ", parquet_metadata_->schema()->Column(i)->physical_type());
274+
break;
275+
}
276+
shapes_.push_back(TensorShape({static_cast<int64>(parquet_metadata_->num_rows())}));
277+
dtypes_.push_back(dtype);
278+
columns_.push_back(parquet_metadata_->schema()->Column(i)->path().get()->ToDotString());
279+
}
280+
281+
return Status::OK();
282+
}
283+
Status Spec(std::vector<PartialTensorShape>& shapes, std::vector<DataType>& dtypes) override {
284+
shapes.clear();
285+
for (size_t i = 0; i < shapes_.size(); i++) {
286+
shapes.push_back(shapes_[i]);
287+
}
288+
dtypes.clear();
289+
for (size_t i = 0; i < dtypes_.size(); i++) {
290+
dtypes.push_back(dtypes_[i]);
291+
}
292+
return Status::OK();
293+
}
294+
295+
Status Extra(std::vector<Tensor>* extra) override {
296+
// Expose columns
297+
Tensor columns(DT_STRING, TensorShape({static_cast<int64>(columns_.size())}));
298+
for (size_t i = 0; i < columns_.size(); i++) {
299+
columns.flat<string>()(i) = columns_[i];
300+
}
301+
extra->push_back(columns);
302+
return Status::OK();
303+
}
304+
305+
Status GetItem(const int64 start, const int64 stop, const int64 step, const int64 component, Tensor* tensor) override {
306+
if (step != 1) {
307+
return errors::InvalidArgument("step ", step, " is not supported");
308+
}
309+
int64 row_group_offset = 0;
310+
for (int row_group = 0; row_group < parquet_metadata_->num_row_groups(); row_group++) {
311+
std::shared_ptr<parquet::RowGroupReader> row_group_reader = parquet_reader_->RowGroup(row_group);
312+
// Skip if row group is not within [start..stop]
313+
if ((row_group_offset + row_group_reader->metadata()->num_rows() < start) || (stop <= row_group_offset)) {
314+
row_group_offset += row_group_reader->metadata()->num_rows();
315+
continue;
316+
}
317+
// Find row_to_read range
318+
int64 row_to_read_start = row_group_offset > start ? row_group_offset : start;
319+
int64 row_to_read_final = (row_group_offset + row_group_reader->metadata()->num_rows()) < (stop) ? (row_group_offset + row_group_reader->metadata()->num_rows()) : (stop);
320+
int64 row_to_read_count = row_to_read_final - row_to_read_start;
321+
322+
// TODO: parquet is RowGroup based so ideally the RowGroup should be cached
323+
// with the hope of indexing and slicing happens on each row. For now no caching
324+
// is done yet.
325+
std::shared_ptr<parquet::ColumnReader> column_reader = row_group_reader->Column(component);
326+
327+
// buffer to fill location is tensor.data()[row_to_read_start - start]
328+
329+
#define PARQUET_PROCESS_TYPE(ptype, type) { \
330+
parquet::TypedColumnReader<ptype>* reader = \
331+
static_cast<parquet::TypedColumnReader<ptype>*>( \
332+
column_reader.get()); \
333+
if (row_to_read_start > row_group_offset) { \
334+
reader->Skip(row_to_read_start - row_group_offset); \
335+
} \
336+
ptype::c_type* value = (ptype::c_type *)(void *)(&(tensor->flat<type>().data()[row_to_read_start - start])); \
337+
int64_t values_read; \
338+
int64_t levels_read = reader->ReadBatch(row_to_read_count, nullptr, nullptr, value, &values_read); \
339+
if (!(levels_read == values_read && levels_read == row_to_read_count)) { \
340+
return errors::InvalidArgument("null value in column: ", columns_[component]); \
341+
} \
342+
}
343+
switch (parquet_metadata_->schema()->Column(component)->physical_type()) {
344+
case parquet::Type::BOOLEAN:
345+
PARQUET_PROCESS_TYPE(parquet::BooleanType, bool);
346+
break;
347+
case parquet::Type::INT32:
348+
PARQUET_PROCESS_TYPE(parquet::Int32Type, int32);
349+
break;
350+
case parquet::Type::INT64:
351+
PARQUET_PROCESS_TYPE(parquet::Int64Type, int64);
352+
break;
353+
case parquet::Type::FLOAT:
354+
PARQUET_PROCESS_TYPE(parquet::FloatType, float);
355+
break;
356+
case parquet::Type::DOUBLE:
357+
PARQUET_PROCESS_TYPE(parquet::DoubleType, double);
358+
break;
359+
default:
360+
return errors::InvalidArgument("invalid data type: ", parquet_metadata_->schema()->Column(component)->physical_type());
361+
}
362+
row_group_offset += row_group_reader->metadata()->num_rows();
363+
}
364+
return Status::OK();
365+
}
366+
367+
string DebugString() const override {
368+
mutex_lock l(mu_);
369+
return strings::StrCat("ParquetIndexable");
370+
}
371+
private:
372+
mutable mutex mu_;
373+
Env* env_ GUARDED_BY(mu_);
374+
std::unique_ptr<SizedRandomAccessFile> file_ GUARDED_BY(mu_);
375+
uint64 file_size_ GUARDED_BY(mu_);
376+
std::shared_ptr<ArrowRandomAccessFile> parquet_file_;
377+
std::unique_ptr<::parquet::ParquetFileReader> parquet_reader_;
378+
std::shared_ptr<::parquet::FileMetaData> parquet_metadata_;
379+
380+
std::vector<DataType> dtypes_;
381+
std::vector<TensorShape> shapes_;
382+
std::vector<string> columns_;
383+
std::vector<int> columns_index_;
384+
};
385+
386+
REGISTER_KERNEL_BUILDER(Name("ParquetIndexableInit").Device(DEVICE_CPU),
387+
IOInterfaceInitOp<ParquetIndexable>);
388+
REGISTER_KERNEL_BUILDER(Name("ParquetIndexableGetItem").Device(DEVICE_CPU),
389+
IOIndexableGetItemOp<ParquetIndexable>);
221390
} // namespace data
222391
} // namespace tensorflow

tensorflow_io/parquet/ops/parquet_ops.cc

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,41 @@ limitations under the License.
1919

2020
namespace tensorflow {
2121

22+
REGISTER_OP("ParquetIndexableInit")
23+
.Input("input: string")
24+
.Output("output: resource")
25+
.Output("shapes: int64")
26+
.Output("dtypes: int64")
27+
.Output("columns: string")
28+
.Attr("container: string = ''")
29+
.Attr("shared_name: string = ''")
30+
.SetIsStateful()
31+
.SetShapeFn([](shape_inference::InferenceContext* c) {
32+
c->set_output(0, c->Scalar());
33+
c->set_output(1, c->MakeShape({c->UnknownDim()}));
34+
c->set_output(2, c->MakeShape({c->UnknownDim(), c->UnknownDim()}));
35+
c->set_output(3, c->MakeShape({c->UnknownDim()}));
36+
return Status::OK();
37+
});
38+
39+
REGISTER_OP("ParquetIndexableGetItem")
40+
.Input("input: resource")
41+
.Input("start: int64")
42+
.Input("stop: int64")
43+
.Input("step: int64")
44+
.Input("component: int64")
45+
.Output("output: dtype")
46+
.Attr("shape: shape")
47+
.Attr("dtype: type")
48+
.SetShapeFn([](shape_inference::InferenceContext* c) {
49+
PartialTensorShape shape;
50+
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
51+
shape_inference::ShapeHandle entry;
52+
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &entry));
53+
c->set_output(0, entry);
54+
return Status::OK();
55+
});
56+
2257
REGISTER_OP("ListParquetColumns")
2358
.Input("filename: string")
2459
.Input("memory: string")

tensorflow_io/parquet/python/ops/parquet_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,18 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
import warnings
21+
2022
import tensorflow as tf
2123
from tensorflow_io.core.python.ops import core_ops as parquet_ops
2224
from tensorflow_io.core.python.ops import data_ops
2325

26+
warnings.warn(
27+
"The tensorflow_io.parquet.ParquetDataset is "
28+
"deprecated. Please look for tfio.IOTensor.from_parquet "
29+
"for reading parquet files into tensorflow.",
30+
DeprecationWarning)
31+
2432
def list_parquet_columns(filename, **kwargs):
2533
"""list_parquet_columns"""
2634
if not tf.executing_eagerly():

tests/test_parquet_eager.py

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import tensorflow as tf
2525
if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")):
2626
tf.compat.v1.enable_eager_execution()
27-
import tensorflow_io.parquet as parquet_io # pylint: disable=wrong-import-position
27+
import tensorflow_io as tfio # pylint: disable=wrong-import-position
2828

2929
# Note: The sample file is generated from:
3030
# `parquet-cpp/examples/low-level-api/reader_writer`
@@ -47,18 +47,27 @@ def test_parquet():
4747
"parquet_cpp_example.parquet")
4848
filename = "file://" + filename
4949

50-
specs = parquet_io.list_parquet_columns(filename)
50+
parquet = tfio.IOTensor.from_parquet(filename)
5151
columns = [
5252
'boolean_field',
5353
'int32_field',
5454
'int64_field',
55+
'int96_field',
5556
'float_field',
56-
'double_field']
57-
p0 = parquet_io.read_parquet(filename, specs['boolean_field'])
58-
p1 = parquet_io.read_parquet(filename, specs['int32_field'])
59-
p2 = parquet_io.read_parquet(filename, specs['int64_field'])
60-
p4 = parquet_io.read_parquet(filename, specs['float_field'])
61-
p5 = parquet_io.read_parquet(filename, specs['double_field'])
57+
'double_field',
58+
'ba_field',
59+
'flba_field']
60+
assert parquet.columns == columns
61+
p0 = parquet('boolean_field')
62+
p1 = parquet('int32_field')
63+
p2 = parquet('int64_field')
64+
p4 = parquet('float_field')
65+
p5 = parquet('double_field')
66+
assert p0.dtype == tf.bool
67+
assert p1.dtype == tf.int32
68+
assert p2.dtype == tf.int64
69+
assert p4.dtype == tf.float32
70+
assert p5.dtype == tf.float64
6271

6372
for i in range(500): # 500 rows.
6473
v0 = ((i % 2) == 0)
@@ -72,24 +81,5 @@ def test_parquet():
7281
assert np.isclose(v4, p4[i].numpy())
7382
assert np.isclose(v5, p5[i].numpy())
7483

75-
dataset = tf.compat.v2.data.Dataset.zip(
76-
tuple(
77-
[parquet_io.ParquetDataset(filename, column) for column in columns])
78-
).apply(tf.data.experimental.unbatch())
79-
i = 0
80-
for p in dataset:
81-
v0 = ((i % 2) == 0)
82-
v1 = i
83-
v2 = i * 1000 * 1000 * 1000 * 1000
84-
v4 = 1.1 * i
85-
v5 = 1.1111111 * i
86-
p0, p1, p2, p4, p5 = p
87-
assert v0 == p0.numpy()
88-
assert v1 == p1.numpy()
89-
assert v2 == p2.numpy()
90-
assert np.isclose(v4, p4.numpy())
91-
assert np.isclose(v5, p5.numpy())
92-
i += 1
93-
9484
if __name__ == "__main__":
9585
test.main()

0 commit comments

Comments
 (0)