From d55a9b59120ec60c2516196290e04bfb76611ac3 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 24 Aug 2019 01:04:44 +0000 Subject: [PATCH] Add tfio.IOTensor.from_feather support Note: this PR depends on PR 438. Feather is a columnar file format that often seen with pandas. This PR adds the indexing and slicing support to bring Feather to parity with Parquet file format, by adding tfio.IOTensor.from_feather support so that it is possible to access feather through natual `__getitem__` operations. Signed-off-by: Yong Tang --- tensorflow_io/arrow/kernels/arrow_kernels.cc | 220 ++++++++++++++++++ tensorflow_io/arrow/ops/dataset_ops.cc | 34 +++ .../core/python/ops/feather_io_tensor_ops.py | 51 ++++ tensorflow_io/core/python/ops/io_tensor.py | 18 ++ tests/test_feather_eager.py | 58 +++++ 5 files changed, 381 insertions(+) create mode 100644 tensorflow_io/core/python/ops/feather_io_tensor_ops.py create mode 100644 tests/test_feather_eager.py diff --git a/tensorflow_io/arrow/kernels/arrow_kernels.cc b/tensorflow_io/arrow/kernels/arrow_kernels.cc index 90d337ed0..3cf6c85d0 100644 --- a/tensorflow_io/arrow/kernels/arrow_kernels.cc +++ b/tensorflow_io/arrow/kernels/arrow_kernels.cc @@ -15,11 +15,13 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow_io/arrow/kernels/arrow_kernels.h" +#include "tensorflow_io/core/kernels/io_interface.h" #include "arrow/io/api.h" #include "arrow/ipc/feather.h" #include "arrow/ipc/feather_generated.h" #include "arrow/buffer.h" #include "arrow/adapters/tensorflow/convert.h" +#include "arrow/table.h" namespace tensorflow { namespace data { @@ -173,5 +175,223 @@ REGISTER_KERNEL_BUILDER(Name("ListFeatherColumns").Device(DEVICE_CPU), } // namespace + + +class FeatherIndexable : public IOIndexableInterface { + public: + FeatherIndexable(Env* env) + : env_(env) {} + + ~FeatherIndexable() {} + Status Init(const std::vector& input, const std::vector& metadata, const void* memory_data, const int64 memory_size) override { + if (input.size() > 1) { + return errors::InvalidArgument("more than 1 filename is not supported"); + } + + const string& filename = input[0]; + file_.reset(new SizedRandomAccessFile(env_, filename, memory_data, memory_size)); + TF_RETURN_IF_ERROR(file_->GetFileSize(&file_size_)); + + // FEA1.....[metadata][uint32 metadata_length]FEA1 + static constexpr const char* kFeatherMagicBytes = "FEA1"; + + size_t header_length = strlen(kFeatherMagicBytes); + size_t footer_length = sizeof(uint32) + strlen(kFeatherMagicBytes); + + string buffer; + buffer.resize(header_length > footer_length ? header_length : footer_length); + + StringPiece result; + + TF_RETURN_IF_ERROR(file_->Read(0, header_length, &result, &buffer[0])); + if (memcmp(buffer.data(), kFeatherMagicBytes, header_length) != 0) { + return errors::InvalidArgument("not a feather file"); + } + + TF_RETURN_IF_ERROR(file_->Read(file_size_ - footer_length, footer_length, &result, &buffer[0])); + if (memcmp(buffer.data() + sizeof(uint32), kFeatherMagicBytes, footer_length - sizeof(uint32)) != 0) { + return errors::InvalidArgument("incomplete feather file"); + } + + uint32 metadata_length = *reinterpret_cast(buffer.data()); + + buffer.resize(metadata_length); + + TF_RETURN_IF_ERROR(file_->Read(file_size_ - footer_length - metadata_length, metadata_length, &result, &buffer[0])); + + const ::arrow::ipc::feather::fbs::CTable* table = ::arrow::ipc::feather::fbs::GetCTable(buffer.data()); + + if (table->version() < ::arrow::ipc::feather::kFeatherVersion) { + return errors::InvalidArgument("feather file is old: ", table->version(), " vs. ", ::arrow::ipc::feather::kFeatherVersion); + } + + for (int i = 0; i < table->columns()->size(); i++) { + ::tensorflow::DataType dtype = ::tensorflow::DataType::DT_INVALID; + switch (table->columns()->Get(i)->values()->type()) { + case ::arrow::ipc::feather::fbs::Type_BOOL: + dtype = ::tensorflow::DataType::DT_BOOL; + break; + case ::arrow::ipc::feather::fbs::Type_INT8: + dtype = ::tensorflow::DataType::DT_INT8; + break; + case ::arrow::ipc::feather::fbs::Type_INT16: + dtype = ::tensorflow::DataType::DT_INT16; + break; + case ::arrow::ipc::feather::fbs::Type_INT32: + dtype = ::tensorflow::DataType::DT_INT32; + break; + case ::arrow::ipc::feather::fbs::Type_INT64: + dtype = ::tensorflow::DataType::DT_INT64; + break; + case ::arrow::ipc::feather::fbs::Type_UINT8: + dtype = ::tensorflow::DataType::DT_UINT8; + break; + case ::arrow::ipc::feather::fbs::Type_UINT16: + dtype = ::tensorflow::DataType::DT_UINT16; + break; + case ::arrow::ipc::feather::fbs::Type_UINT32: + dtype = ::tensorflow::DataType::DT_UINT32; + break; + case ::arrow::ipc::feather::fbs::Type_UINT64: + dtype = ::tensorflow::DataType::DT_UINT64; + break; + case ::arrow::ipc::feather::fbs::Type_FLOAT: + dtype = ::tensorflow::DataType::DT_FLOAT; + break; + case ::arrow::ipc::feather::fbs::Type_DOUBLE: + dtype = ::tensorflow::DataType::DT_DOUBLE; + break; + case ::arrow::ipc::feather::fbs::Type_UTF8: + case ::arrow::ipc::feather::fbs::Type_BINARY: + case ::arrow::ipc::feather::fbs::Type_CATEGORY: + case ::arrow::ipc::feather::fbs::Type_TIMESTAMP: + case ::arrow::ipc::feather::fbs::Type_DATE: + case ::arrow::ipc::feather::fbs::Type_TIME: + // case ::arrow::ipc::feather::fbs::Type_LARGE_UTF8: + // case ::arrow::ipc::feather::fbs::Type_LARGE_BINARY: + default: + break; + } + shapes_.push_back(TensorShape({static_cast(table->num_rows())})); + dtypes_.push_back(dtype); + columns_.push_back(table->columns()->Get(i)->name()->str()); + } + + return Status::OK(); + } + Status Spec(std::vector& shapes, std::vector& dtypes) override { + shapes.clear(); + for (size_t i = 0; i < shapes_.size(); i++) { + shapes.push_back(shapes_[i]); + } + dtypes.clear(); + for (size_t i = 0; i < dtypes_.size(); i++) { + dtypes.push_back(dtypes_[i]); + } + return Status::OK(); + } + + Status Extra(std::vector* extra) override { + // Expose columns + Tensor columns(DT_STRING, TensorShape({static_cast(columns_.size())})); + for (size_t i = 0; i < columns_.size(); i++) { + columns.flat()(i) = columns_[i]; + } + extra->push_back(columns); + return Status::OK(); + } + + Status GetItem(const int64 start, const int64 stop, const int64 step, const int64 component, Tensor* tensor) override { + if (step != 1) { + return errors::InvalidArgument("step ", step, " is not supported"); + } + + if (feather_file_.get() == nullptr) { + feather_file_.reset(new ArrowRandomAccessFile(file_.get(), file_size_)); + arrow::Status s = arrow::ipc::feather::TableReader::Open(feather_file_, &reader_); + if (!s.ok()) { + return errors::Internal(s.ToString()); + } + } + + std::shared_ptr column; + arrow::Status s = reader_->GetColumn(component, &column); + if (!s.ok()) { + return errors::Internal(s.ToString()); + } + + std::shared_ptr<::arrow::Column> slice = column->Slice(start, stop); + + #define FEATHER_PROCESS_TYPE(TTYPE,ATYPE) { \ + int64 curr_index = 0; \ + for (auto chunk : slice->data()->chunks()) { \ + for (int64_t item = 0; item < chunk->length(); item++) { \ + tensor->flat()(curr_index) = (dynamic_cast(chunk.get()))->Value(item); \ + curr_index++; \ + } \ + } \ + } + switch (tensor->dtype()) { + case DT_BOOL: + FEATHER_PROCESS_TYPE(bool, ::arrow::BooleanArray); + break; + case DT_INT8: + FEATHER_PROCESS_TYPE(int8, ::arrow::NumericArray<::arrow::Int8Type>); + break; + case DT_UINT8: + FEATHER_PROCESS_TYPE(uint8, ::arrow::NumericArray<::arrow::UInt8Type>); + break; + case DT_INT16: + FEATHER_PROCESS_TYPE(int16, ::arrow::NumericArray<::arrow::Int16Type>); + break; + case DT_UINT16: + FEATHER_PROCESS_TYPE(uint16, ::arrow::NumericArray<::arrow::UInt16Type>); + break; + case DT_INT32: + FEATHER_PROCESS_TYPE(int32, ::arrow::NumericArray<::arrow::Int32Type>); + break; + case DT_UINT32: + FEATHER_PROCESS_TYPE(uint32, ::arrow::NumericArray<::arrow::UInt32Type>); + break; + case DT_INT64: + FEATHER_PROCESS_TYPE(int64, ::arrow::NumericArray<::arrow::Int64Type>); + break; + case DT_UINT64: + FEATHER_PROCESS_TYPE(uint64, ::arrow::NumericArray<::arrow::UInt64Type>); + break; + case DT_FLOAT: + FEATHER_PROCESS_TYPE(float, ::arrow::NumericArray<::arrow::FloatType>); + break; + case DT_DOUBLE: + FEATHER_PROCESS_TYPE(double, ::arrow::NumericArray<::arrow::DoubleType>); + break; + default: + return errors::InvalidArgument("data type is not supported: ", DataTypeString(tensor->dtype())); + } + + return Status::OK(); + } + + string DebugString() const override { + mutex_lock l(mu_); + return strings::StrCat("FeatherIndexable"); + } + private: + mutable mutex mu_; + Env* env_ GUARDED_BY(mu_); + std::unique_ptr file_ GUARDED_BY(mu_); + uint64 file_size_ GUARDED_BY(mu_); + std::shared_ptr feather_file_ GUARDED_BY(mu_); + std::unique_ptr reader_ GUARDED_BY(mu_); + + std::vector dtypes_; + std::vector shapes_; + std::vector columns_; +}; + +REGISTER_KERNEL_BUILDER(Name("FeatherIndexableInit").Device(DEVICE_CPU), + IOInterfaceInitOp); +REGISTER_KERNEL_BUILDER(Name("FeatherIndexableGetItem").Device(DEVICE_CPU), + IOIndexableGetItemOp); } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/arrow/ops/dataset_ops.cc b/tensorflow_io/arrow/ops/dataset_ops.cc index ae875d7e7..ebfd6303f 100644 --- a/tensorflow_io/arrow/ops/dataset_ops.cc +++ b/tensorflow_io/arrow/ops/dataset_ops.cc @@ -100,4 +100,38 @@ REGISTER_OP("ListFeatherColumns") return Status::OK(); }); +REGISTER_OP("FeatherIndexableInit") + .Input("input: string") + .Output("output: resource") + .Output("shapes: int64") + .Output("dtypes: int64") + .Output("columns: string") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + c->set_output(1, c->MakeShape({c->UnknownDim()})); + c->set_output(2, c->MakeShape({c->UnknownDim(), c->UnknownDim()})); + c->set_output(3, c->MakeShape({c->UnknownDim()})); + return Status::OK(); + }); + +REGISTER_OP("FeatherIndexableGetItem") + .Input("input: resource") + .Input("start: int64") + .Input("stop: int64") + .Input("step: int64") + .Input("component: int64") + .Output("output: dtype") + .Attr("shape: shape") + .Attr("dtype: type") + .SetShapeFn([](shape_inference::InferenceContext* c) { + PartialTensorShape shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); + shape_inference::ShapeHandle entry; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &entry)); + c->set_output(0, entry); + return Status::OK(); + }); } // namespace tensorflow diff --git a/tensorflow_io/core/python/ops/feather_io_tensor_ops.py b/tensorflow_io/core/python/ops/feather_io_tensor_ops.py new file mode 100644 index 000000000..8e2b598f4 --- /dev/null +++ b/tensorflow_io/core/python/ops/feather_io_tensor_ops.py @@ -0,0 +1,51 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""FeatherIOTensor""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import uuid + +import tensorflow as tf +from tensorflow_io.core.python.ops import io_tensor_ops +from tensorflow_io.core.python.ops import core_ops + +class FeatherIOTensor(io_tensor_ops._TableIOTensor): # pylint: disable=protected-access + """FeatherIOTensor""" + + #============================================================================= + # Constructor (private) + #============================================================================= + def __init__(self, + filename, + internal=False): + with tf.name_scope("FeatherIOTensor") as scope: + resource, shapes, dtypes, columns = core_ops.feather_indexable_init( + filename, + container=scope, + shared_name="%s/%s" % (filename, uuid.uuid4().hex)) + shapes = [ + tf.TensorShape( + [None if dim < 0 else dim for dim in e.numpy() if dim != 0] + ) for e in tf.unstack(shapes)] + dtypes = [tf.as_dtype(e.numpy()) for e in tf.unstack(dtypes)] + columns = [e.numpy().decode() for e in tf.unstack(columns)] + spec = tuple([tf.TensorSpec(shape, dtype, column) for ( + shape, dtype, column) in zip(shapes, dtypes, columns)]) + super(FeatherIOTensor, self).__init__( + spec, columns, + resource, core_ops.feather_indexable_get_item, + internal=internal) diff --git a/tensorflow_io/core/python/ops/io_tensor.py b/tensorflow_io/core/python/ops/io_tensor.py index 24358ec39..8ba4a3cb3 100644 --- a/tensorflow_io/core/python/ops/io_tensor.py +++ b/tensorflow_io/core/python/ops/io_tensor.py @@ -23,6 +23,7 @@ from tensorflow_io.core.python.ops import json_io_tensor_ops from tensorflow_io.core.python.ops import kafka_io_tensor_ops from tensorflow_io.core.python.ops import prometheus_io_tensor_ops +from tensorflow_io.core.python.ops import feather_io_tensor_ops class IOTensor(io_tensor_ops._IOTensor): # pylint: disable=protected-access """IOTensor @@ -287,3 +288,20 @@ def from_prometheus(cls, with tf.name_scope(kwargs.get("name", "IOFromPrometheus")): return prometheus_io_tensor_ops.PrometheusIOTensor( query, endpoint=kwargs.get("endpoint", None), internal=True) + + @classmethod + def from_feather(cls, + filename, + **kwargs): + """Creates an `IOTensor` from an feather file. + + Args: + filename: A string, the filename of an feather file. + name: A name prefix for the IOTensor (optional). + + Returns: + A `IOTensor`. + + """ + with tf.name_scope(kwargs.get("name", "IOFromFeather")): + return feather_io_tensor_ops.FeatherIOTensor(filename, internal=True) diff --git a/tests/test_feather_eager.py b/tests/test_feather_eager.py new file mode 100644 index 000000000..831559b3a --- /dev/null +++ b/tests/test_feather_eager.py @@ -0,0 +1,58 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for Feather""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile + +import numpy as np +import pandas as pd + +import tensorflow as tf +if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")): + tf.compat.v1.enable_eager_execution() +import tensorflow_io as tfio # pylint: disable=wrong-import-position + +def test_feather_format(): + """test_feather_format""" + data = { + 'bool': np.asarray([e%2 for e in range(100)], np.bool), + 'int8': np.asarray(range(100), np.int8), + 'int16': np.asarray(range(100), np.int16), + 'int32': np.asarray(range(100), np.int32), + 'int64': np.asarray(range(100), np.int64), + 'float': np.asarray(range(100), np.float32), + 'double': np.asarray(range(100), np.float64), + } + df = pd.DataFrame(data).sort_index(axis=1) + with tempfile.NamedTemporaryFile(delete=False) as f: + df.to_feather(f) + + df = pd.read_feather(f.name) + + feather = tfio.IOTensor.from_feather(f.name) + for column in df.columns: + assert feather(column).shape == [100] + assert feather(column).dtype == column + assert np.all(feather(column).to_tensor().numpy() == data[column]) + + os.unlink(f.name) + +if __name__ == "__main__": + test.main()