Skip to content

Commit a504bf0

Browse files
committed
Add tfio.IOTensor.from_feather support
Note: this PR depends on PR 437. Feather is a columnar file format that often seen with pandas. This PR adds the indexing and slicing support to bring Feather to parity with Parquet file format, by adding tfio.IOTensor.from_feather support so that it is possible to access feather through natual `__getitem__` operations. Signed-off-by: Yong Tang <[email protected]>
1 parent 525d55f commit a504bf0

File tree

5 files changed

+411
-0
lines changed

5 files changed

+411
-0
lines changed

tensorflow_io/arrow/kernels/arrow_kernels.cc

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@ limitations under the License.
1515

1616
#include "tensorflow/core/framework/op_kernel.h"
1717
#include "tensorflow_io/arrow/kernels/arrow_kernels.h"
18+
#include "tensorflow_io/core/kernels/io_interface.h"
1819
#include "arrow/io/api.h"
1920
#include "arrow/ipc/feather.h"
2021
#include "arrow/ipc/feather_generated.h"
2122
#include "arrow/buffer.h"
23+
#include "arrow/table.h"
2224

2325
namespace tensorflow {
2426
namespace data {
@@ -155,5 +157,251 @@ REGISTER_KERNEL_BUILDER(Name("ListFeatherColumns").Device(DEVICE_CPU),
155157

156158

157159
} // namespace
160+
161+
162+
class FeatherIndexable : public IOIndexableInterface {
163+
public:
164+
FeatherIndexable(Env* env)
165+
: env_(env) {}
166+
167+
~FeatherIndexable() {}
168+
Status Init(const std::vector<string>& input, const std::vector<string>& metadata, const void* memory_data, const int64 memory_size) override {
169+
if (input.size() > 1) {
170+
return errors::InvalidArgument("more than 1 filename is not supported");
171+
}
172+
173+
const string& filename = input[0];
174+
file_.reset(new SizedRandomAccessFile(env_, filename, memory_data, memory_size));
175+
TF_RETURN_IF_ERROR(file_->GetFileSize(&file_size_));
176+
177+
// FEA1.....[metadata][uint32 metadata_length]FEA1
178+
static constexpr const char* kFeatherMagicBytes = "FEA1";
179+
180+
size_t header_length = strlen(kFeatherMagicBytes);
181+
size_t footer_length = sizeof(uint32) + strlen(kFeatherMagicBytes);
182+
183+
string buffer;
184+
buffer.resize(header_length > footer_length ? header_length : footer_length);
185+
186+
StringPiece result;
187+
188+
TF_RETURN_IF_ERROR(file_->Read(0, header_length, &result, &buffer[0]));
189+
if (memcmp(buffer.data(), kFeatherMagicBytes, header_length) != 0) {
190+
return errors::InvalidArgument("not a feather file");
191+
}
192+
193+
TF_RETURN_IF_ERROR(file_->Read(file_size_ - footer_length, footer_length, &result, &buffer[0]));
194+
if (memcmp(buffer.data() + sizeof(uint32), kFeatherMagicBytes, footer_length - sizeof(uint32)) != 0) {
195+
return errors::InvalidArgument("incomplete feather file");
196+
}
197+
198+
uint32 metadata_length = *reinterpret_cast<const uint32*>(buffer.data());
199+
200+
buffer.resize(metadata_length);
201+
202+
TF_RETURN_IF_ERROR(file_->Read(file_size_ - footer_length - metadata_length, metadata_length, &result, &buffer[0]));
203+
204+
const ::arrow::ipc::feather::fbs::CTable* table = ::arrow::ipc::feather::fbs::GetCTable(buffer.data());
205+
206+
if (table->version() < ::arrow::ipc::feather::kFeatherVersion) {
207+
return errors::InvalidArgument("feather file is old: ", table->version(), " vs. ", ::arrow::ipc::feather::kFeatherVersion);
208+
}
209+
210+
std::vector<string> columns;
211+
for (size_t i = 0; i < metadata.size(); i++) {
212+
if (metadata[i].find_first_of("column: ") == 0) {
213+
columns.emplace_back(metadata[i].substr(8));
214+
}
215+
}
216+
217+
columns_index_.clear();
218+
if (columns.size() == 0) {
219+
for (int i = 0; i < table->columns()->size(); i++) {
220+
columns_index_.push_back(i);
221+
}
222+
} else {
223+
std::unordered_map<string, int> columns_map;
224+
for (int i = 0; i < table->columns()->size(); i++) {
225+
columns_map[table->columns()->Get(i)->name()->str()] = i;
226+
}
227+
for (size_t i = 0; i < columns.size(); i++) {
228+
columns_index_.push_back(columns_map[columns[i]]);
229+
}
230+
}
231+
232+
for (size_t i = 0; i < columns_index_.size(); i++) {
233+
int column_index = columns_index_[i];
234+
::tensorflow::DataType dtype = ::tensorflow::DataType::DT_INVALID;
235+
switch (table->columns()->Get(column_index)->values()->type()) {
236+
case ::arrow::ipc::feather::fbs::Type_BOOL:
237+
dtype = ::tensorflow::DataType::DT_BOOL;
238+
break;
239+
case ::arrow::ipc::feather::fbs::Type_INT8:
240+
dtype = ::tensorflow::DataType::DT_INT8;
241+
break;
242+
case ::arrow::ipc::feather::fbs::Type_INT16:
243+
dtype = ::tensorflow::DataType::DT_INT16;
244+
break;
245+
case ::arrow::ipc::feather::fbs::Type_INT32:
246+
dtype = ::tensorflow::DataType::DT_INT32;
247+
break;
248+
case ::arrow::ipc::feather::fbs::Type_INT64:
249+
dtype = ::tensorflow::DataType::DT_INT64;
250+
break;
251+
case ::arrow::ipc::feather::fbs::Type_UINT8:
252+
dtype = ::tensorflow::DataType::DT_UINT8;
253+
break;
254+
case ::arrow::ipc::feather::fbs::Type_UINT16:
255+
dtype = ::tensorflow::DataType::DT_UINT16;
256+
break;
257+
case ::arrow::ipc::feather::fbs::Type_UINT32:
258+
dtype = ::tensorflow::DataType::DT_UINT32;
259+
break;
260+
case ::arrow::ipc::feather::fbs::Type_UINT64:
261+
dtype = ::tensorflow::DataType::DT_UINT64;
262+
break;
263+
case ::arrow::ipc::feather::fbs::Type_FLOAT:
264+
dtype = ::tensorflow::DataType::DT_FLOAT;
265+
break;
266+
case ::arrow::ipc::feather::fbs::Type_DOUBLE:
267+
dtype = ::tensorflow::DataType::DT_DOUBLE;
268+
break;
269+
case ::arrow::ipc::feather::fbs::Type_UTF8:
270+
case ::arrow::ipc::feather::fbs::Type_BINARY:
271+
case ::arrow::ipc::feather::fbs::Type_CATEGORY:
272+
case ::arrow::ipc::feather::fbs::Type_TIMESTAMP:
273+
case ::arrow::ipc::feather::fbs::Type_DATE:
274+
case ::arrow::ipc::feather::fbs::Type_TIME:
275+
// case ::arrow::ipc::feather::fbs::Type_LARGE_UTF8:
276+
// case ::arrow::ipc::feather::fbs::Type_LARGE_BINARY:
277+
default:
278+
break;
279+
}
280+
dtypes_.push_back(dtype);
281+
shapes_.push_back(TensorShape({static_cast<int64>(table->num_rows())}));
282+
columns_.push_back(table->columns()->Get(column_index)->name()->str());
283+
}
284+
285+
return Status::OK();
286+
}
287+
Status Spec(std::vector<DataType>& dtypes, std::vector<PartialTensorShape>& shapes) override {
288+
dtypes.clear();
289+
for (size_t i = 0; i < dtypes_.size(); i++) {
290+
dtypes.push_back(dtypes_[i]);
291+
}
292+
shapes.clear();
293+
for (size_t i = 0; i < shapes_.size(); i++) {
294+
shapes.push_back(shapes_[i]);
295+
}
296+
return Status::OK();
297+
}
298+
299+
Status Extra(std::vector<Tensor>* extra) override {
300+
// Expose columns
301+
Tensor columns(DT_STRING, TensorShape({static_cast<int64>(columns_.size())}));
302+
for (size_t i = 0; i < columns_.size(); i++) {
303+
columns.flat<string>()(i) = columns_[i];
304+
}
305+
extra->push_back(columns);
306+
return Status::OK();
307+
}
308+
309+
Status GetItem(const int64 start, const int64 stop, const int64 step, std::vector<Tensor>& tensors) override {
310+
if (step != 1) {
311+
return errors::InvalidArgument("step ", step, " is not supported");
312+
}
313+
314+
if (feather_file_.get() == nullptr) {
315+
feather_file_.reset(new ArrowRandomAccessFile(file_.get(), file_size_));
316+
arrow::Status s = arrow::ipc::feather::TableReader::Open(feather_file_, &reader_);
317+
if (!s.ok()) {
318+
return errors::Internal(s.ToString());
319+
}
320+
}
321+
322+
for (size_t i = 0; i < tensors.size(); i++) {
323+
int column_index = columns_index_[i];
324+
325+
std::shared_ptr<arrow::Column> column;
326+
arrow::Status s = reader_->GetColumn(column_index, &column);
327+
if (!s.ok()) {
328+
return errors::Internal(s.ToString());
329+
}
330+
331+
std::shared_ptr<::arrow::Column> slice = column->Slice(start, stop);
332+
333+
#define FEATHER_PROCESS_TYPE(TTYPE,ATYPE) { \
334+
int64 curr_index = 0; \
335+
for (auto chunk : slice->data()->chunks()) { \
336+
for (int64_t item = 0; item < chunk->length(); item++) { \
337+
tensors[i].flat<TTYPE>()(curr_index) = (dynamic_cast<ATYPE *>(chunk.get()))->Value(item); \
338+
curr_index++; \
339+
} \
340+
} \
341+
}
342+
switch (tensors[i].dtype()) {
343+
case DT_BOOL:
344+
FEATHER_PROCESS_TYPE(bool, ::arrow::BooleanArray);
345+
break;
346+
case DT_INT8:
347+
FEATHER_PROCESS_TYPE(int8, ::arrow::NumericArray<::arrow::Int8Type>);
348+
break;
349+
case DT_UINT8:
350+
FEATHER_PROCESS_TYPE(uint8, ::arrow::NumericArray<::arrow::UInt8Type>);
351+
break;
352+
case DT_INT16:
353+
FEATHER_PROCESS_TYPE(int16, ::arrow::NumericArray<::arrow::Int16Type>);
354+
break;
355+
case DT_UINT16:
356+
FEATHER_PROCESS_TYPE(uint16, ::arrow::NumericArray<::arrow::UInt16Type>);
357+
break;
358+
case DT_INT32:
359+
FEATHER_PROCESS_TYPE(int32, ::arrow::NumericArray<::arrow::Int32Type>);
360+
break;
361+
case DT_UINT32:
362+
FEATHER_PROCESS_TYPE(uint32, ::arrow::NumericArray<::arrow::UInt32Type>);
363+
break;
364+
case DT_INT64:
365+
FEATHER_PROCESS_TYPE(int64, ::arrow::NumericArray<::arrow::Int64Type>);
366+
break;
367+
case DT_UINT64:
368+
FEATHER_PROCESS_TYPE(uint64, ::arrow::NumericArray<::arrow::UInt64Type>);
369+
break;
370+
case DT_FLOAT:
371+
FEATHER_PROCESS_TYPE(float, ::arrow::NumericArray<::arrow::FloatType>);
372+
break;
373+
case DT_DOUBLE:
374+
FEATHER_PROCESS_TYPE(double, ::arrow::NumericArray<::arrow::DoubleType>);
375+
break;
376+
default:
377+
return errors::InvalidArgument("data type is not supported: ", DataTypeString(tensors[i].dtype()));
378+
}
379+
}
380+
381+
return Status::OK();
382+
}
383+
384+
string DebugString() const override {
385+
mutex_lock l(mu_);
386+
return strings::StrCat("FeatherIndexable");
387+
}
388+
private:
389+
mutable mutex mu_;
390+
Env* env_ GUARDED_BY(mu_);
391+
std::unique_ptr<SizedRandomAccessFile> file_ GUARDED_BY(mu_);
392+
uint64 file_size_ GUARDED_BY(mu_);
393+
std::shared_ptr<ArrowRandomAccessFile> feather_file_ GUARDED_BY(mu_);
394+
std::unique_ptr<arrow::ipc::feather::TableReader> reader_ GUARDED_BY(mu_);
395+
396+
std::vector<DataType> dtypes_;
397+
std::vector<TensorShape> shapes_;
398+
std::vector<string> columns_;
399+
std::vector<int> columns_index_;
400+
};
401+
402+
REGISTER_KERNEL_BUILDER(Name("FeatherIndexableInit").Device(DEVICE_CPU),
403+
IOInterfaceInitOp<FeatherIndexable>);
404+
REGISTER_KERNEL_BUILDER(Name("FeatherIndexableGetItem").Device(DEVICE_CPU),
405+
IOIndexableGetItemOp<FeatherIndexable>);
158406
} // namespace data
159407
} // namespace tensorflow

tensorflow_io/arrow/ops/dataset_ops.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,4 +100,43 @@ REGISTER_OP("ListFeatherColumns")
100100
return Status::OK();
101101
});
102102

103+
REGISTER_OP("FeatherIndexableInit")
104+
.Input("input: string")
105+
.Input("metadata: string")
106+
.Output("output: resource")
107+
.Output("dtypes: int64")
108+
.Output("shapes: int64")
109+
.Output("columns: string")
110+
.Attr("container: string = ''")
111+
.Attr("shared_name: string = ''")
112+
.SetIsStateful()
113+
.SetShapeFn([](shape_inference::InferenceContext* c) {
114+
c->set_output(0, c->Scalar());
115+
c->set_output(1, c->MakeShape({c->UnknownDim()}));
116+
c->set_output(2, c->MakeShape({c->UnknownDim(), c->UnknownDim()}));
117+
c->set_output(3, c->MakeShape({c->UnknownDim()}));
118+
return Status::OK();
119+
});
120+
121+
REGISTER_OP("FeatherIndexableGetItem")
122+
.Input("input: resource")
123+
.Input("start: int64")
124+
.Input("stop: int64")
125+
.Input("step: int64")
126+
.Output("output: dtype")
127+
.Attr("dtype: list(type) >= 1")
128+
.Attr("shape: list(shape) >= 1")
129+
.SetShapeFn([](shape_inference::InferenceContext* c) {
130+
std::vector<PartialTensorShape> shape;
131+
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
132+
if (shape.size() != c->num_outputs()) {
133+
return errors::InvalidArgument("`shape` must be the same length as `types` (", shape.size(), " vs. ", c->num_outputs());
134+
}
135+
for (size_t i = 0; i < shape.size(); ++i) {
136+
shape_inference::ShapeHandle entry;
137+
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape[i], &entry));
138+
c->set_output(static_cast<int64>(i), entry);
139+
}
140+
return Status::OK();
141+
});
103142
} // namespace tensorflow
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""FeatherIOTensor"""
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import uuid
21+
22+
import tensorflow as tf
23+
from tensorflow_io.core.python.ops import io_tensor_ops
24+
from tensorflow_io.core.python.ops import core_ops
25+
26+
class FeatherIOTensor(io_tensor_ops._TableIOTensor): # pylint: disable=protected-access
27+
"""FeatherIOTensor"""
28+
29+
#=============================================================================
30+
# Constructor (private)
31+
#=============================================================================
32+
def __init__(self,
33+
filename,
34+
columns=None,
35+
internal=False):
36+
with tf.name_scope("FeatherIOTensor") as scope:
37+
metadata = []
38+
if columns is not None:
39+
metadata.extend(["column: "+column for column in columns])
40+
resource, dtypes, shapes, columns = core_ops.feather_indexable_init(
41+
filename, metadata=metadata,
42+
container=scope,
43+
shared_name="%s/%s" % (filename, uuid.uuid4().hex))
44+
self._filename = filename
45+
super(FeatherIOTensor, self).__init__(
46+
shapes, dtypes, columns, filename,
47+
resource, core_ops.feather_indexable_get_item,
48+
internal=internal)

tensorflow_io/core/python/ops/io_tensor.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from tensorflow_io.core.python.ops import audio_io_tensor_ops
2323
from tensorflow_io.core.python.ops import json_io_tensor_ops
2424
from tensorflow_io.core.python.ops import kafka_io_tensor_ops
25+
from tensorflow_io.core.python.ops import feather_io_tensor_ops
2526

2627
class IOTensor(io_tensor_ops._BaseIOTensor): # pylint: disable=protected-access
2728
"""IOTensor
@@ -264,3 +265,20 @@ def from_kafka(cls,
264265
servers=kwargs.get("servers", None),
265266
configuration=kwargs.get("configuration", None),
266267
internal=True)
268+
269+
@classmethod
270+
def from_feather(cls,
271+
filename,
272+
**kwargs):
273+
"""Creates an `IOTensor` from an feather file.
274+
275+
Args:
276+
filename: A string, the filename of an feather file.
277+
name: A name prefix for the IOTensor (optional).
278+
279+
Returns:
280+
A `IOTensor`.
281+
282+
"""
283+
with tf.name_scope(kwargs.get("name", "IOFromJSON")):
284+
return feather_io_tensor_ops.FeatherIOTensor(filename, internal=True)

0 commit comments

Comments
 (0)