@@ -15,10 +15,12 @@ limitations under the License.
1515
1616#include " tensorflow/core/framework/op_kernel.h"
1717#include " tensorflow_io/arrow/kernels/arrow_kernels.h"
18+ #include " tensorflow_io/core/kernels/io_interface.h"
1819#include " arrow/io/api.h"
1920#include " arrow/ipc/feather.h"
2021#include " arrow/ipc/feather_generated.h"
2122#include " arrow/buffer.h"
23+ #include " arrow/table.h"
2224
2325namespace tensorflow {
2426namespace data {
@@ -155,5 +157,251 @@ REGISTER_KERNEL_BUILDER(Name("ListFeatherColumns").Device(DEVICE_CPU),
155157
156158
157159} // namespace
160+
161+
162+ class FeatherIndexable : public IOIndexableInterface {
163+ public:
164+ FeatherIndexable (Env* env)
165+ : env_(env) {}
166+
167+ ~FeatherIndexable () {}
168+ Status Init (const std::vector<string>& input, const std::vector<string>& metadata, const void * memory_data, const int64 memory_size) override {
169+ if (input.size () > 1 ) {
170+ return errors::InvalidArgument (" more than 1 filename is not supported" );
171+ }
172+
173+ const string& filename = input[0 ];
174+ file_.reset (new SizedRandomAccessFile (env_, filename, memory_data, memory_size));
175+ TF_RETURN_IF_ERROR (file_->GetFileSize (&file_size_));
176+
177+ // FEA1.....[metadata][uint32 metadata_length]FEA1
178+ static constexpr const char * kFeatherMagicBytes = " FEA1" ;
179+
180+ size_t header_length = strlen (kFeatherMagicBytes );
181+ size_t footer_length = sizeof (uint32) + strlen (kFeatherMagicBytes );
182+
183+ string buffer;
184+ buffer.resize (header_length > footer_length ? header_length : footer_length);
185+
186+ StringPiece result;
187+
188+ TF_RETURN_IF_ERROR (file_->Read (0 , header_length, &result, &buffer[0 ]));
189+ if (memcmp (buffer.data (), kFeatherMagicBytes , header_length) != 0 ) {
190+ return errors::InvalidArgument (" not a feather file" );
191+ }
192+
193+ TF_RETURN_IF_ERROR (file_->Read (file_size_ - footer_length, footer_length, &result, &buffer[0 ]));
194+ if (memcmp (buffer.data () + sizeof (uint32), kFeatherMagicBytes , footer_length - sizeof (uint32)) != 0 ) {
195+ return errors::InvalidArgument (" incomplete feather file" );
196+ }
197+
198+ uint32 metadata_length = *reinterpret_cast <const uint32*>(buffer.data ());
199+
200+ buffer.resize (metadata_length);
201+
202+ TF_RETURN_IF_ERROR (file_->Read (file_size_ - footer_length - metadata_length, metadata_length, &result, &buffer[0 ]));
203+
204+ const ::arrow::ipc::feather::fbs::CTable* table = ::arrow::ipc::feather::fbs::GetCTable (buffer.data ());
205+
206+ if (table->version () < ::arrow::ipc::feather::kFeatherVersion ) {
207+ return errors::InvalidArgument (" feather file is old: " , table->version (), " vs. " , ::arrow::ipc::feather::kFeatherVersion );
208+ }
209+
210+ std::vector<string> columns;
211+ for (size_t i = 0 ; i < metadata.size (); i++) {
212+ if (metadata[i].find_first_of (" column: " ) == 0 ) {
213+ columns.emplace_back (metadata[i].substr (8 ));
214+ }
215+ }
216+
217+ columns_index_.clear ();
218+ if (columns.size () == 0 ) {
219+ for (int i = 0 ; i < table->columns ()->size (); i++) {
220+ columns_index_.push_back (i);
221+ }
222+ } else {
223+ std::unordered_map<string, int > columns_map;
224+ for (int i = 0 ; i < table->columns ()->size (); i++) {
225+ columns_map[table->columns ()->Get (i)->name ()->str ()] = i;
226+ }
227+ for (size_t i = 0 ; i < columns.size (); i++) {
228+ columns_index_.push_back (columns_map[columns[i]]);
229+ }
230+ }
231+
232+ for (size_t i = 0 ; i < columns_index_.size (); i++) {
233+ int column_index = columns_index_[i];
234+ ::tensorflow::DataType dtype = ::tensorflow::DataType::DT_INVALID;
235+ switch (table->columns ()->Get (column_index)->values ()->type ()) {
236+ case ::arrow::ipc::feather::fbs::Type_BOOL:
237+ dtype = ::tensorflow::DataType::DT_BOOL;
238+ break ;
239+ case ::arrow::ipc::feather::fbs::Type_INT8:
240+ dtype = ::tensorflow::DataType::DT_INT8;
241+ break ;
242+ case ::arrow::ipc::feather::fbs::Type_INT16:
243+ dtype = ::tensorflow::DataType::DT_INT16;
244+ break ;
245+ case ::arrow::ipc::feather::fbs::Type_INT32:
246+ dtype = ::tensorflow::DataType::DT_INT32;
247+ break ;
248+ case ::arrow::ipc::feather::fbs::Type_INT64:
249+ dtype = ::tensorflow::DataType::DT_INT64;
250+ break ;
251+ case ::arrow::ipc::feather::fbs::Type_UINT8:
252+ dtype = ::tensorflow::DataType::DT_UINT8;
253+ break ;
254+ case ::arrow::ipc::feather::fbs::Type_UINT16:
255+ dtype = ::tensorflow::DataType::DT_UINT16;
256+ break ;
257+ case ::arrow::ipc::feather::fbs::Type_UINT32:
258+ dtype = ::tensorflow::DataType::DT_UINT32;
259+ break ;
260+ case ::arrow::ipc::feather::fbs::Type_UINT64:
261+ dtype = ::tensorflow::DataType::DT_UINT64;
262+ break ;
263+ case ::arrow::ipc::feather::fbs::Type_FLOAT:
264+ dtype = ::tensorflow::DataType::DT_FLOAT;
265+ break ;
266+ case ::arrow::ipc::feather::fbs::Type_DOUBLE:
267+ dtype = ::tensorflow::DataType::DT_DOUBLE;
268+ break ;
269+ case ::arrow::ipc::feather::fbs::Type_UTF8:
270+ case ::arrow::ipc::feather::fbs::Type_BINARY:
271+ case ::arrow::ipc::feather::fbs::Type_CATEGORY:
272+ case ::arrow::ipc::feather::fbs::Type_TIMESTAMP:
273+ case ::arrow::ipc::feather::fbs::Type_DATE:
274+ case ::arrow::ipc::feather::fbs::Type_TIME:
275+ // case ::arrow::ipc::feather::fbs::Type_LARGE_UTF8:
276+ // case ::arrow::ipc::feather::fbs::Type_LARGE_BINARY:
277+ default :
278+ break ;
279+ }
280+ dtypes_.push_back (dtype);
281+ shapes_.push_back (TensorShape ({static_cast <int64>(table->num_rows ())}));
282+ columns_.push_back (table->columns ()->Get (column_index)->name ()->str ());
283+ }
284+
285+ return Status::OK ();
286+ }
287+ Status Spec (std::vector<DataType>& dtypes, std::vector<PartialTensorShape>& shapes) override {
288+ dtypes.clear ();
289+ for (size_t i = 0 ; i < dtypes_.size (); i++) {
290+ dtypes.push_back (dtypes_[i]);
291+ }
292+ shapes.clear ();
293+ for (size_t i = 0 ; i < shapes_.size (); i++) {
294+ shapes.push_back (shapes_[i]);
295+ }
296+ return Status::OK ();
297+ }
298+
299+ Status Extra (std::vector<Tensor>* extra) override {
300+ // Expose columns
301+ Tensor columns (DT_STRING, TensorShape ({static_cast <int64>(columns_.size ())}));
302+ for (size_t i = 0 ; i < columns_.size (); i++) {
303+ columns.flat <string>()(i) = columns_[i];
304+ }
305+ extra->push_back (columns);
306+ return Status::OK ();
307+ }
308+
309+ Status GetItem (const int64 start, const int64 stop, const int64 step, std::vector<Tensor>& tensors) override {
310+ if (step != 1 ) {
311+ return errors::InvalidArgument (" step " , step, " is not supported" );
312+ }
313+
314+ if (feather_file_.get () == nullptr ) {
315+ feather_file_.reset (new ArrowRandomAccessFile (file_.get (), file_size_));
316+ arrow::Status s = arrow::ipc::feather::TableReader::Open (feather_file_, &reader_);
317+ if (!s.ok ()) {
318+ return errors::Internal (s.ToString ());
319+ }
320+ }
321+
322+ for (size_t i = 0 ; i < tensors.size (); i++) {
323+ int column_index = columns_index_[i];
324+
325+ std::shared_ptr<arrow::Column> column;
326+ arrow::Status s = reader_->GetColumn (column_index, &column);
327+ if (!s.ok ()) {
328+ return errors::Internal (s.ToString ());
329+ }
330+
331+ std::shared_ptr<::arrow::Column> slice = column->Slice (start, stop);
332+
333+ #define FEATHER_PROCESS_TYPE (TTYPE,ATYPE ) { \
334+ int64 curr_index = 0 ; \
335+ for (auto chunk : slice->data ()->chunks ()) { \
336+ for (int64_t item = 0 ; item < chunk->length (); item++) { \
337+ tensors[i].flat <TTYPE>()(curr_index) = (dynamic_cast <ATYPE *>(chunk.get ()))->Value (item); \
338+ curr_index++; \
339+ } \
340+ } \
341+ }
342+ switch (tensors[i].dtype ()) {
343+ case DT_BOOL:
344+ FEATHER_PROCESS_TYPE (bool , ::arrow::BooleanArray);
345+ break ;
346+ case DT_INT8:
347+ FEATHER_PROCESS_TYPE (int8, ::arrow::NumericArray<::arrow::Int8Type>);
348+ break ;
349+ case DT_UINT8:
350+ FEATHER_PROCESS_TYPE (uint8, ::arrow::NumericArray<::arrow::UInt8Type>);
351+ break ;
352+ case DT_INT16:
353+ FEATHER_PROCESS_TYPE (int16, ::arrow::NumericArray<::arrow::Int16Type>);
354+ break ;
355+ case DT_UINT16:
356+ FEATHER_PROCESS_TYPE (uint16, ::arrow::NumericArray<::arrow::UInt16Type>);
357+ break ;
358+ case DT_INT32:
359+ FEATHER_PROCESS_TYPE (int32, ::arrow::NumericArray<::arrow::Int32Type>);
360+ break ;
361+ case DT_UINT32:
362+ FEATHER_PROCESS_TYPE (uint32, ::arrow::NumericArray<::arrow::UInt32Type>);
363+ break ;
364+ case DT_INT64:
365+ FEATHER_PROCESS_TYPE (int64, ::arrow::NumericArray<::arrow::Int64Type>);
366+ break ;
367+ case DT_UINT64:
368+ FEATHER_PROCESS_TYPE (uint64, ::arrow::NumericArray<::arrow::UInt64Type>);
369+ break ;
370+ case DT_FLOAT:
371+ FEATHER_PROCESS_TYPE (float , ::arrow::NumericArray<::arrow::FloatType>);
372+ break ;
373+ case DT_DOUBLE:
374+ FEATHER_PROCESS_TYPE (double , ::arrow::NumericArray<::arrow::DoubleType>);
375+ break ;
376+ default :
377+ return errors::InvalidArgument (" data type is not supported: " , DataTypeString (tensors[i].dtype ()));
378+ }
379+ }
380+
381+ return Status::OK ();
382+ }
383+
384+ string DebugString () const override {
385+ mutex_lock l (mu_);
386+ return strings::StrCat (" FeatherIndexable" );
387+ }
388+ private:
389+ mutable mutex mu_;
390+ Env* env_ GUARDED_BY (mu_);
391+ std::unique_ptr<SizedRandomAccessFile> file_ GUARDED_BY (mu_);
392+ uint64 file_size_ GUARDED_BY (mu_);
393+ std::shared_ptr<ArrowRandomAccessFile> feather_file_ GUARDED_BY (mu_);
394+ std::unique_ptr<arrow::ipc::feather::TableReader> reader_ GUARDED_BY (mu_);
395+
396+ std::vector<DataType> dtypes_;
397+ std::vector<TensorShape> shapes_;
398+ std::vector<string> columns_;
399+ std::vector<int > columns_index_;
400+ };
401+
402+ REGISTER_KERNEL_BUILDER (Name(" FeatherIndexableInit" ).Device(DEVICE_CPU),
403+ IOInterfaceInitOp<FeatherIndexable>);
404+ REGISTER_KERNEL_BUILDER (Name(" FeatherIndexableGetItem" ).Device(DEVICE_CPU),
405+ IOIndexableGetItemOp<FeatherIndexable>);
158406} // namespace data
159407} // namespace tensorflow
0 commit comments