Skip to content

Commit 1c2ceba

Browse files
authored
Add list_feather_columns function in eager mode (tensorflow#404)
* Add list_feather_columns function in eager mode This PR adds list_feather_columns function in eager mode, so that it is possible to get the column name and spec information for feather format. This PR implements an `::arrow::io::RandomAccessFile` interface so it is possible to read files through scheme file system, e.g., s3, gcs, azfs, etc. The `::arrow::io::RandomAccessFile` is the same as in Parquet PR 384 so they could be combined. Signed-off-by: Yong Tang <[email protected]> * Use flatbuffer to read feather metadata, to avoid reading whole file through feather api. Signed-off-by: Yong Tang <[email protected]> * Keep unsupported datatype so that it is possible to process in python, based on review comment Signed-off-by: Yong Tang <[email protected]> * Combine .so files into one place to reduce whl package size Signed-off-by: Yong Tang <[email protected]> * Combine ArrowRandomAccessFile and ParquetRandomAccessFile as they are the same Signed-off-by: Yong Tang <[email protected]>
1 parent f05b3f0 commit 1c2ceba

File tree

10 files changed

+317
-68
lines changed

10 files changed

+317
-68
lines changed

tensorflow_io/arrow/BUILD

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,22 @@ load(
77
"tf_io_copts",
88
)
99

10-
cc_binary(
11-
name = "python/ops/_arrow_ops.so",
10+
cc_library(
11+
name = "arrow_ops",
1212
srcs = [
1313
"kernels/arrow_dataset_ops.cc",
14+
"kernels/arrow_kernels.cc",
15+
"kernels/arrow_kernels.h",
1416
"kernels/arrow_stream_client.h",
1517
"kernels/arrow_stream_client_unix.cc",
1618
"kernels/arrow_util.cc",
1719
"kernels/arrow_util.h",
1820
"ops/dataset_ops.cc",
1921
],
2022
copts = tf_io_copts(),
21-
linkshared = 1,
23+
linkstatic = True,
2224
deps = [
25+
"//tensorflow_io/core:dataset_ops",
2326
"@arrow",
24-
"@local_config_tf//:libtensorflow_framework",
25-
"@local_config_tf//:tf_header_lib",
2627
],
2728
)

tensorflow_io/arrow/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
@@ArrowDataset
1818
@@ArrowFeatherDataset
1919
@@ArrowStreamDataset
20+
@@list_feather_columns
2021
"""
2122

2223
from __future__ import absolute_import
@@ -26,13 +27,15 @@
2627
from tensorflow_io.arrow.python.ops.arrow_dataset_ops import ArrowDataset
2728
from tensorflow_io.arrow.python.ops.arrow_dataset_ops import ArrowFeatherDataset
2829
from tensorflow_io.arrow.python.ops.arrow_dataset_ops import ArrowStreamDataset
30+
from tensorflow_io.arrow.python.ops.arrow_dataset_ops import list_feather_columns
2931

3032
from tensorflow.python.util.all_util import remove_undocumented
3133

3234
_allowed_symbols = [
3335
"ArrowDataset",
3436
"ArrowFeatherDataset",
3537
"ArrowStreamDataset",
38+
"list_feather_columns",
3639
]
3740

3841
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/core/framework/op_kernel.h"
17+
#include "tensorflow_io/arrow/kernels/arrow_kernels.h"
18+
#include "arrow/io/api.h"
19+
#include "arrow/ipc/feather.h"
20+
#include "arrow/ipc/feather_generated.h"
21+
#include "arrow/buffer.h"
22+
23+
namespace tensorflow {
24+
namespace data {
25+
namespace {
26+
27+
class ListFeatherColumnsOp : public OpKernel {
28+
public:
29+
explicit ListFeatherColumnsOp(OpKernelConstruction* context) : OpKernel(context) {
30+
env_ = context->env();
31+
}
32+
33+
void Compute(OpKernelContext* context) override {
34+
const Tensor& filename_tensor = context->input(0);
35+
const string filename = filename_tensor.scalar<string>()();
36+
37+
const Tensor& memory_tensor = context->input(1);
38+
const string& memory = memory_tensor.scalar<string>()();
39+
std::unique_ptr<SizedRandomAccessFile> file(new SizedRandomAccessFile(env_, filename, memory.data(), memory.size()));
40+
uint64 size;
41+
OP_REQUIRES_OK(context, file->GetFileSize(&size));
42+
43+
// FEA1.....[metadata][uint32 metadata_length]FEA1
44+
static constexpr const char* kFeatherMagicBytes = "FEA1";
45+
46+
size_t header_length = strlen(kFeatherMagicBytes);
47+
size_t footer_length = sizeof(uint32) + strlen(kFeatherMagicBytes);
48+
49+
string buffer;
50+
buffer.resize(header_length > footer_length ? header_length : footer_length);
51+
52+
StringPiece result;
53+
54+
OP_REQUIRES_OK(context, file->Read(0, header_length, &result, &buffer[0]));
55+
OP_REQUIRES(context, !memcmp(buffer.data(), kFeatherMagicBytes, header_length), errors::InvalidArgument("not a feather file"));
56+
57+
OP_REQUIRES_OK(context, file->Read(size - footer_length, footer_length, &result, &buffer[0]));
58+
OP_REQUIRES(context, !memcmp(buffer.data() + sizeof(uint32), kFeatherMagicBytes, footer_length - sizeof(uint32)), errors::InvalidArgument("incomplete feather file"));
59+
60+
uint32 metadata_length = *reinterpret_cast<const uint32*>(buffer.data());
61+
62+
buffer.resize(metadata_length);
63+
64+
OP_REQUIRES_OK(context, file->Read(size - footer_length - metadata_length, metadata_length, &result, &buffer[0]));
65+
66+
const ::arrow::ipc::feather::fbs::CTable* table = ::arrow::ipc::feather::fbs::GetCTable(buffer.data());
67+
68+
OP_REQUIRES(context, (table->version() >= ::arrow::ipc::feather::kFeatherVersion), errors::InvalidArgument("feather file is old: ", table->version(), " vs. ", ::arrow::ipc::feather::kFeatherVersion));
69+
70+
std::vector<string> columns;
71+
std::vector<string> dtypes;
72+
std::vector<int64> counts;
73+
columns.reserve(table->columns()->size());
74+
dtypes.reserve(table->columns()->size());
75+
counts.reserve(table->columns()->size());
76+
77+
for (int64 i = 0; i < table->columns()->size(); i++) {
78+
DataType dtype = ::tensorflow::DataType::DT_INVALID;
79+
switch (table->columns()->Get(i)->values()->type()) {
80+
case ::arrow::ipc::feather::fbs::Type_BOOL:
81+
dtype = ::tensorflow::DataType::DT_BOOL;
82+
break;
83+
case ::arrow::ipc::feather::fbs::Type_INT8:
84+
dtype = ::tensorflow::DataType::DT_INT8;
85+
break;
86+
case ::arrow::ipc::feather::fbs::Type_INT16:
87+
dtype = ::tensorflow::DataType::DT_INT16;
88+
break;
89+
case ::arrow::ipc::feather::fbs::Type_INT32:
90+
dtype = ::tensorflow::DataType::DT_INT32;
91+
break;
92+
case ::arrow::ipc::feather::fbs::Type_INT64:
93+
dtype = ::tensorflow::DataType::DT_INT64;
94+
break;
95+
case ::arrow::ipc::feather::fbs::Type_UINT8:
96+
dtype = ::tensorflow::DataType::DT_UINT8;
97+
break;
98+
case ::arrow::ipc::feather::fbs::Type_UINT16:
99+
dtype = ::tensorflow::DataType::DT_UINT16;
100+
break;
101+
case ::arrow::ipc::feather::fbs::Type_UINT32:
102+
dtype = ::tensorflow::DataType::DT_UINT32;
103+
break;
104+
case ::arrow::ipc::feather::fbs::Type_UINT64:
105+
dtype = ::tensorflow::DataType::DT_UINT64;
106+
break;
107+
case ::arrow::ipc::feather::fbs::Type_FLOAT:
108+
dtype = ::tensorflow::DataType::DT_FLOAT;
109+
break;
110+
case ::arrow::ipc::feather::fbs::Type_DOUBLE:
111+
dtype = ::tensorflow::DataType::DT_DOUBLE;
112+
break;
113+
case ::arrow::ipc::feather::fbs::Type_UTF8:
114+
case ::arrow::ipc::feather::fbs::Type_BINARY:
115+
case ::arrow::ipc::feather::fbs::Type_CATEGORY:
116+
case ::arrow::ipc::feather::fbs::Type_TIMESTAMP:
117+
case ::arrow::ipc::feather::fbs::Type_DATE:
118+
case ::arrow::ipc::feather::fbs::Type_TIME:
119+
// case ::arrow::ipc::feather::fbs::Type_LARGE_UTF8:
120+
// case ::arrow::ipc::feather::fbs::Type_LARGE_BINARY:
121+
default:
122+
break;
123+
}
124+
columns.push_back(table->columns()->Get(i)->name()->str());
125+
dtypes.push_back(::tensorflow::DataTypeString(dtype));
126+
counts.push_back(table->num_rows());
127+
}
128+
129+
TensorShape output_shape = filename_tensor.shape();
130+
output_shape.AddDim(columns.size());
131+
132+
Tensor* columns_tensor;
133+
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &columns_tensor));
134+
Tensor* dtypes_tensor;
135+
OP_REQUIRES_OK(context, context->allocate_output(1, output_shape, &dtypes_tensor));
136+
137+
output_shape.AddDim(1);
138+
139+
Tensor* shapes_tensor;
140+
OP_REQUIRES_OK(context, context->allocate_output(2, output_shape, &shapes_tensor));
141+
142+
for (size_t i = 0; i < columns.size(); i++) {
143+
columns_tensor->flat<string>()(i) = columns[i];
144+
dtypes_tensor->flat<string>()(i) = dtypes[i];
145+
shapes_tensor->flat<int64>()(i) = counts[i];
146+
}
147+
}
148+
private:
149+
mutex mu_;
150+
Env* env_ GUARDED_BY(mu_);
151+
};
152+
153+
REGISTER_KERNEL_BUILDER(Name("ListFeatherColumns").Device(DEVICE_CPU),
154+
ListFeatherColumnsOp);
155+
156+
157+
} // namespace
158+
} // namespace data
159+
} // namespace tensorflow
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "kernels/stream.h"
17+
#include "arrow/io/api.h"
18+
#include "arrow/buffer.h"
19+
20+
namespace tensorflow {
21+
namespace data {
22+
23+
// NOTE: Both SizedRandomAccessFile and ArrowRandomAccessFile overlap
24+
// with another PR. Will remove duplicate once PR merged
25+
26+
class ArrowRandomAccessFile : public ::arrow::io::RandomAccessFile {
27+
public:
28+
explicit ArrowRandomAccessFile(tensorflow::RandomAccessFile *file, int64 size)
29+
: file_(file)
30+
, size_(size) { }
31+
32+
~ArrowRandomAccessFile() {}
33+
arrow::Status Close() override {
34+
return arrow::Status::OK();
35+
}
36+
arrow::Status Tell(int64_t* position) const override {
37+
return arrow::Status::NotImplemented("Tell");
38+
}
39+
arrow::Status Seek(int64_t position) override {
40+
return arrow::Status::NotImplemented("Seek");
41+
}
42+
arrow::Status Read(int64_t nbytes, int64_t* bytes_read, void* out) override {
43+
return arrow::Status::NotImplemented("Read (void*)");
44+
}
45+
arrow::Status Read(int64_t nbytes, std::shared_ptr<arrow::Buffer>* out) override {
46+
return arrow::Status::NotImplemented("Read (Buffer*)");
47+
}
48+
arrow::Status GetSize(int64_t* size) override {
49+
*size = size_;
50+
return arrow::Status::OK();
51+
}
52+
bool supports_zero_copy() const override {
53+
return false;
54+
}
55+
arrow::Status ReadAt(int64_t position, int64_t nbytes, int64_t* bytes_read, void* out) override {
56+
StringPiece result;
57+
Status status = file_->Read(position, nbytes, &result, (char*)out);
58+
if (!(status.ok() || errors::IsOutOfRange(status))) {
59+
return arrow::Status::IOError(status.error_message());
60+
}
61+
*bytes_read = result.size();
62+
return arrow::Status::OK();
63+
}
64+
arrow::Status ReadAt(int64_t position, int64_t nbytes, std::shared_ptr<arrow::Buffer>* out) override {
65+
string buffer;
66+
buffer.resize(nbytes);
67+
StringPiece result;
68+
Status status = file_->Read(position, nbytes, &result, (char*)(&buffer[0]));
69+
if (!(status.ok() || errors::IsOutOfRange(status))) {
70+
return arrow::Status::IOError(status.error_message());
71+
}
72+
buffer.resize(result.size());
73+
return arrow::Buffer::FromString(buffer, out);
74+
}
75+
private:
76+
tensorflow::RandomAccessFile* file_;
77+
int64 size_;
78+
};
79+
} // namespace data
80+
} // namespace tensorflow

tensorflow_io/arrow/ops/dataset_ops.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,4 +67,18 @@ Creates a dataset that connects to a host serving Arrow RecordBatches in stream
6767
endpoints: One or more host addresses that are serving an Arrow stream.
6868
)doc");
6969

70+
71+
REGISTER_OP("ListFeatherColumns")
72+
.Input("filename: string")
73+
.Input("memory: string")
74+
.Output("columns: string")
75+
.Output("dtypes: string")
76+
.Output("shapes: int64")
77+
.SetShapeFn([](shape_inference::InferenceContext* c) {
78+
c->set_output(0, c->MakeShape({c->UnknownDim()}));
79+
c->set_output(1, c->MakeShape({c->UnknownDim()}));
80+
c->set_output(2, c->MakeShape({c->UnknownDim(), c->UnknownDim()}));
81+
return Status::OK();
82+
});
83+
7084
} // namespace tensorflow

tensorflow_io/arrow/python/ops/arrow_dataset_ops.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@
3030
from tensorflow.compat.v2 import data
3131
from tensorflow.python.data.ops.dataset_ops import flat_structure
3232
from tensorflow.python.data.util import structure as structure_lib
33-
from tensorflow_io import _load_library
34-
arrow_ops = _load_library('_arrow_ops.so')
33+
from tensorflow_io.core.python.ops import core_ops
3534

3635
if hasattr(tf, "nest"):
3736
from tensorflow import nest # pylint: disable=ungrouped-imports
@@ -183,7 +182,7 @@ def __init__(self,
183182
"auto" (size to number of records in Arrow record batch)
184183
"""
185184
super(ArrowDataset, self).__init__(
186-
partial(arrow_ops.arrow_dataset, serialized_batches),
185+
partial(core_ops.arrow_dataset, serialized_batches),
187186
columns,
188187
output_types,
189188
output_shapes,
@@ -316,7 +315,7 @@ def __init__(self,
316315
dtype=dtypes.string,
317316
name="filenames")
318317
super(ArrowFeatherDataset, self).__init__(
319-
partial(arrow_ops.arrow_feather_dataset, filenames),
318+
partial(core_ops.arrow_feather_dataset, filenames),
320319
columns,
321320
output_types,
322321
output_shapes,
@@ -401,7 +400,7 @@ def __init__(self,
401400
dtype=dtypes.string,
402401
name="endpoints")
403402
super(ArrowStreamDataset, self).__init__(
404-
partial(arrow_ops.arrow_stream_dataset, endpoints),
403+
partial(core_ops.arrow_stream_dataset, endpoints),
405404
columns,
406405
output_types,
407406
output_shapes,
@@ -594,3 +593,15 @@ def gen_record_batches():
594593
batch_size=batch_size,
595594
batch_mode='keep_remainder',
596595
record_batch_iter_factory=gen_record_batches)
596+
597+
def list_feather_columns(filename, **kwargs):
598+
"""list_feather_columns"""
599+
if not tf.executing_eagerly():
600+
raise NotImplementedError("list_feather_columns only support eager mode")
601+
memory = kwargs.get("memory", "")
602+
columns, dtypes_, shapes = core_ops.list_feather_columns(
603+
filename, memory=memory)
604+
entries = zip(tf.unstack(columns), tf.unstack(dtypes_), tf.unstack(shapes))
605+
return dict([(column.numpy().decode(), tf.TensorSpec(
606+
shape.numpy(), dtype.numpy().decode(), column.numpy().decode())) for (
607+
column, dtype, shape) in entries])

tensorflow_io/core/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ cc_binary(
125125
linkshared = 1,
126126
deps = [
127127
":core_ops",
128+
"//tensorflow_io/arrow:arrow_ops",
128129
"//tensorflow_io/audio:audio_ops",
129130
"//tensorflow_io/avro:avro_ops",
130131
"//tensorflow_io/azure:azfs_ops",

tensorflow_io/parquet/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ cc_library(
1616
copts = tf_io_copts(),
1717
linkstatic = True,
1818
deps = [
19+
"//tensorflow_io/arrow:arrow_ops",
1920
"//tensorflow_io/core:dataset_ops",
20-
"@arrow",
2121
],
2222
)

0 commit comments

Comments
 (0)