diff --git a/tensorflow_io/core/python/ops/csv_io_tensor_ops.py b/tensorflow_io/core/python/ops/csv_io_tensor_ops.py new file mode 100644 index 000000000..b5b77117f --- /dev/null +++ b/tensorflow_io/core/python/ops/csv_io_tensor_ops.py @@ -0,0 +1,51 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""CSVIOTensor""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import uuid + +import tensorflow as tf +from tensorflow_io.core.python.ops import io_tensor_ops +from tensorflow_io.core.python.ops import core_ops + +class CSVIOTensor(io_tensor_ops._TableIOTensor): # pylint: disable=protected-access + """CSVIOTensor""" + + #============================================================================= + # Constructor (private) + #============================================================================= + def __init__(self, + filename, + internal=False): + with tf.name_scope("CSVIOTensor") as scope: + resource, columns = core_ops.csv_indexable_init( + filename, + container=scope, + shared_name="%s/%s" % (filename, uuid.uuid4().hex)) + columns = [column.decode() for column in columns.numpy().tolist()] + spec = [] + for column in columns: + shape, dtype = core_ops.csv_indexable_spec(resource, column) + shape = tf.TensorShape(shape) + dtype = tf.as_dtype(dtype.numpy()) + spec.append(tf.TensorSpec(shape, dtype, column)) + spec = tuple(spec) + super(CSVIOTensor, self).__init__( + spec, columns, + resource, core_ops.csv_indexable_get_item, + internal=internal) diff --git a/tensorflow_io/core/python/ops/io_tensor.py b/tensorflow_io/core/python/ops/io_tensor.py index e452ed40c..e86cf60c4 100644 --- a/tensorflow_io/core/python/ops/io_tensor.py +++ b/tensorflow_io/core/python/ops/io_tensor.py @@ -26,6 +26,7 @@ from tensorflow_io.core.python.ops import lmdb_io_tensor_ops from tensorflow_io.core.python.ops import prometheus_io_tensor_ops from tensorflow_io.core.python.ops import feather_io_tensor_ops +from tensorflow_io.core.python.ops import csv_io_tensor_ops class IOTensor(io_tensor_ops._IOTensor): # pylint: disable=protected-access """IOTensor @@ -364,3 +365,20 @@ def from_hdf5(cls, """ with tf.name_scope(kwargs.get("name", "IOFromHDF5")): return hdf5_io_tensor_ops.HDF5IOTensor(filename, internal=True) + + @classmethod + def from_csv(cls, + filename, + **kwargs): + """Creates an `IOTensor` from an csv file. + + Args: + filename: A string, the filename of an csv file. + name: A name prefix for the IOTensor (optional). + + Returns: + A `IOTensor`. + + """ + with tf.name_scope(kwargs.get("name", "IOFromCSV")): + return csv_io_tensor_ops.CSVIOTensor(filename, internal=True) diff --git a/tensorflow_io/text/BUILD b/tensorflow_io/text/BUILD index 12c169dec..5563d3302 100644 --- a/tensorflow_io/text/BUILD +++ b/tensorflow_io/text/BUILD @@ -10,6 +10,7 @@ load( cc_library( name = "text_ops", srcs = [ + "kernels/csv_kernels.cc", "kernels/csv_output.cc", "kernels/text_kernels.cc", "kernels/text_output.cc", @@ -23,6 +24,7 @@ cc_library( ], linkstatic = True, deps = [ + "//tensorflow_io/arrow:arrow_ops", "//tensorflow_io/core:dataset_ops", "//tensorflow_io/core:output_ops", "//tensorflow_io/core:sequence_ops", diff --git a/tensorflow_io/text/kernels/csv_kernels.cc b/tensorflow_io/text/kernels/csv_kernels.cc new file mode 100644 index 000000000..3c5cfa86e --- /dev/null +++ b/tensorflow_io/text/kernels/csv_kernels.cc @@ -0,0 +1,225 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow_io/core/kernels/stream.h" +#include "tensorflow/core/lib/io/buffered_inputstream.h" +#include "tensorflow_io/core/kernels/io_interface.h" +#include "tensorflow_io/core/kernels/stream.h" +#include "arrow/memory_pool.h" +#include "arrow/csv/reader.h" +#include "arrow/table.h" +#include "tensorflow_io/arrow/kernels/arrow_kernels.h" + +namespace tensorflow { +namespace data { + +class CSVIndexable : public IOIndexableInterface { + public: + CSVIndexable(Env* env) + : env_(env) {} + + ~CSVIndexable() {} + Status Init(const std::vector& input, const std::vector& metadata, const void* memory_data, const int64 memory_size) override { + if (input.size() > 1) { + return errors::InvalidArgument("more than 1 filename is not supported"); + } + const string& filename = input[0]; + file_.reset(new SizedRandomAccessFile(env_, filename, memory_data, memory_size)); + TF_RETURN_IF_ERROR(file_->GetFileSize(&file_size_)); + + csv_file_.reset(new ArrowRandomAccessFile(file_.get(), file_size_)); + + ::arrow::Status status; + + status = ::arrow::csv::TableReader::Make(::arrow::default_memory_pool(), csv_file_, ::arrow::csv::ReadOptions::Defaults(), ::arrow::csv::ParseOptions::Defaults(), ::arrow::csv::ConvertOptions::Defaults(), &reader_); + if (!status.ok()) { + return errors::InvalidArgument("unable to make a TableReader: ", status); + } + status = reader_->Read(&table_); + if (!status.ok()) { + return errors::InvalidArgument("unable to read table: ", status); + } + + for (int i = 0; i < table_->num_columns(); i++) { + ::tensorflow::DataType dtype; + switch (table_->column(i)->type()->id()) { + case ::arrow::Type::BOOL: + dtype = ::tensorflow::DT_BOOL; + break; + case ::arrow::Type::UINT8: + dtype= ::tensorflow::DT_UINT8; + break; + case ::arrow::Type::INT8: + dtype= ::tensorflow::DT_INT8; + break; + case ::arrow::Type::UINT16: + dtype= ::tensorflow::DT_UINT16; + break; + case ::arrow::Type::INT16: + dtype= ::tensorflow::DT_INT16; + break; + case ::arrow::Type::UINT32: + dtype= ::tensorflow::DT_UINT32; + break; + case ::arrow::Type::INT32: + dtype= ::tensorflow::DT_INT32; + break; + case ::arrow::Type::UINT64: + dtype= ::tensorflow::DT_UINT64; + break; + case ::arrow::Type::INT64: + dtype= ::tensorflow::DT_INT64; + break; + case ::arrow::Type::HALF_FLOAT: + dtype= ::tensorflow::DT_HALF; + break; + case ::arrow::Type::FLOAT: + dtype= ::tensorflow::DT_FLOAT; + break; + case ::arrow::Type::DOUBLE: + dtype= ::tensorflow::DT_DOUBLE; + break; + case ::arrow::Type::STRING: + case ::arrow::Type::BINARY: + case ::arrow::Type::FIXED_SIZE_BINARY: + case ::arrow::Type::DATE32: + case ::arrow::Type::DATE64: + case ::arrow::Type::TIMESTAMP: + case ::arrow::Type::TIME32: + case ::arrow::Type::TIME64: + case ::arrow::Type::INTERVAL: + case ::arrow::Type::DECIMAL: + case ::arrow::Type::LIST: + case ::arrow::Type::STRUCT: + case ::arrow::Type::UNION: + case ::arrow::Type::DICTIONARY: + case ::arrow::Type::MAP: + default: + return errors::InvalidArgument("arrow data type is not supported: ", table_->column(i)->type()->ToString()); + } + shapes_.push_back(TensorShape({static_cast(table_->num_rows())})); + dtypes_.push_back(dtype); + columns_.push_back(table_->column(i)->name()); + columns_index_[table_->column(i)->name()] = i; + } + + return Status::OK(); + } + Status Component(Tensor* component) override { + *component = Tensor(DT_STRING, TensorShape({static_cast(columns_.size())})); + for (size_t i = 0; i < columns_.size(); i++) { + component->flat()(i) = columns_[i]; + } + return Status::OK(); + } + Status Spec(const Tensor& component, PartialTensorShape* shape, DataType* dtype) override { + if (columns_index_.find(component.scalar()()) == columns_index_.end()) { + return errors::InvalidArgument("component ", component.scalar()(), " is invalid"); + } + int64 column_index = columns_index_[component.scalar()()]; + *shape = shapes_[column_index]; + *dtype = dtypes_[column_index]; + return Status::OK(); + } + + Status GetItem(const int64 start, const int64 stop, const int64 step, const Tensor& component, Tensor* tensor) override { + if (step != 1) { + return errors::InvalidArgument("step ", step, " is not supported"); + } + if (columns_index_.find(component.scalar()()) == columns_index_.end()) { + return errors::InvalidArgument("component ", component.scalar()(), " is invalid"); + } + int64 column_index = columns_index_[component.scalar()()]; + + std::shared_ptr<::arrow::Column> slice = table_->column(column_index)->Slice(start, stop); + + #define PROCESS_TYPE(TTYPE,ATYPE) { \ + int64 curr_index = 0; \ + for (auto chunk : slice->data()->chunks()) { \ + for (int64_t item = 0; item < chunk->length(); item++) { \ + tensor->flat()(curr_index) = (dynamic_cast(chunk.get()))->Value(item); \ + curr_index++; \ + } \ + } \ + } + switch (tensor->dtype()) { + case DT_BOOL: + PROCESS_TYPE(bool, ::arrow::BooleanArray); + break; + case DT_INT8: + PROCESS_TYPE(int8, ::arrow::NumericArray<::arrow::Int8Type>); + break; + case DT_UINT8: + PROCESS_TYPE(uint8, ::arrow::NumericArray<::arrow::UInt8Type>); + break; + case DT_INT16: + PROCESS_TYPE(int16, ::arrow::NumericArray<::arrow::Int16Type>); + break; + case DT_UINT16: + PROCESS_TYPE(uint16, ::arrow::NumericArray<::arrow::UInt16Type>); + break; + case DT_INT32: + PROCESS_TYPE(int32, ::arrow::NumericArray<::arrow::Int32Type>); + break; + case DT_UINT32: + PROCESS_TYPE(uint32, ::arrow::NumericArray<::arrow::UInt32Type>); + break; + case DT_INT64: + PROCESS_TYPE(int64, ::arrow::NumericArray<::arrow::Int64Type>); + break; + case DT_UINT64: + PROCESS_TYPE(uint64, ::arrow::NumericArray<::arrow::UInt64Type>); + break; + case DT_FLOAT: + PROCESS_TYPE(float, ::arrow::NumericArray<::arrow::FloatType>); + break; + case DT_DOUBLE: + PROCESS_TYPE(double, ::arrow::NumericArray<::arrow::DoubleType>); + break; + default: + return errors::InvalidArgument("data type is not supported: ", DataTypeString(tensor->dtype())); + } + + return Status::OK(); + } + + string DebugString() const override { + mutex_lock l(mu_); + return strings::StrCat("CSVIndexable"); + } + private: + mutable mutex mu_; + Env* env_ GUARDED_BY(mu_); + std::unique_ptr file_ GUARDED_BY(mu_); + uint64 file_size_ GUARDED_BY(mu_); + std::shared_ptr csv_file_; + std::shared_ptr<::arrow::csv::TableReader> reader_; + std::shared_ptr<::arrow::Table> table_; + + std::vector dtypes_; + std::vector shapes_; + std::vector columns_; + std::unordered_map columns_index_; +}; + +REGISTER_KERNEL_BUILDER(Name("CSVIndexableInit").Device(DEVICE_CPU), + IOInterfaceInitOp); +REGISTER_KERNEL_BUILDER(Name("CSVIndexableSpec").Device(DEVICE_CPU), + IOInterfaceSpecOp); +REGISTER_KERNEL_BUILDER(Name("CSVIndexableGetItem").Device(DEVICE_CPU), + IOIndexableGetItemOp); +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/text/ops/text_ops.cc b/tensorflow_io/text/ops/text_ops.cc index e8544ca22..986004a9f 100644 --- a/tensorflow_io/text/ops/text_ops.cc +++ b/tensorflow_io/text/ops/text_ops.cc @@ -135,4 +135,45 @@ REGISTER_OP("TextOutputSequenceSetItem") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); +REGISTER_OP("CSVIndexableInit") + .Input("input: string") + .Output("output: resource") + .Output("component: string") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + c->set_output(1, c->MakeShape({})); + return Status::OK(); + }); + +REGISTER_OP("CSVIndexableSpec") + .Input("input: resource") + .Input("component: string") + .Output("shape: int64") + .Output("dtype: int64") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->MakeShape({c->UnknownDim()})); + c->set_output(1, c->MakeShape({})); + return Status::OK(); + }); + +REGISTER_OP("CSVIndexableGetItem") + .Input("input: resource") + .Input("start: int64") + .Input("stop: int64") + .Input("step: int64") + .Input("component: string") + .Output("output: dtype") + .Attr("shape: shape") + .Attr("dtype: type") + .SetShapeFn([](shape_inference::InferenceContext* c) { + PartialTensorShape shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); + shape_inference::ShapeHandle entry; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &entry)); + c->set_output(0, entry); + return Status::OK(); + }); + } // namespace tensorflow diff --git a/tests/test_csv_eager.py b/tests/test_csv_eager.py new file mode 100644 index 000000000..e803fe0f7 --- /dev/null +++ b/tests/test_csv_eager.py @@ -0,0 +1,54 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for CSV""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile + +import numpy as np +import pandas as pd + +import tensorflow as tf +if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")): + tf.compat.v1.enable_eager_execution() +import tensorflow_io as tfio # pylint: disable=wrong-import-position + +def test_csv_format(): + """test_csv_format""" + data = { + 'bool': np.asarray([e%2 for e in range(100)], np.bool), + 'int64': np.asarray(range(100), np.int64), + 'double': np.asarray(range(100), np.float64), + } + df = pd.DataFrame(data).sort_index(axis=1) + with tempfile.NamedTemporaryFile(delete=False, mode="w") as f: + df.to_csv(f, index=False) + + df = pd.read_csv(f.name) + + csv = tfio.IOTensor.from_csv(f.name) + for column in df.columns: + assert csv(column).shape == [100] + assert csv(column).dtype == column + assert np.all(csv(column).to_tensor().numpy() == data[column]) + + os.unlink(f.name) + +if __name__ == "__main__": + test.main() diff --git a/third_party/arrow.BUILD b/third_party/arrow.BUILD index d52d8242e..b573bfc6a 100644 --- a/third_party/arrow.BUILD +++ b/third_party/arrow.BUILD @@ -47,6 +47,8 @@ cc_library( "cpp/src/arrow/io/*.h", "cpp/src/arrow/ipc/*.cc", "cpp/src/arrow/ipc/*.h", + "cpp/src/arrow/csv/*.cc", + "cpp/src/arrow/csv/*.h", "cpp/src/arrow/json/*.cc", "cpp/src/arrow/json/*.h", "cpp/src/arrow/util/*.cc",