Skip to content

Commit 497290a

Browse files
committed
Add tfio.IOTensor.from_feather support
Note: this PR depends on PR 438. Feather is a columnar file format that often seen with pandas. This PR adds the indexing and slicing support to bring Feather to parity with Parquet file format, by adding tfio.IOTensor.from_feather support so that it is possible to access feather through natual `__getitem__` operations. Signed-off-by: Yong Tang <[email protected]>
1 parent 6844b4f commit 497290a

File tree

5 files changed

+381
-0
lines changed

5 files changed

+381
-0
lines changed

tensorflow_io/arrow/kernels/arrow_kernels.cc

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@ limitations under the License.
1515

1616
#include "tensorflow/core/framework/op_kernel.h"
1717
#include "tensorflow_io/arrow/kernels/arrow_kernels.h"
18+
#include "tensorflow_io/core/kernels/io_interface.h"
1819
#include "arrow/io/api.h"
1920
#include "arrow/ipc/feather.h"
2021
#include "arrow/ipc/feather_generated.h"
2122
#include "arrow/buffer.h"
2223
#include "arrow/adapters/tensorflow/convert.h"
24+
#include "arrow/table.h"
2325

2426
namespace tensorflow {
2527
namespace data {
@@ -173,5 +175,223 @@ REGISTER_KERNEL_BUILDER(Name("ListFeatherColumns").Device(DEVICE_CPU),
173175

174176

175177
} // namespace
178+
179+
180+
class FeatherIndexable : public IOIndexableInterface {
181+
public:
182+
FeatherIndexable(Env* env)
183+
: env_(env) {}
184+
185+
~FeatherIndexable() {}
186+
Status Init(const std::vector<string>& input, const std::vector<string>& metadata, const void* memory_data, const int64 memory_size) override {
187+
if (input.size() > 1) {
188+
return errors::InvalidArgument("more than 1 filename is not supported");
189+
}
190+
191+
const string& filename = input[0];
192+
file_.reset(new SizedRandomAccessFile(env_, filename, memory_data, memory_size));
193+
TF_RETURN_IF_ERROR(file_->GetFileSize(&file_size_));
194+
195+
// FEA1.....[metadata][uint32 metadata_length]FEA1
196+
static constexpr const char* kFeatherMagicBytes = "FEA1";
197+
198+
size_t header_length = strlen(kFeatherMagicBytes);
199+
size_t footer_length = sizeof(uint32) + strlen(kFeatherMagicBytes);
200+
201+
string buffer;
202+
buffer.resize(header_length > footer_length ? header_length : footer_length);
203+
204+
StringPiece result;
205+
206+
TF_RETURN_IF_ERROR(file_->Read(0, header_length, &result, &buffer[0]));
207+
if (memcmp(buffer.data(), kFeatherMagicBytes, header_length) != 0) {
208+
return errors::InvalidArgument("not a feather file");
209+
}
210+
211+
TF_RETURN_IF_ERROR(file_->Read(file_size_ - footer_length, footer_length, &result, &buffer[0]));
212+
if (memcmp(buffer.data() + sizeof(uint32), kFeatherMagicBytes, footer_length - sizeof(uint32)) != 0) {
213+
return errors::InvalidArgument("incomplete feather file");
214+
}
215+
216+
uint32 metadata_length = *reinterpret_cast<const uint32*>(buffer.data());
217+
218+
buffer.resize(metadata_length);
219+
220+
TF_RETURN_IF_ERROR(file_->Read(file_size_ - footer_length - metadata_length, metadata_length, &result, &buffer[0]));
221+
222+
const ::arrow::ipc::feather::fbs::CTable* table = ::arrow::ipc::feather::fbs::GetCTable(buffer.data());
223+
224+
if (table->version() < ::arrow::ipc::feather::kFeatherVersion) {
225+
return errors::InvalidArgument("feather file is old: ", table->version(), " vs. ", ::arrow::ipc::feather::kFeatherVersion);
226+
}
227+
228+
for (int i = 0; i < table->columns()->size(); i++) {
229+
::tensorflow::DataType dtype = ::tensorflow::DataType::DT_INVALID;
230+
switch (table->columns()->Get(i)->values()->type()) {
231+
case ::arrow::ipc::feather::fbs::Type_BOOL:
232+
dtype = ::tensorflow::DataType::DT_BOOL;
233+
break;
234+
case ::arrow::ipc::feather::fbs::Type_INT8:
235+
dtype = ::tensorflow::DataType::DT_INT8;
236+
break;
237+
case ::arrow::ipc::feather::fbs::Type_INT16:
238+
dtype = ::tensorflow::DataType::DT_INT16;
239+
break;
240+
case ::arrow::ipc::feather::fbs::Type_INT32:
241+
dtype = ::tensorflow::DataType::DT_INT32;
242+
break;
243+
case ::arrow::ipc::feather::fbs::Type_INT64:
244+
dtype = ::tensorflow::DataType::DT_INT64;
245+
break;
246+
case ::arrow::ipc::feather::fbs::Type_UINT8:
247+
dtype = ::tensorflow::DataType::DT_UINT8;
248+
break;
249+
case ::arrow::ipc::feather::fbs::Type_UINT16:
250+
dtype = ::tensorflow::DataType::DT_UINT16;
251+
break;
252+
case ::arrow::ipc::feather::fbs::Type_UINT32:
253+
dtype = ::tensorflow::DataType::DT_UINT32;
254+
break;
255+
case ::arrow::ipc::feather::fbs::Type_UINT64:
256+
dtype = ::tensorflow::DataType::DT_UINT64;
257+
break;
258+
case ::arrow::ipc::feather::fbs::Type_FLOAT:
259+
dtype = ::tensorflow::DataType::DT_FLOAT;
260+
break;
261+
case ::arrow::ipc::feather::fbs::Type_DOUBLE:
262+
dtype = ::tensorflow::DataType::DT_DOUBLE;
263+
break;
264+
case ::arrow::ipc::feather::fbs::Type_UTF8:
265+
case ::arrow::ipc::feather::fbs::Type_BINARY:
266+
case ::arrow::ipc::feather::fbs::Type_CATEGORY:
267+
case ::arrow::ipc::feather::fbs::Type_TIMESTAMP:
268+
case ::arrow::ipc::feather::fbs::Type_DATE:
269+
case ::arrow::ipc::feather::fbs::Type_TIME:
270+
// case ::arrow::ipc::feather::fbs::Type_LARGE_UTF8:
271+
// case ::arrow::ipc::feather::fbs::Type_LARGE_BINARY:
272+
default:
273+
break;
274+
}
275+
shapes_.push_back(TensorShape({static_cast<int64>(table->num_rows())}));
276+
dtypes_.push_back(dtype);
277+
columns_.push_back(table->columns()->Get(i)->name()->str());
278+
}
279+
280+
return Status::OK();
281+
}
282+
Status Spec(std::vector<PartialTensorShape>& shapes, std::vector<DataType>& dtypes) override {
283+
shapes.clear();
284+
for (size_t i = 0; i < shapes_.size(); i++) {
285+
shapes.push_back(shapes_[i]);
286+
}
287+
dtypes.clear();
288+
for (size_t i = 0; i < dtypes_.size(); i++) {
289+
dtypes.push_back(dtypes_[i]);
290+
}
291+
return Status::OK();
292+
}
293+
294+
Status Extra(std::vector<Tensor>* extra) override {
295+
// Expose columns
296+
Tensor columns(DT_STRING, TensorShape({static_cast<int64>(columns_.size())}));
297+
for (size_t i = 0; i < columns_.size(); i++) {
298+
columns.flat<string>()(i) = columns_[i];
299+
}
300+
extra->push_back(columns);
301+
return Status::OK();
302+
}
303+
304+
Status GetItem(const int64 start, const int64 stop, const int64 step, const int64 component, Tensor* tensor) override {
305+
if (step != 1) {
306+
return errors::InvalidArgument("step ", step, " is not supported");
307+
}
308+
309+
if (feather_file_.get() == nullptr) {
310+
feather_file_.reset(new ArrowRandomAccessFile(file_.get(), file_size_));
311+
arrow::Status s = arrow::ipc::feather::TableReader::Open(feather_file_, &reader_);
312+
if (!s.ok()) {
313+
return errors::Internal(s.ToString());
314+
}
315+
}
316+
317+
std::shared_ptr<arrow::Column> column;
318+
arrow::Status s = reader_->GetColumn(component, &column);
319+
if (!s.ok()) {
320+
return errors::Internal(s.ToString());
321+
}
322+
323+
std::shared_ptr<::arrow::Column> slice = column->Slice(start, stop);
324+
325+
#define FEATHER_PROCESS_TYPE(TTYPE,ATYPE) { \
326+
int64 curr_index = 0; \
327+
for (auto chunk : slice->data()->chunks()) { \
328+
for (int64_t item = 0; item < chunk->length(); item++) { \
329+
tensor->flat<TTYPE>()(curr_index) = (dynamic_cast<ATYPE *>(chunk.get()))->Value(item); \
330+
curr_index++; \
331+
} \
332+
} \
333+
}
334+
switch (tensor->dtype()) {
335+
case DT_BOOL:
336+
FEATHER_PROCESS_TYPE(bool, ::arrow::BooleanArray);
337+
break;
338+
case DT_INT8:
339+
FEATHER_PROCESS_TYPE(int8, ::arrow::NumericArray<::arrow::Int8Type>);
340+
break;
341+
case DT_UINT8:
342+
FEATHER_PROCESS_TYPE(uint8, ::arrow::NumericArray<::arrow::UInt8Type>);
343+
break;
344+
case DT_INT16:
345+
FEATHER_PROCESS_TYPE(int16, ::arrow::NumericArray<::arrow::Int16Type>);
346+
break;
347+
case DT_UINT16:
348+
FEATHER_PROCESS_TYPE(uint16, ::arrow::NumericArray<::arrow::UInt16Type>);
349+
break;
350+
case DT_INT32:
351+
FEATHER_PROCESS_TYPE(int32, ::arrow::NumericArray<::arrow::Int32Type>);
352+
break;
353+
case DT_UINT32:
354+
FEATHER_PROCESS_TYPE(uint32, ::arrow::NumericArray<::arrow::UInt32Type>);
355+
break;
356+
case DT_INT64:
357+
FEATHER_PROCESS_TYPE(int64, ::arrow::NumericArray<::arrow::Int64Type>);
358+
break;
359+
case DT_UINT64:
360+
FEATHER_PROCESS_TYPE(uint64, ::arrow::NumericArray<::arrow::UInt64Type>);
361+
break;
362+
case DT_FLOAT:
363+
FEATHER_PROCESS_TYPE(float, ::arrow::NumericArray<::arrow::FloatType>);
364+
break;
365+
case DT_DOUBLE:
366+
FEATHER_PROCESS_TYPE(double, ::arrow::NumericArray<::arrow::DoubleType>);
367+
break;
368+
default:
369+
return errors::InvalidArgument("data type is not supported: ", DataTypeString(tensor->dtype()));
370+
}
371+
372+
return Status::OK();
373+
}
374+
375+
string DebugString() const override {
376+
mutex_lock l(mu_);
377+
return strings::StrCat("FeatherIndexable");
378+
}
379+
private:
380+
mutable mutex mu_;
381+
Env* env_ GUARDED_BY(mu_);
382+
std::unique_ptr<SizedRandomAccessFile> file_ GUARDED_BY(mu_);
383+
uint64 file_size_ GUARDED_BY(mu_);
384+
std::shared_ptr<ArrowRandomAccessFile> feather_file_ GUARDED_BY(mu_);
385+
std::unique_ptr<arrow::ipc::feather::TableReader> reader_ GUARDED_BY(mu_);
386+
387+
std::vector<DataType> dtypes_;
388+
std::vector<TensorShape> shapes_;
389+
std::vector<string> columns_;
390+
};
391+
392+
REGISTER_KERNEL_BUILDER(Name("FeatherIndexableInit").Device(DEVICE_CPU),
393+
IOInterfaceInitOp<FeatherIndexable>);
394+
REGISTER_KERNEL_BUILDER(Name("FeatherIndexableGetItem").Device(DEVICE_CPU),
395+
IOIndexableGetItemOp<FeatherIndexable>);
176396
} // namespace data
177397
} // namespace tensorflow

tensorflow_io/arrow/ops/dataset_ops.cc

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,4 +100,38 @@ REGISTER_OP("ListFeatherColumns")
100100
return Status::OK();
101101
});
102102

103+
REGISTER_OP("FeatherIndexableInit")
104+
.Input("input: string")
105+
.Output("output: resource")
106+
.Output("shapes: int64")
107+
.Output("dtypes: int64")
108+
.Output("columns: string")
109+
.Attr("container: string = ''")
110+
.Attr("shared_name: string = ''")
111+
.SetIsStateful()
112+
.SetShapeFn([](shape_inference::InferenceContext* c) {
113+
c->set_output(0, c->Scalar());
114+
c->set_output(1, c->MakeShape({c->UnknownDim()}));
115+
c->set_output(2, c->MakeShape({c->UnknownDim(), c->UnknownDim()}));
116+
c->set_output(3, c->MakeShape({c->UnknownDim()}));
117+
return Status::OK();
118+
});
119+
120+
REGISTER_OP("FeatherIndexableGetItem")
121+
.Input("input: resource")
122+
.Input("start: int64")
123+
.Input("stop: int64")
124+
.Input("step: int64")
125+
.Input("component: int64")
126+
.Output("output: dtype")
127+
.Attr("shape: shape")
128+
.Attr("dtype: type")
129+
.SetShapeFn([](shape_inference::InferenceContext* c) {
130+
PartialTensorShape shape;
131+
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
132+
shape_inference::ShapeHandle entry;
133+
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &entry));
134+
c->set_output(0, entry);
135+
return Status::OK();
136+
});
103137
} // namespace tensorflow
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""FeatherIOTensor"""
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import uuid
21+
22+
import tensorflow as tf
23+
from tensorflow_io.core.python.ops import io_tensor_ops
24+
from tensorflow_io.core.python.ops import core_ops
25+
26+
class FeatherIOTensor(io_tensor_ops._TableIOTensor): # pylint: disable=protected-access
27+
"""FeatherIOTensor"""
28+
29+
#=============================================================================
30+
# Constructor (private)
31+
#=============================================================================
32+
def __init__(self,
33+
filename,
34+
internal=False):
35+
with tf.name_scope("FeatherIOTensor") as scope:
36+
resource, shapes, dtypes, columns = core_ops.feather_indexable_init(
37+
filename,
38+
container=scope,
39+
shared_name="%s/%s" % (filename, uuid.uuid4().hex))
40+
shapes = [
41+
tf.TensorShape(
42+
[None if dim < 0 else dim for dim in e.numpy() if dim != 0]
43+
) for e in tf.unstack(shapes)]
44+
dtypes = [tf.as_dtype(e.numpy()) for e in tf.unstack(dtypes)]
45+
columns = [e.numpy().decode() for e in tf.unstack(columns)]
46+
spec = tuple([tf.TensorSpec(shape, dtype, column) for (
47+
shape, dtype, column) in zip(shapes, dtypes, columns)])
48+
super(FeatherIOTensor, self).__init__(
49+
spec, columns,
50+
resource, core_ops.feather_indexable_get_item,
51+
internal=internal)

tensorflow_io/core/python/ops/io_tensor.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from tensorflow_io.core.python.ops import json_io_tensor_ops
2424
from tensorflow_io.core.python.ops import kafka_io_tensor_ops
2525
from tensorflow_io.core.python.ops import prometheus_io_tensor_ops
26+
from tensorflow_io.core.python.ops import feather_io_tensor_ops
2627

2728
class IOTensor(io_tensor_ops._IOTensor): # pylint: disable=protected-access
2829
"""IOTensor
@@ -287,3 +288,20 @@ def from_prometheus(cls,
287288
with tf.name_scope(kwargs.get("name", "IOFromPrometheus")):
288289
return prometheus_io_tensor_ops.PrometheusIOTensor(
289290
query, endpoint=kwargs.get("endpoint", None), internal=True)
291+
292+
@classmethod
293+
def from_feather(cls,
294+
filename,
295+
**kwargs):
296+
"""Creates an `IOTensor` from an feather file.
297+
298+
Args:
299+
filename: A string, the filename of an feather file.
300+
name: A name prefix for the IOTensor (optional).
301+
302+
Returns:
303+
A `IOTensor`.
304+
305+
"""
306+
with tf.name_scope(kwargs.get("name", "IOFromFeather")):
307+
return feather_io_tensor_ops.FeatherIOTensor(filename, internal=True)

0 commit comments

Comments
 (0)