Skip to content

Commit c34aa02

Browse files
committed
Use flatbuffer to read feather metadata, to avoid reading whole file through feather api.
Signed-off-by: Yong Tang <[email protected]>
1 parent 8aaaf3f commit c34aa02

File tree

1 file changed

+63
-51
lines changed

1 file changed

+63
-51
lines changed

tensorflow_io/arrow/kernels/arrow_kernels.cc

Lines changed: 63 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,9 @@ limitations under the License.
1616
#include "tensorflow/core/framework/op_kernel.h"
1717
#include "arrow/io/api.h"
1818
#include "arrow/ipc/feather.h"
19-
#include "arrow/table.h"
19+
#include "arrow/ipc/feather_generated.h"
20+
#include "arrow/buffer.h"
2021

21-
namespace arrow {
22-
namespace adapters {
23-
namespace tensorflow {
24-
Status GetTensorFlowType(std::shared_ptr<DataType> dtype, ::tensorflow::DataType* out);
25-
}
26-
}
27-
}
2822
namespace tensorflow {
2923
namespace data {
3024
namespace {
@@ -150,75 +144,93 @@ class ListFeatherColumnsOp : public OpKernel {
150144
uint64 size;
151145
OP_REQUIRES_OK(context, file->GetFileSize(&size));
152146

153-
std::shared_ptr<ArrowRandomAccessFile> feather_file(new ArrowRandomAccessFile(file.get(), size));
147+
// FEA1.....[metadata][uint32 metadata_length]FEA1
148+
static constexpr const char* kFeatherMagicBytes = "FEA1";
149+
150+
size_t header_length = strlen(kFeatherMagicBytes);
151+
size_t footer_length = sizeof(uint32) + strlen(kFeatherMagicBytes);
152+
153+
string buffer;
154+
buffer.resize(header_length > footer_length ? header_length : footer_length);
155+
156+
StringPiece result;
157+
158+
OP_REQUIRES_OK(context, file->Read(0, header_length, &result, &buffer[0]));
159+
OP_REQUIRES(context, !memcmp(buffer.data(), kFeatherMagicBytes, header_length), errors::InvalidArgument("not a feather file"));
160+
161+
OP_REQUIRES_OK(context, file->Read(size - footer_length, footer_length, &result, &buffer[0]));
162+
OP_REQUIRES(context, !memcmp(buffer.data() + sizeof(uint32), kFeatherMagicBytes, footer_length - sizeof(uint32)), errors::InvalidArgument("incomplete feather file"));
154163

155-
std::unique_ptr<arrow::ipc::feather::TableReader> reader;
156-
arrow::Status s = arrow::ipc::feather::TableReader::Open(feather_file, &reader);
157-
OP_REQUIRES(context, s.ok(), errors::Internal(s.ToString()));
164+
uint32 metadata_length = *reinterpret_cast<const uint32*>(buffer.data());
165+
166+
buffer.resize(metadata_length);
167+
168+
OP_REQUIRES_OK(context, file->Read(size - footer_length - metadata_length, metadata_length, &result, &buffer[0]));
169+
170+
const ::arrow::ipc::feather::fbs::CTable* table = ::arrow::ipc::feather::fbs::GetCTable(buffer.data());
171+
172+
OP_REQUIRES(context, (table->version() >= ::arrow::ipc::feather::kFeatherVersion), errors::InvalidArgument("feather file is old: ", table->version(), " vs. ", ::arrow::ipc::feather::kFeatherVersion));
158173

159174
std::vector<string> columns;
160175
std::vector<string> dtypes;
161176
std::vector<int64> counts;
162-
columns.reserve(reader->num_columns());
163-
dtypes.reserve(reader->num_columns());
164-
counts.reserve(reader->num_columns());
165-
for (int i = 0; i < reader->num_columns(); i++) {
166-
std::shared_ptr<arrow::Column> column;
167-
s = reader->GetColumn(i, &column);
168-
OP_REQUIRES(context, s.ok(), errors::Internal(s.ToString()));
169-
170-
::tensorflow::DataType data_type;
171-
s = ::arrow::adapters::tensorflow::GetTensorFlowType(column->type(), &data_type);
172-
if (!s.ok()) {
173-
continue;
174-
}
177+
columns.reserve(table->columns()->size());
178+
dtypes.reserve(table->columns()->size());
179+
counts.reserve(table->columns()->size());
180+
181+
for (int64 i = 0; i < table->columns()->size(); i++) {
175182
string dtype = "";
176-
switch (data_type) {
177-
case ::tensorflow::DT_BOOL:
183+
switch (table->columns()->Get(i)->values()->type()) {
184+
case ::arrow::ipc::feather::fbs::Type_BOOL:
178185
dtype = "bool";
179186
break;
180-
case ::tensorflow::DT_UINT8:
181-
dtype = "uint8";
182-
break;
183-
case ::tensorflow::DT_INT8:
187+
case ::arrow::ipc::feather::fbs::Type_INT8:
184188
dtype = "int8";
185189
break;
186-
case ::tensorflow::DT_UINT16:
187-
dtype = "uint16";
188-
break;
189-
case ::tensorflow::DT_INT16:
190+
case ::arrow::ipc::feather::fbs::Type_INT16:
190191
dtype = "int16";
191192
break;
192-
case ::tensorflow::DT_UINT32:
193-
dtype = "uint32";
194-
break;
195-
case ::tensorflow::DT_INT32:
193+
case ::arrow::ipc::feather::fbs::Type_INT32:
196194
dtype = "int32";
197195
break;
198-
case ::tensorflow::DT_UINT64:
199-
dtype = "uint64";
200-
break;
201-
case ::tensorflow::DT_INT64:
196+
case ::arrow::ipc::feather::fbs::Type_INT64:
202197
dtype = "int64";
203198
break;
204-
case ::tensorflow::DT_HALF:
205-
dtype = "half";
199+
case ::arrow::ipc::feather::fbs::Type_UINT8:
200+
dtype = "uint8";
206201
break;
207-
case ::tensorflow::DT_FLOAT:
202+
case ::arrow::ipc::feather::fbs::Type_UINT16:
203+
dtype = "uint16";
204+
break;
205+
case ::arrow::ipc::feather::fbs::Type_UINT32:
206+
dtype = "uint32";
207+
break;
208+
case ::arrow::ipc::feather::fbs::Type_UINT64:
209+
dtype = "uint64";
210+
break;
211+
case ::arrow::ipc::feather::fbs::Type_FLOAT:
208212
dtype = "float";
209213
break;
210-
case ::tensorflow::DT_DOUBLE:
214+
case ::arrow::ipc::feather::fbs::Type_DOUBLE:
211215
dtype = "double";
212216
break;
217+
case ::arrow::ipc::feather::fbs::Type_UTF8:
218+
case ::arrow::ipc::feather::fbs::Type_BINARY:
219+
case ::arrow::ipc::feather::fbs::Type_CATEGORY:
220+
case ::arrow::ipc::feather::fbs::Type_TIMESTAMP:
221+
case ::arrow::ipc::feather::fbs::Type_DATE:
222+
case ::arrow::ipc::feather::fbs::Type_TIME:
223+
// case ::arrow::ipc::feather::fbs::Type_LARGE_UTF8:
224+
// case ::arrow::ipc::feather::fbs::Type_LARGE_BINARY:
213225
default:
214-
break;
226+
break;
215227
}
216228
if (dtype == "") {
217229
continue;
218230
}
219-
columns.push_back(reader->GetColumnName(i));
231+
columns.push_back(table->columns()->Get(i)->name()->str());
220232
dtypes.push_back(dtype);
221-
counts.push_back(reader->num_rows());
233+
counts.push_back(table->num_rows());
222234
}
223235

224236
TensorShape output_shape = filename_tensor.shape();
@@ -234,7 +246,7 @@ class ListFeatherColumnsOp : public OpKernel {
234246
Tensor* shapes_tensor;
235247
OP_REQUIRES_OK(context, context->allocate_output(2, output_shape, &shapes_tensor));
236248

237-
for (int i = 0; i < columns.size(); i++) {
249+
for (size_t i = 0; i < columns.size(); i++) {
238250
columns_tensor->flat<string>()(i) = columns[i];
239251
dtypes_tensor->flat<string>()(i) = dtypes[i];
240252
shapes_tensor->flat<int64>()(i) = counts[i];

0 commit comments

Comments
 (0)