From 5c0575b3dc666d2801bb40b2d3e9f6f8dbdb019d Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 24 Aug 2019 03:07:59 +0000 Subject: [PATCH 1/2] Add tfio.IOTensor.from_csv support (experimental with Apache Arrow's CSV parser) CSV file is one of the most widely used format and in TensorFlow's main repo there is already a CsvDataset which could be conviniently used for iteration, either feed into tf.keras directly, or access through a for loop. There are, still some reasons to have a CSV input processor that give indexing and slicing access. The most notable reason is that while CSV file itself technically is only splittable (not truly indexable), in reality especially in data science by default CSV file is almost always loaded into memory. And because of its wide usage, it is really more convenient and more flexible to have a CSV processor that allows indexing and slicing. This PR takes the indexing and slicing approach and built the parser on top of Arrow. One advantage of Arrow is that Arrow's CSV's parser options are closer to widely used pandas. This will allow easy usage of importing csv files created by pandas. Signed-off-by: Yong Tang --- .../core/python/ops/csv_io_tensor_ops.py | 51 ++++ tensorflow_io/core/python/ops/io_tensor.py | 18 ++ tensorflow_io/text/BUILD | 2 + tensorflow_io/text/kernels/csv_kernels.cc | 225 ++++++++++++++++++ tensorflow_io/text/ops/text_ops.cc | 41 ++++ tests/test_csv_eager.py | 54 +++++ third_party/arrow.BUILD | 2 + 7 files changed, 393 insertions(+) create mode 100644 tensorflow_io/core/python/ops/csv_io_tensor_ops.py create mode 100644 tensorflow_io/text/kernels/csv_kernels.cc create mode 100644 tests/test_csv_eager.py diff --git a/tensorflow_io/core/python/ops/csv_io_tensor_ops.py b/tensorflow_io/core/python/ops/csv_io_tensor_ops.py new file mode 100644 index 000000000..b5b77117f --- /dev/null +++ b/tensorflow_io/core/python/ops/csv_io_tensor_ops.py @@ -0,0 +1,51 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""CSVIOTensor""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import uuid + +import tensorflow as tf +from tensorflow_io.core.python.ops import io_tensor_ops +from tensorflow_io.core.python.ops import core_ops + +class CSVIOTensor(io_tensor_ops._TableIOTensor): # pylint: disable=protected-access + """CSVIOTensor""" + + #============================================================================= + # Constructor (private) + #============================================================================= + def __init__(self, + filename, + internal=False): + with tf.name_scope("CSVIOTensor") as scope: + resource, columns = core_ops.csv_indexable_init( + filename, + container=scope, + shared_name="%s/%s" % (filename, uuid.uuid4().hex)) + columns = [column.decode() for column in columns.numpy().tolist()] + spec = [] + for column in columns: + shape, dtype = core_ops.csv_indexable_spec(resource, column) + shape = tf.TensorShape(shape) + dtype = tf.as_dtype(dtype.numpy()) + spec.append(tf.TensorSpec(shape, dtype, column)) + spec = tuple(spec) + super(CSVIOTensor, self).__init__( + spec, columns, + resource, core_ops.csv_indexable_get_item, + internal=internal) diff --git a/tensorflow_io/core/python/ops/io_tensor.py b/tensorflow_io/core/python/ops/io_tensor.py index e452ed40c..e86cf60c4 100644 --- a/tensorflow_io/core/python/ops/io_tensor.py +++ b/tensorflow_io/core/python/ops/io_tensor.py @@ -26,6 +26,7 @@ from tensorflow_io.core.python.ops import lmdb_io_tensor_ops from tensorflow_io.core.python.ops import prometheus_io_tensor_ops from tensorflow_io.core.python.ops import feather_io_tensor_ops +from tensorflow_io.core.python.ops import csv_io_tensor_ops class IOTensor(io_tensor_ops._IOTensor): # pylint: disable=protected-access """IOTensor @@ -364,3 +365,20 @@ def from_hdf5(cls, """ with tf.name_scope(kwargs.get("name", "IOFromHDF5")): return hdf5_io_tensor_ops.HDF5IOTensor(filename, internal=True) + + @classmethod + def from_csv(cls, + filename, + **kwargs): + """Creates an `IOTensor` from an csv file. + + Args: + filename: A string, the filename of an csv file. + name: A name prefix for the IOTensor (optional). + + Returns: + A `IOTensor`. + + """ + with tf.name_scope(kwargs.get("name", "IOFromCSV")): + return csv_io_tensor_ops.CSVIOTensor(filename, internal=True) diff --git a/tensorflow_io/text/BUILD b/tensorflow_io/text/BUILD index 12c169dec..5563d3302 100644 --- a/tensorflow_io/text/BUILD +++ b/tensorflow_io/text/BUILD @@ -10,6 +10,7 @@ load( cc_library( name = "text_ops", srcs = [ + "kernels/csv_kernels.cc", "kernels/csv_output.cc", "kernels/text_kernels.cc", "kernels/text_output.cc", @@ -23,6 +24,7 @@ cc_library( ], linkstatic = True, deps = [ + "//tensorflow_io/arrow:arrow_ops", "//tensorflow_io/core:dataset_ops", "//tensorflow_io/core:output_ops", "//tensorflow_io/core:sequence_ops", diff --git a/tensorflow_io/text/kernels/csv_kernels.cc b/tensorflow_io/text/kernels/csv_kernels.cc new file mode 100644 index 000000000..3c5cfa86e --- /dev/null +++ b/tensorflow_io/text/kernels/csv_kernels.cc @@ -0,0 +1,225 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow_io/core/kernels/stream.h" +#include "tensorflow/core/lib/io/buffered_inputstream.h" +#include "tensorflow_io/core/kernels/io_interface.h" +#include "tensorflow_io/core/kernels/stream.h" +#include "arrow/memory_pool.h" +#include "arrow/csv/reader.h" +#include "arrow/table.h" +#include "tensorflow_io/arrow/kernels/arrow_kernels.h" + +namespace tensorflow { +namespace data { + +class CSVIndexable : public IOIndexableInterface { + public: + CSVIndexable(Env* env) + : env_(env) {} + + ~CSVIndexable() {} + Status Init(const std::vector& input, const std::vector& metadata, const void* memory_data, const int64 memory_size) override { + if (input.size() > 1) { + return errors::InvalidArgument("more than 1 filename is not supported"); + } + const string& filename = input[0]; + file_.reset(new SizedRandomAccessFile(env_, filename, memory_data, memory_size)); + TF_RETURN_IF_ERROR(file_->GetFileSize(&file_size_)); + + csv_file_.reset(new ArrowRandomAccessFile(file_.get(), file_size_)); + + ::arrow::Status status; + + status = ::arrow::csv::TableReader::Make(::arrow::default_memory_pool(), csv_file_, ::arrow::csv::ReadOptions::Defaults(), ::arrow::csv::ParseOptions::Defaults(), ::arrow::csv::ConvertOptions::Defaults(), &reader_); + if (!status.ok()) { + return errors::InvalidArgument("unable to make a TableReader: ", status); + } + status = reader_->Read(&table_); + if (!status.ok()) { + return errors::InvalidArgument("unable to read table: ", status); + } + + for (int i = 0; i < table_->num_columns(); i++) { + ::tensorflow::DataType dtype; + switch (table_->column(i)->type()->id()) { + case ::arrow::Type::BOOL: + dtype = ::tensorflow::DT_BOOL; + break; + case ::arrow::Type::UINT8: + dtype= ::tensorflow::DT_UINT8; + break; + case ::arrow::Type::INT8: + dtype= ::tensorflow::DT_INT8; + break; + case ::arrow::Type::UINT16: + dtype= ::tensorflow::DT_UINT16; + break; + case ::arrow::Type::INT16: + dtype= ::tensorflow::DT_INT16; + break; + case ::arrow::Type::UINT32: + dtype= ::tensorflow::DT_UINT32; + break; + case ::arrow::Type::INT32: + dtype= ::tensorflow::DT_INT32; + break; + case ::arrow::Type::UINT64: + dtype= ::tensorflow::DT_UINT64; + break; + case ::arrow::Type::INT64: + dtype= ::tensorflow::DT_INT64; + break; + case ::arrow::Type::HALF_FLOAT: + dtype= ::tensorflow::DT_HALF; + break; + case ::arrow::Type::FLOAT: + dtype= ::tensorflow::DT_FLOAT; + break; + case ::arrow::Type::DOUBLE: + dtype= ::tensorflow::DT_DOUBLE; + break; + case ::arrow::Type::STRING: + case ::arrow::Type::BINARY: + case ::arrow::Type::FIXED_SIZE_BINARY: + case ::arrow::Type::DATE32: + case ::arrow::Type::DATE64: + case ::arrow::Type::TIMESTAMP: + case ::arrow::Type::TIME32: + case ::arrow::Type::TIME64: + case ::arrow::Type::INTERVAL: + case ::arrow::Type::DECIMAL: + case ::arrow::Type::LIST: + case ::arrow::Type::STRUCT: + case ::arrow::Type::UNION: + case ::arrow::Type::DICTIONARY: + case ::arrow::Type::MAP: + default: + return errors::InvalidArgument("arrow data type is not supported: ", table_->column(i)->type()->ToString()); + } + shapes_.push_back(TensorShape({static_cast(table_->num_rows())})); + dtypes_.push_back(dtype); + columns_.push_back(table_->column(i)->name()); + columns_index_[table_->column(i)->name()] = i; + } + + return Status::OK(); + } + Status Component(Tensor* component) override { + *component = Tensor(DT_STRING, TensorShape({static_cast(columns_.size())})); + for (size_t i = 0; i < columns_.size(); i++) { + component->flat()(i) = columns_[i]; + } + return Status::OK(); + } + Status Spec(const Tensor& component, PartialTensorShape* shape, DataType* dtype) override { + if (columns_index_.find(component.scalar()()) == columns_index_.end()) { + return errors::InvalidArgument("component ", component.scalar()(), " is invalid"); + } + int64 column_index = columns_index_[component.scalar()()]; + *shape = shapes_[column_index]; + *dtype = dtypes_[column_index]; + return Status::OK(); + } + + Status GetItem(const int64 start, const int64 stop, const int64 step, const Tensor& component, Tensor* tensor) override { + if (step != 1) { + return errors::InvalidArgument("step ", step, " is not supported"); + } + if (columns_index_.find(component.scalar()()) == columns_index_.end()) { + return errors::InvalidArgument("component ", component.scalar()(), " is invalid"); + } + int64 column_index = columns_index_[component.scalar()()]; + + std::shared_ptr<::arrow::Column> slice = table_->column(column_index)->Slice(start, stop); + + #define PROCESS_TYPE(TTYPE,ATYPE) { \ + int64 curr_index = 0; \ + for (auto chunk : slice->data()->chunks()) { \ + for (int64_t item = 0; item < chunk->length(); item++) { \ + tensor->flat()(curr_index) = (dynamic_cast(chunk.get()))->Value(item); \ + curr_index++; \ + } \ + } \ + } + switch (tensor->dtype()) { + case DT_BOOL: + PROCESS_TYPE(bool, ::arrow::BooleanArray); + break; + case DT_INT8: + PROCESS_TYPE(int8, ::arrow::NumericArray<::arrow::Int8Type>); + break; + case DT_UINT8: + PROCESS_TYPE(uint8, ::arrow::NumericArray<::arrow::UInt8Type>); + break; + case DT_INT16: + PROCESS_TYPE(int16, ::arrow::NumericArray<::arrow::Int16Type>); + break; + case DT_UINT16: + PROCESS_TYPE(uint16, ::arrow::NumericArray<::arrow::UInt16Type>); + break; + case DT_INT32: + PROCESS_TYPE(int32, ::arrow::NumericArray<::arrow::Int32Type>); + break; + case DT_UINT32: + PROCESS_TYPE(uint32, ::arrow::NumericArray<::arrow::UInt32Type>); + break; + case DT_INT64: + PROCESS_TYPE(int64, ::arrow::NumericArray<::arrow::Int64Type>); + break; + case DT_UINT64: + PROCESS_TYPE(uint64, ::arrow::NumericArray<::arrow::UInt64Type>); + break; + case DT_FLOAT: + PROCESS_TYPE(float, ::arrow::NumericArray<::arrow::FloatType>); + break; + case DT_DOUBLE: + PROCESS_TYPE(double, ::arrow::NumericArray<::arrow::DoubleType>); + break; + default: + return errors::InvalidArgument("data type is not supported: ", DataTypeString(tensor->dtype())); + } + + return Status::OK(); + } + + string DebugString() const override { + mutex_lock l(mu_); + return strings::StrCat("CSVIndexable"); + } + private: + mutable mutex mu_; + Env* env_ GUARDED_BY(mu_); + std::unique_ptr file_ GUARDED_BY(mu_); + uint64 file_size_ GUARDED_BY(mu_); + std::shared_ptr csv_file_; + std::shared_ptr<::arrow::csv::TableReader> reader_; + std::shared_ptr<::arrow::Table> table_; + + std::vector dtypes_; + std::vector shapes_; + std::vector columns_; + std::unordered_map columns_index_; +}; + +REGISTER_KERNEL_BUILDER(Name("CSVIndexableInit").Device(DEVICE_CPU), + IOInterfaceInitOp); +REGISTER_KERNEL_BUILDER(Name("CSVIndexableSpec").Device(DEVICE_CPU), + IOInterfaceSpecOp); +REGISTER_KERNEL_BUILDER(Name("CSVIndexableGetItem").Device(DEVICE_CPU), + IOIndexableGetItemOp); +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/text/ops/text_ops.cc b/tensorflow_io/text/ops/text_ops.cc index e8544ca22..986004a9f 100644 --- a/tensorflow_io/text/ops/text_ops.cc +++ b/tensorflow_io/text/ops/text_ops.cc @@ -135,4 +135,45 @@ REGISTER_OP("TextOutputSequenceSetItem") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); +REGISTER_OP("CSVIndexableInit") + .Input("input: string") + .Output("output: resource") + .Output("component: string") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + c->set_output(1, c->MakeShape({})); + return Status::OK(); + }); + +REGISTER_OP("CSVIndexableSpec") + .Input("input: resource") + .Input("component: string") + .Output("shape: int64") + .Output("dtype: int64") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->MakeShape({c->UnknownDim()})); + c->set_output(1, c->MakeShape({})); + return Status::OK(); + }); + +REGISTER_OP("CSVIndexableGetItem") + .Input("input: resource") + .Input("start: int64") + .Input("stop: int64") + .Input("step: int64") + .Input("component: string") + .Output("output: dtype") + .Attr("shape: shape") + .Attr("dtype: type") + .SetShapeFn([](shape_inference::InferenceContext* c) { + PartialTensorShape shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); + shape_inference::ShapeHandle entry; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &entry)); + c->set_output(0, entry); + return Status::OK(); + }); + } // namespace tensorflow diff --git a/tests/test_csv_eager.py b/tests/test_csv_eager.py new file mode 100644 index 000000000..6a0620ab5 --- /dev/null +++ b/tests/test_csv_eager.py @@ -0,0 +1,54 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for CSV""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile + +import numpy as np +import pandas as pd + +import tensorflow as tf +if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")): + tf.compat.v1.enable_eager_execution() +import tensorflow_io as tfio # pylint: disable=wrong-import-position + +def test_csv_format(): + """test_csv_format""" + data = { + 'bool': np.asarray([e%2 for e in range(100)], np.bool), + 'int64': np.asarray(range(100), np.int64), + 'double': np.asarray(range(100), np.float64), + } + df = pd.DataFrame(data).sort_index(axis=1) + with tempfile.NamedTemporaryFile(delete=False) as f: + df.to_csv(f, index=False) + + df = pd.read_csv(f.name) + + csv = tfio.IOTensor.from_csv(f.name) + for column in df.columns: + assert csv(column).shape == [100] + assert csv(column).dtype == column + assert np.all(csv(column).to_tensor().numpy() == data[column]) + + os.unlink(f.name) + +if __name__ == "__main__": + test.main() diff --git a/third_party/arrow.BUILD b/third_party/arrow.BUILD index d52d8242e..b573bfc6a 100644 --- a/third_party/arrow.BUILD +++ b/third_party/arrow.BUILD @@ -47,6 +47,8 @@ cc_library( "cpp/src/arrow/io/*.h", "cpp/src/arrow/ipc/*.cc", "cpp/src/arrow/ipc/*.h", + "cpp/src/arrow/csv/*.cc", + "cpp/src/arrow/csv/*.h", "cpp/src/arrow/json/*.cc", "cpp/src/arrow/json/*.h", "cpp/src/arrow/util/*.cc", From 4e9ec5fe0b45c3194605d0542706dd4124a8ee06 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Wed, 11 Sep 2019 18:04:41 +0000 Subject: [PATCH 2/2] Fix python 3 error Signed-off-by: Yong Tang --- tests/test_csv_eager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_csv_eager.py b/tests/test_csv_eager.py index 6a0620ab5..e803fe0f7 100644 --- a/tests/test_csv_eager.py +++ b/tests/test_csv_eager.py @@ -37,7 +37,7 @@ def test_csv_format(): 'double': np.asarray(range(100), np.float64), } df = pd.DataFrame(data).sort_index(axis=1) - with tempfile.NamedTemporaryFile(delete=False) as f: + with tempfile.NamedTemporaryFile(delete=False, mode="w") as f: df.to_csv(f, index=False) df = pd.read_csv(f.name)