Skip to content

Commit 2d25fc9

Browse files
authored
Add tfio.IOTensor.from_avro support (#440)
* Add tfio.IOTensor.from_avro support Avro is a columnar file format that naturally fits into a table/column data. Avro file itself is not directly indexable. However, it is pseudo-indexable as it consists of blocks with each blocks specifying file offset/size, and count of items. So indexing coulb be done by small range iteration. It would be desirable to make Avro indexable as it will be much more convenient with increased flexibility. This PR adds tfio.IOTensor.from_avro support so that it is possible to acess avro data through natual __getitem__ operations. Signed-off-by: Yong Tang <[email protected]> * Add a Partitions function to Avro, so that it is possible to dynamically adjust the capacity of the chunk size when reading. Signed-off-by: Yong Tang <[email protected]> * Rename to io_stream.h for consistency Signed-off-by: Yong Tang <[email protected]> * Remove the need to pass component, unless needed explicitly Signed-off-by: Yong Tang <[email protected]> * Move Partitions to a generic location and support dataset Signed-off-by: Yong Tang <[email protected]>
1 parent 5bbb94c commit 2d25fc9

29 files changed

+587
-75
lines changed

tensorflow_io/arrow/kernels/arrow_dataset_ops.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ limitations under the License.
1818
#include "arrow/util/io-util.h"
1919
#include "tensorflow/core/framework/dataset.h"
2020
#include "tensorflow/core/graph/graph.h"
21-
#include "tensorflow_io/core/kernels/stream.h"
21+
#include "tensorflow_io/core/kernels/io_stream.h"
2222
#include "tensorflow_io/arrow/kernels/arrow_kernels.h"
2323
#include "tensorflow_io/arrow/kernels/arrow_stream_client.h"
2424
#include "tensorflow_io/arrow/kernels/arrow_util.h"

tensorflow_io/arrow/kernels/arrow_kernels.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ limitations under the License.
1616
#ifndef TENSORFLOW_IO_ARROW_KERNELS_H_
1717
#define TENSORFLOW_IO_ARROW_KERNELS_H_
1818

19-
#include "kernels/stream.h"
19+
#include "tensorflow_io/core/kernels/io_stream.h"
2020
#include "arrow/io/api.h"
2121
#include "arrow/buffer.h"
2222
#include "arrow/type.h"

tensorflow_io/audio/kernels/audio_kernels.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ limitations under the License.
1414
==============================================================================*/
1515

1616
#include "tensorflow_io/core/kernels/io_interface.h"
17-
#include "tensorflow_io/core/kernels/stream.h"
17+
#include "tensorflow_io/core/kernels/io_stream.h"
1818

1919
namespace tensorflow {
2020
namespace data {

tensorflow_io/audio/ops/audio_ops.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ REGISTER_OP("WAVIndexableInit")
3131

3232
REGISTER_OP("WAVIndexableSpec")
3333
.Input("input: resource")
34-
.Input("component: int64")
3534
.Output("shape: int64")
3635
.Output("dtype: int64")
3736
.Output("rate: int32")
@@ -46,7 +45,6 @@ REGISTER_OP("WAVIndexableRead")
4645
.Input("input: resource")
4746
.Input("start: int64")
4847
.Input("stop: int64")
49-
.Input("component: int64")
5048
.Output("value: dtype")
5149
.Attr("shape: shape")
5250
.Attr("dtype: type")

tensorflow_io/avro/kernels/avro_kernels.cc

Lines changed: 231 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ limitations under the License.
1414
==============================================================================*/
1515

1616
#include "tensorflow/core/framework/op_kernel.h"
17-
#include "tensorflow_io/core/kernels/stream.h"
17+
#include "tensorflow_io/core/kernels/io_interface.h"
18+
#include "tensorflow_io/core/kernels/io_stream.h"
1819
#include "api/DataFile.hh"
1920
#include "api/Compiler.hh"
2021
#include "api/Generic.hh"
@@ -287,6 +288,235 @@ REGISTER_KERNEL_BUILDER(Name("ReadAvro").Device(DEVICE_CPU),
287288
ReadAvroOp);
288289

289290

291+
290292
} // namespace
293+
294+
class AvroIndexable : public IOIndexableInterface {
295+
public:
296+
AvroIndexable(Env* env)
297+
: env_(env) {}
298+
299+
~AvroIndexable() {}
300+
Status Init(const std::vector<string>& input, const std::vector<string>& metadata, const void* memory_data, const int64 memory_size) override {
301+
if (input.size() > 1) {
302+
return errors::InvalidArgument("more than 1 filename is not supported");
303+
}
304+
const string& filename = input[0];
305+
file_.reset(new SizedRandomAccessFile(env_, filename, memory_data, memory_size));
306+
TF_RETURN_IF_ERROR(file_->GetFileSize(&file_size_));
307+
308+
string schema;
309+
for (size_t i = 0; i < metadata.size(); i++) {
310+
if (metadata[i].find_first_of("schema: ") == 0) {
311+
schema = metadata[i].substr(8);
312+
}
313+
}
314+
315+
string error;
316+
std::istringstream ss(schema);
317+
if (!(avro::compileJsonSchema(ss, reader_schema_, error))) {
318+
return errors::Internal("Avro schema error: ", error);
319+
}
320+
321+
for (int i = 0; i < reader_schema_.root()->names(); i++) {
322+
columns_.push_back(reader_schema_.root()->nameAt(i));
323+
columns_index_[reader_schema_.root()->nameAt(i)] = i;
324+
}
325+
326+
avro::GenericDatum datum(reader_schema_.root());
327+
const avro::GenericRecord& record = datum.value<avro::GenericRecord>();
328+
for (size_t i = 0; i < reader_schema_.root()->names(); i++) {
329+
const avro::GenericDatum& field = record.field(columns_[i]);
330+
::tensorflow::DataType dtype;
331+
switch(field.type()) {
332+
case avro::AVRO_BOOL:
333+
dtype = DT_BOOL;
334+
break;
335+
case avro::AVRO_INT:
336+
dtype = DT_INT32;
337+
break;
338+
case avro::AVRO_LONG:
339+
dtype = DT_INT64;
340+
break;
341+
case avro::AVRO_FLOAT:
342+
dtype = DT_FLOAT;
343+
break;
344+
case avro::AVRO_DOUBLE:
345+
dtype = DT_DOUBLE;
346+
break;
347+
case avro::AVRO_STRING:
348+
dtype = DT_STRING;
349+
break;
350+
case avro::AVRO_BYTES:
351+
dtype = DT_STRING;
352+
break;
353+
case avro::AVRO_FIXED:
354+
dtype = DT_STRING;
355+
break;
356+
case avro::AVRO_ENUM:
357+
dtype = DT_STRING;
358+
break;
359+
default:
360+
return errors::InvalidArgument("Avro type unsupported: ", field.type());
361+
}
362+
dtypes_.emplace_back(dtype);
363+
}
364+
365+
// Find out the total number of rows
366+
reader_stream_.reset(new AvroInputStream(file_.get()));
367+
reader_.reset(new avro::DataFileReader<avro::GenericDatum>(std::move(reader_stream_), reader_schema_));
368+
369+
avro::DecoderPtr decoder = avro::binaryDecoder();
370+
371+
int64 total = 0;
372+
373+
reader_->sync(0);
374+
int64 offset = reader_->previousSync();
375+
while (offset < file_size_) {
376+
StringPiece result;
377+
string buffer(16, 0x00);
378+
TF_RETURN_IF_ERROR(file_->Read(offset, buffer.size(), &result, &buffer[0]));
379+
std::unique_ptr<avro::InputStream> in = avro::memoryInputStream((const uint8_t*)result.data(), result.size());
380+
decoder->init(*in);
381+
long items = decoder->decodeLong();
382+
383+
total += static_cast<int64>(items);
384+
positions_.emplace_back(std::pair<int64, int64>(static_cast<int64>(items), offset));
385+
386+
reader_->sync(offset);
387+
offset = reader_->previousSync();
388+
}
389+
390+
for (size_t i = 0; i < columns_.size(); i++) {
391+
shapes_.emplace_back(TensorShape({total}));
392+
}
393+
return Status::OK();
394+
}
395+
396+
Status Partitions(std::vector<int64> *partitions) override {
397+
partitions->clear();
398+
// positions_ are pairs of <items, offset>
399+
for (size_t i = 0; i < positions_.size(); i++) {
400+
partitions->emplace_back(positions_[i].first);
401+
}
402+
return Status::OK();
403+
}
404+
405+
Status Components(Tensor* components) override {
406+
*components = Tensor(DT_STRING, TensorShape({static_cast<int64>(columns_.size())}));
407+
for (size_t i = 0; i < columns_.size(); i++) {
408+
components->flat<string>()(i) = columns_[i];
409+
}
410+
return Status::OK();
411+
}
412+
Status Spec(const Tensor& component, PartialTensorShape* shape, DataType* dtype, bool label) override {
413+
if (columns_index_.find(component.scalar<string>()()) == columns_index_.end()) {
414+
return errors::InvalidArgument("component ", component.scalar<string>()(), " is invalid");
415+
}
416+
int64 column_index = columns_index_[component.scalar<string>()()];
417+
*shape = shapes_[column_index];
418+
*dtype = dtypes_[column_index];
419+
return Status::OK();
420+
}
421+
422+
Status Read(const int64 start, const int64 stop, const Tensor& component, Tensor* value, Tensor* label) override {
423+
const string& column = component.scalar<string>()();
424+
avro::GenericDatum datum(reader_schema_);
425+
426+
// Find the start sync point
427+
int64 item_index_sync = 0;
428+
for (size_t i = 0; i < positions_.size(); i++, item_index_sync += positions_[i].first) {
429+
if (item_index_sync >= stop) {
430+
continue;
431+
}
432+
if (item_index_sync + positions_[i].first <= start) {
433+
continue;
434+
}
435+
// TODO: Avro is sync point partitioned and each block is very similiar to
436+
// Row Group of parquet. Ideally each block should be cached with the hope
437+
// that slicing and indexing will happend around the same block across multiple
438+
// rows. Caching is not done yet.
439+
440+
// Seek to sync
441+
reader_->seek(positions_[i].second);
442+
for (int64 item_index = item_index_sync; item_index < (item_index_sync + positions_[i].first) && item_index < stop; item_index++) {
443+
// Read anyway
444+
if (!reader_->read(datum)) {
445+
return errors::Internal("unable to read record at: ", item_index);
446+
}
447+
// Assign only when in range
448+
if (item_index >= start) {
449+
const avro::GenericRecord& record = datum.value<avro::GenericRecord>();
450+
const avro::GenericDatum& field = record.field(column);
451+
switch(field.type()) {
452+
case avro::AVRO_BOOL:
453+
value->flat<bool>()(item_index - start) = field.value<bool>();
454+
break;
455+
case avro::AVRO_INT:
456+
value->flat<int32>()(item_index - start) = field.value<int32_t>();
457+
break;
458+
case avro::AVRO_LONG:
459+
value->flat<int64>()(item_index - start) = field.value<int64_t>();
460+
break;
461+
case avro::AVRO_FLOAT:
462+
value->flat<float>()(item_index - start) = field.value<float>();
463+
break;
464+
case avro::AVRO_DOUBLE:
465+
value->flat<double>()(item_index - start) = field.value<double>();
466+
break;
467+
case avro::AVRO_STRING:
468+
value->flat<string>()(item_index - start) = field.value<string>();
469+
break;
470+
case avro::AVRO_BYTES: {
471+
const std::vector<uint8_t>& field_value = field.value<std::vector<uint8_t>>();
472+
value->flat<string>()(item_index - start) = string((char *)&field_value[0], field_value.size());
473+
}
474+
break;
475+
case avro::AVRO_FIXED: {
476+
const std::vector<uint8_t>& field_value = field.value<avro::GenericFixed>().value();
477+
value->flat<string>()(item_index - start) = string((char *)&field_value[0], field_value.size());
478+
}
479+
break;
480+
case avro::AVRO_ENUM:
481+
value->flat<string>()(item_index - start) = field.value<avro::GenericEnum>().symbol();
482+
break;
483+
default:
484+
return errors::InvalidArgument("unsupported data type: ", field.type());
485+
}
486+
}
487+
}
488+
}
489+
return Status::OK();
490+
}
491+
492+
string DebugString() const override {
493+
mutex_lock l(mu_);
494+
return strings::StrCat("AvroIndexable");
495+
}
496+
private:
497+
mutable mutex mu_;
498+
Env* env_ GUARDED_BY(mu_);
499+
std::unique_ptr<SizedRandomAccessFile> file_ GUARDED_BY(mu_);
500+
uint64 file_size_ GUARDED_BY(mu_);
501+
avro::ValidSchema reader_schema_;
502+
std::unique_ptr<avro::InputStream> reader_stream_;
503+
std::unique_ptr<avro::DataFileReader<avro::GenericDatum>> reader_;
504+
std::vector<std::pair<int64, int64>> positions_; // <items/sync> pair
505+
506+
std::vector<DataType> dtypes_;
507+
std::vector<TensorShape> shapes_;
508+
std::vector<string> columns_;
509+
std::unordered_map<string, int64> columns_index_;
510+
};
511+
512+
REGISTER_KERNEL_BUILDER(Name("AvroIndexableInit").Device(DEVICE_CPU),
513+
IOInterfaceInitOp<AvroIndexable>);
514+
REGISTER_KERNEL_BUILDER(Name("AvroIndexableSpec").Device(DEVICE_CPU),
515+
IOInterfaceSpecOp<AvroIndexable>);
516+
REGISTER_KERNEL_BUILDER(Name("AvroIndexablePartitions").Device(DEVICE_CPU),
517+
IOIndexablePartitionsOp<AvroIndexable>);
518+
REGISTER_KERNEL_BUILDER(Name("AvroIndexableRead").Device(DEVICE_CPU),
519+
IOIndexableReadOp<AvroIndexable>);
520+
291521
} // namespace data
292522
} // namespace tensorflow

tensorflow_io/avro/ops/avro_ops.cc

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,54 @@ REGISTER_OP("ReadAvro")
4545
return Status::OK();
4646
});
4747

48+
REGISTER_OP("AvroIndexableInit")
49+
.Input("input: string")
50+
.Input("metadata: string")
51+
.Output("resource: resource")
52+
.Output("component: string")
53+
.Attr("container: string = ''")
54+
.Attr("shared_name: string = ''")
55+
.SetShapeFn([](shape_inference::InferenceContext* c) {
56+
c->set_output(0, c->Scalar());
57+
c->set_output(1, c->MakeShape({}));
58+
return Status::OK();
59+
});
60+
61+
REGISTER_OP("AvroIndexableSpec")
62+
.Input("input: resource")
63+
.Input("component: string")
64+
.Output("shape: int64")
65+
.Output("dtype: int64")
66+
.SetShapeFn([](shape_inference::InferenceContext* c) {
67+
c->set_output(0, c->MakeShape({c->UnknownDim()}));
68+
c->set_output(1, c->MakeShape({}));
69+
return Status::OK();
70+
});
71+
72+
REGISTER_OP("AvroIndexableRead")
73+
.Input("input: resource")
74+
.Input("start: int64")
75+
.Input("stop: int64")
76+
.Input("component: string")
77+
.Output("value: dtype")
78+
.Attr("filter: list(string) = []")
79+
.Attr("shape: shape")
80+
.Attr("dtype: type")
81+
.SetShapeFn([](shape_inference::InferenceContext* c) {
82+
PartialTensorShape shape;
83+
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
84+
shape_inference::ShapeHandle entry;
85+
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &entry));
86+
c->set_output(0, entry);
87+
return Status::OK();
88+
});
89+
90+
REGISTER_OP("AvroIndexablePartitions")
91+
.Input("input: resource")
92+
.Output("partitions: int64")
93+
.SetShapeFn([](shape_inference::InferenceContext* c) {
94+
c->set_output(0, c->MakeShape({c->UnknownDim()}));
95+
return Status::OK();
96+
});
97+
4898
} // namespace tensorflow

tensorflow_io/avro/python/ops/avro_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,18 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
import warnings
21+
2022
import tensorflow as tf
2123
from tensorflow_io.core.python.ops import core_ops
2224
from tensorflow_io.core.python.ops import data_ops
2325

26+
warnings.warn(
27+
"The tensorflow_io.avro.AvroDataset is "
28+
"deprecated. Please look for tfio.IOTensor.from_avro "
29+
"for reading Avro files into tensorflow.",
30+
DeprecationWarning)
31+
2432
def list_avro_columns(filename, schema, **kwargs):
2533
"""list_avro_columns"""
2634
if not tf.executing_eagerly():

tensorflow_io/core/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ cc_library(
3030
srcs = [
3131
"kernels/dataset_ops.h",
3232
"kernels/io_interface.h",
33-
"kernels/stream.h",
33+
"kernels/io_stream.h",
3434
],
3535
copts = tf_io_copts(),
3636
includes = [

tensorflow_io/core/kernels/archive_kernels.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ limitations under the License.
2020
#include "tensorflow/core/lib/io/random_inputstream.h"
2121
#include "tensorflow/core/lib/io/zlib_compression_options.h"
2222
#include "tensorflow/core/lib/io/zlib_inputstream.h"
23-
#include "tensorflow_io/core/kernels/stream.h"
23+
#include "tensorflow_io/core/kernels/io_stream.h"
2424

2525
namespace tensorflow {
2626
namespace data {

0 commit comments

Comments
 (0)