@@ -15,11 +15,13 @@ limitations under the License.
1515
1616#include " tensorflow/core/framework/op_kernel.h"
1717#include " tensorflow_io/arrow/kernels/arrow_kernels.h"
18+ #include " tensorflow_io/core/kernels/io_interface.h"
1819#include " arrow/io/api.h"
1920#include " arrow/ipc/feather.h"
2021#include " arrow/ipc/feather_generated.h"
2122#include " arrow/buffer.h"
2223#include " arrow/adapters/tensorflow/convert.h"
24+ #include " arrow/table.h"
2325
2426namespace tensorflow {
2527namespace data {
@@ -173,5 +175,223 @@ REGISTER_KERNEL_BUILDER(Name("ListFeatherColumns").Device(DEVICE_CPU),
173175
174176
175177} // namespace
178+
179+
180+ class FeatherIndexable : public IOIndexableInterface {
181+ public:
182+ FeatherIndexable (Env* env)
183+ : env_(env) {}
184+
185+ ~FeatherIndexable () {}
186+ Status Init (const std::vector<string>& input, const std::vector<string>& metadata, const void * memory_data, const int64 memory_size) override {
187+ if (input.size () > 1 ) {
188+ return errors::InvalidArgument (" more than 1 filename is not supported" );
189+ }
190+
191+ const string& filename = input[0 ];
192+ file_.reset (new SizedRandomAccessFile (env_, filename, memory_data, memory_size));
193+ TF_RETURN_IF_ERROR (file_->GetFileSize (&file_size_));
194+
195+ // FEA1.....[metadata][uint32 metadata_length]FEA1
196+ static constexpr const char * kFeatherMagicBytes = " FEA1" ;
197+
198+ size_t header_length = strlen (kFeatherMagicBytes );
199+ size_t footer_length = sizeof (uint32) + strlen (kFeatherMagicBytes );
200+
201+ string buffer;
202+ buffer.resize (header_length > footer_length ? header_length : footer_length);
203+
204+ StringPiece result;
205+
206+ TF_RETURN_IF_ERROR (file_->Read (0 , header_length, &result, &buffer[0 ]));
207+ if (memcmp (buffer.data (), kFeatherMagicBytes , header_length) != 0 ) {
208+ return errors::InvalidArgument (" not a feather file" );
209+ }
210+
211+ TF_RETURN_IF_ERROR (file_->Read (file_size_ - footer_length, footer_length, &result, &buffer[0 ]));
212+ if (memcmp (buffer.data () + sizeof (uint32), kFeatherMagicBytes , footer_length - sizeof (uint32)) != 0 ) {
213+ return errors::InvalidArgument (" incomplete feather file" );
214+ }
215+
216+ uint32 metadata_length = *reinterpret_cast <const uint32*>(buffer.data ());
217+
218+ buffer.resize (metadata_length);
219+
220+ TF_RETURN_IF_ERROR (file_->Read (file_size_ - footer_length - metadata_length, metadata_length, &result, &buffer[0 ]));
221+
222+ const ::arrow::ipc::feather::fbs::CTable* table = ::arrow::ipc::feather::fbs::GetCTable (buffer.data ());
223+
224+ if (table->version () < ::arrow::ipc::feather::kFeatherVersion ) {
225+ return errors::InvalidArgument (" feather file is old: " , table->version (), " vs. " , ::arrow::ipc::feather::kFeatherVersion );
226+ }
227+
228+ for (int i = 0 ; i < table->columns ()->size (); i++) {
229+ ::tensorflow::DataType dtype = ::tensorflow::DataType::DT_INVALID;
230+ switch (table->columns ()->Get (i)->values ()->type ()) {
231+ case ::arrow::ipc::feather::fbs::Type_BOOL:
232+ dtype = ::tensorflow::DataType::DT_BOOL;
233+ break ;
234+ case ::arrow::ipc::feather::fbs::Type_INT8:
235+ dtype = ::tensorflow::DataType::DT_INT8;
236+ break ;
237+ case ::arrow::ipc::feather::fbs::Type_INT16:
238+ dtype = ::tensorflow::DataType::DT_INT16;
239+ break ;
240+ case ::arrow::ipc::feather::fbs::Type_INT32:
241+ dtype = ::tensorflow::DataType::DT_INT32;
242+ break ;
243+ case ::arrow::ipc::feather::fbs::Type_INT64:
244+ dtype = ::tensorflow::DataType::DT_INT64;
245+ break ;
246+ case ::arrow::ipc::feather::fbs::Type_UINT8:
247+ dtype = ::tensorflow::DataType::DT_UINT8;
248+ break ;
249+ case ::arrow::ipc::feather::fbs::Type_UINT16:
250+ dtype = ::tensorflow::DataType::DT_UINT16;
251+ break ;
252+ case ::arrow::ipc::feather::fbs::Type_UINT32:
253+ dtype = ::tensorflow::DataType::DT_UINT32;
254+ break ;
255+ case ::arrow::ipc::feather::fbs::Type_UINT64:
256+ dtype = ::tensorflow::DataType::DT_UINT64;
257+ break ;
258+ case ::arrow::ipc::feather::fbs::Type_FLOAT:
259+ dtype = ::tensorflow::DataType::DT_FLOAT;
260+ break ;
261+ case ::arrow::ipc::feather::fbs::Type_DOUBLE:
262+ dtype = ::tensorflow::DataType::DT_DOUBLE;
263+ break ;
264+ case ::arrow::ipc::feather::fbs::Type_UTF8:
265+ case ::arrow::ipc::feather::fbs::Type_BINARY:
266+ case ::arrow::ipc::feather::fbs::Type_CATEGORY:
267+ case ::arrow::ipc::feather::fbs::Type_TIMESTAMP:
268+ case ::arrow::ipc::feather::fbs::Type_DATE:
269+ case ::arrow::ipc::feather::fbs::Type_TIME:
270+ // case ::arrow::ipc::feather::fbs::Type_LARGE_UTF8:
271+ // case ::arrow::ipc::feather::fbs::Type_LARGE_BINARY:
272+ default :
273+ break ;
274+ }
275+ shapes_.push_back (TensorShape ({static_cast <int64>(table->num_rows ())}));
276+ dtypes_.push_back (dtype);
277+ columns_.push_back (table->columns ()->Get (i)->name ()->str ());
278+ }
279+
280+ return Status::OK ();
281+ }
282+ Status Spec (std::vector<PartialTensorShape>& shapes, std::vector<DataType>& dtypes) override {
283+ shapes.clear ();
284+ for (size_t i = 0 ; i < shapes_.size (); i++) {
285+ shapes.push_back (shapes_[i]);
286+ }
287+ dtypes.clear ();
288+ for (size_t i = 0 ; i < dtypes_.size (); i++) {
289+ dtypes.push_back (dtypes_[i]);
290+ }
291+ return Status::OK ();
292+ }
293+
294+ Status Extra (std::vector<Tensor>* extra) override {
295+ // Expose columns
296+ Tensor columns (DT_STRING, TensorShape ({static_cast <int64>(columns_.size ())}));
297+ for (size_t i = 0 ; i < columns_.size (); i++) {
298+ columns.flat <string>()(i) = columns_[i];
299+ }
300+ extra->push_back (columns);
301+ return Status::OK ();
302+ }
303+
304+ Status GetItem (const int64 start, const int64 stop, const int64 step, const int64 component, Tensor* tensor) override {
305+ if (step != 1 ) {
306+ return errors::InvalidArgument (" step " , step, " is not supported" );
307+ }
308+
309+ if (feather_file_.get () == nullptr ) {
310+ feather_file_.reset (new ArrowRandomAccessFile (file_.get (), file_size_));
311+ arrow::Status s = arrow::ipc::feather::TableReader::Open (feather_file_, &reader_);
312+ if (!s.ok ()) {
313+ return errors::Internal (s.ToString ());
314+ }
315+ }
316+
317+ std::shared_ptr<arrow::Column> column;
318+ arrow::Status s = reader_->GetColumn (component, &column);
319+ if (!s.ok ()) {
320+ return errors::Internal (s.ToString ());
321+ }
322+
323+ std::shared_ptr<::arrow::Column> slice = column->Slice (start, stop);
324+
325+ #define FEATHER_PROCESS_TYPE (TTYPE,ATYPE ) { \
326+ int64 curr_index = 0 ; \
327+ for (auto chunk : slice->data ()->chunks ()) { \
328+ for (int64_t item = 0 ; item < chunk->length (); item++) { \
329+ tensor->flat <TTYPE>()(curr_index) = (dynamic_cast <ATYPE *>(chunk.get ()))->Value (item); \
330+ curr_index++; \
331+ } \
332+ } \
333+ }
334+ switch (tensor->dtype ()) {
335+ case DT_BOOL:
336+ FEATHER_PROCESS_TYPE (bool , ::arrow::BooleanArray);
337+ break ;
338+ case DT_INT8:
339+ FEATHER_PROCESS_TYPE (int8, ::arrow::NumericArray<::arrow::Int8Type>);
340+ break ;
341+ case DT_UINT8:
342+ FEATHER_PROCESS_TYPE (uint8, ::arrow::NumericArray<::arrow::UInt8Type>);
343+ break ;
344+ case DT_INT16:
345+ FEATHER_PROCESS_TYPE (int16, ::arrow::NumericArray<::arrow::Int16Type>);
346+ break ;
347+ case DT_UINT16:
348+ FEATHER_PROCESS_TYPE (uint16, ::arrow::NumericArray<::arrow::UInt16Type>);
349+ break ;
350+ case DT_INT32:
351+ FEATHER_PROCESS_TYPE (int32, ::arrow::NumericArray<::arrow::Int32Type>);
352+ break ;
353+ case DT_UINT32:
354+ FEATHER_PROCESS_TYPE (uint32, ::arrow::NumericArray<::arrow::UInt32Type>);
355+ break ;
356+ case DT_INT64:
357+ FEATHER_PROCESS_TYPE (int64, ::arrow::NumericArray<::arrow::Int64Type>);
358+ break ;
359+ case DT_UINT64:
360+ FEATHER_PROCESS_TYPE (uint64, ::arrow::NumericArray<::arrow::UInt64Type>);
361+ break ;
362+ case DT_FLOAT:
363+ FEATHER_PROCESS_TYPE (float , ::arrow::NumericArray<::arrow::FloatType>);
364+ break ;
365+ case DT_DOUBLE:
366+ FEATHER_PROCESS_TYPE (double , ::arrow::NumericArray<::arrow::DoubleType>);
367+ break ;
368+ default :
369+ return errors::InvalidArgument (" data type is not supported: " , DataTypeString (tensor->dtype ()));
370+ }
371+
372+ return Status::OK ();
373+ }
374+
375+ string DebugString () const override {
376+ mutex_lock l (mu_);
377+ return strings::StrCat (" FeatherIndexable" );
378+ }
379+ private:
380+ mutable mutex mu_;
381+ Env* env_ GUARDED_BY (mu_);
382+ std::unique_ptr<SizedRandomAccessFile> file_ GUARDED_BY (mu_);
383+ uint64 file_size_ GUARDED_BY (mu_);
384+ std::shared_ptr<ArrowRandomAccessFile> feather_file_ GUARDED_BY (mu_);
385+ std::unique_ptr<arrow::ipc::feather::TableReader> reader_ GUARDED_BY (mu_);
386+
387+ std::vector<DataType> dtypes_;
388+ std::vector<TensorShape> shapes_;
389+ std::vector<string> columns_;
390+ };
391+
392+ REGISTER_KERNEL_BUILDER (Name(" FeatherIndexableInit" ).Device(DEVICE_CPU),
393+ IOInterfaceInitOp<FeatherIndexable>);
394+ REGISTER_KERNEL_BUILDER (Name(" FeatherIndexableGetItem" ).Device(DEVICE_CPU),
395+ IOIndexableGetItemOp<FeatherIndexable>);
176396} // namespace data
177397} // namespace tensorflow
0 commit comments