@@ -16,15 +16,9 @@ limitations under the License.
1616#include " tensorflow/core/framework/op_kernel.h"
1717#include " arrow/io/api.h"
1818#include " arrow/ipc/feather.h"
19- #include " arrow/table.h"
19+ #include " arrow/ipc/feather_generated.h"
20+ #include " arrow/buffer.h"
2021
21- namespace arrow {
22- namespace adapters {
23- namespace tensorflow {
24- Status GetTensorFlowType (std::shared_ptr<DataType> dtype, ::tensorflow::DataType* out);
25- }
26- }
27- }
2822namespace tensorflow {
2923namespace data {
3024namespace {
@@ -150,75 +144,93 @@ class ListFeatherColumnsOp : public OpKernel {
150144 uint64 size;
151145 OP_REQUIRES_OK (context, file->GetFileSize (&size));
152146
153- std::shared_ptr<ArrowRandomAccessFile> feather_file (new ArrowRandomAccessFile (file.get (), size));
147+ // FEA1.....[metadata][uint32 metadata_length]FEA1
148+ static constexpr const char * kFeatherMagicBytes = " FEA1" ;
149+
150+ size_t header_length = strlen (kFeatherMagicBytes );
151+ size_t footer_length = sizeof (uint32) + strlen (kFeatherMagicBytes );
152+
153+ string buffer;
154+ buffer.resize (header_length > footer_length ? header_length : footer_length);
155+
156+ StringPiece result;
157+
158+ OP_REQUIRES_OK (context, file->Read (0 , header_length, &result, &buffer[0 ]));
159+ OP_REQUIRES (context, !memcmp (buffer.data (), kFeatherMagicBytes , header_length), errors::InvalidArgument (" not a feather file" ));
160+
161+ OP_REQUIRES_OK (context, file->Read (size - footer_length, footer_length, &result, &buffer[0 ]));
162+ OP_REQUIRES (context, !memcmp (buffer.data () + sizeof (uint32), kFeatherMagicBytes , footer_length - sizeof (uint32)), errors::InvalidArgument (" incomplete feather file" ));
154163
155- std::unique_ptr<arrow::ipc::feather::TableReader> reader;
156- arrow::Status s = arrow::ipc::feather::TableReader::Open (feather_file, &reader);
157- OP_REQUIRES (context, s.ok (), errors::Internal (s.ToString ()));
164+ uint32 metadata_length = *reinterpret_cast <const uint32*>(buffer.data ());
165+
166+ buffer.resize (metadata_length);
167+
168+ OP_REQUIRES_OK (context, file->Read (size - footer_length - metadata_length, metadata_length, &result, &buffer[0 ]));
169+
170+ const ::arrow::ipc::feather::fbs::CTable* table = ::arrow::ipc::feather::fbs::GetCTable (buffer.data ());
171+
172+ OP_REQUIRES (context, (table->version () >= ::arrow::ipc::feather::kFeatherVersion ), errors::InvalidArgument (" feather file is old: " , table->version (), " vs. " , ::arrow::ipc::feather::kFeatherVersion ));
158173
159174 std::vector<string> columns;
160175 std::vector<string> dtypes;
161176 std::vector<int64> counts;
162- columns.reserve (reader->num_columns ());
163- dtypes.reserve (reader->num_columns ());
164- counts.reserve (reader->num_columns ());
165- for (int i = 0 ; i < reader->num_columns (); i++) {
166- std::shared_ptr<arrow::Column> column;
167- s = reader->GetColumn (i, &column);
168- OP_REQUIRES (context, s.ok (), errors::Internal (s.ToString ()));
169-
170- ::tensorflow::DataType data_type;
171- s = ::arrow::adapters::tensorflow::GetTensorFlowType (column->type (), &data_type);
172- if (!s.ok ()) {
173- continue ;
174- }
177+ columns.reserve (table->columns ()->size ());
178+ dtypes.reserve (table->columns ()->size ());
179+ counts.reserve (table->columns ()->size ());
180+
181+ for (int64 i = 0 ; i < table->columns ()->size (); i++) {
175182 string dtype = " " ;
176- switch (data_type ) {
177- case ::tensorflow::DT_BOOL :
183+ switch (table-> columns ()-> Get (i)-> values ()-> type () ) {
184+ case ::arrow::ipc::feather::fbs::Type_BOOL :
178185 dtype = " bool" ;
179186 break ;
180- case ::tensorflow::DT_UINT8:
181- dtype = " uint8" ;
182- break ;
183- case ::tensorflow::DT_INT8:
187+ case ::arrow::ipc::feather::fbs::Type_INT8:
184188 dtype = " int8" ;
185189 break ;
186- case ::tensorflow::DT_UINT16:
187- dtype = " uint16" ;
188- break ;
189- case ::tensorflow::DT_INT16:
190+ case ::arrow::ipc::feather::fbs::Type_INT16:
190191 dtype = " int16" ;
191192 break ;
192- case ::tensorflow::DT_UINT32:
193- dtype = " uint32" ;
194- break ;
195- case ::tensorflow::DT_INT32:
193+ case ::arrow::ipc::feather::fbs::Type_INT32:
196194 dtype = " int32" ;
197195 break ;
198- case ::tensorflow::DT_UINT64:
199- dtype = " uint64" ;
200- break ;
201- case ::tensorflow::DT_INT64:
196+ case ::arrow::ipc::feather::fbs::Type_INT64:
202197 dtype = " int64" ;
203198 break ;
204- case ::tensorflow::DT_HALF :
205- dtype = " half " ;
199+ case ::arrow::ipc::feather::fbs::Type_UINT8 :
200+ dtype = " uint8 " ;
206201 break ;
207- case ::tensorflow::DT_FLOAT:
202+ case ::arrow::ipc::feather::fbs::Type_UINT16:
203+ dtype = " uint16" ;
204+ break ;
205+ case ::arrow::ipc::feather::fbs::Type_UINT32:
206+ dtype = " uint32" ;
207+ break ;
208+ case ::arrow::ipc::feather::fbs::Type_UINT64:
209+ dtype = " uint64" ;
210+ break ;
211+ case ::arrow::ipc::feather::fbs::Type_FLOAT:
208212 dtype = " float" ;
209213 break ;
210- case ::tensorflow::DT_DOUBLE :
214+ case ::arrow::ipc::feather::fbs::Type_DOUBLE :
211215 dtype = " double" ;
212216 break ;
217+ case ::arrow::ipc::feather::fbs::Type_UTF8:
218+ case ::arrow::ipc::feather::fbs::Type_BINARY:
219+ case ::arrow::ipc::feather::fbs::Type_CATEGORY:
220+ case ::arrow::ipc::feather::fbs::Type_TIMESTAMP:
221+ case ::arrow::ipc::feather::fbs::Type_DATE:
222+ case ::arrow::ipc::feather::fbs::Type_TIME:
223+ // case ::arrow::ipc::feather::fbs::Type_LARGE_UTF8:
224+ // case ::arrow::ipc::feather::fbs::Type_LARGE_BINARY:
213225 default :
214- break ;
226+ break ;
215227 }
216228 if (dtype == " " ) {
217229 continue ;
218230 }
219- columns.push_back (reader-> GetColumnName (i ));
231+ columns.push_back (table-> columns ()-> Get (i)-> name ()-> str ( ));
220232 dtypes.push_back (dtype);
221- counts.push_back (reader ->num_rows ());
233+ counts.push_back (table ->num_rows ());
222234 }
223235
224236 TensorShape output_shape = filename_tensor.shape ();
@@ -234,7 +246,7 @@ class ListFeatherColumnsOp : public OpKernel {
234246 Tensor* shapes_tensor;
235247 OP_REQUIRES_OK (context, context->allocate_output (2 , output_shape, &shapes_tensor));
236248
237- for (int i = 0 ; i < columns.size (); i++) {
249+ for (size_t i = 0 ; i < columns.size (); i++) {
238250 columns_tensor->flat <string>()(i) = columns[i];
239251 dtypes_tensor->flat <string>()(i) = dtypes[i];
240252 shapes_tensor->flat <int64>()(i) = counts[i];
0 commit comments