@@ -14,7 +14,8 @@ limitations under the License.
1414==============================================================================*/
1515
1616#include " tensorflow/core/framework/op_kernel.h"
17- #include " tensorflow_io/core/kernels/stream.h"
17+ #include " tensorflow_io/core/kernels/io_interface.h"
18+ #include " tensorflow_io/core/kernels/io_stream.h"
1819#include " api/DataFile.hh"
1920#include " api/Compiler.hh"
2021#include " api/Generic.hh"
@@ -287,6 +288,235 @@ REGISTER_KERNEL_BUILDER(Name("ReadAvro").Device(DEVICE_CPU),
287288 ReadAvroOp);
288289
289290
291+
290292} // namespace
293+
294+ class AvroIndexable : public IOIndexableInterface {
295+ public:
296+ AvroIndexable (Env* env)
297+ : env_(env) {}
298+
299+ ~AvroIndexable () {}
300+ Status Init (const std::vector<string>& input, const std::vector<string>& metadata, const void * memory_data, const int64 memory_size) override {
301+ if (input.size () > 1 ) {
302+ return errors::InvalidArgument (" more than 1 filename is not supported" );
303+ }
304+ const string& filename = input[0 ];
305+ file_.reset (new SizedRandomAccessFile (env_, filename, memory_data, memory_size));
306+ TF_RETURN_IF_ERROR (file_->GetFileSize (&file_size_));
307+
308+ string schema;
309+ for (size_t i = 0 ; i < metadata.size (); i++) {
310+ if (metadata[i].find_first_of (" schema: " ) == 0 ) {
311+ schema = metadata[i].substr (8 );
312+ }
313+ }
314+
315+ string error;
316+ std::istringstream ss (schema);
317+ if (!(avro::compileJsonSchema (ss, reader_schema_, error))) {
318+ return errors::Internal (" Avro schema error: " , error);
319+ }
320+
321+ for (int i = 0 ; i < reader_schema_.root ()->names (); i++) {
322+ columns_.push_back (reader_schema_.root ()->nameAt (i));
323+ columns_index_[reader_schema_.root ()->nameAt (i)] = i;
324+ }
325+
326+ avro::GenericDatum datum (reader_schema_.root ());
327+ const avro::GenericRecord& record = datum.value <avro::GenericRecord>();
328+ for (size_t i = 0 ; i < reader_schema_.root ()->names (); i++) {
329+ const avro::GenericDatum& field = record.field (columns_[i]);
330+ ::tensorflow::DataType dtype;
331+ switch (field.type ()) {
332+ case avro::AVRO_BOOL:
333+ dtype = DT_BOOL;
334+ break ;
335+ case avro::AVRO_INT:
336+ dtype = DT_INT32;
337+ break ;
338+ case avro::AVRO_LONG:
339+ dtype = DT_INT64;
340+ break ;
341+ case avro::AVRO_FLOAT:
342+ dtype = DT_FLOAT;
343+ break ;
344+ case avro::AVRO_DOUBLE:
345+ dtype = DT_DOUBLE;
346+ break ;
347+ case avro::AVRO_STRING:
348+ dtype = DT_STRING;
349+ break ;
350+ case avro::AVRO_BYTES:
351+ dtype = DT_STRING;
352+ break ;
353+ case avro::AVRO_FIXED:
354+ dtype = DT_STRING;
355+ break ;
356+ case avro::AVRO_ENUM:
357+ dtype = DT_STRING;
358+ break ;
359+ default :
360+ return errors::InvalidArgument (" Avro type unsupported: " , field.type ());
361+ }
362+ dtypes_.emplace_back (dtype);
363+ }
364+
365+ // Find out the total number of rows
366+ reader_stream_.reset (new AvroInputStream (file_.get ()));
367+ reader_.reset (new avro::DataFileReader<avro::GenericDatum>(std::move (reader_stream_), reader_schema_));
368+
369+ avro::DecoderPtr decoder = avro::binaryDecoder ();
370+
371+ int64 total = 0 ;
372+
373+ reader_->sync (0 );
374+ int64 offset = reader_->previousSync ();
375+ while (offset < file_size_) {
376+ StringPiece result;
377+ string buffer (16 , 0x00 );
378+ TF_RETURN_IF_ERROR (file_->Read (offset, buffer.size (), &result, &buffer[0 ]));
379+ std::unique_ptr<avro::InputStream> in = avro::memoryInputStream ((const uint8_t *)result.data (), result.size ());
380+ decoder->init (*in);
381+ long items = decoder->decodeLong ();
382+
383+ total += static_cast <int64>(items);
384+ positions_.emplace_back (std::pair<int64, int64>(static_cast <int64>(items), offset));
385+
386+ reader_->sync (offset);
387+ offset = reader_->previousSync ();
388+ }
389+
390+ for (size_t i = 0 ; i < columns_.size (); i++) {
391+ shapes_.emplace_back (TensorShape ({total}));
392+ }
393+ return Status::OK ();
394+ }
395+
396+ Status Partitions (std::vector<int64> *partitions) override {
397+ partitions->clear ();
398+ // positions_ are pairs of <items, offset>
399+ for (size_t i = 0 ; i < positions_.size (); i++) {
400+ partitions->emplace_back (positions_[i].first );
401+ }
402+ return Status::OK ();
403+ }
404+
405+ Status Components (Tensor* components) override {
406+ *components = Tensor (DT_STRING, TensorShape ({static_cast <int64>(columns_.size ())}));
407+ for (size_t i = 0 ; i < columns_.size (); i++) {
408+ components->flat <string>()(i) = columns_[i];
409+ }
410+ return Status::OK ();
411+ }
412+ Status Spec (const Tensor& component, PartialTensorShape* shape, DataType* dtype, bool label) override {
413+ if (columns_index_.find (component.scalar <string>()()) == columns_index_.end ()) {
414+ return errors::InvalidArgument (" component " , component.scalar <string>()(), " is invalid" );
415+ }
416+ int64 column_index = columns_index_[component.scalar <string>()()];
417+ *shape = shapes_[column_index];
418+ *dtype = dtypes_[column_index];
419+ return Status::OK ();
420+ }
421+
422+ Status Read (const int64 start, const int64 stop, const Tensor& component, Tensor* value, Tensor* label) override {
423+ const string& column = component.scalar <string>()();
424+ avro::GenericDatum datum (reader_schema_);
425+
426+ // Find the start sync point
427+ int64 item_index_sync = 0 ;
428+ for (size_t i = 0 ; i < positions_.size (); i++, item_index_sync += positions_[i].first ) {
429+ if (item_index_sync >= stop) {
430+ continue ;
431+ }
432+ if (item_index_sync + positions_[i].first <= start) {
433+ continue ;
434+ }
435+ // TODO: Avro is sync point partitioned and each block is very similiar to
436+ // Row Group of parquet. Ideally each block should be cached with the hope
437+ // that slicing and indexing will happend around the same block across multiple
438+ // rows. Caching is not done yet.
439+
440+ // Seek to sync
441+ reader_->seek (positions_[i].second );
442+ for (int64 item_index = item_index_sync; item_index < (item_index_sync + positions_[i].first ) && item_index < stop; item_index++) {
443+ // Read anyway
444+ if (!reader_->read (datum)) {
445+ return errors::Internal (" unable to read record at: " , item_index);
446+ }
447+ // Assign only when in range
448+ if (item_index >= start) {
449+ const avro::GenericRecord& record = datum.value <avro::GenericRecord>();
450+ const avro::GenericDatum& field = record.field (column);
451+ switch (field.type ()) {
452+ case avro::AVRO_BOOL:
453+ value->flat <bool >()(item_index - start) = field.value <bool >();
454+ break ;
455+ case avro::AVRO_INT:
456+ value->flat <int32>()(item_index - start) = field.value <int32_t >();
457+ break ;
458+ case avro::AVRO_LONG:
459+ value->flat <int64>()(item_index - start) = field.value <int64_t >();
460+ break ;
461+ case avro::AVRO_FLOAT:
462+ value->flat <float >()(item_index - start) = field.value <float >();
463+ break ;
464+ case avro::AVRO_DOUBLE:
465+ value->flat <double >()(item_index - start) = field.value <double >();
466+ break ;
467+ case avro::AVRO_STRING:
468+ value->flat <string>()(item_index - start) = field.value <string>();
469+ break ;
470+ case avro::AVRO_BYTES: {
471+ const std::vector<uint8_t >& field_value = field.value <std::vector<uint8_t >>();
472+ value->flat <string>()(item_index - start) = string ((char *)&field_value[0 ], field_value.size ());
473+ }
474+ break ;
475+ case avro::AVRO_FIXED: {
476+ const std::vector<uint8_t >& field_value = field.value <avro::GenericFixed>().value ();
477+ value->flat <string>()(item_index - start) = string ((char *)&field_value[0 ], field_value.size ());
478+ }
479+ break ;
480+ case avro::AVRO_ENUM:
481+ value->flat <string>()(item_index - start) = field.value <avro::GenericEnum>().symbol ();
482+ break ;
483+ default :
484+ return errors::InvalidArgument (" unsupported data type: " , field.type ());
485+ }
486+ }
487+ }
488+ }
489+ return Status::OK ();
490+ }
491+
492+ string DebugString () const override {
493+ mutex_lock l (mu_);
494+ return strings::StrCat (" AvroIndexable" );
495+ }
496+ private:
497+ mutable mutex mu_;
498+ Env* env_ GUARDED_BY (mu_);
499+ std::unique_ptr<SizedRandomAccessFile> file_ GUARDED_BY (mu_);
500+ uint64 file_size_ GUARDED_BY (mu_);
501+ avro::ValidSchema reader_schema_;
502+ std::unique_ptr<avro::InputStream> reader_stream_;
503+ std::unique_ptr<avro::DataFileReader<avro::GenericDatum>> reader_;
504+ std::vector<std::pair<int64, int64>> positions_; // <items/sync> pair
505+
506+ std::vector<DataType> dtypes_;
507+ std::vector<TensorShape> shapes_;
508+ std::vector<string> columns_;
509+ std::unordered_map<string, int64> columns_index_;
510+ };
511+
512+ REGISTER_KERNEL_BUILDER (Name(" AvroIndexableInit" ).Device(DEVICE_CPU),
513+ IOInterfaceInitOp<AvroIndexable>);
514+ REGISTER_KERNEL_BUILDER (Name(" AvroIndexableSpec" ).Device(DEVICE_CPU),
515+ IOInterfaceSpecOp<AvroIndexable>);
516+ REGISTER_KERNEL_BUILDER (Name(" AvroIndexablePartitions" ).Device(DEVICE_CPU),
517+ IOIndexablePartitionsOp<AvroIndexable>);
518+ REGISTER_KERNEL_BUILDER (Name(" AvroIndexableRead" ).Device(DEVICE_CPU),
519+ IOIndexableReadOp<AvroIndexable>);
520+
291521} // namespace data
292522} // namespace tensorflow
0 commit comments