@@ -15,6 +15,7 @@ limitations under the License.
1515
1616#include " tensorflow/core/framework/op_kernel.h"
1717#include " tensorflow_io/arrow/kernels/arrow_kernels.h"
18+ #include " tensorflow_io/core/kernels/io_interface.h"
1819#include " parquet/api/reader.h"
1920
2021namespace tensorflow {
@@ -218,5 +219,173 @@ REGISTER_KERNEL_BUILDER(Name("ReadParquet").Device(DEVICE_CPU),
218219
219220
220221} // namespace
222+
223+
224+ class ParquetIndexable : public IOIndexableInterface {
225+ public:
226+ ParquetIndexable (Env* env)
227+ : env_(env) {}
228+
229+ ~ParquetIndexable () {}
230+ Status Init (const std::vector<string>& input, const std::vector<string>& metadata, const void * memory_data, const int64 memory_size) override {
231+ if (input.size () > 1 ) {
232+ return errors::InvalidArgument (" more than 1 filename is not supported" );
233+ }
234+ const string& filename = input[0 ];
235+ file_.reset (new SizedRandomAccessFile (env_, filename, memory_data, memory_size));
236+ TF_RETURN_IF_ERROR (file_->GetFileSize (&file_size_));
237+
238+ parquet_file_.reset (new ArrowRandomAccessFile (file_.get (), file_size_));
239+ parquet_reader_ = parquet::ParquetFileReader::Open (parquet_file_);
240+ parquet_metadata_ = parquet_reader_->metadata ();
241+
242+ shapes_.clear ();
243+ dtypes_.clear ();
244+ columns_.clear ();
245+ for (size_t i = 0 ; i < parquet_metadata_->num_columns (); i++) {
246+ ::tensorflow::DataType dtype;
247+ switch (parquet_metadata_->schema ()->Column (i)->physical_type ()) {
248+ case parquet::Type::BOOLEAN:
249+ dtype = ::tensorflow::DT_BOOL;
250+ break ;
251+ case parquet::Type::INT32:
252+ dtype = ::tensorflow::DT_INT32;
253+ break ;
254+ case parquet::Type::INT64:
255+ dtype = ::tensorflow::DT_INT64;
256+ break ;
257+ case parquet::Type::INT96: // Deprecated, thrown out exception when access with __getitem__
258+ dtype = ::tensorflow::DT_INT64;
259+ break ;
260+ case parquet::Type::FLOAT:
261+ dtype = ::tensorflow::DT_FLOAT;
262+ break ;
263+ case parquet::Type::DOUBLE:
264+ dtype = ::tensorflow::DT_DOUBLE;
265+ break ;
266+ case parquet::Type::BYTE_ARRAY:
267+ dtype = ::tensorflow::DT_STRING;
268+ break ;
269+ case parquet::Type::FIXED_LEN_BYTE_ARRAY:
270+ dtype = ::tensorflow::DT_STRING;
271+ break ;
272+ default :
273+ return errors::InvalidArgument (" parquet data type is not supported: " , parquet_metadata_->schema ()->Column (i)->physical_type ());
274+ break ;
275+ }
276+ shapes_.push_back (TensorShape ({static_cast <int64>(parquet_metadata_->num_rows ())}));
277+ dtypes_.push_back (dtype);
278+ columns_.push_back (parquet_metadata_->schema ()->Column (i)->path ().get ()->ToDotString ());
279+ }
280+
281+ return Status::OK ();
282+ }
283+ Status Spec (std::vector<PartialTensorShape>& shapes, std::vector<DataType>& dtypes) override {
284+ shapes.clear ();
285+ for (size_t i = 0 ; i < shapes_.size (); i++) {
286+ shapes.push_back (shapes_[i]);
287+ }
288+ dtypes.clear ();
289+ for (size_t i = 0 ; i < dtypes_.size (); i++) {
290+ dtypes.push_back (dtypes_[i]);
291+ }
292+ return Status::OK ();
293+ }
294+
295+ Status Extra (std::vector<Tensor>* extra) override {
296+ // Expose columns
297+ Tensor columns (DT_STRING, TensorShape ({static_cast <int64>(columns_.size ())}));
298+ for (size_t i = 0 ; i < columns_.size (); i++) {
299+ columns.flat <string>()(i) = columns_[i];
300+ }
301+ extra->push_back (columns);
302+ return Status::OK ();
303+ }
304+
305+ Status GetItem (const int64 start, const int64 stop, const int64 step, const int64 component, Tensor* tensor) override {
306+ if (step != 1 ) {
307+ return errors::InvalidArgument (" step " , step, " is not supported" );
308+ }
309+ int64 row_group_offset = 0 ;
310+ for (int row_group = 0 ; row_group < parquet_metadata_->num_row_groups (); row_group++) {
311+ std::shared_ptr<parquet::RowGroupReader> row_group_reader = parquet_reader_->RowGroup (row_group);
312+ // Skip if row group is not within [start..stop]
313+ if ((row_group_offset + row_group_reader->metadata ()->num_rows () < start) || (stop <= row_group_offset)) {
314+ row_group_offset += row_group_reader->metadata ()->num_rows ();
315+ continue ;
316+ }
317+ // Find row_to_read range
318+ int64 row_to_read_start = row_group_offset > start ? row_group_offset : start;
319+ int64 row_to_read_final = (row_group_offset + row_group_reader->metadata ()->num_rows ()) < (stop) ? (row_group_offset + row_group_reader->metadata ()->num_rows ()) : (stop);
320+ int64 row_to_read_count = row_to_read_final - row_to_read_start;
321+
322+ // TODO: parquet is RowGroup based so ideally the RowGroup should be cached
323+ // with the hope of indexing and slicing happens on each row. For now no caching
324+ // is done yet.
325+ std::shared_ptr<parquet::ColumnReader> column_reader = row_group_reader->Column (component);
326+
327+ // buffer to fill location is tensor.data()[row_to_read_start - start]
328+
329+ #define PARQUET_PROCESS_TYPE (ptype, type ) { \
330+ parquet::TypedColumnReader<ptype>* reader = \
331+ static_cast <parquet::TypedColumnReader<ptype>*>( \
332+ column_reader.get ()); \
333+ if (row_to_read_start > row_group_offset) { \
334+ reader->Skip (row_to_read_start - row_group_offset); \
335+ } \
336+ ptype::c_type* value = (ptype::c_type *)(void *)(&(tensor->flat <type>().data ()[row_to_read_start - start])); \
337+ int64_t values_read; \
338+ int64_t levels_read = reader->ReadBatch (row_to_read_count, nullptr , nullptr , value, &values_read); \
339+ if (!(levels_read == values_read && levels_read == row_to_read_count)) { \
340+ return errors::InvalidArgument (" null value in column: " , columns_[component]); \
341+ } \
342+ }
343+ switch (parquet_metadata_->schema ()->Column (component)->physical_type ()) {
344+ case parquet::Type::BOOLEAN:
345+ PARQUET_PROCESS_TYPE (parquet::BooleanType, bool );
346+ break ;
347+ case parquet::Type::INT32:
348+ PARQUET_PROCESS_TYPE (parquet::Int32Type, int32);
349+ break ;
350+ case parquet::Type::INT64:
351+ PARQUET_PROCESS_TYPE (parquet::Int64Type, int64);
352+ break ;
353+ case parquet::Type::FLOAT:
354+ PARQUET_PROCESS_TYPE (parquet::FloatType, float );
355+ break ;
356+ case parquet::Type::DOUBLE:
357+ PARQUET_PROCESS_TYPE (parquet::DoubleType, double );
358+ break ;
359+ default :
360+ return errors::InvalidArgument (" invalid data type: " , parquet_metadata_->schema ()->Column (component)->physical_type ());
361+ }
362+ row_group_offset += row_group_reader->metadata ()->num_rows ();
363+ }
364+ return Status::OK ();
365+ }
366+
367+ string DebugString () const override {
368+ mutex_lock l (mu_);
369+ return strings::StrCat (" ParquetIndexable" );
370+ }
371+ private:
372+ mutable mutex mu_;
373+ Env* env_ GUARDED_BY (mu_);
374+ std::unique_ptr<SizedRandomAccessFile> file_ GUARDED_BY (mu_);
375+ uint64 file_size_ GUARDED_BY (mu_);
376+ std::shared_ptr<ArrowRandomAccessFile> parquet_file_;
377+ std::unique_ptr<::parquet::ParquetFileReader> parquet_reader_;
378+ std::shared_ptr<::parquet::FileMetaData> parquet_metadata_;
379+
380+ std::vector<DataType> dtypes_;
381+ std::vector<TensorShape> shapes_;
382+ std::vector<string> columns_;
383+ std::vector<int > columns_index_;
384+ };
385+
386+ REGISTER_KERNEL_BUILDER (Name(" ParquetIndexableInit" ).Device(DEVICE_CPU),
387+ IOInterfaceInitOp<ParquetIndexable>);
388+ REGISTER_KERNEL_BUILDER (Name(" ParquetIndexableGetItem" ).Device(DEVICE_CPU),
389+ IOIndexableGetItemOp<ParquetIndexable>);
221390} // namespace data
222391} // namespace tensorflow
0 commit comments